import pytest import json import base64 import io import matplotlib.pyplot as plt from fastapi.testclient import TestClient from unittest.mock import MagicMock, patch, AsyncMock from ea_chatbot.api.main import app from ea_chatbot.api.dependencies import get_current_user from ea_chatbot.history.models import User, Conversation from ea_chatbot.api.routers.agent import stream_agent_events from langchain_core.messages import AIMessage @pytest.fixture def client(): return TestClient(app) @pytest.fixture def mock_user(): return User(id="user-123", username="test@example.com") @pytest.mark.asyncio async def test_stream_agent_events_all_features(): """ Test stream_agent_events generator logic with: - stream chunks - plots - final response - summary update """ # Prepare mock events fig, ax = plt.subplots() ax.plot([1, 2], [1, 2]) mock_events = [ # Stream chunk { "event": "on_chat_model_stream", "metadata": {"langgraph_node": "synthesizer"}, "data": {"chunk": AIMessage(content="Hello ")} }, { "event": "on_chat_model_stream", "metadata": {"langgraph_node": "synthesizer"}, "data": {"chunk": AIMessage(content="world")} }, # Plot event - with nested subgraph it might bubble up or come directly from data_analyst_worker # Let's mock it coming from the data_analyst_worker on_chain_end { "event": "on_chain_end", "name": "data_analyst_worker", "data": {"output": {"plots": [fig]}} }, # Final response { "event": "on_chain_end", "name": "synthesizer", "data": {"output": {"messages": [AIMessage(content="Hello world final")]}} }, # Summary update { "event": "on_chain_end", "name": "summarize_conversation", "data": {"output": {"summary": "A conversation about hello world"}} } ] async def mock_astream_events(*args, **kwargs): for event in mock_events: yield event with patch("ea_chatbot.api.routers.agent.app.astream_events", side_effect=mock_astream_events), \ patch("ea_chatbot.api.routers.agent.get_checkpointer") as mock_cp, \ patch("ea_chatbot.api.routers.agent.history_manager") as mock_hm: mock_cp.return_value.__aenter__.return_value = AsyncMock() gen = stream_agent_events( message="hi", thread_id="t1", user_id="u1", summary="old", messages=[] ) results = [] async for item in gen: if item.startswith("data: "): results.append(json.loads(item[6:])) # Verify results assert any(r.get("type") == "on_chat_model_stream" for r in results) # Verify plot was encoded plot_event = next(r for r in results if r.get("name") == "data_analyst_worker") assert "encoded_plots" in plot_event["data"] assert len(plot_event["data"]["encoded_plots"]) == 1 # Verify summary update was called mock_hm.update_conversation_summary.assert_called_once_with("t1", "A conversation about hello world") # Verify final message saved to DB # Final response should be "Hello world final" from summarizer node mock_hm.add_message.assert_called() args, kwargs = mock_hm.add_message.call_args assert args[1] == "assistant" assert args[2] == "Hello world final" assert len(kwargs.get("plots", [])) == 1 @pytest.mark.asyncio async def test_stream_agent_events_exception(): """Test exception handling in stream_agent_events.""" async def mock_astream_events_fail(*args, **kwargs): raise ValueError("Something went wrong") yield {} # Never reached with patch("ea_chatbot.api.routers.agent.app.astream_events", side_effect=mock_astream_events_fail), \ patch("ea_chatbot.api.routers.agent.get_checkpointer") as mock_cp, \ patch("ea_chatbot.api.routers.agent.history_manager") as mock_hm: mock_cp.return_value.__aenter__.return_value = AsyncMock() gen = stream_agent_events("hi", "t1", "u1", "old", []) results = [] async for item in gen: if item.startswith("data: "): results.append(json.loads(item[6:])) assert any(r.get("type") == "error" for r in results) assert "Something went wrong" in results[-1]["data"]["message"] # Verify error message saved to DB mock_hm.add_message.assert_called() args, kwargs = mock_hm.add_message.call_args assert "Something went wrong" in args[2] def test_chat_stream_404_403(client, mock_user): """Test 404 and 403 cases in chat_stream router.""" from ea_chatbot.api.dependencies import get_current_user app.dependency_overrides[get_current_user] = lambda: mock_user try: with patch("ea_chatbot.api.routers.agent.history_manager") as mock_hm: mock_session = MagicMock() mock_hm.get_session.return_value.__enter__.return_value = mock_session # Case 1: 404 Not Found mock_session.get.return_value = None response = client.post("/api/v1/chat/stream", json={"message": "hi", "thread_id": "none"}) assert response.status_code == 404 # Case 2: 403 Forbidden mock_session.get.return_value = Conversation(id="c1", user_id="other-user") response = client.post("/api/v1/chat/stream", json={"message": "hi", "thread_id": "c1"}) assert response.status_code == 403 finally: app.dependency_overrides.clear()