fix(orchestrator): Apply refinements from code review
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import io
|
||||
import sys
|
||||
import traceback
|
||||
import copy
|
||||
from contextlib import redirect_stdout
|
||||
from typing import TYPE_CHECKING
|
||||
import pandas as pd
|
||||
@@ -39,7 +40,7 @@ def executor_node(state: AgentState) -> dict:
|
||||
db_client = DBClient(settings=db_settings)
|
||||
|
||||
# Initialize the Virtual File System (VFS) helper
|
||||
vfs_state = dict(state.get("vfs", {}))
|
||||
vfs_state = copy.deepcopy(state.get("vfs", {}))
|
||||
vfs_helper = VFSHelper(vfs_state)
|
||||
|
||||
# Initialize local variables for execution
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import yaml
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from typing import List, Literal
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
|
||||
@@ -48,16 +48,16 @@ If there were major errors or the output is missing critical data requested in t
|
||||
}
|
||||
else:
|
||||
logger.warning(f"[bold yellow]Sub-task NOT satisfied.[/bold yellow] Reason: {response.reasoning}")
|
||||
# For now, we'll still advance to avoid infinite loops, but a more complex orchestrator
|
||||
# would trigger a retry or adjustment.
|
||||
# Do NOT advance the step. This triggers a retry of the same task.
|
||||
# In a more advanced version, we might route to a 'planner' for revision.
|
||||
return {
|
||||
"current_step": current_step + 1,
|
||||
"current_step": current_step,
|
||||
"next_action": "delegate"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to reflect: {str(e)}")
|
||||
# Fallback: advance anyway
|
||||
# On error, do not advance to be safe
|
||||
return {
|
||||
"current_step": current_step + 1,
|
||||
"current_step": current_step,
|
||||
"next_action": "delegate"
|
||||
}
|
||||
|
||||
@@ -12,12 +12,12 @@ class AgentState(AS):
|
||||
|
||||
# Query Analysis (Decomposition results)
|
||||
analysis: Optional[Dict[str, Any]]
|
||||
# Expected keys: "requires_dataset", "expert", "data", "unknown", "condition"
|
||||
# Expected keys: "data_required", "unknowns", "ambiguities", "conditions"
|
||||
|
||||
# Step-by-step reasoning
|
||||
# Step-by-step reasoning (Legacy, use checklist for new flows)
|
||||
plan: Optional[str]
|
||||
|
||||
# Code execution context
|
||||
# Code execution context (Legacy, use workers for new flows)
|
||||
code: Optional[str]
|
||||
code_output: Optional[str]
|
||||
error: Optional[str]
|
||||
@@ -26,10 +26,10 @@ class AgentState(AS):
|
||||
plots: Annotated[List[Any], operator.add] # Matplotlib figures
|
||||
dfs: Dict[str, Any] # Pandas DataFrames
|
||||
|
||||
# Conversation summary
|
||||
# Conversation summary / Latest worker result
|
||||
summary: Optional[str]
|
||||
|
||||
# Routing hint: "clarify", "plan", "research", "end"
|
||||
# Routing hint: "clarify", "plan", "end"
|
||||
next_action: str
|
||||
|
||||
# Number of execution attempts
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Dict, Any, List
|
||||
import copy
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.graph.workers.data_analyst.state import WorkerState
|
||||
@@ -16,7 +17,7 @@ def prepare_worker_input(state: AgentState) -> Dict[str, Any]:
|
||||
return {
|
||||
"task": task_desc,
|
||||
"messages": [HumanMessage(content=task_desc)], # Start worker loop with the task
|
||||
"vfs_state": dict(state.get("vfs", {})),
|
||||
"vfs_state": copy.deepcopy(state.get("vfs", {})),
|
||||
"iterations": 0,
|
||||
"plots": [],
|
||||
"code": None,
|
||||
@@ -27,11 +28,12 @@ def prepare_worker_input(state: AgentState) -> Dict[str, Any]:
|
||||
|
||||
def merge_worker_output(worker_state: WorkerState) -> Dict[str, Any]:
|
||||
"""Map worker results back to the global AgentState."""
|
||||
result = worker_state.get("result", "Analysis complete.")
|
||||
result = worker_state.get("result") or "No result produced by data analyst worker."
|
||||
|
||||
# We bubble up the summary as an AI message
|
||||
# We bubble up the summary as an AI message and update the global summary field
|
||||
updates = {
|
||||
"messages": [AIMessage(content=result)],
|
||||
"summary": result,
|
||||
"vfs": worker_state.get("vfs_state", {}),
|
||||
"plots": worker_state.get("plots", [])
|
||||
}
|
||||
|
||||
@@ -33,8 +33,11 @@ def coder_node(state: WorkerState) -> dict:
|
||||
vfs_summary = "Available in-memory files (VFS):\n"
|
||||
if vfs_state:
|
||||
for filename, data in vfs_state.items():
|
||||
meta = data.get("metadata", {})
|
||||
vfs_summary += f"- {filename} ({meta.get('type', 'unknown')})\n"
|
||||
if isinstance(data, dict):
|
||||
meta = data.get("metadata", {})
|
||||
vfs_summary += f"- {filename} ({meta.get('type', 'unknown')})\n"
|
||||
else:
|
||||
vfs_summary += f"- {filename} (raw data)\n"
|
||||
else:
|
||||
vfs_summary += "- None"
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import io
|
||||
import sys
|
||||
import traceback
|
||||
import copy
|
||||
from contextlib import redirect_stdout
|
||||
from typing import TYPE_CHECKING
|
||||
import pandas as pd
|
||||
@@ -39,7 +40,7 @@ def executor_node(state: WorkerState) -> dict:
|
||||
db_client = DBClient(settings=db_settings)
|
||||
|
||||
# Initialize the Virtual File System (VFS) helper with the snapshot from state
|
||||
vfs_state = dict(state.get("vfs_state", {}))
|
||||
vfs_state = copy.deepcopy(state.get("vfs_state", {}))
|
||||
vfs_helper = VFSHelper(vfs_state)
|
||||
|
||||
# Initialize local variables for execution
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from langgraph.graph import StateGraph, END
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
from ea_chatbot.graph.workers.data_analyst.state import WorkerState
|
||||
from ea_chatbot.graph.workers.data_analyst.nodes.coder import coder_node
|
||||
from ea_chatbot.graph.workers.data_analyst.nodes.executor import executor_node
|
||||
@@ -20,7 +21,7 @@ def create_data_analyst_worker(
|
||||
coder=coder_node,
|
||||
executor=executor_node,
|
||||
summarizer=summarizer_node
|
||||
) -> StateGraph:
|
||||
) -> CompiledStateGraph:
|
||||
"""Create the Data Analyst worker subgraph."""
|
||||
workflow = StateGraph(WorkerState)
|
||||
|
||||
|
||||
@@ -23,9 +23,10 @@ def prepare_researcher_input(state: AgentState) -> Dict[str, Any]:
|
||||
|
||||
def merge_researcher_output(worker_state: WorkerState) -> Dict[str, Any]:
|
||||
"""Map researcher results back to the global AgentState."""
|
||||
result = worker_state.get("result", "Research complete.")
|
||||
result = worker_state.get("result") or "No result produced by researcher worker."
|
||||
|
||||
return {
|
||||
"messages": [AIMessage(content=result)],
|
||||
"summary": result,
|
||||
# Researcher doesn't usually update VFS or Plots, but we keep the structure
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from langgraph.graph import StateGraph, END
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
from ea_chatbot.graph.workers.researcher.state import WorkerState
|
||||
from ea_chatbot.graph.workers.researcher.nodes.searcher import searcher_node
|
||||
from ea_chatbot.graph.workers.researcher.nodes.summarizer import summarizer_node
|
||||
@@ -6,7 +7,7 @@ from ea_chatbot.graph.workers.researcher.nodes.summarizer import summarizer_node
|
||||
def create_researcher_worker(
|
||||
searcher=searcher_node,
|
||||
summarizer=summarizer_node
|
||||
) -> StateGraph:
|
||||
) -> CompiledStateGraph:
|
||||
"""Create the Researcher worker subgraph."""
|
||||
workflow = StateGraph(WorkerState)
|
||||
|
||||
|
||||
@@ -9,22 +9,23 @@ from ea_chatbot.graph.workers.data_analyst.workflow import create_data_analyst_w
|
||||
from ea_chatbot.graph.workers.data_analyst.mapping import prepare_worker_input, merge_worker_output
|
||||
from ea_chatbot.graph.workers.researcher.workflow import create_researcher_worker
|
||||
from ea_chatbot.graph.workers.researcher.mapping import prepare_researcher_input, merge_researcher_output
|
||||
from ea_chatbot.graph.nodes.researcher import researcher_node
|
||||
from ea_chatbot.graph.nodes.clarification import clarification_node
|
||||
from ea_chatbot.graph.nodes.summarize_conversation import summarize_conversation_node
|
||||
|
||||
# Global cache for compiled subgraphs
|
||||
_DATA_ANALYST_WORKER = create_data_analyst_worker()
|
||||
_RESEARCHER_WORKER = create_researcher_worker()
|
||||
|
||||
def data_analyst_worker_node(state: AgentState) -> dict:
|
||||
"""Wrapper node for the Data Analyst subgraph with state mapping."""
|
||||
worker_graph = create_data_analyst_worker()
|
||||
worker_input = prepare_worker_input(state)
|
||||
worker_result = worker_graph.invoke(worker_input)
|
||||
worker_result = _DATA_ANALYST_WORKER.invoke(worker_input)
|
||||
return merge_worker_output(worker_result)
|
||||
|
||||
def researcher_worker_node(state: AgentState) -> dict:
|
||||
"""Wrapper node for the Researcher subgraph with state mapping."""
|
||||
worker_graph = create_researcher_worker()
|
||||
worker_input = prepare_researcher_input(state)
|
||||
worker_result = worker_graph.invoke(worker_input)
|
||||
worker_result = _RESEARCHER_WORKER.invoke(worker_input)
|
||||
return merge_researcher_output(worker_result)
|
||||
|
||||
def main_router(state: AgentState) -> str:
|
||||
@@ -32,6 +33,8 @@ def main_router(state: AgentState) -> str:
|
||||
next_action = state.get("next_action")
|
||||
if next_action == "clarify":
|
||||
return "clarification"
|
||||
# Even if QA suggests 'research', we now go through 'planner' for orchestration
|
||||
# aligning with the new hierarchical architecture.
|
||||
return "planner"
|
||||
|
||||
def delegation_router(state: AgentState) -> str:
|
||||
|
||||
@@ -17,8 +17,11 @@ class VFSHelper:
|
||||
def read(self, filename: str) -> Tuple[Optional[Any], Optional[Dict[str, Any]]]:
|
||||
"""Read a file and its metadata from the VFS. Returns (None, None) if not found."""
|
||||
file_data = self._vfs.get(filename)
|
||||
if file_data:
|
||||
return file_data["content"], file_data["metadata"]
|
||||
if file_data is not None:
|
||||
# Handle raw values (backwards compatibility or inconsistent schema)
|
||||
if not isinstance(file_data, dict) or "content" not in file_data:
|
||||
return file_data, {}
|
||||
return file_data["content"], file_data.get("metadata", {})
|
||||
return None, None
|
||||
|
||||
def list(self) -> List[str]:
|
||||
|
||||
Reference in New Issue
Block a user