chore(graph): Relocate QueryAnalysis schema and update existing tests for Orchestrator architecture
This commit is contained in:
@@ -2,7 +2,7 @@ import json
|
||||
import pytest
|
||||
import logging
|
||||
from unittest.mock import patch
|
||||
from ea_chatbot.graph.workflow import app
|
||||
from ea_chatbot.graph.workflow import create_workflow
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.utils.logging import get_logger
|
||||
from langchain_community.chat_models import FakeListChatModel
|
||||
@@ -43,10 +43,10 @@ def test_logging_e2e_json_output(tmp_path):
|
||||
"question": "Who won in 2024?",
|
||||
"analysis": None,
|
||||
"next_action": "",
|
||||
"plan": None,
|
||||
"code": None,
|
||||
"code_output": None,
|
||||
"error": None,
|
||||
"iterations": 0,
|
||||
"checklist": [],
|
||||
"current_step": 0,
|
||||
"vfs": {},
|
||||
"plots": [],
|
||||
"dfs": {}
|
||||
}
|
||||
@@ -57,6 +57,20 @@ def test_logging_e2e_json_output(tmp_path):
|
||||
|
||||
fake_clarify = FakeListChatModel(responses=["Please specify."])
|
||||
|
||||
# Create a test app without interrupts
|
||||
# We need to manually compile it here to avoid the global 'app' which has interrupts
|
||||
from langgraph.graph import StateGraph, END
|
||||
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node
|
||||
from ea_chatbot.graph.nodes.clarification import clarification_node
|
||||
|
||||
workflow = StateGraph(AgentState)
|
||||
workflow.add_node("query_analyzer", query_analyzer_node)
|
||||
workflow.add_node("clarification", clarification_node)
|
||||
workflow.set_entry_point("query_analyzer")
|
||||
workflow.add_edge("query_analyzer", "clarification")
|
||||
workflow.add_edge("clarification", END)
|
||||
test_app = workflow.compile()
|
||||
|
||||
with patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model") as mock_llm_factory:
|
||||
mock_llm_factory.return_value = fake_analyzer
|
||||
|
||||
@@ -64,7 +78,7 @@ def test_logging_e2e_json_output(tmp_path):
|
||||
mock_clarify_llm_factory.return_value = fake_clarify
|
||||
|
||||
# Run the graph
|
||||
list(app.stream(initial_state))
|
||||
test_app.invoke(initial_state)
|
||||
|
||||
# Verify file content
|
||||
assert log_file.exists()
|
||||
|
||||
@@ -1,24 +1,27 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from ea_chatbot.graph.nodes.planner import planner_node
|
||||
from ea_chatbot.graph.nodes.researcher import researcher_node
|
||||
from ea_chatbot.graph.nodes.summarizer import summarizer_node
|
||||
from ea_chatbot.schemas import TaskPlanResponse
|
||||
from ea_chatbot.schemas import ChecklistResponse, ChecklistTask
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
|
||||
@pytest.fixture
|
||||
def mock_state_with_history():
|
||||
return {
|
||||
"messages": [
|
||||
HumanMessage(content="Show me the 2024 results for Florida"),
|
||||
AIMessage(content="Here are the results for Florida in 2024...")
|
||||
HumanMessage(content="What about NJ?"),
|
||||
AIMessage(content="NJ has 9 million voters.")
|
||||
],
|
||||
"question": "What about in New Jersey?",
|
||||
"analysis": {"data_required": ["2024 results", "New Jersey"], "unknowns": [], "ambiguities": [], "conditions": []},
|
||||
"question": "Show me the breakdown by county for 2024",
|
||||
"analysis": {
|
||||
"data_required": ["2024 results", "New Jersey"],
|
||||
"unknowns": [],
|
||||
"ambiguities": [],
|
||||
"conditions": []
|
||||
},
|
||||
"next_action": "plan",
|
||||
"summary": "The user is asking about 2024 election results.",
|
||||
"plan": "Plan steps...",
|
||||
"code_output": "Code output..."
|
||||
"summary": "The user is asking about NJ 2024 results.",
|
||||
"checklist": [],
|
||||
"current_step": 0
|
||||
}
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.planner.get_llm_model")
|
||||
@@ -31,49 +34,17 @@ def test_planner_uses_history_and_summary(mock_prompt, mock_get_summary, mock_ge
|
||||
mock_structured_llm = MagicMock()
|
||||
mock_llm_instance.with_structured_output.return_value = mock_structured_llm
|
||||
|
||||
mock_structured_llm.invoke.return_value = TaskPlanResponse(
|
||||
mock_structured_llm.invoke.return_value = ChecklistResponse(
|
||||
goal="goal",
|
||||
reflection="reflection",
|
||||
context={
|
||||
"initial_context": "context",
|
||||
"assumptions": [],
|
||||
"constraints": []
|
||||
},
|
||||
steps=["Step 1: test"]
|
||||
checklist=[ChecklistTask(task="Step 1: test", worker="data_analyst")]
|
||||
)
|
||||
|
||||
planner_node(mock_state_with_history)
|
||||
|
||||
mock_prompt.format_messages.assert_called_once()
|
||||
kwargs = mock_prompt.format_messages.call_args[1]
|
||||
assert kwargs["question"] == "What about in New Jersey?"
|
||||
assert kwargs["summary"] == mock_state_with_history["summary"]
|
||||
assert len(kwargs["history"]) == 2
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
|
||||
@patch("ea_chatbot.graph.nodes.researcher.RESEARCHER_PROMPT")
|
||||
def test_researcher_uses_history_and_summary(mock_prompt, mock_get_llm, mock_state_with_history):
|
||||
mock_llm_instance = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm_instance
|
||||
|
||||
researcher_node(mock_state_with_history)
|
||||
|
||||
mock_prompt.format_messages.assert_called_once()
|
||||
kwargs = mock_prompt.format_messages.call_args[1]
|
||||
assert kwargs["question"] == "What about in New Jersey?"
|
||||
assert kwargs["summary"] == mock_state_with_history["summary"]
|
||||
assert len(kwargs["history"]) == 2
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.summarizer.get_llm_model")
|
||||
@patch("ea_chatbot.graph.nodes.summarizer.SUMMARIZER_PROMPT")
|
||||
def test_summarizer_uses_history_and_summary(mock_prompt, mock_get_llm, mock_state_with_history):
|
||||
mock_llm_instance = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm_instance
|
||||
|
||||
summarizer_node(mock_state_with_history)
|
||||
|
||||
mock_prompt.format_messages.assert_called_once()
|
||||
kwargs = mock_prompt.format_messages.call_args[1]
|
||||
assert kwargs["question"] == "What about in New Jersey?"
|
||||
assert kwargs["summary"] == mock_state_with_history["summary"]
|
||||
assert len(kwargs["history"]) == 2
|
||||
# Verify history and summary were passed to prompt format
|
||||
# We check the arguments passed to the mock_prompt.format_messages
|
||||
call_args = mock_prompt.format_messages.call_args[1]
|
||||
assert call_args["summary"] == "The user is asking about NJ 2024 results."
|
||||
assert len(call_args["history"]) == 2
|
||||
assert "breakdown by county" in call_args["question"]
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.graph.nodes.planner import planner_node
|
||||
from ea_chatbot.schemas import ChecklistResponse, ChecklistTask
|
||||
|
||||
@pytest.fixture
|
||||
def mock_state():
|
||||
@@ -8,39 +9,40 @@ def mock_state():
|
||||
"messages": [],
|
||||
"question": "Show me results for New Jersey",
|
||||
"analysis": {
|
||||
# "requires_dataset" removed as it's no longer used
|
||||
"expert": "Data Analyst",
|
||||
"data": "NJ data",
|
||||
"unknown": "results",
|
||||
"condition": "state=NJ"
|
||||
},
|
||||
"next_action": "plan",
|
||||
"plan": None
|
||||
"checklist": [],
|
||||
"current_step": 0
|
||||
}
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.planner.get_llm_model")
|
||||
@patch("ea_chatbot.utils.database_inspection.get_data_summary")
|
||||
def test_planner_node(mock_get_summary, mock_get_llm, mock_state):
|
||||
"""Test planner node with unified prompt."""
|
||||
"""Test planner node with checklist generation."""
|
||||
mock_get_summary.return_value = "Column: Name, Type: text"
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
from ea_chatbot.schemas import TaskPlanResponse, TaskPlanContext
|
||||
mock_plan = TaskPlanResponse(
|
||||
mock_response = ChecklistResponse(
|
||||
goal="Get NJ results",
|
||||
reflection="The user wants NJ results",
|
||||
context=TaskPlanContext(initial_context="NJ data", assumptions=[], constraints=[]),
|
||||
steps=["Step 1: Load data", "Step 2: Filter by NJ"]
|
||||
checklist=[
|
||||
ChecklistTask(task="Query NJ data", worker="data_analyst")
|
||||
]
|
||||
)
|
||||
mock_llm.with_structured_output.return_value.invoke.return_value = mock_plan
|
||||
mock_llm.with_structured_output.return_value.invoke.return_value = mock_response
|
||||
|
||||
result = planner_node(mock_state)
|
||||
|
||||
assert "plan" in result
|
||||
assert "Step 1: Load data" in result["plan"]
|
||||
assert "Step 2: Filter by NJ" in result["plan"]
|
||||
assert "checklist" in result
|
||||
assert result["checklist"][0]["task"] == "Query NJ data"
|
||||
assert result["current_step"] == 0
|
||||
assert result["summary"] == "The user wants NJ results"
|
||||
|
||||
# Verify helper was called
|
||||
mock_get_summary.assert_called_once()
|
||||
mock_get_summary.assert_called_once()
|
||||
|
||||
@@ -1,92 +1,80 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.graph.workflow import app
|
||||
from ea_chatbot.graph.nodes.query_analyzer import QueryAnalysis
|
||||
from ea_chatbot.schemas import TaskPlanResponse, TaskPlanContext, CodeGenerationResponse
|
||||
from ea_chatbot.graph.workflow import create_workflow
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
|
||||
@patch("ea_chatbot.graph.nodes.planner.get_llm_model")
|
||||
@patch("ea_chatbot.graph.nodes.coder.get_llm_model")
|
||||
@patch("ea_chatbot.graph.nodes.summarizer.get_llm_model")
|
||||
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
|
||||
@patch("ea_chatbot.utils.database_inspection.get_data_summary")
|
||||
@patch("ea_chatbot.graph.nodes.executor.Settings")
|
||||
@patch("ea_chatbot.graph.nodes.executor.DBClient")
|
||||
def test_workflow_full_flow(mock_db_client, mock_settings, mock_get_summary, mock_researcher_llm, mock_summarizer_llm, mock_coder_llm, mock_planner_llm, mock_qa_llm):
|
||||
"""Test the flow from query_analyzer through planner to coder."""
|
||||
def test_workflow_full_flow():
|
||||
"""Test the full Orchestrator-Workers flow using node injection."""
|
||||
|
||||
# Mock Settings for Executor
|
||||
mock_settings_instance = MagicMock()
|
||||
mock_settings_instance.db_host = "localhost"
|
||||
mock_settings_instance.db_port = 5432
|
||||
mock_settings_instance.db_user = "user"
|
||||
mock_settings_instance.db_pswd = "pass"
|
||||
mock_settings_instance.db_name = "test_db"
|
||||
mock_settings_instance.db_table = "test_table"
|
||||
mock_settings.return_value = mock_settings_instance
|
||||
mock_analyzer = MagicMock()
|
||||
mock_planner = MagicMock()
|
||||
mock_delegate = MagicMock()
|
||||
mock_worker = MagicMock()
|
||||
mock_reflector = MagicMock()
|
||||
mock_synthesizer = MagicMock()
|
||||
mock_summarize_conv = MagicMock()
|
||||
|
||||
# Mock DBClient
|
||||
mock_client_instance = MagicMock()
|
||||
mock_db_client.return_value = mock_client_instance
|
||||
# 1. Analyzer: Proceed to planning
|
||||
mock_analyzer.return_value = {"next_action": "plan"}
|
||||
|
||||
# 1. Mock Query Analyzer
|
||||
mock_qa_instance = MagicMock()
|
||||
mock_qa_llm.return_value = mock_qa_instance
|
||||
mock_qa_instance.with_structured_output.return_value.invoke.return_value = QueryAnalysis(
|
||||
data_required=["2024 results"],
|
||||
unknowns=[],
|
||||
ambiguities=[],
|
||||
conditions=[],
|
||||
next_action="plan"
|
||||
# 2. Planner: Generate checklist
|
||||
mock_planner.return_value = {
|
||||
"checklist": [{"task": "Step 1", "worker": "data_analyst"}],
|
||||
"current_step": 0
|
||||
}
|
||||
|
||||
# 3. Delegate: Route to data_analyst
|
||||
mock_delegate.side_effect = [
|
||||
{"next_action": "data_analyst"},
|
||||
{"next_action": "summarize"}
|
||||
]
|
||||
|
||||
# 4. Worker: Success
|
||||
mock_worker.return_value = {
|
||||
"messages": [AIMessage(content="Worker Summary")],
|
||||
"vfs": {}
|
||||
}
|
||||
|
||||
# 5. Reflector: Advance
|
||||
mock_reflector.return_value = {
|
||||
"current_step": 1,
|
||||
"next_action": "delegate"
|
||||
}
|
||||
|
||||
# 6. Synthesizer: Final answer
|
||||
mock_synthesizer.return_value = {
|
||||
"messages": [AIMessage(content="Final Summary")],
|
||||
"next_action": "end"
|
||||
}
|
||||
|
||||
# 7. Summarize Conv: End
|
||||
mock_summarize_conv.return_value = {"summary": "Done"}
|
||||
|
||||
app = create_workflow(
|
||||
query_analyzer=mock_analyzer,
|
||||
planner=mock_planner,
|
||||
delegate=mock_delegate,
|
||||
data_analyst_worker=mock_worker,
|
||||
reflector=mock_reflector,
|
||||
synthesizer=mock_synthesizer,
|
||||
summarize_conversation=mock_summarize_conv
|
||||
)
|
||||
|
||||
# 2. Mock Planner
|
||||
mock_planner_instance = MagicMock()
|
||||
mock_planner_llm.return_value = mock_planner_instance
|
||||
mock_get_summary.return_value = "Data summary"
|
||||
mock_planner_instance.with_structured_output.return_value.invoke.return_value = TaskPlanResponse(
|
||||
goal="Task Goal",
|
||||
reflection="Reflection",
|
||||
context=TaskPlanContext(initial_context="Ctx", assumptions=[], constraints=[]),
|
||||
steps=["Step 1"]
|
||||
)
|
||||
|
||||
# 3. Mock Coder
|
||||
mock_coder_instance = MagicMock()
|
||||
mock_coder_llm.return_value = mock_coder_instance
|
||||
mock_coder_instance.with_structured_output.return_value.invoke.return_value = CodeGenerationResponse(
|
||||
code="print('Hello')",
|
||||
explanation="Explanation"
|
||||
)
|
||||
|
||||
# 4. Mock Summarizer
|
||||
mock_summarizer_instance = MagicMock()
|
||||
mock_summarizer_llm.return_value = mock_summarizer_instance
|
||||
mock_summarizer_instance.invoke.return_value = AIMessage(content="Summary")
|
||||
|
||||
# 5. Mock Researcher (just in case)
|
||||
mock_researcher_instance = MagicMock()
|
||||
mock_researcher_llm.return_value = mock_researcher_instance
|
||||
|
||||
# Initial state
|
||||
initial_state = {
|
||||
"messages": [],
|
||||
"question": "Show me the 2024 results",
|
||||
"messages": [HumanMessage(content="Show me results")],
|
||||
"question": "Show me results",
|
||||
"analysis": None,
|
||||
"next_action": "",
|
||||
"plan": None,
|
||||
"code": None,
|
||||
"error": None,
|
||||
"iterations": 0,
|
||||
"checklist": [],
|
||||
"current_step": 0,
|
||||
"vfs": {},
|
||||
"plots": [],
|
||||
"dfs": {}
|
||||
}
|
||||
|
||||
# Run the graph
|
||||
# We use recursion_limit to avoid infinite loops in placeholders if any
|
||||
result = app.invoke(initial_state, config={"recursion_limit": 10})
|
||||
result = app.invoke(initial_state, config={"recursion_limit": 20})
|
||||
|
||||
assert result["next_action"] == "plan"
|
||||
assert "plan" in result and result["plan"] is not None
|
||||
assert "code" in result and "print('Hello')" in result["code"]
|
||||
assert "analysis" in result
|
||||
assert "Final Summary" in [m.content for m in result["messages"]]
|
||||
assert result["current_step"] == 1
|
||||
|
||||
@@ -1,60 +1,51 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from langchain_core.messages import AIMessage
|
||||
from ea_chatbot.graph.workflow import app
|
||||
from ea_chatbot.graph.nodes.query_analyzer import QueryAnalysis
|
||||
from ea_chatbot.schemas import TaskPlanResponse, TaskPlanContext, CodeGenerationResponse
|
||||
from ea_chatbot.schemas import QueryAnalysis, ChecklistResponse, ChecklistTask, CodeGenerationResponse
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llms():
|
||||
with patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model") as mock_qa_llm, \
|
||||
patch("ea_chatbot.graph.nodes.planner.get_llm_model") as mock_planner_llm, \
|
||||
patch("ea_chatbot.graph.nodes.coder.get_llm_model") as mock_coder_llm, \
|
||||
patch("ea_chatbot.graph.nodes.summarizer.get_llm_model") as mock_summarizer_llm, \
|
||||
patch("ea_chatbot.graph.nodes.researcher.get_llm_model") as mock_researcher_llm, \
|
||||
patch("ea_chatbot.graph.nodes.summarize_conversation.get_llm_model") as mock_summary_llm, \
|
||||
patch("ea_chatbot.utils.database_inspection.get_data_summary") as mock_get_summary:
|
||||
mock_get_summary.return_value = "Data summary"
|
||||
|
||||
# Mock summary LLM to return a simple response
|
||||
mock_summary_instance = MagicMock()
|
||||
mock_summary_llm.return_value = mock_summary_instance
|
||||
mock_summary_instance.invoke.return_value = AIMessage(content="Turn summary")
|
||||
|
||||
with patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model") as mock_qa, \
|
||||
patch("ea_chatbot.graph.nodes.planner.get_llm_model") as mock_planner, \
|
||||
patch("ea_chatbot.graph.workers.data_analyst.nodes.coder.get_llm_model") as mock_coder, \
|
||||
patch("ea_chatbot.graph.workers.data_analyst.nodes.summarizer.get_llm_model") as mock_worker_summarizer, \
|
||||
patch("ea_chatbot.graph.nodes.synthesizer.get_llm_model") as mock_synthesizer, \
|
||||
patch("ea_chatbot.graph.nodes.researcher.get_llm_model") as mock_researcher:
|
||||
yield {
|
||||
"qa": mock_qa_llm,
|
||||
"planner": mock_planner_llm,
|
||||
"coder": mock_coder_llm,
|
||||
"summarizer": mock_summarizer_llm,
|
||||
"researcher": mock_researcher_llm,
|
||||
"summary": mock_summary_llm
|
||||
"qa": mock_qa,
|
||||
"planner": mock_planner,
|
||||
"coder": mock_coder,
|
||||
"worker_summarizer": mock_worker_summarizer,
|
||||
"synthesizer": mock_synthesizer,
|
||||
"researcher": mock_researcher
|
||||
}
|
||||
|
||||
def test_workflow_data_analysis_flow(mock_llms):
|
||||
"""Test full flow: QueryAnalyzer -> Planner -> Coder -> Executor -> Summarizer."""
|
||||
"""Test full flow: QueryAnalyzer -> Planner -> Delegate -> DataAnalyst -> Reflector -> Synthesizer."""
|
||||
|
||||
# 1. Mock Query Analyzer (routes to plan)
|
||||
# 1. Mock Query Analyzer
|
||||
mock_qa_instance = MagicMock()
|
||||
mock_llms["qa"].return_value = mock_qa_instance
|
||||
mock_qa_instance.with_structured_output.return_value.invoke.return_value = QueryAnalysis(
|
||||
data_required=["2024 results"],
|
||||
unknowns=[],
|
||||
data_required=["2024 results"],
|
||||
unknowns=[],
|
||||
ambiguities=[],
|
||||
conditions=[],
|
||||
conditions=[],
|
||||
next_action="plan"
|
||||
)
|
||||
|
||||
# 2. Mock Planner
|
||||
mock_planner_instance = MagicMock()
|
||||
mock_llms["planner"].return_value = mock_planner_instance
|
||||
mock_planner_instance.with_structured_output.return_value.invoke.return_value = TaskPlanResponse(
|
||||
mock_planner_instance.with_structured_output.return_value.invoke.return_value = ChecklistResponse(
|
||||
goal="Get results",
|
||||
reflection="Reflect",
|
||||
context=TaskPlanContext(initial_context="Ctx", assumptions=[], constraints=[]),
|
||||
steps=["Step 1"]
|
||||
checklist=[ChecklistTask(task="Query Data", worker="data_analyst")]
|
||||
)
|
||||
|
||||
# 3. Mock Coder
|
||||
# 3. Mock Coder (Worker)
|
||||
mock_coder_instance = MagicMock()
|
||||
mock_llms["coder"].return_value = mock_coder_instance
|
||||
mock_coder_instance.with_structured_output.return_value.invoke.return_value = CodeGenerationResponse(
|
||||
@@ -62,10 +53,15 @@ def test_workflow_data_analysis_flow(mock_llms):
|
||||
explanation="Explain"
|
||||
)
|
||||
|
||||
# 4. Mock Summarizer
|
||||
mock_summarizer_instance = MagicMock()
|
||||
mock_llms["summarizer"].return_value = mock_summarizer_instance
|
||||
mock_summarizer_instance.invoke.return_value = AIMessage(content="Final Summary: Success")
|
||||
# 4. Mock Worker Summarizer
|
||||
mock_ws_instance = MagicMock()
|
||||
mock_llms["worker_summarizer"].return_value = mock_ws_instance
|
||||
mock_ws_instance.invoke.return_value = AIMessage(content="Worker Summary")
|
||||
|
||||
# 5. Mock Synthesizer
|
||||
mock_syn_instance = MagicMock()
|
||||
mock_llms["synthesizer"].return_value = mock_syn_instance
|
||||
mock_syn_instance.invoke.return_value = AIMessage(content="Final Summary: Success")
|
||||
|
||||
# Initial state
|
||||
initial_state = {
|
||||
@@ -73,66 +69,67 @@ def test_workflow_data_analysis_flow(mock_llms):
|
||||
"question": "Show me 2024 results",
|
||||
"analysis": None,
|
||||
"next_action": "",
|
||||
"plan": None,
|
||||
"code": None,
|
||||
"error": None,
|
||||
"iterations": 0,
|
||||
"checklist": [],
|
||||
"current_step": 0,
|
||||
"vfs": {},
|
||||
"plots": [],
|
||||
"dfs": {}
|
||||
}
|
||||
|
||||
# Run the graph
|
||||
result = app.invoke(initial_state, config={"recursion_limit": 15})
|
||||
result = app.invoke(initial_state, config={"recursion_limit": 20})
|
||||
|
||||
assert result["next_action"] == "plan"
|
||||
assert "Execution Success" in result["code_output"]
|
||||
assert "Final Summary: Success" in result["messages"][-1].content
|
||||
assert "Final Summary: Success" in [m.content for m in result["messages"]]
|
||||
assert result["current_step"] == 1
|
||||
|
||||
def test_workflow_research_flow(mock_llms):
|
||||
"""Test flow: QueryAnalyzer -> Researcher -> Summarizer."""
|
||||
"""Test flow with research task."""
|
||||
|
||||
# 1. Mock Query Analyzer (routes to research)
|
||||
# 1. Mock Query Analyzer
|
||||
mock_qa_instance = MagicMock()
|
||||
mock_llms["qa"].return_value = mock_qa_instance
|
||||
mock_qa_instance.with_structured_output.return_value.invoke.return_value = QueryAnalysis(
|
||||
data_required=[],
|
||||
unknowns=[],
|
||||
data_required=[],
|
||||
unknowns=[],
|
||||
ambiguities=[],
|
||||
conditions=[],
|
||||
conditions=[],
|
||||
next_action="research"
|
||||
)
|
||||
|
||||
# 2. Mock Researcher
|
||||
mock_researcher_instance = MagicMock()
|
||||
mock_llms["researcher"].return_value = mock_researcher_instance
|
||||
# Researcher node uses bind_tools if it's ChatOpenAI/ChatGoogleGenerativeAI
|
||||
# Since it's a MagicMock, it will fallback to using the base instance
|
||||
mock_researcher_instance.invoke.return_value = AIMessage(content="Research Results")
|
||||
# 2. Mock Planner
|
||||
mock_planner_instance = MagicMock()
|
||||
mock_llms["planner"].return_value = mock_planner_instance
|
||||
mock_planner_instance.with_structured_output.return_value.invoke.return_value = ChecklistResponse(
|
||||
goal="Search",
|
||||
reflection="Reflect",
|
||||
checklist=[ChecklistTask(task="Search Web", worker="researcher")]
|
||||
)
|
||||
|
||||
# Also mock bind_tools just in case we ever use spec
|
||||
mock_llm_with_tools = MagicMock()
|
||||
mock_researcher_instance.bind_tools.return_value = mock_llm_with_tools
|
||||
mock_llm_with_tools.invoke.return_value = AIMessage(content="Research Results")
|
||||
# 3. Mock Researcher
|
||||
mock_res_instance = MagicMock()
|
||||
mock_llms["researcher"].return_value = mock_res_instance
|
||||
mock_res_instance.invoke.return_value = AIMessage(content="Research Result")
|
||||
|
||||
# 3. Mock Summarizer (not used in this flow, but kept for completeness)
|
||||
mock_summarizer_instance = MagicMock()
|
||||
mock_llms["summarizer"].return_value = mock_summarizer_instance
|
||||
mock_summarizer_instance.invoke.return_value = AIMessage(content="Final Summary: Research Success")
|
||||
# 4. Mock Synthesizer
|
||||
mock_syn_instance = MagicMock()
|
||||
mock_llms["synthesizer"].return_value = mock_syn_instance
|
||||
mock_syn_instance.invoke.return_value = AIMessage(content="Final Research Summary")
|
||||
|
||||
# Initial state
|
||||
initial_state = {
|
||||
"messages": [],
|
||||
"question": "Who is the governor of Florida?",
|
||||
"question": "Who is the governor?",
|
||||
"analysis": None,
|
||||
"next_action": "",
|
||||
"plan": None,
|
||||
"code": None,
|
||||
"error": None,
|
||||
"iterations": 0,
|
||||
"checklist": [],
|
||||
"current_step": 0,
|
||||
"vfs": {},
|
||||
"plots": [],
|
||||
"dfs": {}
|
||||
}
|
||||
|
||||
# Run the graph
|
||||
result = app.invoke(initial_state, config={"recursion_limit": 10})
|
||||
result = app.invoke(initial_state, config={"recursion_limit": 20})
|
||||
|
||||
assert result["next_action"] == "research"
|
||||
assert "Research Results" in result["messages"][-1].content
|
||||
assert "Final Research Summary" in [m.content for m in result["messages"]]
|
||||
assert result["current_step"] == 1
|
||||
|
||||
Reference in New Issue
Block a user