import pytest from unittest.mock import MagicMock, patch from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node from ea_chatbot.graph.nodes.planner import planner_node from ea_chatbot.graph.nodes.reflector import reflector_node from ea_chatbot.graph.nodes.synthesizer import synthesizer_node from ea_chatbot.graph.workers.data_analyst.nodes.coder import coder_node as analyst_coder from ea_chatbot.graph.workers.researcher.nodes.searcher import searcher_node as researcher_searcher from ea_chatbot.graph.state import AgentState from ea_chatbot.graph.workers.data_analyst.state import WorkerState as AnalystState from ea_chatbot.graph.workers.researcher.state import WorkerState as ResearcherState from ea_chatbot.config import Settings @pytest.fixture def mock_settings(): settings = Settings() settings.query_analyzer_llm.model = "model-qa" settings.planner_llm.model = "model-planner" settings.reflector_llm.model = "model-reflector" settings.synthesizer_llm.model = "model-synthesizer" settings.coder_llm.model = "model-coder" settings.researcher_llm.model = "model-researcher" return settings @patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model") @patch("ea_chatbot.graph.nodes.query_analyzer.Settings") def test_query_analyzer_model(mock_settings_cls, mock_get_llm, mock_settings): mock_settings_cls.return_value = mock_settings mock_llm = MagicMock() mock_get_llm.return_value = mock_llm state = AgentState(question="test", messages=[], checklist=[], current_step=0, iterations=0, vfs={}, plots=[], dfs={}, next_action="", analysis={}) try: query_analyzer_node(state) except: pass # We only care about the call # Verify it was called with the QA config args = mock_get_llm.call_args[0][0] assert args.model == "model-qa" @patch("ea_chatbot.graph.nodes.planner.get_llm_model") @patch("ea_chatbot.graph.nodes.planner.Settings") def test_planner_model(mock_settings_cls, mock_get_llm, mock_settings): mock_settings_cls.return_value = mock_settings mock_llm = MagicMock() mock_get_llm.return_value = mock_llm state = AgentState(question="test", messages=[], checklist=[], current_step=0, iterations=0, vfs={}, plots=[], dfs={}, next_action="", analysis={}) try: planner_node(state) except: pass args = mock_get_llm.call_args[0][0] assert args.model == "model-planner" @patch("ea_chatbot.graph.nodes.reflector.get_llm_model") @patch("ea_chatbot.graph.nodes.reflector.Settings") def test_reflector_model(mock_settings_cls, mock_get_llm, mock_settings): mock_settings_cls.return_value = mock_settings mock_llm = MagicMock() mock_get_llm.return_value = mock_llm state = AgentState(question="test", messages=[], checklist=[{"task": "T1"}], current_step=0, iterations=0, vfs={}, plots=[], dfs={}, next_action="", analysis={}, summary="Worker did X") try: reflector_node(state) except: pass args = mock_get_llm.call_args[0][0] assert args.model == "model-reflector" @patch("ea_chatbot.graph.nodes.synthesizer.get_llm_model") @patch("ea_chatbot.graph.nodes.synthesizer.Settings") def test_synthesizer_model(mock_settings_cls, mock_get_llm, mock_settings): mock_settings_cls.return_value = mock_settings mock_llm = MagicMock() mock_get_llm.return_value = mock_llm state = AgentState(question="test", messages=[], checklist=[], current_step=0, iterations=0, vfs={}, plots=[], dfs={}, next_action="", analysis={}) try: synthesizer_node(state) except: pass args = mock_get_llm.call_args[0][0] assert args.model == "model-synthesizer" @patch("ea_chatbot.graph.workers.data_analyst.nodes.coder.get_llm_model") @patch("ea_chatbot.graph.workers.data_analyst.nodes.coder.Settings") def test_analyst_coder_model(mock_settings_cls, mock_get_llm, mock_settings): mock_settings_cls.return_value = mock_settings mock_llm = MagicMock() mock_get_llm.return_value = mock_llm state = AnalystState(task="test", messages=[], iterations=0, vfs_state={}, plots=[], result=None, code=None, output=None, error=None) try: analyst_coder(state) except: pass args = mock_get_llm.call_args[0][0] assert args.model == "model-coder" @patch("ea_chatbot.graph.workers.researcher.nodes.searcher.get_llm_model") @patch("ea_chatbot.graph.workers.researcher.nodes.searcher.Settings") def test_researcher_searcher_model(mock_settings_cls, mock_get_llm, mock_settings): mock_settings_cls.return_value = mock_settings mock_llm = MagicMock() mock_get_llm.return_value = mock_llm state = ResearcherState(task="test", messages=[], iterations=0, result=None, queries=[], raw_results=[]) try: researcher_searcher(state) except: pass args = mock_get_llm.call_args[0][0] assert args.model == "model-researcher"