import pytest from unittest.mock import MagicMock, patch from ea_chatbot.graph.workflow import app from ea_chatbot.schemas import QueryAnalysis, ChecklistResponse, ChecklistTask, CodeGenerationResponse from ea_chatbot.graph.state import AgentState from langchain_core.messages import AIMessage @pytest.fixture def mock_llms(): with patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model") as mock_qa, \ patch("ea_chatbot.graph.nodes.planner.get_llm_model") as mock_planner, \ patch("ea_chatbot.graph.workers.data_analyst.nodes.coder.get_llm_model") as mock_coder, \ patch("ea_chatbot.graph.workers.data_analyst.nodes.summarizer.get_llm_model") as mock_worker_summarizer, \ patch("ea_chatbot.graph.nodes.synthesizer.get_llm_model") as mock_synthesizer, \ patch("ea_chatbot.graph.nodes.researcher.get_llm_model") as mock_researcher, \ patch("ea_chatbot.graph.nodes.reflector.get_llm_model") as mock_reflector: yield { "qa": mock_qa, "planner": mock_planner, "coder": mock_coder, "worker_summarizer": mock_worker_summarizer, "synthesizer": mock_synthesizer, "researcher": mock_researcher, "reflector": mock_reflector } def test_workflow_data_analysis_flow(mock_llms): """Test full flow: QueryAnalyzer -> Planner -> Delegate -> DataAnalyst -> Reflector -> Synthesizer.""" # 1. Mock Query Analyzer mock_qa_instance = MagicMock() mock_llms["qa"].return_value = mock_qa_instance mock_qa_instance.with_structured_output.return_value.invoke.return_value = QueryAnalysis( data_required=["2024 results"], unknowns=[], ambiguities=[], conditions=[], next_action="plan" ) # 2. Mock Planner mock_planner_instance = MagicMock() mock_llms["planner"].return_value = mock_planner_instance mock_planner_instance.with_structured_output.return_value.invoke.return_value = ChecklistResponse( goal="Get results", reflection="Reflect", checklist=[ChecklistTask(task="Query Data", worker="data_analyst")] ) # 3. Mock Coder (Worker) mock_coder_instance = MagicMock() mock_llms["coder"].return_value = mock_coder_instance mock_coder_instance.with_structured_output.return_value.invoke.return_value = CodeGenerationResponse( code="print('Execution Success')", explanation="Explain" ) # 4. Mock Worker Summarizer mock_ws_instance = MagicMock() mock_llms["worker_summarizer"].return_value = mock_ws_instance mock_ws_instance.invoke.return_value = AIMessage(content="Worker Summary") # 5. Mock Reflector mock_reflector_instance = MagicMock() mock_llms["reflector"].return_value = mock_reflector_instance mock_reflector_instance.with_structured_output.return_value.invoke.return_value = MagicMock(satisfied=True, reasoning="Good.") # 6. Mock Synthesizer mock_syn_instance = MagicMock() mock_llms["synthesizer"].return_value = mock_syn_instance mock_syn_instance.invoke.return_value = AIMessage(content="Final Summary: Success") # Initial state initial_state = { "messages": [], "question": "Show me 2024 results", "analysis": None, "next_action": "", "iterations": 0, "checklist": [], "current_step": 0, "vfs": {}, "plots": [], "dfs": {} } # Run the graph result = app.invoke(initial_state, config={"recursion_limit": 20}) assert "Final Summary: Success" in [m.content for m in result["messages"]] assert result["current_step"] == 1 def test_workflow_research_flow(mock_llms): """Test flow with research task.""" # 1. Mock Query Analyzer mock_qa_instance = MagicMock() mock_llms["qa"].return_value = mock_qa_instance mock_qa_instance.with_structured_output.return_value.invoke.return_value = QueryAnalysis( data_required=[], unknowns=[], ambiguities=[], conditions=[], next_action="research" ) # 2. Mock Planner mock_planner_instance = MagicMock() mock_llms["planner"].return_value = mock_planner_instance mock_planner_instance.with_structured_output.return_value.invoke.return_value = ChecklistResponse( goal="Search", reflection="Reflect", checklist=[ChecklistTask(task="Search Web", worker="researcher")] ) # 3. Mock Researcher mock_res_instance = MagicMock() mock_llms["researcher"].return_value = mock_res_instance mock_res_instance.invoke.return_value = AIMessage(content="Research Result") # 4. Mock Reflector mock_reflector_instance = MagicMock() mock_llms["reflector"].return_value = mock_reflector_instance mock_reflector_instance.with_structured_output.return_value.invoke.return_value = MagicMock(satisfied=True, reasoning="Good.") # 5. Mock Synthesizer mock_syn_instance = MagicMock() mock_llms["synthesizer"].return_value = mock_syn_instance mock_syn_instance.invoke.return_value = AIMessage(content="Final Research Summary") initial_state = { "messages": [], "question": "Who is the governor?", "analysis": None, "next_action": "", "iterations": 0, "checklist": [], "current_step": 0, "vfs": {}, "plots": [], "dfs": {} } result = app.invoke(initial_state, config={"recursion_limit": 20}) assert "Final Research Summary" in [m.content for m in result["messages"]] assert result["current_step"] == 1