chore(graph): Relocate QueryAnalysis schema and update existing tests for Orchestrator architecture

This commit is contained in:
Yunxiao Xu
2026-02-23 05:58:58 -08:00
parent ad7845cc6a
commit f4d09c07c4
7 changed files with 199 additions and 227 deletions

View File

@@ -1,18 +1,10 @@
from typing import List, Literal from typing import List, Literal
from pydantic import BaseModel, Field
from ea_chatbot.graph.state import AgentState from ea_chatbot.graph.state import AgentState
from ea_chatbot.config import Settings from ea_chatbot.config import Settings
from ea_chatbot.utils.llm_factory import get_llm_model from ea_chatbot.utils.llm_factory import get_llm_model
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
from ea_chatbot.graph.prompts.query_analyzer import QUERY_ANALYZER_PROMPT from ea_chatbot.graph.prompts.query_analyzer import QUERY_ANALYZER_PROMPT
from ea_chatbot.schemas import QueryAnalysis
class QueryAnalysis(BaseModel):
"""Analysis of the user's query."""
data_required: List[str] = Field(description="List of data points or entities mentioned (e.g., ['2024 results', 'Florida']).")
unknowns: List[str] = Field(description="List of target information the user wants to know or needed for final answer (e.g., 'who won', 'total votes').")
ambiguities: List[str] = Field(description="List of CRITICAL missing details that prevent ANY analysis. Do NOT include database names or plot types if defaults can be used.")
conditions: List[str] = Field(description="List of any filters or constraints (e.g., ['year=2024', 'state=Florida']). Include context resolved from history.")
next_action: Literal["plan", "clarify", "research"] = Field(description="The next action to take. 'plan' for data analysis (even with defaults), 'research' for general knowledge, or 'clarify' ONLY for critical ambiguities.")
def query_analyzer_node(state: AgentState) -> dict: def query_analyzer_node(state: AgentState) -> dict:
"""Analyze the user's question and determine the next course of action.""" """Analyze the user's question and determine the next course of action."""

View File

@@ -1,7 +1,15 @@
from pydantic import BaseModel, Field, computed_field from pydantic import BaseModel, Field, computed_field
from typing import Sequence, Optional, List, Dict, Any from typing import Sequence, Optional, List, Dict, Any, Literal
import re import re
class QueryAnalysis(BaseModel):
"""Analysis of the user's query."""
data_required: List[str] = Field(description="List of data points or entities mentioned (e.g., ['2024 results', 'Florida']).")
unknowns: List[str] = Field(description="List of target information the user wants to know or needed for final answer (e.g., 'who won', 'total votes').")
ambiguities: List[str] = Field(description="List of CRITICAL missing details that prevent ANY analysis. Do NOT include database names or plot types if defaults can be used.")
conditions: List[str] = Field(description="List of any filters or constraints (e.g., ['year=2024', 'state=Florida']). Include context resolved from history.")
next_action: Literal["plan", "clarify", "research"] = Field(description="The next action to take. 'plan' for data analysis (even with defaults), 'research' for general knowledge, or 'clarify' ONLY for critical ambiguities.")
class TaskPlanContext(BaseModel): class TaskPlanContext(BaseModel):
'''Background context relevant to the task plan''' '''Background context relevant to the task plan'''
initial_context: str = Field( initial_context: str = Field(

View File

@@ -2,7 +2,7 @@ import json
import pytest import pytest
import logging import logging
from unittest.mock import patch from unittest.mock import patch
from ea_chatbot.graph.workflow import app from ea_chatbot.graph.workflow import create_workflow
from ea_chatbot.graph.state import AgentState from ea_chatbot.graph.state import AgentState
from ea_chatbot.utils.logging import get_logger from ea_chatbot.utils.logging import get_logger
from langchain_community.chat_models import FakeListChatModel from langchain_community.chat_models import FakeListChatModel
@@ -43,10 +43,10 @@ def test_logging_e2e_json_output(tmp_path):
"question": "Who won in 2024?", "question": "Who won in 2024?",
"analysis": None, "analysis": None,
"next_action": "", "next_action": "",
"plan": None, "iterations": 0,
"code": None, "checklist": [],
"code_output": None, "current_step": 0,
"error": None, "vfs": {},
"plots": [], "plots": [],
"dfs": {} "dfs": {}
} }
@@ -57,6 +57,20 @@ def test_logging_e2e_json_output(tmp_path):
fake_clarify = FakeListChatModel(responses=["Please specify."]) fake_clarify = FakeListChatModel(responses=["Please specify."])
# Create a test app without interrupts
# We need to manually compile it here to avoid the global 'app' which has interrupts
from langgraph.graph import StateGraph, END
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node
from ea_chatbot.graph.nodes.clarification import clarification_node
workflow = StateGraph(AgentState)
workflow.add_node("query_analyzer", query_analyzer_node)
workflow.add_node("clarification", clarification_node)
workflow.set_entry_point("query_analyzer")
workflow.add_edge("query_analyzer", "clarification")
workflow.add_edge("clarification", END)
test_app = workflow.compile()
with patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model") as mock_llm_factory: with patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model") as mock_llm_factory:
mock_llm_factory.return_value = fake_analyzer mock_llm_factory.return_value = fake_analyzer
@@ -64,7 +78,7 @@ def test_logging_e2e_json_output(tmp_path):
mock_clarify_llm_factory.return_value = fake_clarify mock_clarify_llm_factory.return_value = fake_clarify
# Run the graph # Run the graph
list(app.stream(initial_state)) test_app.invoke(initial_state)
# Verify file content # Verify file content
assert log_file.exists() assert log_file.exists()

