test(api): Add extended agent stream tests and fix type annotations

This commit is contained in:
Yunxiao Xu
2026-02-15 18:52:26 -08:00
parent 5b9d644fe5
commit 1cf00d0b3f
2 changed files with 162 additions and 2 deletions

View File

@@ -14,6 +14,8 @@ import base64
from langchain_core.runnables.config import RunnableConfig from langchain_core.runnables.config import RunnableConfig
from langchain_core.messages import BaseMessage from langchain_core.messages import BaseMessage
from ea_chatbot.graph.state import AgentState
router = APIRouter(prefix="/chat", tags=["agent"]) router = APIRouter(prefix="/chat", tags=["agent"])
async def stream_agent_events( async def stream_agent_events(
@@ -27,7 +29,7 @@ async def stream_agent_events(
Generator that invokes the LangGraph agent and yields SSE formatted events. Generator that invokes the LangGraph agent and yields SSE formatted events.
Persists assistant responses and plots to the database. Persists assistant responses and plots to the database.
""" """
initial_state = { initial_state: AgentState = {
"messages": messages, "messages": messages,
"question": message, "question": message,
"summary": summary, "summary": summary,
@@ -38,7 +40,8 @@ async def stream_agent_events(
"code_output": None, "code_output": None,
"error": None, "error": None,
"plots": [], "plots": [],
"dfs": {} "dfs": {},
"iterations": 0
} }
config: RunnableConfig = {"configurable": {"thread_id": thread_id}} config: RunnableConfig = {"configurable": {"thread_id": thread_id}}

View File

@@ -0,0 +1,157 @@
import pytest
import json
import base64
import io
import matplotlib.pyplot as plt
from fastapi.testclient import TestClient
from unittest.mock import MagicMock, patch, AsyncMock
from ea_chatbot.api.main import app
from ea_chatbot.api.dependencies import get_current_user
from ea_chatbot.history.models import User, Conversation
from ea_chatbot.api.routers.agent import stream_agent_events
from langchain_core.messages import AIMessage
@pytest.fixture
def client():
return TestClient(app)
@pytest.fixture
def mock_user():
return User(id="user-123", username="test@example.com")
@pytest.mark.asyncio
async def test_stream_agent_events_all_features():
"""
Test stream_agent_events generator logic with:
- stream chunks
- plots
- final response
- summary update
"""
# Prepare mock events
fig, ax = plt.subplots()
ax.plot([1, 2], [1, 2])
mock_events = [
# Stream chunk
{
"event": "on_chat_model_stream",
"metadata": {"langgraph_node": "summarizer"},
"data": {"chunk": AIMessage(content="Hello ")}
},
{
"event": "on_chat_model_stream",
"metadata": {"langgraph_node": "summarizer"},
"data": {"chunk": AIMessage(content="world")}
},
# Plot event
{
"event": "on_chain_end",
"name": "executor",
"data": {"output": {"plots": [fig]}}
},
# Final response
{
"event": "on_chain_end",
"name": "summarizer",
"data": {"output": {"messages": [AIMessage(content="Hello world final")]}}
},
# Summary update
{
"event": "on_chain_end",
"name": "summarize_conversation",
"data": {"output": {"summary": "A conversation about hello world"}}
}
]
async def mock_astream_events(*args, **kwargs):
for event in mock_events:
yield event
with patch("ea_chatbot.api.routers.agent.app.astream_events", side_effect=mock_astream_events), \
patch("ea_chatbot.api.routers.agent.get_checkpointer") as mock_cp, \
patch("ea_chatbot.api.routers.agent.history_manager") as mock_hm:
mock_cp.return_value.__aenter__.return_value = AsyncMock()
gen = stream_agent_events(
message="hi",
thread_id="t1",
user_id="u1",
summary="old",
messages=[]
)
results = []
async for item in gen:
if item.startswith("data: "):
results.append(json.loads(item[6:]))
# Verify results
assert any(r.get("type") == "on_chat_model_stream" for r in results)
# Verify plot was encoded
plot_event = next(r for r in results if r.get("name") == "executor")
assert "encoded_plots" in plot_event["data"]
assert len(plot_event["data"]["encoded_plots"]) == 1
# Verify summary update was called
mock_hm.update_conversation_summary.assert_called_once_with("t1", "A conversation about hello world")
# Verify final message saved to DB
# Final response should be "Hello world final" from summarizer node
mock_hm.add_message.assert_called()
args, kwargs = mock_hm.add_message.call_args
assert args[1] == "assistant"
assert args[2] == "Hello world final"
assert len(kwargs.get("plots", [])) == 1
@pytest.mark.asyncio
async def test_stream_agent_events_exception():
"""Test exception handling in stream_agent_events."""
async def mock_astream_events_fail(*args, **kwargs):
raise ValueError("Something went wrong")
yield {} # Never reached
with patch("ea_chatbot.api.routers.agent.app.astream_events", side_effect=mock_astream_events_fail), \
patch("ea_chatbot.api.routers.agent.get_checkpointer") as mock_cp, \
patch("ea_chatbot.api.routers.agent.history_manager") as mock_hm:
mock_cp.return_value.__aenter__.return_value = AsyncMock()
gen = stream_agent_events("hi", "t1", "u1", "old", [])
results = []
async for item in gen:
if item.startswith("data: "):
results.append(json.loads(item[6:]))
assert any(r.get("type") == "error" for r in results)
assert "Something went wrong" in results[-1]["data"]["message"]
# Verify error message saved to DB
mock_hm.add_message.assert_called()
args, kwargs = mock_hm.add_message.call_args
assert "Something went wrong" in args[2]
def test_chat_stream_404_403(client, mock_user):
"""Test 404 and 403 cases in chat_stream router."""
from ea_chatbot.api.dependencies import get_current_user
app.dependency_overrides[get_current_user] = lambda: mock_user
try:
with patch("ea_chatbot.api.routers.agent.history_manager") as mock_hm:
mock_session = MagicMock()
mock_hm.get_session.return_value.__enter__.return_value = mock_session
# Case 1: 404 Not Found
mock_session.get.return_value = None
response = client.post("/api/v1/chat/stream", json={"message": "hi", "thread_id": "none"})
assert response.status_code == 404
# Case 2: 403 Forbidden
mock_session.get.return_value = Conversation(id="c1", user_id="other-user")
response = client.post("/api/v1/chat/stream", json={"message": "hi", "thread_id": "c1"})
assert response.status_code == 403
finally:
app.dependency_overrides.clear()