diff --git a/backend/src/ea_chatbot/graph/workers/data_analyst/mapping.py b/backend/src/ea_chatbot/graph/workers/data_analyst/mapping.py new file mode 100644 index 0000000..99851be --- /dev/null +++ b/backend/src/ea_chatbot/graph/workers/data_analyst/mapping.py @@ -0,0 +1,39 @@ +from typing import Dict, Any, List +from langchain_core.messages import HumanMessage, AIMessage +from ea_chatbot.graph.state import AgentState +from ea_chatbot.graph.workers.data_analyst.state import WorkerState + +def prepare_worker_input(state: AgentState) -> Dict[str, Any]: + """Prepare the initial state for the Data Analyst worker.""" + checklist = state.get("checklist", []) + current_step = state.get("current_step", 0) + + # Get the current task description + task_desc = "Analyze data" # Default + if 0 <= current_step < len(checklist): + task_desc = checklist[current_step].get("task", task_desc) + + return { + "task": task_desc, + "messages": [HumanMessage(content=task_desc)], # Start worker loop with the task + "vfs_state": dict(state.get("vfs", {})), + "iterations": 0, + "plots": [], + "code": None, + "output": None, + "error": None, + "result": None + } + +def merge_worker_output(worker_state: WorkerState) -> Dict[str, Any]: + """Map worker results back to the global AgentState.""" + result = worker_state.get("result", "Analysis complete.") + + # We bubble up the summary as an AI message + updates = { + "messages": [AIMessage(content=result)], + "vfs": worker_state.get("vfs_state", {}), + "plots": worker_state.get("plots", []) + } + + return updates diff --git a/backend/tests/test_data_analyst_mapping.py b/backend/tests/test_data_analyst_mapping.py new file mode 100644 index 0000000..e4fdc87 --- /dev/null +++ b/backend/tests/test_data_analyst_mapping.py @@ -0,0 +1,53 @@ +from langchain_core.messages import HumanMessage, AIMessage +from ea_chatbot.graph.state import AgentState +from ea_chatbot.graph.workers.data_analyst.mapping import ( + prepare_worker_input, + merge_worker_output +) +from ea_chatbot.graph.workers.data_analyst.state import WorkerState + +def test_prepare_worker_input(): + """Verify that we correctly map global state to worker input.""" + global_state = AgentState( + messages=[HumanMessage(content="global message")], + question="original question", + checklist=[{"task": "Worker Task", "status": "pending"}], + current_step=0, + vfs={"old.txt": "old data"}, + plots=[], + dfs={}, + next_action="test", + iterations=0 + ) + + worker_input = prepare_worker_input(global_state) + + assert worker_input["task"] == "Worker Task" + assert "old.txt" in worker_input["vfs_state"] + # Internal worker messages should start fresh or with the task + assert len(worker_input["messages"]) == 1 + assert worker_input["messages"][0].content == "Worker Task" + +def test_merge_worker_output(): + """Verify that we correctly merge worker results back to global state.""" + worker_state = WorkerState( + messages=[HumanMessage(content="internal"), AIMessage(content="summary")], + task="Worker Task", + result="Finished analysis", + plots=["plot1"], + vfs_state={"new.txt": "new data"}, + iterations=2 + ) + + updates = merge_worker_output(worker_state) + + # We expect the 'result' to be added as an AI message to global history + assert len(updates["messages"]) == 1 + assert updates["messages"][0].content == "Finished analysis" + + # VFS should be updated + assert "new.txt" in updates["vfs"] + + # Plots should be bubbled up + assert len(updates["plots"]) == 1 + assert updates["plots"][0] == "plot1"