Files
ea-chatbot-lg/backend/tests/test_node_model_selection.py

127 lines
4.9 KiB
Python

import pytest
from unittest.mock import MagicMock, patch
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node
from ea_chatbot.graph.nodes.planner import planner_node
from ea_chatbot.graph.nodes.reflector import reflector_node
from ea_chatbot.graph.nodes.synthesizer import synthesizer_node
from ea_chatbot.graph.workers.data_analyst.nodes.coder import coder_node as analyst_coder
from ea_chatbot.graph.workers.researcher.nodes.searcher import searcher_node as researcher_searcher
from ea_chatbot.graph.state import AgentState
from ea_chatbot.graph.workers.data_analyst.state import WorkerState as AnalystState
from ea_chatbot.graph.workers.researcher.state import WorkerState as ResearcherState
from ea_chatbot.config import Settings
@pytest.fixture
def mock_settings():
settings = Settings()
settings.query_analyzer_llm.model = "model-qa"
settings.planner_llm.model = "model-planner"
settings.reflector_llm.model = "model-reflector"
settings.synthesizer_llm.model = "model-synthesizer"
settings.coder_llm.model = "model-coder"
settings.researcher_llm.model = "model-researcher"
return settings
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
@patch("ea_chatbot.graph.nodes.query_analyzer.Settings")
def test_query_analyzer_model(mock_settings_cls, mock_get_llm, mock_settings):
mock_settings_cls.return_value = mock_settings
mock_llm = MagicMock()
mock_get_llm.return_value = mock_llm
state = AgentState(question="test", messages=[], checklist=[], current_step=0, iterations=0, vfs={}, plots=[], dfs={}, next_action="", analysis={})
try:
query_analyzer_node(state)
except:
pass # We only care about the call
# Verify it was called with the QA config
args = mock_get_llm.call_args[0][0]
assert args.model == "model-qa"
@patch("ea_chatbot.graph.nodes.planner.get_llm_model")
@patch("ea_chatbot.graph.nodes.planner.Settings")
def test_planner_model(mock_settings_cls, mock_get_llm, mock_settings):
mock_settings_cls.return_value = mock_settings
mock_llm = MagicMock()
mock_get_llm.return_value = mock_llm
state = AgentState(question="test", messages=[], checklist=[], current_step=0, iterations=0, vfs={}, plots=[], dfs={}, next_action="", analysis={})
try:
planner_node(state)
except:
pass
args = mock_get_llm.call_args[0][0]
assert args.model == "model-planner"
@patch("ea_chatbot.graph.nodes.reflector.get_llm_model")
@patch("ea_chatbot.graph.nodes.reflector.Settings")
def test_reflector_model(mock_settings_cls, mock_get_llm, mock_settings):
mock_settings_cls.return_value = mock_settings
mock_llm = MagicMock()
mock_get_llm.return_value = mock_llm
state = AgentState(question="test", messages=[], checklist=[{"task": "T1"}], current_step=0, iterations=0, vfs={}, plots=[], dfs={}, next_action="", analysis={}, summary="Worker did X")
try:
reflector_node(state)
except:
pass
args = mock_get_llm.call_args[0][0]
assert args.model == "model-reflector"
@patch("ea_chatbot.graph.nodes.synthesizer.get_llm_model")
@patch("ea_chatbot.graph.nodes.synthesizer.Settings")
def test_synthesizer_model(mock_settings_cls, mock_get_llm, mock_settings):
mock_settings_cls.return_value = mock_settings
mock_llm = MagicMock()
mock_get_llm.return_value = mock_llm
state = AgentState(question="test", messages=[], checklist=[], current_step=0, iterations=0, vfs={}, plots=[], dfs={}, next_action="", analysis={})
try:
synthesizer_node(state)
except:
pass
args = mock_get_llm.call_args[0][0]
assert args.model == "model-synthesizer"
@patch("ea_chatbot.graph.workers.data_analyst.nodes.coder.get_llm_model")
@patch("ea_chatbot.graph.workers.data_analyst.nodes.coder.Settings")
def test_analyst_coder_model(mock_settings_cls, mock_get_llm, mock_settings):
mock_settings_cls.return_value = mock_settings
mock_llm = MagicMock()
mock_get_llm.return_value = mock_llm
state = AnalystState(task="test", messages=[], iterations=0, vfs_state={}, plots=[], result=None, code=None, output=None, error=None)
try:
analyst_coder(state)
except:
pass
args = mock_get_llm.call_args[0][0]
assert args.model == "model-coder"
@patch("ea_chatbot.graph.workers.researcher.nodes.searcher.get_llm_model")
@patch("ea_chatbot.graph.workers.researcher.nodes.searcher.Settings")
def test_researcher_searcher_model(mock_settings_cls, mock_get_llm, mock_settings):
mock_settings_cls.return_value = mock_settings
mock_llm = MagicMock()
mock_get_llm.return_value = mock_llm
state = ResearcherState(task="test", messages=[], iterations=0, result=None, queries=[], raw_results=[])
try:
researcher_searcher(state)
except:
pass
args = mock_get_llm.call_args[0][0]
assert args.model == "model-researcher"