feat(workers): Extract Coder and Executor nodes into Data Analyst worker subgraph
This commit is contained in:
81
backend/tests/test_data_analyst_worker.py
Normal file
81
backend/tests/test_data_analyst_worker.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
from ea_chatbot.graph.workers.data_analyst.workflow import create_data_analyst_worker, WorkerState
|
||||
|
||||
def test_data_analyst_worker_one_shot():
|
||||
"""Verify a successful one-shot execution of the worker subgraph."""
|
||||
mock_coder = MagicMock()
|
||||
mock_executor = MagicMock()
|
||||
mock_summarizer = MagicMock()
|
||||
|
||||
# Scenario: Coder -> Executor (Success) -> Summarizer -> END
|
||||
mock_coder.return_value = {"code": "print(1)", "error": None, "iterations": 1}
|
||||
mock_executor.return_value = {"output": "1\n", "error": None, "plots": []}
|
||||
mock_summarizer.return_value = {"result": "Result is 1"}
|
||||
|
||||
graph = create_data_analyst_worker(
|
||||
coder=mock_coder,
|
||||
executor=mock_executor,
|
||||
summarizer=mock_summarizer
|
||||
)
|
||||
|
||||
initial_state = WorkerState(
|
||||
messages=[],
|
||||
task="Calculate 1+1",
|
||||
code=None,
|
||||
output=None,
|
||||
error=None,
|
||||
iterations=0,
|
||||
plots=[],
|
||||
vfs_state={},
|
||||
result=None
|
||||
)
|
||||
|
||||
final_state = graph.invoke(initial_state)
|
||||
|
||||
assert final_state["result"] == "Result is 1"
|
||||
assert mock_coder.call_count == 1
|
||||
assert mock_executor.call_count == 1
|
||||
assert mock_summarizer.call_count == 1
|
||||
|
||||
def test_data_analyst_worker_retry():
|
||||
"""Verify that the worker retries on error."""
|
||||
mock_coder = MagicMock()
|
||||
mock_executor = MagicMock()
|
||||
mock_summarizer = MagicMock()
|
||||
|
||||
# Scenario: Coder (1) -> Executor (Error) -> Router (coder) -> Coder (2) -> Executor (Success) -> Summarizer -> END
|
||||
mock_coder.side_effect = [
|
||||
{"code": "error_code", "error": None, "iterations": 1},
|
||||
{"code": "fixed_code", "error": None, "iterations": 2}
|
||||
]
|
||||
mock_executor.side_effect = [
|
||||
{"output": "", "error": "NameError", "plots": []},
|
||||
{"output": "Success", "error": None, "plots": []}
|
||||
]
|
||||
mock_summarizer.return_value = {"result": "Fixed Result"}
|
||||
|
||||
graph = create_data_analyst_worker(
|
||||
coder=mock_coder,
|
||||
executor=mock_executor,
|
||||
summarizer=mock_summarizer
|
||||
)
|
||||
|
||||
initial_state = WorkerState(
|
||||
messages=[],
|
||||
task="Retry Task",
|
||||
code=None,
|
||||
output=None,
|
||||
error=None,
|
||||
iterations=0,
|
||||
plots=[],
|
||||
vfs_state={},
|
||||
result=None
|
||||
)
|
||||
|
||||
final_state = graph.invoke(initial_state)
|
||||
|
||||
assert final_state["result"] == "Fixed Result"
|
||||
assert mock_coder.call_count == 2
|
||||
assert mock_executor.call_count == 2
|
||||
assert mock_summarizer.call_count == 1
|
||||
Reference in New Issue
Block a user