feat(history): Implement map_db_messages_to_langchain utility
This commit is contained in:
29
backend/src/ea_chatbot/history/utils.py
Normal file
29
backend/src/ea_chatbot/history/utils.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
from typing import List
|
||||||
|
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage
|
||||||
|
from ea_chatbot.history.models import Message
|
||||||
|
|
||||||
|
def map_db_messages_to_langchain(db_messages: List[Message]) -> List[BaseMessage]:
|
||||||
|
"""
|
||||||
|
Converts a list of database Message models to LangChain BaseMessage objects.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_messages: List of Message objects from the database.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of HumanMessage, AIMessage, or SystemMessage objects.
|
||||||
|
"""
|
||||||
|
lc_messages: List[BaseMessage] = []
|
||||||
|
|
||||||
|
for m in db_messages:
|
||||||
|
role = m.role.lower()
|
||||||
|
if role == "user":
|
||||||
|
lc_messages.append(HumanMessage(content=m.content))
|
||||||
|
elif role == "assistant":
|
||||||
|
lc_messages.append(AIMessage(content=m.content))
|
||||||
|
elif role == "system":
|
||||||
|
lc_messages.append(SystemMessage(content=m.content))
|
||||||
|
else:
|
||||||
|
# Default to HumanMessage for unknown roles
|
||||||
|
lc_messages.append(HumanMessage(content=m.content))
|
||||||
|
|
||||||
|
return lc_messages
|
||||||
43
backend/tests/test_history_mapping.py
Normal file
43
backend/tests/test_history_mapping.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
import pytest
|
||||||
|
from ea_chatbot.history.models import Message
|
||||||
|
from ea_chatbot.history.utils import map_db_messages_to_langchain
|
||||||
|
from langchain_core.messages import HumanMessage, AIMessage
|
||||||
|
|
||||||
|
def test_map_db_messages_to_langchain():
|
||||||
|
"""
|
||||||
|
Test that DB Message objects are correctly mapped to LangChain message objects.
|
||||||
|
"""
|
||||||
|
# 1. Setup mock DB messages
|
||||||
|
m1 = Message(role="user", content="Hello")
|
||||||
|
m2 = Message(role="assistant", content="Hi there!")
|
||||||
|
m3 = Message(role="user", content="How are you?")
|
||||||
|
|
||||||
|
db_messages = [m1, m2, m3]
|
||||||
|
|
||||||
|
# 2. Map
|
||||||
|
lc_messages = map_db_messages_to_langchain(db_messages)
|
||||||
|
|
||||||
|
# 3. Assertions
|
||||||
|
assert len(lc_messages) == 3
|
||||||
|
|
||||||
|
assert isinstance(lc_messages[0], HumanMessage)
|
||||||
|
assert lc_messages[0].content == "Hello"
|
||||||
|
|
||||||
|
assert isinstance(lc_messages[1], AIMessage)
|
||||||
|
assert lc_messages[1].content == "Hi there!"
|
||||||
|
|
||||||
|
assert isinstance(lc_messages[2], HumanMessage)
|
||||||
|
assert lc_messages[2].content == "How are you?"
|
||||||
|
|
||||||
|
def test_map_db_messages_unknown_role():
|
||||||
|
"""
|
||||||
|
Test that unknown roles are handled gracefully (e.g., defaulted to HumanMessage or raising error).
|
||||||
|
Let's assume we default to HumanMessage for safety, or just skip.
|
||||||
|
"""
|
||||||
|
m = Message(role="system", content="System instruction")
|
||||||
|
lc_messages = map_db_messages_to_langchain([m])
|
||||||
|
|
||||||
|
# Depending on implementation choice. Let's say we support 'system' too.
|
||||||
|
from langchain_core.messages import SystemMessage
|
||||||
|
assert isinstance(lc_messages[0], SystemMessage)
|
||||||
|
assert lc_messages[0].content == "System instruction"
|
||||||
Reference in New Issue
Block a user