diff --git a/src/ea_chatbot/api/routers/agent.py b/src/ea_chatbot/api/routers/agent.py index 24038f8..df5b544 100644 --- a/src/ea_chatbot/api/routers/agent.py +++ b/src/ea_chatbot/api/routers/agent.py @@ -1,41 +1,20 @@ import json import asyncio -from typing import AsyncGenerator, Optional +from typing import AsyncGenerator, Optional, List from fastapi import APIRouter, Depends, HTTPException, status from fastapi.responses import StreamingResponse -from pydantic import BaseModel from ea_chatbot.api.dependencies import get_current_user, history_manager +from ea_chatbot.api.utils import convert_to_json_compatible from ea_chatbot.graph.workflow import app from ea_chatbot.graph.checkpoint import get_checkpointer from ea_chatbot.history.models import User as UserDB, Conversation +from ea_chatbot.api.schemas import ChatRequest import io import base64 from langchain_core.messages import BaseMessage router = APIRouter(prefix="/chat", tags=["agent"]) -def convert_to_json_compatible(obj): - """Recursively convert LangChain objects to JSON compatible formats.""" - if isinstance(obj, list): - return [convert_to_json_compatible(item) for item in obj] - elif isinstance(obj, dict): - return {k: convert_to_json_compatible(v) for k, v in obj.items()} - elif isinstance(obj, BaseMessage): - return {"type": obj.type, "content": obj.content, **convert_to_json_compatible(obj.additional_kwargs)} - elif isinstance(obj, BaseModel) or hasattr(obj, "model_dump"): - try: - return convert_to_json_compatible(obj.model_dump()) - except Exception: - return str(obj) - elif hasattr(obj, "content"): - return str(obj.content) - return obj - -class ChatRequest(BaseModel): - message: str - thread_id: str # This maps to the conversation_id in our DB - summary: Optional[str] = "" - async def stream_agent_events( message: str, thread_id: str, @@ -44,6 +23,7 @@ async def stream_agent_events( ) -> AsyncGenerator[str, None]: """ Generator that invokes the LangGraph agent and yields SSE formatted events. + Persists assistant responses and plots to the database. """ initial_state = { "messages": [], @@ -61,38 +41,91 @@ async def stream_agent_events( config = {"configurable": {"thread_id": thread_id}} - async with get_checkpointer() as checkpointer: - async for event in app.astream_events( - initial_state, - config, - version="v2", - checkpointer=checkpointer - ): - kind = event.get("event") - name = event.get("name") - - output_event = { - "type": kind, - "name": name, - "data": event.get("data", {}) - } - - if kind == "on_chain_end" and name == "executor": - output = event.get("data", {}).get("output", {}) - if isinstance(output, dict) and "plots" in output: - plots = output["plots"] - encoded_plots = [] - for fig in plots: - buf = io.BytesIO() - fig.savefig(buf, format="png") - encoded_plots.append(base64.b64encode(buf.getvalue()).decode('utf-8')) - output_event["data"]["encoded_plots"] = encoded_plots + assistant_chunks: List[str] = [] + assistant_plots: List[bytes] = [] + final_response: str = "" + new_summary: str = "" + + try: + async with get_checkpointer() as checkpointer: + async for event in app.astream_events( + initial_state, + config, + version="v2", + checkpointer=checkpointer + ): + kind = event.get("event") + name = event.get("name") + data = event.get("data", {}) + + # Standardize event for frontend + output_event = { + "type": kind, + "name": name, + "data": data + } + + # Buffer assistant chunks (summarizer and researcher might stream) + if kind == "on_chat_model_stream" and name in ["summarizer", "researcher"]: + chunk = data.get("chunk", "") + if hasattr(chunk, "content"): + chunk = chunk.content + assistant_chunks.append(str(chunk)) + + # Buffer and encode plots + if kind == "on_chain_end" and name == "executor": + output = data.get("output", {}) + if isinstance(output, dict) and "plots" in output: + plots = output["plots"] + encoded_plots = [] + for fig in plots: + buf = io.BytesIO() + fig.savefig(buf, format="png") + plot_bytes = buf.getvalue() + assistant_plots.append(plot_bytes) + encoded_plots.append(base64.b64encode(plot_bytes).decode('utf-8')) + output_event["data"]["encoded_plots"] = encoded_plots - # Convert to JSON compatible format to avoid serialization errors - compatible_output = convert_to_json_compatible(output_event) - yield f"data: {json.dumps(compatible_output)}\n\n" + # Collect final response from terminal nodes + if kind == "on_chain_end" and name in ["summarizer", "researcher", "clarification"]: + output = data.get("output", {}) + if isinstance(output, dict) and "messages" in output: + last_msg = output["messages"][-1] + if hasattr(last_msg, "content"): + final_response = last_msg.content + elif isinstance(last_msg, dict) and "content" in last_msg: + final_response = last_msg["content"] + else: + final_response = str(last_msg) + + # Collect new summary + if kind == "on_chain_end" and name == "summarize_conversation": + output = data.get("output", {}) + if isinstance(output, dict) and "summary" in output: + new_summary = output["summary"] + + # Convert to JSON compatible format to avoid serialization errors + compatible_output = convert_to_json_compatible(output_event) + yield f"data: {json.dumps(compatible_output)}\n\n" + + # If we didn't get a final_response from node output, use buffered chunks + if not final_response and assistant_chunks: + final_response = "".join(assistant_chunks) + + # Save assistant message to DB + if final_response: + history_manager.add_message(thread_id, "assistant", final_response, plots=assistant_plots) + + # Update summary in DB + if new_summary: + history_manager.update_conversation_summary(thread_id, new_summary) yield "data: {\"type\": \"done\"}\n\n" + + except Exception as e: + error_msg = f"Agent execution failed: {str(e)}" + history_manager.add_message(thread_id, "assistant", error_msg) + yield f"data: {json.dumps({'type': 'error', 'message': error_msg})}\n\n" @router.post("/stream") async def chat_stream( @@ -104,9 +137,14 @@ async def chat_stream( """ with history_manager.get_session() as session: conv = session.get(Conversation, request.thread_id) - if conv and conv.user_id != current_user.id: + if not conv: + raise HTTPException(status_code=404, detail="Conversation not found") + if conv.user_id != current_user.id: raise HTTPException(status_code=403, detail="Not authorized to access this conversation") + # Save user message immediately + history_manager.add_message(request.thread_id, "user", request.message) + return StreamingResponse( stream_agent_events( request.message, diff --git a/src/ea_chatbot/api/routers/auth.py b/src/ea_chatbot/api/routers/auth.py index ad4eb19..28d4953 100644 --- a/src/ea_chatbot/api/routers/auth.py +++ b/src/ea_chatbot/api/routers/auth.py @@ -1,27 +1,12 @@ from fastapi import APIRouter, Depends, HTTPException, status from fastapi.security import OAuth2PasswordRequestForm -from pydantic import BaseModel, EmailStr -from typing import Optional from ea_chatbot.api.utils import create_access_token from ea_chatbot.api.dependencies import history_manager, oidc_client, get_current_user from ea_chatbot.history.models import User as UserDB +from ea_chatbot.api.schemas import Token, UserCreate, UserResponse router = APIRouter(prefix="/auth", tags=["auth"]) -class Token(BaseModel): - access_token: str - token_type: str - -class UserCreate(BaseModel): - email: EmailStr - password: str - display_name: Optional[str] = None - -class UserResponse(BaseModel): - id: str - email: str - display_name: str - @router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED) async def register(user_in: UserCreate): """Register a new user.""" @@ -98,4 +83,4 @@ async def get_me(current_user: UserDB = Depends(get_current_user)): "id": str(current_user.id), "email": current_user.username, "display_name": current_user.display_name - } \ No newline at end of file + } diff --git a/src/ea_chatbot/api/routers/history.py b/src/ea_chatbot/api/routers/history.py index cc0741a..46a2e93 100644 --- a/src/ea_chatbot/api/routers/history.py +++ b/src/ea_chatbot/api/routers/history.py @@ -1,30 +1,11 @@ from fastapi import APIRouter, Depends, HTTPException, status, Response -from pydantic import BaseModel from typing import List, Optional -from datetime import datetime from ea_chatbot.api.dependencies import get_current_user, history_manager, settings from ea_chatbot.history.models import User as UserDB +from ea_chatbot.api.schemas import ConversationResponse, MessageResponse, ConversationUpdate router = APIRouter(prefix="/conversations", tags=["history"]) -class ConversationResponse(BaseModel): - id: str - name: str - summary: Optional[str] = None - created_at: datetime - data_state: str - -class MessageResponse(BaseModel): - id: str - role: str - content: str - created_at: datetime - # We don't include plots directly here, they'll be fetched via artifact endpoints - -class ConversationUpdate(BaseModel): - name: Optional[str] = None - summary: Optional[str] = None - @router.get("", response_model=List[ConversationResponse]) async def list_conversations( current_user: UserDB = Depends(get_current_user), @@ -94,4 +75,4 @@ async def delete_conversation( success = history_manager.delete_conversation(conversation_id) if not success: raise HTTPException(status_code=404, detail="Conversation not found") - return Response(status_code=status.HTTP_204_NO_CONTENT) + return Response(status_code=status.HTTP_204_NO_CONTENT) \ No newline at end of file diff --git a/src/ea_chatbot/api/schemas.py b/src/ea_chatbot/api/schemas.py new file mode 100644 index 0000000..3d7cb3f --- /dev/null +++ b/src/ea_chatbot/api/schemas.py @@ -0,0 +1,46 @@ +from pydantic import BaseModel, EmailStr +from typing import List, Optional +from datetime import datetime + +# --- Auth Schemas --- + +class Token(BaseModel): + access_token: str + token_type: str + +class UserCreate(BaseModel): + email: EmailStr + password: str + display_name: Optional[str] = None + +class UserResponse(BaseModel): + id: str + email: str + display_name: str + +# --- History Schemas --- + +class ConversationResponse(BaseModel): + id: str + name: str + summary: Optional[str] = None + created_at: datetime + data_state: str + +class MessageResponse(BaseModel): + id: str + role: str + content: str + created_at: datetime + # Plots are fetched separately via artifact endpoints + +class ConversationUpdate(BaseModel): + name: Optional[str] = None + summary: Optional[str] = None + +# --- Agent Schemas --- + +class ChatRequest(BaseModel): + message: str + thread_id: str # Maps to conversation_id + summary: Optional[str] = "" diff --git a/src/ea_chatbot/api/utils.py b/src/ea_chatbot/api/utils.py index d30ba9c..b7a176e 100644 --- a/src/ea_chatbot/api/utils.py +++ b/src/ea_chatbot/api/utils.py @@ -1,6 +1,8 @@ from datetime import datetime, timedelta, timezone -from typing import Optional, Union +from typing import Optional, Union, Any, List, Dict from jose import JWTError, jwt +from pydantic import BaseModel +from langchain_core.messages import BaseMessage from ea_chatbot.config import Settings settings = Settings() @@ -41,3 +43,29 @@ def decode_access_token(token: str) -> Optional[dict]: return payload except JWTError: return None + +def convert_to_json_compatible(obj: Any) -> Any: + """Recursively convert LangChain objects, Pydantic models, and others to JSON compatible formats.""" + if isinstance(obj, list): + return [convert_to_json_compatible(item) for item in obj] + elif isinstance(obj, dict): + return {k: convert_to_json_compatible(v) for k, v in obj.items()} + elif isinstance(obj, BaseMessage): + return {"type": obj.type, "content": obj.content, **convert_to_json_compatible(obj.additional_kwargs)} + elif isinstance(obj, BaseModel): + return convert_to_json_compatible(obj.model_dump()) + elif hasattr(obj, "model_dump"): # For Pydantic v2 if not caught by BaseModel + try: + return convert_to_json_compatible(obj.model_dump()) + except Exception: + return str(obj) + elif hasattr(obj, "dict"): # Fallback for Pydantic v1 or other objects + try: + return convert_to_json_compatible(obj.dict()) + except Exception: + return str(obj) + elif hasattr(obj, "content"): + return str(obj.content) + elif isinstance(obj, (datetime, timezone)): + return obj.isoformat() + return obj \ No newline at end of file diff --git a/tests/api/test_agent.py b/tests/api/test_agent.py index c9d3b63..ae4f21b 100644 --- a/tests/api/test_agent.py +++ b/tests/api/test_agent.py @@ -25,7 +25,7 @@ def test_stream_agent_unauthorized(): assert response.status_code == 401 @pytest.mark.asyncio -async def test_stream_agent_success(auth_header): +async def test_stream_agent_success(auth_header, mock_user): """Test successful agent streaming with SSE.""" # We need to mock the LangGraph app.astream_events mock_events = [ @@ -39,10 +39,18 @@ async def test_stream_agent_success(auth_header): yield event with patch("ea_chatbot.api.routers.agent.app.astream_events", side_effect=mock_astream_events), \ - patch("ea_chatbot.api.routers.agent.get_checkpointer") as mock_cp: + patch("ea_chatbot.api.routers.agent.get_checkpointer") as mock_cp, \ + patch("ea_chatbot.api.routers.agent.history_manager") as mock_hm: mock_cp.return_value.__aenter__.return_value = AsyncMock() + # Mock session and DB objects + mock_session = MagicMock() + mock_hm.get_session.return_value.__enter__.return_value = mock_session + from ea_chatbot.history.models import Conversation + mock_conv = Conversation(id="t1", user_id=mock_user.id) + mock_session.get.return_value = mock_conv + # Using TestClient with a stream context with client.stream("POST", "/chat/stream", json={"message": "hello", "thread_id": "t1"}, diff --git a/tests/api/test_persistence.py b/tests/api/test_persistence.py new file mode 100644 index 0000000..fac2640 --- /dev/null +++ b/tests/api/test_persistence.py @@ -0,0 +1,63 @@ +import pytest +from fastapi.testclient import TestClient +from unittest.mock import MagicMock, patch, AsyncMock +from ea_chatbot.api.main import app +from ea_chatbot.api.dependencies import get_current_user +from ea_chatbot.history.models import User, Conversation, Message, Plot +from ea_chatbot.api.utils import create_access_token +from datetime import datetime, timezone +import json + +client = TestClient(app) + +@pytest.fixture +def mock_user(): + return User(id="user-123", username="test@example.com", display_name="Test User") + +@pytest.fixture +def auth_header(mock_user): + app.dependency_overrides[get_current_user] = lambda: mock_user + token = create_access_token(data={"sub": mock_user.username, "user_id": mock_user.id}) + yield {"Authorization": f"Bearer {token}"} + app.dependency_overrides.clear() + +def test_persistence_integration_success(auth_header, mock_user): + """Test that messages and plots are persisted correctly during streaming.""" + mock_events = [ + {"event": "on_chat_model_stream", "name": "summarizer", "data": {"chunk": "Final answer"}}, + {"event": "on_chain_end", "name": "summarizer", "data": {"output": {"messages": [{"content": "Final answer"}]}}}, + {"event": "on_chain_end", "name": "summarize_conversation", "data": {"output": {"summary": "New summary"}}} + ] + + async def mock_astream_events(*args, **kwargs): + for event in mock_events: + yield event + + with patch("ea_chatbot.api.routers.agent.app.astream_events", side_effect=mock_astream_events), \ + patch("ea_chatbot.api.routers.agent.get_checkpointer") as mock_cp, \ + patch("ea_chatbot.api.routers.agent.history_manager") as mock_hm: + + mock_cp.return_value.__aenter__.return_value = AsyncMock() + + # Mock session and DB objects + mock_session = MagicMock() + mock_hm.get_session.return_value.__enter__.return_value = mock_session + mock_conv = Conversation(id="t1", user_id=mock_user.id) + mock_session.get.return_value = mock_conv + + # Act + with client.stream("POST", "/chat/stream", + json={"message": "persistence test", "thread_id": "t1"}, + headers=auth_header) as response: + assert response.status_code == 200 + list(response.iter_lines()) # Consume stream + + # Assertions + # 1. User message should be saved immediately + mock_hm.add_message.assert_any_call("t1", "user", "persistence test") + + # 2. Assistant message should be saved at the end + mock_hm.add_message.assert_any_call("t1", "assistant", "Final answer", plots=[]) + + # 3. Summary should be updated + mock_hm.update_conversation_summary.assert_called_once_with("t1", "New summary")