feat(api): Synchronize history and summary from DB in chat stream

This commit is contained in:
Yunxiao Xu
2026-02-15 04:11:42 -08:00
parent 398aad4280
commit 5b9d644fe5
2 changed files with 95 additions and 4 deletions

View File

@@ -7,10 +7,12 @@ from ea_chatbot.api.utils import convert_to_json_compatible
from ea_chatbot.graph.workflow import app from ea_chatbot.graph.workflow import app
from ea_chatbot.graph.checkpoint import get_checkpointer from ea_chatbot.graph.checkpoint import get_checkpointer
from ea_chatbot.history.models import User as UserDB, Conversation from ea_chatbot.history.models import User as UserDB, Conversation
from ea_chatbot.history.utils import map_db_messages_to_langchain
from ea_chatbot.api.schemas import ChatRequest from ea_chatbot.api.schemas import ChatRequest
import io import io
import base64 import base64
from langchain_core.runnables.config import RunnableConfig from langchain_core.runnables.config import RunnableConfig
from langchain_core.messages import BaseMessage
router = APIRouter(prefix="/chat", tags=["agent"]) router = APIRouter(prefix="/chat", tags=["agent"])
@@ -18,14 +20,15 @@ async def stream_agent_events(
message: str, message: str,
thread_id: str, thread_id: str,
user_id: str, user_id: str,
summary: str summary: str,
messages: List[BaseMessage] = []
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
""" """
Generator that invokes the LangGraph agent and yields SSE formatted events. Generator that invokes the LangGraph agent and yields SSE formatted events.
Persists assistant responses and plots to the database. Persists assistant responses and plots to the database.
""" """
initial_state = { initial_state = {
"messages": [], "messages": messages,
"question": message, "question": message,
"summary": summary, "summary": summary,
"analysis": None, "analysis": None,
@@ -150,7 +153,15 @@ async def chat_stream(
if conv.user_id != current_user.id: if conv.user_id != current_user.id:
raise HTTPException(status_code=403, detail="Not authorized to access this conversation") raise HTTPException(status_code=403, detail="Not authorized to access this conversation")
# Save user message immediately # Load existing summary from DB if not provided in request
db_summary = conv.summary or ""
# Load last 10 messages for context (BEFORE saving the current user message)
# This ensures we don't include the current message twice if the graph reduces it.
db_messages = history_manager.get_messages_by_window(request.thread_id, window_size=10)
lc_messages = map_db_messages_to_langchain(db_messages)
# Save user message immediately to DB
history_manager.add_message(request.thread_id, "user", request.message) history_manager.add_message(request.thread_id, "user", request.message)
return StreamingResponse( return StreamingResponse(
@@ -158,7 +169,8 @@ async def chat_stream(
request.message, request.message,
request.thread_id, request.thread_id,
current_user.id, current_user.id,
request.summary or "" db_summary,
lc_messages
), ),
media_type="text/event-stream" media_type="text/event-stream"
) )

View File

@@ -0,0 +1,79 @@
import pytest
from fastapi.testclient import TestClient
from unittest.mock import MagicMock, patch
from ea_chatbot.api.main import app
from ea_chatbot.history.models import User, Conversation, Message
from langchain_core.messages import HumanMessage, AIMessage
@pytest.fixture
def client():
return TestClient(app)
@pytest.fixture
def mock_user():
return User(id="user-123", username="test@test.com")
def test_chat_stream_loads_history(client, mock_user):
"""
Test that the /chat/stream endpoint loads history from the DB
and passes it to stream_agent_events.
"""
# 1. Setup mocks
from ea_chatbot.api.dependencies import get_current_user
app.dependency_overrides[get_current_user] = lambda: mock_user
try:
with patch("ea_chatbot.api.routers.agent.history_manager") as mock_hm, \
patch("ea_chatbot.api.routers.agent.stream_agent_events") as mock_stream:
# Mock conversation exists and belongs to user
mock_conv = Conversation(id="conv-123", user_id=mock_user.id, summary="Old summary")
mock_session = MagicMock()
mock_session.get.return_value = mock_conv
mock_hm.get_session.return_value.__enter__.return_value = mock_session
# Mock history window returns some messages
mock_db_messages = [
Message(role="user", content="Question 1"),
Message(role="assistant", content="Answer 1")
]
mock_hm.get_messages_by_window.return_value = mock_db_messages
# Mock stream_agent_events to return an empty generator
async def empty_gen(*args, **kwargs):
yield 'data: {"type": "done"}\n\n'
mock_stream.side_effect = empty_gen
# 2. Execute request
response = client.post(
"/api/v1/chat/stream",
json={"message": "New question", "thread_id": "conv-123"}
)
# 3. Assertions
assert response.status_code == 200
# Verify history_manager was called to get history
mock_hm.get_messages_by_window.assert_called_once_with("conv-123", window_size=10)
# Verify stream_agent_events was called with loaded history and summary
args, kwargs = mock_stream.call_args
# args: message, thread_id, user_id, summary, messages
assert args[0] == "New question"
assert args[1] == "conv-123"
assert args[2] == mock_user.id
assert args[3] == "Old summary"
passed_messages = args[4]
assert len(passed_messages) == 2
assert isinstance(passed_messages[0], HumanMessage)
assert passed_messages[0].content == "Question 1"
assert isinstance(passed_messages[1], AIMessage)
assert passed_messages[1].content == "Answer 1"
# Verify current message was saved AFTER loading history (based on code logic)
mock_hm.add_message.assert_called_with("conv-123", "user", "New question")
finally:
app.dependency_overrides.clear()