57 lines
2.2 KiB
Python
57 lines
2.2 KiB
Python
import pytest
|
|
from unittest.mock import MagicMock, patch
|
|
from langchain_core.messages import HumanMessage, AIMessage
|
|
from ea_chatbot.graph.nodes.summarize_conversation import summarize_conversation_node
|
|
|
|
@pytest.fixture
|
|
def mock_state_with_history():
|
|
return {
|
|
"messages": [
|
|
HumanMessage(content="Show me the 2024 results for Florida"),
|
|
AIMessage(content="Here are the results for Florida in 2024...")
|
|
],
|
|
"summary": "The user is asking about 2024 election results."
|
|
}
|
|
|
|
@patch("ea_chatbot.graph.nodes.summarize_conversation.get_llm_model")
|
|
def test_summarize_conversation_node_updates_summary(mock_get_llm, mock_state_with_history):
|
|
mock_llm_instance = MagicMock()
|
|
mock_get_llm.return_value = mock_llm_instance
|
|
|
|
# Mock LLM response for updating summary
|
|
mock_llm_instance.invoke.return_value = AIMessage(content="Updated summary including NJ results.")
|
|
|
|
# Add new messages to simulate a completed turn
|
|
mock_state_with_history["messages"].extend([
|
|
HumanMessage(content="What about in New Jersey?"),
|
|
AIMessage(content="In New Jersey, the 2024 results were...")
|
|
])
|
|
|
|
result = summarize_conversation_node(mock_state_with_history)
|
|
|
|
assert "summary" in result
|
|
assert result["summary"] == "Updated summary including NJ results."
|
|
|
|
# Verify LLM was called with the correct context
|
|
call_messages = mock_llm_instance.invoke.call_args[0][0]
|
|
# Should include current summary and last turn messages
|
|
assert "Current summary: The user is asking about 2024 election results." in call_messages[0].content
|
|
|
|
@patch("ea_chatbot.graph.nodes.summarize_conversation.get_llm_model")
|
|
def test_summarize_conversation_node_initial_summary(mock_get_llm):
|
|
state = {
|
|
"messages": [
|
|
HumanMessage(content="Hi"),
|
|
AIMessage(content="Hello! How can I help you today?")
|
|
],
|
|
"summary": ""
|
|
}
|
|
|
|
mock_llm_instance = MagicMock()
|
|
mock_get_llm.return_value = mock_llm_instance
|
|
mock_llm_instance.invoke.return_value = AIMessage(content="Initial greeting.")
|
|
|
|
result = summarize_conversation_node(state)
|
|
|
|
assert result["summary"] == "Initial greeting."
|