feat(workers): Extract Coder and Executor nodes into Data Analyst worker subgraph

This commit is contained in:
Yunxiao Xu
2026-02-23 04:58:46 -08:00
parent 5324cbe851
commit cb045504d1
6 changed files with 330 additions and 0 deletions

View File

@@ -0,0 +1,61 @@
from typing import Dict, Any, List, Optional
from ea_chatbot.graph.workers.data_analyst.state import WorkerState
from ea_chatbot.config import Settings
from ea_chatbot.utils.llm_factory import get_llm_model
from ea_chatbot.utils import database_inspection
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
from ea_chatbot.graph.prompts.coder import CODE_GENERATOR_PROMPT
from ea_chatbot.schemas import CodeGenerationResponse
def coder_node(state: WorkerState) -> dict:
"""Generate Python code based on the sub-task assigned to the worker."""
task = state["task"]
output = state.get("output", "None")
error = state.get("error", "None")
vfs_state = state.get("vfs_state", {})
settings = Settings()
logger = get_logger("data_analyst_worker:coder")
logger.info(f"Generating Python code for task: {task[:50]}...")
# We can use the configured 'coder_llm' for this node
llm = get_llm_model(
settings.coder_llm,
callbacks=[LangChainLoggingHandler(logger=logger)]
)
structured_llm = llm.with_structured_output(CodeGenerationResponse)
# Data summary for context
database_description = database_inspection.get_data_summary(data_dir=settings.data_dir) or "No data available."
# VFS Summary: Let the LLM know which files are available in-memory
vfs_summary = "Available in-memory files (VFS):\n"
if vfs_state:
for filename, data in vfs_state.items():
meta = data.get("metadata", {})
vfs_summary += f"- {filename} ({meta.get('type', 'unknown')})\n"
else:
vfs_summary += "- None"
# Reuse the global prompt but adapt 'question' and 'plan' labels
# For a sub-task worker, 'task' effectively replaces the high-level 'plan'
messages = CODE_GENERATOR_PROMPT.format_messages(
question=task,
plan="Focus on the specific task below.",
database_description=database_description,
code_exec_results=f"Output: {output}\nError: {error}\n\n{vfs_summary}",
example_code=""
)
try:
response = structured_llm.invoke(messages)
logger.info("[bold green]Code generated.[/bold green]")
return {
"code": response.parsed_code,
"error": None, # Clear previous errors
"iterations": state.get("iterations", 0) + 1
}
except Exception as e:
logger.error(f"Failed to generate code: {str(e)}")
raise e

View File

@@ -0,0 +1,96 @@
import io
import sys
import traceback
from contextlib import redirect_stdout
from typing import TYPE_CHECKING
import pandas as pd
from matplotlib.figure import Figure
from ea_chatbot.graph.workers.data_analyst.state import WorkerState
from ea_chatbot.utils.db_client import DBClient
from ea_chatbot.utils.vfs import VFSHelper
from ea_chatbot.utils.logging import get_logger
from ea_chatbot.config import Settings
if TYPE_CHECKING:
from ea_chatbot.types import DBSettings
def executor_node(state: WorkerState) -> dict:
"""Execute the Python code in the context of the Data Analyst worker."""
code = state.get("code")
logger = get_logger("data_analyst_worker:executor")
if not code:
logger.error("No code provided to executor.")
return {"error": "No code provided to executor."}
logger.info("Executing Python code...")
settings = Settings()
db_settings: "DBSettings" = {
"host": settings.db_host,
"port": settings.db_port,
"user": settings.db_user,
"pswd": settings.db_pswd,
"db": settings.db_name,
"table": settings.db_table
}
db_client = DBClient(settings=db_settings)
# Initialize the Virtual File System (VFS) helper with the snapshot from state
vfs_state = dict(state.get("vfs_state", {}))
vfs_helper = VFSHelper(vfs_state)
# Initialize local variables for execution
local_vars = {
'db': db_client,
'plots': [],
'pd': pd,
'vfs': vfs_helper
}
stdout_buffer = io.StringIO()
error = None
output = ""
plots = []
try:
with redirect_stdout(stdout_buffer):
# Execute the code in the context of local_vars
exec(code, {}, local_vars)
output = stdout_buffer.getvalue()
# Limit the output length if it's too long
if output.count('\n') > 32:
output = '\n'.join(output.split('\n')[:32]) + '\n...'
# Extract plots
raw_plots = local_vars.get('plots', [])
if isinstance(raw_plots, list):
plots = [p for p in raw_plots if isinstance(p, Figure)]
logger.info(f"[bold green]Execution complete.[/bold green] Captured {len(plots)} plots.")
except Exception as e:
# Capture the traceback
exc_type, exc_value, tb = sys.exc_info()
full_traceback = traceback.format_exc()
# Filter traceback to show only the relevant part (the executed string)
filtered_tb_lines = [line for line in full_traceback.split('\n') if '<string>' in line]
error = '\n'.join(filtered_tb_lines)
if error:
error += '\n'
error += f"{exc_type.__name__ if exc_type else 'Exception'}: {exc_value}"
logger.error(f"Execution failed: {str(e)}")
output = stdout_buffer.getvalue()
return {
"output": output,
"error": error,
"plots": plots,
"vfs_state": vfs_state
}

