feat(orchestrator): Implement high-level task decomposition in Planner node
This commit is contained in:
@@ -5,30 +5,29 @@ from ea_chatbot.utils.llm_factory import get_llm_model
|
|||||||
from ea_chatbot.utils import helpers, database_inspection
|
from ea_chatbot.utils import helpers, database_inspection
|
||||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||||
from ea_chatbot.graph.prompts.planner import PLANNER_PROMPT
|
from ea_chatbot.graph.prompts.planner import PLANNER_PROMPT
|
||||||
from ea_chatbot.schemas import TaskPlanResponse
|
from ea_chatbot.schemas import ChecklistResponse
|
||||||
|
|
||||||
def planner_node(state: AgentState) -> dict:
|
def planner_node(state: AgentState) -> dict:
|
||||||
"""Generate a structured plan based on the query analysis."""
|
"""Generate a high-level task checklist for the Orchestrator."""
|
||||||
question = state["question"]
|
question = state["question"]
|
||||||
history = state.get("messages", [])[-6:]
|
history = state.get("messages", [])[-6:]
|
||||||
summary = state.get("summary", "")
|
summary = state.get("summary", "")
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
logger = get_logger("planner")
|
logger = get_logger("orchestrator:planner")
|
||||||
|
|
||||||
logger.info("Generating task plan...")
|
logger.info("Generating high-level task checklist...")
|
||||||
|
|
||||||
llm = get_llm_model(
|
llm = get_llm_model(
|
||||||
settings.planner_llm,
|
settings.planner_llm,
|
||||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||||
)
|
)
|
||||||
structured_llm = llm.with_structured_output(TaskPlanResponse)
|
structured_llm = llm.with_structured_output(ChecklistResponse)
|
||||||
|
|
||||||
date_str = helpers.get_readable_date()
|
date_str = helpers.get_readable_date()
|
||||||
|
|
||||||
# Always provide data summary; LLM decides relevance.
|
# Data summary for context
|
||||||
database_description = database_inspection.get_data_summary(data_dir=settings.data_dir) or "No data available."
|
database_description = database_inspection.get_data_summary(data_dir=settings.data_dir) or "No data available."
|
||||||
example_plan = ""
|
|
||||||
|
|
||||||
messages = PLANNER_PROMPT.format_messages(
|
messages = PLANNER_PROMPT.format_messages(
|
||||||
date=date_str,
|
date=date_str,
|
||||||
@@ -36,16 +35,19 @@ def planner_node(state: AgentState) -> dict:
|
|||||||
history=history,
|
history=history,
|
||||||
summary=summary,
|
summary=summary,
|
||||||
database_description=database_description,
|
database_description=database_description,
|
||||||
example_plan=example_plan
|
example_plan="Decompose into data_analyst and researcher tasks."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Generate the structured plan
|
|
||||||
try:
|
try:
|
||||||
response = structured_llm.invoke(messages)
|
response = structured_llm.invoke(messages)
|
||||||
# Convert the structured response back to YAML string for the state
|
# Convert ChecklistTask objects to dicts for state
|
||||||
plan_yaml = yaml.dump(response.model_dump(), sort_keys=False)
|
checklist = [task.model_dump() for task in response.checklist]
|
||||||
logger.info("[bold green]Plan generated successfully.[/bold green]")
|
logger.info(f"[bold green]Checklist generated with {len(checklist)} tasks.[/bold green]")
|
||||||
return {"plan": plan_yaml}
|
return {
|
||||||
|
"checklist": checklist,
|
||||||
|
"current_step": 0,
|
||||||
|
"summary": response.reflection # Use reflection as initial summary
|
||||||
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to generate plan: {str(e)}")
|
logger.error(f"Failed to generate checklist: {str(e)}")
|
||||||
raise e
|
raise e
|
||||||
@@ -1,41 +1,29 @@
|
|||||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||||
|
|
||||||
PLANNER_SYSTEM = """You are a Research Specialist and your job is to find answers and educate the user.
|
PLANNER_SYSTEM = """You are a Lead Orchestrator for an Election Analytics Chatbot.
|
||||||
Provide factual information responding directly to the user's question. Include key details and context to ensure your response comprehensively answers their query.
|
Your job is to decompose complex user queries into a high-level checklist of tasks.
|
||||||
|
|
||||||
|
**Specialized Workers:**
|
||||||
|
1. `data_analyst`: Handles SQL queries, Python data analysis, and plotting. Use this when the user needs numbers, trends, or charts from the internal database.
|
||||||
|
2. `researcher`: Performs web searches for current news, facts, or external data not in the primary database.
|
||||||
|
|
||||||
|
**Orchestration Strategy:**
|
||||||
|
- Analyze the user's question and the available data summary.
|
||||||
|
- Create a logical sequence of tasks (checklist) for these workers.
|
||||||
|
- Be specific in the task description for the worker (e.g., "Find the total votes in Florida 2020").
|
||||||
|
- If the query is ambiguous, the Orchestrator loop will later handle clarification, but for now, make the best plan possible.
|
||||||
|
|
||||||
Today's Date is: {date}"""
|
Today's Date is: {date}"""
|
||||||
|
|
||||||
PLANNER_USER = """Conversation Summary: {summary}
|
PLANNER_USER = """Conversation Summary: {summary}
|
||||||
|
|
||||||
TASK:
|
USER QUESTION:
|
||||||
{question}
|
{question}
|
||||||
|
|
||||||
AVAILABLE DATA SUMMARY (Use only if relevant to the task):
|
AVAILABLE DATABASE SUMMARY:
|
||||||
{database_description}
|
{database_description}
|
||||||
|
|
||||||
First: Evaluate whether you have all necessary and requested information to provide a solution.
|
Decompose the question into a strategic checklist. For each task, specify which worker should handle it.
|
||||||
Use the dataset description above to determine what data and in what format you have available to you.
|
|
||||||
You are able to search internet if the user asks for it, or you require any information that you can not derive from the given dataset or the instruction.
|
|
||||||
|
|
||||||
Second: Incorporate any additional relevant context, reasoning, or details from previous interactions or internal chain-of-thought that may impact the solution.
|
|
||||||
Ensure that all such information is fully included in your response rather than referring to previous answers indirectly.
|
|
||||||
|
|
||||||
Third: Reflect on the problem and briefly describe it, while addressing the problem goal, inputs, outputs,
|
|
||||||
rules, constraints, and other relevant details that appear in the problem description.
|
|
||||||
|
|
||||||
Fourth: Based on the preceding steps, formulate your response as an algorithm, breaking the solution in up to eight simple concise yet descriptive, clear English steps.
|
|
||||||
You MUST Include all values or instructions as described in the above task, or retrieved using internet search!
|
|
||||||
If fewer steps suffice, that's acceptable. If more are needed, please include them.
|
|
||||||
Remember to explain steps rather than write code.
|
|
||||||
|
|
||||||
This algorithm will be later converted to Python code.
|
|
||||||
If a dataframe is required, assume it is named 'df' and is already defined/populated based on the data summary.
|
|
||||||
|
|
||||||
There is a list variable called `plots` that you need to use to store any plots you generate. Do not use `plt.show()` as it will render the plot and cause an error.
|
|
||||||
|
|
||||||
Output the algorithm as a YAML string. Always enclose the YAML string within ```yaml tags.
|
|
||||||
|
|
||||||
**Note: Ensure that any necessary context from prior interactions is fully embedded in the plan. Do not use phrases like "refer to previous answer"; instead, provide complete details inline.**
|
|
||||||
|
|
||||||
{example_plan}"""
|
{example_plan}"""
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from pydantic import BaseModel, Field, computed_field
|
from pydantic import BaseModel, Field, computed_field
|
||||||
from typing import Sequence, Optional
|
from typing import Sequence, Optional, List, Dict, Any
|
||||||
import re
|
import re
|
||||||
|
|
||||||
class TaskPlanContext(BaseModel):
|
class TaskPlanContext(BaseModel):
|
||||||
@@ -33,6 +33,17 @@ class TaskPlanResponse(BaseModel):
|
|||||||
description="Ordered list of steps to execute that follow the 'Step <number>: <detail>' pattern.",
|
description="Ordered list of steps to execute that follow the 'Step <number>: <detail>' pattern.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
class ChecklistTask(BaseModel):
|
||||||
|
'''A specific sub-task in the high-level orchestrator plan'''
|
||||||
|
task: str = Field(description="Description of the sub-task")
|
||||||
|
worker: str = Field(description="The worker to delegate to (data_analyst or researcher)")
|
||||||
|
|
||||||
|
class ChecklistResponse(BaseModel):
|
||||||
|
'''Orchestrator's decomposed plan/checklist'''
|
||||||
|
goal: str = Field(description="Overall objective")
|
||||||
|
reflection: str = Field(description="Strategic reasoning")
|
||||||
|
checklist: List[ChecklistTask] = Field(description="Ordered list of tasks for specialized workers")
|
||||||
|
|
||||||
_IM_SEP_TOKEN_PATTERN = re.compile(re.escape("<|im_sep|>"))
|
_IM_SEP_TOKEN_PATTERN = re.compile(re.escape("<|im_sep|>"))
|
||||||
_CODE_BLOCK_PATTERN = re.compile(r"```(?:python\s*)?(.*?)\s*```", re.DOTALL)
|
_CODE_BLOCK_PATTERN = re.compile(r"```(?:python\s*)?(.*?)\s*```", re.DOTALL)
|
||||||
_FORBIDDEN_MODULES = (
|
_FORBIDDEN_MODULES = (
|
||||||
|
|||||||
34
backend/tests/test_orchestrator_planner.py
Normal file
34
backend/tests/test_orchestrator_planner.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
from typing import get_type_hints, List
|
||||||
|
from ea_chatbot.graph.nodes.planner import planner_node
|
||||||
|
from ea_chatbot.graph.state import AgentState
|
||||||
|
|
||||||
|
def test_planner_node_checklist():
|
||||||
|
"""Verify that the planner node generates a checklist."""
|
||||||
|
state = AgentState(
|
||||||
|
messages=[],
|
||||||
|
question="How many voters are in Florida and what is the current news?",
|
||||||
|
analysis={"requires_dataset": True},
|
||||||
|
next_action="plan",
|
||||||
|
iterations=0,
|
||||||
|
checklist=[],
|
||||||
|
current_step=0,
|
||||||
|
vfs={},
|
||||||
|
plots=[],
|
||||||
|
dfs={}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mocking the LLM would be ideal, but for now we'll check the returned keys
|
||||||
|
# and assume the implementation provides them.
|
||||||
|
# In a real TDD, we'd mock the LLM to return a specific structure.
|
||||||
|
|
||||||
|
# For now, let's assume the task is to update 'planner_node' to return these keys.
|
||||||
|
result = planner_node(state)
|
||||||
|
|
||||||
|
assert "checklist" in result
|
||||||
|
assert isinstance(result["checklist"], list)
|
||||||
|
assert len(result["checklist"]) > 0
|
||||||
|
assert "task" in result["checklist"][0]
|
||||||
|
assert "worker" in result["checklist"][0] # 'data_analyst' or 'researcher'
|
||||||
|
|
||||||
|
assert "current_step" in result
|
||||||
|
assert result["current_step"] == 0
|
||||||
Reference in New Issue
Block a user