import pytest from unittest.mock import MagicMock, patch from ea_chatbot.graph.workflow import app from ea_chatbot.graph.nodes.query_analyzer import QueryAnalysis from ea_chatbot.schemas import TaskPlanResponse, TaskPlanContext, CodeGenerationResponse from langchain_core.messages import AIMessage @patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model") @patch("ea_chatbot.graph.nodes.planner.get_llm_model") @patch("ea_chatbot.graph.nodes.coder.get_llm_model") @patch("ea_chatbot.graph.nodes.summarizer.get_llm_model") @patch("ea_chatbot.graph.nodes.researcher.get_llm_model") @patch("ea_chatbot.utils.database_inspection.get_data_summary") @patch("ea_chatbot.graph.nodes.executor.Settings") @patch("ea_chatbot.graph.nodes.executor.DBClient") def test_workflow_full_flow(mock_db_client, mock_settings, mock_get_summary, mock_researcher_llm, mock_summarizer_llm, mock_coder_llm, mock_planner_llm, mock_qa_llm): """Test the flow from query_analyzer through planner to coder.""" # Mock Settings for Executor mock_settings_instance = MagicMock() mock_settings_instance.db_host = "localhost" mock_settings_instance.db_port = 5432 mock_settings_instance.db_user = "user" mock_settings_instance.db_pswd = "pass" mock_settings_instance.db_name = "test_db" mock_settings_instance.db_table = "test_table" mock_settings.return_value = mock_settings_instance # Mock DBClient mock_client_instance = MagicMock() mock_db_client.return_value = mock_client_instance # 1. Mock Query Analyzer mock_qa_instance = MagicMock() mock_qa_llm.return_value = mock_qa_instance mock_qa_instance.with_structured_output.return_value.invoke.return_value = QueryAnalysis( data_required=["2024 results"], unknowns=[], ambiguities=[], conditions=[], next_action="plan" ) # 2. Mock Planner mock_planner_instance = MagicMock() mock_planner_llm.return_value = mock_planner_instance mock_get_summary.return_value = "Data summary" mock_planner_instance.with_structured_output.return_value.invoke.return_value = TaskPlanResponse( goal="Task Goal", reflection="Reflection", context=TaskPlanContext(initial_context="Ctx", assumptions=[], constraints=[]), steps=["Step 1"] ) # 3. Mock Coder mock_coder_instance = MagicMock() mock_coder_llm.return_value = mock_coder_instance mock_coder_instance.with_structured_output.return_value.invoke.return_value = CodeGenerationResponse( code="print('Hello')", explanation="Explanation" ) # 4. Mock Summarizer mock_summarizer_instance = MagicMock() mock_summarizer_llm.return_value = mock_summarizer_instance mock_summarizer_instance.invoke.return_value = AIMessage(content="Summary") # 5. Mock Researcher (just in case) mock_researcher_instance = MagicMock() mock_researcher_llm.return_value = mock_researcher_instance # Initial state initial_state = { "messages": [], "question": "Show me the 2024 results", "analysis": None, "next_action": "", "plan": None, "code": None, "error": None, "plots": [], "dfs": {} } # Run the graph # We use recursion_limit to avoid infinite loops in placeholders if any result = app.invoke(initial_state, config={"recursion_limit": 10}) assert result["next_action"] == "plan" assert "plan" in result and result["plan"] is not None assert "code" in result and "print('Hello')" in result["code"] assert "analysis" in result