feat(orchestrator): Implement Reflector node for task evaluation and plan advancement
This commit is contained in:
63
backend/src/ea_chatbot/graph/nodes/reflector.py
Normal file
63
backend/src/ea_chatbot/graph/nodes/reflector.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
from ea_chatbot.graph.state import AgentState
|
||||||
|
from ea_chatbot.config import Settings
|
||||||
|
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||||
|
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||||
|
from ea_chatbot.schemas import ReflectorResponse
|
||||||
|
|
||||||
|
def reflector_node(state: AgentState) -> dict:
|
||||||
|
"""Evaluate if the worker's output satisfies the current sub-task."""
|
||||||
|
checklist = state.get("checklist", [])
|
||||||
|
current_step = state.get("current_step", 0)
|
||||||
|
summary = state.get("summary", "") # This contains the worker's summary
|
||||||
|
|
||||||
|
if not checklist or current_step >= len(checklist):
|
||||||
|
return {"next_action": "summarize"}
|
||||||
|
|
||||||
|
task_info = checklist[current_step]
|
||||||
|
task_desc = task_info.get("task", "")
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
logger = get_logger("orchestrator:reflector")
|
||||||
|
|
||||||
|
logger.info(f"Evaluating worker output for task: {task_desc[:50]}...")
|
||||||
|
|
||||||
|
llm = get_llm_model(
|
||||||
|
settings.planner_llm, # Using planner model for evaluation
|
||||||
|
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||||
|
)
|
||||||
|
structured_llm = llm.with_structured_output(ReflectorResponse)
|
||||||
|
|
||||||
|
prompt = f"""You are a Lead Orchestrator evaluating the work of a specialized sub-agent.
|
||||||
|
|
||||||
|
**Sub-Task assigned:**
|
||||||
|
{task_desc}
|
||||||
|
|
||||||
|
**Worker's Result Summary:**
|
||||||
|
{summary}
|
||||||
|
|
||||||
|
Evaluate if the result is satisfactory and complete for this specific sub-task.
|
||||||
|
If there were major errors or the output is missing critical data requested in the sub-task, mark satisfied as False."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = structured_llm.invoke(prompt)
|
||||||
|
if response.satisfied:
|
||||||
|
logger.info("[bold green]Sub-task satisfied.[/bold green] Advancing plan.")
|
||||||
|
return {
|
||||||
|
"current_step": current_step + 1,
|
||||||
|
"next_action": "delegate"
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
logger.warning(f"[bold yellow]Sub-task NOT satisfied.[/bold yellow] Reason: {response.reasoning}")
|
||||||
|
# For now, we'll still advance to avoid infinite loops, but a more complex orchestrator
|
||||||
|
# would trigger a retry or adjustment.
|
||||||
|
return {
|
||||||
|
"current_step": current_step + 1,
|
||||||
|
"next_action": "delegate"
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to reflect: {str(e)}")
|
||||||
|
# Fallback: advance anyway
|
||||||
|
return {
|
||||||
|
"current_step": current_step + 1,
|
||||||
|
"next_action": "delegate"
|
||||||
|
}
|
||||||
@@ -44,6 +44,11 @@ class ChecklistResponse(BaseModel):
|
|||||||
reflection: str = Field(description="Strategic reasoning")
|
reflection: str = Field(description="Strategic reasoning")
|
||||||
checklist: List[ChecklistTask] = Field(description="Ordered list of tasks for specialized workers")
|
checklist: List[ChecklistTask] = Field(description="Ordered list of tasks for specialized workers")
|
||||||
|
|
||||||
|
class ReflectorResponse(BaseModel):
|
||||||
|
'''Orchestrator's evaluation of worker output'''
|
||||||
|
satisfied: bool = Field(description="Whether the worker's output satisfies the sub-task requirements")
|
||||||
|
reasoning: str = Field(description="Brief explanation of the evaluation")
|
||||||
|
|
||||||
_IM_SEP_TOKEN_PATTERN = re.compile(re.escape("<|im_sep|>"))
|
_IM_SEP_TOKEN_PATTERN = re.compile(re.escape("<|im_sep|>"))
|
||||||
_CODE_BLOCK_PATTERN = re.compile(r"```(?:python\s*)?(.*?)\s*```", re.DOTALL)
|
_CODE_BLOCK_PATTERN = re.compile(r"```(?:python\s*)?(.*?)\s*```", re.DOTALL)
|
||||||
_FORBIDDEN_MODULES = (
|
_FORBIDDEN_MODULES = (
|
||||||
|
|||||||
30
backend/tests/test_orchestrator_reflector.py
Normal file
30
backend/tests/test_orchestrator_reflector.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
from ea_chatbot.graph.nodes.reflector import reflector_node
|
||||||
|
from ea_chatbot.graph.state import AgentState
|
||||||
|
|
||||||
|
def test_reflector_node_satisfied():
|
||||||
|
"""Verify that the reflector node increments current_step when satisfied."""
|
||||||
|
state = AgentState(
|
||||||
|
checklist=[{"task": "Analyze votes", "worker": "data_analyst"}],
|
||||||
|
current_step=0,
|
||||||
|
messages=[],
|
||||||
|
question="test",
|
||||||
|
analysis={},
|
||||||
|
next_action="",
|
||||||
|
iterations=0,
|
||||||
|
vfs={},
|
||||||
|
plots=[],
|
||||||
|
dfs={},
|
||||||
|
summary="Worker successfully analyzed votes and found 5 million."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mocking the LLM to return 'satisfied=True'
|
||||||
|
with patch("ea_chatbot.graph.nodes.reflector.get_llm_model") as mock_get_llm:
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.with_structured_output.return_value.invoke.return_value = MagicMock(satisfied=True, reasoning="Good.")
|
||||||
|
mock_get_llm.return_value = mock_llm
|
||||||
|
|
||||||
|
result = reflector_node(state)
|
||||||
|
|
||||||
|
assert result["current_step"] == 1
|
||||||
|
assert result["next_action"] == "delegate" # Route back to delegate for next task
|
||||||
Reference in New Issue
Block a user