feat(api): Synchronize history and summary from DB in chat stream
This commit is contained in:
@@ -7,10 +7,12 @@ from ea_chatbot.api.utils import convert_to_json_compatible
|
|||||||
from ea_chatbot.graph.workflow import app
|
from ea_chatbot.graph.workflow import app
|
||||||
from ea_chatbot.graph.checkpoint import get_checkpointer
|
from ea_chatbot.graph.checkpoint import get_checkpointer
|
||||||
from ea_chatbot.history.models import User as UserDB, Conversation
|
from ea_chatbot.history.models import User as UserDB, Conversation
|
||||||
|
from ea_chatbot.history.utils import map_db_messages_to_langchain
|
||||||
from ea_chatbot.api.schemas import ChatRequest
|
from ea_chatbot.api.schemas import ChatRequest
|
||||||
import io
|
import io
|
||||||
import base64
|
import base64
|
||||||
from langchain_core.runnables.config import RunnableConfig
|
from langchain_core.runnables.config import RunnableConfig
|
||||||
|
from langchain_core.messages import BaseMessage
|
||||||
|
|
||||||
router = APIRouter(prefix="/chat", tags=["agent"])
|
router = APIRouter(prefix="/chat", tags=["agent"])
|
||||||
|
|
||||||
@@ -18,14 +20,15 @@ async def stream_agent_events(
|
|||||||
message: str,
|
message: str,
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
summary: str
|
summary: str,
|
||||||
|
messages: List[BaseMessage] = []
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
"""
|
"""
|
||||||
Generator that invokes the LangGraph agent and yields SSE formatted events.
|
Generator that invokes the LangGraph agent and yields SSE formatted events.
|
||||||
Persists assistant responses and plots to the database.
|
Persists assistant responses and plots to the database.
|
||||||
"""
|
"""
|
||||||
initial_state = {
|
initial_state = {
|
||||||
"messages": [],
|
"messages": messages,
|
||||||
"question": message,
|
"question": message,
|
||||||
"summary": summary,
|
"summary": summary,
|
||||||
"analysis": None,
|
"analysis": None,
|
||||||
@@ -150,7 +153,15 @@ async def chat_stream(
|
|||||||
if conv.user_id != current_user.id:
|
if conv.user_id != current_user.id:
|
||||||
raise HTTPException(status_code=403, detail="Not authorized to access this conversation")
|
raise HTTPException(status_code=403, detail="Not authorized to access this conversation")
|
||||||
|
|
||||||
# Save user message immediately
|
# Load existing summary from DB if not provided in request
|
||||||
|
db_summary = conv.summary or ""
|
||||||
|
|
||||||
|
# Load last 10 messages for context (BEFORE saving the current user message)
|
||||||
|
# This ensures we don't include the current message twice if the graph reduces it.
|
||||||
|
db_messages = history_manager.get_messages_by_window(request.thread_id, window_size=10)
|
||||||
|
lc_messages = map_db_messages_to_langchain(db_messages)
|
||||||
|
|
||||||
|
# Save user message immediately to DB
|
||||||
history_manager.add_message(request.thread_id, "user", request.message)
|
history_manager.add_message(request.thread_id, "user", request.message)
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
@@ -158,7 +169,8 @@ async def chat_stream(
|
|||||||
request.message,
|
request.message,
|
||||||
request.thread_id,
|
request.thread_id,
|
||||||
current_user.id,
|
current_user.id,
|
||||||
request.summary or ""
|
db_summary,
|
||||||
|
lc_messages
|
||||||
),
|
),
|
||||||
media_type="text/event-stream"
|
media_type="text/event-stream"
|
||||||
)
|
)
|
||||||
79
backend/tests/api/test_history_sync.py
Normal file
79
backend/tests/api/test_history_sync.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
from ea_chatbot.api.main import app
|
||||||
|
from ea_chatbot.history.models import User, Conversation, Message
|
||||||
|
from langchain_core.messages import HumanMessage, AIMessage
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client():
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_user():
|
||||||
|
return User(id="user-123", username="test@test.com")
|
||||||
|
|
||||||
|
def test_chat_stream_loads_history(client, mock_user):
|
||||||
|
"""
|
||||||
|
Test that the /chat/stream endpoint loads history from the DB
|
||||||
|
and passes it to stream_agent_events.
|
||||||
|
"""
|
||||||
|
# 1. Setup mocks
|
||||||
|
from ea_chatbot.api.dependencies import get_current_user
|
||||||
|
app.dependency_overrides[get_current_user] = lambda: mock_user
|
||||||
|
|
||||||
|
try:
|
||||||
|
with patch("ea_chatbot.api.routers.agent.history_manager") as mock_hm, \
|
||||||
|
patch("ea_chatbot.api.routers.agent.stream_agent_events") as mock_stream:
|
||||||
|
|
||||||
|
# Mock conversation exists and belongs to user
|
||||||
|
mock_conv = Conversation(id="conv-123", user_id=mock_user.id, summary="Old summary")
|
||||||
|
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_session.get.return_value = mock_conv
|
||||||
|
mock_hm.get_session.return_value.__enter__.return_value = mock_session
|
||||||
|
|
||||||
|
# Mock history window returns some messages
|
||||||
|
mock_db_messages = [
|
||||||
|
Message(role="user", content="Question 1"),
|
||||||
|
Message(role="assistant", content="Answer 1")
|
||||||
|
]
|
||||||
|
mock_hm.get_messages_by_window.return_value = mock_db_messages
|
||||||
|
|
||||||
|
# Mock stream_agent_events to return an empty generator
|
||||||
|
async def empty_gen(*args, **kwargs):
|
||||||
|
yield 'data: {"type": "done"}\n\n'
|
||||||
|
|
||||||
|
mock_stream.side_effect = empty_gen
|
||||||
|
|
||||||
|
# 2. Execute request
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/chat/stream",
|
||||||
|
json={"message": "New question", "thread_id": "conv-123"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Assertions
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# Verify history_manager was called to get history
|
||||||
|
mock_hm.get_messages_by_window.assert_called_once_with("conv-123", window_size=10)
|
||||||
|
|
||||||
|
# Verify stream_agent_events was called with loaded history and summary
|
||||||
|
args, kwargs = mock_stream.call_args
|
||||||
|
# args: message, thread_id, user_id, summary, messages
|
||||||
|
assert args[0] == "New question"
|
||||||
|
assert args[1] == "conv-123"
|
||||||
|
assert args[2] == mock_user.id
|
||||||
|
assert args[3] == "Old summary"
|
||||||
|
|
||||||
|
passed_messages = args[4]
|
||||||
|
assert len(passed_messages) == 2
|
||||||
|
assert isinstance(passed_messages[0], HumanMessage)
|
||||||
|
assert passed_messages[0].content == "Question 1"
|
||||||
|
assert isinstance(passed_messages[1], AIMessage)
|
||||||
|
assert passed_messages[1].content == "Answer 1"
|
||||||
|
|
||||||
|
# Verify current message was saved AFTER loading history (based on code logic)
|
||||||
|
mock_hm.add_message.assert_called_with("conv-123", "user", "New question")
|
||||||
|
finally:
|
||||||
|
app.dependency_overrides.clear()
|
||||||
Reference in New Issue
Block a user