diff --git a/backend/src/ea_chatbot/history/utils.py b/backend/src/ea_chatbot/history/utils.py new file mode 100644 index 0000000..450ebba --- /dev/null +++ b/backend/src/ea_chatbot/history/utils.py @@ -0,0 +1,29 @@ +from typing import List +from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage +from ea_chatbot.history.models import Message + +def map_db_messages_to_langchain(db_messages: List[Message]) -> List[BaseMessage]: + """ + Converts a list of database Message models to LangChain BaseMessage objects. + + Args: + db_messages: List of Message objects from the database. + + Returns: + List of HumanMessage, AIMessage, or SystemMessage objects. + """ + lc_messages: List[BaseMessage] = [] + + for m in db_messages: + role = m.role.lower() + if role == "user": + lc_messages.append(HumanMessage(content=m.content)) + elif role == "assistant": + lc_messages.append(AIMessage(content=m.content)) + elif role == "system": + lc_messages.append(SystemMessage(content=m.content)) + else: + # Default to HumanMessage for unknown roles + lc_messages.append(HumanMessage(content=m.content)) + + return lc_messages diff --git a/backend/tests/test_history_mapping.py b/backend/tests/test_history_mapping.py new file mode 100644 index 0000000..35bf45e --- /dev/null +++ b/backend/tests/test_history_mapping.py @@ -0,0 +1,43 @@ +import pytest +from ea_chatbot.history.models import Message +from ea_chatbot.history.utils import map_db_messages_to_langchain +from langchain_core.messages import HumanMessage, AIMessage + +def test_map_db_messages_to_langchain(): + """ + Test that DB Message objects are correctly mapped to LangChain message objects. + """ + # 1. Setup mock DB messages + m1 = Message(role="user", content="Hello") + m2 = Message(role="assistant", content="Hi there!") + m3 = Message(role="user", content="How are you?") + + db_messages = [m1, m2, m3] + + # 2. Map + lc_messages = map_db_messages_to_langchain(db_messages) + + # 3. Assertions + assert len(lc_messages) == 3 + + assert isinstance(lc_messages[0], HumanMessage) + assert lc_messages[0].content == "Hello" + + assert isinstance(lc_messages[1], AIMessage) + assert lc_messages[1].content == "Hi there!" + + assert isinstance(lc_messages[2], HumanMessage) + assert lc_messages[2].content == "How are you?" + +def test_map_db_messages_unknown_role(): + """ + Test that unknown roles are handled gracefully (e.g., defaulted to HumanMessage or raising error). + Let's assume we default to HumanMessage for safety, or just skip. + """ + m = Message(role="system", content="System instruction") + lc_messages = map_db_messages_to_langchain([m]) + + # Depending on implementation choice. Let's say we support 'system' too. + from langchain_core.messages import SystemMessage + assert isinstance(lc_messages[0], SystemMessage) + assert lc_messages[0].content == "System instruction"