diff --git a/backend/src/ea_chatbot/graph/nodes/clarification.py b/backend/src/ea_chatbot/graph/nodes/clarification.py index 58e887d..fa953aa 100644 --- a/backend/src/ea_chatbot/graph/nodes/clarification.py +++ b/backend/src/ea_chatbot/graph/nodes/clarification.py @@ -36,8 +36,7 @@ Please ask the user for the necessary details.""" response = llm.invoke(messages) logger.info("[bold green]Clarification generated.[/bold green]") return { - "messages": [response], - "next_action": "end" # To indicate we are done for now + "messages": [response] } except Exception as e: logger.error(f"Failed to generate clarification: {str(e)}") diff --git a/backend/src/ea_chatbot/graph/workflow.py b/backend/src/ea_chatbot/graph/workflow.py index 25df66b..ba0722e 100644 --- a/backend/src/ea_chatbot/graph/workflow.py +++ b/backend/src/ea_chatbot/graph/workflow.py @@ -106,10 +106,8 @@ def create_workflow( workflow.add_edge("summarize_conversation", END) workflow.add_edge("clarification", END) - # Compile the graph with human-in-the-loop interrupt - app = workflow.compile( - interrupt_before=["clarification"] - ) + # Compile the graph + app = workflow.compile() return app diff --git a/backend/tests/test_review_fix_clarification.py b/backend/tests/test_review_fix_clarification.py new file mode 100644 index 0000000..7ab8306 --- /dev/null +++ b/backend/tests/test_review_fix_clarification.py @@ -0,0 +1,55 @@ +import pytest +from unittest.mock import MagicMock, patch +from ea_chatbot.graph.workflow import create_workflow +from ea_chatbot.graph.state import AgentState +from langchain_core.messages import HumanMessage, AIMessage + +def test_clarification_flow_immediate_execution(): + """Verify that an ambiguous query immediately executes the clarification node without interruption.""" + + mock_analyzer = MagicMock() + mock_clarification = MagicMock() + + # 1. Analyzer returns 'clarify' + mock_analyzer.return_value = {"next_action": "clarify"} + + # 2. Clarification node returns a question + mock_clarification.return_value = {"messages": [AIMessage(content="What year?")]} + + # Create workflow without other nodes since they won't be reached + # We still need to provide mock planners etc. to create_workflow + app = create_workflow( + query_analyzer=mock_analyzer, + clarification=mock_clarification, + planner=MagicMock(), + delegate=MagicMock(), + data_analyst_worker=MagicMock(), + researcher_worker=MagicMock(), + reflector=MagicMock(), + synthesizer=MagicMock(), + summarize_conversation=MagicMock() + ) + + initial_state = AgentState( + messages=[HumanMessage(content="Who won?")], + question="Who won?", + analysis={}, + next_action="", + iterations=0, + checklist=[], + current_step=0, + vfs={}, + plots=[], + dfs={} + ) + + # Run the graph + final_state = app.invoke(initial_state) + + # Assertions + assert mock_analyzer.called + assert mock_clarification.called + + # Verify the state contains the clarification message + assert len(final_state["messages"]) > 0 + assert "What year?" in [m.content for m in final_state["messages"]] diff --git a/backend/tests/test_review_fix_reflector.py b/backend/tests/test_review_fix_reflector.py index 4b9d858..9419b7a 100644 --- a/backend/tests/test_review_fix_reflector.py +++ b/backend/tests/test_review_fix_reflector.py @@ -27,7 +27,9 @@ def test_reflector_does_not_advance_on_failure(): result = reflector_node(state) - # Should NOT increment - assert result["current_step"] == 0 + # Should NOT increment (therefore not in the updates dict) + assert "current_step" not in result + # Should increment iterations + assert result["iterations"] == 1 # Should probably route to planner or retry assert result["next_action"] == "delegate" # Or 'planner' if we want re-planning