fix(orchestrator): Apply refinements from code review

This commit is contained in:
Yunxiao Xu
2026-02-23 15:46:21 -08:00
parent c5cf4b38a1
commit 2cfbc5d1d0
19 changed files with 252 additions and 33 deletions

View File

@@ -1,6 +1,7 @@
import io
import sys
import traceback
import copy
from contextlib import redirect_stdout
from typing import TYPE_CHECKING
import pandas as pd
@@ -39,7 +40,7 @@ def executor_node(state: AgentState) -> dict:
db_client = DBClient(settings=db_settings)
# Initialize the Virtual File System (VFS) helper
vfs_state = dict(state.get("vfs", {}))
vfs_state = copy.deepcopy(state.get("vfs", {}))
vfs_helper = VFSHelper(vfs_state)
# Initialize local variables for execution

View File

@@ -1,4 +1,3 @@
import yaml
from ea_chatbot.graph.state import AgentState
from ea_chatbot.config import Settings
from ea_chatbot.utils.llm_factory import get_llm_model

View File

@@ -1,4 +1,3 @@
from typing import List, Literal
from ea_chatbot.graph.state import AgentState
from ea_chatbot.config import Settings
from ea_chatbot.utils.llm_factory import get_llm_model

View File

@@ -48,16 +48,16 @@ If there were major errors or the output is missing critical data requested in t
}
else:
logger.warning(f"[bold yellow]Sub-task NOT satisfied.[/bold yellow] Reason: {response.reasoning}")
# For now, we'll still advance to avoid infinite loops, but a more complex orchestrator
# would trigger a retry or adjustment.
# Do NOT advance the step. This triggers a retry of the same task.
# In a more advanced version, we might route to a 'planner' for revision.
return {
"current_step": current_step + 1,
"current_step": current_step,
"next_action": "delegate"
}
except Exception as e:
logger.error(f"Failed to reflect: {str(e)}")
# Fallback: advance anyway
# On error, do not advance to be safe
return {
"current_step": current_step + 1,
"current_step": current_step,
"next_action": "delegate"
}

View File

@@ -12,12 +12,12 @@ class AgentState(AS):
# Query Analysis (Decomposition results)
analysis: Optional[Dict[str, Any]]
# Expected keys: "requires_dataset", "expert", "data", "unknown", "condition"
# Expected keys: "data_required", "unknowns", "ambiguities", "conditions"
# Step-by-step reasoning
# Step-by-step reasoning (Legacy, use checklist for new flows)
plan: Optional[str]
# Code execution context
# Code execution context (Legacy, use workers for new flows)
code: Optional[str]
code_output: Optional[str]
error: Optional[str]
@@ -26,10 +26,10 @@ class AgentState(AS):
plots: Annotated[List[Any], operator.add] # Matplotlib figures
dfs: Dict[str, Any] # Pandas DataFrames
# Conversation summary
# Conversation summary / Latest worker result
summary: Optional[str]
# Routing hint: "clarify", "plan", "research", "end"
# Routing hint: "clarify", "plan", "end"
next_action: str
# Number of execution attempts

View File

@@ -1,4 +1,5 @@
from typing import Dict, Any, List
import copy
from langchain_core.messages import HumanMessage, AIMessage
from ea_chatbot.graph.state import AgentState
from ea_chatbot.graph.workers.data_analyst.state import WorkerState
@@ -16,7 +17,7 @@ def prepare_worker_input(state: AgentState) -> Dict[str, Any]:
return {
"task": task_desc,
"messages": [HumanMessage(content=task_desc)], # Start worker loop with the task
"vfs_state": dict(state.get("vfs", {})),
"vfs_state": copy.deepcopy(state.get("vfs", {})),
"iterations": 0,
"plots": [],
"code": None,
@@ -27,11 +28,12 @@ def prepare_worker_input(state: AgentState) -> Dict[str, Any]:
def merge_worker_output(worker_state: WorkerState) -> Dict[str, Any]:
"""Map worker results back to the global AgentState."""
result = worker_state.get("result", "Analysis complete.")
result = worker_state.get("result") or "No result produced by data analyst worker."
# We bubble up the summary as an AI message
# We bubble up the summary as an AI message and update the global summary field
updates = {
"messages": [AIMessage(content=result)],
"summary": result,
"vfs": worker_state.get("vfs_state", {}),
"plots": worker_state.get("plots", [])
}

View File

@@ -33,8 +33,11 @@ def coder_node(state: WorkerState) -> dict:
vfs_summary = "Available in-memory files (VFS):\n"
if vfs_state:
for filename, data in vfs_state.items():
if isinstance(data, dict):
meta = data.get("metadata", {})
vfs_summary += f"- {filename} ({meta.get('type', 'unknown')})\n"
else:
vfs_summary += f"- {filename} (raw data)\n"
else:
vfs_summary += "- None"

View File

@@ -1,6 +1,7 @@
import io
import sys
import traceback
import copy
from contextlib import redirect_stdout
from typing import TYPE_CHECKING
import pandas as pd
@@ -39,7 +40,7 @@ def executor_node(state: WorkerState) -> dict:
db_client = DBClient(settings=db_settings)
# Initialize the Virtual File System (VFS) helper with the snapshot from state
vfs_state = dict(state.get("vfs_state", {}))
vfs_state = copy.deepcopy(state.get("vfs_state", {}))
vfs_helper = VFSHelper(vfs_state)
# Initialize local variables for execution

View File

@@ -1,4 +1,5 @@
from langgraph.graph import StateGraph, END
from langgraph.graph.state import CompiledStateGraph
from ea_chatbot.graph.workers.data_analyst.state import WorkerState
from ea_chatbot.graph.workers.data_analyst.nodes.coder import coder_node
from ea_chatbot.graph.workers.data_analyst.nodes.executor import executor_node
@@ -20,7 +21,7 @@ def create_data_analyst_worker(
coder=coder_node,
executor=executor_node,
summarizer=summarizer_node
) -> StateGraph:
) -> CompiledStateGraph:
"""Create the Data Analyst worker subgraph."""
workflow = StateGraph(WorkerState)

