import pytest from unittest.mock import MagicMock from ea_chatbot.graph.workers.data_analyst.workflow import create_data_analyst_worker, WorkerState def test_data_analyst_worker_one_shot(): """Verify a successful one-shot execution of the worker subgraph.""" mock_coder = MagicMock() mock_executor = MagicMock() mock_summarizer = MagicMock() # Scenario: Coder -> Executor (Success) -> Summarizer -> END mock_coder.return_value = {"code": "print(1)", "error": None, "iterations": 1} mock_executor.return_value = {"output": "1\n", "error": None, "plots": []} mock_summarizer.return_value = {"result": "Result is 1"} graph = create_data_analyst_worker( coder=mock_coder, executor=mock_executor, summarizer=mock_summarizer ) initial_state = WorkerState( messages=[], task="Calculate 1+1", code=None, output=None, error=None, iterations=0, plots=[], vfs_state={}, result=None ) final_state = graph.invoke(initial_state) assert final_state["result"] == "Result is 1" assert mock_coder.call_count == 1 assert mock_executor.call_count == 1 assert mock_summarizer.call_count == 1 def test_data_analyst_worker_retry(): """Verify that the worker retries on error.""" mock_coder = MagicMock() mock_executor = MagicMock() mock_summarizer = MagicMock() # Scenario: Coder (1) -> Executor (Error) -> Router (coder) -> Coder (2) -> Executor (Success) -> Summarizer -> END mock_coder.side_effect = [ {"code": "error_code", "error": None, "iterations": 1}, {"code": "fixed_code", "error": None, "iterations": 2} ] mock_executor.side_effect = [ {"output": "", "error": "NameError", "plots": []}, {"output": "Success", "error": None, "plots": []} ] mock_summarizer.return_value = {"result": "Fixed Result"} graph = create_data_analyst_worker( coder=mock_coder, executor=mock_executor, summarizer=mock_summarizer ) initial_state = WorkerState( messages=[], task="Retry Task", code=None, output=None, error=None, iterations=0, plots=[], vfs_state={}, result=None ) final_state = graph.invoke(initial_state) assert final_state["result"] == "Fixed Result" assert mock_coder.call_count == 2 assert mock_executor.call_count == 2 assert mock_summarizer.call_count == 1