136 lines
4.9 KiB
Python
136 lines
4.9 KiB
Python
import pytest
|
|
from unittest.mock import MagicMock, patch
|
|
from ea_chatbot.graph.workflow import app
|
|
from ea_chatbot.schemas import QueryAnalysis, ChecklistResponse, ChecklistTask, CodeGenerationResponse
|
|
from ea_chatbot.graph.state import AgentState
|
|
from langchain_core.messages import AIMessage
|
|
|
|
@pytest.fixture
|
|
def mock_llms():
|
|
with patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model") as mock_qa, \
|
|
patch("ea_chatbot.graph.nodes.planner.get_llm_model") as mock_planner, \
|
|
patch("ea_chatbot.graph.workers.data_analyst.nodes.coder.get_llm_model") as mock_coder, \
|
|
patch("ea_chatbot.graph.workers.data_analyst.nodes.summarizer.get_llm_model") as mock_worker_summarizer, \
|
|
patch("ea_chatbot.graph.nodes.synthesizer.get_llm_model") as mock_synthesizer, \
|
|
patch("ea_chatbot.graph.nodes.researcher.get_llm_model") as mock_researcher:
|
|
yield {
|
|
"qa": mock_qa,
|
|
"planner": mock_planner,
|
|
"coder": mock_coder,
|
|
"worker_summarizer": mock_worker_summarizer,
|
|
"synthesizer": mock_synthesizer,
|
|
"researcher": mock_researcher
|
|
}
|
|
|
|
def test_workflow_data_analysis_flow(mock_llms):
|
|
"""Test full flow: QueryAnalyzer -> Planner -> Delegate -> DataAnalyst -> Reflector -> Synthesizer."""
|
|
|
|
# 1. Mock Query Analyzer
|
|
mock_qa_instance = MagicMock()
|
|
mock_llms["qa"].return_value = mock_qa_instance
|
|
mock_qa_instance.with_structured_output.return_value.invoke.return_value = QueryAnalysis(
|
|
data_required=["2024 results"],
|
|
unknowns=[],
|
|
ambiguities=[],
|
|
conditions=[],
|
|
next_action="plan"
|
|
)
|
|
|
|
# 2. Mock Planner
|
|
mock_planner_instance = MagicMock()
|
|
mock_llms["planner"].return_value = mock_planner_instance
|
|
mock_planner_instance.with_structured_output.return_value.invoke.return_value = ChecklistResponse(
|
|
goal="Get results",
|
|
reflection="Reflect",
|
|
checklist=[ChecklistTask(task="Query Data", worker="data_analyst")]
|
|
)
|
|
|
|
# 3. Mock Coder (Worker)
|
|
mock_coder_instance = MagicMock()
|
|
mock_llms["coder"].return_value = mock_coder_instance
|
|
mock_coder_instance.with_structured_output.return_value.invoke.return_value = CodeGenerationResponse(
|
|
code="print('Execution Success')",
|
|
explanation="Explain"
|
|
)
|
|
|
|
# 4. Mock Worker Summarizer
|
|
mock_ws_instance = MagicMock()
|
|
mock_llms["worker_summarizer"].return_value = mock_ws_instance
|
|
mock_ws_instance.invoke.return_value = AIMessage(content="Worker Summary")
|
|
|
|
# 5. Mock Synthesizer
|
|
mock_syn_instance = MagicMock()
|
|
mock_llms["synthesizer"].return_value = mock_syn_instance
|
|
mock_syn_instance.invoke.return_value = AIMessage(content="Final Summary: Success")
|
|
|
|
# Initial state
|
|
initial_state = {
|
|
"messages": [],
|
|
"question": "Show me 2024 results",
|
|
"analysis": None,
|
|
"next_action": "",
|
|
"iterations": 0,
|
|
"checklist": [],
|
|
"current_step": 0,
|
|
"vfs": {},
|
|
"plots": [],
|
|
"dfs": {}
|
|
}
|
|
|
|
# Run the graph
|
|
result = app.invoke(initial_state, config={"recursion_limit": 20})
|
|
|
|
assert "Final Summary: Success" in [m.content for m in result["messages"]]
|
|
assert result["current_step"] == 1
|
|
|
|
def test_workflow_research_flow(mock_llms):
|
|
"""Test flow with research task."""
|
|
|
|
# 1. Mock Query Analyzer
|
|
mock_qa_instance = MagicMock()
|
|
mock_llms["qa"].return_value = mock_qa_instance
|
|
mock_qa_instance.with_structured_output.return_value.invoke.return_value = QueryAnalysis(
|
|
data_required=[],
|
|
unknowns=[],
|
|
ambiguities=[],
|
|
conditions=[],
|
|
next_action="research"
|
|
)
|
|
|
|
# 2. Mock Planner
|
|
mock_planner_instance = MagicMock()
|
|
mock_llms["planner"].return_value = mock_planner_instance
|
|
mock_planner_instance.with_structured_output.return_value.invoke.return_value = ChecklistResponse(
|
|
goal="Search",
|
|
reflection="Reflect",
|
|
checklist=[ChecklistTask(task="Search Web", worker="researcher")]
|
|
)
|
|
|
|
# 3. Mock Researcher
|
|
mock_res_instance = MagicMock()
|
|
mock_llms["researcher"].return_value = mock_res_instance
|
|
mock_res_instance.invoke.return_value = AIMessage(content="Research Result")
|
|
|
|
# 4. Mock Synthesizer
|
|
mock_syn_instance = MagicMock()
|
|
mock_llms["synthesizer"].return_value = mock_syn_instance
|
|
mock_syn_instance.invoke.return_value = AIMessage(content="Final Research Summary")
|
|
|
|
initial_state = {
|
|
"messages": [],
|
|
"question": "Who is the governor?",
|
|
"analysis": None,
|
|
"next_action": "",
|
|
"iterations": 0,
|
|
"checklist": [],
|
|
"current_step": 0,
|
|
"vfs": {},
|
|
"plots": [],
|
|
"dfs": {}
|
|
}
|
|
|
|
result = app.invoke(initial_state, config={"recursion_limit": 20})
|
|
|
|
assert "Final Research Summary" in [m.content for m in result["messages"]]
|
|
assert result["current_step"] == 1
|