82 lines
2.5 KiB
Python
82 lines
2.5 KiB
Python
import pytest
|
|
from unittest.mock import MagicMock
|
|
from ea_chatbot.graph.workers.data_analyst.workflow import create_data_analyst_worker, WorkerState
|
|
|
|
def test_data_analyst_worker_one_shot():
|
|
"""Verify a successful one-shot execution of the worker subgraph."""
|
|
mock_coder = MagicMock()
|
|
mock_executor = MagicMock()
|
|
mock_summarizer = MagicMock()
|
|
|
|
# Scenario: Coder -> Executor (Success) -> Summarizer -> END
|
|
mock_coder.return_value = {"code": "print(1)", "error": None, "iterations": 1}
|
|
mock_executor.return_value = {"output": "1\n", "error": None, "plots": []}
|
|
mock_summarizer.return_value = {"result": "Result is 1"}
|
|
|
|
graph = create_data_analyst_worker(
|
|
coder=mock_coder,
|
|
executor=mock_executor,
|
|
summarizer=mock_summarizer
|
|
)
|
|
|
|
initial_state = WorkerState(
|
|
messages=[],
|
|
task="Calculate 1+1",
|
|
code=None,
|
|
output=None,
|
|
error=None,
|
|
iterations=0,
|
|
plots=[],
|
|
vfs_state={},
|
|
result=None
|
|
)
|
|
|
|
final_state = graph.invoke(initial_state)
|
|
|
|
assert final_state["result"] == "Result is 1"
|
|
assert mock_coder.call_count == 1
|
|
assert mock_executor.call_count == 1
|
|
assert mock_summarizer.call_count == 1
|
|
|
|
def test_data_analyst_worker_retry():
|
|
"""Verify that the worker retries on error."""
|
|
mock_coder = MagicMock()
|
|
mock_executor = MagicMock()
|
|
mock_summarizer = MagicMock()
|
|
|
|
# Scenario: Coder (1) -> Executor (Error) -> Router (coder) -> Coder (2) -> Executor (Success) -> Summarizer -> END
|
|
mock_coder.side_effect = [
|
|
{"code": "error_code", "error": None, "iterations": 1},
|
|
{"code": "fixed_code", "error": None, "iterations": 2}
|
|
]
|
|
mock_executor.side_effect = [
|
|
{"output": "", "error": "NameError", "plots": []},
|
|
{"output": "Success", "error": None, "plots": []}
|
|
]
|
|
mock_summarizer.return_value = {"result": "Fixed Result"}
|
|
|
|
graph = create_data_analyst_worker(
|
|
coder=mock_coder,
|
|
executor=mock_executor,
|
|
summarizer=mock_summarizer
|
|
)
|
|
|
|
initial_state = WorkerState(
|
|
messages=[],
|
|
task="Retry Task",
|
|
code=None,
|
|
output=None,
|
|
error=None,
|
|
iterations=0,
|
|
plots=[],
|
|
vfs_state={},
|
|
result=None
|
|
)
|
|
|
|
final_state = graph.invoke(initial_state)
|
|
|
|
assert final_state["result"] == "Fixed Result"
|
|
assert mock_coder.call_count == 2
|
|
assert mock_executor.call_count == 2
|
|
assert mock_summarizer.call_count == 1
|