56 lines
1.8 KiB
Python
56 lines
1.8 KiB
Python
import pytest
|
|
from unittest.mock import MagicMock, patch
|
|
from ea_chatbot.graph.workflow import create_workflow
|
|
from ea_chatbot.graph.state import AgentState
|
|
from langchain_core.messages import HumanMessage, AIMessage
|
|
|
|
def test_clarification_flow_immediate_execution():
|
|
"""Verify that an ambiguous query immediately executes the clarification node without interruption."""
|
|
|
|
mock_analyzer = MagicMock()
|
|
mock_clarification = MagicMock()
|
|
|
|
# 1. Analyzer returns 'clarify'
|
|
mock_analyzer.return_value = {"next_action": "clarify"}
|
|
|
|
# 2. Clarification node returns a question
|
|
mock_clarification.return_value = {"messages": [AIMessage(content="What year?")]}
|
|
|
|
# Create workflow without other nodes since they won't be reached
|
|
# We still need to provide mock planners etc. to create_workflow
|
|
app = create_workflow(
|
|
query_analyzer=mock_analyzer,
|
|
clarification=mock_clarification,
|
|
planner=MagicMock(),
|
|
delegate=MagicMock(),
|
|
data_analyst_worker=MagicMock(),
|
|
researcher_worker=MagicMock(),
|
|
reflector=MagicMock(),
|
|
synthesizer=MagicMock(),
|
|
summarize_conversation=MagicMock()
|
|
)
|
|
|
|
initial_state = AgentState(
|
|
messages=[HumanMessage(content="Who won?")],
|
|
question="Who won?",
|
|
analysis={},
|
|
next_action="",
|
|
iterations=0,
|
|
checklist=[],
|
|
current_step=0,
|
|
vfs={},
|
|
plots=[],
|
|
dfs={}
|
|
)
|
|
|
|
# Run the graph
|
|
final_state = app.invoke(initial_state)
|
|
|
|
# Assertions
|
|
assert mock_analyzer.called
|
|
assert mock_clarification.called
|
|
|
|
# Verify the state contains the clarification message
|
|
assert len(final_state["messages"]) > 0
|
|
assert "What year?" in [m.content for m in final_state["messages"]]
|