fix(orchestrator): Apply refinements from code review
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import io
|
||||
import sys
|
||||
import traceback
|
||||
import copy
|
||||
from contextlib import redirect_stdout
|
||||
from typing import TYPE_CHECKING
|
||||
import pandas as pd
|
||||
@@ -39,7 +40,7 @@ def executor_node(state: AgentState) -> dict:
|
||||
db_client = DBClient(settings=db_settings)
|
||||
|
||||
# Initialize the Virtual File System (VFS) helper
|
||||
vfs_state = dict(state.get("vfs", {}))
|
||||
vfs_state = copy.deepcopy(state.get("vfs", {}))
|
||||
vfs_helper = VFSHelper(vfs_state)
|
||||
|
||||
# Initialize local variables for execution
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import yaml
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from typing import List, Literal
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
|
||||
@@ -48,16 +48,16 @@ If there were major errors or the output is missing critical data requested in t
|
||||
}
|
||||
else:
|
||||
logger.warning(f"[bold yellow]Sub-task NOT satisfied.[/bold yellow] Reason: {response.reasoning}")
|
||||
# For now, we'll still advance to avoid infinite loops, but a more complex orchestrator
|
||||
# would trigger a retry or adjustment.
|
||||
# Do NOT advance the step. This triggers a retry of the same task.
|
||||
# In a more advanced version, we might route to a 'planner' for revision.
|
||||
return {
|
||||
"current_step": current_step + 1,
|
||||
"current_step": current_step,
|
||||
"next_action": "delegate"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to reflect: {str(e)}")
|
||||
# Fallback: advance anyway
|
||||
# On error, do not advance to be safe
|
||||
return {
|
||||
"current_step": current_step + 1,
|
||||
"current_step": current_step,
|
||||
"next_action": "delegate"
|
||||
}
|
||||
|
||||
@@ -12,12 +12,12 @@ class AgentState(AS):
|
||||
|
||||
# Query Analysis (Decomposition results)
|
||||
analysis: Optional[Dict[str, Any]]
|
||||
# Expected keys: "requires_dataset", "expert", "data", "unknown", "condition"
|
||||
# Expected keys: "data_required", "unknowns", "ambiguities", "conditions"
|
||||
|
||||
# Step-by-step reasoning
|
||||
# Step-by-step reasoning (Legacy, use checklist for new flows)
|
||||
plan: Optional[str]
|
||||
|
||||
# Code execution context
|
||||
# Code execution context (Legacy, use workers for new flows)
|
||||
code: Optional[str]
|
||||
code_output: Optional[str]
|
||||
error: Optional[str]
|
||||
@@ -26,10 +26,10 @@ class AgentState(AS):
|
||||
plots: Annotated[List[Any], operator.add] # Matplotlib figures
|
||||
dfs: Dict[str, Any] # Pandas DataFrames
|
||||
|
||||
# Conversation summary
|
||||
# Conversation summary / Latest worker result
|
||||
summary: Optional[str]
|
||||
|
||||
# Routing hint: "clarify", "plan", "research", "end"
|
||||
# Routing hint: "clarify", "plan", "end"
|
||||
next_action: str
|
||||
|
||||
# Number of execution attempts
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Dict, Any, List
|
||||
import copy
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.graph.workers.data_analyst.state import WorkerState
|
||||
@@ -16,7 +17,7 @@ def prepare_worker_input(state: AgentState) -> Dict[str, Any]:
|
||||
return {
|
||||
"task": task_desc,
|
||||
"messages": [HumanMessage(content=task_desc)], # Start worker loop with the task
|
||||
"vfs_state": dict(state.get("vfs", {})),
|
||||
"vfs_state": copy.deepcopy(state.get("vfs", {})),
|
||||
"iterations": 0,
|
||||
"plots": [],
|
||||
"code": None,
|
||||
@@ -27,11 +28,12 @@ def prepare_worker_input(state: AgentState) -> Dict[str, Any]:
|
||||
|
||||
def merge_worker_output(worker_state: WorkerState) -> Dict[str, Any]:
|
||||
"""Map worker results back to the global AgentState."""
|
||||
result = worker_state.get("result", "Analysis complete.")
|
||||
result = worker_state.get("result") or "No result produced by data analyst worker."
|
||||
|
||||
# We bubble up the summary as an AI message
|
||||
# We bubble up the summary as an AI message and update the global summary field
|
||||
updates = {
|
||||
"messages": [AIMessage(content=result)],
|
||||
"summary": result,
|
||||
"vfs": worker_state.get("vfs_state", {}),
|
||||
"plots": worker_state.get("plots", [])
|
||||
}
|
||||
|
||||
@@ -33,8 +33,11 @@ def coder_node(state: WorkerState) -> dict:
|
||||
vfs_summary = "Available in-memory files (VFS):\n"
|
||||
if vfs_state:
|
||||
for filename, data in vfs_state.items():
|
||||
meta = data.get("metadata", {})
|
||||
vfs_summary += f"- {filename} ({meta.get('type', 'unknown')})\n"
|
||||
if isinstance(data, dict):
|
||||
meta = data.get("metadata", {})
|
||||
vfs_summary += f"- {filename} ({meta.get('type', 'unknown')})\n"
|
||||
else:
|
||||
vfs_summary += f"- {filename} (raw data)\n"
|
||||
else:
|
||||
vfs_summary += "- None"
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import io
|
||||
import sys
|
||||
import traceback
|
||||
import copy
|
||||
from contextlib import redirect_stdout
|
||||
from typing import TYPE_CHECKING
|
||||
import pandas as pd
|
||||
@@ -39,7 +40,7 @@ def executor_node(state: WorkerState) -> dict:
|
||||
db_client = DBClient(settings=db_settings)
|
||||
|
||||
# Initialize the Virtual File System (VFS) helper with the snapshot from state
|
||||
vfs_state = dict(state.get("vfs_state", {}))
|
||||
vfs_state = copy.deepcopy(state.get("vfs_state", {}))
|
||||
vfs_helper = VFSHelper(vfs_state)
|
||||
|
||||
# Initialize local variables for execution
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from langgraph.graph import StateGraph, END
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
from ea_chatbot.graph.workers.data_analyst.state import WorkerState
|
||||
from ea_chatbot.graph.workers.data_analyst.nodes.coder import coder_node
|
||||
from ea_chatbot.graph.workers.data_analyst.nodes.executor import executor_node
|
||||
@@ -20,7 +21,7 @@ def create_data_analyst_worker(
|
||||
coder=coder_node,
|
||||
executor=executor_node,
|
||||
summarizer=summarizer_node
|
||||
) -> StateGraph:
|
||||
) -> CompiledStateGraph:
|
||||
"""Create the Data Analyst worker subgraph."""
|
||||
workflow = StateGraph(WorkerState)
|
||||
|
||||
|
||||
@@ -23,9 +23,10 @@ def prepare_researcher_input(state: AgentState) -> Dict[str, Any]:
|
||||
|
||||
def merge_researcher_output(worker_state: WorkerState) -> Dict[str, Any]:
|
||||
"""Map researcher results back to the global AgentState."""
|
||||
result = worker_state.get("result", "Research complete.")
|
||||
result = worker_state.get("result") or "No result produced by researcher worker."
|
||||
|
||||
return {
|
||||
"messages": [AIMessage(content=result)],
|
||||
"summary": result,
|
||||
# Researcher doesn't usually update VFS or Plots, but we keep the structure
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from langgraph.graph import StateGraph, END
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
from ea_chatbot.graph.workers.researcher.state import WorkerState
|
||||
from ea_chatbot.graph.workers.researcher.nodes.searcher import searcher_node
|
||||
from ea_chatbot.graph.workers.researcher.nodes.summarizer import summarizer_node
|
||||
@@ -6,7 +7,7 @@ from ea_chatbot.graph.workers.researcher.nodes.summarizer import summarizer_node
|
||||
def create_researcher_worker(
|
||||
searcher=searcher_node,
|
||||
summarizer=summarizer_node
|
||||
) -> StateGraph:
|
||||
) -> CompiledStateGraph:
|
||||
"""Create the Researcher worker subgraph."""
|
||||
workflow = StateGraph(WorkerState)
|
||||
|
||||
|
||||
@@ -9,22 +9,23 @@ from ea_chatbot.graph.workers.data_analyst.workflow import create_data_analyst_w
|
||||
from ea_chatbot.graph.workers.data_analyst.mapping import prepare_worker_input, merge_worker_output
|
||||
from ea_chatbot.graph.workers.researcher.workflow import create_researcher_worker
|
||||
from ea_chatbot.graph.workers.researcher.mapping import prepare_researcher_input, merge_researcher_output
|
||||
from ea_chatbot.graph.nodes.researcher import researcher_node
|
||||
from ea_chatbot.graph.nodes.clarification import clarification_node
|
||||
from ea_chatbot.graph.nodes.summarize_conversation import summarize_conversation_node
|
||||
|
||||
# Global cache for compiled subgraphs
|
||||
_DATA_ANALYST_WORKER = create_data_analyst_worker()
|
||||
_RESEARCHER_WORKER = create_researcher_worker()
|
||||
|
||||
def data_analyst_worker_node(state: AgentState) -> dict:
|
||||
"""Wrapper node for the Data Analyst subgraph with state mapping."""
|
||||
worker_graph = create_data_analyst_worker()
|
||||
worker_input = prepare_worker_input(state)
|
||||
worker_result = worker_graph.invoke(worker_input)
|
||||
worker_result = _DATA_ANALYST_WORKER.invoke(worker_input)
|
||||
return merge_worker_output(worker_result)
|
||||
|
||||
def researcher_worker_node(state: AgentState) -> dict:
|
||||
"""Wrapper node for the Researcher subgraph with state mapping."""
|
||||
worker_graph = create_researcher_worker()
|
||||
worker_input = prepare_researcher_input(state)
|
||||
worker_result = worker_graph.invoke(worker_input)
|
||||
worker_result = _RESEARCHER_WORKER.invoke(worker_input)
|
||||
return merge_researcher_output(worker_result)
|
||||
|
||||
def main_router(state: AgentState) -> str:
|
||||
@@ -32,6 +33,8 @@ def main_router(state: AgentState) -> str:
|
||||
next_action = state.get("next_action")
|
||||
if next_action == "clarify":
|
||||
return "clarification"
|
||||
# Even if QA suggests 'research', we now go through 'planner' for orchestration
|
||||
# aligning with the new hierarchical architecture.
|
||||
return "planner"
|
||||
|
||||
def delegation_router(state: AgentState) -> str:
|
||||
|
||||
@@ -17,8 +17,11 @@ class VFSHelper:
|
||||
def read(self, filename: str) -> Tuple[Optional[Any], Optional[Dict[str, Any]]]:
|
||||
"""Read a file and its metadata from the VFS. Returns (None, None) if not found."""
|
||||
file_data = self._vfs.get(filename)
|
||||
if file_data:
|
||||
return file_data["content"], file_data["metadata"]
|
||||
if file_data is not None:
|
||||
# Handle raw values (backwards compatibility or inconsistent schema)
|
||||
if not isinstance(file_data, dict) or "content" not in file_data:
|
||||
return file_data, {}
|
||||
return file_data["content"], file_data.get("metadata", {})
|
||||
return None, None
|
||||
|
||||
def list(self) -> List[str]:
|
||||
|
||||
@@ -123,3 +123,20 @@ def test_get_me_success(client):
|
||||
assert response.status_code == 200
|
||||
assert response.json()["email"] == "test@example.com"
|
||||
assert response.json()["id"] == "123"
|
||||
|
||||
def test_get_me_rejects_refresh_token(client):
|
||||
"""Test that /auth/me rejects refresh tokens for authentication."""
|
||||
from ea_chatbot.api.utils import create_refresh_token
|
||||
token = create_refresh_token(data={"sub": "123"})
|
||||
|
||||
with patch("ea_chatbot.api.dependencies.history_manager") as mock_hm:
|
||||
# Even if the user exists, the dependency should reject the token type
|
||||
mock_hm.get_user_by_id.return_value = User(id="123", username="test@example.com")
|
||||
|
||||
response = client.get(
|
||||
"/api/v1/auth/me",
|
||||
headers={"Authorization": f"Bearer {token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
assert "Cannot use refresh token" in response.json()["detail"]
|
||||
|
||||
42
backend/tests/test_review_fix_integration.py
Normal file
42
backend/tests/test_review_fix_integration.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import pytest
|
||||
from ea_chatbot.graph.workflow import create_workflow, data_analyst_worker_node
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from unittest.mock import MagicMock, patch
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
|
||||
def test_worker_merge_sets_summary_for_reflector():
|
||||
"""Verify that worker node (wrapper) sets the 'summary' field for the Reflector."""
|
||||
|
||||
state = AgentState(
|
||||
messages=[HumanMessage(content="test")],
|
||||
question="test",
|
||||
checklist=[{"task": "Analyze data", "worker": "data_analyst"}],
|
||||
current_step=0,
|
||||
iterations=0,
|
||||
vfs={},
|
||||
plots=[],
|
||||
dfs={},
|
||||
next_action="",
|
||||
analysis={},
|
||||
summary="Initial Planner Summary" # Stale summary
|
||||
)
|
||||
|
||||
# Mock the compiled worker subgraph to return a specific result
|
||||
with patch("ea_chatbot.graph.workflow._DATA_ANALYST_WORKER") as mock_worker:
|
||||
mock_worker.invoke.return_value = {
|
||||
"result": "Actual Worker Findings",
|
||||
"messages": [AIMessage(content="Internal")],
|
||||
"vfs_state": {},
|
||||
"plots": []
|
||||
}
|
||||
|
||||
# Execute the wrapper node
|
||||
updates = data_analyst_worker_node(state)
|
||||
|
||||
# Verify that 'summary' is in updates and has the worker result
|
||||
assert "summary" in updates
|
||||
assert updates["summary"] == "Actual Worker Findings"
|
||||
|
||||
# When applied to state, it should overwrite the stale summary
|
||||
state.update(updates)
|
||||
assert state["summary"] == "Actual Worker Findings"
|
||||
33
backend/tests/test_review_fix_reflector.py
Normal file
33
backend/tests/test_review_fix_reflector.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.graph.nodes.reflector import reflector_node
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
|
||||
def test_reflector_does_not_advance_on_failure():
|
||||
"""Verify that reflector does not increment current_step if not satisfied."""
|
||||
state = AgentState(
|
||||
checklist=[{"task": "Task 1", "worker": "data_analyst"}],
|
||||
current_step=0,
|
||||
messages=[],
|
||||
question="test",
|
||||
analysis={},
|
||||
next_action="",
|
||||
iterations=0,
|
||||
vfs={},
|
||||
plots=[],
|
||||
dfs={},
|
||||
summary="Failed output"
|
||||
)
|
||||
|
||||
with patch("ea_chatbot.graph.nodes.reflector.get_llm_model") as mock_get_llm:
|
||||
mock_llm = MagicMock()
|
||||
# Mark as NOT satisfied
|
||||
mock_llm.with_structured_output.return_value.invoke.return_value = MagicMock(satisfied=False, reasoning="Incomplete.")
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
result = reflector_node(state)
|
||||
|
||||
# Should NOT increment
|
||||
assert result["current_step"] == 0
|
||||
# Should probably route to planner or retry
|
||||
assert result["next_action"] == "delegate" # Or 'planner' if we want re-planning
|
||||
51
backend/tests/test_review_fix_summary.py
Normal file
51
backend/tests/test_review_fix_summary.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import pytest
|
||||
from ea_chatbot.graph.workers.data_analyst.mapping import merge_worker_output as merge_analyst
|
||||
from ea_chatbot.graph.workers.researcher.mapping import merge_researcher_output as merge_researcher
|
||||
from ea_chatbot.graph.workers.data_analyst.state import WorkerState as AnalystState
|
||||
from ea_chatbot.graph.workers.researcher.state import WorkerState as ResearcherState
|
||||
|
||||
def test_analyst_merge_updates_summary():
|
||||
"""Verify that analyst merge updates the global summary."""
|
||||
worker_state = AnalystState(
|
||||
result="Actual Worker Result",
|
||||
messages=[],
|
||||
task="test",
|
||||
iterations=1,
|
||||
vfs_state={},
|
||||
plots=[]
|
||||
)
|
||||
|
||||
updates = merge_analyst(worker_state)
|
||||
assert "summary" in updates
|
||||
assert updates["summary"] == "Actual Worker Result"
|
||||
|
||||
def test_researcher_merge_updates_summary():
|
||||
"""Verify that researcher merge updates the global summary."""
|
||||
worker_state = ResearcherState(
|
||||
result="Actual Research Result",
|
||||
messages=[],
|
||||
task="test",
|
||||
iterations=1,
|
||||
queries=[],
|
||||
raw_results=[]
|
||||
)
|
||||
|
||||
updates = merge_researcher(worker_state)
|
||||
assert "summary" in updates
|
||||
assert updates["summary"] == "Actual Research Result"
|
||||
|
||||
def test_merge_handles_none_result():
|
||||
"""Verify that merge functions handle None results gracefully."""
|
||||
worker_state = AnalystState(
|
||||
result=None,
|
||||
messages=[],
|
||||
task="test",
|
||||
iterations=1,
|
||||
vfs_state={},
|
||||
plots=[]
|
||||
)
|
||||
|
||||
updates = merge_analyst(worker_state)
|
||||
assert updates["summary"] is not None
|
||||
assert isinstance(updates["messages"][0].content, str)
|
||||
assert len(updates["messages"][0].content) > 0
|
||||
50
backend/tests/test_review_fix_vfs.py
Normal file
50
backend/tests/test_review_fix_vfs.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import pytest
|
||||
from ea_chatbot.utils.vfs import VFSHelper
|
||||
from ea_chatbot.graph.workers.data_analyst.mapping import prepare_worker_input
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
|
||||
def test_vfs_isolation_deep_copy():
|
||||
"""Verify that worker VFS is deep-copied from global state."""
|
||||
global_vfs = {
|
||||
"data.txt": {
|
||||
"content": "original",
|
||||
"metadata": {"tags": ["a"]}
|
||||
}
|
||||
}
|
||||
state = AgentState(
|
||||
checklist=[{"task": "test", "worker": "data_analyst"}],
|
||||
current_step=0,
|
||||
messages=[],
|
||||
question="test",
|
||||
analysis={},
|
||||
next_action="",
|
||||
iterations=0,
|
||||
vfs=global_vfs,
|
||||
plots=[],
|
||||
dfs={}
|
||||
)
|
||||
|
||||
worker_input = prepare_worker_input(state)
|
||||
worker_vfs = worker_input["vfs_state"]
|
||||
|
||||
# Mutate worker VFS nested object
|
||||
worker_vfs["data.txt"]["metadata"]["tags"].append("b")
|
||||
|
||||
# Global VFS should remain unchanged
|
||||
assert global_vfs["data.txt"]["metadata"]["tags"] == ["a"]
|
||||
|
||||
def test_vfs_schema_normalization():
|
||||
"""Verify that VFSHelper handles inconsistent VFS entries."""
|
||||
vfs = {
|
||||
"raw.txt": "just a string", # Inconsistent with standard Dict[str, Any] entry
|
||||
"valid.txt": {"content": "data", "metadata": {}}
|
||||
}
|
||||
helper = VFSHelper(vfs)
|
||||
|
||||
# Should not crash during read
|
||||
content, metadata = helper.read("raw.txt")
|
||||
assert content == "just a string"
|
||||
assert metadata == {}
|
||||
|
||||
# Should not crash during list/other ops if they assume dict
|
||||
assert "raw.txt" in helper.list()
|
||||
@@ -12,14 +12,16 @@ def mock_llms():
|
||||
patch("ea_chatbot.graph.workers.data_analyst.nodes.coder.get_llm_model") as mock_coder, \
|
||||
patch("ea_chatbot.graph.workers.data_analyst.nodes.summarizer.get_llm_model") as mock_worker_summarizer, \
|
||||
patch("ea_chatbot.graph.nodes.synthesizer.get_llm_model") as mock_synthesizer, \
|
||||
patch("ea_chatbot.graph.nodes.researcher.get_llm_model") as mock_researcher:
|
||||
patch("ea_chatbot.graph.nodes.researcher.get_llm_model") as mock_researcher, \
|
||||
patch("ea_chatbot.graph.nodes.reflector.get_llm_model") as mock_reflector:
|
||||
yield {
|
||||
"qa": mock_qa,
|
||||
"planner": mock_planner,
|
||||
"coder": mock_coder,
|
||||
"worker_summarizer": mock_worker_summarizer,
|
||||
"synthesizer": mock_synthesizer,
|
||||
"researcher": mock_researcher
|
||||
"researcher": mock_researcher,
|
||||
"reflector": mock_reflector
|
||||
}
|
||||
|
||||
def test_workflow_data_analysis_flow(mock_llms):
|
||||
@@ -58,7 +60,12 @@ def test_workflow_data_analysis_flow(mock_llms):
|
||||
mock_llms["worker_summarizer"].return_value = mock_ws_instance
|
||||
mock_ws_instance.invoke.return_value = AIMessage(content="Worker Summary")
|
||||
|
||||
# 5. Mock Synthesizer
|
||||
# 5. Mock Reflector
|
||||
mock_reflector_instance = MagicMock()
|
||||
mock_llms["reflector"].return_value = mock_reflector_instance
|
||||
mock_reflector_instance.with_structured_output.return_value.invoke.return_value = MagicMock(satisfied=True, reasoning="Good.")
|
||||
|
||||
# 6. Mock Synthesizer
|
||||
mock_syn_instance = MagicMock()
|
||||
mock_llms["synthesizer"].return_value = mock_syn_instance
|
||||
mock_syn_instance.invoke.return_value = AIMessage(content="Final Summary: Success")
|
||||
@@ -111,7 +118,12 @@ def test_workflow_research_flow(mock_llms):
|
||||
mock_llms["researcher"].return_value = mock_res_instance
|
||||
mock_res_instance.invoke.return_value = AIMessage(content="Research Result")
|
||||
|
||||
# 4. Mock Synthesizer
|
||||
# 4. Mock Reflector
|
||||
mock_reflector_instance = MagicMock()
|
||||
mock_llms["reflector"].return_value = mock_reflector_instance
|
||||
mock_reflector_instance.with_structured_output.return_value.invoke.return_value = MagicMock(satisfied=True, reasoning="Good.")
|
||||
|
||||
# 5. Mock Synthesizer
|
||||
mock_syn_instance = MagicMock()
|
||||
mock_llms["synthesizer"].return_value = mock_syn_instance
|
||||
mock_syn_instance.invoke.return_value = AIMessage(content="Final Research Summary")
|
||||
|
||||
Reference in New Issue
Block a user