import json import asyncio from typing import AsyncGenerator, Optional, List from fastapi import APIRouter, Depends, HTTPException, status from fastapi.responses import StreamingResponse from ea_chatbot.api.dependencies import get_current_user, history_manager from ea_chatbot.api.utils import convert_to_json_compatible from ea_chatbot.graph.workflow import app from ea_chatbot.graph.checkpoint import get_checkpointer from ea_chatbot.history.models import User as UserDB, Conversation from ea_chatbot.api.schemas import ChatRequest import io import base64 from langchain_core.runnables.config import RunnableConfig router = APIRouter(prefix="/chat", tags=["agent"]) async def stream_agent_events( message: str, thread_id: str, user_id: str, summary: str ) -> AsyncGenerator[str, None]: """ Generator that invokes the LangGraph agent and yields SSE formatted events. Persists assistant responses and plots to the database. """ initial_state = { "messages": [], "question": message, "summary": summary, "analysis": None, "next_action": "", "plan": None, "code": None, "code_output": None, "error": None, "plots": [], "dfs": {} } config: RunnableConfig = {"configurable": {"thread_id": thread_id}} assistant_chunks: List[str] = [] assistant_plots: List[bytes] = [] final_response: str = "" new_summary: str = "" try: async with get_checkpointer() as checkpointer: async for event in app.astream_events( initial_state, config, version="v2", checkpointer=checkpointer ): kind = event.get("event") name = event.get("name") node_name = event.get("metadata", {}).get("langgraph_node", name) data = event.get("data", {}) # Standardize event for frontend output_event = { "type": kind, "name": name, "node": node_name, "data": data } # Buffer assistant chunks (summarizer and researcher might stream) if kind == "on_chat_model_stream" and node_name in ["summarizer", "researcher", "clarification"]: chunk = data.get("chunk", "") # Use utility to safely extract text content from the chunk chunk_data = convert_to_json_compatible(chunk) if isinstance(chunk_data, dict) and "content" in chunk_data: assistant_chunks.append(str(chunk_data["content"])) else: # TODO: need better way to handle this assistant_chunks.append(str(chunk_data)) # Buffer and encode plots if kind == "on_chain_end" and name == "executor": output = data.get("output", {}) if isinstance(output, dict) and "plots" in output: plots = output["plots"] encoded_plots: list[str] = [] for fig in plots: buf = io.BytesIO() fig.savefig(buf, format="png") plot_bytes = buf.getvalue() assistant_plots.append(plot_bytes) encoded_plots.append(base64.b64encode(plot_bytes).decode('utf-8')) output_event["data"]["encoded_plots"] = encoded_plots # Collect final response from terminal nodes if kind == "on_chain_end" and name in ["summarizer", "researcher", "clarification"]: output = data.get("output", {}) if isinstance(output, dict) and "messages" in output: last_msg = output["messages"][-1] # Use centralized utility to extract clean text content # Since convert_to_json_compatible returns a dict for BaseMessage, # we can extract 'content' from it. msg_data = convert_to_json_compatible(last_msg) if isinstance(msg_data, dict) and "content" in msg_data: final_response = msg_data["content"] else: final_response = str(msg_data) # Collect new summary if kind == "on_chain_end" and name == "summarize_conversation": output = data.get("output", {}) if isinstance(output, dict) and "summary" in output: new_summary = output["summary"] # Convert to JSON compatible format to avoid serialization errors compatible_output = convert_to_json_compatible(output_event) yield f"data: {json.dumps(compatible_output)}\n\n" # If we didn't get a final_response from node output, use buffered chunks if not final_response and assistant_chunks: final_response = "".join(assistant_chunks) # Save assistant message to DB if final_response: history_manager.add_message(thread_id, "assistant", final_response, plots=assistant_plots) # Update summary in DB if new_summary: history_manager.update_conversation_summary(thread_id, new_summary) yield "data: {\"type\": \"done\"}\n\n" except Exception as e: error_msg = f"Agent execution failed: {str(e)}" history_manager.add_message(thread_id, "assistant", error_msg) yield f"data: {json.dumps({'type': 'error', 'data': {'message': error_msg}})}\n\n" @router.post("/stream") async def chat_stream( request: ChatRequest, current_user: UserDB = Depends(get_current_user) ): """ Stream agent execution events via SSE. """ with history_manager.get_session() as session: conv = session.get(Conversation, request.thread_id) if not conv: raise HTTPException(status_code=404, detail="Conversation not found") if conv.user_id != current_user.id: raise HTTPException(status_code=403, detail="Not authorized to access this conversation") # Save user message immediately history_manager.add_message(request.thread_id, "user", request.message) return StreamingResponse( stream_agent_events( request.message, request.thread_id, current_user.id, request.summary or "" ), media_type="text/event-stream" )