feat(workers): Extract Coder and Executor nodes into Data Analyst worker subgraph
This commit is contained in:
@@ -0,0 +1,61 @@
|
|||||||
|
from typing import Dict, Any, List, Optional
|
||||||
|
from ea_chatbot.graph.workers.data_analyst.state import WorkerState
|
||||||
|
from ea_chatbot.config import Settings
|
||||||
|
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||||
|
from ea_chatbot.utils import database_inspection
|
||||||
|
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||||
|
from ea_chatbot.graph.prompts.coder import CODE_GENERATOR_PROMPT
|
||||||
|
from ea_chatbot.schemas import CodeGenerationResponse
|
||||||
|
|
||||||
|
def coder_node(state: WorkerState) -> dict:
|
||||||
|
"""Generate Python code based on the sub-task assigned to the worker."""
|
||||||
|
task = state["task"]
|
||||||
|
output = state.get("output", "None")
|
||||||
|
error = state.get("error", "None")
|
||||||
|
vfs_state = state.get("vfs_state", {})
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
logger = get_logger("data_analyst_worker:coder")
|
||||||
|
|
||||||
|
logger.info(f"Generating Python code for task: {task[:50]}...")
|
||||||
|
|
||||||
|
# We can use the configured 'coder_llm' for this node
|
||||||
|
llm = get_llm_model(
|
||||||
|
settings.coder_llm,
|
||||||
|
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||||
|
)
|
||||||
|
structured_llm = llm.with_structured_output(CodeGenerationResponse)
|
||||||
|
|
||||||
|
# Data summary for context
|
||||||
|
database_description = database_inspection.get_data_summary(data_dir=settings.data_dir) or "No data available."
|
||||||
|
|
||||||
|
# VFS Summary: Let the LLM know which files are available in-memory
|
||||||
|
vfs_summary = "Available in-memory files (VFS):\n"
|
||||||
|
if vfs_state:
|
||||||
|
for filename, data in vfs_state.items():
|
||||||
|
meta = data.get("metadata", {})
|
||||||
|
vfs_summary += f"- {filename} ({meta.get('type', 'unknown')})\n"
|
||||||
|
else:
|
||||||
|
vfs_summary += "- None"
|
||||||
|
|
||||||
|
# Reuse the global prompt but adapt 'question' and 'plan' labels
|
||||||
|
# For a sub-task worker, 'task' effectively replaces the high-level 'plan'
|
||||||
|
messages = CODE_GENERATOR_PROMPT.format_messages(
|
||||||
|
question=task,
|
||||||
|
plan="Focus on the specific task below.",
|
||||||
|
database_description=database_description,
|
||||||
|
code_exec_results=f"Output: {output}\nError: {error}\n\n{vfs_summary}",
|
||||||
|
example_code=""
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = structured_llm.invoke(messages)
|
||||||
|
logger.info("[bold green]Code generated.[/bold green]")
|
||||||
|
return {
|
||||||
|
"code": response.parsed_code,
|
||||||
|
"error": None, # Clear previous errors
|
||||||
|
"iterations": state.get("iterations", 0) + 1
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to generate code: {str(e)}")
|
||||||
|
raise e
|
||||||
@@ -0,0 +1,96 @@
|
|||||||
|
import io
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
from contextlib import redirect_stdout
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
import pandas as pd
|
||||||
|
from matplotlib.figure import Figure
|
||||||
|
|
||||||
|
from ea_chatbot.graph.workers.data_analyst.state import WorkerState
|
||||||
|
from ea_chatbot.utils.db_client import DBClient
|
||||||
|
from ea_chatbot.utils.vfs import VFSHelper
|
||||||
|
from ea_chatbot.utils.logging import get_logger
|
||||||
|
from ea_chatbot.config import Settings
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ea_chatbot.types import DBSettings
|
||||||
|
|
||||||
|
def executor_node(state: WorkerState) -> dict:
|
||||||
|
"""Execute the Python code in the context of the Data Analyst worker."""
|
||||||
|
code = state.get("code")
|
||||||
|
logger = get_logger("data_analyst_worker:executor")
|
||||||
|
|
||||||
|
if not code:
|
||||||
|
logger.error("No code provided to executor.")
|
||||||
|
return {"error": "No code provided to executor."}
|
||||||
|
|
||||||
|
logger.info("Executing Python code...")
|
||||||
|
settings = Settings()
|
||||||
|
|
||||||
|
db_settings: "DBSettings" = {
|
||||||
|
"host": settings.db_host,
|
||||||
|
"port": settings.db_port,
|
||||||
|
"user": settings.db_user,
|
||||||
|
"pswd": settings.db_pswd,
|
||||||
|
"db": settings.db_name,
|
||||||
|
"table": settings.db_table
|
||||||
|
}
|
||||||
|
|
||||||
|
db_client = DBClient(settings=db_settings)
|
||||||
|
|
||||||
|
# Initialize the Virtual File System (VFS) helper with the snapshot from state
|
||||||
|
vfs_state = dict(state.get("vfs_state", {}))
|
||||||
|
vfs_helper = VFSHelper(vfs_state)
|
||||||
|
|
||||||
|
# Initialize local variables for execution
|
||||||
|
local_vars = {
|
||||||
|
'db': db_client,
|
||||||
|
'plots': [],
|
||||||
|
'pd': pd,
|
||||||
|
'vfs': vfs_helper
|
||||||
|
}
|
||||||
|
|
||||||
|
stdout_buffer = io.StringIO()
|
||||||
|
error = None
|
||||||
|
output = ""
|
||||||
|
plots = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
with redirect_stdout(stdout_buffer):
|
||||||
|
# Execute the code in the context of local_vars
|
||||||
|
exec(code, {}, local_vars)
|
||||||
|
|
||||||
|
output = stdout_buffer.getvalue()
|
||||||
|
|
||||||
|
# Limit the output length if it's too long
|
||||||
|
if output.count('\n') > 32:
|
||||||
|
output = '\n'.join(output.split('\n')[:32]) + '\n...'
|
||||||
|
|
||||||
|
# Extract plots
|
||||||
|
raw_plots = local_vars.get('plots', [])
|
||||||
|
if isinstance(raw_plots, list):
|
||||||
|
plots = [p for p in raw_plots if isinstance(p, Figure)]
|
||||||
|
|
||||||
|
logger.info(f"[bold green]Execution complete.[/bold green] Captured {len(plots)} plots.")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Capture the traceback
|
||||||
|
exc_type, exc_value, tb = sys.exc_info()
|
||||||
|
full_traceback = traceback.format_exc()
|
||||||
|
|
||||||
|
# Filter traceback to show only the relevant part (the executed string)
|
||||||
|
filtered_tb_lines = [line for line in full_traceback.split('\n') if '<string>' in line]
|
||||||
|
error = '\n'.join(filtered_tb_lines)
|
||||||
|
if error:
|
||||||
|
error += '\n'
|
||||||
|
error += f"{exc_type.__name__ if exc_type else 'Exception'}: {exc_value}"
|
||||||
|
|
||||||
|
logger.error(f"Execution failed: {str(e)}")
|
||||||
|
output = stdout_buffer.getvalue()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"output": output,
|
||||||
|
"error": error,
|
||||||
|
"plots": plots,
|
||||||
|
"vfs_state": vfs_state
|
||||||
|
}
|
||||||
@@ -0,0 +1,42 @@
|
|||||||
|
from typing import Dict, Any, List, Optional
|
||||||
|
from ea_chatbot.graph.workers.data_analyst.state import WorkerState
|
||||||
|
from ea_chatbot.config import Settings
|
||||||
|
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||||
|
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||||
|
|
||||||
|
def summarizer_node(state: WorkerState) -> dict:
|
||||||
|
"""Summarize the data analysis results for the Orchestrator."""
|
||||||
|
task = state["task"]
|
||||||
|
output = state.get("output", "")
|
||||||
|
error = state.get("error")
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
logger = get_logger("data_analyst_worker:summarizer")
|
||||||
|
|
||||||
|
logger.info("Summarizing analysis results for the Orchestrator...")
|
||||||
|
|
||||||
|
# We can use a smaller/faster model for this summary if needed
|
||||||
|
llm = get_llm_model(
|
||||||
|
settings.planner_llm, # Using planner model for summary logic
|
||||||
|
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = f"""You are a data analyst sub-agent. You have completed a sub-task.
|
||||||
|
Task: {task}
|
||||||
|
Execution Results: {output}
|
||||||
|
Error Log (if any): {error}
|
||||||
|
|
||||||
|
Provide a concise summary of the findings or status for the top-level Orchestrator.
|
||||||
|
If the execution failed after multiple retries, explain why concisely.
|
||||||
|
Do NOT include the raw Python code, just the results of the analysis."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = llm.invoke(prompt)
|
||||||
|
result = response.content if hasattr(response, "content") else str(response)
|
||||||
|
logger.info("[bold green]Analysis results summarized.[/bold green]")
|
||||||
|
return {
|
||||||
|
"result": result
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to summarize results: {str(e)}")
|
||||||
|
raise e
|
||||||
@@ -0,0 +1,50 @@
|
|||||||
|
from langgraph.graph import StateGraph, END
|
||||||
|
from ea_chatbot.graph.workers.data_analyst.state import WorkerState
|
||||||
|
from ea_chatbot.graph.workers.data_analyst.nodes.coder import coder_node
|
||||||
|
from ea_chatbot.graph.workers.data_analyst.nodes.executor import executor_node
|
||||||
|
from ea_chatbot.graph.workers.data_analyst.nodes.summarizer import summarizer_node
|
||||||
|
|
||||||
|
def router(state: WorkerState) -> str:
|
||||||
|
"""Routes the subgraph between coding, execution, and summarization."""
|
||||||
|
error = state.get("error")
|
||||||
|
iterations = state.get("iterations", 0)
|
||||||
|
|
||||||
|
if error and iterations < 3:
|
||||||
|
# Retry with error correction
|
||||||
|
return "coder"
|
||||||
|
|
||||||
|
# Either success or max retries reached
|
||||||
|
return "summarizer"
|
||||||
|
|
||||||
|
def create_data_analyst_worker(
|
||||||
|
coder=coder_node,
|
||||||
|
executor=executor_node,
|
||||||
|
summarizer=summarizer_node
|
||||||
|
) -> StateGraph:
|
||||||
|
"""Create the Data Analyst worker subgraph."""
|
||||||
|
workflow = StateGraph(WorkerState)
|
||||||
|
|
||||||
|
# Add Nodes
|
||||||
|
workflow.add_node("coder", coder)
|
||||||
|
workflow.add_node("executor", executor)
|
||||||
|
workflow.add_node("summarizer", summarizer)
|
||||||
|
|
||||||
|
# Set entry point
|
||||||
|
workflow.set_entry_point("coder")
|
||||||
|
|
||||||
|
# Add Edges
|
||||||
|
workflow.add_edge("coder", "executor")
|
||||||
|
|
||||||
|
# Add Conditional Edges
|
||||||
|
workflow.add_conditional_edges(
|
||||||
|
"executor",
|
||||||
|
router,
|
||||||
|
{
|
||||||
|
"coder": "coder",
|
||||||
|
"summarizer": "summarizer"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
workflow.add_edge("summarizer", END)
|
||||||
|
|
||||||
|
return workflow.compile()
|
||||||
81
backend/tests/test_data_analyst_worker.py
Normal file
81
backend/tests/test_data_analyst_worker.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
from ea_chatbot.graph.workers.data_analyst.workflow import create_data_analyst_worker, WorkerState
|
||||||
|
|
||||||
|
def test_data_analyst_worker_one_shot():
|
||||||
|
"""Verify a successful one-shot execution of the worker subgraph."""
|
||||||
|
mock_coder = MagicMock()
|
||||||
|
mock_executor = MagicMock()
|
||||||
|
mock_summarizer = MagicMock()
|
||||||
|
|
||||||
|
# Scenario: Coder -> Executor (Success) -> Summarizer -> END
|
||||||
|
mock_coder.return_value = {"code": "print(1)", "error": None, "iterations": 1}
|
||||||
|
mock_executor.return_value = {"output": "1\n", "error": None, "plots": []}
|
||||||
|
mock_summarizer.return_value = {"result": "Result is 1"}
|
||||||
|
|
||||||
|
graph = create_data_analyst_worker(
|
||||||
|
coder=mock_coder,
|
||||||
|
executor=mock_executor,
|
||||||
|
summarizer=mock_summarizer
|
||||||
|
)
|
||||||
|
|
||||||
|
initial_state = WorkerState(
|
||||||
|
messages=[],
|
||||||
|
task="Calculate 1+1",
|
||||||
|
code=None,
|
||||||
|
output=None,
|
||||||
|
error=None,
|
||||||
|
iterations=0,
|
||||||
|
plots=[],
|
||||||
|
vfs_state={},
|
||||||
|
result=None
|
||||||
|
)
|
||||||
|
|
||||||
|
final_state = graph.invoke(initial_state)
|
||||||
|
|
||||||
|
assert final_state["result"] == "Result is 1"
|
||||||
|
assert mock_coder.call_count == 1
|
||||||
|
assert mock_executor.call_count == 1
|
||||||
|
assert mock_summarizer.call_count == 1
|
||||||
|
|
||||||
|
def test_data_analyst_worker_retry():
|
||||||
|
"""Verify that the worker retries on error."""
|
||||||
|
mock_coder = MagicMock()
|
||||||
|
mock_executor = MagicMock()
|
||||||
|
mock_summarizer = MagicMock()
|
||||||
|
|
||||||
|
# Scenario: Coder (1) -> Executor (Error) -> Router (coder) -> Coder (2) -> Executor (Success) -> Summarizer -> END
|
||||||
|
mock_coder.side_effect = [
|
||||||
|
{"code": "error_code", "error": None, "iterations": 1},
|
||||||
|
{"code": "fixed_code", "error": None, "iterations": 2}
|
||||||
|
]
|
||||||
|
mock_executor.side_effect = [
|
||||||
|
{"output": "", "error": "NameError", "plots": []},
|
||||||
|
{"output": "Success", "error": None, "plots": []}
|
||||||
|
]
|
||||||
|
mock_summarizer.return_value = {"result": "Fixed Result"}
|
||||||
|
|
||||||
|
graph = create_data_analyst_worker(
|
||||||
|
coder=mock_coder,
|
||||||
|
executor=mock_executor,
|
||||||
|
summarizer=mock_summarizer
|
||||||
|
)
|
||||||
|
|
||||||
|
initial_state = WorkerState(
|
||||||
|
messages=[],
|
||||||
|
task="Retry Task",
|
||||||
|
code=None,
|
||||||
|
output=None,
|
||||||
|
error=None,
|
||||||
|
iterations=0,
|
||||||
|
plots=[],
|
||||||
|
vfs_state={},
|
||||||
|
result=None
|
||||||
|
)
|
||||||
|
|
||||||
|
final_state = graph.invoke(initial_state)
|
||||||
|
|
||||||
|
assert final_state["result"] == "Fixed Result"
|
||||||
|
assert mock_coder.call_count == 2
|
||||||
|
assert mock_executor.call_count == 2
|
||||||
|
assert mock_summarizer.call_count == 1
|
||||||
Reference in New Issue
Block a user