fix(api): Robustly handle complex LangChain message content in API responses and persistence
This commit is contained in:
@@ -68,9 +68,12 @@ async def stream_agent_events(
|
|||||||
# Buffer assistant chunks (summarizer and researcher might stream)
|
# Buffer assistant chunks (summarizer and researcher might stream)
|
||||||
if kind == "on_chat_model_stream" and name in ["summarizer", "researcher"]:
|
if kind == "on_chat_model_stream" and name in ["summarizer", "researcher"]:
|
||||||
chunk = data.get("chunk", "")
|
chunk = data.get("chunk", "")
|
||||||
if hasattr(chunk, "content"):
|
# Use utility to safely extract text content from the chunk
|
||||||
chunk = chunk.content
|
chunk_data = convert_to_json_compatible(chunk)
|
||||||
assistant_chunks.append(str(chunk))
|
if isinstance(chunk_data, dict) and "content" in chunk_data:
|
||||||
|
assistant_chunks.append(str(chunk_data["content"]))
|
||||||
|
else:
|
||||||
|
assistant_chunks.append(str(chunk_data))
|
||||||
|
|
||||||
# Buffer and encode plots
|
# Buffer and encode plots
|
||||||
if kind == "on_chain_end" and name == "executor":
|
if kind == "on_chain_end" and name == "executor":
|
||||||
@@ -91,12 +94,15 @@ async def stream_agent_events(
|
|||||||
output = data.get("output", {})
|
output = data.get("output", {})
|
||||||
if isinstance(output, dict) and "messages" in output:
|
if isinstance(output, dict) and "messages" in output:
|
||||||
last_msg = output["messages"][-1]
|
last_msg = output["messages"][-1]
|
||||||
if hasattr(last_msg, "content"):
|
|
||||||
final_response = last_msg.content
|
# Use centralized utility to extract clean text content
|
||||||
elif isinstance(last_msg, dict) and "content" in last_msg:
|
# Since convert_to_json_compatible returns a dict for BaseMessage,
|
||||||
final_response = last_msg["content"]
|
# we can extract 'content' from it.
|
||||||
|
msg_data = convert_to_json_compatible(last_msg)
|
||||||
|
if isinstance(msg_data, dict) and "content" in msg_data:
|
||||||
|
final_response = msg_data["content"]
|
||||||
else:
|
else:
|
||||||
final_response = str(last_msg)
|
final_response = str(msg_data)
|
||||||
|
|
||||||
# Collect new summary
|
# Collect new summary
|
||||||
if kind == "on_chain_end" and name == "summarize_conversation":
|
if kind == "on_chain_end" and name == "summarize_conversation":
|
||||||
|
|||||||
@@ -51,7 +51,24 @@ def convert_to_json_compatible(obj: Any) -> Any:
|
|||||||
elif isinstance(obj, dict):
|
elif isinstance(obj, dict):
|
||||||
return {k: convert_to_json_compatible(v) for k, v in obj.items()}
|
return {k: convert_to_json_compatible(v) for k, v in obj.items()}
|
||||||
elif isinstance(obj, BaseMessage):
|
elif isinstance(obj, BaseMessage):
|
||||||
return {"type": obj.type, "content": obj.content, **convert_to_json_compatible(obj.additional_kwargs)}
|
# Handle content that might be a list of blocks (e.g. from Gemini or OpenAI tools)
|
||||||
|
content = obj.content
|
||||||
|
if isinstance(content, list):
|
||||||
|
text_parts = []
|
||||||
|
for block in content:
|
||||||
|
if isinstance(block, str):
|
||||||
|
text_parts.append(block)
|
||||||
|
elif isinstance(block, dict):
|
||||||
|
if block.get("type") == "text":
|
||||||
|
text_parts.append(block.get("text", ""))
|
||||||
|
# You could also handle other block types if needed
|
||||||
|
content = "".join(text_parts)
|
||||||
|
|
||||||
|
# Prefer .text property if available (common in some message types)
|
||||||
|
if hasattr(obj, "text") and isinstance(obj.text, str) and obj.text:
|
||||||
|
content = obj.text
|
||||||
|
|
||||||
|
return {"type": obj.type, "content": content, **convert_to_json_compatible(obj.additional_kwargs)}
|
||||||
elif isinstance(obj, BaseModel):
|
elif isinstance(obj, BaseModel):
|
||||||
return convert_to_json_compatible(obj.model_dump())
|
return convert_to_json_compatible(obj.model_dump())
|
||||||
elif hasattr(obj, "model_dump"): # For Pydantic v2 if not caught by BaseModel
|
elif hasattr(obj, "model_dump"): # For Pydantic v2 if not caught by BaseModel
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from ea_chatbot.api.utils import create_access_token, decode_access_token
|
from ea_chatbot.api.utils import create_access_token, decode_access_token, convert_to_json_compatible
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
|
||||||
def test_create_and_decode_access_token():
|
def test_create_and_decode_access_token():
|
||||||
"""Test that a token can be created and then decoded."""
|
"""Test that a token can be created and then decoded."""
|
||||||
@@ -22,3 +23,29 @@ def test_expired_token():
|
|||||||
token = create_access_token(data, expires_delta=timedelta(minutes=-1))
|
token = create_access_token(data, expires_delta=timedelta(minutes=-1))
|
||||||
|
|
||||||
assert decode_access_token(token) is None
|
assert decode_access_token(token) is None
|
||||||
|
|
||||||
|
def test_convert_to_json_compatible_complex_message():
|
||||||
|
"""Test that list-based message content is handled correctly."""
|
||||||
|
# Mock a message with list-based content (blocks)
|
||||||
|
msg = AIMessage(content=[
|
||||||
|
{"type": "text", "text": "Hello "},
|
||||||
|
{"type": "text", "text": "world!"},
|
||||||
|
{"type": "other", "data": "ignore me"}
|
||||||
|
])
|
||||||
|
|
||||||
|
result = convert_to_json_compatible(msg)
|
||||||
|
assert result["content"] == "Hello world!"
|
||||||
|
assert result["type"] == "ai"
|
||||||
|
|
||||||
|
def test_convert_to_json_compatible_message_with_text_prop():
|
||||||
|
"""Test that .text property is prioritized if available."""
|
||||||
|
# Using a MagicMock to simulate the property safely
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
msg = MagicMock(spec=AIMessage)
|
||||||
|
msg.content = "Raw content"
|
||||||
|
msg.text = "Just the text"
|
||||||
|
msg.type = "ai"
|
||||||
|
msg.additional_kwargs = {}
|
||||||
|
|
||||||
|
result = convert_to_json_compatible(msg)
|
||||||
|
assert result["content"] == "Just the text"
|
||||||
|
|||||||
Reference in New Issue
Block a user