feat(api): Implement agent streaming router with SSE

This commit is contained in:
Yunxiao Xu
2026-02-10 15:49:15 -08:00
parent 057278a1c5
commit b8fa60962e
3 changed files with 101 additions and 2 deletions

View File

@@ -1,6 +1,6 @@
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from ea_chatbot.api.routers import auth, history, artifacts
from ea_chatbot.api.routers import auth, history, artifacts, agent
app = FastAPI(
title="Election Analytics Chatbot API",
@@ -20,6 +20,7 @@ app.add_middleware(
app.include_router(auth.router)
app.include_router(history.router)
app.include_router(artifacts.router)
app.include_router(agent.router)
@app.get("/health")
async def health_check():

View File

@@ -0,0 +1,98 @@
import json
import asyncio
from typing import AsyncGenerator, Optional
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from ea_chatbot.api.dependencies import get_current_user, history_manager
from ea_chatbot.graph.workflow import app
from ea_chatbot.graph.checkpoint import get_checkpointer
from ea_chatbot.history.models import User as UserDB, Conversation
import io
import base64
router = APIRouter(prefix="/chat", tags=["agent"])
class ChatRequest(BaseModel):
message: str
thread_id: str # This maps to the conversation_id in our DB
summary: Optional[str] = ""
async def stream_agent_events(
message: str,
thread_id: str,
user_id: str,
summary: str
) -> AsyncGenerator[str, None]:
"""
Generator that invokes the LangGraph agent and yields SSE formatted events.
"""
initial_state = {
"messages": [],
"question": message,
"summary": summary,
"analysis": None,
"next_action": "",
"plan": None,
"code": None,
"code_output": None,
"error": None,
"plots": [],
"dfs": {}
}
config = {"configurable": {"thread_id": thread_id}}
async with get_checkpointer() as checkpointer:
async for event in app.astream_events(
initial_state,
config,
version="v2",
checkpointer=checkpointer
):
kind = event.get("event")
name = event.get("name")
output_event = {
"type": kind,
"name": name,
"data": event.get("data", {})
}
if kind == "on_chain_end" and name == "executor":
output = event.get("data", {}).get("output", {})
if isinstance(output, dict) and "plots" in output:
plots = output["plots"]
encoded_plots = []
for fig in plots:
buf = io.BytesIO()
fig.savefig(buf, format="png")
encoded_plots.append(base64.b64encode(buf.getvalue()).decode('utf-8'))
output_event["data"]["encoded_plots"] = encoded_plots
yield f"data: {json.dumps(output_event)}\n\n"
yield "data: {\"type\": \"done\"}\n\n"
@router.post("/stream")
async def chat_stream(
request: ChatRequest,
current_user: UserDB = Depends(get_current_user)
):
"""
Stream agent execution events via SSE.
"""
with history_manager.get_session() as session:
conv = session.get(Conversation, request.thread_id)
if conv and conv.user_id != current_user.id:
raise HTTPException(status_code=403, detail="Not authorized to access this conversation")
return StreamingResponse(
stream_agent_events(
request.message,
request.thread_id,
current_user.id,
request.summary or ""
),
media_type="text/event-stream"
)

View File

@@ -48,7 +48,7 @@ async def test_stream_agent_success(auth_header):
json={"message": "hello", "thread_id": "t1"},
headers=auth_header) as response:
assert response.status_code == 200
assert response.headers["content-type"] == "text/event-stream"
assert "text/event-stream" in response.headers["content-type"]
lines = list(response.iter_lines())
# Each event should start with 'data: ' and be valid JSON