54 lines
1.8 KiB
Python
54 lines
1.8 KiB
Python
from langchain_core.messages import HumanMessage, AIMessage
|
|
from ea_chatbot.graph.state import AgentState
|
|
from ea_chatbot.graph.workers.data_analyst.mapping import (
|
|
prepare_worker_input,
|
|
merge_worker_output
|
|
)
|
|
from ea_chatbot.graph.workers.data_analyst.state import WorkerState
|
|
|
|
def test_prepare_worker_input():
|
|
"""Verify that we correctly map global state to worker input."""
|
|
global_state = AgentState(
|
|
messages=[HumanMessage(content="global message")],
|
|
question="original question",
|
|
checklist=[{"task": "Worker Task", "status": "pending"}],
|
|
current_step=0,
|
|
vfs={"old.txt": "old data"},
|
|
plots=[],
|
|
dfs={},
|
|
next_action="test",
|
|
iterations=0
|
|
)
|
|
|
|
worker_input = prepare_worker_input(global_state)
|
|
|
|
assert worker_input["task"] == "Worker Task"
|
|
assert "old.txt" in worker_input["vfs_state"]
|
|
# Internal worker messages should start fresh or with the task
|
|
assert len(worker_input["messages"]) == 1
|
|
assert worker_input["messages"][0].content == "Worker Task"
|
|
|
|
def test_merge_worker_output():
|
|
"""Verify that we correctly merge worker results back to global state."""
|
|
worker_state = WorkerState(
|
|
messages=[HumanMessage(content="internal"), AIMessage(content="summary")],
|
|
task="Worker Task",
|
|
result="Finished analysis",
|
|
plots=["plot1"],
|
|
vfs_state={"new.txt": "new data"},
|
|
iterations=2
|
|
)
|
|
|
|
updates = merge_worker_output(worker_state)
|
|
|
|
# We expect the 'result' to be added as an AI message to global history
|
|
assert len(updates["messages"]) == 1
|
|
assert updates["messages"][0].content == "Finished analysis"
|
|
|
|
# VFS should be updated
|
|
assert "new.txt" in updates["vfs"]
|
|
|
|
# Plots should be bubbled up
|
|
assert len(updates["plots"]) == 1
|
|
assert updates["plots"][0] == "plot1"
|