View File

@@ -1,24 +1,27 @@
import pytest import pytest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from langchain_core.messages import HumanMessage, AIMessage
from ea_chatbot.graph.nodes.planner import planner_node from ea_chatbot.graph.nodes.planner import planner_node
from ea_chatbot.graph.nodes.researcher import researcher_node from ea_chatbot.schemas import ChecklistResponse, ChecklistTask
from ea_chatbot.graph.nodes.summarizer import summarizer_node from langchain_core.messages import HumanMessage, AIMessage
from ea_chatbot.schemas import TaskPlanResponse
@pytest.fixture @pytest.fixture
def mock_state_with_history(): def mock_state_with_history():
return { return {
"messages": [ "messages": [
HumanMessage(content="Show me the 2024 results for Florida"), HumanMessage(content="What about NJ?"),
AIMessage(content="Here are the results for Florida in 2024...") AIMessage(content="NJ has 9 million voters.")
], ],
"question": "What about in New Jersey?", "question": "Show me the breakdown by county for 2024",
"analysis": {"data_required": ["2024 results", "New Jersey"], "unknowns": [], "ambiguities": [], "conditions": []}, "analysis": {
"data_required": ["2024 results", "New Jersey"],
"unknowns": [],
"ambiguities": [],
"conditions": []
},
"next_action": "plan", "next_action": "plan",
"summary": "The user is asking about 2024 election results.", "summary": "The user is asking about NJ 2024 results.",
"plan": "Plan steps...", "checklist": [],
"code_output": "Code output..." "current_step": 0
} }
@patch("ea_chatbot.graph.nodes.planner.get_llm_model") @patch("ea_chatbot.graph.nodes.planner.get_llm_model")
@@ -31,49 +34,17 @@ def test_planner_uses_history_and_summary(mock_prompt, mock_get_summary, mock_ge
mock_structured_llm = MagicMock() mock_structured_llm = MagicMock()
mock_llm_instance.with_structured_output.return_value = mock_structured_llm mock_llm_instance.with_structured_output.return_value = mock_structured_llm
mock_structured_llm.invoke.return_value = TaskPlanResponse( mock_structured_llm.invoke.return_value = ChecklistResponse(
goal="goal", goal="goal",
reflection="reflection", reflection="reflection",
context={ checklist=[ChecklistTask(task="Step 1: test", worker="data_analyst")]
"initial_context": "context",
"assumptions": [],
"constraints": []
},
steps=["Step 1: test"]
) )
planner_node(mock_state_with_history) planner_node(mock_state_with_history)
mock_prompt.format_messages.assert_called_once() # Verify history and summary were passed to prompt format
kwargs = mock_prompt.format_messages.call_args[1] # We check the arguments passed to the mock_prompt.format_messages
assert kwargs["question"] == "What about in New Jersey?" call_args = mock_prompt.format_messages.call_args[1]
assert kwargs["summary"] == mock_state_with_history["summary"] assert call_args["summary"] == "The user is asking about NJ 2024 results."
assert len(kwargs["history"]) == 2 assert len(call_args["history"]) == 2
assert "breakdown by county" in call_args["question"]
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
@patch("ea_chatbot.graph.nodes.researcher.RESEARCHER_PROMPT")
def test_researcher_uses_history_and_summary(mock_prompt, mock_get_llm, mock_state_with_history):
mock_llm_instance = MagicMock()
mock_get_llm.return_value = mock_llm_instance
researcher_node(mock_state_with_history)
mock_prompt.format_messages.assert_called_once()
kwargs = mock_prompt.format_messages.call_args[1]
assert kwargs["question"] == "What about in New Jersey?"
assert kwargs["summary"] == mock_state_with_history["summary"]
assert len(kwargs["history"]) == 2
@patch("ea_chatbot.graph.nodes.summarizer.get_llm_model")
@patch("ea_chatbot.graph.nodes.summarizer.SUMMARIZER_PROMPT")
def test_summarizer_uses_history_and_summary(mock_prompt, mock_get_llm, mock_state_with_history):
mock_llm_instance = MagicMock()
mock_get_llm.return_value = mock_llm_instance
summarizer_node(mock_state_with_history)
mock_prompt.format_messages.assert_called_once()
kwargs = mock_prompt.format_messages.call_args[1]
assert kwargs["question"] == "What about in New Jersey?"
assert kwargs["summary"] == mock_state_with_history["summary"]
assert len(kwargs["history"]) == 2

View File

@@ -1,6 +1,7 @@
import pytest import pytest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from ea_chatbot.graph.nodes.planner import planner_node from ea_chatbot.graph.nodes.planner import planner_node
from ea_chatbot.schemas import ChecklistResponse, ChecklistTask
@pytest.fixture @pytest.fixture
def mock_state(): def mock_state():
@@ -8,39 +9,40 @@ def mock_state():
"messages": [], "messages": [],
"question": "Show me results for New Jersey", "question": "Show me results for New Jersey",
"analysis": { "analysis": {
# "requires_dataset" removed as it's no longer used
"expert": "Data Analyst", "expert": "Data Analyst",
"data": "NJ data", "data": "NJ data",
"unknown": "results", "unknown": "results",
"condition": "state=NJ" "condition": "state=NJ"
}, },
"next_action": "plan", "next_action": "plan",
"plan": None "checklist": [],
"current_step": 0
} }
@patch("ea_chatbot.graph.nodes.planner.get_llm_model") @patch("ea_chatbot.graph.nodes.planner.get_llm_model")
@patch("ea_chatbot.utils.database_inspection.get_data_summary") @patch("ea_chatbot.utils.database_inspection.get_data_summary")
def test_planner_node(mock_get_summary, mock_get_llm, mock_state): def test_planner_node(mock_get_summary, mock_get_llm, mock_state):
"""Test planner node with unified prompt.""" """Test planner node with checklist generation."""
mock_get_summary.return_value = "Column: Name, Type: text" mock_get_summary.return_value = "Column: Name, Type: text"
mock_llm = MagicMock() mock_llm = MagicMock()
mock_get_llm.return_value = mock_llm mock_get_llm.return_value = mock_llm
from ea_chatbot.schemas import TaskPlanResponse, TaskPlanContext mock_response = ChecklistResponse(
mock_plan = TaskPlanResponse(
goal="Get NJ results", goal="Get NJ results",
reflection="The user wants NJ results", reflection="The user wants NJ results",
context=TaskPlanContext(initial_context="NJ data", assumptions=[], constraints=[]), checklist=[
steps=["Step 1: Load data", "Step 2: Filter by NJ"] ChecklistTask(task="Query NJ data", worker="data_analyst")
]
) )
mock_llm.with_structured_output.return_value.invoke.return_value = mock_plan mock_llm.with_structured_output.return_value.invoke.return_value = mock_response
result = planner_node(mock_state) result = planner_node(mock_state)
assert "plan" in result assert "checklist" in result
assert "Step 1: Load data" in result["plan"] assert result["checklist"][0]["task"] == "Query NJ data"
assert "Step 2: Filter by NJ" in result["plan"] assert result["current_step"] == 0
assert result["summary"] == "The user wants NJ results"
# Verify helper was called # Verify helper was called
mock_get_summary.assert_called_once() mock_get_summary.assert_called_once()