View File

@@ -23,9 +23,10 @@ def prepare_researcher_input(state: AgentState) -> Dict[str, Any]:
def merge_researcher_output(worker_state: WorkerState) -> Dict[str, Any]:
"""Map researcher results back to the global AgentState."""
result = worker_state.get("result", "Research complete.")
result = worker_state.get("result") or "No result produced by researcher worker."
return {
"messages": [AIMessage(content=result)],
"summary": result,
# Researcher doesn't usually update VFS or Plots, but we keep the structure
}

View File

@@ -1,4 +1,5 @@
from langgraph.graph import StateGraph, END
from langgraph.graph.state import CompiledStateGraph
from ea_chatbot.graph.workers.researcher.state import WorkerState
from ea_chatbot.graph.workers.researcher.nodes.searcher import searcher_node
from ea_chatbot.graph.workers.researcher.nodes.summarizer import summarizer_node
@@ -6,7 +7,7 @@ from ea_chatbot.graph.workers.researcher.nodes.summarizer import summarizer_node
def create_researcher_worker(
searcher=searcher_node,
summarizer=summarizer_node
) -> StateGraph:
) -> CompiledStateGraph:
"""Create the Researcher worker subgraph."""
workflow = StateGraph(WorkerState)

View File

@@ -9,22 +9,23 @@ from ea_chatbot.graph.workers.data_analyst.workflow import create_data_analyst_w
from ea_chatbot.graph.workers.data_analyst.mapping import prepare_worker_input, merge_worker_output
from ea_chatbot.graph.workers.researcher.workflow import create_researcher_worker
from ea_chatbot.graph.workers.researcher.mapping import prepare_researcher_input, merge_researcher_output
from ea_chatbot.graph.nodes.researcher import researcher_node
from ea_chatbot.graph.nodes.clarification import clarification_node
from ea_chatbot.graph.nodes.summarize_conversation import summarize_conversation_node
# Global cache for compiled subgraphs
_DATA_ANALYST_WORKER = create_data_analyst_worker()
_RESEARCHER_WORKER = create_researcher_worker()
def data_analyst_worker_node(state: AgentState) -> dict:
"""Wrapper node for the Data Analyst subgraph with state mapping."""
worker_graph = create_data_analyst_worker()
worker_input = prepare_worker_input(state)
worker_result = worker_graph.invoke(worker_input)
worker_result = _DATA_ANALYST_WORKER.invoke(worker_input)
return merge_worker_output(worker_result)
def researcher_worker_node(state: AgentState) -> dict:
"""Wrapper node for the Researcher subgraph with state mapping."""
worker_graph = create_researcher_worker()
worker_input = prepare_researcher_input(state)
worker_result = worker_graph.invoke(worker_input)
worker_result = _RESEARCHER_WORKER.invoke(worker_input)
return merge_researcher_output(worker_result)
def main_router(state: AgentState) -> str:
@@ -32,6 +33,8 @@ def main_router(state: AgentState) -> str:
next_action = state.get("next_action")
if next_action == "clarify":
return "clarification"
# Even if QA suggests 'research', we now go through 'planner' for orchestration
# aligning with the new hierarchical architecture.
return "planner"
def delegation_router(state: AgentState) -> str:

View File

