feat(chat): Implement real-time SSE streaming with reasoning steps and improved UI indicators.
This commit is contained in:
22
backend/nj_voter_counts_by_county.csv
Normal file
22
backend/nj_voter_counts_by_county.csv
Normal file
@@ -0,0 +1,22 @@
|
||||
county,voters,pct_total,rank_by_size
|
||||
Bergen,637753,10.46,1
|
||||
Middlesex,531951,8.72,2
|
||||
Essex,499446,8.19,3
|
||||
Monmouth,472627,7.75,4
|
||||
Ocean,453981,7.45,5
|
||||
Hudson,374651,6.14,6
|
||||
Morris,368252,6.04,7
|
||||
Camden,359742,5.90,8
|
||||
Union,354205,5.81,9
|
||||
Burlington,340761,5.59,10
|
||||
Passaic,313061,5.13,11
|
||||
Somerset,241463,3.96,12
|
||||
Mercer,241236,3.96,13
|
||||
Gloucester,217083,3.56,14
|
||||
Atlantic,189627,3.11,15
|
||||
Sussex,110789,1.82,16
|
||||
Hunterdon,100606,1.65,17
|
||||
Cumberland,90934,1.49,18
|
||||
Warren,81642,1.34,19
|
||||
Cape May,72299,1.19,20
|
||||
Salem,45018,0.74,21
|
||||
|
@@ -11,7 +11,7 @@ from ea_chatbot.history.models import User as UserDB, Conversation
|
||||
from ea_chatbot.api.schemas import ChatRequest
|
||||
import io
|
||||
import base64
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
router = APIRouter(prefix="/chat", tags=["agent"])
|
||||
|
||||
@@ -39,7 +39,7 @@ async def stream_agent_events(
|
||||
"dfs": {}
|
||||
}
|
||||
|
||||
config = {"configurable": {"thread_id": thread_id}}
|
||||
config: RunnableConfig = {"configurable": {"thread_id": thread_id}}
|
||||
|
||||
assistant_chunks: List[str] = []
|
||||
assistant_plots: List[bytes] = []
|
||||
@@ -56,23 +56,26 @@ async def stream_agent_events(
|
||||
):
|
||||
kind = event.get("event")
|
||||
name = event.get("name")
|
||||
node_name = event.get("metadata", {}).get("langgraph_node", name)
|
||||
data = event.get("data", {})
|
||||
|
||||
# Standardize event for frontend
|
||||
output_event = {
|
||||
"type": kind,
|
||||
"name": name,
|
||||
"node": node_name,
|
||||
"data": data
|
||||
}
|
||||
|
||||
# Buffer assistant chunks (summarizer and researcher might stream)
|
||||
if kind == "on_chat_model_stream" and name in ["summarizer", "researcher"]:
|
||||
if kind == "on_chat_model_stream" and node_name in ["summarizer", "researcher", "clarification"]:
|
||||
chunk = data.get("chunk", "")
|
||||
# Use utility to safely extract text content from the chunk
|
||||
chunk_data = convert_to_json_compatible(chunk)
|
||||
if isinstance(chunk_data, dict) and "content" in chunk_data:
|
||||
assistant_chunks.append(str(chunk_data["content"]))
|
||||
else:
|
||||
# TODO: need better way to handle this
|
||||
assistant_chunks.append(str(chunk_data))
|
||||
|
||||
# Buffer and encode plots
|
||||
@@ -80,7 +83,7 @@ async def stream_agent_events(
|
||||
output = data.get("output", {})
|
||||
if isinstance(output, dict) and "plots" in output:
|
||||
plots = output["plots"]
|
||||
encoded_plots = []
|
||||
encoded_plots: list[str] = []
|
||||
for fig in plots:
|
||||
buf = io.BytesIO()
|
||||
fig.savefig(buf, format="png")
|
||||
@@ -131,7 +134,7 @@ async def stream_agent_events(
|
||||
except Exception as e:
|
||||
error_msg = f"Agent execution failed: {str(e)}"
|
||||
history_manager.add_message(thread_id, "assistant", error_msg)
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': error_msg})}\n\n"
|
||||
yield f"data: {json.dumps({'type': 'error', 'data': {'message': error_msg}})}\n\n"
|
||||
|
||||
@router.post("/stream")
|
||||
async def chat_stream(
|
||||
|
||||
@@ -52,6 +52,11 @@ def decode_access_token(token: str) -> Optional[dict]:
|
||||
|
||||
def convert_to_json_compatible(obj: Any) -> Any:
|
||||
"""Recursively convert LangChain objects, Pydantic models, and others to JSON compatible formats."""
|
||||
# Handle known non-serializable types first to avoid recursion
|
||||
type_name = type(obj).__name__
|
||||
if type_name == "Figure" or type_name == "DataFrame":
|
||||
return f"<{type_name} object>"
|
||||
|
||||
if isinstance(obj, list):
|
||||
return [convert_to_json_compatible(item) for item in obj]
|
||||
elif isinstance(obj, dict):
|
||||
@@ -91,4 +96,11 @@ def convert_to_json_compatible(obj: Any) -> Any:
|
||||
return str(obj.content)
|
||||
elif isinstance(obj, (datetime, timezone)):
|
||||
return obj.isoformat()
|
||||
return obj
|
||||
|
||||
# Final fallback for any other types that might not be JSON serializable
|
||||
import json
|
||||
try:
|
||||
json.dumps(obj)
|
||||
return obj
|
||||
except (TypeError, OverflowError):
|
||||
return str(obj)
|
||||
Reference in New Issue
Block a user