View File

@@ -1,92 +1,80 @@
import pytest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from ea_chatbot.graph.workflow import app from ea_chatbot.graph.workflow import create_workflow
from ea_chatbot.graph.nodes.query_analyzer import QueryAnalysis from ea_chatbot.graph.state import AgentState
from ea_chatbot.schemas import TaskPlanResponse, TaskPlanContext, CodeGenerationResponse from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.messages import AIMessage def test_workflow_full_flow():
"""Test the full Orchestrator-Workers flow using node injection."""
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
@patch("ea_chatbot.graph.nodes.planner.get_llm_model")
@patch("ea_chatbot.graph.nodes.coder.get_llm_model")
@patch("ea_chatbot.graph.nodes.summarizer.get_llm_model")
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
@patch("ea_chatbot.utils.database_inspection.get_data_summary")
@patch("ea_chatbot.graph.nodes.executor.Settings")
@patch("ea_chatbot.graph.nodes.executor.DBClient")
def test_workflow_full_flow(mock_db_client, mock_settings, mock_get_summary, mock_researcher_llm, mock_summarizer_llm, mock_coder_llm, mock_planner_llm, mock_qa_llm):
"""Test the flow from query_analyzer through planner to coder."""
# Mock Settings for Executor mock_analyzer = MagicMock()
mock_settings_instance = MagicMock() mock_planner = MagicMock()
mock_settings_instance.db_host = "localhost" mock_delegate = MagicMock()
mock_settings_instance.db_port = 5432 mock_worker = MagicMock()
mock_settings_instance.db_user = "user" mock_reflector = MagicMock()
mock_settings_instance.db_pswd = "pass" mock_synthesizer = MagicMock()
mock_settings_instance.db_name = "test_db" mock_summarize_conv = MagicMock()
mock_settings_instance.db_table = "test_table"
mock_settings.return_value = mock_settings_instance
# Mock DBClient # 1. Analyzer: Proceed to planning
mock_client_instance = MagicMock() mock_analyzer.return_value = {"next_action": "plan"}
mock_db_client.return_value = mock_client_instance
# 1. Mock Query Analyzer # 2. Planner: Generate checklist
mock_qa_instance = MagicMock() mock_planner.return_value = {
mock_qa_llm.return_value = mock_qa_instance "checklist": [{"task": "Step 1", "worker": "data_analyst"}],
mock_qa_instance.with_structured_output.return_value.invoke.return_value = QueryAnalysis( "current_step": 0
data_required=["2024 results"], }
unknowns=[],
ambiguities=[], # 3. Delegate: Route to data_analyst
conditions=[], mock_delegate.side_effect = [
next_action="plan" {"next_action": "data_analyst"},
{"next_action": "summarize"}
]
# 4. Worker: Success
mock_worker.return_value = {
"messages": [AIMessage(content="Worker Summary")],
"vfs": {}
}
# 5. Reflector: Advance
mock_reflector.return_value = {
"current_step": 1,
"next_action": "delegate"
}
# 6. Synthesizer: Final answer
mock_synthesizer.return_value = {
"messages": [AIMessage(content="Final Summary")],
"next_action": "end"
}
# 7. Summarize Conv: End
mock_summarize_conv.return_value = {"summary": "Done"}
app = create_workflow(
query_analyzer=mock_analyzer,
planner=mock_planner,
delegate=mock_delegate,
data_analyst_worker=mock_worker,
reflector=mock_reflector,
synthesizer=mock_synthesizer,
summarize_conversation=mock_summarize_conv
) )
# 2. Mock Planner
mock_planner_instance = MagicMock()
mock_planner_llm.return_value = mock_planner_instance
mock_get_summary.return_value = "Data summary"
mock_planner_instance.with_structured_output.return_value.invoke.return_value = TaskPlanResponse(
goal="Task Goal",
reflection="Reflection",
context=TaskPlanContext(initial_context="Ctx", assumptions=[], constraints=[]),
steps=["Step 1"]
)
# 3. Mock Coder
mock_coder_instance = MagicMock()
mock_coder_llm.return_value = mock_coder_instance
mock_coder_instance.with_structured_output.return_value.invoke.return_value = CodeGenerationResponse(
code="print('Hello')",
explanation="Explanation"
)
# 4. Mock Summarizer
mock_summarizer_instance = MagicMock()
mock_summarizer_llm.return_value = mock_summarizer_instance
mock_summarizer_instance.invoke.return_value = AIMessage(content="Summary")
# 5. Mock Researcher (just in case)
mock_researcher_instance = MagicMock()
mock_researcher_llm.return_value = mock_researcher_instance
# Initial state
initial_state = { initial_state = {
"messages": [], "messages": [HumanMessage(content="Show me results")],
"question": "Show me the 2024 results", "question": "Show me results",
"analysis": None, "analysis": None,
"next_action": "", "next_action": "",
"plan": None, "iterations": 0,
"code": None, "checklist": [],
"error": None, "current_step": 0,
"vfs": {},
"plots": [], "plots": [],
"dfs": {} "dfs": {}
} }
# Run the graph result = app.invoke(initial_state, config={"recursion_limit": 20})
# We use recursion_limit to avoid infinite loops in placeholders if any
result = app.invoke(initial_state, config={"recursion_limit": 10})
assert result["next_action"] == "plan" assert "Final Summary" in [m.content for m in result["messages"]]
assert "plan" in result and result["plan"] is not None assert result["current_step"] == 1
assert "code" in result and "print('Hello')" in result["code"]
assert "analysis" in result

