chore(graph): Relocate QueryAnalysis schema and update existing tests for Orchestrator architecture

This commit is contained in:
Yunxiao Xu
2026-02-23 05:58:58 -08:00
parent ad7845cc6a
commit f4d09c07c4
7 changed files with 199 additions and 227 deletions

View File

@@ -1,92 +1,80 @@
import pytest
from unittest.mock import MagicMock, patch
from ea_chatbot.graph.workflow import app
from ea_chatbot.graph.nodes.query_analyzer import QueryAnalysis
from ea_chatbot.schemas import TaskPlanResponse, TaskPlanContext, CodeGenerationResponse
from ea_chatbot.graph.workflow import create_workflow
from ea_chatbot.graph.state import AgentState
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.messages import AIMessage
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
@patch("ea_chatbot.graph.nodes.planner.get_llm_model")
@patch("ea_chatbot.graph.nodes.coder.get_llm_model")
@patch("ea_chatbot.graph.nodes.summarizer.get_llm_model")
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
@patch("ea_chatbot.utils.database_inspection.get_data_summary")
@patch("ea_chatbot.graph.nodes.executor.Settings")
@patch("ea_chatbot.graph.nodes.executor.DBClient")
def test_workflow_full_flow(mock_db_client, mock_settings, mock_get_summary, mock_researcher_llm, mock_summarizer_llm, mock_coder_llm, mock_planner_llm, mock_qa_llm):
"""Test the flow from query_analyzer through planner to coder."""
def test_workflow_full_flow():
"""Test the full Orchestrator-Workers flow using node injection."""
# Mock Settings for Executor
mock_settings_instance = MagicMock()
mock_settings_instance.db_host = "localhost"
mock_settings_instance.db_port = 5432
mock_settings_instance.db_user = "user"
mock_settings_instance.db_pswd = "pass"
mock_settings_instance.db_name = "test_db"
mock_settings_instance.db_table = "test_table"
mock_settings.return_value = mock_settings_instance
mock_analyzer = MagicMock()
mock_planner = MagicMock()
mock_delegate = MagicMock()
mock_worker = MagicMock()
mock_reflector = MagicMock()
mock_synthesizer = MagicMock()
mock_summarize_conv = MagicMock()
# Mock DBClient
mock_client_instance = MagicMock()
mock_db_client.return_value = mock_client_instance
# 1. Analyzer: Proceed to planning
mock_analyzer.return_value = {"next_action": "plan"}
# 1. Mock Query Analyzer
mock_qa_instance = MagicMock()
mock_qa_llm.return_value = mock_qa_instance
mock_qa_instance.with_structured_output.return_value.invoke.return_value = QueryAnalysis(
data_required=["2024 results"],
unknowns=[],
ambiguities=[],
conditions=[],
next_action="plan"
# 2. Planner: Generate checklist
mock_planner.return_value = {
"checklist": [{"task": "Step 1", "worker": "data_analyst"}],
"current_step": 0
}
# 3. Delegate: Route to data_analyst
mock_delegate.side_effect = [
{"next_action": "data_analyst"},
{"next_action": "summarize"}
]
# 4. Worker: Success
mock_worker.return_value = {
"messages": [AIMessage(content="Worker Summary")],
"vfs": {}
}
# 5. Reflector: Advance
mock_reflector.return_value = {
"current_step": 1,
"next_action": "delegate"
}
# 6. Synthesizer: Final answer
mock_synthesizer.return_value = {
"messages": [AIMessage(content="Final Summary")],
"next_action": "end"
}
# 7. Summarize Conv: End
mock_summarize_conv.return_value = {"summary": "Done"}
app = create_workflow(
query_analyzer=mock_analyzer,
planner=mock_planner,
delegate=mock_delegate,
data_analyst_worker=mock_worker,
reflector=mock_reflector,
synthesizer=mock_synthesizer,
summarize_conversation=mock_summarize_conv
)
# 2. Mock Planner
mock_planner_instance = MagicMock()
mock_planner_llm.return_value = mock_planner_instance
mock_get_summary.return_value = "Data summary"
mock_planner_instance.with_structured_output.return_value.invoke.return_value = TaskPlanResponse(
goal="Task Goal",
reflection="Reflection",
context=TaskPlanContext(initial_context="Ctx", assumptions=[], constraints=[]),
steps=["Step 1"]
)
# 3. Mock Coder
mock_coder_instance = MagicMock()
mock_coder_llm.return_value = mock_coder_instance
mock_coder_instance.with_structured_output.return_value.invoke.return_value = CodeGenerationResponse(
code="print('Hello')",
explanation="Explanation"
)
# 4. Mock Summarizer
mock_summarizer_instance = MagicMock()
mock_summarizer_llm.return_value = mock_summarizer_instance
mock_summarizer_instance.invoke.return_value = AIMessage(content="Summary")
# 5. Mock Researcher (just in case)
mock_researcher_instance = MagicMock()
mock_researcher_llm.return_value = mock_researcher_instance
# Initial state
initial_state = {
"messages": [],
"question": "Show me the 2024 results",
"messages": [HumanMessage(content="Show me results")],
"question": "Show me results",
"analysis": None,
"next_action": "",
"plan": None,
"code": None,
"error": None,
"iterations": 0,
"checklist": [],
"current_step": 0,
"vfs": {},
"plots": [],
"dfs": {}
}
# Run the graph
# We use recursion_limit to avoid infinite loops in placeholders if any
result = app.invoke(initial_state, config={"recursion_limit": 10})
result = app.invoke(initial_state, config={"recursion_limit": 20})
assert result["next_action"] == "plan"
assert "plan" in result and result["plan"] is not None
assert "code" in result and "print('Hello')" in result["code"]
assert "analysis" in result
assert "Final Summary" in [m.content for m in result["messages"]]
assert result["current_step"] == 1