import pytest from langchain_core.messages import HumanMessage, AIMessage from ea_chatbot.utils.helpers import merge_agent_state def test_merge_agent_state_list_accumulation(): """Verify that list fields (messages, plots) are accumulated (appended).""" current_state = { "messages": [HumanMessage(content="hello")], "plots": ["plot1"] } update = { "messages": [AIMessage(content="hi")], "plots": ["plot2"] } merged = merge_agent_state(current_state, update) assert len(merged["messages"]) == 2 assert merged["messages"][0].content == "hello" assert merged["messages"][1].content == "hi" assert len(merged["plots"]) == 2 assert merged["plots"] == ["plot1", "plot2"] def test_merge_agent_state_dict_update(): """Verify that dictionary fields (dfs) are updated (shallow merge).""" current_state = { "dfs": {"df1": "data1"} } update = { "dfs": {"df2": "data2"} } merged = merge_agent_state(current_state, update) assert merged["dfs"] == {"df1": "data1", "df2": "data2"} # Verify overwrite within dict update_overwrite = { "dfs": {"df1": "new_data1"} } merged_overwrite = merge_agent_state(merged, update_overwrite) assert merged_overwrite["dfs"] == {"df1": "new_data1", "df2": "data2"} def test_merge_agent_state_standard_overwrite(): """Verify that standard fields are overwritten.""" current_state = { "question": "old question", "next_action": "old action", "plan": "old plan" } update = { "question": "new question", "next_action": "new action", "plan": "new plan" } merged = merge_agent_state(current_state, update) assert merged["question"] == "new question" assert merged["next_action"] == "new action" assert merged["plan"] == "new plan" def test_merge_agent_state_none_handling(): """Verify that None updates or missing keys in update don't break things.""" current_state = { "question": "test", "messages": ["msg1"] } # Empty update assert merge_agent_state(current_state, {}) == current_state # Update with None value for overwritable field merged_none = merge_agent_state(current_state, {"question": None}) assert merged_none["question"] is None assert merged_none["messages"] == ["msg1"]