test(api): Add extended agent stream tests and fix type annotations

This commit is contained in:
Yunxiao Xu
2026-02-15 18:52:26 -08:00
parent 5b9d644fe5
commit 1cf00d0b3f
2 changed files with 162 additions and 2 deletions

View File

@@ -14,6 +14,8 @@ import base64
from langchain_core.runnables.config import RunnableConfig
from langchain_core.messages import BaseMessage
from ea_chatbot.graph.state import AgentState
router = APIRouter(prefix="/chat", tags=["agent"])
async def stream_agent_events(
@@ -27,7 +29,7 @@ async def stream_agent_events(
Generator that invokes the LangGraph agent and yields SSE formatted events.
Persists assistant responses and plots to the database.
"""
initial_state = {
initial_state: AgentState = {
"messages": messages,
"question": message,
"summary": summary,
@@ -38,7 +40,8 @@ async def stream_agent_events(
"code_output": None,
"error": None,
"plots": [],
"dfs": {}
"dfs": {},
"iterations": 0
}
config: RunnableConfig = {"configurable": {"thread_id": thread_id}}

View File

@@ -0,0 +1,157 @@
import pytest
import json
import base64
import io
import matplotlib.pyplot as plt
from fastapi.testclient import TestClient
from unittest.mock import MagicMock, patch, AsyncMock
from ea_chatbot.api.main import app
from ea_chatbot.api.dependencies import get_current_user
from ea_chatbot.history.models import User, Conversation
from ea_chatbot.api.routers.agent import stream_agent_events
from langchain_core.messages import AIMessage
@pytest.fixture
def client():
return TestClient(app)
@pytest.fixture
def mock_user():
return User(id="user-123", username="test@example.com")
@pytest.mark.asyncio
async def test_stream_agent_events_all_features():
"""
Test stream_agent_events generator logic with:
- stream chunks
- plots
- final response
- summary update
"""
# Prepare mock events
fig, ax = plt.subplots()
ax.plot([1, 2], [1, 2])
mock_events = [
# Stream chunk
{
"event": "on_chat_model_stream",
"metadata": {"langgraph_node": "summarizer"},
"data": {"chunk": AIMessage(content="Hello ")}
},
{
"event": "on_chat_model_stream",
"metadata": {"langgraph_node": "summarizer"},
"data": {"chunk": AIMessage(content="world")}
},
# Plot event
{
"event": "on_chain_end",
"name": "executor",
"data": {"output": {"plots": [fig]}}
},
# Final response
{
"event": "on_chain_end",
"name": "summarizer",
"data": {"output": {"messages": [AIMessage(content="Hello world final")]}}
},
# Summary update
{
"event": "on_chain_end",
"name": "summarize_conversation",
"data": {"output": {"summary": "A conversation about hello world"}}
}
]
async def mock_astream_events(*args, **kwargs):
for event in mock_events:
yield event
with patch("ea_chatbot.api.routers.agent.app.astream_events", side_effect=mock_astream_events), \
patch("ea_chatbot.api.routers.agent.get_checkpointer") as mock_cp, \
patch("ea_chatbot.api.routers.agent.history_manager") as mock_hm:
mock_cp.return_value.__aenter__.return_value = AsyncMock()
gen = stream_agent_events(
message="hi",
thread_id="t1",
user_id="u1",
summary="old",
messages=[]
)
results = []
async for item in gen:
if item.startswith("data: "):
results.append(json.loads(item[6:]))
# Verify results
assert any(r.get("type") == "on_chat_model_stream" for r in results)
# Verify plot was encoded
plot_event = next(r for r in results if r.get("name") == "executor")
assert "encoded_plots" in plot_event["data"]
assert len(plot_event["data"]["encoded_plots"]) == 1
# Verify summary update was called
mock_hm.update_conversation_summary.assert_called_once_with("t1", "A conversation about hello world")
# Verify final message saved to DB
# Final response should be "Hello world final" from summarizer node
mock_hm.add_message.assert_called()
args, kwargs = mock_hm.add_message.call_args
assert args[1] == "assistant"
assert args[2] == "Hello world final"
assert len(kwargs.get("plots", [])) == 1
@pytest.mark.asyncio
async def test_stream_agent_events_exception():
"""Test exception handling in stream_agent_events."""
async def mock_astream_events_fail(*args, **kwargs):
raise ValueError("Something went wrong")
yield {} # Never reached
with patch("ea_chatbot.api.routers.agent.app.astream_events", side_effect=mock_astream_events_fail), \
patch("ea_chatbot.api.routers.agent.get_checkpointer") as mock_cp, \
patch("ea_chatbot.api.routers.agent.history_manager") as mock_hm:
mock_cp.return_value.__aenter__.return_value = AsyncMock()
gen = stream_agent_events("hi", "t1", "u1", "old", [])
results = []
async for item in gen:
if item.startswith("data: "):
results.append(json.loads(item[6:]))
assert any(r.get("type") == "error" for r in results)
assert "Something went wrong" in results[-1]["data"]["message"]
# Verify error message saved to DB
mock_hm.add_message.assert_called()
args, kwargs = mock_hm.add_message.call_args
assert "Something went wrong" in args[2]
def test_chat_stream_404_403(client, mock_user):
"""Test 404 and 403 cases in chat_stream router."""
from ea_chatbot.api.dependencies import get_current_user
app.dependency_overrides[get_current_user] = lambda: mock_user
try:
with patch("ea_chatbot.api.routers.agent.history_manager") as mock_hm:
mock_session = MagicMock()
mock_hm.get_session.return_value.__enter__.return_value = mock_session
# Case 1: 404 Not Found
mock_session.get.return_value = None
response = client.post("/api/v1/chat/stream", json={"message": "hi", "thread_id": "none"})
assert response.status_code == 404
# Case 2: 403 Forbidden
mock_session.get.return_value = Conversation(id="c1", user_id="other-user")
response = client.post("/api/v1/chat/stream", json={"message": "hi", "thread_id": "c1"})
assert response.status_code == 403
finally:
app.dependency_overrides.clear()