feat(api): Enhance persistence logic and refactor codebase structure
This commit is contained in:
@@ -1,41 +1,20 @@
|
|||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import AsyncGenerator, Optional
|
from typing import AsyncGenerator, Optional, List
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from pydantic import BaseModel
|
|
||||||
from ea_chatbot.api.dependencies import get_current_user, history_manager
|
from ea_chatbot.api.dependencies import get_current_user, history_manager
|
||||||
|
from ea_chatbot.api.utils import convert_to_json_compatible
|
||||||
from ea_chatbot.graph.workflow import app
|
from ea_chatbot.graph.workflow import app
|
||||||
from ea_chatbot.graph.checkpoint import get_checkpointer
|
from ea_chatbot.graph.checkpoint import get_checkpointer
|
||||||
from ea_chatbot.history.models import User as UserDB, Conversation
|
from ea_chatbot.history.models import User as UserDB, Conversation
|
||||||
|
from ea_chatbot.api.schemas import ChatRequest
|
||||||
import io
|
import io
|
||||||
import base64
|
import base64
|
||||||
from langchain_core.messages import BaseMessage
|
from langchain_core.messages import BaseMessage
|
||||||
|
|
||||||
router = APIRouter(prefix="/chat", tags=["agent"])
|
router = APIRouter(prefix="/chat", tags=["agent"])
|
||||||
|
|
||||||
def convert_to_json_compatible(obj):
|
|
||||||
"""Recursively convert LangChain objects to JSON compatible formats."""
|
|
||||||
if isinstance(obj, list):
|
|
||||||
return [convert_to_json_compatible(item) for item in obj]
|
|
||||||
elif isinstance(obj, dict):
|
|
||||||
return {k: convert_to_json_compatible(v) for k, v in obj.items()}
|
|
||||||
elif isinstance(obj, BaseMessage):
|
|
||||||
return {"type": obj.type, "content": obj.content, **convert_to_json_compatible(obj.additional_kwargs)}
|
|
||||||
elif isinstance(obj, BaseModel) or hasattr(obj, "model_dump"):
|
|
||||||
try:
|
|
||||||
return convert_to_json_compatible(obj.model_dump())
|
|
||||||
except Exception:
|
|
||||||
return str(obj)
|
|
||||||
elif hasattr(obj, "content"):
|
|
||||||
return str(obj.content)
|
|
||||||
return obj
|
|
||||||
|
|
||||||
class ChatRequest(BaseModel):
|
|
||||||
message: str
|
|
||||||
thread_id: str # This maps to the conversation_id in our DB
|
|
||||||
summary: Optional[str] = ""
|
|
||||||
|
|
||||||
async def stream_agent_events(
|
async def stream_agent_events(
|
||||||
message: str,
|
message: str,
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
@@ -44,6 +23,7 @@ async def stream_agent_events(
|
|||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
"""
|
"""
|
||||||
Generator that invokes the LangGraph agent and yields SSE formatted events.
|
Generator that invokes the LangGraph agent and yields SSE formatted events.
|
||||||
|
Persists assistant responses and plots to the database.
|
||||||
"""
|
"""
|
||||||
initial_state = {
|
initial_state = {
|
||||||
"messages": [],
|
"messages": [],
|
||||||
@@ -61,6 +41,12 @@ async def stream_agent_events(
|
|||||||
|
|
||||||
config = {"configurable": {"thread_id": thread_id}}
|
config = {"configurable": {"thread_id": thread_id}}
|
||||||
|
|
||||||
|
assistant_chunks: List[str] = []
|
||||||
|
assistant_plots: List[bytes] = []
|
||||||
|
final_response: str = ""
|
||||||
|
new_summary: str = ""
|
||||||
|
|
||||||
|
try:
|
||||||
async with get_checkpointer() as checkpointer:
|
async with get_checkpointer() as checkpointer:
|
||||||
async for event in app.astream_events(
|
async for event in app.astream_events(
|
||||||
initial_state,
|
initial_state,
|
||||||
@@ -70,30 +56,77 @@ async def stream_agent_events(
|
|||||||
):
|
):
|
||||||
kind = event.get("event")
|
kind = event.get("event")
|
||||||
name = event.get("name")
|
name = event.get("name")
|
||||||
|
data = event.get("data", {})
|
||||||
|
|
||||||
|
# Standardize event for frontend
|
||||||
output_event = {
|
output_event = {
|
||||||
"type": kind,
|
"type": kind,
|
||||||
"name": name,
|
"name": name,
|
||||||
"data": event.get("data", {})
|
"data": data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Buffer assistant chunks (summarizer and researcher might stream)
|
||||||
|
if kind == "on_chat_model_stream" and name in ["summarizer", "researcher"]:
|
||||||
|
chunk = data.get("chunk", "")
|
||||||
|
if hasattr(chunk, "content"):
|
||||||
|
chunk = chunk.content
|
||||||
|
assistant_chunks.append(str(chunk))
|
||||||
|
|
||||||
|
# Buffer and encode plots
|
||||||
if kind == "on_chain_end" and name == "executor":
|
if kind == "on_chain_end" and name == "executor":
|
||||||
output = event.get("data", {}).get("output", {})
|
output = data.get("output", {})
|
||||||
if isinstance(output, dict) and "plots" in output:
|
if isinstance(output, dict) and "plots" in output:
|
||||||
plots = output["plots"]
|
plots = output["plots"]
|
||||||
encoded_plots = []
|
encoded_plots = []
|
||||||
for fig in plots:
|
for fig in plots:
|
||||||
buf = io.BytesIO()
|
buf = io.BytesIO()
|
||||||
fig.savefig(buf, format="png")
|
fig.savefig(buf, format="png")
|
||||||
encoded_plots.append(base64.b64encode(buf.getvalue()).decode('utf-8'))
|
plot_bytes = buf.getvalue()
|
||||||
|
assistant_plots.append(plot_bytes)
|
||||||
|
encoded_plots.append(base64.b64encode(plot_bytes).decode('utf-8'))
|
||||||
output_event["data"]["encoded_plots"] = encoded_plots
|
output_event["data"]["encoded_plots"] = encoded_plots
|
||||||
|
|
||||||
|
# Collect final response from terminal nodes
|
||||||
|
if kind == "on_chain_end" and name in ["summarizer", "researcher", "clarification"]:
|
||||||
|
output = data.get("output", {})
|
||||||
|
if isinstance(output, dict) and "messages" in output:
|
||||||
|
last_msg = output["messages"][-1]
|
||||||
|
if hasattr(last_msg, "content"):
|
||||||
|
final_response = last_msg.content
|
||||||
|
elif isinstance(last_msg, dict) and "content" in last_msg:
|
||||||
|
final_response = last_msg["content"]
|
||||||
|
else:
|
||||||
|
final_response = str(last_msg)
|
||||||
|
|
||||||
|
# Collect new summary
|
||||||
|
if kind == "on_chain_end" and name == "summarize_conversation":
|
||||||
|
output = data.get("output", {})
|
||||||
|
if isinstance(output, dict) and "summary" in output:
|
||||||
|
new_summary = output["summary"]
|
||||||
|
|
||||||
# Convert to JSON compatible format to avoid serialization errors
|
# Convert to JSON compatible format to avoid serialization errors
|
||||||
compatible_output = convert_to_json_compatible(output_event)
|
compatible_output = convert_to_json_compatible(output_event)
|
||||||
yield f"data: {json.dumps(compatible_output)}\n\n"
|
yield f"data: {json.dumps(compatible_output)}\n\n"
|
||||||
|
|
||||||
|
# If we didn't get a final_response from node output, use buffered chunks
|
||||||
|
if not final_response and assistant_chunks:
|
||||||
|
final_response = "".join(assistant_chunks)
|
||||||
|
|
||||||
|
# Save assistant message to DB
|
||||||
|
if final_response:
|
||||||
|
history_manager.add_message(thread_id, "assistant", final_response, plots=assistant_plots)
|
||||||
|
|
||||||
|
# Update summary in DB
|
||||||
|
if new_summary:
|
||||||
|
history_manager.update_conversation_summary(thread_id, new_summary)
|
||||||
|
|
||||||
yield "data: {\"type\": \"done\"}\n\n"
|
yield "data: {\"type\": \"done\"}\n\n"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Agent execution failed: {str(e)}"
|
||||||
|
history_manager.add_message(thread_id, "assistant", error_msg)
|
||||||
|
yield f"data: {json.dumps({'type': 'error', 'message': error_msg})}\n\n"
|
||||||
|
|
||||||
@router.post("/stream")
|
@router.post("/stream")
|
||||||
async def chat_stream(
|
async def chat_stream(
|
||||||
request: ChatRequest,
|
request: ChatRequest,
|
||||||
@@ -104,9 +137,14 @@ async def chat_stream(
|
|||||||
"""
|
"""
|
||||||
with history_manager.get_session() as session:
|
with history_manager.get_session() as session:
|
||||||
conv = session.get(Conversation, request.thread_id)
|
conv = session.get(Conversation, request.thread_id)
|
||||||
if conv and conv.user_id != current_user.id:
|
if not conv:
|
||||||
|
raise HTTPException(status_code=404, detail="Conversation not found")
|
||||||
|
if conv.user_id != current_user.id:
|
||||||
raise HTTPException(status_code=403, detail="Not authorized to access this conversation")
|
raise HTTPException(status_code=403, detail="Not authorized to access this conversation")
|
||||||
|
|
||||||
|
# Save user message immediately
|
||||||
|
history_manager.add_message(request.thread_id, "user", request.message)
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
stream_agent_events(
|
stream_agent_events(
|
||||||
request.message,
|
request.message,
|
||||||
|
|||||||
@@ -1,27 +1,12 @@
|
|||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
from fastapi.security import OAuth2PasswordRequestForm
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
from pydantic import BaseModel, EmailStr
|
|
||||||
from typing import Optional
|
|
||||||
from ea_chatbot.api.utils import create_access_token
|
from ea_chatbot.api.utils import create_access_token
|
||||||
from ea_chatbot.api.dependencies import history_manager, oidc_client, get_current_user
|
from ea_chatbot.api.dependencies import history_manager, oidc_client, get_current_user
|
||||||
from ea_chatbot.history.models import User as UserDB
|
from ea_chatbot.history.models import User as UserDB
|
||||||
|
from ea_chatbot.api.schemas import Token, UserCreate, UserResponse
|
||||||
|
|
||||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||||
|
|
||||||
class Token(BaseModel):
|
|
||||||
access_token: str
|
|
||||||
token_type: str
|
|
||||||
|
|
||||||
class UserCreate(BaseModel):
|
|
||||||
email: EmailStr
|
|
||||||
password: str
|
|
||||||
display_name: Optional[str] = None
|
|
||||||
|
|
||||||
class UserResponse(BaseModel):
|
|
||||||
id: str
|
|
||||||
email: str
|
|
||||||
display_name: str
|
|
||||||
|
|
||||||
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||||
async def register(user_in: UserCreate):
|
async def register(user_in: UserCreate):
|
||||||
"""Register a new user."""
|
"""Register a new user."""
|
||||||
|
|||||||
@@ -1,30 +1,11 @@
|
|||||||
from fastapi import APIRouter, Depends, HTTPException, status, Response
|
from fastapi import APIRouter, Depends, HTTPException, status, Response
|
||||||
from pydantic import BaseModel
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from datetime import datetime
|
|
||||||
from ea_chatbot.api.dependencies import get_current_user, history_manager, settings
|
from ea_chatbot.api.dependencies import get_current_user, history_manager, settings
|
||||||
from ea_chatbot.history.models import User as UserDB
|
from ea_chatbot.history.models import User as UserDB
|
||||||
|
from ea_chatbot.api.schemas import ConversationResponse, MessageResponse, ConversationUpdate
|
||||||
|
|
||||||
router = APIRouter(prefix="/conversations", tags=["history"])
|
router = APIRouter(prefix="/conversations", tags=["history"])
|
||||||
|
|
||||||
class ConversationResponse(BaseModel):
|
|
||||||
id: str
|
|
||||||
name: str
|
|
||||||
summary: Optional[str] = None
|
|
||||||
created_at: datetime
|
|
||||||
data_state: str
|
|
||||||
|
|
||||||
class MessageResponse(BaseModel):
|
|
||||||
id: str
|
|
||||||
role: str
|
|
||||||
content: str
|
|
||||||
created_at: datetime
|
|
||||||
# We don't include plots directly here, they'll be fetched via artifact endpoints
|
|
||||||
|
|
||||||
class ConversationUpdate(BaseModel):
|
|
||||||
name: Optional[str] = None
|
|
||||||
summary: Optional[str] = None
|
|
||||||
|
|
||||||
@router.get("", response_model=List[ConversationResponse])
|
@router.get("", response_model=List[ConversationResponse])
|
||||||
async def list_conversations(
|
async def list_conversations(
|
||||||
current_user: UserDB = Depends(get_current_user),
|
current_user: UserDB = Depends(get_current_user),
|
||||||
|
|||||||
46
src/ea_chatbot/api/schemas.py
Normal file
46
src/ea_chatbot/api/schemas.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
from pydantic import BaseModel, EmailStr
|
||||||
|
from typing import List, Optional
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
# --- Auth Schemas ---
|
||||||
|
|
||||||
|
class Token(BaseModel):
|
||||||
|
access_token: str
|
||||||
|
token_type: str
|
||||||
|
|
||||||
|
class UserCreate(BaseModel):
|
||||||
|
email: EmailStr
|
||||||
|
password: str
|
||||||
|
display_name: Optional[str] = None
|
||||||
|
|
||||||
|
class UserResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
email: str
|
||||||
|
display_name: str
|
||||||
|
|
||||||
|
# --- History Schemas ---
|
||||||
|
|
||||||
|
class ConversationResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
summary: Optional[str] = None
|
||||||
|
created_at: datetime
|
||||||
|
data_state: str
|
||||||
|
|
||||||
|
class MessageResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
role: str
|
||||||
|
content: str
|
||||||
|
created_at: datetime
|
||||||
|
# Plots are fetched separately via artifact endpoints
|
||||||
|
|
||||||
|
class ConversationUpdate(BaseModel):
|
||||||
|
name: Optional[str] = None
|
||||||
|
summary: Optional[str] = None
|
||||||
|
|
||||||
|
# --- Agent Schemas ---
|
||||||
|
|
||||||
|
class ChatRequest(BaseModel):
|
||||||
|
message: str
|
||||||
|
thread_id: str # Maps to conversation_id
|
||||||
|
summary: Optional[str] = ""
|
||||||
@@ -1,6 +1,8 @@
|
|||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union, Any, List, Dict
|
||||||
from jose import JWTError, jwt
|
from jose import JWTError, jwt
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from langchain_core.messages import BaseMessage
|
||||||
from ea_chatbot.config import Settings
|
from ea_chatbot.config import Settings
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
@@ -41,3 +43,29 @@ def decode_access_token(token: str) -> Optional[dict]:
|
|||||||
return payload
|
return payload
|
||||||
except JWTError:
|
except JWTError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def convert_to_json_compatible(obj: Any) -> Any:
|
||||||
|
"""Recursively convert LangChain objects, Pydantic models, and others to JSON compatible formats."""
|
||||||
|
if isinstance(obj, list):
|
||||||
|
return [convert_to_json_compatible(item) for item in obj]
|
||||||
|
elif isinstance(obj, dict):
|
||||||
|
return {k: convert_to_json_compatible(v) for k, v in obj.items()}
|
||||||
|
elif isinstance(obj, BaseMessage):
|
||||||
|
return {"type": obj.type, "content": obj.content, **convert_to_json_compatible(obj.additional_kwargs)}
|
||||||
|
elif isinstance(obj, BaseModel):
|
||||||
|
return convert_to_json_compatible(obj.model_dump())
|
||||||
|
elif hasattr(obj, "model_dump"): # For Pydantic v2 if not caught by BaseModel
|
||||||
|
try:
|
||||||
|
return convert_to_json_compatible(obj.model_dump())
|
||||||
|
except Exception:
|
||||||
|
return str(obj)
|
||||||
|
elif hasattr(obj, "dict"): # Fallback for Pydantic v1 or other objects
|
||||||
|
try:
|
||||||
|
return convert_to_json_compatible(obj.dict())
|
||||||
|
except Exception:
|
||||||
|
return str(obj)
|
||||||
|
elif hasattr(obj, "content"):
|
||||||
|
return str(obj.content)
|
||||||
|
elif isinstance(obj, (datetime, timezone)):
|
||||||
|
return obj.isoformat()
|
||||||
|
return obj
|
||||||
@@ -25,7 +25,7 @@ def test_stream_agent_unauthorized():
|
|||||||
assert response.status_code == 401
|
assert response.status_code == 401
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_stream_agent_success(auth_header):
|
async def test_stream_agent_success(auth_header, mock_user):
|
||||||
"""Test successful agent streaming with SSE."""
|
"""Test successful agent streaming with SSE."""
|
||||||
# We need to mock the LangGraph app.astream_events
|
# We need to mock the LangGraph app.astream_events
|
||||||
mock_events = [
|
mock_events = [
|
||||||
@@ -39,10 +39,18 @@ async def test_stream_agent_success(auth_header):
|
|||||||
yield event
|
yield event
|
||||||
|
|
||||||
with patch("ea_chatbot.api.routers.agent.app.astream_events", side_effect=mock_astream_events), \
|
with patch("ea_chatbot.api.routers.agent.app.astream_events", side_effect=mock_astream_events), \
|
||||||
patch("ea_chatbot.api.routers.agent.get_checkpointer") as mock_cp:
|
patch("ea_chatbot.api.routers.agent.get_checkpointer") as mock_cp, \
|
||||||
|
patch("ea_chatbot.api.routers.agent.history_manager") as mock_hm:
|
||||||
|
|
||||||
mock_cp.return_value.__aenter__.return_value = AsyncMock()
|
mock_cp.return_value.__aenter__.return_value = AsyncMock()
|
||||||
|
|
||||||
|
# Mock session and DB objects
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_hm.get_session.return_value.__enter__.return_value = mock_session
|
||||||
|
from ea_chatbot.history.models import Conversation
|
||||||
|
mock_conv = Conversation(id="t1", user_id=mock_user.id)
|
||||||
|
mock_session.get.return_value = mock_conv
|
||||||
|
|
||||||
# Using TestClient with a stream context
|
# Using TestClient with a stream context
|
||||||
with client.stream("POST", "/chat/stream",
|
with client.stream("POST", "/chat/stream",
|
||||||
json={"message": "hello", "thread_id": "t1"},
|
json={"message": "hello", "thread_id": "t1"},
|
||||||
|
|||||||
63
tests/api/test_persistence.py
Normal file
63
tests/api/test_persistence.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from unittest.mock import MagicMock, patch, AsyncMock
|
||||||
|
from ea_chatbot.api.main import app
|
||||||
|
from ea_chatbot.api.dependencies import get_current_user
|
||||||
|
from ea_chatbot.history.models import User, Conversation, Message, Plot
|
||||||
|
from ea_chatbot.api.utils import create_access_token
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
import json
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_user():
|
||||||
|
return User(id="user-123", username="test@example.com", display_name="Test User")
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def auth_header(mock_user):
|
||||||
|
app.dependency_overrides[get_current_user] = lambda: mock_user
|
||||||
|
token = create_access_token(data={"sub": mock_user.username, "user_id": mock_user.id})
|
||||||
|
yield {"Authorization": f"Bearer {token}"}
|
||||||
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
def test_persistence_integration_success(auth_header, mock_user):
|
||||||
|
"""Test that messages and plots are persisted correctly during streaming."""
|
||||||
|
mock_events = [
|
||||||
|
{"event": "on_chat_model_stream", "name": "summarizer", "data": {"chunk": "Final answer"}},
|
||||||
|
{"event": "on_chain_end", "name": "summarizer", "data": {"output": {"messages": [{"content": "Final answer"}]}}},
|
||||||
|
{"event": "on_chain_end", "name": "summarize_conversation", "data": {"output": {"summary": "New summary"}}}
|
||||||
|
]
|
||||||
|
|
||||||
|
async def mock_astream_events(*args, **kwargs):
|
||||||
|
for event in mock_events:
|
||||||
|
yield event
|
||||||
|
|
||||||
|
with patch("ea_chatbot.api.routers.agent.app.astream_events", side_effect=mock_astream_events), \
|
||||||
|
patch("ea_chatbot.api.routers.agent.get_checkpointer") as mock_cp, \
|
||||||
|
patch("ea_chatbot.api.routers.agent.history_manager") as mock_hm:
|
||||||
|
|
||||||
|
mock_cp.return_value.__aenter__.return_value = AsyncMock()
|
||||||
|
|
||||||
|
# Mock session and DB objects
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_hm.get_session.return_value.__enter__.return_value = mock_session
|
||||||
|
mock_conv = Conversation(id="t1", user_id=mock_user.id)
|
||||||
|
mock_session.get.return_value = mock_conv
|
||||||
|
|
||||||
|
# Act
|
||||||
|
with client.stream("POST", "/chat/stream",
|
||||||
|
json={"message": "persistence test", "thread_id": "t1"},
|
||||||
|
headers=auth_header) as response:
|
||||||
|
assert response.status_code == 200
|
||||||
|
list(response.iter_lines()) # Consume stream
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
# 1. User message should be saved immediately
|
||||||
|
mock_hm.add_message.assert_any_call("t1", "user", "persistence test")
|
||||||
|
|
||||||
|
# 2. Assistant message should be saved at the end
|
||||||
|
mock_hm.add_message.assert_any_call("t1", "assistant", "Final answer", plots=[])
|
||||||
|
|
||||||
|
# 3. Summary should be updated
|
||||||
|
mock_hm.update_conversation_summary.assert_called_once_with("t1", "New summary")
|
||||||
Reference in New Issue
Block a user