feat(workers): Extract Coder and Executor nodes into Data Analyst worker subgraph
This commit is contained in:
@@ -0,0 +1,61 @@
|
||||
from typing import Dict, Any, List, Optional
|
||||
from ea_chatbot.graph.workers.data_analyst.state import WorkerState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from ea_chatbot.utils import database_inspection
|
||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||
from ea_chatbot.graph.prompts.coder import CODE_GENERATOR_PROMPT
|
||||
from ea_chatbot.schemas import CodeGenerationResponse
|
||||
|
||||
def coder_node(state: WorkerState) -> dict:
|
||||
"""Generate Python code based on the sub-task assigned to the worker."""
|
||||
task = state["task"]
|
||||
output = state.get("output", "None")
|
||||
error = state.get("error", "None")
|
||||
vfs_state = state.get("vfs_state", {})
|
||||
|
||||
settings = Settings()
|
||||
logger = get_logger("data_analyst_worker:coder")
|
||||
|
||||
logger.info(f"Generating Python code for task: {task[:50]}...")
|
||||
|
||||
# We can use the configured 'coder_llm' for this node
|
||||
llm = get_llm_model(
|
||||
settings.coder_llm,
|
||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||
)
|
||||
structured_llm = llm.with_structured_output(CodeGenerationResponse)
|
||||
|
||||
# Data summary for context
|
||||
database_description = database_inspection.get_data_summary(data_dir=settings.data_dir) or "No data available."
|
||||
|
||||
# VFS Summary: Let the LLM know which files are available in-memory
|
||||
vfs_summary = "Available in-memory files (VFS):\n"
|
||||
if vfs_state:
|
||||
for filename, data in vfs_state.items():
|
||||
meta = data.get("metadata", {})
|
||||
vfs_summary += f"- {filename} ({meta.get('type', 'unknown')})\n"
|
||||
else:
|
||||
vfs_summary += "- None"
|
||||
|
||||
# Reuse the global prompt but adapt 'question' and 'plan' labels
|
||||
# For a sub-task worker, 'task' effectively replaces the high-level 'plan'
|
||||
messages = CODE_GENERATOR_PROMPT.format_messages(
|
||||
question=task,
|
||||
plan="Focus on the specific task below.",
|
||||
database_description=database_description,
|
||||
code_exec_results=f"Output: {output}\nError: {error}\n\n{vfs_summary}",
|
||||
example_code=""
|
||||
)
|
||||
|
||||
try:
|
||||
response = structured_llm.invoke(messages)
|
||||
logger.info("[bold green]Code generated.[/bold green]")
|
||||
return {
|
||||
"code": response.parsed_code,
|
||||
"error": None, # Clear previous errors
|
||||
"iterations": state.get("iterations", 0) + 1
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate code: {str(e)}")
|
||||
raise e
|
||||
@@ -0,0 +1,96 @@
|
||||
import io
|
||||
import sys
|
||||
import traceback
|
||||
from contextlib import redirect_stdout
|
||||
from typing import TYPE_CHECKING
|
||||
import pandas as pd
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
from ea_chatbot.graph.workers.data_analyst.state import WorkerState
|
||||
from ea_chatbot.utils.db_client import DBClient
|
||||
from ea_chatbot.utils.vfs import VFSHelper
|
||||
from ea_chatbot.utils.logging import get_logger
|
||||
from ea_chatbot.config import Settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ea_chatbot.types import DBSettings
|
||||
|
||||
def executor_node(state: WorkerState) -> dict:
|
||||
"""Execute the Python code in the context of the Data Analyst worker."""
|
||||
code = state.get("code")
|
||||
logger = get_logger("data_analyst_worker:executor")
|
||||
|
||||
if not code:
|
||||
logger.error("No code provided to executor.")
|
||||
return {"error": "No code provided to executor."}
|
||||
|
||||
logger.info("Executing Python code...")
|
||||
settings = Settings()
|
||||
|
||||
db_settings: "DBSettings" = {
|
||||
"host": settings.db_host,
|
||||
"port": settings.db_port,
|
||||
"user": settings.db_user,
|
||||
"pswd": settings.db_pswd,
|
||||
"db": settings.db_name,
|
||||
"table": settings.db_table
|
||||
}
|
||||
|
||||
db_client = DBClient(settings=db_settings)
|
||||
|
||||
# Initialize the Virtual File System (VFS) helper with the snapshot from state
|
||||
vfs_state = dict(state.get("vfs_state", {}))
|
||||
vfs_helper = VFSHelper(vfs_state)
|
||||
|
||||
# Initialize local variables for execution
|
||||
local_vars = {
|
||||
'db': db_client,
|
||||
'plots': [],
|
||||
'pd': pd,
|
||||
'vfs': vfs_helper
|
||||
}
|
||||
|
||||
stdout_buffer = io.StringIO()
|
||||
error = None
|
||||
output = ""
|
||||
plots = []
|
||||
|
||||
try:
|
||||
with redirect_stdout(stdout_buffer):
|
||||
# Execute the code in the context of local_vars
|
||||
exec(code, {}, local_vars)
|
||||
|
||||
output = stdout_buffer.getvalue()
|
||||
|
||||
# Limit the output length if it's too long
|
||||
if output.count('\n') > 32:
|
||||
output = '\n'.join(output.split('\n')[:32]) + '\n...'
|
||||
|
||||
# Extract plots
|
||||
raw_plots = local_vars.get('plots', [])
|
||||
if isinstance(raw_plots, list):
|
||||
plots = [p for p in raw_plots if isinstance(p, Figure)]
|
||||
|
||||
logger.info(f"[bold green]Execution complete.[/bold green] Captured {len(plots)} plots.")
|
||||
|
||||
except Exception as e:
|
||||
# Capture the traceback
|
||||
exc_type, exc_value, tb = sys.exc_info()
|
||||
full_traceback = traceback.format_exc()
|
||||
|
||||
# Filter traceback to show only the relevant part (the executed string)
|
||||
filtered_tb_lines = [line for line in full_traceback.split('\n') if '<string>' in line]
|
||||
error = '\n'.join(filtered_tb_lines)
|
||||
if error:
|
||||
error += '\n'
|
||||
error += f"{exc_type.__name__ if exc_type else 'Exception'}: {exc_value}"
|
||||
|
||||
logger.error(f"Execution failed: {str(e)}")
|
||||
output = stdout_buffer.getvalue()
|
||||
|
||||
return {
|
||||
"output": output,
|
||||
"error": error,
|
||||
"plots": plots,
|
||||
"vfs_state": vfs_state
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
from typing import Dict, Any, List, Optional
|
||||
from ea_chatbot.graph.workers.data_analyst.state import WorkerState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||
|
||||
def summarizer_node(state: WorkerState) -> dict:
|
||||
"""Summarize the data analysis results for the Orchestrator."""
|
||||
task = state["task"]
|
||||
output = state.get("output", "")
|
||||
error = state.get("error")
|
||||
|
||||
settings = Settings()
|
||||
logger = get_logger("data_analyst_worker:summarizer")
|
||||
|
||||
logger.info("Summarizing analysis results for the Orchestrator...")
|
||||
|
||||
# We can use a smaller/faster model for this summary if needed
|
||||
llm = get_llm_model(
|
||||
settings.planner_llm, # Using planner model for summary logic
|
||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||
)
|
||||
|
||||
prompt = f"""You are a data analyst sub-agent. You have completed a sub-task.
|
||||
Task: {task}
|
||||
Execution Results: {output}
|
||||
Error Log (if any): {error}
|
||||
|
||||
Provide a concise summary of the findings or status for the top-level Orchestrator.
|
||||
If the execution failed after multiple retries, explain why concisely.
|
||||
Do NOT include the raw Python code, just the results of the analysis."""
|
||||
|
||||
try:
|
||||
response = llm.invoke(prompt)
|
||||
result = response.content if hasattr(response, "content") else str(response)
|
||||
logger.info("[bold green]Analysis results summarized.[/bold green]")
|
||||
return {
|
||||
"result": result
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to summarize results: {str(e)}")
|
||||
raise e
|
||||
@@ -0,0 +1,50 @@
|
||||
from langgraph.graph import StateGraph, END
|
||||
from ea_chatbot.graph.workers.data_analyst.state import WorkerState
|
||||
from ea_chatbot.graph.workers.data_analyst.nodes.coder import coder_node
|
||||
from ea_chatbot.graph.workers.data_analyst.nodes.executor import executor_node
|
||||
from ea_chatbot.graph.workers.data_analyst.nodes.summarizer import summarizer_node
|
||||
|
||||
def router(state: WorkerState) -> str:
|
||||
"""Routes the subgraph between coding, execution, and summarization."""
|
||||
error = state.get("error")
|
||||
iterations = state.get("iterations", 0)
|
||||
|
||||
if error and iterations < 3:
|
||||
# Retry with error correction
|
||||
return "coder"
|
||||
|
||||
# Either success or max retries reached
|
||||
return "summarizer"
|
||||
|
||||
def create_data_analyst_worker(
|
||||
coder=coder_node,
|
||||
executor=executor_node,
|
||||
summarizer=summarizer_node
|
||||
) -> StateGraph:
|
||||
"""Create the Data Analyst worker subgraph."""
|
||||
workflow = StateGraph(WorkerState)
|
||||
|
||||
# Add Nodes
|
||||
workflow.add_node("coder", coder)
|
||||
workflow.add_node("executor", executor)
|
||||
workflow.add_node("summarizer", summarizer)
|
||||
|
||||
# Set entry point
|
||||
workflow.set_entry_point("coder")
|
||||
|
||||
# Add Edges
|
||||
workflow.add_edge("coder", "executor")
|
||||
|
||||
# Add Conditional Edges
|
||||
workflow.add_conditional_edges(
|
||||
"executor",
|
||||
router,
|
||||
{
|
||||
"coder": "coder",
|
||||
"summarizer": "summarizer"
|
||||
}
|
||||
)
|
||||
|
||||
workflow.add_edge("summarizer", END)
|
||||
|
||||
return workflow.compile()
|
||||
81
backend/tests/test_data_analyst_worker.py
Normal file
81
backend/tests/test_data_analyst_worker.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
from ea_chatbot.graph.workers.data_analyst.workflow import create_data_analyst_worker, WorkerState
|
||||
|
||||
def test_data_analyst_worker_one_shot():
|
||||
"""Verify a successful one-shot execution of the worker subgraph."""
|
||||
mock_coder = MagicMock()
|
||||
mock_executor = MagicMock()
|
||||
mock_summarizer = MagicMock()
|
||||
|
||||
# Scenario: Coder -> Executor (Success) -> Summarizer -> END
|
||||
mock_coder.return_value = {"code": "print(1)", "error": None, "iterations": 1}
|
||||
mock_executor.return_value = {"output": "1\n", "error": None, "plots": []}
|
||||
mock_summarizer.return_value = {"result": "Result is 1"}
|
||||
|
||||
graph = create_data_analyst_worker(
|
||||
coder=mock_coder,
|
||||
executor=mock_executor,
|
||||
summarizer=mock_summarizer
|
||||
)
|
||||
|
||||
initial_state = WorkerState(
|
||||
messages=[],
|
||||
task="Calculate 1+1",
|
||||
code=None,
|
||||
output=None,
|
||||
error=None,
|
||||
iterations=0,
|
||||
plots=[],
|
||||
vfs_state={},
|
||||
result=None
|
||||
)
|
||||
|
||||
final_state = graph.invoke(initial_state)
|
||||
|
||||
assert final_state["result"] == "Result is 1"
|
||||
assert mock_coder.call_count == 1
|
||||
assert mock_executor.call_count == 1
|
||||
assert mock_summarizer.call_count == 1
|
||||
|
||||
def test_data_analyst_worker_retry():
|
||||
"""Verify that the worker retries on error."""
|
||||
mock_coder = MagicMock()
|
||||
mock_executor = MagicMock()
|
||||
mock_summarizer = MagicMock()
|
||||
|
||||
# Scenario: Coder (1) -> Executor (Error) -> Router (coder) -> Coder (2) -> Executor (Success) -> Summarizer -> END
|
||||
mock_coder.side_effect = [
|
||||
{"code": "error_code", "error": None, "iterations": 1},
|
||||
{"code": "fixed_code", "error": None, "iterations": 2}
|
||||
]
|
||||
mock_executor.side_effect = [
|
||||
{"output": "", "error": "NameError", "plots": []},
|
||||
{"output": "Success", "error": None, "plots": []}
|
||||
]
|
||||
mock_summarizer.return_value = {"result": "Fixed Result"}
|
||||
|
||||
graph = create_data_analyst_worker(
|
||||
coder=mock_coder,
|
||||
executor=mock_executor,
|
||||
summarizer=mock_summarizer
|
||||
)
|
||||
|
||||
initial_state = WorkerState(
|
||||
messages=[],
|
||||
task="Retry Task",
|
||||
code=None,
|
||||
output=None,
|
||||
error=None,
|
||||
iterations=0,
|
||||
plots=[],
|
||||
vfs_state={},
|
||||
result=None
|
||||
)
|
||||
|
||||
final_state = graph.invoke(initial_state)
|
||||
|
||||
assert final_state["result"] == "Fixed Result"
|
||||
assert mock_coder.call_count == 2
|
||||
assert mock_executor.call_count == 2
|
||||
assert mock_summarizer.call_count == 1
|
||||
Reference in New Issue
Block a user