From b8fa60962e3b7f48973876752a81f69598cef293 Mon Sep 17 00:00:00 2001 From: Yunxiao Xu Date: Tue, 10 Feb 2026 15:49:15 -0800 Subject: [PATCH] feat(api): Implement agent streaming router with SSE --- src/ea_chatbot/api/main.py | 3 +- src/ea_chatbot/api/routers/agent.py | 98 +++++++++++++++++++++++++++++ tests/api/test_agent.py | 2 +- 3 files changed, 101 insertions(+), 2 deletions(-) create mode 100644 src/ea_chatbot/api/routers/agent.py diff --git a/src/ea_chatbot/api/main.py b/src/ea_chatbot/api/main.py index a9cb837..7a8a838 100644 --- a/src/ea_chatbot/api/main.py +++ b/src/ea_chatbot/api/main.py @@ -1,6 +1,6 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware -from ea_chatbot.api.routers import auth, history, artifacts +from ea_chatbot.api.routers import auth, history, artifacts, agent app = FastAPI( title="Election Analytics Chatbot API", @@ -20,6 +20,7 @@ app.add_middleware( app.include_router(auth.router) app.include_router(history.router) app.include_router(artifacts.router) +app.include_router(agent.router) @app.get("/health") async def health_check(): diff --git a/src/ea_chatbot/api/routers/agent.py b/src/ea_chatbot/api/routers/agent.py new file mode 100644 index 0000000..d2ecf13 --- /dev/null +++ b/src/ea_chatbot/api/routers/agent.py @@ -0,0 +1,98 @@ +import json +import asyncio +from typing import AsyncGenerator, Optional +from fastapi import APIRouter, Depends, HTTPException, status +from fastapi.responses import StreamingResponse +from pydantic import BaseModel +from ea_chatbot.api.dependencies import get_current_user, history_manager +from ea_chatbot.graph.workflow import app +from ea_chatbot.graph.checkpoint import get_checkpointer +from ea_chatbot.history.models import User as UserDB, Conversation +import io +import base64 + +router = APIRouter(prefix="/chat", tags=["agent"]) + +class ChatRequest(BaseModel): + message: str + thread_id: str # This maps to the conversation_id in our DB + summary: Optional[str] = "" + +async def stream_agent_events( + message: str, + thread_id: str, + user_id: str, + summary: str +) -> AsyncGenerator[str, None]: + """ + Generator that invokes the LangGraph agent and yields SSE formatted events. + """ + initial_state = { + "messages": [], + "question": message, + "summary": summary, + "analysis": None, + "next_action": "", + "plan": None, + "code": None, + "code_output": None, + "error": None, + "plots": [], + "dfs": {} + } + + config = {"configurable": {"thread_id": thread_id}} + + async with get_checkpointer() as checkpointer: + async for event in app.astream_events( + initial_state, + config, + version="v2", + checkpointer=checkpointer + ): + kind = event.get("event") + name = event.get("name") + + output_event = { + "type": kind, + "name": name, + "data": event.get("data", {}) + } + + if kind == "on_chain_end" and name == "executor": + output = event.get("data", {}).get("output", {}) + if isinstance(output, dict) and "plots" in output: + plots = output["plots"] + encoded_plots = [] + for fig in plots: + buf = io.BytesIO() + fig.savefig(buf, format="png") + encoded_plots.append(base64.b64encode(buf.getvalue()).decode('utf-8')) + output_event["data"]["encoded_plots"] = encoded_plots + + yield f"data: {json.dumps(output_event)}\n\n" + + yield "data: {\"type\": \"done\"}\n\n" + +@router.post("/stream") +async def chat_stream( + request: ChatRequest, + current_user: UserDB = Depends(get_current_user) +): + """ + Stream agent execution events via SSE. + """ + with history_manager.get_session() as session: + conv = session.get(Conversation, request.thread_id) + if conv and conv.user_id != current_user.id: + raise HTTPException(status_code=403, detail="Not authorized to access this conversation") + + return StreamingResponse( + stream_agent_events( + request.message, + request.thread_id, + current_user.id, + request.summary or "" + ), + media_type="text/event-stream" + ) \ No newline at end of file diff --git a/tests/api/test_agent.py b/tests/api/test_agent.py index 8eb0ea7..c9d3b63 100644 --- a/tests/api/test_agent.py +++ b/tests/api/test_agent.py @@ -48,7 +48,7 @@ async def test_stream_agent_success(auth_header): json={"message": "hello", "thread_id": "t1"}, headers=auth_header) as response: assert response.status_code == 200 - assert response.headers["content-type"] == "text/event-stream" + assert "text/event-stream" in response.headers["content-type"] lines = list(response.iter_lines()) # Each event should start with 'data: ' and be valid JSON