@@ -17,8 +17,11 @@ class VFSHelper:
def read(self, filename: str) -> Tuple[Optional[Any], Optional[Dict[str, Any]]]:
"""Read a file and its metadata from the VFS. Returns (None, None) if not found."""
file_data = self._vfs.get(filename)
if file_data:
return file_data["content"], file_data["metadata"]
if file_data is not None:
# Handle raw values (backwards compatibility or inconsistent schema)
if not isinstance(file_data, dict) or "content" not in file_data:
return file_data, {}
return file_data["content"], file_data.get("metadata", {})
return None, None
def list(self) -> List[str]:

View File

@@ -123,3 +123,20 @@ def test_get_me_success(client):
assert response.status_code == 200
assert response.json()["email"] == "test@example.com"
assert response.json()["id"] == "123"
def test_get_me_rejects_refresh_token(client):
"""Test that /auth/me rejects refresh tokens for authentication."""
from ea_chatbot.api.utils import create_refresh_token
token = create_refresh_token(data={"sub": "123"})
with patch("ea_chatbot.api.dependencies.history_manager") as mock_hm:
# Even if the user exists, the dependency should reject the token type
mock_hm.get_user_by_id.return_value = User(id="123", username="test@example.com")
response = client.get(
"/api/v1/auth/me",
headers={"Authorization": f"Bearer {token}"}
)
assert response.status_code == 401
assert "Cannot use refresh token" in response.json()["detail"]

View File

@@ -0,0 +1,42 @@
import pytest
from ea_chatbot.graph.workflow import create_workflow, data_analyst_worker_node
from ea_chatbot.graph.state import AgentState
from unittest.mock import MagicMock, patch
from langchain_core.messages import HumanMessage, AIMessage
def test_worker_merge_sets_summary_for_reflector():
"""Verify that worker node (wrapper) sets the 'summary' field for the Reflector."""
state = AgentState(
messages=[HumanMessage(content="test")],
question="test",
checklist=[{"task": "Analyze data", "worker": "data_analyst"}],
current_step=0,
iterations=0,
vfs={},
plots=[],
dfs={},
next_action="",
analysis={},
summary="Initial Planner Summary" # Stale summary
)
# Mock the compiled worker subgraph to return a specific result
with patch("ea_chatbot.graph.workflow._DATA_ANALYST_WORKER") as mock_worker:
mock_worker.invoke.return_value = {
"result": "Actual Worker Findings",
"messages": [AIMessage(content="Internal")],
"vfs_state": {},
"plots": []
}
# Execute the wrapper node
updates = data_analyst_worker_node(state)
# Verify that 'summary' is in updates and has the worker result
assert "summary" in updates
assert updates["summary"] == "Actual Worker Findings"
# When applied to state, it should overwrite the stale summary
state.update(updates)
assert state["summary"] == "Actual Worker Findings"

View File

@@ -0,0 +1,33 @@
import pytest
from unittest.mock import MagicMock, patch
from ea_chatbot.graph.nodes.reflector import reflector_node
from ea_chatbot.graph.state import AgentState
def test_reflector_does_not_advance_on_failure():
"""Verify that reflector does not increment current_step if not satisfied."""
state = AgentState(
checklist=[{"task": "Task 1", "worker": "data_analyst"}],
current_step=0,
messages=[],
question="test",
analysis={},
next_action="",
iterations=0,
vfs={},
plots=[],
dfs={},
summary="Failed output"
)
with patch("ea_chatbot.graph.nodes.reflector.get_llm_model") as mock_get_llm:
mock_llm = MagicMock()
# Mark as NOT satisfied
mock_llm.with_structured_output.return_value.invoke.return_value = MagicMock(satisfied=False, reasoning="Incomplete.")
mock_get_llm.return_value = mock_llm
result = reflector_node(state)
# Should NOT increment
assert result["current_step"] == 0
# Should probably route to planner or retry
assert result["next_action"] == "delegate" # Or 'planner' if we want re-planning

View File

@@ -0,0 +1,51 @@
import pytest
from ea_chatbot.graph.workers.data_analyst.mapping import merge_worker_output as merge_analyst
from ea_chatbot.graph.workers.researcher.mapping import merge_researcher_output as merge_researcher
from ea_chatbot.graph.workers.data_analyst.state import WorkerState as AnalystState
from ea_chatbot.graph.workers.researcher.state import WorkerState as ResearcherState
def test_analyst_merge_updates_summary():
"""Verify that analyst merge updates the global summary."""
worker_state = AnalystState(
result="Actual Worker Result",
messages=[],
task="test",
iterations=1,
vfs_state={},
plots=[]
)
updates = merge_analyst(worker_state)
assert "summary" in updates
assert updates["summary"] == "Actual Worker Result"
def test_researcher_merge_updates_summary():
"""Verify that researcher merge updates the global summary."""
worker_state = ResearcherState(
result="Actual Research Result",
messages=[],
task="test",
iterations=1,
queries=[],
raw_results=[]
)
updates = merge_researcher(worker_state)
assert "summary" in updates
assert updates["summary"] == "Actual Research Result"
def test_merge_handles_none_result():
"""Verify that merge functions handle None results gracefully."""
worker_state = AnalystState(
result=None,
messages=[],
task="test",
iterations=1,
vfs_state={},
plots=[]
)
updates = merge_analyst(worker_state)
assert updates["summary"] is not None
assert isinstance(updates["messages"][0].content, str)
assert len(updates["messages"][0].content) > 0

