test(api): Add extended agent stream tests and fix type annotations
This commit is contained in:
@@ -14,6 +14,8 @@ import base64
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
|
||||
router = APIRouter(prefix="/chat", tags=["agent"])
|
||||
|
||||
async def stream_agent_events(
|
||||
@@ -27,7 +29,7 @@ async def stream_agent_events(
|
||||
Generator that invokes the LangGraph agent and yields SSE formatted events.
|
||||
Persists assistant responses and plots to the database.
|
||||
"""
|
||||
initial_state = {
|
||||
initial_state: AgentState = {
|
||||
"messages": messages,
|
||||
"question": message,
|
||||
"summary": summary,
|
||||
@@ -38,7 +40,8 @@ async def stream_agent_events(
|
||||
"code_output": None,
|
||||
"error": None,
|
||||
"plots": [],
|
||||
"dfs": {}
|
||||
"dfs": {},
|
||||
"iterations": 0
|
||||
}
|
||||
|
||||
config: RunnableConfig = {"configurable": {"thread_id": thread_id}}
|
||||
|
||||
157
backend/tests/api/test_agent_stream.py
Normal file
157
backend/tests/api/test_agent_stream.py
Normal file
@@ -0,0 +1,157 @@
|
||||
import pytest
|
||||
import json
|
||||
import base64
|
||||
import io
|
||||
import matplotlib.pyplot as plt
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
from ea_chatbot.api.main import app
|
||||
from ea_chatbot.api.dependencies import get_current_user
|
||||
from ea_chatbot.history.models import User, Conversation
|
||||
from ea_chatbot.api.routers.agent import stream_agent_events
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
return TestClient(app)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user():
|
||||
return User(id="user-123", username="test@example.com")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_agent_events_all_features():
|
||||
"""
|
||||
Test stream_agent_events generator logic with:
|
||||
- stream chunks
|
||||
- plots
|
||||
- final response
|
||||
- summary update
|
||||
"""
|
||||
# Prepare mock events
|
||||
fig, ax = plt.subplots()
|
||||
ax.plot([1, 2], [1, 2])
|
||||
|
||||
mock_events = [
|
||||
# Stream chunk
|
||||
{
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"langgraph_node": "summarizer"},
|
||||
"data": {"chunk": AIMessage(content="Hello ")}
|
||||
},
|
||||
{
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"langgraph_node": "summarizer"},
|
||||
"data": {"chunk": AIMessage(content="world")}
|
||||
},
|
||||
# Plot event
|
||||
{
|
||||
"event": "on_chain_end",
|
||||
"name": "executor",
|
||||
"data": {"output": {"plots": [fig]}}
|
||||
},
|
||||
# Final response
|
||||
{
|
||||
"event": "on_chain_end",
|
||||
"name": "summarizer",
|
||||
"data": {"output": {"messages": [AIMessage(content="Hello world final")]}}
|
||||
},
|
||||
# Summary update
|
||||
{
|
||||
"event": "on_chain_end",
|
||||
"name": "summarize_conversation",
|
||||
"data": {"output": {"summary": "A conversation about hello world"}}
|
||||
}
|
||||
]
|
||||
|
||||
async def mock_astream_events(*args, **kwargs):
|
||||
for event in mock_events:
|
||||
yield event
|
||||
|
||||
with patch("ea_chatbot.api.routers.agent.app.astream_events", side_effect=mock_astream_events), \
|
||||
patch("ea_chatbot.api.routers.agent.get_checkpointer") as mock_cp, \
|
||||
patch("ea_chatbot.api.routers.agent.history_manager") as mock_hm:
|
||||
|
||||
mock_cp.return_value.__aenter__.return_value = AsyncMock()
|
||||
|
||||
gen = stream_agent_events(
|
||||
message="hi",
|
||||
thread_id="t1",
|
||||
user_id="u1",
|
||||
summary="old",
|
||||
messages=[]
|
||||
)
|
||||
|
||||
results = []
|
||||
async for item in gen:
|
||||
if item.startswith("data: "):
|
||||
results.append(json.loads(item[6:]))
|
||||
|
||||
# Verify results
|
||||
assert any(r.get("type") == "on_chat_model_stream" for r in results)
|
||||
|
||||
# Verify plot was encoded
|
||||
plot_event = next(r for r in results if r.get("name") == "executor")
|
||||
assert "encoded_plots" in plot_event["data"]
|
||||
assert len(plot_event["data"]["encoded_plots"]) == 1
|
||||
|
||||
# Verify summary update was called
|
||||
mock_hm.update_conversation_summary.assert_called_once_with("t1", "A conversation about hello world")
|
||||
|
||||
# Verify final message saved to DB
|
||||
# Final response should be "Hello world final" from summarizer node
|
||||
mock_hm.add_message.assert_called()
|
||||
args, kwargs = mock_hm.add_message.call_args
|
||||
assert args[1] == "assistant"
|
||||
assert args[2] == "Hello world final"
|
||||
assert len(kwargs.get("plots", [])) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_agent_events_exception():
|
||||
"""Test exception handling in stream_agent_events."""
|
||||
async def mock_astream_events_fail(*args, **kwargs):
|
||||
raise ValueError("Something went wrong")
|
||||
yield {} # Never reached
|
||||
|
||||
with patch("ea_chatbot.api.routers.agent.app.astream_events", side_effect=mock_astream_events_fail), \
|
||||
patch("ea_chatbot.api.routers.agent.get_checkpointer") as mock_cp, \
|
||||
patch("ea_chatbot.api.routers.agent.history_manager") as mock_hm:
|
||||
|
||||
mock_cp.return_value.__aenter__.return_value = AsyncMock()
|
||||
|
||||
gen = stream_agent_events("hi", "t1", "u1", "old", [])
|
||||
|
||||
results = []
|
||||
async for item in gen:
|
||||
if item.startswith("data: "):
|
||||
results.append(json.loads(item[6:]))
|
||||
|
||||
assert any(r.get("type") == "error" for r in results)
|
||||
assert "Something went wrong" in results[-1]["data"]["message"]
|
||||
|
||||
# Verify error message saved to DB
|
||||
mock_hm.add_message.assert_called()
|
||||
args, kwargs = mock_hm.add_message.call_args
|
||||
assert "Something went wrong" in args[2]
|
||||
|
||||
def test_chat_stream_404_403(client, mock_user):
|
||||
"""Test 404 and 403 cases in chat_stream router."""
|
||||
from ea_chatbot.api.dependencies import get_current_user
|
||||
app.dependency_overrides[get_current_user] = lambda: mock_user
|
||||
|
||||
try:
|
||||
with patch("ea_chatbot.api.routers.agent.history_manager") as mock_hm:
|
||||
mock_session = MagicMock()
|
||||
mock_hm.get_session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Case 1: 404 Not Found
|
||||
mock_session.get.return_value = None
|
||||
response = client.post("/api/v1/chat/stream", json={"message": "hi", "thread_id": "none"})
|
||||
assert response.status_code == 404
|
||||
|
||||
# Case 2: 403 Forbidden
|
||||
mock_session.get.return_value = Conversation(id="c1", user_id="other-user")
|
||||
response = client.post("/api/v1/chat/stream", json={"message": "hi", "thread_id": "c1"})
|
||||
assert response.status_code == 403
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
Reference in New Issue
Block a user