From cb045504d16392ebf44c3815e84db921609dc157 Mon Sep 17 00:00:00 2001 From: Yunxiao Xu Date: Mon, 23 Feb 2026 04:58:46 -0800 Subject: [PATCH] feat(workers): Extract Coder and Executor nodes into Data Analyst worker subgraph --- .../workers/data_analyst/nodes/__init__.py | 0 .../graph/workers/data_analyst/nodes/coder.py | 61 ++++++++++++ .../workers/data_analyst/nodes/executor.py | 96 +++++++++++++++++++ .../workers/data_analyst/nodes/summarizer.py | 42 ++++++++ .../graph/workers/data_analyst/workflow.py | 50 ++++++++++ backend/tests/test_data_analyst_worker.py | 81 ++++++++++++++++ 6 files changed, 330 insertions(+) create mode 100644 backend/src/ea_chatbot/graph/workers/data_analyst/nodes/__init__.py create mode 100644 backend/src/ea_chatbot/graph/workers/data_analyst/nodes/coder.py create mode 100644 backend/src/ea_chatbot/graph/workers/data_analyst/nodes/executor.py create mode 100644 backend/src/ea_chatbot/graph/workers/data_analyst/nodes/summarizer.py create mode 100644 backend/src/ea_chatbot/graph/workers/data_analyst/workflow.py create mode 100644 backend/tests/test_data_analyst_worker.py diff --git a/backend/src/ea_chatbot/graph/workers/data_analyst/nodes/__init__.py b/backend/src/ea_chatbot/graph/workers/data_analyst/nodes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/src/ea_chatbot/graph/workers/data_analyst/nodes/coder.py b/backend/src/ea_chatbot/graph/workers/data_analyst/nodes/coder.py new file mode 100644 index 0000000..ff08158 --- /dev/null +++ b/backend/src/ea_chatbot/graph/workers/data_analyst/nodes/coder.py @@ -0,0 +1,61 @@ +from typing import Dict, Any, List, Optional +from ea_chatbot.graph.workers.data_analyst.state import WorkerState +from ea_chatbot.config import Settings +from ea_chatbot.utils.llm_factory import get_llm_model +from ea_chatbot.utils import database_inspection +from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler +from ea_chatbot.graph.prompts.coder import CODE_GENERATOR_PROMPT +from ea_chatbot.schemas import CodeGenerationResponse + +def coder_node(state: WorkerState) -> dict: + """Generate Python code based on the sub-task assigned to the worker.""" + task = state["task"] + output = state.get("output", "None") + error = state.get("error", "None") + vfs_state = state.get("vfs_state", {}) + + settings = Settings() + logger = get_logger("data_analyst_worker:coder") + + logger.info(f"Generating Python code for task: {task[:50]}...") + + # We can use the configured 'coder_llm' for this node + llm = get_llm_model( + settings.coder_llm, + callbacks=[LangChainLoggingHandler(logger=logger)] + ) + structured_llm = llm.with_structured_output(CodeGenerationResponse) + + # Data summary for context + database_description = database_inspection.get_data_summary(data_dir=settings.data_dir) or "No data available." + + # VFS Summary: Let the LLM know which files are available in-memory + vfs_summary = "Available in-memory files (VFS):\n" + if vfs_state: + for filename, data in vfs_state.items(): + meta = data.get("metadata", {}) + vfs_summary += f"- {filename} ({meta.get('type', 'unknown')})\n" + else: + vfs_summary += "- None" + + # Reuse the global prompt but adapt 'question' and 'plan' labels + # For a sub-task worker, 'task' effectively replaces the high-level 'plan' + messages = CODE_GENERATOR_PROMPT.format_messages( + question=task, + plan="Focus on the specific task below.", + database_description=database_description, + code_exec_results=f"Output: {output}\nError: {error}\n\n{vfs_summary}", + example_code="" + ) + + try: + response = structured_llm.invoke(messages) + logger.info("[bold green]Code generated.[/bold green]") + return { + "code": response.parsed_code, + "error": None, # Clear previous errors + "iterations": state.get("iterations", 0) + 1 + } + except Exception as e: + logger.error(f"Failed to generate code: {str(e)}") + raise e diff --git a/backend/src/ea_chatbot/graph/workers/data_analyst/nodes/executor.py b/backend/src/ea_chatbot/graph/workers/data_analyst/nodes/executor.py new file mode 100644 index 0000000..a92dd15 --- /dev/null +++ b/backend/src/ea_chatbot/graph/workers/data_analyst/nodes/executor.py @@ -0,0 +1,96 @@ +import io +import sys +import traceback +from contextlib import redirect_stdout +from typing import TYPE_CHECKING +import pandas as pd +from matplotlib.figure import Figure + +from ea_chatbot.graph.workers.data_analyst.state import WorkerState +from ea_chatbot.utils.db_client import DBClient +from ea_chatbot.utils.vfs import VFSHelper +from ea_chatbot.utils.logging import get_logger +from ea_chatbot.config import Settings + +if TYPE_CHECKING: + from ea_chatbot.types import DBSettings + +def executor_node(state: WorkerState) -> dict: + """Execute the Python code in the context of the Data Analyst worker.""" + code = state.get("code") + logger = get_logger("data_analyst_worker:executor") + + if not code: + logger.error("No code provided to executor.") + return {"error": "No code provided to executor."} + + logger.info("Executing Python code...") + settings = Settings() + + db_settings: "DBSettings" = { + "host": settings.db_host, + "port": settings.db_port, + "user": settings.db_user, + "pswd": settings.db_pswd, + "db": settings.db_name, + "table": settings.db_table + } + + db_client = DBClient(settings=db_settings) + + # Initialize the Virtual File System (VFS) helper with the snapshot from state + vfs_state = dict(state.get("vfs_state", {})) + vfs_helper = VFSHelper(vfs_state) + + # Initialize local variables for execution + local_vars = { + 'db': db_client, + 'plots': [], + 'pd': pd, + 'vfs': vfs_helper + } + + stdout_buffer = io.StringIO() + error = None + output = "" + plots = [] + + try: + with redirect_stdout(stdout_buffer): + # Execute the code in the context of local_vars + exec(code, {}, local_vars) + + output = stdout_buffer.getvalue() + + # Limit the output length if it's too long + if output.count('\n') > 32: + output = '\n'.join(output.split('\n')[:32]) + '\n...' + + # Extract plots + raw_plots = local_vars.get('plots', []) + if isinstance(raw_plots, list): + plots = [p for p in raw_plots if isinstance(p, Figure)] + + logger.info(f"[bold green]Execution complete.[/bold green] Captured {len(plots)} plots.") + + except Exception as e: + # Capture the traceback + exc_type, exc_value, tb = sys.exc_info() + full_traceback = traceback.format_exc() + + # Filter traceback to show only the relevant part (the executed string) + filtered_tb_lines = [line for line in full_traceback.split('\n') if '' in line] + error = '\n'.join(filtered_tb_lines) + if error: + error += '\n' + error += f"{exc_type.__name__ if exc_type else 'Exception'}: {exc_value}" + + logger.error(f"Execution failed: {str(e)}") + output = stdout_buffer.getvalue() + + return { + "output": output, + "error": error, + "plots": plots, + "vfs_state": vfs_state + } diff --git a/backend/src/ea_chatbot/graph/workers/data_analyst/nodes/summarizer.py b/backend/src/ea_chatbot/graph/workers/data_analyst/nodes/summarizer.py new file mode 100644 index 0000000..b65e831 --- /dev/null +++ b/backend/src/ea_chatbot/graph/workers/data_analyst/nodes/summarizer.py @@ -0,0 +1,42 @@ +from typing import Dict, Any, List, Optional +from ea_chatbot.graph.workers.data_analyst.state import WorkerState +from ea_chatbot.config import Settings +from ea_chatbot.utils.llm_factory import get_llm_model +from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler + +def summarizer_node(state: WorkerState) -> dict: + """Summarize the data analysis results for the Orchestrator.""" + task = state["task"] + output = state.get("output", "") + error = state.get("error") + + settings = Settings() + logger = get_logger("data_analyst_worker:summarizer") + + logger.info("Summarizing analysis results for the Orchestrator...") + + # We can use a smaller/faster model for this summary if needed + llm = get_llm_model( + settings.planner_llm, # Using planner model for summary logic + callbacks=[LangChainLoggingHandler(logger=logger)] + ) + + prompt = f"""You are a data analyst sub-agent. You have completed a sub-task. +Task: {task} +Execution Results: {output} +Error Log (if any): {error} + +Provide a concise summary of the findings or status for the top-level Orchestrator. +If the execution failed after multiple retries, explain why concisely. +Do NOT include the raw Python code, just the results of the analysis.""" + + try: + response = llm.invoke(prompt) + result = response.content if hasattr(response, "content") else str(response) + logger.info("[bold green]Analysis results summarized.[/bold green]") + return { + "result": result + } + except Exception as e: + logger.error(f"Failed to summarize results: {str(e)}") + raise e diff --git a/backend/src/ea_chatbot/graph/workers/data_analyst/workflow.py b/backend/src/ea_chatbot/graph/workers/data_analyst/workflow.py new file mode 100644 index 0000000..2fdae8e --- /dev/null +++ b/backend/src/ea_chatbot/graph/workers/data_analyst/workflow.py @@ -0,0 +1,50 @@ +from langgraph.graph import StateGraph, END +from ea_chatbot.graph.workers.data_analyst.state import WorkerState +from ea_chatbot.graph.workers.data_analyst.nodes.coder import coder_node +from ea_chatbot.graph.workers.data_analyst.nodes.executor import executor_node +from ea_chatbot.graph.workers.data_analyst.nodes.summarizer import summarizer_node + +def router(state: WorkerState) -> str: + """Routes the subgraph between coding, execution, and summarization.""" + error = state.get("error") + iterations = state.get("iterations", 0) + + if error and iterations < 3: + # Retry with error correction + return "coder" + + # Either success or max retries reached + return "summarizer" + +def create_data_analyst_worker( + coder=coder_node, + executor=executor_node, + summarizer=summarizer_node +) -> StateGraph: + """Create the Data Analyst worker subgraph.""" + workflow = StateGraph(WorkerState) + + # Add Nodes + workflow.add_node("coder", coder) + workflow.add_node("executor", executor) + workflow.add_node("summarizer", summarizer) + + # Set entry point + workflow.set_entry_point("coder") + + # Add Edges + workflow.add_edge("coder", "executor") + + # Add Conditional Edges + workflow.add_conditional_edges( + "executor", + router, + { + "coder": "coder", + "summarizer": "summarizer" + } + ) + + workflow.add_edge("summarizer", END) + + return workflow.compile() diff --git a/backend/tests/test_data_analyst_worker.py b/backend/tests/test_data_analyst_worker.py new file mode 100644 index 0000000..915276c --- /dev/null +++ b/backend/tests/test_data_analyst_worker.py @@ -0,0 +1,81 @@ +import pytest +from unittest.mock import MagicMock +from ea_chatbot.graph.workers.data_analyst.workflow import create_data_analyst_worker, WorkerState + +def test_data_analyst_worker_one_shot(): + """Verify a successful one-shot execution of the worker subgraph.""" + mock_coder = MagicMock() + mock_executor = MagicMock() + mock_summarizer = MagicMock() + + # Scenario: Coder -> Executor (Success) -> Summarizer -> END + mock_coder.return_value = {"code": "print(1)", "error": None, "iterations": 1} + mock_executor.return_value = {"output": "1\n", "error": None, "plots": []} + mock_summarizer.return_value = {"result": "Result is 1"} + + graph = create_data_analyst_worker( + coder=mock_coder, + executor=mock_executor, + summarizer=mock_summarizer + ) + + initial_state = WorkerState( + messages=[], + task="Calculate 1+1", + code=None, + output=None, + error=None, + iterations=0, + plots=[], + vfs_state={}, + result=None + ) + + final_state = graph.invoke(initial_state) + + assert final_state["result"] == "Result is 1" + assert mock_coder.call_count == 1 + assert mock_executor.call_count == 1 + assert mock_summarizer.call_count == 1 + +def test_data_analyst_worker_retry(): + """Verify that the worker retries on error.""" + mock_coder = MagicMock() + mock_executor = MagicMock() + mock_summarizer = MagicMock() + + # Scenario: Coder (1) -> Executor (Error) -> Router (coder) -> Coder (2) -> Executor (Success) -> Summarizer -> END + mock_coder.side_effect = [ + {"code": "error_code", "error": None, "iterations": 1}, + {"code": "fixed_code", "error": None, "iterations": 2} + ] + mock_executor.side_effect = [ + {"output": "", "error": "NameError", "plots": []}, + {"output": "Success", "error": None, "plots": []} + ] + mock_summarizer.return_value = {"result": "Fixed Result"} + + graph = create_data_analyst_worker( + coder=mock_coder, + executor=mock_executor, + summarizer=mock_summarizer + ) + + initial_state = WorkerState( + messages=[], + task="Retry Task", + code=None, + output=None, + error=None, + iterations=0, + plots=[], + vfs_state={}, + result=None + ) + + final_state = graph.invoke(initial_state) + + assert final_state["result"] == "Fixed Result" + assert mock_coder.call_count == 2 + assert mock_executor.call_count == 2 + assert mock_summarizer.call_count == 1