diff --git a/backend/src/ea_chatbot/graph/nodes/coder.py b/backend/src/ea_chatbot/graph/nodes/coder.py deleted file mode 100644 index 8a3bd84..0000000 --- a/backend/src/ea_chatbot/graph/nodes/coder.py +++ /dev/null @@ -1,47 +0,0 @@ -from ea_chatbot.graph.state import AgentState -from ea_chatbot.config import Settings -from ea_chatbot.utils.llm_factory import get_llm_model -from ea_chatbot.utils import database_inspection -from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler -from ea_chatbot.graph.prompts.coder import CODE_GENERATOR_PROMPT -from ea_chatbot.schemas import CodeGenerationResponse - -def coder_node(state: AgentState) -> dict: - """Generate Python code based on the plan and data summary.""" - question = state["question"] - plan = state.get("plan", "") - code_output = state.get("code_output", "None") - - settings = Settings() - logger = get_logger("coder") - - logger.info("Generating Python code...") - - llm = get_llm_model( - settings.coder_llm, - callbacks=[LangChainLoggingHandler(logger=logger)] - ) - structured_llm = llm.with_structured_output(CodeGenerationResponse) - - # Always provide data summary - database_description = database_inspection.get_data_summary(data_dir=settings.data_dir) or "No data available." - example_code = "" # Placeholder - - messages = CODE_GENERATOR_PROMPT.format_messages( - question=question, - plan=plan, - database_description=database_description, - code_exec_results=code_output, - example_code=example_code - ) - - try: - response = structured_llm.invoke(messages) - logger.info("[bold green]Code generated.[/bold green]") - return { - "code": response.parsed_code, - "error": None # Clear previous errors on new code generation - } - except Exception as e: - logger.error(f"Failed to generate code: {str(e)}") - raise e diff --git a/backend/src/ea_chatbot/graph/nodes/error_corrector.py b/backend/src/ea_chatbot/graph/nodes/error_corrector.py deleted file mode 100644 index a2aed0b..0000000 --- a/backend/src/ea_chatbot/graph/nodes/error_corrector.py +++ /dev/null @@ -1,44 +0,0 @@ -from ea_chatbot.graph.state import AgentState -from ea_chatbot.config import Settings -from ea_chatbot.utils.llm_factory import get_llm_model -from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler -from ea_chatbot.graph.prompts.coder import ERROR_CORRECTOR_PROMPT -from ea_chatbot.schemas import CodeGenerationResponse - -def error_corrector_node(state: AgentState) -> dict: - """Fix the code based on the execution error.""" - code = state.get("code", "") - error = state.get("error", "Unknown error") - - settings = Settings() - logger = get_logger("error_corrector") - - logger.warning(f"[bold red]Execution error detected:[/bold red] {error[:100]}...") - logger.info("Attempting to correct the code...") - - # Reuse coder LLM config or add a new one. Using coder_llm for now. - llm = get_llm_model( - settings.coder_llm, - callbacks=[LangChainLoggingHandler(logger=logger)] - ) - structured_llm = llm.with_structured_output(CodeGenerationResponse) - - messages = ERROR_CORRECTOR_PROMPT.format_messages( - code=code, - error=error - ) - - try: - response = structured_llm.invoke(messages) - logger.info("[bold green]Correction generated.[/bold green]") - - current_iterations = state.get("iterations", 0) - - return { - "code": response.parsed_code, - "error": None, # Clear error after fix attempt - "iterations": current_iterations + 1 - } - except Exception as e: - logger.error(f"Failed to correct code: {str(e)}") - raise e diff --git a/backend/src/ea_chatbot/graph/nodes/executor.py b/backend/src/ea_chatbot/graph/nodes/executor.py deleted file mode 100644 index 28db069..0000000 --- a/backend/src/ea_chatbot/graph/nodes/executor.py +++ /dev/null @@ -1,110 +0,0 @@ -import io -import sys -import traceback -import copy -from contextlib import redirect_stdout -from typing import TYPE_CHECKING -import pandas as pd -from matplotlib.figure import Figure - -from ea_chatbot.graph.state import AgentState -from ea_chatbot.utils.db_client import DBClient -from ea_chatbot.utils.vfs import VFSHelper, safe_vfs_copy -from ea_chatbot.utils.logging import get_logger -from ea_chatbot.config import Settings - -if TYPE_CHECKING: - from ea_chatbot.types import DBSettings - -def executor_node(state: AgentState) -> dict: - """Execute the Python code and capture output, plots, and dataframes.""" - code = state.get("code") - logger = get_logger("executor") - - if not code: - logger.error("No code provided to executor.") - return {"error": "No code provided to executor."} - - logger.info("Executing Python code...") - settings = Settings() - - db_settings: "DBSettings" = { - "host": settings.db_host, - "port": settings.db_port, - "user": settings.db_user, - "pswd": settings.db_pswd, - "db": settings.db_name, - "table": settings.db_table - } - - db_client = DBClient(settings=db_settings) - - # Initialize the Virtual File System (VFS) helper - vfs_state = safe_vfs_copy(state.get("vfs", {})) - vfs_helper = VFSHelper(vfs_state) - - # Initialize local variables for execution - # 'db' is the DBClient instance, 'plots' is for matplotlib figures - local_vars = { - 'db': db_client, - 'plots': [], - 'pd': pd, - 'vfs': vfs_helper - } - - stdout_buffer = io.StringIO() - error = None - code_output = "" - plots = [] - dfs = {} - - try: - with redirect_stdout(stdout_buffer): - # Execute the code in the context of local_vars - exec(code, {}, local_vars) - - code_output = stdout_buffer.getvalue() - - # Limit the output length if it's too long - if code_output.count('\n') > 32: - code_output = '\n'.join(code_output.split('\n')[:32]) + '\n...' - - # Extract plots - raw_plots = local_vars.get('plots', []) - if isinstance(raw_plots, list): - plots = [p for p in raw_plots if isinstance(p, Figure)] - - # Extract DataFrames that were likely intended for display - # We look for DataFrames in local_vars that were mentioned in the code - for key, value in local_vars.items(): - if isinstance(value, pd.DataFrame): - # Heuristic: if the variable name is in the code, it might be a result DF - if key in code: - dfs[key] = value - - logger.info(f"[bold green]Execution complete.[/bold green] Captured {len(plots)} plots and {len(dfs)} dataframes.") - - except Exception as e: - # Capture the traceback - exc_type, exc_value, tb = sys.exc_info() - full_traceback = traceback.format_exc() - - # Filter traceback to show only the relevant part (the executed string) - filtered_tb_lines = [line for line in full_traceback.split('\n') if '' in line] - error = '\n'.join(filtered_tb_lines) - if error: - error += '\n' - error += f"{exc_type.__name__ if exc_type else 'Exception'}: {exc_value}" - - logger.error(f"Execution failed: {str(e)}") - - # If we have an error, we still might want to see partial stdout - code_output = stdout_buffer.getvalue() - - return { - "code_output": code_output, - "error": error, - "plots": plots, - "dfs": dfs, - "vfs": vfs_state - } diff --git a/backend/src/ea_chatbot/graph/nodes/planner.py b/backend/src/ea_chatbot/graph/nodes/planner.py index de7e858..c4a4e00 100644 --- a/backend/src/ea_chatbot/graph/nodes/planner.py +++ b/backend/src/ea_chatbot/graph/nodes/planner.py @@ -38,7 +38,7 @@ def planner_node(state: AgentState) -> dict: ) try: - response = structured_llm.invoke(messages) + response = ChecklistResponse.model_validate(structured_llm.invoke(messages)) # Convert ChecklistTask objects to dicts for state checklist = [task.model_dump() for task in response.checklist] logger.info(f"[bold green]Checklist generated with {len(checklist)} tasks.[/bold green]") diff --git a/backend/src/ea_chatbot/graph/nodes/researcher.py b/backend/src/ea_chatbot/graph/nodes/researcher.py deleted file mode 100644 index 84aa53b..0000000 --- a/backend/src/ea_chatbot/graph/nodes/researcher.py +++ /dev/null @@ -1,59 +0,0 @@ -from langchain_openai import ChatOpenAI -from langchain_google_genai import ChatGoogleGenerativeAI -from ea_chatbot.graph.state import AgentState -from ea_chatbot.config import Settings -from ea_chatbot.utils.llm_factory import get_llm_model -from ea_chatbot.utils import helpers -from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler -from ea_chatbot.graph.prompts.researcher import RESEARCHER_PROMPT - -def researcher_node(state: AgentState) -> dict: - """Handle general research queries or web searches.""" - question = state["question"] - history = state.get("messages", [])[-6:] - summary = state.get("summary", "") - - settings = Settings() - logger = get_logger("researcher") - - logger.info(f"Researching question: [italic]\"{question}\"[/italic]") - - # Use researcher_llm from settings - llm = get_llm_model( - settings.researcher_llm, - callbacks=[LangChainLoggingHandler(logger=logger)] - ) - - date_str = helpers.get_readable_date() - - messages = RESEARCHER_PROMPT.format_messages( - date=date_str, - question=question, - history=history, - summary=summary - ) - - # Provider-aware tool binding - try: - if isinstance(llm, ChatGoogleGenerativeAI): - # Native Google Search for Gemini - llm_with_tools = llm.bind_tools([{"google_search": {}}]) - elif isinstance(llm, ChatOpenAI): - # Native Web Search for OpenAI (built-in tool) - llm_with_tools = llm.bind_tools([{"type": "web_search"}]) - else: - # Fallback for other providers that might not support these specific search tools - llm_with_tools = llm - except Exception as e: - logger.warning(f"Failed to bind search tools: {str(e)}. Falling back to base LLM.") - llm_with_tools = llm - - try: - response = llm_with_tools.invoke(messages) - logger.info("[bold green]Research complete.[/bold green]") - return { - "messages": [response] - } - except Exception as e: - logger.error(f"Research failed: {str(e)}") - raise e \ No newline at end of file diff --git a/backend/src/ea_chatbot/graph/nodes/summarizer.py b/backend/src/ea_chatbot/graph/nodes/summarizer.py deleted file mode 100644 index 3796e9d..0000000 --- a/backend/src/ea_chatbot/graph/nodes/summarizer.py +++ /dev/null @@ -1,43 +0,0 @@ -from ea_chatbot.graph.state import AgentState -from ea_chatbot.config import Settings -from ea_chatbot.utils.llm_factory import get_llm_model -from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler -from ea_chatbot.graph.prompts.summarizer import SUMMARIZER_PROMPT - -def summarizer_node(state: AgentState) -> dict: - """Summarize the code execution results into a final answer.""" - question = state["question"] - plan = state.get("plan", "") - code_output = state.get("code_output", "") - history = state.get("messages", [])[-6:] - summary = state.get("summary", "") - - settings = Settings() - logger = get_logger("summarizer") - - logger.info("Generating final summary...") - - llm = get_llm_model( - settings.summarizer_llm, - callbacks=[LangChainLoggingHandler(logger=logger)] - ) - - messages = SUMMARIZER_PROMPT.format_messages( - question=question, - plan=plan, - code_output=code_output, - history=history, - summary=summary - ) - - try: - response = llm.invoke(messages) - logger.info("[bold green]Summary generated.[/bold green]") - - # Return the final message to be added to the state - return { - "messages": [response] - } - except Exception as e: - logger.error(f"Failed to generate summary: {str(e)}") - raise e diff --git a/backend/tests/test_coder.py b/backend/tests/test_coder.py deleted file mode 100644 index 2bb907f..0000000 --- a/backend/tests/test_coder.py +++ /dev/null @@ -1,62 +0,0 @@ -import pytest -from unittest.mock import MagicMock, patch -from ea_chatbot.graph.nodes.coder import coder_node -from ea_chatbot.graph.nodes.error_corrector import error_corrector_node - -@pytest.fixture -def mock_state(): - return { - "messages": [], - "question": "Show me results for New Jersey", - "plan": "Step 1: Load data\nStep 2: Filter by NJ", - "code": None, - "error": None, - "plots": [], - "dfs": {}, - "next_action": "plan" - } - -@patch("ea_chatbot.graph.nodes.coder.get_llm_model") -@patch("ea_chatbot.utils.database_inspection.get_data_summary") -def test_coder_node(mock_get_summary, mock_get_llm, mock_state): - """Test coder node generates code from plan.""" - mock_get_summary.return_value = "Column: Name, Type: text" - - mock_llm = MagicMock() - mock_get_llm.return_value = mock_llm - - from ea_chatbot.schemas import CodeGenerationResponse - mock_response = CodeGenerationResponse( - code="import pandas as pd\nprint('Hello')", - explanation="Generated code" - ) - mock_llm.with_structured_output.return_value.invoke.return_value = mock_response - - result = coder_node(mock_state) - - assert "code" in result - assert "import pandas as pd" in result["code"] - assert "error" in result - assert result["error"] is None - -@patch("ea_chatbot.graph.nodes.error_corrector.get_llm_model") -def test_error_corrector_node(mock_get_llm, mock_state): - """Test error corrector node fixes code.""" - mock_state["code"] = "import pandas as pd\nprint(undefined_var)" - mock_state["error"] = "NameError: name 'undefined_var' is not defined" - - mock_llm = MagicMock() - mock_get_llm.return_value = mock_llm - - from ea_chatbot.schemas import CodeGenerationResponse - mock_response = CodeGenerationResponse( - code="import pandas as pd\nprint('Defined')", - explanation="Fixed variable" - ) - mock_llm.with_structured_output.return_value.invoke.return_value = mock_response - - result = error_corrector_node(mock_state) - - assert "code" in result - assert "print('Defined')" in result["code"] - assert result["error"] is None \ No newline at end of file diff --git a/backend/tests/test_executor.py b/backend/tests/test_executor.py deleted file mode 100644 index 3f9d68b..0000000 --- a/backend/tests/test_executor.py +++ /dev/null @@ -1,122 +0,0 @@ -import pytest -import pandas as pd -from unittest.mock import MagicMock, patch -from matplotlib.figure import Figure -from ea_chatbot.graph.nodes.executor import executor_node - -@pytest.fixture -def mock_settings(): - with patch("ea_chatbot.graph.nodes.executor.Settings") as MockSettings: - mock_settings_instance = MagicMock() - mock_settings_instance.db_host = "localhost" - mock_settings_instance.db_port = 5432 - mock_settings_instance.db_user = "user" - mock_settings_instance.db_pswd = "pass" - mock_settings_instance.db_name = "test_db" - mock_settings_instance.db_table = "test_table" - MockSettings.return_value = mock_settings_instance - yield mock_settings_instance - -@pytest.fixture -def mock_db_client(): - with patch("ea_chatbot.graph.nodes.executor.DBClient") as MockDBClient: - mock_client_instance = MagicMock() - MockDBClient.return_value = mock_client_instance - yield mock_client_instance - -def test_executor_node_success_simple_print(mock_settings, mock_db_client): - """Test executing simple code that prints to stdout.""" - state = { - "code": "print('Hello, World!')", - "question": "test", - "messages": [] - } - - result = executor_node(state) - - assert "code_output" in result - assert "Hello, World!" in result["code_output"] - assert result["error"] is None - assert result["plots"] == [] - assert result["dfs"] == {} - -def test_executor_node_success_dataframe(mock_settings, mock_db_client): - """Test executing code that creates a DataFrame.""" - code = """ -import pandas as pd -df = pd.DataFrame({'a': [1, 2], 'b': [3, 4]}) -print(df) -""" - state = { - "code": code, - "question": "test", - "messages": [] - } - - result = executor_node(state) - - assert "code_output" in result - assert "a b" in result["code_output"] # Check part of DF string representation - assert "dfs" in result - assert "df" in result["dfs"] - assert isinstance(result["dfs"]["df"], pd.DataFrame) - -def test_executor_node_success_plot(mock_settings, mock_db_client): - """Test executing code that generates a plot.""" - code = """ -import matplotlib.pyplot as plt -fig = plt.figure() -plots.append(fig) -print('Plot generated') -""" - state = { - "code": code, - "question": "test", - "messages": [] - } - - result = executor_node(state) - - assert "Plot generated" in result["code_output"] - assert "plots" in result - assert len(result["plots"]) == 1 - assert isinstance(result["plots"][0], Figure) - -def test_executor_node_error_syntax(mock_settings, mock_db_client): - """Test executing code with a syntax error.""" - state = { - "code": "print('Hello World", # Missing closing quote - "question": "test", - "messages": [] - } - - result = executor_node(state) - - assert result["error"] is not None - assert "SyntaxError" in result["error"] - -def test_executor_node_error_runtime(mock_settings, mock_db_client): - """Test executing code with a runtime error.""" - state = { - "code": "print(1 / 0)", - "question": "test", - "messages": [] - } - - result = executor_node(state) - - assert result["error"] is not None - assert "ZeroDivisionError" in result["error"] - -def test_executor_node_no_code(mock_settings, mock_db_client): - """Test handling when no code is provided.""" - state = { - "code": None, - "question": "test", - "messages": [] - } - - result = executor_node(state) - - assert "error" in result - assert "No code provided" in result["error"] diff --git a/backend/tests/test_executor_vfs.py b/backend/tests/test_executor_vfs.py deleted file mode 100644 index 3f67368..0000000 --- a/backend/tests/test_executor_vfs.py +++ /dev/null @@ -1,36 +0,0 @@ -from ea_chatbot.graph.nodes.executor import executor_node -from ea_chatbot.graph.state import AgentState - -def test_executor_with_vfs(): - """Verify that the executor node provides VFS access to the code.""" - state = AgentState( - messages=[], - question="test", - analysis={}, - next_action="test", - iterations=0, - checklist=[], - current_step=0, - vfs={}, - plots=[], - dfs={} - ) - - # Code that uses the 'vfs' helper - code = """ -vfs.write("output.txt", "Execution Result", metadata={"type": "text"}) -print("VFS Write Complete") -""" - state["code"] = code - - result = executor_node(state) - - # Check if the execution was successful - assert result["error"] is None - assert "VFS Write Complete" in result["code_output"] - - # Verify that the VFS state was updated - # Note: executor_node returns a dict of updates, which should include the updated 'vfs' - assert "vfs" in result - assert "output.txt" in result["vfs"] - assert result["vfs"]["output.txt"]["content"] == "Execution Result" diff --git a/backend/tests/test_researcher.py b/backend/tests/test_researcher.py deleted file mode 100644 index 1226d81..0000000 --- a/backend/tests/test_researcher.py +++ /dev/null @@ -1,34 +0,0 @@ -import pytest -from unittest.mock import MagicMock, patch -from langchain_core.messages import AIMessage -from langchain_openai import ChatOpenAI -from ea_chatbot.graph.nodes.researcher import researcher_node - -@pytest.fixture -def mock_llm(): - with patch("ea_chatbot.graph.nodes.researcher.get_llm_model") as mock_get_llm: - mock_llm_instance = MagicMock(spec=ChatOpenAI) - mock_get_llm.return_value = mock_llm_instance - yield mock_llm_instance - -def test_researcher_node_success(mock_llm): - """Test that researcher_node invokes LLM with web_search tool and returns messages.""" - state = { - "question": "What is the capital of France?", - "messages": [] - } - - mock_llm_with_tools = MagicMock() - mock_llm.bind_tools.return_value = mock_llm_with_tools - mock_llm_with_tools.invoke.return_value = AIMessage(content="The capital of France is Paris.") - - result = researcher_node(state) - - assert mock_llm.bind_tools.called - # Check that it was called with web_search - args, kwargs = mock_llm.bind_tools.call_args - assert {"type": "web_search"} in args[0] - - assert mock_llm_with_tools.invoke.called - assert "messages" in result - assert result["messages"][0].content == "The capital of France is Paris." diff --git a/backend/tests/test_researcher_search_tools.py b/backend/tests/test_researcher_search_tools.py deleted file mode 100644 index 42e82cd..0000000 --- a/backend/tests/test_researcher_search_tools.py +++ /dev/null @@ -1,62 +0,0 @@ -import pytest -from unittest.mock import MagicMock, patch -from langchain_core.messages import AIMessage -from langchain_openai import ChatOpenAI -from langchain_google_genai import ChatGoogleGenerativeAI -from ea_chatbot.graph.nodes.researcher import researcher_node - -@pytest.fixture -def base_state(): - return { - "question": "Who won the 2024 election?", - "messages": [], - "summary": "" - } - -@patch("ea_chatbot.graph.nodes.researcher.get_llm_model") -def test_researcher_binds_openai_search(mock_get_llm, base_state): - """Test that OpenAI LLM binds 'web_search' tool.""" - mock_llm = MagicMock(spec=ChatOpenAI) - mock_get_llm.return_value = mock_llm - - mock_llm_with_tools = MagicMock() - mock_llm.bind_tools.return_value = mock_llm_with_tools - mock_llm_with_tools.invoke.return_value = AIMessage(content="OpenAI Search Result") - - result = researcher_node(base_state) - - # Verify bind_tools called with correct OpenAI tool - mock_llm.bind_tools.assert_called_once_with([{"type": "web_search"}]) - assert result["messages"][0].content == "OpenAI Search Result" - -@patch("ea_chatbot.graph.nodes.researcher.get_llm_model") -def test_researcher_binds_google_search(mock_get_llm, base_state): - """Test that Google LLM binds 'google_search' tool.""" - mock_llm = MagicMock(spec=ChatGoogleGenerativeAI) - mock_get_llm.return_value = mock_llm - - mock_llm_with_tools = MagicMock() - mock_llm.bind_tools.return_value = mock_llm_with_tools - mock_llm_with_tools.invoke.return_value = AIMessage(content="Google Search Result") - - result = researcher_node(base_state) - - # Verify bind_tools called with correct Google tool - mock_llm.bind_tools.assert_called_once_with([{"google_search": {}}]) - assert result["messages"][0].content == "Google Search Result" - -@patch("ea_chatbot.graph.nodes.researcher.get_llm_model") -def test_researcher_fallback_on_bind_error(mock_get_llm, base_state): - """Test that researcher falls back to basic LLM if bind_tools fails.""" - mock_llm = MagicMock(spec=ChatOpenAI) - mock_get_llm.return_value = mock_llm - - # Simulate bind_tools failing (e.g. model doesn't support it) - mock_llm.bind_tools.side_effect = Exception("Not supported") - mock_llm.invoke.return_value = AIMessage(content="Basic Result") - - result = researcher_node(base_state) - - # Should still succeed using the base LLM - assert result["messages"][0].content == "Basic Result" - mock_llm.invoke.assert_called_once() diff --git a/backend/tests/test_summarizer.py b/backend/tests/test_summarizer.py deleted file mode 100644 index b152fdb..0000000 --- a/backend/tests/test_summarizer.py +++ /dev/null @@ -1,47 +0,0 @@ -import pytest -from unittest.mock import MagicMock, patch -from langchain_core.messages import AIMessage -from ea_chatbot.graph.nodes.summarizer import summarizer_node - -@pytest.fixture -def mock_llm(): - with patch("ea_chatbot.graph.nodes.summarizer.get_llm_model") as mock_get_llm: - mock_llm_instance = MagicMock() - mock_get_llm.return_value = mock_llm_instance - yield mock_llm_instance - -def test_summarizer_node_success(mock_llm): - """Test that summarizer_node invokes LLM with correct inputs and returns messages.""" - state = { - "question": "What is the total count?", - "plan": "1. Run query\n2. Sum results", - "code_output": "The total is 100", - "messages": [] - } - - mock_llm.invoke.return_value = AIMessage(content="The final answer is 100.") - - result = summarizer_node(state) - - # Verify LLM was called - assert mock_llm.invoke.called - - # Verify result structure - assert "messages" in result - assert len(result["messages"]) == 1 - assert isinstance(result["messages"][0], AIMessage) - assert result["messages"][0].content == "The final answer is 100." - -def test_summarizer_node_empty_state(mock_llm): - """Test handling of empty or minimal state.""" - state = { - "question": "Empty?", - "messages": [] - } - - mock_llm.invoke.return_value = AIMessage(content="No data provided.") - - result = summarizer_node(state) - - assert "messages" in result - assert result["messages"][0].content == "No data provided." diff --git a/backend/tests/test_workflow_e2e.py b/backend/tests/test_workflow_e2e.py index 29762e8..e5035b1 100644 --- a/backend/tests/test_workflow_e2e.py +++ b/backend/tests/test_workflow_e2e.py @@ -12,7 +12,8 @@ def mock_llms(): patch("ea_chatbot.graph.workers.data_analyst.nodes.coder.get_llm_model") as mock_coder, \ patch("ea_chatbot.graph.workers.data_analyst.nodes.summarizer.get_llm_model") as mock_worker_summarizer, \ patch("ea_chatbot.graph.nodes.synthesizer.get_llm_model") as mock_synthesizer, \ - patch("ea_chatbot.graph.nodes.researcher.get_llm_model") as mock_researcher, \ + patch("ea_chatbot.graph.workers.researcher.nodes.searcher.get_llm_model") as mock_researcher, \ + patch("ea_chatbot.graph.workers.researcher.nodes.summarizer.get_llm_model") as mock_res_summarizer, \ patch("ea_chatbot.graph.nodes.reflector.get_llm_model") as mock_reflector: yield { "qa": mock_qa, @@ -21,6 +22,7 @@ def mock_llms(): "worker_summarizer": mock_worker_summarizer, "synthesizer": mock_synthesizer, "researcher": mock_researcher, + "res_summarizer": mock_res_summarizer, "reflector": mock_reflector } @@ -113,17 +115,22 @@ def test_workflow_research_flow(mock_llms): checklist=[ChecklistTask(task="Search Web", worker="researcher")] ) - # 3. Mock Researcher + # 3. Mock Researcher Searcher mock_res_instance = MagicMock() mock_llms["researcher"].return_value = mock_res_instance mock_res_instance.invoke.return_value = AIMessage(content="Research Result") - # 4. Mock Reflector + # 4. Mock Researcher Summarizer + mock_rs_instance = MagicMock() + mock_llms["res_summarizer"].return_value = mock_rs_instance + mock_rs_instance.invoke.return_value = AIMessage(content="Researcher Summary") + + # 5. Mock Reflector mock_reflector_instance = MagicMock() mock_llms["reflector"].return_value = mock_reflector_instance mock_reflector_instance.with_structured_output.return_value.invoke.return_value = MagicMock(satisfied=True, reasoning="Good.") - # 5. Mock Synthesizer + # 6. Mock Synthesizer mock_syn_instance = MagicMock() mock_llms["synthesizer"].return_value = mock_syn_instance mock_syn_instance.invoke.return_value = AIMessage(content="Final Research Summary")