78 lines
2.4 KiB
Python
78 lines
2.4 KiB
Python
import pytest
|
|
from langchain_core.messages import HumanMessage, AIMessage
|
|
from ea_chatbot.utils.helpers import merge_agent_state
|
|
|
|
def test_merge_agent_state_list_accumulation():
|
|
"""Verify that list fields (messages, plots) are accumulated (appended)."""
|
|
current_state = {
|
|
"messages": [HumanMessage(content="hello")],
|
|
"plots": ["plot1"]
|
|
}
|
|
update = {
|
|
"messages": [AIMessage(content="hi")],
|
|
"plots": ["plot2"]
|
|
}
|
|
|
|
merged = merge_agent_state(current_state, update)
|
|
|
|
assert len(merged["messages"]) == 2
|
|
assert merged["messages"][0].content == "hello"
|
|
assert merged["messages"][1].content == "hi"
|
|
|
|
assert len(merged["plots"]) == 2
|
|
assert merged["plots"] == ["plot1", "plot2"]
|
|
|
|
def test_merge_agent_state_dict_update():
|
|
"""Verify that dictionary fields (dfs) are updated (shallow merge)."""
|
|
current_state = {
|
|
"dfs": {"df1": "data1"}
|
|
}
|
|
update = {
|
|
"dfs": {"df2": "data2"}
|
|
}
|
|
|
|
merged = merge_agent_state(current_state, update)
|
|
|
|
assert merged["dfs"] == {"df1": "data1", "df2": "data2"}
|
|
|
|
# Verify overwrite within dict
|
|
update_overwrite = {
|
|
"dfs": {"df1": "new_data1"}
|
|
}
|
|
merged_overwrite = merge_agent_state(merged, update_overwrite)
|
|
assert merged_overwrite["dfs"] == {"df1": "new_data1", "df2": "data2"}
|
|
|
|
def test_merge_agent_state_standard_overwrite():
|
|
"""Verify that standard fields are overwritten."""
|
|
current_state = {
|
|
"question": "old question",
|
|
"next_action": "old action",
|
|
"plan": "old plan"
|
|
}
|
|
update = {
|
|
"question": "new question",
|
|
"next_action": "new action",
|
|
"plan": "new plan"
|
|
}
|
|
|
|
merged = merge_agent_state(current_state, update)
|
|
|
|
assert merged["question"] == "new question"
|
|
assert merged["next_action"] == "new action"
|
|
assert merged["plan"] == "new plan"
|
|
|
|
def test_merge_agent_state_none_handling():
|
|
"""Verify that None updates or missing keys in update don't break things."""
|
|
current_state = {
|
|
"question": "test",
|
|
"messages": ["msg1"]
|
|
}
|
|
|
|
# Empty update
|
|
assert merge_agent_state(current_state, {}) == current_state
|
|
|
|
# Update with None value for overwritable field
|
|
merged_none = merge_agent_state(current_state, {"question": None})
|
|
assert merged_none["question"] is None
|
|
assert merged_none["messages"] == ["msg1"]
|