View File

@@ -0,0 +1,42 @@
from typing import Dict, Any, List, Optional
from ea_chatbot.graph.workers.data_analyst.state import WorkerState
from ea_chatbot.config import Settings
from ea_chatbot.utils.llm_factory import get_llm_model
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
def summarizer_node(state: WorkerState) -> dict:
"""Summarize the data analysis results for the Orchestrator."""
task = state["task"]
output = state.get("output", "")
error = state.get("error")
settings = Settings()
logger = get_logger("data_analyst_worker:summarizer")
logger.info("Summarizing analysis results for the Orchestrator...")
# We can use a smaller/faster model for this summary if needed
llm = get_llm_model(
settings.planner_llm, # Using planner model for summary logic
callbacks=[LangChainLoggingHandler(logger=logger)]
)
prompt = f"""You are a data analyst sub-agent. You have completed a sub-task.
Task: {task}
Execution Results: {output}
Error Log (if any): {error}
Provide a concise summary of the findings or status for the top-level Orchestrator.
If the execution failed after multiple retries, explain why concisely.
Do NOT include the raw Python code, just the results of the analysis."""
try:
response = llm.invoke(prompt)
result = response.content if hasattr(response, "content") else str(response)
logger.info("[bold green]Analysis results summarized.[/bold green]")
return {
"result": result
}
except Exception as e:
logger.error(f"Failed to summarize results: {str(e)}")
raise e

View File

@@ -0,0 +1,50 @@
from langgraph.graph import StateGraph, END
from ea_chatbot.graph.workers.data_analyst.state import WorkerState
from ea_chatbot.graph.workers.data_analyst.nodes.coder import coder_node
from ea_chatbot.graph.workers.data_analyst.nodes.executor import executor_node
from ea_chatbot.graph.workers.data_analyst.nodes.summarizer import summarizer_node
def router(state: WorkerState) -> str:
"""Routes the subgraph between coding, execution, and summarization."""
error = state.get("error")
iterations = state.get("iterations", 0)
if error and iterations < 3:
# Retry with error correction
return "coder"
# Either success or max retries reached
return "summarizer"
def create_data_analyst_worker(
coder=coder_node,
executor=executor_node,
summarizer=summarizer_node
) -> StateGraph:
"""Create the Data Analyst worker subgraph."""
workflow = StateGraph(WorkerState)
# Add Nodes
workflow.add_node("coder", coder)
workflow.add_node("executor", executor)
workflow.add_node("summarizer", summarizer)
# Set entry point
workflow.set_entry_point("coder")
# Add Edges
workflow.add_edge("coder", "executor")
# Add Conditional Edges
workflow.add_conditional_edges(
"executor",
router,
{
"coder": "coder",
"summarizer": "summarizer"
}
)
workflow.add_edge("summarizer", END)
return workflow.compile()

View File

@@ -0,0 +1,81 @@
import pytest
from unittest.mock import MagicMock
from ea_chatbot.graph.workers.data_analyst.workflow import create_data_analyst_worker, WorkerState
def test_data_analyst_worker_one_shot():
"""Verify a successful one-shot execution of the worker subgraph."""
mock_coder = MagicMock()
mock_executor = MagicMock()
mock_summarizer = MagicMock()
# Scenario: Coder -> Executor (Success) -> Summarizer -> END
mock_coder.return_value = {"code": "print(1)", "error": None, "iterations": 1}
mock_executor.return_value = {"output": "1\n", "error": None, "plots": []}
mock_summarizer.return_value = {"result": "Result is 1"}
graph = create_data_analyst_worker(
coder=mock_coder,
executor=mock_executor,
summarizer=mock_summarizer
)
initial_state = WorkerState(
messages=[],
task="Calculate 1+1",
code=None,
output=None,
error=None,
iterations=0,
plots=[],
vfs_state={},
result=None
)
final_state = graph.invoke(initial_state)
assert final_state["result"] == "Result is 1"
assert mock_coder.call_count == 1
assert mock_executor.call_count == 1
assert mock_summarizer.call_count == 1
def test_data_analyst_worker_retry():
"""Verify that the worker retries on error."""
mock_coder = MagicMock()
mock_executor = MagicMock()
mock_summarizer = MagicMock()
# Scenario: Coder (1) -> Executor (Error) -> Router (coder) -> Coder (2) -> Executor (Success) -> Summarizer -> END
mock_coder.side_effect = [
{"code": "error_code", "error": None, "iterations": 1},
{"code": "fixed_code", "error": None, "iterations": 2}
]
mock_executor.side_effect = [
{"output": "", "error": "NameError", "plots": []},
{"output": "Success", "error": None, "plots": []}
]
mock_summarizer.return_value = {"result": "Fixed Result"}
graph = create_data_analyst_worker(
coder=mock_coder,
executor=mock_executor,
summarizer=mock_summarizer
)
initial_state = WorkerState(
messages=[],
task="Retry Task",
code=None,
output=None,
error=None,
iterations=0,
plots=[],
vfs_state={},
result=None
)
final_state = graph.invoke(initial_state)
assert final_state["result"] == "Fixed Result"
assert mock_coder.call_count == 2
assert mock_executor.call_count == 2
assert mock_summarizer.call_count == 1