feat(workers): Extract Coder and Executor nodes into Data Analyst worker subgraph

This commit is contained in:
Yunxiao Xu
2026-02-23 04:58:46 -08:00
parent 5324cbe851
commit cb045504d1
6 changed files with 330 additions and 0 deletions

View File

@@ -0,0 +1,81 @@
import pytest
from unittest.mock import MagicMock
from ea_chatbot.graph.workers.data_analyst.workflow import create_data_analyst_worker, WorkerState
def test_data_analyst_worker_one_shot():
"""Verify a successful one-shot execution of the worker subgraph."""
mock_coder = MagicMock()
mock_executor = MagicMock()
mock_summarizer = MagicMock()
# Scenario: Coder -> Executor (Success) -> Summarizer -> END
mock_coder.return_value = {"code": "print(1)", "error": None, "iterations": 1}
mock_executor.return_value = {"output": "1\n", "error": None, "plots": []}
mock_summarizer.return_value = {"result": "Result is 1"}
graph = create_data_analyst_worker(
coder=mock_coder,
executor=mock_executor,
summarizer=mock_summarizer
)
initial_state = WorkerState(
messages=[],
task="Calculate 1+1",
code=None,
output=None,
error=None,
iterations=0,
plots=[],
vfs_state={},
result=None
)
final_state = graph.invoke(initial_state)
assert final_state["result"] == "Result is 1"
assert mock_coder.call_count == 1
assert mock_executor.call_count == 1
assert mock_summarizer.call_count == 1
def test_data_analyst_worker_retry():
"""Verify that the worker retries on error."""
mock_coder = MagicMock()
mock_executor = MagicMock()
mock_summarizer = MagicMock()
# Scenario: Coder (1) -> Executor (Error) -> Router (coder) -> Coder (2) -> Executor (Success) -> Summarizer -> END
mock_coder.side_effect = [
{"code": "error_code", "error": None, "iterations": 1},
{"code": "fixed_code", "error": None, "iterations": 2}
]
mock_executor.side_effect = [
{"output": "", "error": "NameError", "plots": []},
{"output": "Success", "error": None, "plots": []}
]
mock_summarizer.return_value = {"result": "Fixed Result"}
graph = create_data_analyst_worker(
coder=mock_coder,
executor=mock_executor,
summarizer=mock_summarizer
)
initial_state = WorkerState(
messages=[],
task="Retry Task",
code=None,
output=None,
error=None,
iterations=0,
plots=[],
vfs_state={},
result=None
)
final_state = graph.invoke(initial_state)
assert final_state["result"] == "Fixed Result"
assert mock_coder.call_count == 2
assert mock_executor.call_count == 2
assert mock_summarizer.call_count == 1