81 lines
2.3 KiB
Python
81 lines
2.3 KiB
Python
import pytest
|
|
from unittest.mock import MagicMock, patch
|
|
from ea_chatbot.graph.workflow import create_workflow
|
|
from ea_chatbot.graph.state import AgentState
|
|
from langchain_core.messages import AIMessage, HumanMessage
|
|
|
|
def test_workflow_full_flow():
|
|
"""Test the full Orchestrator-Workers flow using node injection."""
|
|
|
|
mock_analyzer = MagicMock()
|
|
mock_planner = MagicMock()
|
|
mock_delegate = MagicMock()
|
|
mock_worker = MagicMock()
|
|
mock_reflector = MagicMock()
|
|
mock_synthesizer = MagicMock()
|
|
mock_summarize_conv = MagicMock()
|
|
|
|
# 1. Analyzer: Proceed to planning
|
|
mock_analyzer.return_value = {"next_action": "plan"}
|
|
|
|
# 2. Planner: Generate checklist
|
|
mock_planner.return_value = {
|
|
"checklist": [{"task": "Step 1", "worker": "data_analyst"}],
|
|
"current_step": 0
|
|
}
|
|
|
|
# 3. Delegate: Route to data_analyst
|
|
mock_delegate.side_effect = [
|
|
{"next_action": "data_analyst"},
|
|
{"next_action": "summarize"}
|
|
]
|
|
|
|
# 4. Worker: Success
|
|
mock_worker.return_value = {
|
|
"messages": [AIMessage(content="Worker Summary")],
|
|
"vfs": {}
|
|
}
|
|
|
|
# 5. Reflector: Advance
|
|
mock_reflector.return_value = {
|
|
"current_step": 1,
|
|
"next_action": "delegate"
|
|
}
|
|
|
|
# 6. Synthesizer: Final answer
|
|
mock_synthesizer.return_value = {
|
|
"messages": [AIMessage(content="Final Summary")],
|
|
"next_action": "end"
|
|
}
|
|
|
|
# 7. Summarize Conv: End
|
|
mock_summarize_conv.return_value = {"summary": "Done"}
|
|
|
|
app = create_workflow(
|
|
query_analyzer=mock_analyzer,
|
|
planner=mock_planner,
|
|
delegate=mock_delegate,
|
|
data_analyst_worker=mock_worker,
|
|
reflector=mock_reflector,
|
|
synthesizer=mock_synthesizer,
|
|
summarize_conversation=mock_summarize_conv
|
|
)
|
|
|
|
initial_state = {
|
|
"messages": [HumanMessage(content="Show me results")],
|
|
"question": "Show me results",
|
|
"analysis": None,
|
|
"next_action": "",
|
|
"iterations": 0,
|
|
"checklist": [],
|
|
"current_step": 0,
|
|
"vfs": {},
|
|
"plots": [],
|
|
"dfs": {}
|
|
}
|
|
|
|
result = app.invoke(initial_state, config={"recursion_limit": 20})
|
|
|
|
assert "Final Summary" in [m.content for m in result["messages"]]
|
|
assert result["current_step"] == 1
|