feat(workers): Implement input/output mapping for Data Analyst subgraph
This commit is contained in:
39
backend/src/ea_chatbot/graph/workers/data_analyst/mapping.py
Normal file
39
backend/src/ea_chatbot/graph/workers/data_analyst/mapping.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from typing import Dict, Any, List
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.graph.workers.data_analyst.state import WorkerState
|
||||
|
||||
def prepare_worker_input(state: AgentState) -> Dict[str, Any]:
|
||||
"""Prepare the initial state for the Data Analyst worker."""
|
||||
checklist = state.get("checklist", [])
|
||||
current_step = state.get("current_step", 0)
|
||||
|
||||
# Get the current task description
|
||||
task_desc = "Analyze data" # Default
|
||||
if 0 <= current_step < len(checklist):
|
||||
task_desc = checklist[current_step].get("task", task_desc)
|
||||
|
||||
return {
|
||||
"task": task_desc,
|
||||
"messages": [HumanMessage(content=task_desc)], # Start worker loop with the task
|
||||
"vfs_state": dict(state.get("vfs", {})),
|
||||
"iterations": 0,
|
||||
"plots": [],
|
||||
"code": None,
|
||||
"output": None,
|
||||
"error": None,
|
||||
"result": None
|
||||
}
|
||||
|
||||
def merge_worker_output(worker_state: WorkerState) -> Dict[str, Any]:
|
||||
"""Map worker results back to the global AgentState."""
|
||||
result = worker_state.get("result", "Analysis complete.")
|
||||
|
||||
# We bubble up the summary as an AI message
|
||||
updates = {
|
||||
"messages": [AIMessage(content=result)],
|
||||
"vfs": worker_state.get("vfs_state", {}),
|
||||
"plots": worker_state.get("plots", [])
|
||||
}
|
||||
|
||||
return updates
|
||||
53
backend/tests/test_data_analyst_mapping.py
Normal file
53
backend/tests/test_data_analyst_mapping.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.graph.workers.data_analyst.mapping import (
|
||||
prepare_worker_input,
|
||||
merge_worker_output
|
||||
)
|
||||
from ea_chatbot.graph.workers.data_analyst.state import WorkerState
|
||||
|
||||
def test_prepare_worker_input():
|
||||
"""Verify that we correctly map global state to worker input."""
|
||||
global_state = AgentState(
|
||||
messages=[HumanMessage(content="global message")],
|
||||
question="original question",
|
||||
checklist=[{"task": "Worker Task", "status": "pending"}],
|
||||
current_step=0,
|
||||
vfs={"old.txt": "old data"},
|
||||
plots=[],
|
||||
dfs={},
|
||||
next_action="test",
|
||||
iterations=0
|
||||
)
|
||||
|
||||
worker_input = prepare_worker_input(global_state)
|
||||
|
||||
assert worker_input["task"] == "Worker Task"
|
||||
assert "old.txt" in worker_input["vfs_state"]
|
||||
# Internal worker messages should start fresh or with the task
|
||||
assert len(worker_input["messages"]) == 1
|
||||
assert worker_input["messages"][0].content == "Worker Task"
|
||||
|
||||
def test_merge_worker_output():
|
||||
"""Verify that we correctly merge worker results back to global state."""
|
||||
worker_state = WorkerState(
|
||||
messages=[HumanMessage(content="internal"), AIMessage(content="summary")],
|
||||
task="Worker Task",
|
||||
result="Finished analysis",
|
||||
plots=["plot1"],
|
||||
vfs_state={"new.txt": "new data"},
|
||||
iterations=2
|
||||
)
|
||||
|
||||
updates = merge_worker_output(worker_state)
|
||||
|
||||
# We expect the 'result' to be added as an AI message to global history
|
||||
assert len(updates["messages"]) == 1
|
||||
assert updates["messages"][0].content == "Finished analysis"
|
||||
|
||||
# VFS should be updated
|
||||
assert "new.txt" in updates["vfs"]
|
||||
|
||||
# Plots should be bubbled up
|
||||
assert len(updates["plots"]) == 1
|
||||
assert updates["plots"][0] == "plot1"
|
||||
Reference in New Issue
Block a user