diff --git a/backend/tests/test_node_model_selection.py b/backend/tests/test_node_model_selection.py new file mode 100644 index 0000000..bfb3f76 --- /dev/null +++ b/backend/tests/test_node_model_selection.py @@ -0,0 +1,126 @@ +import pytest +from unittest.mock import MagicMock, patch +from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node +from ea_chatbot.graph.nodes.planner import planner_node +from ea_chatbot.graph.nodes.reflector import reflector_node +from ea_chatbot.graph.nodes.synthesizer import synthesizer_node +from ea_chatbot.graph.workers.data_analyst.nodes.coder import coder_node as analyst_coder +from ea_chatbot.graph.workers.researcher.nodes.searcher import searcher_node as researcher_searcher +from ea_chatbot.graph.state import AgentState +from ea_chatbot.graph.workers.data_analyst.state import WorkerState as AnalystState +from ea_chatbot.graph.workers.researcher.state import WorkerState as ResearcherState +from ea_chatbot.config import Settings + +@pytest.fixture +def mock_settings(): + settings = Settings() + settings.query_analyzer_llm.model = "model-qa" + settings.planner_llm.model = "model-planner" + settings.reflector_llm.model = "model-reflector" + settings.synthesizer_llm.model = "model-synthesizer" + settings.coder_llm.model = "model-coder" + settings.researcher_llm.model = "model-researcher" + return settings + +@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model") +@patch("ea_chatbot.graph.nodes.query_analyzer.Settings") +def test_query_analyzer_model(mock_settings_cls, mock_get_llm, mock_settings): + mock_settings_cls.return_value = mock_settings + mock_llm = MagicMock() + mock_get_llm.return_value = mock_llm + + state = AgentState(question="test", messages=[], checklist=[], current_step=0, iterations=0, vfs={}, plots=[], dfs={}, next_action="", analysis={}) + + try: + query_analyzer_node(state) + except: + pass # We only care about the call + + # Verify it was called with the QA config + args = mock_get_llm.call_args[0][0] + assert args.model == "model-qa" + +@patch("ea_chatbot.graph.nodes.planner.get_llm_model") +@patch("ea_chatbot.graph.nodes.planner.Settings") +def test_planner_model(mock_settings_cls, mock_get_llm, mock_settings): + mock_settings_cls.return_value = mock_settings + mock_llm = MagicMock() + mock_get_llm.return_value = mock_llm + + state = AgentState(question="test", messages=[], checklist=[], current_step=0, iterations=0, vfs={}, plots=[], dfs={}, next_action="", analysis={}) + + try: + planner_node(state) + except: + pass + + args = mock_get_llm.call_args[0][0] + assert args.model == "model-planner" + +@patch("ea_chatbot.graph.nodes.reflector.get_llm_model") +@patch("ea_chatbot.graph.nodes.reflector.Settings") +def test_reflector_model(mock_settings_cls, mock_get_llm, mock_settings): + mock_settings_cls.return_value = mock_settings + mock_llm = MagicMock() + mock_get_llm.return_value = mock_llm + + state = AgentState(question="test", messages=[], checklist=[{"task": "T1"}], current_step=0, iterations=0, vfs={}, plots=[], dfs={}, next_action="", analysis={}, summary="Worker did X") + + try: + reflector_node(state) + except: + pass + + args = mock_get_llm.call_args[0][0] + assert args.model == "model-reflector" + +@patch("ea_chatbot.graph.nodes.synthesizer.get_llm_model") +@patch("ea_chatbot.graph.nodes.synthesizer.Settings") +def test_synthesizer_model(mock_settings_cls, mock_get_llm, mock_settings): + mock_settings_cls.return_value = mock_settings + mock_llm = MagicMock() + mock_get_llm.return_value = mock_llm + + state = AgentState(question="test", messages=[], checklist=[], current_step=0, iterations=0, vfs={}, plots=[], dfs={}, next_action="", analysis={}) + + try: + synthesizer_node(state) + except: + pass + + args = mock_get_llm.call_args[0][0] + assert args.model == "model-synthesizer" + +@patch("ea_chatbot.graph.workers.data_analyst.nodes.coder.get_llm_model") +@patch("ea_chatbot.graph.workers.data_analyst.nodes.coder.Settings") +def test_analyst_coder_model(mock_settings_cls, mock_get_llm, mock_settings): + mock_settings_cls.return_value = mock_settings + mock_llm = MagicMock() + mock_get_llm.return_value = mock_llm + + state = AnalystState(task="test", messages=[], iterations=0, vfs_state={}, plots=[], result=None, code=None, output=None, error=None) + + try: + analyst_coder(state) + except: + pass + + args = mock_get_llm.call_args[0][0] + assert args.model == "model-coder" + +@patch("ea_chatbot.graph.workers.researcher.nodes.searcher.get_llm_model") +@patch("ea_chatbot.graph.workers.researcher.nodes.searcher.Settings") +def test_researcher_searcher_model(mock_settings_cls, mock_get_llm, mock_settings): + mock_settings_cls.return_value = mock_settings + mock_llm = MagicMock() + mock_get_llm.return_value = mock_llm + + state = ResearcherState(task="test", messages=[], iterations=0, result=None, queries=[], raw_results=[]) + + try: + researcher_searcher(state) + except: + pass + + args = mock_get_llm.call_args[0][0] + assert args.model == "model-researcher"