From 1cf00d0b3f68748743a21660db883d87ceb71d38 Mon Sep 17 00:00:00 2001 From: Yunxiao Xu Date: Sun, 15 Feb 2026 18:52:26 -0800 Subject: [PATCH] test(api): Add extended agent stream tests and fix type annotations --- backend/src/ea_chatbot/api/routers/agent.py | 7 +- backend/tests/api/test_agent_stream.py | 157 ++++++++++++++++++++ 2 files changed, 162 insertions(+), 2 deletions(-) create mode 100644 backend/tests/api/test_agent_stream.py diff --git a/backend/src/ea_chatbot/api/routers/agent.py b/backend/src/ea_chatbot/api/routers/agent.py index 7e2aee4..349dc6b 100644 --- a/backend/src/ea_chatbot/api/routers/agent.py +++ b/backend/src/ea_chatbot/api/routers/agent.py @@ -14,6 +14,8 @@ import base64 from langchain_core.runnables.config import RunnableConfig from langchain_core.messages import BaseMessage +from ea_chatbot.graph.state import AgentState + router = APIRouter(prefix="/chat", tags=["agent"]) async def stream_agent_events( @@ -27,7 +29,7 @@ async def stream_agent_events( Generator that invokes the LangGraph agent and yields SSE formatted events. Persists assistant responses and plots to the database. """ - initial_state = { + initial_state: AgentState = { "messages": messages, "question": message, "summary": summary, @@ -38,7 +40,8 @@ async def stream_agent_events( "code_output": None, "error": None, "plots": [], - "dfs": {} + "dfs": {}, + "iterations": 0 } config: RunnableConfig = {"configurable": {"thread_id": thread_id}} diff --git a/backend/tests/api/test_agent_stream.py b/backend/tests/api/test_agent_stream.py new file mode 100644 index 0000000..f81c5d9 --- /dev/null +++ b/backend/tests/api/test_agent_stream.py @@ -0,0 +1,157 @@ +import pytest +import json +import base64 +import io +import matplotlib.pyplot as plt +from fastapi.testclient import TestClient +from unittest.mock import MagicMock, patch, AsyncMock +from ea_chatbot.api.main import app +from ea_chatbot.api.dependencies import get_current_user +from ea_chatbot.history.models import User, Conversation +from ea_chatbot.api.routers.agent import stream_agent_events +from langchain_core.messages import AIMessage + +@pytest.fixture +def client(): + return TestClient(app) + +@pytest.fixture +def mock_user(): + return User(id="user-123", username="test@example.com") + +@pytest.mark.asyncio +async def test_stream_agent_events_all_features(): + """ + Test stream_agent_events generator logic with: + - stream chunks + - plots + - final response + - summary update + """ + # Prepare mock events + fig, ax = plt.subplots() + ax.plot([1, 2], [1, 2]) + + mock_events = [ + # Stream chunk + { + "event": "on_chat_model_stream", + "metadata": {"langgraph_node": "summarizer"}, + "data": {"chunk": AIMessage(content="Hello ")} + }, + { + "event": "on_chat_model_stream", + "metadata": {"langgraph_node": "summarizer"}, + "data": {"chunk": AIMessage(content="world")} + }, + # Plot event + { + "event": "on_chain_end", + "name": "executor", + "data": {"output": {"plots": [fig]}} + }, + # Final response + { + "event": "on_chain_end", + "name": "summarizer", + "data": {"output": {"messages": [AIMessage(content="Hello world final")]}} + }, + # Summary update + { + "event": "on_chain_end", + "name": "summarize_conversation", + "data": {"output": {"summary": "A conversation about hello world"}} + } + ] + + async def mock_astream_events(*args, **kwargs): + for event in mock_events: + yield event + + with patch("ea_chatbot.api.routers.agent.app.astream_events", side_effect=mock_astream_events), \ + patch("ea_chatbot.api.routers.agent.get_checkpointer") as mock_cp, \ + patch("ea_chatbot.api.routers.agent.history_manager") as mock_hm: + + mock_cp.return_value.__aenter__.return_value = AsyncMock() + + gen = stream_agent_events( + message="hi", + thread_id="t1", + user_id="u1", + summary="old", + messages=[] + ) + + results = [] + async for item in gen: + if item.startswith("data: "): + results.append(json.loads(item[6:])) + + # Verify results + assert any(r.get("type") == "on_chat_model_stream" for r in results) + + # Verify plot was encoded + plot_event = next(r for r in results if r.get("name") == "executor") + assert "encoded_plots" in plot_event["data"] + assert len(plot_event["data"]["encoded_plots"]) == 1 + + # Verify summary update was called + mock_hm.update_conversation_summary.assert_called_once_with("t1", "A conversation about hello world") + + # Verify final message saved to DB + # Final response should be "Hello world final" from summarizer node + mock_hm.add_message.assert_called() + args, kwargs = mock_hm.add_message.call_args + assert args[1] == "assistant" + assert args[2] == "Hello world final" + assert len(kwargs.get("plots", [])) == 1 + +@pytest.mark.asyncio +async def test_stream_agent_events_exception(): + """Test exception handling in stream_agent_events.""" + async def mock_astream_events_fail(*args, **kwargs): + raise ValueError("Something went wrong") + yield {} # Never reached + + with patch("ea_chatbot.api.routers.agent.app.astream_events", side_effect=mock_astream_events_fail), \ + patch("ea_chatbot.api.routers.agent.get_checkpointer") as mock_cp, \ + patch("ea_chatbot.api.routers.agent.history_manager") as mock_hm: + + mock_cp.return_value.__aenter__.return_value = AsyncMock() + + gen = stream_agent_events("hi", "t1", "u1", "old", []) + + results = [] + async for item in gen: + if item.startswith("data: "): + results.append(json.loads(item[6:])) + + assert any(r.get("type") == "error" for r in results) + assert "Something went wrong" in results[-1]["data"]["message"] + + # Verify error message saved to DB + mock_hm.add_message.assert_called() + args, kwargs = mock_hm.add_message.call_args + assert "Something went wrong" in args[2] + +def test_chat_stream_404_403(client, mock_user): + """Test 404 and 403 cases in chat_stream router.""" + from ea_chatbot.api.dependencies import get_current_user + app.dependency_overrides[get_current_user] = lambda: mock_user + + try: + with patch("ea_chatbot.api.routers.agent.history_manager") as mock_hm: + mock_session = MagicMock() + mock_hm.get_session.return_value.__enter__.return_value = mock_session + + # Case 1: 404 Not Found + mock_session.get.return_value = None + response = client.post("/api/v1/chat/stream", json={"message": "hi", "thread_id": "none"}) + assert response.status_code == 404 + + # Case 2: 403 Forbidden + mock_session.get.return_value = Conversation(id="c1", user_id="other-user") + response = client.post("/api/v1/chat/stream", json={"message": "hi", "thread_id": "c1"}) + assert response.status_code == 403 + finally: + app.dependency_overrides.clear()