158 lines
5.6 KiB
Python
158 lines
5.6 KiB
Python
import pytest
|
|
import json
|
|
import base64
|
|
import io
|
|
import matplotlib.pyplot as plt
|
|
from fastapi.testclient import TestClient
|
|
from unittest.mock import MagicMock, patch, AsyncMock
|
|
from ea_chatbot.api.main import app
|
|
from ea_chatbot.api.dependencies import get_current_user
|
|
from ea_chatbot.history.models import User, Conversation
|
|
from ea_chatbot.api.routers.agent import stream_agent_events
|
|
from langchain_core.messages import AIMessage
|
|
|
|
@pytest.fixture
|
|
def client():
|
|
return TestClient(app)
|
|
|
|
@pytest.fixture
|
|
def mock_user():
|
|
return User(id="user-123", username="test@example.com")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stream_agent_events_all_features():
|
|
"""
|
|
Test stream_agent_events generator logic with:
|
|
- stream chunks
|
|
- plots
|
|
- final response
|
|
- summary update
|
|
"""
|
|
# Prepare mock events
|
|
fig, ax = plt.subplots()
|
|
ax.plot([1, 2], [1, 2])
|
|
|
|
mock_events = [
|
|
# Stream chunk
|
|
{
|
|
"event": "on_chat_model_stream",
|
|
"metadata": {"langgraph_node": "summarizer"},
|
|
"data": {"chunk": AIMessage(content="Hello ")}
|
|
},
|
|
{
|
|
"event": "on_chat_model_stream",
|
|
"metadata": {"langgraph_node": "summarizer"},
|
|
"data": {"chunk": AIMessage(content="world")}
|
|
},
|
|
# Plot event
|
|
{
|
|
"event": "on_chain_end",
|
|
"name": "executor",
|
|
"data": {"output": {"plots": [fig]}}
|
|
},
|
|
# Final response
|
|
{
|
|
"event": "on_chain_end",
|
|
"name": "summarizer",
|
|
"data": {"output": {"messages": [AIMessage(content="Hello world final")]}}
|
|
},
|
|
# Summary update
|
|
{
|
|
"event": "on_chain_end",
|
|
"name": "summarize_conversation",
|
|
"data": {"output": {"summary": "A conversation about hello world"}}
|
|
}
|
|
]
|
|
|
|
async def mock_astream_events(*args, **kwargs):
|
|
for event in mock_events:
|
|
yield event
|
|
|
|
with patch("ea_chatbot.api.routers.agent.app.astream_events", side_effect=mock_astream_events), \
|
|
patch("ea_chatbot.api.routers.agent.get_checkpointer") as mock_cp, \
|
|
patch("ea_chatbot.api.routers.agent.history_manager") as mock_hm:
|
|
|
|
mock_cp.return_value.__aenter__.return_value = AsyncMock()
|
|
|
|
gen = stream_agent_events(
|
|
message="hi",
|
|
thread_id="t1",
|
|
user_id="u1",
|
|
summary="old",
|
|
messages=[]
|
|
)
|
|
|
|
results = []
|
|
async for item in gen:
|
|
if item.startswith("data: "):
|
|
results.append(json.loads(item[6:]))
|
|
|
|
# Verify results
|
|
assert any(r.get("type") == "on_chat_model_stream" for r in results)
|
|
|
|
# Verify plot was encoded
|
|
plot_event = next(r for r in results if r.get("name") == "executor")
|
|
assert "encoded_plots" in plot_event["data"]
|
|
assert len(plot_event["data"]["encoded_plots"]) == 1
|
|
|
|
# Verify summary update was called
|
|
mock_hm.update_conversation_summary.assert_called_once_with("t1", "A conversation about hello world")
|
|
|
|
# Verify final message saved to DB
|
|
# Final response should be "Hello world final" from summarizer node
|
|
mock_hm.add_message.assert_called()
|
|
args, kwargs = mock_hm.add_message.call_args
|
|
assert args[1] == "assistant"
|
|
assert args[2] == "Hello world final"
|
|
assert len(kwargs.get("plots", [])) == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stream_agent_events_exception():
|
|
"""Test exception handling in stream_agent_events."""
|
|
async def mock_astream_events_fail(*args, **kwargs):
|
|
raise ValueError("Something went wrong")
|
|
yield {} # Never reached
|
|
|
|
with patch("ea_chatbot.api.routers.agent.app.astream_events", side_effect=mock_astream_events_fail), \
|
|
patch("ea_chatbot.api.routers.agent.get_checkpointer") as mock_cp, \
|
|
patch("ea_chatbot.api.routers.agent.history_manager") as mock_hm:
|
|
|
|
mock_cp.return_value.__aenter__.return_value = AsyncMock()
|
|
|
|
gen = stream_agent_events("hi", "t1", "u1", "old", [])
|
|
|
|
results = []
|
|
async for item in gen:
|
|
if item.startswith("data: "):
|
|
results.append(json.loads(item[6:]))
|
|
|
|
assert any(r.get("type") == "error" for r in results)
|
|
assert "Something went wrong" in results[-1]["data"]["message"]
|
|
|
|
# Verify error message saved to DB
|
|
mock_hm.add_message.assert_called()
|
|
args, kwargs = mock_hm.add_message.call_args
|
|
assert "Something went wrong" in args[2]
|
|
|
|
def test_chat_stream_404_403(client, mock_user):
|
|
"""Test 404 and 403 cases in chat_stream router."""
|
|
from ea_chatbot.api.dependencies import get_current_user
|
|
app.dependency_overrides[get_current_user] = lambda: mock_user
|
|
|
|
try:
|
|
with patch("ea_chatbot.api.routers.agent.history_manager") as mock_hm:
|
|
mock_session = MagicMock()
|
|
mock_hm.get_session.return_value.__enter__.return_value = mock_session
|
|
|
|
# Case 1: 404 Not Found
|
|
mock_session.get.return_value = None
|
|
response = client.post("/api/v1/chat/stream", json={"message": "hi", "thread_id": "none"})
|
|
assert response.status_code == 404
|
|
|
|
# Case 2: 403 Forbidden
|
|
mock_session.get.return_value = Conversation(id="c1", user_id="other-user")
|
|
response = client.post("/api/v1/chat/stream", json={"message": "hi", "thread_id": "c1"})
|
|
assert response.status_code == 403
|
|
finally:
|
|
app.dependency_overrides.clear()
|