View File

@@ -1,60 +1,51 @@
import pytest import pytest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from langchain_core.messages import AIMessage
from ea_chatbot.graph.workflow import app from ea_chatbot.graph.workflow import app
from ea_chatbot.graph.nodes.query_analyzer import QueryAnalysis from ea_chatbot.schemas import QueryAnalysis, ChecklistResponse, ChecklistTask, CodeGenerationResponse
from ea_chatbot.schemas import TaskPlanResponse, TaskPlanContext, CodeGenerationResponse from ea_chatbot.graph.state import AgentState
from langchain_core.messages import AIMessage
@pytest.fixture @pytest.fixture
def mock_llms(): def mock_llms():
with patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model") as mock_qa_llm, \ with patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model") as mock_qa, \
patch("ea_chatbot.graph.nodes.planner.get_llm_model") as mock_planner_llm, \ patch("ea_chatbot.graph.nodes.planner.get_llm_model") as mock_planner, \
patch("ea_chatbot.graph.nodes.coder.get_llm_model") as mock_coder_llm, \ patch("ea_chatbot.graph.workers.data_analyst.nodes.coder.get_llm_model") as mock_coder, \
patch("ea_chatbot.graph.nodes.summarizer.get_llm_model") as mock_summarizer_llm, \ patch("ea_chatbot.graph.workers.data_analyst.nodes.summarizer.get_llm_model") as mock_worker_summarizer, \
patch("ea_chatbot.graph.nodes.researcher.get_llm_model") as mock_researcher_llm, \ patch("ea_chatbot.graph.nodes.synthesizer.get_llm_model") as mock_synthesizer, \
patch("ea_chatbot.graph.nodes.summarize_conversation.get_llm_model") as mock_summary_llm, \ patch("ea_chatbot.graph.nodes.researcher.get_llm_model") as mock_researcher:
patch("ea_chatbot.utils.database_inspection.get_data_summary") as mock_get_summary:
mock_get_summary.return_value = "Data summary"
# Mock summary LLM to return a simple response
mock_summary_instance = MagicMock()
mock_summary_llm.return_value = mock_summary_instance
mock_summary_instance.invoke.return_value = AIMessage(content="Turn summary")
yield { yield {
"qa": mock_qa_llm, "qa": mock_qa,
"planner": mock_planner_llm, "planner": mock_planner,
"coder": mock_coder_llm, "coder": mock_coder,
"summarizer": mock_summarizer_llm, "worker_summarizer": mock_worker_summarizer,
"researcher": mock_researcher_llm, "synthesizer": mock_synthesizer,
"summary": mock_summary_llm "researcher": mock_researcher
} }
def test_workflow_data_analysis_flow(mock_llms): def test_workflow_data_analysis_flow(mock_llms):
"""Test full flow: QueryAnalyzer -> Planner -> Coder -> Executor -> Summarizer.""" """Test full flow: QueryAnalyzer -> Planner -> Delegate -> DataAnalyst -> Reflector -> Synthesizer."""
# 1. Mock Query Analyzer (routes to plan) # 1. Mock Query Analyzer
mock_qa_instance = MagicMock() mock_qa_instance = MagicMock()
mock_llms["qa"].return_value = mock_qa_instance mock_llms["qa"].return_value = mock_qa_instance
mock_qa_instance.with_structured_output.return_value.invoke.return_value = QueryAnalysis( mock_qa_instance.with_structured_output.return_value.invoke.return_value = QueryAnalysis(
data_required=["2024 results"], data_required=["2024 results"],
unknowns=[], unknowns=[],
ambiguities=[], ambiguities=[],
conditions=[], conditions=[],
next_action="plan" next_action="plan"
) )
# 2. Mock Planner # 2. Mock Planner
mock_planner_instance = MagicMock() mock_planner_instance = MagicMock()
mock_llms["planner"].return_value = mock_planner_instance mock_llms["planner"].return_value = mock_planner_instance
mock_planner_instance.with_structured_output.return_value.invoke.return_value = TaskPlanResponse( mock_planner_instance.with_structured_output.return_value.invoke.return_value = ChecklistResponse(
goal="Get results", goal="Get results",
reflection="Reflect", reflection="Reflect",
context=TaskPlanContext(initial_context="Ctx", assumptions=[], constraints=[]), checklist=[ChecklistTask(task="Query Data", worker="data_analyst")]
steps=["Step 1"]
) )
# 3. Mock Coder # 3. Mock Coder (Worker)
mock_coder_instance = MagicMock() mock_coder_instance = MagicMock()
mock_llms["coder"].return_value = mock_coder_instance mock_llms["coder"].return_value = mock_coder_instance
mock_coder_instance.with_structured_output.return_value.invoke.return_value = CodeGenerationResponse( mock_coder_instance.with_structured_output.return_value.invoke.return_value = CodeGenerationResponse(
@@ -62,10 +53,15 @@ def test_workflow_data_analysis_flow(mock_llms):
explanation="Explain" explanation="Explain"
) )
# 4. Mock Summarizer # 4. Mock Worker Summarizer
mock_summarizer_instance = MagicMock() mock_ws_instance = MagicMock()
mock_llms["summarizer"].return_value = mock_summarizer_instance mock_llms["worker_summarizer"].return_value = mock_ws_instance
mock_summarizer_instance.invoke.return_value = AIMessage(content="Final Summary: Success") mock_ws_instance.invoke.return_value = AIMessage(content="Worker Summary")
# 5. Mock Synthesizer
mock_syn_instance = MagicMock()
mock_llms["synthesizer"].return_value = mock_syn_instance
mock_syn_instance.invoke.return_value = AIMessage(content="Final Summary: Success")
# Initial state # Initial state
initial_state = { initial_state = {
@@ -73,66 +69,67 @@ def test_workflow_data_analysis_flow(mock_llms):
"question": "Show me 2024 results", "question": "Show me 2024 results",
"analysis": None, "analysis": None,
"next_action": "", "next_action": "",
"plan": None, "iterations": 0,
"code": None, "checklist": [],
"error": None, "current_step": 0,
"vfs": {},
"plots": [], "plots": [],
"dfs": {} "dfs": {}
} }
# Run the graph # Run the graph
result = app.invoke(initial_state, config={"recursion_limit": 15}) result = app.invoke(initial_state, config={"recursion_limit": 20})
assert result["next_action"] == "plan" assert "Final Summary: Success" in [m.content for m in result["messages"]]
assert "Execution Success" in result["code_output"] assert result["current_step"] == 1
assert "Final Summary: Success" in result["messages"][-1].content
def test_workflow_research_flow(mock_llms): def test_workflow_research_flow(mock_llms):
"""Test flow: QueryAnalyzer -> Researcher -> Summarizer.""" """Test flow with research task."""
# 1. Mock Query Analyzer (routes to research) # 1. Mock Query Analyzer
mock_qa_instance = MagicMock() mock_qa_instance = MagicMock()
mock_llms["qa"].return_value = mock_qa_instance mock_llms["qa"].return_value = mock_qa_instance
mock_qa_instance.with_structured_output.return_value.invoke.return_value = QueryAnalysis( mock_qa_instance.with_structured_output.return_value.invoke.return_value = QueryAnalysis(
data_required=[], data_required=[],
unknowns=[], unknowns=[],
ambiguities=[], ambiguities=[],
conditions=[], conditions=[],
next_action="research" next_action="research"
) )
# 2. Mock Researcher # 2. Mock Planner
mock_researcher_instance = MagicMock() mock_planner_instance = MagicMock()
mock_llms["researcher"].return_value = mock_researcher_instance mock_llms["planner"].return_value = mock_planner_instance
# Researcher node uses bind_tools if it's ChatOpenAI/ChatGoogleGenerativeAI mock_planner_instance.with_structured_output.return_value.invoke.return_value = ChecklistResponse(
# Since it's a MagicMock, it will fallback to using the base instance goal="Search",
mock_researcher_instance.invoke.return_value = AIMessage(content="Research Results") reflection="Reflect",
checklist=[ChecklistTask(task="Search Web", worker="researcher")]
)
# Also mock bind_tools just in case we ever use spec # 3. Mock Researcher
mock_llm_with_tools = MagicMock() mock_res_instance = MagicMock()
mock_researcher_instance.bind_tools.return_value = mock_llm_with_tools mock_llms["researcher"].return_value = mock_res_instance
mock_llm_with_tools.invoke.return_value = AIMessage(content="Research Results") mock_res_instance.invoke.return_value = AIMessage(content="Research Result")
# 3. Mock Summarizer (not used in this flow, but kept for completeness) # 4. Mock Synthesizer
mock_summarizer_instance = MagicMock() mock_syn_instance = MagicMock()
mock_llms["summarizer"].return_value = mock_summarizer_instance mock_llms["synthesizer"].return_value = mock_syn_instance
mock_summarizer_instance.invoke.return_value = AIMessage(content="Final Summary: Research Success") mock_syn_instance.invoke.return_value = AIMessage(content="Final Research Summary")
# Initial state
initial_state = { initial_state = {
"messages": [], "messages": [],
"question": "Who is the governor of Florida?", "question": "Who is the governor?",
"analysis": None, "analysis": None,
"next_action": "", "next_action": "",
"plan": None, "iterations": 0,
"code": None, "checklist": [],
"error": None, "current_step": 0,
"vfs": {},
"plots": [], "plots": [],
"dfs": {} "dfs": {}
} }
# Run the graph result = app.invoke(initial_state, config={"recursion_limit": 20})
result = app.invoke(initial_state, config={"recursion_limit": 10})
assert result["next_action"] == "research" assert "Final Research Summary" in [m.content for m in result["messages"]]
assert "Research Results" in result["messages"][-1].content assert result["current_step"] == 1