127 lines
4.9 KiB
Python
127 lines
4.9 KiB
Python
import pytest
|
|
from unittest.mock import MagicMock, patch
|
|
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node
|
|
from ea_chatbot.graph.nodes.planner import planner_node
|
|
from ea_chatbot.graph.nodes.reflector import reflector_node
|
|
from ea_chatbot.graph.nodes.synthesizer import synthesizer_node
|
|
from ea_chatbot.graph.workers.data_analyst.nodes.coder import coder_node as analyst_coder
|
|
from ea_chatbot.graph.workers.researcher.nodes.searcher import searcher_node as researcher_searcher
|
|
from ea_chatbot.graph.state import AgentState
|
|
from ea_chatbot.graph.workers.data_analyst.state import WorkerState as AnalystState
|
|
from ea_chatbot.graph.workers.researcher.state import WorkerState as ResearcherState
|
|
from ea_chatbot.config import Settings
|
|
|
|
@pytest.fixture
|
|
def mock_settings():
|
|
settings = Settings()
|
|
settings.query_analyzer_llm.model = "model-qa"
|
|
settings.planner_llm.model = "model-planner"
|
|
settings.reflector_llm.model = "model-reflector"
|
|
settings.synthesizer_llm.model = "model-synthesizer"
|
|
settings.coder_llm.model = "model-coder"
|
|
settings.researcher_llm.model = "model-researcher"
|
|
return settings
|
|
|
|
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
|
|
@patch("ea_chatbot.graph.nodes.query_analyzer.Settings")
|
|
def test_query_analyzer_model(mock_settings_cls, mock_get_llm, mock_settings):
|
|
mock_settings_cls.return_value = mock_settings
|
|
mock_llm = MagicMock()
|
|
mock_get_llm.return_value = mock_llm
|
|
|
|
state = AgentState(question="test", messages=[], checklist=[], current_step=0, iterations=0, vfs={}, plots=[], dfs={}, next_action="", analysis={})
|
|
|
|
try:
|
|
query_analyzer_node(state)
|
|
except:
|
|
pass # We only care about the call
|
|
|
|
# Verify it was called with the QA config
|
|
args = mock_get_llm.call_args[0][0]
|
|
assert args.model == "model-qa"
|
|
|
|
@patch("ea_chatbot.graph.nodes.planner.get_llm_model")
|
|
@patch("ea_chatbot.graph.nodes.planner.Settings")
|
|
def test_planner_model(mock_settings_cls, mock_get_llm, mock_settings):
|
|
mock_settings_cls.return_value = mock_settings
|
|
mock_llm = MagicMock()
|
|
mock_get_llm.return_value = mock_llm
|
|
|
|
state = AgentState(question="test", messages=[], checklist=[], current_step=0, iterations=0, vfs={}, plots=[], dfs={}, next_action="", analysis={})
|
|
|
|
try:
|
|
planner_node(state)
|
|
except:
|
|
pass
|
|
|
|
args = mock_get_llm.call_args[0][0]
|
|
assert args.model == "model-planner"
|
|
|
|
@patch("ea_chatbot.graph.nodes.reflector.get_llm_model")
|
|
@patch("ea_chatbot.graph.nodes.reflector.Settings")
|
|
def test_reflector_model(mock_settings_cls, mock_get_llm, mock_settings):
|
|
mock_settings_cls.return_value = mock_settings
|
|
mock_llm = MagicMock()
|
|
mock_get_llm.return_value = mock_llm
|
|
|
|
state = AgentState(question="test", messages=[], checklist=[{"task": "T1"}], current_step=0, iterations=0, vfs={}, plots=[], dfs={}, next_action="", analysis={}, summary="Worker did X")
|
|
|
|
try:
|
|
reflector_node(state)
|
|
except:
|
|
pass
|
|
|
|
args = mock_get_llm.call_args[0][0]
|
|
assert args.model == "model-reflector"
|
|
|
|
@patch("ea_chatbot.graph.nodes.synthesizer.get_llm_model")
|
|
@patch("ea_chatbot.graph.nodes.synthesizer.Settings")
|
|
def test_synthesizer_model(mock_settings_cls, mock_get_llm, mock_settings):
|
|
mock_settings_cls.return_value = mock_settings
|
|
mock_llm = MagicMock()
|
|
mock_get_llm.return_value = mock_llm
|
|
|
|
state = AgentState(question="test", messages=[], checklist=[], current_step=0, iterations=0, vfs={}, plots=[], dfs={}, next_action="", analysis={})
|
|
|
|
try:
|
|
synthesizer_node(state)
|
|
except:
|
|
pass
|
|
|
|
args = mock_get_llm.call_args[0][0]
|
|
assert args.model == "model-synthesizer"
|
|
|
|
@patch("ea_chatbot.graph.workers.data_analyst.nodes.coder.get_llm_model")
|
|
@patch("ea_chatbot.graph.workers.data_analyst.nodes.coder.Settings")
|
|
def test_analyst_coder_model(mock_settings_cls, mock_get_llm, mock_settings):
|
|
mock_settings_cls.return_value = mock_settings
|
|
mock_llm = MagicMock()
|
|
mock_get_llm.return_value = mock_llm
|
|
|
|
state = AnalystState(task="test", messages=[], iterations=0, vfs_state={}, plots=[], result=None, code=None, output=None, error=None)
|
|
|
|
try:
|
|
analyst_coder(state)
|
|
except:
|
|
pass
|
|
|
|
args = mock_get_llm.call_args[0][0]
|
|
assert args.model == "model-coder"
|
|
|
|
@patch("ea_chatbot.graph.workers.researcher.nodes.searcher.get_llm_model")
|
|
@patch("ea_chatbot.graph.workers.researcher.nodes.searcher.Settings")
|
|
def test_researcher_searcher_model(mock_settings_cls, mock_get_llm, mock_settings):
|
|
mock_settings_cls.return_value = mock_settings
|
|
mock_llm = MagicMock()
|
|
mock_get_llm.return_value = mock_llm
|
|
|
|
state = ResearcherState(task="test", messages=[], iterations=0, result=None, queries=[], raw_results=[])
|
|
|
|
try:
|
|
researcher_searcher(state)
|
|
except:
|
|
pass
|
|
|
|
args = mock_get_llm.call_args[0][0]
|
|
assert args.model == "model-researcher"
|