feat(api): Update Chat Stream Protocol for Orchestrator Architecture

This commit is contained in:
Yunxiao Xu
2026-02-23 16:44:06 -08:00
parent 92c30d217e
commit 02d93120e0
2 changed files with 13 additions and 11 deletions

View File

@@ -56,7 +56,8 @@ async def stream_agent_events(
initial_state, initial_state,
config, config,
version="v2", version="v2",
checkpointer=checkpointer checkpointer=checkpointer,
subgraphs=True
): ):
kind = event.get("event") kind = event.get("event")
name = event.get("name") name = event.get("name")
@@ -71,8 +72,8 @@ async def stream_agent_events(
"data": data "data": data
} }
# Buffer assistant chunks (summarizer and researcher might stream) # Buffer assistant chunks (synthesizer and clarification might stream)
if kind == "on_chat_model_stream" and node_name in ["summarizer", "researcher", "clarification"]: if kind == "on_chat_model_stream" and node_name in ["synthesizer", "clarification"]:
chunk = data.get("chunk", "") chunk = data.get("chunk", "")
# Use utility to safely extract text content from the chunk # Use utility to safely extract text content from the chunk
chunk_data = convert_to_json_compatible(chunk) chunk_data = convert_to_json_compatible(chunk)
@@ -83,7 +84,7 @@ async def stream_agent_events(
assistant_chunks.append(str(chunk_data)) assistant_chunks.append(str(chunk_data))
# Buffer and encode plots # Buffer and encode plots
if kind == "on_chain_end" and name == "executor": if kind == "on_chain_end" and name == "data_analyst_worker":
output = data.get("output", {}) output = data.get("output", {})
if isinstance(output, dict) and "plots" in output: if isinstance(output, dict) and "plots" in output:
plots = output["plots"] plots = output["plots"]
@@ -95,7 +96,7 @@ async def stream_agent_events(
output_event["data"]["encoded_plots"] = encoded_plots output_event["data"]["encoded_plots"] = encoded_plots
# Collect final response from terminal nodes # Collect final response from terminal nodes
if kind == "on_chain_end" and name in ["summarizer", "researcher", "clarification"]: if kind == "on_chain_end" and name in ["synthesizer", "clarification"]:
output = data.get("output", {}) output = data.get("output", {})
if isinstance(output, dict) and "messages" in output: if isinstance(output, dict) and "messages" in output:
last_msg = output["messages"][-1] last_msg = output["messages"][-1]

View File

@@ -36,24 +36,25 @@ async def test_stream_agent_events_all_features():
# Stream chunk # Stream chunk
{ {
"event": "on_chat_model_stream", "event": "on_chat_model_stream",
"metadata": {"langgraph_node": "summarizer"}, "metadata": {"langgraph_node": "synthesizer"},
"data": {"chunk": AIMessage(content="Hello ")} "data": {"chunk": AIMessage(content="Hello ")}
}, },
{ {
"event": "on_chat_model_stream", "event": "on_chat_model_stream",
"metadata": {"langgraph_node": "summarizer"}, "metadata": {"langgraph_node": "synthesizer"},
"data": {"chunk": AIMessage(content="world")} "data": {"chunk": AIMessage(content="world")}
}, },
# Plot event # Plot event - with nested subgraph it might bubble up or come directly from data_analyst_worker
# Let's mock it coming from the data_analyst_worker on_chain_end
{ {
"event": "on_chain_end", "event": "on_chain_end",
"name": "executor", "name": "data_analyst_worker",
"data": {"output": {"plots": [fig]}} "data": {"output": {"plots": [fig]}}
}, },
# Final response # Final response
{ {
"event": "on_chain_end", "event": "on_chain_end",
"name": "summarizer", "name": "synthesizer",
"data": {"output": {"messages": [AIMessage(content="Hello world final")]}} "data": {"output": {"messages": [AIMessage(content="Hello world final")]}}
}, },
# Summary update # Summary update
@@ -91,7 +92,7 @@ async def test_stream_agent_events_all_features():
assert any(r.get("type") == "on_chat_model_stream" for r in results) assert any(r.get("type") == "on_chat_model_stream" for r in results)
# Verify plot was encoded # Verify plot was encoded
plot_event = next(r for r in results if r.get("name") == "executor") plot_event = next(r for r in results if r.get("name") == "data_analyst_worker")
assert "encoded_plots" in plot_event["data"] assert "encoded_plots" in plot_event["data"]
assert len(plot_event["data"]["encoded_plots"]) == 1 assert len(plot_event["data"]["encoded_plots"]) == 1