feat(api): Enhance persistence logic and refactor codebase structure

This commit is contained in:
Yunxiao Xu
2026-02-11 15:33:56 -08:00
parent 371582dcd1
commit 85329cffda
7 changed files with 244 additions and 95 deletions

View File

@@ -1,41 +1,20 @@
import json import json
import asyncio import asyncio
from typing import AsyncGenerator, Optional from typing import AsyncGenerator, Optional, List
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from ea_chatbot.api.dependencies import get_current_user, history_manager from ea_chatbot.api.dependencies import get_current_user, history_manager
from ea_chatbot.api.utils import convert_to_json_compatible
from ea_chatbot.graph.workflow import app from ea_chatbot.graph.workflow import app
from ea_chatbot.graph.checkpoint import get_checkpointer from ea_chatbot.graph.checkpoint import get_checkpointer
from ea_chatbot.history.models import User as UserDB, Conversation from ea_chatbot.history.models import User as UserDB, Conversation
from ea_chatbot.api.schemas import ChatRequest
import io import io
import base64 import base64
from langchain_core.messages import BaseMessage from langchain_core.messages import BaseMessage
router = APIRouter(prefix="/chat", tags=["agent"]) router = APIRouter(prefix="/chat", tags=["agent"])
def convert_to_json_compatible(obj):
"""Recursively convert LangChain objects to JSON compatible formats."""
if isinstance(obj, list):
return [convert_to_json_compatible(item) for item in obj]
elif isinstance(obj, dict):
return {k: convert_to_json_compatible(v) for k, v in obj.items()}
elif isinstance(obj, BaseMessage):
return {"type": obj.type, "content": obj.content, **convert_to_json_compatible(obj.additional_kwargs)}
elif isinstance(obj, BaseModel) or hasattr(obj, "model_dump"):
try:
return convert_to_json_compatible(obj.model_dump())
except Exception:
return str(obj)
elif hasattr(obj, "content"):
return str(obj.content)
return obj
class ChatRequest(BaseModel):
message: str
thread_id: str # This maps to the conversation_id in our DB
summary: Optional[str] = ""
async def stream_agent_events( async def stream_agent_events(
message: str, message: str,
thread_id: str, thread_id: str,
@@ -44,6 +23,7 @@ async def stream_agent_events(
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
""" """
Generator that invokes the LangGraph agent and yields SSE formatted events. Generator that invokes the LangGraph agent and yields SSE formatted events.
Persists assistant responses and plots to the database.
""" """
initial_state = { initial_state = {
"messages": [], "messages": [],
@@ -61,39 +41,92 @@ async def stream_agent_events(
config = {"configurable": {"thread_id": thread_id}} config = {"configurable": {"thread_id": thread_id}}
async with get_checkpointer() as checkpointer: assistant_chunks: List[str] = []
async for event in app.astream_events( assistant_plots: List[bytes] = []
initial_state, final_response: str = ""
config, new_summary: str = ""
version="v2",
checkpointer=checkpointer
):
kind = event.get("event")
name = event.get("name")
output_event = { try:
"type": kind, async with get_checkpointer() as checkpointer:
"name": name, async for event in app.astream_events(
"data": event.get("data", {}) initial_state,
} config,
version="v2",
checkpointer=checkpointer
):
kind = event.get("event")
name = event.get("name")
data = event.get("data", {})
if kind == "on_chain_end" and name == "executor": # Standardize event for frontend
output = event.get("data", {}).get("output", {}) output_event = {
if isinstance(output, dict) and "plots" in output: "type": kind,
plots = output["plots"] "name": name,
encoded_plots = [] "data": data
for fig in plots: }
buf = io.BytesIO()
fig.savefig(buf, format="png")
encoded_plots.append(base64.b64encode(buf.getvalue()).decode('utf-8'))
output_event["data"]["encoded_plots"] = encoded_plots
# Convert to JSON compatible format to avoid serialization errors # Buffer assistant chunks (summarizer and researcher might stream)
compatible_output = convert_to_json_compatible(output_event) if kind == "on_chat_model_stream" and name in ["summarizer", "researcher"]:
yield f"data: {json.dumps(compatible_output)}\n\n" chunk = data.get("chunk", "")
if hasattr(chunk, "content"):
chunk = chunk.content
assistant_chunks.append(str(chunk))
# Buffer and encode plots
if kind == "on_chain_end" and name == "executor":
output = data.get("output", {})
if isinstance(output, dict) and "plots" in output:
plots = output["plots"]
encoded_plots = []
for fig in plots:
buf = io.BytesIO()
fig.savefig(buf, format="png")
plot_bytes = buf.getvalue()
assistant_plots.append(plot_bytes)
encoded_plots.append(base64.b64encode(plot_bytes).decode('utf-8'))
output_event["data"]["encoded_plots"] = encoded_plots
# Collect final response from terminal nodes
if kind == "on_chain_end" and name in ["summarizer", "researcher", "clarification"]:
output = data.get("output", {})
if isinstance(output, dict) and "messages" in output:
last_msg = output["messages"][-1]
if hasattr(last_msg, "content"):
final_response = last_msg.content
elif isinstance(last_msg, dict) and "content" in last_msg:
final_response = last_msg["content"]
else:
final_response = str(last_msg)
# Collect new summary
if kind == "on_chain_end" and name == "summarize_conversation":
output = data.get("output", {})
if isinstance(output, dict) and "summary" in output:
new_summary = output["summary"]
# Convert to JSON compatible format to avoid serialization errors
compatible_output = convert_to_json_compatible(output_event)
yield f"data: {json.dumps(compatible_output)}\n\n"
# If we didn't get a final_response from node output, use buffered chunks
if not final_response and assistant_chunks:
final_response = "".join(assistant_chunks)
# Save assistant message to DB
if final_response:
history_manager.add_message(thread_id, "assistant", final_response, plots=assistant_plots)
# Update summary in DB
if new_summary:
history_manager.update_conversation_summary(thread_id, new_summary)
yield "data: {\"type\": \"done\"}\n\n" yield "data: {\"type\": \"done\"}\n\n"
except Exception as e:
error_msg = f"Agent execution failed: {str(e)}"
history_manager.add_message(thread_id, "assistant", error_msg)
yield f"data: {json.dumps({'type': 'error', 'message': error_msg})}\n\n"
@router.post("/stream") @router.post("/stream")
async def chat_stream( async def chat_stream(
request: ChatRequest, request: ChatRequest,
@@ -104,9 +137,14 @@ async def chat_stream(
""" """
with history_manager.get_session() as session: with history_manager.get_session() as session:
conv = session.get(Conversation, request.thread_id) conv = session.get(Conversation, request.thread_id)
if conv and conv.user_id != current_user.id: if not conv:
raise HTTPException(status_code=404, detail="Conversation not found")
if conv.user_id != current_user.id:
raise HTTPException(status_code=403, detail="Not authorized to access this conversation") raise HTTPException(status_code=403, detail="Not authorized to access this conversation")
# Save user message immediately
history_manager.add_message(request.thread_id, "user", request.message)
return StreamingResponse( return StreamingResponse(
stream_agent_events( stream_agent_events(
request.message, request.message,

View File

@@ -1,27 +1,12 @@
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordRequestForm
from pydantic import BaseModel, EmailStr
from typing import Optional
from ea_chatbot.api.utils import create_access_token from ea_chatbot.api.utils import create_access_token
from ea_chatbot.api.dependencies import history_manager, oidc_client, get_current_user from ea_chatbot.api.dependencies import history_manager, oidc_client, get_current_user
from ea_chatbot.history.models import User as UserDB from ea_chatbot.history.models import User as UserDB
from ea_chatbot.api.schemas import Token, UserCreate, UserResponse
router = APIRouter(prefix="/auth", tags=["auth"]) router = APIRouter(prefix="/auth", tags=["auth"])
class Token(BaseModel):
access_token: str
token_type: str
class UserCreate(BaseModel):
email: EmailStr
password: str
display_name: Optional[str] = None
class UserResponse(BaseModel):
id: str
email: str
display_name: str
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED) @router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
async def register(user_in: UserCreate): async def register(user_in: UserCreate):
"""Register a new user.""" """Register a new user."""

View File

@@ -1,30 +1,11 @@
from fastapi import APIRouter, Depends, HTTPException, status, Response from fastapi import APIRouter, Depends, HTTPException, status, Response
from pydantic import BaseModel
from typing import List, Optional from typing import List, Optional
from datetime import datetime
from ea_chatbot.api.dependencies import get_current_user, history_manager, settings from ea_chatbot.api.dependencies import get_current_user, history_manager, settings
from ea_chatbot.history.models import User as UserDB from ea_chatbot.history.models import User as UserDB
from ea_chatbot.api.schemas import ConversationResponse, MessageResponse, ConversationUpdate
router = APIRouter(prefix="/conversations", tags=["history"]) router = APIRouter(prefix="/conversations", tags=["history"])
class ConversationResponse(BaseModel):
id: str
name: str
summary: Optional[str] = None
created_at: datetime
data_state: str
class MessageResponse(BaseModel):
id: str
role: str
content: str
created_at: datetime
# We don't include plots directly here, they'll be fetched via artifact endpoints
class ConversationUpdate(BaseModel):
name: Optional[str] = None
summary: Optional[str] = None
@router.get("", response_model=List[ConversationResponse]) @router.get("", response_model=List[ConversationResponse])
async def list_conversations( async def list_conversations(
current_user: UserDB = Depends(get_current_user), current_user: UserDB = Depends(get_current_user),

View File

@@ -0,0 +1,46 @@
from pydantic import BaseModel, EmailStr
from typing import List, Optional
from datetime import datetime
# --- Auth Schemas ---
class Token(BaseModel):
access_token: str
token_type: str
class UserCreate(BaseModel):
email: EmailStr
password: str
display_name: Optional[str] = None
class UserResponse(BaseModel):
id: str
email: str
display_name: str
# --- History Schemas ---
class ConversationResponse(BaseModel):
id: str
name: str
summary: Optional[str] = None
created_at: datetime
data_state: str
class MessageResponse(BaseModel):
id: str
role: str
content: str
created_at: datetime
# Plots are fetched separately via artifact endpoints
class ConversationUpdate(BaseModel):
name: Optional[str] = None
summary: Optional[str] = None
# --- Agent Schemas ---
class ChatRequest(BaseModel):
message: str
thread_id: str # Maps to conversation_id
summary: Optional[str] = ""

View File

@@ -1,6 +1,8 @@
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import Optional, Union from typing import Optional, Union, Any, List, Dict
from jose import JWTError, jwt from jose import JWTError, jwt
from pydantic import BaseModel
from langchain_core.messages import BaseMessage
from ea_chatbot.config import Settings from ea_chatbot.config import Settings
settings = Settings() settings = Settings()
@@ -41,3 +43,29 @@ def decode_access_token(token: str) -> Optional[dict]:
return payload return payload
except JWTError: except JWTError:
return None return None
def convert_to_json_compatible(obj: Any) -> Any:
"""Recursively convert LangChain objects, Pydantic models, and others to JSON compatible formats."""
if isinstance(obj, list):
return [convert_to_json_compatible(item) for item in obj]
elif isinstance(obj, dict):
return {k: convert_to_json_compatible(v) for k, v in obj.items()}
elif isinstance(obj, BaseMessage):
return {"type": obj.type, "content": obj.content, **convert_to_json_compatible(obj.additional_kwargs)}
elif isinstance(obj, BaseModel):
return convert_to_json_compatible(obj.model_dump())
elif hasattr(obj, "model_dump"): # For Pydantic v2 if not caught by BaseModel
try:
return convert_to_json_compatible(obj.model_dump())
except Exception:
return str(obj)
elif hasattr(obj, "dict"): # Fallback for Pydantic v1 or other objects
try:
return convert_to_json_compatible(obj.dict())
except Exception:
return str(obj)
elif hasattr(obj, "content"):
return str(obj.content)
elif isinstance(obj, (datetime, timezone)):
return obj.isoformat()
return obj

View File

@@ -25,7 +25,7 @@ def test_stream_agent_unauthorized():
assert response.status_code == 401 assert response.status_code == 401
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_stream_agent_success(auth_header): async def test_stream_agent_success(auth_header, mock_user):
"""Test successful agent streaming with SSE.""" """Test successful agent streaming with SSE."""
# We need to mock the LangGraph app.astream_events # We need to mock the LangGraph app.astream_events
mock_events = [ mock_events = [
@@ -39,10 +39,18 @@ async def test_stream_agent_success(auth_header):
yield event yield event
with patch("ea_chatbot.api.routers.agent.app.astream_events", side_effect=mock_astream_events), \ with patch("ea_chatbot.api.routers.agent.app.astream_events", side_effect=mock_astream_events), \
patch("ea_chatbot.api.routers.agent.get_checkpointer") as mock_cp: patch("ea_chatbot.api.routers.agent.get_checkpointer") as mock_cp, \
patch("ea_chatbot.api.routers.agent.history_manager") as mock_hm:
mock_cp.return_value.__aenter__.return_value = AsyncMock() mock_cp.return_value.__aenter__.return_value = AsyncMock()
# Mock session and DB objects
mock_session = MagicMock()
mock_hm.get_session.return_value.__enter__.return_value = mock_session
from ea_chatbot.history.models import Conversation
mock_conv = Conversation(id="t1", user_id=mock_user.id)
mock_session.get.return_value = mock_conv
# Using TestClient with a stream context # Using TestClient with a stream context
with client.stream("POST", "/chat/stream", with client.stream("POST", "/chat/stream",
json={"message": "hello", "thread_id": "t1"}, json={"message": "hello", "thread_id": "t1"},

View File

@@ -0,0 +1,63 @@
import pytest
from fastapi.testclient import TestClient
from unittest.mock import MagicMock, patch, AsyncMock
from ea_chatbot.api.main import app
from ea_chatbot.api.dependencies import get_current_user
from ea_chatbot.history.models import User, Conversation, Message, Plot
from ea_chatbot.api.utils import create_access_token
from datetime import datetime, timezone
import json
client = TestClient(app)
@pytest.fixture
def mock_user():
return User(id="user-123", username="test@example.com", display_name="Test User")
@pytest.fixture
def auth_header(mock_user):
app.dependency_overrides[get_current_user] = lambda: mock_user
token = create_access_token(data={"sub": mock_user.username, "user_id": mock_user.id})
yield {"Authorization": f"Bearer {token}"}
app.dependency_overrides.clear()
def test_persistence_integration_success(auth_header, mock_user):
"""Test that messages and plots are persisted correctly during streaming."""
mock_events = [
{"event": "on_chat_model_stream", "name": "summarizer", "data": {"chunk": "Final answer"}},
{"event": "on_chain_end", "name": "summarizer", "data": {"output": {"messages": [{"content": "Final answer"}]}}},
{"event": "on_chain_end", "name": "summarize_conversation", "data": {"output": {"summary": "New summary"}}}
]
async def mock_astream_events(*args, **kwargs):
for event in mock_events:
yield event
with patch("ea_chatbot.api.routers.agent.app.astream_events", side_effect=mock_astream_events), \
patch("ea_chatbot.api.routers.agent.get_checkpointer") as mock_cp, \
patch("ea_chatbot.api.routers.agent.history_manager") as mock_hm:
mock_cp.return_value.__aenter__.return_value = AsyncMock()
# Mock session and DB objects
mock_session = MagicMock()
mock_hm.get_session.return_value.__enter__.return_value = mock_session
mock_conv = Conversation(id="t1", user_id=mock_user.id)
mock_session.get.return_value = mock_conv
# Act
with client.stream("POST", "/chat/stream",
json={"message": "persistence test", "thread_id": "t1"},
headers=auth_header) as response:
assert response.status_code == 200
list(response.iter_lines()) # Consume stream
# Assertions
# 1. User message should be saved immediately
mock_hm.add_message.assert_any_call("t1", "user", "persistence test")
# 2. Assistant message should be saved at the end
mock_hm.add_message.assert_any_call("t1", "assistant", "Final answer", plots=[])
# 3. Summary should be updated
mock_hm.update_conversation_summary.assert_called_once_with("t1", "New summary")