diff --git a/src/ea_chatbot/api/routers/agent.py b/src/ea_chatbot/api/routers/agent.py index df5b544..fdeeb53 100644 --- a/src/ea_chatbot/api/routers/agent.py +++ b/src/ea_chatbot/api/routers/agent.py @@ -68,9 +68,12 @@ async def stream_agent_events( # Buffer assistant chunks (summarizer and researcher might stream) if kind == "on_chat_model_stream" and name in ["summarizer", "researcher"]: chunk = data.get("chunk", "") - if hasattr(chunk, "content"): - chunk = chunk.content - assistant_chunks.append(str(chunk)) + # Use utility to safely extract text content from the chunk + chunk_data = convert_to_json_compatible(chunk) + if isinstance(chunk_data, dict) and "content" in chunk_data: + assistant_chunks.append(str(chunk_data["content"])) + else: + assistant_chunks.append(str(chunk_data)) # Buffer and encode plots if kind == "on_chain_end" and name == "executor": @@ -91,12 +94,15 @@ async def stream_agent_events( output = data.get("output", {}) if isinstance(output, dict) and "messages" in output: last_msg = output["messages"][-1] - if hasattr(last_msg, "content"): - final_response = last_msg.content - elif isinstance(last_msg, dict) and "content" in last_msg: - final_response = last_msg["content"] + + # Use centralized utility to extract clean text content + # Since convert_to_json_compatible returns a dict for BaseMessage, + # we can extract 'content' from it. + msg_data = convert_to_json_compatible(last_msg) + if isinstance(msg_data, dict) and "content" in msg_data: + final_response = msg_data["content"] else: - final_response = str(last_msg) + final_response = str(msg_data) # Collect new summary if kind == "on_chain_end" and name == "summarize_conversation": diff --git a/src/ea_chatbot/api/utils.py b/src/ea_chatbot/api/utils.py index b7a176e..a775095 100644 --- a/src/ea_chatbot/api/utils.py +++ b/src/ea_chatbot/api/utils.py @@ -51,7 +51,24 @@ def convert_to_json_compatible(obj: Any) -> Any: elif isinstance(obj, dict): return {k: convert_to_json_compatible(v) for k, v in obj.items()} elif isinstance(obj, BaseMessage): - return {"type": obj.type, "content": obj.content, **convert_to_json_compatible(obj.additional_kwargs)} + # Handle content that might be a list of blocks (e.g. from Gemini or OpenAI tools) + content = obj.content + if isinstance(content, list): + text_parts = [] + for block in content: + if isinstance(block, str): + text_parts.append(block) + elif isinstance(block, dict): + if block.get("type") == "text": + text_parts.append(block.get("text", "")) + # You could also handle other block types if needed + content = "".join(text_parts) + + # Prefer .text property if available (common in some message types) + if hasattr(obj, "text") and isinstance(obj.text, str) and obj.text: + content = obj.text + + return {"type": obj.type, "content": content, **convert_to_json_compatible(obj.additional_kwargs)} elif isinstance(obj, BaseModel): return convert_to_json_compatible(obj.model_dump()) elif hasattr(obj, "model_dump"): # For Pydantic v2 if not caught by BaseModel diff --git a/tests/api/test_utils.py b/tests/api/test_utils.py index 9fb1d26..75be2f1 100644 --- a/tests/api/test_utils.py +++ b/tests/api/test_utils.py @@ -1,5 +1,6 @@ from datetime import timedelta -from ea_chatbot.api.utils import create_access_token, decode_access_token +from ea_chatbot.api.utils import create_access_token, decode_access_token, convert_to_json_compatible +from langchain_core.messages import AIMessage def test_create_and_decode_access_token(): """Test that a token can be created and then decoded.""" @@ -22,3 +23,29 @@ def test_expired_token(): token = create_access_token(data, expires_delta=timedelta(minutes=-1)) assert decode_access_token(token) is None + +def test_convert_to_json_compatible_complex_message(): + """Test that list-based message content is handled correctly.""" + # Mock a message with list-based content (blocks) + msg = AIMessage(content=[ + {"type": "text", "text": "Hello "}, + {"type": "text", "text": "world!"}, + {"type": "other", "data": "ignore me"} + ]) + + result = convert_to_json_compatible(msg) + assert result["content"] == "Hello world!" + assert result["type"] == "ai" + +def test_convert_to_json_compatible_message_with_text_prop(): + """Test that .text property is prioritized if available.""" + # Using a MagicMock to simulate the property safely + from unittest.mock import MagicMock + msg = MagicMock(spec=AIMessage) + msg.content = "Raw content" + msg.text = "Just the text" + msg.type = "ai" + msg.additional_kwargs = {} + + result = convert_to_json_compatible(msg) + assert result["content"] == "Just the text"