feat(chat): Implement real-time SSE streaming with reasoning steps and improved UI indicators.
This commit is contained in:
@@ -11,7 +11,7 @@ from ea_chatbot.history.models import User as UserDB, Conversation
|
||||
from ea_chatbot.api.schemas import ChatRequest
|
||||
import io
|
||||
import base64
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
router = APIRouter(prefix="/chat", tags=["agent"])
|
||||
|
||||
@@ -39,7 +39,7 @@ async def stream_agent_events(
|
||||
"dfs": {}
|
||||
}
|
||||
|
||||
config = {"configurable": {"thread_id": thread_id}}
|
||||
config: RunnableConfig = {"configurable": {"thread_id": thread_id}}
|
||||
|
||||
assistant_chunks: List[str] = []
|
||||
assistant_plots: List[bytes] = []
|
||||
@@ -56,23 +56,26 @@ async def stream_agent_events(
|
||||
):
|
||||
kind = event.get("event")
|
||||
name = event.get("name")
|
||||
node_name = event.get("metadata", {}).get("langgraph_node", name)
|
||||
data = event.get("data", {})
|
||||
|
||||
# Standardize event for frontend
|
||||
output_event = {
|
||||
"type": kind,
|
||||
"name": name,
|
||||
"node": node_name,
|
||||
"data": data
|
||||
}
|
||||
|
||||
# Buffer assistant chunks (summarizer and researcher might stream)
|
||||
if kind == "on_chat_model_stream" and name in ["summarizer", "researcher"]:
|
||||
if kind == "on_chat_model_stream" and node_name in ["summarizer", "researcher", "clarification"]:
|
||||
chunk = data.get("chunk", "")
|
||||
# Use utility to safely extract text content from the chunk
|
||||
chunk_data = convert_to_json_compatible(chunk)
|
||||
if isinstance(chunk_data, dict) and "content" in chunk_data:
|
||||
assistant_chunks.append(str(chunk_data["content"]))
|
||||
else:
|
||||
# TODO: need better way to handle this
|
||||
assistant_chunks.append(str(chunk_data))
|
||||
|
||||
# Buffer and encode plots
|
||||
@@ -80,7 +83,7 @@ async def stream_agent_events(
|
||||
output = data.get("output", {})
|
||||
if isinstance(output, dict) and "plots" in output:
|
||||
plots = output["plots"]
|
||||
encoded_plots = []
|
||||
encoded_plots: list[str] = []
|
||||
for fig in plots:
|
||||
buf = io.BytesIO()
|
||||
fig.savefig(buf, format="png")
|
||||
@@ -131,7 +134,7 @@ async def stream_agent_events(
|
||||
except Exception as e:
|
||||
error_msg = f"Agent execution failed: {str(e)}"
|
||||
history_manager.add_message(thread_id, "assistant", error_msg)
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': error_msg})}\n\n"
|
||||
yield f"data: {json.dumps({'type': 'error', 'data': {'message': error_msg}})}\n\n"
|
||||
|
||||
@router.post("/stream")
|
||||
async def chat_stream(
|
||||
|
||||
Reference in New Issue
Block a user