fix(orchestrator): Apply refinements from code review
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
import io
|
import io
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
import copy
|
||||||
from contextlib import redirect_stdout
|
from contextlib import redirect_stdout
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@@ -39,7 +40,7 @@ def executor_node(state: AgentState) -> dict:
|
|||||||
db_client = DBClient(settings=db_settings)
|
db_client = DBClient(settings=db_settings)
|
||||||
|
|
||||||
# Initialize the Virtual File System (VFS) helper
|
# Initialize the Virtual File System (VFS) helper
|
||||||
vfs_state = dict(state.get("vfs", {}))
|
vfs_state = copy.deepcopy(state.get("vfs", {}))
|
||||||
vfs_helper = VFSHelper(vfs_state)
|
vfs_helper = VFSHelper(vfs_state)
|
||||||
|
|
||||||
# Initialize local variables for execution
|
# Initialize local variables for execution
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import yaml
|
|
||||||
from ea_chatbot.graph.state import AgentState
|
from ea_chatbot.graph.state import AgentState
|
||||||
from ea_chatbot.config import Settings
|
from ea_chatbot.config import Settings
|
||||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
from typing import List, Literal
|
|
||||||
from ea_chatbot.graph.state import AgentState
|
from ea_chatbot.graph.state import AgentState
|
||||||
from ea_chatbot.config import Settings
|
from ea_chatbot.config import Settings
|
||||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||||
|
|||||||
@@ -48,16 +48,16 @@ If there were major errors or the output is missing critical data requested in t
|
|||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
logger.warning(f"[bold yellow]Sub-task NOT satisfied.[/bold yellow] Reason: {response.reasoning}")
|
logger.warning(f"[bold yellow]Sub-task NOT satisfied.[/bold yellow] Reason: {response.reasoning}")
|
||||||
# For now, we'll still advance to avoid infinite loops, but a more complex orchestrator
|
# Do NOT advance the step. This triggers a retry of the same task.
|
||||||
# would trigger a retry or adjustment.
|
# In a more advanced version, we might route to a 'planner' for revision.
|
||||||
return {
|
return {
|
||||||
"current_step": current_step + 1,
|
"current_step": current_step,
|
||||||
"next_action": "delegate"
|
"next_action": "delegate"
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to reflect: {str(e)}")
|
logger.error(f"Failed to reflect: {str(e)}")
|
||||||
# Fallback: advance anyway
|
# On error, do not advance to be safe
|
||||||
return {
|
return {
|
||||||
"current_step": current_step + 1,
|
"current_step": current_step,
|
||||||
"next_action": "delegate"
|
"next_action": "delegate"
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,12 +12,12 @@ class AgentState(AS):
|
|||||||
|
|
||||||
# Query Analysis (Decomposition results)
|
# Query Analysis (Decomposition results)
|
||||||
analysis: Optional[Dict[str, Any]]
|
analysis: Optional[Dict[str, Any]]
|
||||||
# Expected keys: "requires_dataset", "expert", "data", "unknown", "condition"
|
# Expected keys: "data_required", "unknowns", "ambiguities", "conditions"
|
||||||
|
|
||||||
# Step-by-step reasoning
|
# Step-by-step reasoning (Legacy, use checklist for new flows)
|
||||||
plan: Optional[str]
|
plan: Optional[str]
|
||||||
|
|
||||||
# Code execution context
|
# Code execution context (Legacy, use workers for new flows)
|
||||||
code: Optional[str]
|
code: Optional[str]
|
||||||
code_output: Optional[str]
|
code_output: Optional[str]
|
||||||
error: Optional[str]
|
error: Optional[str]
|
||||||
@@ -26,10 +26,10 @@ class AgentState(AS):
|
|||||||
plots: Annotated[List[Any], operator.add] # Matplotlib figures
|
plots: Annotated[List[Any], operator.add] # Matplotlib figures
|
||||||
dfs: Dict[str, Any] # Pandas DataFrames
|
dfs: Dict[str, Any] # Pandas DataFrames
|
||||||
|
|
||||||
# Conversation summary
|
# Conversation summary / Latest worker result
|
||||||
summary: Optional[str]
|
summary: Optional[str]
|
||||||
|
|
||||||
# Routing hint: "clarify", "plan", "research", "end"
|
# Routing hint: "clarify", "plan", "end"
|
||||||
next_action: str
|
next_action: str
|
||||||
|
|
||||||
# Number of execution attempts
|
# Number of execution attempts
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from typing import Dict, Any, List
|
from typing import Dict, Any, List
|
||||||
|
import copy
|
||||||
from langchain_core.messages import HumanMessage, AIMessage
|
from langchain_core.messages import HumanMessage, AIMessage
|
||||||
from ea_chatbot.graph.state import AgentState
|
from ea_chatbot.graph.state import AgentState
|
||||||
from ea_chatbot.graph.workers.data_analyst.state import WorkerState
|
from ea_chatbot.graph.workers.data_analyst.state import WorkerState
|
||||||
@@ -16,7 +17,7 @@ def prepare_worker_input(state: AgentState) -> Dict[str, Any]:
|
|||||||
return {
|
return {
|
||||||
"task": task_desc,
|
"task": task_desc,
|
||||||
"messages": [HumanMessage(content=task_desc)], # Start worker loop with the task
|
"messages": [HumanMessage(content=task_desc)], # Start worker loop with the task
|
||||||
"vfs_state": dict(state.get("vfs", {})),
|
"vfs_state": copy.deepcopy(state.get("vfs", {})),
|
||||||
"iterations": 0,
|
"iterations": 0,
|
||||||
"plots": [],
|
"plots": [],
|
||||||
"code": None,
|
"code": None,
|
||||||
@@ -27,11 +28,12 @@ def prepare_worker_input(state: AgentState) -> Dict[str, Any]:
|
|||||||
|
|
||||||
def merge_worker_output(worker_state: WorkerState) -> Dict[str, Any]:
|
def merge_worker_output(worker_state: WorkerState) -> Dict[str, Any]:
|
||||||
"""Map worker results back to the global AgentState."""
|
"""Map worker results back to the global AgentState."""
|
||||||
result = worker_state.get("result", "Analysis complete.")
|
result = worker_state.get("result") or "No result produced by data analyst worker."
|
||||||
|
|
||||||
# We bubble up the summary as an AI message
|
# We bubble up the summary as an AI message and update the global summary field
|
||||||
updates = {
|
updates = {
|
||||||
"messages": [AIMessage(content=result)],
|
"messages": [AIMessage(content=result)],
|
||||||
|
"summary": result,
|
||||||
"vfs": worker_state.get("vfs_state", {}),
|
"vfs": worker_state.get("vfs_state", {}),
|
||||||
"plots": worker_state.get("plots", [])
|
"plots": worker_state.get("plots", [])
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -33,8 +33,11 @@ def coder_node(state: WorkerState) -> dict:
|
|||||||
vfs_summary = "Available in-memory files (VFS):\n"
|
vfs_summary = "Available in-memory files (VFS):\n"
|
||||||
if vfs_state:
|
if vfs_state:
|
||||||
for filename, data in vfs_state.items():
|
for filename, data in vfs_state.items():
|
||||||
meta = data.get("metadata", {})
|
if isinstance(data, dict):
|
||||||
vfs_summary += f"- {filename} ({meta.get('type', 'unknown')})\n"
|
meta = data.get("metadata", {})
|
||||||
|
vfs_summary += f"- {filename} ({meta.get('type', 'unknown')})\n"
|
||||||
|
else:
|
||||||
|
vfs_summary += f"- {filename} (raw data)\n"
|
||||||
else:
|
else:
|
||||||
vfs_summary += "- None"
|
vfs_summary += "- None"
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import io
|
import io
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
import copy
|
||||||
from contextlib import redirect_stdout
|
from contextlib import redirect_stdout
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@@ -39,7 +40,7 @@ def executor_node(state: WorkerState) -> dict:
|
|||||||
db_client = DBClient(settings=db_settings)
|
db_client = DBClient(settings=db_settings)
|
||||||
|
|
||||||
# Initialize the Virtual File System (VFS) helper with the snapshot from state
|
# Initialize the Virtual File System (VFS) helper with the snapshot from state
|
||||||
vfs_state = dict(state.get("vfs_state", {}))
|
vfs_state = copy.deepcopy(state.get("vfs_state", {}))
|
||||||
vfs_helper = VFSHelper(vfs_state)
|
vfs_helper = VFSHelper(vfs_state)
|
||||||
|
|
||||||
# Initialize local variables for execution
|
# Initialize local variables for execution
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from langgraph.graph import StateGraph, END
|
from langgraph.graph import StateGraph, END
|
||||||
|
from langgraph.graph.state import CompiledStateGraph
|
||||||
from ea_chatbot.graph.workers.data_analyst.state import WorkerState
|
from ea_chatbot.graph.workers.data_analyst.state import WorkerState
|
||||||
from ea_chatbot.graph.workers.data_analyst.nodes.coder import coder_node
|
from ea_chatbot.graph.workers.data_analyst.nodes.coder import coder_node
|
||||||
from ea_chatbot.graph.workers.data_analyst.nodes.executor import executor_node
|
from ea_chatbot.graph.workers.data_analyst.nodes.executor import executor_node
|
||||||
@@ -20,7 +21,7 @@ def create_data_analyst_worker(
|
|||||||
coder=coder_node,
|
coder=coder_node,
|
||||||
executor=executor_node,
|
executor=executor_node,
|
||||||
summarizer=summarizer_node
|
summarizer=summarizer_node
|
||||||
) -> StateGraph:
|
) -> CompiledStateGraph:
|
||||||
"""Create the Data Analyst worker subgraph."""
|
"""Create the Data Analyst worker subgraph."""
|
||||||
workflow = StateGraph(WorkerState)
|
workflow = StateGraph(WorkerState)
|
||||||
|
|
||||||
|
|||||||
@@ -23,9 +23,10 @@ def prepare_researcher_input(state: AgentState) -> Dict[str, Any]:
|
|||||||
|
|
||||||
def merge_researcher_output(worker_state: WorkerState) -> Dict[str, Any]:
|
def merge_researcher_output(worker_state: WorkerState) -> Dict[str, Any]:
|
||||||
"""Map researcher results back to the global AgentState."""
|
"""Map researcher results back to the global AgentState."""
|
||||||
result = worker_state.get("result", "Research complete.")
|
result = worker_state.get("result") or "No result produced by researcher worker."
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"messages": [AIMessage(content=result)],
|
"messages": [AIMessage(content=result)],
|
||||||
|
"summary": result,
|
||||||
# Researcher doesn't usually update VFS or Plots, but we keep the structure
|
# Researcher doesn't usually update VFS or Plots, but we keep the structure
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from langgraph.graph import StateGraph, END
|
from langgraph.graph import StateGraph, END
|
||||||
|
from langgraph.graph.state import CompiledStateGraph
|
||||||
from ea_chatbot.graph.workers.researcher.state import WorkerState
|
from ea_chatbot.graph.workers.researcher.state import WorkerState
|
||||||
from ea_chatbot.graph.workers.researcher.nodes.searcher import searcher_node
|
from ea_chatbot.graph.workers.researcher.nodes.searcher import searcher_node
|
||||||
from ea_chatbot.graph.workers.researcher.nodes.summarizer import summarizer_node
|
from ea_chatbot.graph.workers.researcher.nodes.summarizer import summarizer_node
|
||||||
@@ -6,7 +7,7 @@ from ea_chatbot.graph.workers.researcher.nodes.summarizer import summarizer_node
|
|||||||
def create_researcher_worker(
|
def create_researcher_worker(
|
||||||
searcher=searcher_node,
|
searcher=searcher_node,
|
||||||
summarizer=summarizer_node
|
summarizer=summarizer_node
|
||||||
) -> StateGraph:
|
) -> CompiledStateGraph:
|
||||||
"""Create the Researcher worker subgraph."""
|
"""Create the Researcher worker subgraph."""
|
||||||
workflow = StateGraph(WorkerState)
|
workflow = StateGraph(WorkerState)
|
||||||
|
|
||||||
|
|||||||
@@ -9,22 +9,23 @@ from ea_chatbot.graph.workers.data_analyst.workflow import create_data_analyst_w
|
|||||||
from ea_chatbot.graph.workers.data_analyst.mapping import prepare_worker_input, merge_worker_output
|
from ea_chatbot.graph.workers.data_analyst.mapping import prepare_worker_input, merge_worker_output
|
||||||
from ea_chatbot.graph.workers.researcher.workflow import create_researcher_worker
|
from ea_chatbot.graph.workers.researcher.workflow import create_researcher_worker
|
||||||
from ea_chatbot.graph.workers.researcher.mapping import prepare_researcher_input, merge_researcher_output
|
from ea_chatbot.graph.workers.researcher.mapping import prepare_researcher_input, merge_researcher_output
|
||||||
from ea_chatbot.graph.nodes.researcher import researcher_node
|
|
||||||
from ea_chatbot.graph.nodes.clarification import clarification_node
|
from ea_chatbot.graph.nodes.clarification import clarification_node
|
||||||
from ea_chatbot.graph.nodes.summarize_conversation import summarize_conversation_node
|
from ea_chatbot.graph.nodes.summarize_conversation import summarize_conversation_node
|
||||||
|
|
||||||
|
# Global cache for compiled subgraphs
|
||||||
|
_DATA_ANALYST_WORKER = create_data_analyst_worker()
|
||||||
|
_RESEARCHER_WORKER = create_researcher_worker()
|
||||||
|
|
||||||
def data_analyst_worker_node(state: AgentState) -> dict:
|
def data_analyst_worker_node(state: AgentState) -> dict:
|
||||||
"""Wrapper node for the Data Analyst subgraph with state mapping."""
|
"""Wrapper node for the Data Analyst subgraph with state mapping."""
|
||||||
worker_graph = create_data_analyst_worker()
|
|
||||||
worker_input = prepare_worker_input(state)
|
worker_input = prepare_worker_input(state)
|
||||||
worker_result = worker_graph.invoke(worker_input)
|
worker_result = _DATA_ANALYST_WORKER.invoke(worker_input)
|
||||||
return merge_worker_output(worker_result)
|
return merge_worker_output(worker_result)
|
||||||
|
|
||||||
def researcher_worker_node(state: AgentState) -> dict:
|
def researcher_worker_node(state: AgentState) -> dict:
|
||||||
"""Wrapper node for the Researcher subgraph with state mapping."""
|
"""Wrapper node for the Researcher subgraph with state mapping."""
|
||||||
worker_graph = create_researcher_worker()
|
|
||||||
worker_input = prepare_researcher_input(state)
|
worker_input = prepare_researcher_input(state)
|
||||||
worker_result = worker_graph.invoke(worker_input)
|
worker_result = _RESEARCHER_WORKER.invoke(worker_input)
|
||||||
return merge_researcher_output(worker_result)
|
return merge_researcher_output(worker_result)
|
||||||
|
|
||||||
def main_router(state: AgentState) -> str:
|
def main_router(state: AgentState) -> str:
|
||||||
@@ -32,6 +33,8 @@ def main_router(state: AgentState) -> str:
|
|||||||
next_action = state.get("next_action")
|
next_action = state.get("next_action")
|
||||||
if next_action == "clarify":
|
if next_action == "clarify":
|
||||||
return "clarification"
|
return "clarification"
|
||||||
|
# Even if QA suggests 'research', we now go through 'planner' for orchestration
|
||||||
|
# aligning with the new hierarchical architecture.
|
||||||
return "planner"
|
return "planner"
|
||||||
|
|
||||||
def delegation_router(state: AgentState) -> str:
|
def delegation_router(state: AgentState) -> str:
|
||||||
|
|||||||
@@ -17,8 +17,11 @@ class VFSHelper:
|
|||||||
def read(self, filename: str) -> Tuple[Optional[Any], Optional[Dict[str, Any]]]:
|
def read(self, filename: str) -> Tuple[Optional[Any], Optional[Dict[str, Any]]]:
|
||||||
"""Read a file and its metadata from the VFS. Returns (None, None) if not found."""
|
"""Read a file and its metadata from the VFS. Returns (None, None) if not found."""
|
||||||
file_data = self._vfs.get(filename)
|
file_data = self._vfs.get(filename)
|
||||||
if file_data:
|
if file_data is not None:
|
||||||
return file_data["content"], file_data["metadata"]
|
# Handle raw values (backwards compatibility or inconsistent schema)
|
||||||
|
if not isinstance(file_data, dict) or "content" not in file_data:
|
||||||
|
return file_data, {}
|
||||||
|
return file_data["content"], file_data.get("metadata", {})
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
def list(self) -> List[str]:
|
def list(self) -> List[str]:
|
||||||
|
|||||||
@@ -123,3 +123,20 @@ def test_get_me_success(client):
|
|||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json()["email"] == "test@example.com"
|
assert response.json()["email"] == "test@example.com"
|
||||||
assert response.json()["id"] == "123"
|
assert response.json()["id"] == "123"
|
||||||
|
|
||||||
|
def test_get_me_rejects_refresh_token(client):
|
||||||
|
"""Test that /auth/me rejects refresh tokens for authentication."""
|
||||||
|
from ea_chatbot.api.utils import create_refresh_token
|
||||||
|
token = create_refresh_token(data={"sub": "123"})
|
||||||
|
|
||||||
|
with patch("ea_chatbot.api.dependencies.history_manager") as mock_hm:
|
||||||
|
# Even if the user exists, the dependency should reject the token type
|
||||||
|
mock_hm.get_user_by_id.return_value = User(id="123", username="test@example.com")
|
||||||
|
|
||||||
|
response = client.get(
|
||||||
|
"/api/v1/auth/me",
|
||||||
|
headers={"Authorization": f"Bearer {token}"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert "Cannot use refresh token" in response.json()["detail"]
|
||||||
|
|||||||
42
backend/tests/test_review_fix_integration.py
Normal file
42
backend/tests/test_review_fix_integration.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
import pytest
|
||||||
|
from ea_chatbot.graph.workflow import create_workflow, data_analyst_worker_node
|
||||||
|
from ea_chatbot.graph.state import AgentState
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
from langchain_core.messages import HumanMessage, AIMessage
|
||||||
|
|
||||||
|
def test_worker_merge_sets_summary_for_reflector():
|
||||||
|
"""Verify that worker node (wrapper) sets the 'summary' field for the Reflector."""
|
||||||
|
|
||||||
|
state = AgentState(
|
||||||
|
messages=[HumanMessage(content="test")],
|
||||||
|
question="test",
|
||||||
|
checklist=[{"task": "Analyze data", "worker": "data_analyst"}],
|
||||||
|
current_step=0,
|
||||||
|
iterations=0,
|
||||||
|
vfs={},
|
||||||
|
plots=[],
|
||||||
|
dfs={},
|
||||||
|
next_action="",
|
||||||
|
analysis={},
|
||||||
|
summary="Initial Planner Summary" # Stale summary
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock the compiled worker subgraph to return a specific result
|
||||||
|
with patch("ea_chatbot.graph.workflow._DATA_ANALYST_WORKER") as mock_worker:
|
||||||
|
mock_worker.invoke.return_value = {
|
||||||
|
"result": "Actual Worker Findings",
|
||||||
|
"messages": [AIMessage(content="Internal")],
|
||||||
|
"vfs_state": {},
|
||||||
|
"plots": []
|
||||||
|
}
|
||||||
|
|
||||||
|
# Execute the wrapper node
|
||||||
|
updates = data_analyst_worker_node(state)
|
||||||
|
|
||||||
|
# Verify that 'summary' is in updates and has the worker result
|
||||||
|
assert "summary" in updates
|
||||||
|
assert updates["summary"] == "Actual Worker Findings"
|
||||||
|
|
||||||
|
# When applied to state, it should overwrite the stale summary
|
||||||
|
state.update(updates)
|
||||||
|
assert state["summary"] == "Actual Worker Findings"
|
||||||
33
backend/tests/test_review_fix_reflector.py
Normal file
33
backend/tests/test_review_fix_reflector.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
from ea_chatbot.graph.nodes.reflector import reflector_node
|
||||||
|
from ea_chatbot.graph.state import AgentState
|
||||||
|
|
||||||
|
def test_reflector_does_not_advance_on_failure():
|
||||||
|
"""Verify that reflector does not increment current_step if not satisfied."""
|
||||||
|
state = AgentState(
|
||||||
|
checklist=[{"task": "Task 1", "worker": "data_analyst"}],
|
||||||
|
current_step=0,
|
||||||
|
messages=[],
|
||||||
|
question="test",
|
||||||
|
analysis={},
|
||||||
|
next_action="",
|
||||||
|
iterations=0,
|
||||||
|
vfs={},
|
||||||
|
plots=[],
|
||||||
|
dfs={},
|
||||||
|
summary="Failed output"
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("ea_chatbot.graph.nodes.reflector.get_llm_model") as mock_get_llm:
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
# Mark as NOT satisfied
|
||||||
|
mock_llm.with_structured_output.return_value.invoke.return_value = MagicMock(satisfied=False, reasoning="Incomplete.")
|
||||||
|
mock_get_llm.return_value = mock_llm
|
||||||
|
|
||||||
|
result = reflector_node(state)
|
||||||
|
|
||||||
|
# Should NOT increment
|
||||||
|
assert result["current_step"] == 0
|
||||||
|
# Should probably route to planner or retry
|
||||||
|
assert result["next_action"] == "delegate" # Or 'planner' if we want re-planning
|
||||||
51
backend/tests/test_review_fix_summary.py
Normal file
51
backend/tests/test_review_fix_summary.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
import pytest
|
||||||
|
from ea_chatbot.graph.workers.data_analyst.mapping import merge_worker_output as merge_analyst
|
||||||
|
from ea_chatbot.graph.workers.researcher.mapping import merge_researcher_output as merge_researcher
|
||||||
|
from ea_chatbot.graph.workers.data_analyst.state import WorkerState as AnalystState
|
||||||
|
from ea_chatbot.graph.workers.researcher.state import WorkerState as ResearcherState
|
||||||
|
|
||||||
|
def test_analyst_merge_updates_summary():
|
||||||
|
"""Verify that analyst merge updates the global summary."""
|
||||||
|
worker_state = AnalystState(
|
||||||
|
result="Actual Worker Result",
|
||||||
|
messages=[],
|
||||||
|
task="test",
|
||||||
|
iterations=1,
|
||||||
|
vfs_state={},
|
||||||
|
plots=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
updates = merge_analyst(worker_state)
|
||||||
|
assert "summary" in updates
|
||||||
|
assert updates["summary"] == "Actual Worker Result"
|
||||||
|
|
||||||
|
def test_researcher_merge_updates_summary():
|
||||||
|
"""Verify that researcher merge updates the global summary."""
|
||||||
|
worker_state = ResearcherState(
|
||||||
|
result="Actual Research Result",
|
||||||
|
messages=[],
|
||||||
|
task="test",
|
||||||
|
iterations=1,
|
||||||
|
queries=[],
|
||||||
|
raw_results=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
updates = merge_researcher(worker_state)
|
||||||
|
assert "summary" in updates
|
||||||
|
assert updates["summary"] == "Actual Research Result"
|
||||||
|
|
||||||
|
def test_merge_handles_none_result():
|
||||||
|
"""Verify that merge functions handle None results gracefully."""
|
||||||
|
worker_state = AnalystState(
|
||||||
|
result=None,
|
||||||
|
messages=[],
|
||||||
|
task="test",
|
||||||
|
iterations=1,
|
||||||
|
vfs_state={},
|
||||||
|
plots=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
updates = merge_analyst(worker_state)
|
||||||
|
assert updates["summary"] is not None
|
||||||
|
assert isinstance(updates["messages"][0].content, str)
|
||||||
|
assert len(updates["messages"][0].content) > 0
|
||||||
50
backend/tests/test_review_fix_vfs.py
Normal file
50
backend/tests/test_review_fix_vfs.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
import pytest
|
||||||
|
from ea_chatbot.utils.vfs import VFSHelper
|
||||||
|
from ea_chatbot.graph.workers.data_analyst.mapping import prepare_worker_input
|
||||||
|
from ea_chatbot.graph.state import AgentState
|
||||||
|
|
||||||
|
def test_vfs_isolation_deep_copy():
|
||||||
|
"""Verify that worker VFS is deep-copied from global state."""
|
||||||
|
global_vfs = {
|
||||||
|
"data.txt": {
|
||||||
|
"content": "original",
|
||||||
|
"metadata": {"tags": ["a"]}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
state = AgentState(
|
||||||
|
checklist=[{"task": "test", "worker": "data_analyst"}],
|
||||||
|
current_step=0,
|
||||||
|
messages=[],
|
||||||
|
question="test",
|
||||||
|
analysis={},
|
||||||
|
next_action="",
|
||||||
|
iterations=0,
|
||||||
|
vfs=global_vfs,
|
||||||
|
plots=[],
|
||||||
|
dfs={}
|
||||||
|
)
|
||||||
|
|
||||||
|
worker_input = prepare_worker_input(state)
|
||||||
|
worker_vfs = worker_input["vfs_state"]
|
||||||
|
|
||||||
|
# Mutate worker VFS nested object
|
||||||
|
worker_vfs["data.txt"]["metadata"]["tags"].append("b")
|
||||||
|
|
||||||
|
# Global VFS should remain unchanged
|
||||||
|
assert global_vfs["data.txt"]["metadata"]["tags"] == ["a"]
|
||||||
|
|
||||||
|
def test_vfs_schema_normalization():
|
||||||
|
"""Verify that VFSHelper handles inconsistent VFS entries."""
|
||||||
|
vfs = {
|
||||||
|
"raw.txt": "just a string", # Inconsistent with standard Dict[str, Any] entry
|
||||||
|
"valid.txt": {"content": "data", "metadata": {}}
|
||||||
|
}
|
||||||
|
helper = VFSHelper(vfs)
|
||||||
|
|
||||||
|
# Should not crash during read
|
||||||
|
content, metadata = helper.read("raw.txt")
|
||||||
|
assert content == "just a string"
|
||||||
|
assert metadata == {}
|
||||||
|
|
||||||
|
# Should not crash during list/other ops if they assume dict
|
||||||
|
assert "raw.txt" in helper.list()
|
||||||
@@ -12,14 +12,16 @@ def mock_llms():
|
|||||||
patch("ea_chatbot.graph.workers.data_analyst.nodes.coder.get_llm_model") as mock_coder, \
|
patch("ea_chatbot.graph.workers.data_analyst.nodes.coder.get_llm_model") as mock_coder, \
|
||||||
patch("ea_chatbot.graph.workers.data_analyst.nodes.summarizer.get_llm_model") as mock_worker_summarizer, \
|
patch("ea_chatbot.graph.workers.data_analyst.nodes.summarizer.get_llm_model") as mock_worker_summarizer, \
|
||||||
patch("ea_chatbot.graph.nodes.synthesizer.get_llm_model") as mock_synthesizer, \
|
patch("ea_chatbot.graph.nodes.synthesizer.get_llm_model") as mock_synthesizer, \
|
||||||
patch("ea_chatbot.graph.nodes.researcher.get_llm_model") as mock_researcher:
|
patch("ea_chatbot.graph.nodes.researcher.get_llm_model") as mock_researcher, \
|
||||||
|
patch("ea_chatbot.graph.nodes.reflector.get_llm_model") as mock_reflector:
|
||||||
yield {
|
yield {
|
||||||
"qa": mock_qa,
|
"qa": mock_qa,
|
||||||
"planner": mock_planner,
|
"planner": mock_planner,
|
||||||
"coder": mock_coder,
|
"coder": mock_coder,
|
||||||
"worker_summarizer": mock_worker_summarizer,
|
"worker_summarizer": mock_worker_summarizer,
|
||||||
"synthesizer": mock_synthesizer,
|
"synthesizer": mock_synthesizer,
|
||||||
"researcher": mock_researcher
|
"researcher": mock_researcher,
|
||||||
|
"reflector": mock_reflector
|
||||||
}
|
}
|
||||||
|
|
||||||
def test_workflow_data_analysis_flow(mock_llms):
|
def test_workflow_data_analysis_flow(mock_llms):
|
||||||
@@ -58,7 +60,12 @@ def test_workflow_data_analysis_flow(mock_llms):
|
|||||||
mock_llms["worker_summarizer"].return_value = mock_ws_instance
|
mock_llms["worker_summarizer"].return_value = mock_ws_instance
|
||||||
mock_ws_instance.invoke.return_value = AIMessage(content="Worker Summary")
|
mock_ws_instance.invoke.return_value = AIMessage(content="Worker Summary")
|
||||||
|
|
||||||
# 5. Mock Synthesizer
|
# 5. Mock Reflector
|
||||||
|
mock_reflector_instance = MagicMock()
|
||||||
|
mock_llms["reflector"].return_value = mock_reflector_instance
|
||||||
|
mock_reflector_instance.with_structured_output.return_value.invoke.return_value = MagicMock(satisfied=True, reasoning="Good.")
|
||||||
|
|
||||||
|
# 6. Mock Synthesizer
|
||||||
mock_syn_instance = MagicMock()
|
mock_syn_instance = MagicMock()
|
||||||
mock_llms["synthesizer"].return_value = mock_syn_instance
|
mock_llms["synthesizer"].return_value = mock_syn_instance
|
||||||
mock_syn_instance.invoke.return_value = AIMessage(content="Final Summary: Success")
|
mock_syn_instance.invoke.return_value = AIMessage(content="Final Summary: Success")
|
||||||
@@ -111,7 +118,12 @@ def test_workflow_research_flow(mock_llms):
|
|||||||
mock_llms["researcher"].return_value = mock_res_instance
|
mock_llms["researcher"].return_value = mock_res_instance
|
||||||
mock_res_instance.invoke.return_value = AIMessage(content="Research Result")
|
mock_res_instance.invoke.return_value = AIMessage(content="Research Result")
|
||||||
|
|
||||||
# 4. Mock Synthesizer
|
# 4. Mock Reflector
|
||||||
|
mock_reflector_instance = MagicMock()
|
||||||
|
mock_llms["reflector"].return_value = mock_reflector_instance
|
||||||
|
mock_reflector_instance.with_structured_output.return_value.invoke.return_value = MagicMock(satisfied=True, reasoning="Good.")
|
||||||
|
|
||||||
|
# 5. Mock Synthesizer
|
||||||
mock_syn_instance = MagicMock()
|
mock_syn_instance = MagicMock()
|
||||||
mock_llms["synthesizer"].return_value = mock_syn_instance
|
mock_llms["synthesizer"].return_value = mock_syn_instance
|
||||||
mock_syn_instance.invoke.return_value = AIMessage(content="Final Research Summary")
|
mock_syn_instance.invoke.return_value = AIMessage(content="Final Research Summary")
|
||||||
|
|||||||
Reference in New Issue
Block a user