fix(orchestrator): Apply refinements from code review

This commit is contained in:
Yunxiao Xu
2026-02-23 15:46:21 -08:00
parent c5cf4b38a1
commit 2cfbc5d1d0
19 changed files with 252 additions and 33 deletions

View File

@@ -123,3 +123,20 @@ def test_get_me_success(client):
assert response.status_code == 200
assert response.json()["email"] == "test@example.com"
assert response.json()["id"] == "123"
def test_get_me_rejects_refresh_token(client):
"""Test that /auth/me rejects refresh tokens for authentication."""
from ea_chatbot.api.utils import create_refresh_token
token = create_refresh_token(data={"sub": "123"})
with patch("ea_chatbot.api.dependencies.history_manager") as mock_hm:
# Even if the user exists, the dependency should reject the token type
mock_hm.get_user_by_id.return_value = User(id="123", username="test@example.com")
response = client.get(
"/api/v1/auth/me",
headers={"Authorization": f"Bearer {token}"}
)
assert response.status_code == 401
assert "Cannot use refresh token" in response.json()["detail"]

View File

@@ -0,0 +1,42 @@
import pytest
from ea_chatbot.graph.workflow import create_workflow, data_analyst_worker_node
from ea_chatbot.graph.state import AgentState
from unittest.mock import MagicMock, patch
from langchain_core.messages import HumanMessage, AIMessage
def test_worker_merge_sets_summary_for_reflector():
"""Verify that worker node (wrapper) sets the 'summary' field for the Reflector."""
state = AgentState(
messages=[HumanMessage(content="test")],
question="test",
checklist=[{"task": "Analyze data", "worker": "data_analyst"}],
current_step=0,
iterations=0,
vfs={},
plots=[],
dfs={},
next_action="",
analysis={},
summary="Initial Planner Summary" # Stale summary
)
# Mock the compiled worker subgraph to return a specific result
with patch("ea_chatbot.graph.workflow._DATA_ANALYST_WORKER") as mock_worker:
mock_worker.invoke.return_value = {
"result": "Actual Worker Findings",
"messages": [AIMessage(content="Internal")],
"vfs_state": {},
"plots": []
}
# Execute the wrapper node
updates = data_analyst_worker_node(state)
# Verify that 'summary' is in updates and has the worker result
assert "summary" in updates
assert updates["summary"] == "Actual Worker Findings"
# When applied to state, it should overwrite the stale summary
state.update(updates)
assert state["summary"] == "Actual Worker Findings"

View File

@@ -0,0 +1,33 @@
import pytest
from unittest.mock import MagicMock, patch
from ea_chatbot.graph.nodes.reflector import reflector_node
from ea_chatbot.graph.state import AgentState
def test_reflector_does_not_advance_on_failure():
"""Verify that reflector does not increment current_step if not satisfied."""
state = AgentState(
checklist=[{"task": "Task 1", "worker": "data_analyst"}],
current_step=0,
messages=[],
question="test",
analysis={},
next_action="",
iterations=0,
vfs={},
plots=[],
dfs={},
summary="Failed output"
)
with patch("ea_chatbot.graph.nodes.reflector.get_llm_model") as mock_get_llm:
mock_llm = MagicMock()
# Mark as NOT satisfied
mock_llm.with_structured_output.return_value.invoke.return_value = MagicMock(satisfied=False, reasoning="Incomplete.")
mock_get_llm.return_value = mock_llm
result = reflector_node(state)
# Should NOT increment
assert result["current_step"] == 0
# Should probably route to planner or retry
assert result["next_action"] == "delegate" # Or 'planner' if we want re-planning

View File

@@ -0,0 +1,51 @@
import pytest
from ea_chatbot.graph.workers.data_analyst.mapping import merge_worker_output as merge_analyst
from ea_chatbot.graph.workers.researcher.mapping import merge_researcher_output as merge_researcher
from ea_chatbot.graph.workers.data_analyst.state import WorkerState as AnalystState
from ea_chatbot.graph.workers.researcher.state import WorkerState as ResearcherState
def test_analyst_merge_updates_summary():
"""Verify that analyst merge updates the global summary."""
worker_state = AnalystState(
result="Actual Worker Result",
messages=[],
task="test",
iterations=1,
vfs_state={},
plots=[]
)
updates = merge_analyst(worker_state)
assert "summary" in updates
assert updates["summary"] == "Actual Worker Result"
def test_researcher_merge_updates_summary():
"""Verify that researcher merge updates the global summary."""
worker_state = ResearcherState(
result="Actual Research Result",
messages=[],
task="test",
iterations=1,
queries=[],
raw_results=[]
)
updates = merge_researcher(worker_state)
assert "summary" in updates
assert updates["summary"] == "Actual Research Result"
def test_merge_handles_none_result():
"""Verify that merge functions handle None results gracefully."""
worker_state = AnalystState(
result=None,
messages=[],
task="test",
iterations=1,
vfs_state={},
plots=[]
)
updates = merge_analyst(worker_state)
assert updates["summary"] is not None
assert isinstance(updates["messages"][0].content, str)
assert len(updates["messages"][0].content) > 0

