feat(api): Enhance persistence logic and refactor codebase structure

This commit is contained in:
Yunxiao Xu
2026-02-11 15:33:56 -08:00
parent 371582dcd1
commit 85329cffda
7 changed files with 244 additions and 95 deletions

View File

@@ -25,7 +25,7 @@ def test_stream_agent_unauthorized():
assert response.status_code == 401
@pytest.mark.asyncio
async def test_stream_agent_success(auth_header):
async def test_stream_agent_success(auth_header, mock_user):
"""Test successful agent streaming with SSE."""
# We need to mock the LangGraph app.astream_events
mock_events = [
@@ -39,10 +39,18 @@ async def test_stream_agent_success(auth_header):
yield event
with patch("ea_chatbot.api.routers.agent.app.astream_events", side_effect=mock_astream_events), \
patch("ea_chatbot.api.routers.agent.get_checkpointer") as mock_cp:
patch("ea_chatbot.api.routers.agent.get_checkpointer") as mock_cp, \
patch("ea_chatbot.api.routers.agent.history_manager") as mock_hm:
mock_cp.return_value.__aenter__.return_value = AsyncMock()
# Mock session and DB objects
mock_session = MagicMock()
mock_hm.get_session.return_value.__enter__.return_value = mock_session
from ea_chatbot.history.models import Conversation
mock_conv = Conversation(id="t1", user_id=mock_user.id)
mock_session.get.return_value = mock_conv
# Using TestClient with a stream context
with client.stream("POST", "/chat/stream",
json={"message": "hello", "thread_id": "t1"},

View File

@@ -0,0 +1,63 @@
import pytest
from fastapi.testclient import TestClient
from unittest.mock import MagicMock, patch, AsyncMock
from ea_chatbot.api.main import app
from ea_chatbot.api.dependencies import get_current_user
from ea_chatbot.history.models import User, Conversation, Message, Plot
from ea_chatbot.api.utils import create_access_token
from datetime import datetime, timezone
import json
client = TestClient(app)
@pytest.fixture
def mock_user():
return User(id="user-123", username="test@example.com", display_name="Test User")
@pytest.fixture
def auth_header(mock_user):
app.dependency_overrides[get_current_user] = lambda: mock_user
token = create_access_token(data={"sub": mock_user.username, "user_id": mock_user.id})
yield {"Authorization": f"Bearer {token}"}
app.dependency_overrides.clear()
def test_persistence_integration_success(auth_header, mock_user):
"""Test that messages and plots are persisted correctly during streaming."""
mock_events = [
{"event": "on_chat_model_stream", "name": "summarizer", "data": {"chunk": "Final answer"}},
{"event": "on_chain_end", "name": "summarizer", "data": {"output": {"messages": [{"content": "Final answer"}]}}},
{"event": "on_chain_end", "name": "summarize_conversation", "data": {"output": {"summary": "New summary"}}}
]
async def mock_astream_events(*args, **kwargs):
for event in mock_events:
yield event
with patch("ea_chatbot.api.routers.agent.app.astream_events", side_effect=mock_astream_events), \
patch("ea_chatbot.api.routers.agent.get_checkpointer") as mock_cp, \
patch("ea_chatbot.api.routers.agent.history_manager") as mock_hm:
mock_cp.return_value.__aenter__.return_value = AsyncMock()
# Mock session and DB objects
mock_session = MagicMock()
mock_hm.get_session.return_value.__enter__.return_value = mock_session
mock_conv = Conversation(id="t1", user_id=mock_user.id)
mock_session.get.return_value = mock_conv
# Act
with client.stream("POST", "/chat/stream",
json={"message": "persistence test", "thread_id": "t1"},
headers=auth_header) as response:
assert response.status_code == 200
list(response.iter_lines()) # Consume stream
# Assertions
# 1. User message should be saved immediately
mock_hm.add_message.assert_any_call("t1", "user", "persistence test")
# 2. Assistant message should be saved at the end
mock_hm.add_message.assert_any_call("t1", "assistant", "Final answer", plots=[])
# 3. Summary should be updated
mock_hm.update_conversation_summary.assert_called_once_with("t1", "New summary")