165 lines
6.7 KiB
Python
165 lines
6.7 KiB
Python
import json
|
|
import asyncio
|
|
from typing import AsyncGenerator, Optional, List
|
|
from fastapi import APIRouter, Depends, HTTPException, status
|
|
from fastapi.responses import StreamingResponse
|
|
from ea_chatbot.api.dependencies import get_current_user, history_manager
|
|
from ea_chatbot.api.utils import convert_to_json_compatible
|
|
from ea_chatbot.graph.workflow import app
|
|
from ea_chatbot.graph.checkpoint import get_checkpointer
|
|
from ea_chatbot.history.models import User as UserDB, Conversation
|
|
from ea_chatbot.api.schemas import ChatRequest
|
|
import io
|
|
import base64
|
|
from langchain_core.runnables.config import RunnableConfig
|
|
|
|
router = APIRouter(prefix="/chat", tags=["agent"])
|
|
|
|
async def stream_agent_events(
|
|
message: str,
|
|
thread_id: str,
|
|
user_id: str,
|
|
summary: str
|
|
) -> AsyncGenerator[str, None]:
|
|
"""
|
|
Generator that invokes the LangGraph agent and yields SSE formatted events.
|
|
Persists assistant responses and plots to the database.
|
|
"""
|
|
initial_state = {
|
|
"messages": [],
|
|
"question": message,
|
|
"summary": summary,
|
|
"analysis": None,
|
|
"next_action": "",
|
|
"plan": None,
|
|
"code": None,
|
|
"code_output": None,
|
|
"error": None,
|
|
"plots": [],
|
|
"dfs": {}
|
|
}
|
|
|
|
config: RunnableConfig = {"configurable": {"thread_id": thread_id}}
|
|
|
|
assistant_chunks: List[str] = []
|
|
assistant_plots: List[bytes] = []
|
|
final_response: str = ""
|
|
new_summary: str = ""
|
|
|
|
try:
|
|
async with get_checkpointer() as checkpointer:
|
|
async for event in app.astream_events(
|
|
initial_state,
|
|
config,
|
|
version="v2",
|
|
checkpointer=checkpointer
|
|
):
|
|
kind = event.get("event")
|
|
name = event.get("name")
|
|
node_name = event.get("metadata", {}).get("langgraph_node", name)
|
|
data = event.get("data", {})
|
|
|
|
# Standardize event for frontend
|
|
output_event = {
|
|
"type": kind,
|
|
"name": name,
|
|
"node": node_name,
|
|
"data": data
|
|
}
|
|
|
|
# Buffer assistant chunks (summarizer and researcher might stream)
|
|
if kind == "on_chat_model_stream" and node_name in ["summarizer", "researcher", "clarification"]:
|
|
chunk = data.get("chunk", "")
|
|
# Use utility to safely extract text content from the chunk
|
|
chunk_data = convert_to_json_compatible(chunk)
|
|
if isinstance(chunk_data, dict) and "content" in chunk_data:
|
|
assistant_chunks.append(str(chunk_data["content"]))
|
|
else:
|
|
# TODO: need better way to handle this
|
|
assistant_chunks.append(str(chunk_data))
|
|
|
|
# Buffer and encode plots
|
|
if kind == "on_chain_end" and name == "executor":
|
|
output = data.get("output", {})
|
|
if isinstance(output, dict) and "plots" in output:
|
|
plots = output["plots"]
|
|
encoded_plots: list[str] = []
|
|
for fig in plots:
|
|
buf = io.BytesIO()
|
|
fig.savefig(buf, format="png")
|
|
plot_bytes = buf.getvalue()
|
|
assistant_plots.append(plot_bytes)
|
|
encoded_plots.append(base64.b64encode(plot_bytes).decode('utf-8'))
|
|
output_event["data"]["encoded_plots"] = encoded_plots
|
|
|
|
# Collect final response from terminal nodes
|
|
if kind == "on_chain_end" and name in ["summarizer", "researcher", "clarification"]:
|
|
output = data.get("output", {})
|
|
if isinstance(output, dict) and "messages" in output:
|
|
last_msg = output["messages"][-1]
|
|
|
|
# Use centralized utility to extract clean text content
|
|
# Since convert_to_json_compatible returns a dict for BaseMessage,
|
|
# we can extract 'content' from it.
|
|
msg_data = convert_to_json_compatible(last_msg)
|
|
if isinstance(msg_data, dict) and "content" in msg_data:
|
|
final_response = msg_data["content"]
|
|
else:
|
|
final_response = str(msg_data)
|
|
|
|
# Collect new summary
|
|
if kind == "on_chain_end" and name == "summarize_conversation":
|
|
output = data.get("output", {})
|
|
if isinstance(output, dict) and "summary" in output:
|
|
new_summary = output["summary"]
|
|
|
|
# Convert to JSON compatible format to avoid serialization errors
|
|
compatible_output = convert_to_json_compatible(output_event)
|
|
yield f"data: {json.dumps(compatible_output)}\n\n"
|
|
|
|
# If we didn't get a final_response from node output, use buffered chunks
|
|
if not final_response and assistant_chunks:
|
|
final_response = "".join(assistant_chunks)
|
|
|
|
# Save assistant message to DB
|
|
if final_response:
|
|
history_manager.add_message(thread_id, "assistant", final_response, plots=assistant_plots)
|
|
|
|
# Update summary in DB
|
|
if new_summary:
|
|
history_manager.update_conversation_summary(thread_id, new_summary)
|
|
|
|
yield "data: {\"type\": \"done\"}\n\n"
|
|
|
|
except Exception as e:
|
|
error_msg = f"Agent execution failed: {str(e)}"
|
|
history_manager.add_message(thread_id, "assistant", error_msg)
|
|
yield f"data: {json.dumps({'type': 'error', 'data': {'message': error_msg}})}\n\n"
|
|
|
|
@router.post("/stream")
|
|
async def chat_stream(
|
|
request: ChatRequest,
|
|
current_user: UserDB = Depends(get_current_user)
|
|
):
|
|
"""
|
|
Stream agent execution events via SSE.
|
|
"""
|
|
with history_manager.get_session() as session:
|
|
conv = session.get(Conversation, request.thread_id)
|
|
if not conv:
|
|
raise HTTPException(status_code=404, detail="Conversation not found")
|
|
if conv.user_id != current_user.id:
|
|
raise HTTPException(status_code=403, detail="Not authorized to access this conversation")
|
|
|
|
# Save user message immediately
|
|
history_manager.add_message(request.thread_id, "user", request.message)
|
|
|
|
return StreamingResponse(
|
|
stream_agent_events(
|
|
request.message,
|
|
request.thread_id,
|
|
current_user.id,
|
|
request.summary or ""
|
|
),
|
|
media_type="text/event-stream"
|
|
) |