View File

@@ -0,0 +1,50 @@
import pytest
from ea_chatbot.utils.vfs import VFSHelper
from ea_chatbot.graph.workers.data_analyst.mapping import prepare_worker_input
from ea_chatbot.graph.state import AgentState
def test_vfs_isolation_deep_copy():
"""Verify that worker VFS is deep-copied from global state."""
global_vfs = {
"data.txt": {
"content": "original",
"metadata": {"tags": ["a"]}
}
}
state = AgentState(
checklist=[{"task": "test", "worker": "data_analyst"}],
current_step=0,
messages=[],
question="test",
analysis={},
next_action="",
iterations=0,
vfs=global_vfs,
plots=[],
dfs={}
)
worker_input = prepare_worker_input(state)
worker_vfs = worker_input["vfs_state"]
# Mutate worker VFS nested object
worker_vfs["data.txt"]["metadata"]["tags"].append("b")
# Global VFS should remain unchanged
assert global_vfs["data.txt"]["metadata"]["tags"] == ["a"]
def test_vfs_schema_normalization():
"""Verify that VFSHelper handles inconsistent VFS entries."""
vfs = {
"raw.txt": "just a string", # Inconsistent with standard Dict[str, Any] entry
"valid.txt": {"content": "data", "metadata": {}}
}
helper = VFSHelper(vfs)
# Should not crash during read
content, metadata = helper.read("raw.txt")
assert content == "just a string"
assert metadata == {}
# Should not crash during list/other ops if they assume dict
assert "raw.txt" in helper.list()

View File

@@ -12,14 +12,16 @@ def mock_llms():
patch("ea_chatbot.graph.workers.data_analyst.nodes.coder.get_llm_model") as mock_coder, \
patch("ea_chatbot.graph.workers.data_analyst.nodes.summarizer.get_llm_model") as mock_worker_summarizer, \
patch("ea_chatbot.graph.nodes.synthesizer.get_llm_model") as mock_synthesizer, \
patch("ea_chatbot.graph.nodes.researcher.get_llm_model") as mock_researcher:
patch("ea_chatbot.graph.nodes.researcher.get_llm_model") as mock_researcher, \
patch("ea_chatbot.graph.nodes.reflector.get_llm_model") as mock_reflector:
yield {
"qa": mock_qa,
"planner": mock_planner,
"coder": mock_coder,
"worker_summarizer": mock_worker_summarizer,
"synthesizer": mock_synthesizer,
"researcher": mock_researcher
"researcher": mock_researcher,
"reflector": mock_reflector
}
def test_workflow_data_analysis_flow(mock_llms):
@@ -58,7 +60,12 @@ def test_workflow_data_analysis_flow(mock_llms):
mock_llms["worker_summarizer"].return_value = mock_ws_instance
mock_ws_instance.invoke.return_value = AIMessage(content="Worker Summary")
# 5. Mock Synthesizer
# 5. Mock Reflector
mock_reflector_instance = MagicMock()
mock_llms["reflector"].return_value = mock_reflector_instance
mock_reflector_instance.with_structured_output.return_value.invoke.return_value = MagicMock(satisfied=True, reasoning="Good.")
# 6. Mock Synthesizer
mock_syn_instance = MagicMock()
mock_llms["synthesizer"].return_value = mock_syn_instance
mock_syn_instance.invoke.return_value = AIMessage(content="Final Summary: Success")
@@ -111,7 +118,12 @@ def test_workflow_research_flow(mock_llms):
mock_llms["researcher"].return_value = mock_res_instance
mock_res_instance.invoke.return_value = AIMessage(content="Research Result")
# 4. Mock Synthesizer
# 4. Mock Reflector
mock_reflector_instance = MagicMock()
mock_llms["reflector"].return_value = mock_reflector_instance
mock_reflector_instance.with_structured_output.return_value.invoke.return_value = MagicMock(satisfied=True, reasoning="Good.")
# 5. Mock Synthesizer
mock_syn_instance = MagicMock()
mock_llms["synthesizer"].return_value = mock_syn_instance
mock_syn_instance.invoke.return_value = AIMessage(content="Final Research Summary")