From f4d09c07c46034589b5c6ccb6eee62021dd9797d Mon Sep 17 00:00:00 2001 From: Yunxiao Xu Date: Mon, 23 Feb 2026 05:58:58 -0800 Subject: [PATCH] chore(graph): Relocate QueryAnalysis schema and update existing tests for Orchestrator architecture --- .../ea_chatbot/graph/nodes/query_analyzer.py | 10 +- backend/src/ea_chatbot/schemas.py | 10 +- backend/tests/test_logging_e2e.py | 26 +++- .../test_multi_turn_planner_researcher.py | 73 +++------ backend/tests/test_planner.py | 26 ++-- backend/tests/test_workflow.py | 142 ++++++++---------- backend/tests/test_workflow_e2e.py | 139 +++++++++-------- 7 files changed, 199 insertions(+), 227 deletions(-) diff --git a/backend/src/ea_chatbot/graph/nodes/query_analyzer.py b/backend/src/ea_chatbot/graph/nodes/query_analyzer.py index 349bcc5..22f58c5 100644 --- a/backend/src/ea_chatbot/graph/nodes/query_analyzer.py +++ b/backend/src/ea_chatbot/graph/nodes/query_analyzer.py @@ -1,18 +1,10 @@ from typing import List, Literal -from pydantic import BaseModel, Field from ea_chatbot.graph.state import AgentState from ea_chatbot.config import Settings from ea_chatbot.utils.llm_factory import get_llm_model from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler from ea_chatbot.graph.prompts.query_analyzer import QUERY_ANALYZER_PROMPT - -class QueryAnalysis(BaseModel): - """Analysis of the user's query.""" - data_required: List[str] = Field(description="List of data points or entities mentioned (e.g., ['2024 results', 'Florida']).") - unknowns: List[str] = Field(description="List of target information the user wants to know or needed for final answer (e.g., 'who won', 'total votes').") - ambiguities: List[str] = Field(description="List of CRITICAL missing details that prevent ANY analysis. Do NOT include database names or plot types if defaults can be used.") - conditions: List[str] = Field(description="List of any filters or constraints (e.g., ['year=2024', 'state=Florida']). Include context resolved from history.") - next_action: Literal["plan", "clarify", "research"] = Field(description="The next action to take. 'plan' for data analysis (even with defaults), 'research' for general knowledge, or 'clarify' ONLY for critical ambiguities.") +from ea_chatbot.schemas import QueryAnalysis def query_analyzer_node(state: AgentState) -> dict: """Analyze the user's question and determine the next course of action.""" diff --git a/backend/src/ea_chatbot/schemas.py b/backend/src/ea_chatbot/schemas.py index 0f0fc24..19adba8 100644 --- a/backend/src/ea_chatbot/schemas.py +++ b/backend/src/ea_chatbot/schemas.py @@ -1,7 +1,15 @@ from pydantic import BaseModel, Field, computed_field -from typing import Sequence, Optional, List, Dict, Any +from typing import Sequence, Optional, List, Dict, Any, Literal import re +class QueryAnalysis(BaseModel): + """Analysis of the user's query.""" + data_required: List[str] = Field(description="List of data points or entities mentioned (e.g., ['2024 results', 'Florida']).") + unknowns: List[str] = Field(description="List of target information the user wants to know or needed for final answer (e.g., 'who won', 'total votes').") + ambiguities: List[str] = Field(description="List of CRITICAL missing details that prevent ANY analysis. Do NOT include database names or plot types if defaults can be used.") + conditions: List[str] = Field(description="List of any filters or constraints (e.g., ['year=2024', 'state=Florida']). Include context resolved from history.") + next_action: Literal["plan", "clarify", "research"] = Field(description="The next action to take. 'plan' for data analysis (even with defaults), 'research' for general knowledge, or 'clarify' ONLY for critical ambiguities.") + class TaskPlanContext(BaseModel): '''Background context relevant to the task plan''' initial_context: str = Field( diff --git a/backend/tests/test_logging_e2e.py b/backend/tests/test_logging_e2e.py index 028eb34..e17333b 100644 --- a/backend/tests/test_logging_e2e.py +++ b/backend/tests/test_logging_e2e.py @@ -2,7 +2,7 @@ import json import pytest import logging from unittest.mock import patch -from ea_chatbot.graph.workflow import app +from ea_chatbot.graph.workflow import create_workflow from ea_chatbot.graph.state import AgentState from ea_chatbot.utils.logging import get_logger from langchain_community.chat_models import FakeListChatModel @@ -43,10 +43,10 @@ def test_logging_e2e_json_output(tmp_path): "question": "Who won in 2024?", "analysis": None, "next_action": "", - "plan": None, - "code": None, - "code_output": None, - "error": None, + "iterations": 0, + "checklist": [], + "current_step": 0, + "vfs": {}, "plots": [], "dfs": {} } @@ -57,6 +57,20 @@ def test_logging_e2e_json_output(tmp_path): fake_clarify = FakeListChatModel(responses=["Please specify."]) + # Create a test app without interrupts + # We need to manually compile it here to avoid the global 'app' which has interrupts + from langgraph.graph import StateGraph, END + from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node + from ea_chatbot.graph.nodes.clarification import clarification_node + + workflow = StateGraph(AgentState) + workflow.add_node("query_analyzer", query_analyzer_node) + workflow.add_node("clarification", clarification_node) + workflow.set_entry_point("query_analyzer") + workflow.add_edge("query_analyzer", "clarification") + workflow.add_edge("clarification", END) + test_app = workflow.compile() + with patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model") as mock_llm_factory: mock_llm_factory.return_value = fake_analyzer @@ -64,7 +78,7 @@ def test_logging_e2e_json_output(tmp_path): mock_clarify_llm_factory.return_value = fake_clarify # Run the graph - list(app.stream(initial_state)) + test_app.invoke(initial_state) # Verify file content assert log_file.exists() diff --git a/backend/tests/test_multi_turn_planner_researcher.py b/backend/tests/test_multi_turn_planner_researcher.py index af85e5e..9cf0cc9 100644 --- a/backend/tests/test_multi_turn_planner_researcher.py +++ b/backend/tests/test_multi_turn_planner_researcher.py @@ -1,24 +1,27 @@ import pytest from unittest.mock import MagicMock, patch -from langchain_core.messages import HumanMessage, AIMessage from ea_chatbot.graph.nodes.planner import planner_node -from ea_chatbot.graph.nodes.researcher import researcher_node -from ea_chatbot.graph.nodes.summarizer import summarizer_node -from ea_chatbot.schemas import TaskPlanResponse +from ea_chatbot.schemas import ChecklistResponse, ChecklistTask +from langchain_core.messages import HumanMessage, AIMessage @pytest.fixture def mock_state_with_history(): return { "messages": [ - HumanMessage(content="Show me the 2024 results for Florida"), - AIMessage(content="Here are the results for Florida in 2024...") + HumanMessage(content="What about NJ?"), + AIMessage(content="NJ has 9 million voters.") ], - "question": "What about in New Jersey?", - "analysis": {"data_required": ["2024 results", "New Jersey"], "unknowns": [], "ambiguities": [], "conditions": []}, + "question": "Show me the breakdown by county for 2024", + "analysis": { + "data_required": ["2024 results", "New Jersey"], + "unknowns": [], + "ambiguities": [], + "conditions": [] + }, "next_action": "plan", - "summary": "The user is asking about 2024 election results.", - "plan": "Plan steps...", - "code_output": "Code output..." + "summary": "The user is asking about NJ 2024 results.", + "checklist": [], + "current_step": 0 } @patch("ea_chatbot.graph.nodes.planner.get_llm_model") @@ -31,49 +34,17 @@ def test_planner_uses_history_and_summary(mock_prompt, mock_get_summary, mock_ge mock_structured_llm = MagicMock() mock_llm_instance.with_structured_output.return_value = mock_structured_llm - mock_structured_llm.invoke.return_value = TaskPlanResponse( + mock_structured_llm.invoke.return_value = ChecklistResponse( goal="goal", reflection="reflection", - context={ - "initial_context": "context", - "assumptions": [], - "constraints": [] - }, - steps=["Step 1: test"] + checklist=[ChecklistTask(task="Step 1: test", worker="data_analyst")] ) planner_node(mock_state_with_history) - mock_prompt.format_messages.assert_called_once() - kwargs = mock_prompt.format_messages.call_args[1] - assert kwargs["question"] == "What about in New Jersey?" - assert kwargs["summary"] == mock_state_with_history["summary"] - assert len(kwargs["history"]) == 2 - -@patch("ea_chatbot.graph.nodes.researcher.get_llm_model") -@patch("ea_chatbot.graph.nodes.researcher.RESEARCHER_PROMPT") -def test_researcher_uses_history_and_summary(mock_prompt, mock_get_llm, mock_state_with_history): - mock_llm_instance = MagicMock() - mock_get_llm.return_value = mock_llm_instance - - researcher_node(mock_state_with_history) - - mock_prompt.format_messages.assert_called_once() - kwargs = mock_prompt.format_messages.call_args[1] - assert kwargs["question"] == "What about in New Jersey?" - assert kwargs["summary"] == mock_state_with_history["summary"] - assert len(kwargs["history"]) == 2 - -@patch("ea_chatbot.graph.nodes.summarizer.get_llm_model") -@patch("ea_chatbot.graph.nodes.summarizer.SUMMARIZER_PROMPT") -def test_summarizer_uses_history_and_summary(mock_prompt, mock_get_llm, mock_state_with_history): - mock_llm_instance = MagicMock() - mock_get_llm.return_value = mock_llm_instance - - summarizer_node(mock_state_with_history) - - mock_prompt.format_messages.assert_called_once() - kwargs = mock_prompt.format_messages.call_args[1] - assert kwargs["question"] == "What about in New Jersey?" - assert kwargs["summary"] == mock_state_with_history["summary"] - assert len(kwargs["history"]) == 2 + # Verify history and summary were passed to prompt format + # We check the arguments passed to the mock_prompt.format_messages + call_args = mock_prompt.format_messages.call_args[1] + assert call_args["summary"] == "The user is asking about NJ 2024 results." + assert len(call_args["history"]) == 2 + assert "breakdown by county" in call_args["question"] diff --git a/backend/tests/test_planner.py b/backend/tests/test_planner.py index 472fd81..c72f37a 100644 --- a/backend/tests/test_planner.py +++ b/backend/tests/test_planner.py @@ -1,6 +1,7 @@ import pytest from unittest.mock import MagicMock, patch from ea_chatbot.graph.nodes.planner import planner_node +from ea_chatbot.schemas import ChecklistResponse, ChecklistTask @pytest.fixture def mock_state(): @@ -8,39 +9,40 @@ def mock_state(): "messages": [], "question": "Show me results for New Jersey", "analysis": { - # "requires_dataset" removed as it's no longer used "expert": "Data Analyst", "data": "NJ data", "unknown": "results", "condition": "state=NJ" }, "next_action": "plan", - "plan": None + "checklist": [], + "current_step": 0 } @patch("ea_chatbot.graph.nodes.planner.get_llm_model") @patch("ea_chatbot.utils.database_inspection.get_data_summary") def test_planner_node(mock_get_summary, mock_get_llm, mock_state): - """Test planner node with unified prompt.""" + """Test planner node with checklist generation.""" mock_get_summary.return_value = "Column: Name, Type: text" mock_llm = MagicMock() mock_get_llm.return_value = mock_llm - from ea_chatbot.schemas import TaskPlanResponse, TaskPlanContext - mock_plan = TaskPlanResponse( + mock_response = ChecklistResponse( goal="Get NJ results", reflection="The user wants NJ results", - context=TaskPlanContext(initial_context="NJ data", assumptions=[], constraints=[]), - steps=["Step 1: Load data", "Step 2: Filter by NJ"] + checklist=[ + ChecklistTask(task="Query NJ data", worker="data_analyst") + ] ) - mock_llm.with_structured_output.return_value.invoke.return_value = mock_plan + mock_llm.with_structured_output.return_value.invoke.return_value = mock_response result = planner_node(mock_state) - assert "plan" in result - assert "Step 1: Load data" in result["plan"] - assert "Step 2: Filter by NJ" in result["plan"] + assert "checklist" in result + assert result["checklist"][0]["task"] == "Query NJ data" + assert result["current_step"] == 0 + assert result["summary"] == "The user wants NJ results" # Verify helper was called - mock_get_summary.assert_called_once() \ No newline at end of file + mock_get_summary.assert_called_once() diff --git a/backend/tests/test_workflow.py b/backend/tests/test_workflow.py index c3ea36e..a1ba40f 100644 --- a/backend/tests/test_workflow.py +++ b/backend/tests/test_workflow.py @@ -1,92 +1,80 @@ +import pytest from unittest.mock import MagicMock, patch -from ea_chatbot.graph.workflow import app -from ea_chatbot.graph.nodes.query_analyzer import QueryAnalysis -from ea_chatbot.schemas import TaskPlanResponse, TaskPlanContext, CodeGenerationResponse +from ea_chatbot.graph.workflow import create_workflow +from ea_chatbot.graph.state import AgentState +from langchain_core.messages import AIMessage, HumanMessage -from langchain_core.messages import AIMessage - -@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model") -@patch("ea_chatbot.graph.nodes.planner.get_llm_model") -@patch("ea_chatbot.graph.nodes.coder.get_llm_model") -@patch("ea_chatbot.graph.nodes.summarizer.get_llm_model") -@patch("ea_chatbot.graph.nodes.researcher.get_llm_model") -@patch("ea_chatbot.utils.database_inspection.get_data_summary") -@patch("ea_chatbot.graph.nodes.executor.Settings") -@patch("ea_chatbot.graph.nodes.executor.DBClient") -def test_workflow_full_flow(mock_db_client, mock_settings, mock_get_summary, mock_researcher_llm, mock_summarizer_llm, mock_coder_llm, mock_planner_llm, mock_qa_llm): - """Test the flow from query_analyzer through planner to coder.""" +def test_workflow_full_flow(): + """Test the full Orchestrator-Workers flow using node injection.""" - # Mock Settings for Executor - mock_settings_instance = MagicMock() - mock_settings_instance.db_host = "localhost" - mock_settings_instance.db_port = 5432 - mock_settings_instance.db_user = "user" - mock_settings_instance.db_pswd = "pass" - mock_settings_instance.db_name = "test_db" - mock_settings_instance.db_table = "test_table" - mock_settings.return_value = mock_settings_instance + mock_analyzer = MagicMock() + mock_planner = MagicMock() + mock_delegate = MagicMock() + mock_worker = MagicMock() + mock_reflector = MagicMock() + mock_synthesizer = MagicMock() + mock_summarize_conv = MagicMock() - # Mock DBClient - mock_client_instance = MagicMock() - mock_db_client.return_value = mock_client_instance + # 1. Analyzer: Proceed to planning + mock_analyzer.return_value = {"next_action": "plan"} - # 1. Mock Query Analyzer - mock_qa_instance = MagicMock() - mock_qa_llm.return_value = mock_qa_instance - mock_qa_instance.with_structured_output.return_value.invoke.return_value = QueryAnalysis( - data_required=["2024 results"], - unknowns=[], - ambiguities=[], - conditions=[], - next_action="plan" + # 2. Planner: Generate checklist + mock_planner.return_value = { + "checklist": [{"task": "Step 1", "worker": "data_analyst"}], + "current_step": 0 + } + + # 3. Delegate: Route to data_analyst + mock_delegate.side_effect = [ + {"next_action": "data_analyst"}, + {"next_action": "summarize"} + ] + + # 4. Worker: Success + mock_worker.return_value = { + "messages": [AIMessage(content="Worker Summary")], + "vfs": {} + } + + # 5. Reflector: Advance + mock_reflector.return_value = { + "current_step": 1, + "next_action": "delegate" + } + + # 6. Synthesizer: Final answer + mock_synthesizer.return_value = { + "messages": [AIMessage(content="Final Summary")], + "next_action": "end" + } + + # 7. Summarize Conv: End + mock_summarize_conv.return_value = {"summary": "Done"} + + app = create_workflow( + query_analyzer=mock_analyzer, + planner=mock_planner, + delegate=mock_delegate, + data_analyst_worker=mock_worker, + reflector=mock_reflector, + synthesizer=mock_synthesizer, + summarize_conversation=mock_summarize_conv ) - # 2. Mock Planner - mock_planner_instance = MagicMock() - mock_planner_llm.return_value = mock_planner_instance - mock_get_summary.return_value = "Data summary" - mock_planner_instance.with_structured_output.return_value.invoke.return_value = TaskPlanResponse( - goal="Task Goal", - reflection="Reflection", - context=TaskPlanContext(initial_context="Ctx", assumptions=[], constraints=[]), - steps=["Step 1"] - ) - - # 3. Mock Coder - mock_coder_instance = MagicMock() - mock_coder_llm.return_value = mock_coder_instance - mock_coder_instance.with_structured_output.return_value.invoke.return_value = CodeGenerationResponse( - code="print('Hello')", - explanation="Explanation" - ) - - # 4. Mock Summarizer - mock_summarizer_instance = MagicMock() - mock_summarizer_llm.return_value = mock_summarizer_instance - mock_summarizer_instance.invoke.return_value = AIMessage(content="Summary") - - # 5. Mock Researcher (just in case) - mock_researcher_instance = MagicMock() - mock_researcher_llm.return_value = mock_researcher_instance - - # Initial state initial_state = { - "messages": [], - "question": "Show me the 2024 results", + "messages": [HumanMessage(content="Show me results")], + "question": "Show me results", "analysis": None, "next_action": "", - "plan": None, - "code": None, - "error": None, + "iterations": 0, + "checklist": [], + "current_step": 0, + "vfs": {}, "plots": [], "dfs": {} } - # Run the graph - # We use recursion_limit to avoid infinite loops in placeholders if any - result = app.invoke(initial_state, config={"recursion_limit": 10}) + result = app.invoke(initial_state, config={"recursion_limit": 20}) - assert result["next_action"] == "plan" - assert "plan" in result and result["plan"] is not None - assert "code" in result and "print('Hello')" in result["code"] - assert "analysis" in result \ No newline at end of file + assert "Final Summary" in [m.content for m in result["messages"]] + assert result["current_step"] == 1 diff --git a/backend/tests/test_workflow_e2e.py b/backend/tests/test_workflow_e2e.py index ef76d1f..e5367ca 100644 --- a/backend/tests/test_workflow_e2e.py +++ b/backend/tests/test_workflow_e2e.py @@ -1,60 +1,51 @@ import pytest from unittest.mock import MagicMock, patch -from langchain_core.messages import AIMessage from ea_chatbot.graph.workflow import app -from ea_chatbot.graph.nodes.query_analyzer import QueryAnalysis -from ea_chatbot.schemas import TaskPlanResponse, TaskPlanContext, CodeGenerationResponse +from ea_chatbot.schemas import QueryAnalysis, ChecklistResponse, ChecklistTask, CodeGenerationResponse +from ea_chatbot.graph.state import AgentState +from langchain_core.messages import AIMessage @pytest.fixture def mock_llms(): - with patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model") as mock_qa_llm, \ - patch("ea_chatbot.graph.nodes.planner.get_llm_model") as mock_planner_llm, \ - patch("ea_chatbot.graph.nodes.coder.get_llm_model") as mock_coder_llm, \ - patch("ea_chatbot.graph.nodes.summarizer.get_llm_model") as mock_summarizer_llm, \ - patch("ea_chatbot.graph.nodes.researcher.get_llm_model") as mock_researcher_llm, \ - patch("ea_chatbot.graph.nodes.summarize_conversation.get_llm_model") as mock_summary_llm, \ - patch("ea_chatbot.utils.database_inspection.get_data_summary") as mock_get_summary: - mock_get_summary.return_value = "Data summary" - - # Mock summary LLM to return a simple response - mock_summary_instance = MagicMock() - mock_summary_llm.return_value = mock_summary_instance - mock_summary_instance.invoke.return_value = AIMessage(content="Turn summary") - + with patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model") as mock_qa, \ + patch("ea_chatbot.graph.nodes.planner.get_llm_model") as mock_planner, \ + patch("ea_chatbot.graph.workers.data_analyst.nodes.coder.get_llm_model") as mock_coder, \ + patch("ea_chatbot.graph.workers.data_analyst.nodes.summarizer.get_llm_model") as mock_worker_summarizer, \ + patch("ea_chatbot.graph.nodes.synthesizer.get_llm_model") as mock_synthesizer, \ + patch("ea_chatbot.graph.nodes.researcher.get_llm_model") as mock_researcher: yield { - "qa": mock_qa_llm, - "planner": mock_planner_llm, - "coder": mock_coder_llm, - "summarizer": mock_summarizer_llm, - "researcher": mock_researcher_llm, - "summary": mock_summary_llm + "qa": mock_qa, + "planner": mock_planner, + "coder": mock_coder, + "worker_summarizer": mock_worker_summarizer, + "synthesizer": mock_synthesizer, + "researcher": mock_researcher } def test_workflow_data_analysis_flow(mock_llms): - """Test full flow: QueryAnalyzer -> Planner -> Coder -> Executor -> Summarizer.""" + """Test full flow: QueryAnalyzer -> Planner -> Delegate -> DataAnalyst -> Reflector -> Synthesizer.""" - # 1. Mock Query Analyzer (routes to plan) + # 1. Mock Query Analyzer mock_qa_instance = MagicMock() mock_llms["qa"].return_value = mock_qa_instance mock_qa_instance.with_structured_output.return_value.invoke.return_value = QueryAnalysis( - data_required=["2024 results"], - unknowns=[], + data_required=["2024 results"], + unknowns=[], ambiguities=[], - conditions=[], + conditions=[], next_action="plan" ) # 2. Mock Planner mock_planner_instance = MagicMock() mock_llms["planner"].return_value = mock_planner_instance - mock_planner_instance.with_structured_output.return_value.invoke.return_value = TaskPlanResponse( + mock_planner_instance.with_structured_output.return_value.invoke.return_value = ChecklistResponse( goal="Get results", reflection="Reflect", - context=TaskPlanContext(initial_context="Ctx", assumptions=[], constraints=[]), - steps=["Step 1"] + checklist=[ChecklistTask(task="Query Data", worker="data_analyst")] ) - # 3. Mock Coder + # 3. Mock Coder (Worker) mock_coder_instance = MagicMock() mock_llms["coder"].return_value = mock_coder_instance mock_coder_instance.with_structured_output.return_value.invoke.return_value = CodeGenerationResponse( @@ -62,10 +53,15 @@ def test_workflow_data_analysis_flow(mock_llms): explanation="Explain" ) - # 4. Mock Summarizer - mock_summarizer_instance = MagicMock() - mock_llms["summarizer"].return_value = mock_summarizer_instance - mock_summarizer_instance.invoke.return_value = AIMessage(content="Final Summary: Success") + # 4. Mock Worker Summarizer + mock_ws_instance = MagicMock() + mock_llms["worker_summarizer"].return_value = mock_ws_instance + mock_ws_instance.invoke.return_value = AIMessage(content="Worker Summary") + + # 5. Mock Synthesizer + mock_syn_instance = MagicMock() + mock_llms["synthesizer"].return_value = mock_syn_instance + mock_syn_instance.invoke.return_value = AIMessage(content="Final Summary: Success") # Initial state initial_state = { @@ -73,66 +69,67 @@ def test_workflow_data_analysis_flow(mock_llms): "question": "Show me 2024 results", "analysis": None, "next_action": "", - "plan": None, - "code": None, - "error": None, + "iterations": 0, + "checklist": [], + "current_step": 0, + "vfs": {}, "plots": [], "dfs": {} } # Run the graph - result = app.invoke(initial_state, config={"recursion_limit": 15}) + result = app.invoke(initial_state, config={"recursion_limit": 20}) - assert result["next_action"] == "plan" - assert "Execution Success" in result["code_output"] - assert "Final Summary: Success" in result["messages"][-1].content + assert "Final Summary: Success" in [m.content for m in result["messages"]] + assert result["current_step"] == 1 def test_workflow_research_flow(mock_llms): - """Test flow: QueryAnalyzer -> Researcher -> Summarizer.""" + """Test flow with research task.""" - # 1. Mock Query Analyzer (routes to research) + # 1. Mock Query Analyzer mock_qa_instance = MagicMock() mock_llms["qa"].return_value = mock_qa_instance mock_qa_instance.with_structured_output.return_value.invoke.return_value = QueryAnalysis( - data_required=[], - unknowns=[], + data_required=[], + unknowns=[], ambiguities=[], - conditions=[], + conditions=[], next_action="research" ) - # 2. Mock Researcher - mock_researcher_instance = MagicMock() - mock_llms["researcher"].return_value = mock_researcher_instance - # Researcher node uses bind_tools if it's ChatOpenAI/ChatGoogleGenerativeAI - # Since it's a MagicMock, it will fallback to using the base instance - mock_researcher_instance.invoke.return_value = AIMessage(content="Research Results") + # 2. Mock Planner + mock_planner_instance = MagicMock() + mock_llms["planner"].return_value = mock_planner_instance + mock_planner_instance.with_structured_output.return_value.invoke.return_value = ChecklistResponse( + goal="Search", + reflection="Reflect", + checklist=[ChecklistTask(task="Search Web", worker="researcher")] + ) - # Also mock bind_tools just in case we ever use spec - mock_llm_with_tools = MagicMock() - mock_researcher_instance.bind_tools.return_value = mock_llm_with_tools - mock_llm_with_tools.invoke.return_value = AIMessage(content="Research Results") + # 3. Mock Researcher + mock_res_instance = MagicMock() + mock_llms["researcher"].return_value = mock_res_instance + mock_res_instance.invoke.return_value = AIMessage(content="Research Result") - # 3. Mock Summarizer (not used in this flow, but kept for completeness) - mock_summarizer_instance = MagicMock() - mock_llms["summarizer"].return_value = mock_summarizer_instance - mock_summarizer_instance.invoke.return_value = AIMessage(content="Final Summary: Research Success") + # 4. Mock Synthesizer + mock_syn_instance = MagicMock() + mock_llms["synthesizer"].return_value = mock_syn_instance + mock_syn_instance.invoke.return_value = AIMessage(content="Final Research Summary") - # Initial state initial_state = { "messages": [], - "question": "Who is the governor of Florida?", + "question": "Who is the governor?", "analysis": None, "next_action": "", - "plan": None, - "code": None, - "error": None, + "iterations": 0, + "checklist": [], + "current_step": 0, + "vfs": {}, "plots": [], "dfs": {} } - # Run the graph - result = app.invoke(initial_state, config={"recursion_limit": 10}) + result = app.invoke(initial_state, config={"recursion_limit": 20}) - assert result["next_action"] == "research" - assert "Research Results" in result["messages"][-1].content + assert "Final Research Summary" in [m.content for m in result["messages"]] + assert result["current_step"] == 1