View File

@@ -0,0 +1,50 @@
import pytest
from ea_chatbot.utils.vfs import VFSHelper
from ea_chatbot.graph.workers.data_analyst.mapping import prepare_worker_input
from ea_chatbot.graph.state import AgentState
def test_vfs_isolation_deep_copy():
"""Verify that worker VFS is deep-copied from global state."""
global_vfs = {
"data.txt": {
"content": "original",
"metadata": {"tags": ["a"]}
}
}
state = AgentState(
checklist=[{"task": "test", "worker": "data_analyst"}],
current_step=0,
messages=[],
question="test",
analysis={},
next_action="",
iterations=0,
vfs=global_vfs,
plots=[],
dfs={}
)
worker_input = prepare_worker_input(state)
worker_vfs = worker_input["vfs_state"]
# Mutate worker VFS nested object
worker_vfs["data.txt"]["metadata"]["tags"].append("b")
# Global VFS should remain unchanged
assert global_vfs["data.txt"]["metadata"]["tags"] == ["a"]
def test_vfs_schema_normalization():
"""Verify that VFSHelper handles inconsistent VFS entries."""
vfs = {
"raw.txt": "just a string", # Inconsistent with standard Dict[str, Any] entry
"valid.txt": {"content": "data", "metadata": {}}
}
helper = VFSHelper(vfs)
# Should not crash during read
content, metadata = helper.read("raw.txt")
assert content == "just a string"
assert metadata == {}
# Should not crash during list/other ops if they assume dict
assert "raw.txt" in helper.list()

View File

@@ -12,14 +12,16 @@ def mock_llms():
patch("ea_chatbot.graph.workers.data_analyst.nodes.coder.get_llm_model") as mock_coder, \
patch("ea_chatbot.graph.workers.data_analyst.nodes.summarizer.get_llm_model") as mock_worker_summarizer, \
patch("ea_chatbot.graph.nodes.synthesizer.get_llm_model") as mock_synthesizer, \
patch("ea_chatbot.graph.nodes.researcher.get_llm_model") as mock_researcher:
patch("ea_chatbot.graph.nodes.researcher.get_llm_model") as mock_researcher, \
patch("ea_chatbot.graph.nodes.reflector.get_llm_model") as mock_reflector:
yield {
"qa": mock_qa,
"planner": mock_planner,
"coder": mock_coder,
"worker_summarizer": mock_worker_summarizer,
"synthesizer": mock_synthesizer,
"researcher": mock_researcher
"researcher": mock_researcher,
"reflector": mock_reflector
}
def test_workflow_data_analysis_flow(mock_llms):
@@ -58,7 +60,12 @@ def test_workflow_data_analysis_flow(mock_llms):
mock_llms["worker_summarizer"].return_value = mock_ws_instance
mock_ws_instance.invoke.return_value = AIMessage(content="Worker Summary")
# 5. Mock Synthesizer
# 5. Mock Reflector
mock_reflector_instance = MagicMock()
mock_llms["reflector"].return_value = mock_reflector_instance
mock_reflector_instance.with_structured_output.return_value.invoke.return_value = MagicMock(satisfied=True, reasoning="Good.")
# 6. Mock Synthesizer
mock_syn_instance = MagicMock()
mock_llms["synthesizer"].return_value = mock_syn_instance
mock_syn_instance.invoke.return_value = AIMessage(content="Final Summary: Success")
@@ -111,7 +118,12 @@ def test_workflow_research_flow(mock_llms):
mock_llms["researcher"].return_value = mock_res_instance
mock_res_instance.invoke.return_value = AIMessage(content="Research Result")
# 4. Mock Synthesizer
# 4. Mock Reflector
mock_reflector_instance = MagicMock()
mock_llms["reflector"].return_value = mock_reflector_instance
mock_reflector_instance.with_structured_output.return_value.invoke.return_value = MagicMock(satisfied=True, reasoning="Good.")
# 5. Mock Synthesizer
mock_syn_instance = MagicMock()
mock_llms["synthesizer"].return_value = mock_syn_instance
mock_syn_instance.invoke.return_value = AIMessage(content="Final Research Summary")