Compare commits
38 Commits
main
...
f62a70f7c3
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f62a70f7c3 | ||
|
|
b9084fcaef | ||
|
|
8fcfc4ee88 | ||
|
|
b8d8651924 | ||
|
|
02d93120e0 | ||
|
|
92c30d217e | ||
|
|
88a27f5a8d | ||
|
|
4d92c9aedb | ||
|
|
557b553c59 | ||
|
|
2cfbc5d1d0 | ||
|
|
c5cf4b38a1 | ||
|
|
46129c6f1e | ||
|
|
8eea464be4 | ||
|
|
11c14fb8a8 | ||
|
|
9b97140fff | ||
|
|
9e90f2c9ad | ||
|
|
5cc5bd91ae | ||
|
|
120b6fd11a | ||
|
|
f4d09c07c4 | ||
|
|
ad7845cc6a | ||
|
|
18e4e8db7d | ||
|
|
9fef4888b5 | ||
|
|
37c353a249 | ||
|
|
ff9b443bfe | ||
|
|
575e1a2e53 | ||
|
|
013208b929 | ||
|
|
cb045504d1 | ||
|
|
5324cbe851 | ||
|
|
eeb2be409b | ||
|
|
92d9288f38 | ||
|
|
8957e93f3d | ||
|
|
45fe122580 | ||
|
|
969165f4a7 | ||
|
|
e7be0dbeca | ||
|
|
322ae1e7c8 | ||
|
|
46b57d2a73 | ||
|
|
99ded43483 | ||
|
|
1394c0496a |
@@ -4,6 +4,7 @@ A stateful, graph-based chatbot for election data analysis, built with LangGraph
|
||||
|
||||
## 🚀 Features
|
||||
|
||||
- **Multi-Agent Orchestration**: Decomposes complex queries and delegates them to specialized sub-agents (Data Analyst, Researcher) using a robust feedback loop.
|
||||
- **Intelligent Query Analysis**: Automatically determines if a query needs data analysis, web research, or clarification.
|
||||
- **Automated Data Analysis**: Generates and executes Python code to analyze election datasets and produce visualizations.
|
||||
- **Web Research**: Integrates web search capabilities for general election-related questions.
|
||||
|
||||
@@ -57,3 +57,11 @@ OIDC_REDIRECT_URI=http://localhost:8000/api/v1/auth/oidc/callback
|
||||
# Researcher
|
||||
# RESEARCHER_LLM__PROVIDER=google
|
||||
# RESEARCHER_LLM__MODEL=gemini-2.0-flash
|
||||
|
||||
# Reflector
|
||||
# REFLECTOR_LLM__PROVIDER=openai
|
||||
# REFLECTOR_LLM__MODEL=gpt-5-mini
|
||||
|
||||
# Synthesizer
|
||||
# SYNTHESIZER_LLM__PROVIDER=openai
|
||||
# SYNTHESIZER_LLM__MODEL=gpt-5-mini
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
# Election Analytics Chatbot - Backend Guide
|
||||
|
||||
## Overview
|
||||
The backend is a Python-based FastAPI application that leverages **LangGraph** to provide a stateful, agentic workflow for election data analysis. It handles complex queries by decomposing them into tasks such as data analysis, web research, or user clarification.
|
||||
The backend is a Python-based FastAPI application that leverages **LangGraph** to provide a stateful, hierarchical multi-agent workflow for election data analysis. It handles complex queries using an Orchestrator-Workers pattern, decomposing tasks and delegating them to specialized subgraphs (Data Analyst, Researcher) with built-in reflection and error recovery.
|
||||
|
||||
## 1. Architecture Overview
|
||||
- **Framework**: LangGraph for workflow orchestration and state management.
|
||||
- **Framework**: LangGraph for hierarchical workflow orchestration and state management.
|
||||
- **API**: FastAPI for providing REST and streaming (SSE) endpoints.
|
||||
- **State Management**: Persistent state using LangGraph's `StateGraph` with a PostgreSQL checkpointer.
|
||||
- **State Management**: Persistent state using LangGraph's `StateGraph` with a PostgreSQL checkpointer. Maintains global state (`AgentState`) and isolated worker states (`WorkerState`).
|
||||
- **Virtual File System (VFS)**: An in-memory abstraction passed between nodes to manage intermediate artifacts (scripts, CSVs, charts) without bloating the context window.
|
||||
- **Database**: PostgreSQL.
|
||||
- Application data: Uses `users` table for local and OIDC users (String IDs).
|
||||
- History: Persists chat history and artifacts.
|
||||
@@ -14,23 +15,28 @@ The backend is a Python-based FastAPI application that leverages **LangGraph** t
|
||||
|
||||
## 2. Core Components
|
||||
|
||||
### 2.1. The Graph State (`src/ea_chatbot/graph/state.py`)
|
||||
The state tracks the conversation context, plan, generated code, execution results, and artifacts.
|
||||
### 2.1. State Management (`src/ea_chatbot/graph/state.py` & `workers/*/state.py`)
|
||||
- **Global State**: Tracks the conversation context, the high-level task `checklist`, execution progress (`current_step`), and the VFS.
|
||||
- **Worker State**: Isolated snapshot for specialized subgraphs, tracking internal retry loops (`iterations`), worker-specific prompts, and raw results.
|
||||
|
||||
### 2.2. Nodes (The Actors)
|
||||
### 2.2. The Orchestrator
|
||||
Located in `src/ea_chatbot/graph/nodes/`:
|
||||
|
||||
- **`query_analyzer`**: Analyzes the user query to determine the intent and required data.
|
||||
- **`planner`**: Creates a step-by-step plan for data analysis.
|
||||
- **`coder`**: Generates Python code based on the plan and dataset metadata.
|
||||
- **`executor`**: Safely executes the generated code and captures outputs (dataframes, plots).
|
||||
- **`error_corrector`**: Fixes code if execution fails.
|
||||
- **`researcher`**: Performs web searches for general election information.
|
||||
- **`summarizer`**: Generates a natural language response based on the analysis results.
|
||||
- **`clarification`**: Asks the user for more information if the query is ambiguous.
|
||||
- **`query_analyzer`**: Analyzes the user query to determine the intent and required data. If ambiguous, routes to `clarification`.
|
||||
- **`planner`**: Decomposes the user request into a strategic `checklist` of sub-tasks assigned to specific workers.
|
||||
- **`delegate`**: The traffic controller. Routes the current task to the appropriate worker and enforces a strict retry budget to prevent infinite loops.
|
||||
- **`reflector`**: The quality control node. Evaluates a worker's summary against the sub-task requirements. Can trigger a retry if unsatisfied.
|
||||
- **`synthesizer`**: Aggregates all worker results into a final, cohesive response for the user.
|
||||
- **`clarification`**: Asks the user for more information if the query is critically ambiguous.
|
||||
|
||||
### 2.3. The Workflow (Graph)
|
||||
The graph connects these nodes with conditional edges, allowing for iterative refinement and error correction.
|
||||
### 2.3. Specialized Workers (Sub-Graphs)
|
||||
Located in `src/ea_chatbot/graph/workers/`:
|
||||
|
||||
- **`data_analyst`**: Generates Python/SQL code, executes it securely, and captures dataframes/plots. Contains an internal retry loop (`coder` -> `executor` -> error check -> `coder`).
|
||||
- **`researcher`**: Performs web searches for general election information and synthesizes factual findings.
|
||||
|
||||
### 2.4. The Workflow
|
||||
The global graph connects the Orchestrator nodes, wrapping the Worker subgraphs as self-contained nodes with mapped inputs and outputs.
|
||||
|
||||
## 3. Key Modules
|
||||
|
||||
|
||||
@@ -56,7 +56,8 @@ async def stream_agent_events(
|
||||
initial_state,
|
||||
config,
|
||||
version="v2",
|
||||
checkpointer=checkpointer
|
||||
checkpointer=checkpointer,
|
||||
subgraphs=True
|
||||
):
|
||||
kind = event.get("event")
|
||||
name = event.get("name")
|
||||
@@ -71,8 +72,8 @@ async def stream_agent_events(
|
||||
"data": data
|
||||
}
|
||||
|
||||
# Buffer assistant chunks (summarizer and researcher might stream)
|
||||
if kind == "on_chat_model_stream" and node_name in ["summarizer", "researcher", "clarification"]:
|
||||
# Buffer assistant chunks (synthesizer and clarification might stream)
|
||||
if kind == "on_chat_model_stream" and node_name in ["synthesizer", "clarification"]:
|
||||
chunk = data.get("chunk", "")
|
||||
# Use utility to safely extract text content from the chunk
|
||||
chunk_data = convert_to_json_compatible(chunk)
|
||||
@@ -83,7 +84,7 @@ async def stream_agent_events(
|
||||
assistant_chunks.append(str(chunk_data))
|
||||
|
||||
# Buffer and encode plots
|
||||
if kind == "on_chain_end" and name == "executor":
|
||||
if kind == "on_chain_end" and name == "data_analyst_worker":
|
||||
output = data.get("output", {})
|
||||
if isinstance(output, dict) and "plots" in output:
|
||||
plots = output["plots"]
|
||||
@@ -95,7 +96,7 @@ async def stream_agent_events(
|
||||
output_event["data"]["encoded_plots"] = encoded_plots
|
||||
|
||||
# Collect final response from terminal nodes
|
||||
if kind == "on_chain_end" and name in ["summarizer", "researcher", "clarification"]:
|
||||
if kind == "on_chain_end" and name in ["synthesizer", "clarification"]:
|
||||
output = data.get("output", {})
|
||||
if isinstance(output, dict) and "messages" in output:
|
||||
last_msg = output["messages"][-1]
|
||||
|
||||
@@ -52,6 +52,8 @@ class Settings(BaseSettings):
|
||||
coder_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0))
|
||||
summarizer_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0))
|
||||
researcher_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0))
|
||||
reflector_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0))
|
||||
synthesizer_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0))
|
||||
|
||||
# Allow nested env vars like QUERY_ANALYZER_LLM__MODEL
|
||||
model_config = SettingsConfigDict(env_nested_delimiter='__', env_prefix='')
|
||||
|
||||
@@ -36,8 +36,7 @@ Please ask the user for the necessary details."""
|
||||
response = llm.invoke(messages)
|
||||
logger.info("[bold green]Clarification generated.[/bold green]")
|
||||
return {
|
||||
"messages": [response],
|
||||
"next_action": "end" # To indicate we are done for now
|
||||
"messages": [response]
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate clarification: {str(e)}")
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from ea_chatbot.utils import database_inspection
|
||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||
from ea_chatbot.graph.prompts.coder import CODE_GENERATOR_PROMPT
|
||||
from ea_chatbot.schemas import CodeGenerationResponse
|
||||
|
||||
def coder_node(state: AgentState) -> dict:
|
||||
"""Generate Python code based on the plan and data summary."""
|
||||
question = state["question"]
|
||||
plan = state.get("plan", "")
|
||||
code_output = state.get("code_output", "None")
|
||||
|
||||
settings = Settings()
|
||||
logger = get_logger("coder")
|
||||
|
||||
logger.info("Generating Python code...")
|
||||
|
||||
llm = get_llm_model(
|
||||
settings.coder_llm,
|
||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||
)
|
||||
structured_llm = llm.with_structured_output(CodeGenerationResponse)
|
||||
|
||||
# Always provide data summary
|
||||
database_description = database_inspection.get_data_summary(data_dir=settings.data_dir) or "No data available."
|
||||
example_code = "" # Placeholder
|
||||
|
||||
messages = CODE_GENERATOR_PROMPT.format_messages(
|
||||
question=question,
|
||||
plan=plan,
|
||||
database_description=database_description,
|
||||
code_exec_results=code_output,
|
||||
example_code=example_code
|
||||
)
|
||||
|
||||
try:
|
||||
response = structured_llm.invoke(messages)
|
||||
logger.info("[bold green]Code generated.[/bold green]")
|
||||
return {
|
||||
"code": response.parsed_code,
|
||||
"error": None # Clear previous errors on new code generation
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate code: {str(e)}")
|
||||
raise e
|
||||
28
backend/src/ea_chatbot/graph/nodes/delegate.py
Normal file
28
backend/src/ea_chatbot/graph/nodes/delegate.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.utils.logging import get_logger
|
||||
|
||||
def delegate_node(state: AgentState) -> dict:
|
||||
"""Determine which worker subgraph to call next based on the checklist."""
|
||||
checklist = state.get("checklist", [])
|
||||
current_step = state.get("current_step", 0)
|
||||
iterations = state.get("iterations", 0)
|
||||
logger = get_logger("orchestrator:delegate")
|
||||
|
||||
if not checklist or current_step >= len(checklist):
|
||||
logger.info("Checklist complete or empty. Routing to summarizer.")
|
||||
return {"next_action": "summarize"}
|
||||
|
||||
# Enforce retry budget
|
||||
if iterations >= 3:
|
||||
logger.error(f"Max retries reached for task {current_step}. Routing to summary with failure.")
|
||||
return {
|
||||
"next_action": "summarize",
|
||||
"iterations": 0, # Reset for next turn
|
||||
"summary": f"Failed to complete task {current_step} after {iterations} attempts."
|
||||
}
|
||||
|
||||
task_info = checklist[current_step]
|
||||
worker = task_info.get("worker", "data_analyst")
|
||||
|
||||
logger.info(f"Delegating next task to worker: {worker} (Attempt {iterations + 1})")
|
||||
return {"next_action": worker}
|
||||
@@ -1,44 +0,0 @@
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||
from ea_chatbot.graph.prompts.coder import ERROR_CORRECTOR_PROMPT
|
||||
from ea_chatbot.schemas import CodeGenerationResponse
|
||||
|
||||
def error_corrector_node(state: AgentState) -> dict:
|
||||
"""Fix the code based on the execution error."""
|
||||
code = state.get("code", "")
|
||||
error = state.get("error", "Unknown error")
|
||||
|
||||
settings = Settings()
|
||||
logger = get_logger("error_corrector")
|
||||
|
||||
logger.warning(f"[bold red]Execution error detected:[/bold red] {error[:100]}...")
|
||||
logger.info("Attempting to correct the code...")
|
||||
|
||||
# Reuse coder LLM config or add a new one. Using coder_llm for now.
|
||||
llm = get_llm_model(
|
||||
settings.coder_llm,
|
||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||
)
|
||||
structured_llm = llm.with_structured_output(CodeGenerationResponse)
|
||||
|
||||
messages = ERROR_CORRECTOR_PROMPT.format_messages(
|
||||
code=code,
|
||||
error=error
|
||||
)
|
||||
|
||||
try:
|
||||
response = structured_llm.invoke(messages)
|
||||
logger.info("[bold green]Correction generated.[/bold green]")
|
||||
|
||||
current_iterations = state.get("iterations", 0)
|
||||
|
||||
return {
|
||||
"code": response.parsed_code,
|
||||
"error": None, # Clear error after fix attempt
|
||||
"iterations": current_iterations + 1
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to correct code: {str(e)}")
|
||||
raise e
|
||||
@@ -1,34 +1,32 @@
|
||||
import yaml
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from ea_chatbot.utils import helpers, database_inspection
|
||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||
from ea_chatbot.graph.prompts.planner import PLANNER_PROMPT
|
||||
from ea_chatbot.schemas import TaskPlanResponse
|
||||
from ea_chatbot.schemas import ChecklistResponse
|
||||
|
||||
def planner_node(state: AgentState) -> dict:
|
||||
"""Generate a structured plan based on the query analysis."""
|
||||
"""Generate a high-level task checklist for the Orchestrator."""
|
||||
question = state["question"]
|
||||
history = state.get("messages", [])[-6:]
|
||||
summary = state.get("summary", "")
|
||||
|
||||
settings = Settings()
|
||||
logger = get_logger("planner")
|
||||
logger = get_logger("orchestrator:planner")
|
||||
|
||||
logger.info("Generating task plan...")
|
||||
logger.info("Generating high-level task checklist...")
|
||||
|
||||
llm = get_llm_model(
|
||||
settings.planner_llm,
|
||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||
)
|
||||
structured_llm = llm.with_structured_output(TaskPlanResponse)
|
||||
structured_llm = llm.with_structured_output(ChecklistResponse)
|
||||
|
||||
date_str = helpers.get_readable_date()
|
||||
|
||||
# Always provide data summary; LLM decides relevance.
|
||||
# Data summary for context
|
||||
database_description = database_inspection.get_data_summary(data_dir=settings.data_dir) or "No data available."
|
||||
example_plan = ""
|
||||
|
||||
messages = PLANNER_PROMPT.format_messages(
|
||||
date=date_str,
|
||||
@@ -36,16 +34,20 @@ def planner_node(state: AgentState) -> dict:
|
||||
history=history,
|
||||
summary=summary,
|
||||
database_description=database_description,
|
||||
example_plan=example_plan
|
||||
example_plan="Decompose into data_analyst and researcher tasks."
|
||||
)
|
||||
|
||||
# Generate the structured plan
|
||||
try:
|
||||
response = structured_llm.invoke(messages)
|
||||
# Convert the structured response back to YAML string for the state
|
||||
plan_yaml = yaml.dump(response.model_dump(), sort_keys=False)
|
||||
logger.info("[bold green]Plan generated successfully.[/bold green]")
|
||||
return {"plan": plan_yaml}
|
||||
response = ChecklistResponse.model_validate(structured_llm.invoke(messages))
|
||||
# Convert ChecklistTask objects to dicts for state
|
||||
checklist = [task.model_dump() for task in response.checklist]
|
||||
logger.info(f"[bold green]Checklist generated with {len(checklist)} tasks.[/bold green]")
|
||||
return {
|
||||
"checklist": checklist,
|
||||
"current_step": 0,
|
||||
"iterations": 0, # Reset iteration counter for the new plan
|
||||
"summary": response.reflection # Use reflection as initial summary
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate plan: {str(e)}")
|
||||
logger.error(f"Failed to generate checklist: {str(e)}")
|
||||
raise e
|
||||
@@ -1,18 +1,9 @@
|
||||
from typing import List, Literal
|
||||
from pydantic import BaseModel, Field
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||
from ea_chatbot.graph.prompts.query_analyzer import QUERY_ANALYZER_PROMPT
|
||||
|
||||
class QueryAnalysis(BaseModel):
|
||||
"""Analysis of the user's query."""
|
||||
data_required: List[str] = Field(description="List of data points or entities mentioned (e.g., ['2024 results', 'Florida']).")
|
||||
unknowns: List[str] = Field(description="List of target information the user wants to know or needed for final answer (e.g., 'who won', 'total votes').")
|
||||
ambiguities: List[str] = Field(description="List of CRITICAL missing details that prevent ANY analysis. Do NOT include database names or plot types if defaults can be used.")
|
||||
conditions: List[str] = Field(description="List of any filters or constraints (e.g., ['year=2024', 'state=Florida']). Include context resolved from history.")
|
||||
next_action: Literal["plan", "clarify", "research"] = Field(description="The next action to take. 'plan' for data analysis (even with defaults), 'research' for general knowledge, or 'clarify' ONLY for critical ambiguities.")
|
||||
from ea_chatbot.schemas import QueryAnalysis
|
||||
|
||||
def query_analyzer_node(state: AgentState) -> dict:
|
||||
"""Analyze the user's question and determine the next course of action."""
|
||||
|
||||
63
backend/src/ea_chatbot/graph/nodes/reflector.py
Normal file
63
backend/src/ea_chatbot/graph/nodes/reflector.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||
from ea_chatbot.schemas import ReflectorResponse
|
||||
|
||||
def reflector_node(state: AgentState) -> dict:
|
||||
"""Evaluate if the worker's output satisfies the current sub-task."""
|
||||
checklist = state.get("checklist", [])
|
||||
current_step = state.get("current_step", 0)
|
||||
summary = state.get("summary", "") # This contains the worker's summary
|
||||
|
||||
if not checklist or current_step >= len(checklist):
|
||||
return {"next_action": "summarize"}
|
||||
|
||||
task_info = checklist[current_step]
|
||||
task_desc = task_info.get("task", "")
|
||||
|
||||
settings = Settings()
|
||||
logger = get_logger("orchestrator:reflector")
|
||||
|
||||
logger.info(f"Evaluating worker output for task: {task_desc[:50]}...")
|
||||
|
||||
llm = get_llm_model(
|
||||
settings.reflector_llm,
|
||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||
)
|
||||
structured_llm = llm.with_structured_output(ReflectorResponse)
|
||||
|
||||
prompt = f"""You are a Lead Orchestrator evaluating the work of a specialized sub-agent.
|
||||
|
||||
**Sub-Task assigned:**
|
||||
{task_desc}
|
||||
|
||||
**Worker's Result Summary:**
|
||||
{summary}
|
||||
|
||||
Evaluate if the result is satisfactory and complete for this specific sub-task.
|
||||
If there were major errors or the output is missing critical data requested in the sub-task, mark satisfied as False."""
|
||||
|
||||
try:
|
||||
response = structured_llm.invoke(prompt)
|
||||
if response.satisfied:
|
||||
logger.info("[bold green]Sub-task satisfied.[/bold green] Advancing plan.")
|
||||
return {
|
||||
"current_step": current_step + 1,
|
||||
"iterations": 0, # Reset for next task
|
||||
"next_action": "delegate"
|
||||
}
|
||||
else:
|
||||
logger.warning(f"[bold yellow]Sub-task NOT satisfied.[/bold yellow] Reason: {response.reasoning}")
|
||||
# Do NOT advance the step. Increment iterations to track retries.
|
||||
return {
|
||||
"iterations": state.get("iterations", 0) + 1,
|
||||
"next_action": "delegate"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to reflect: {str(e)}")
|
||||
# On error, increment iterations to avoid infinite loop if LLM is stuck
|
||||
return {
|
||||
"iterations": state.get("iterations", 0) + 1,
|
||||
"next_action": "delegate"
|
||||
}
|
||||
@@ -1,43 +0,0 @@
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||
from ea_chatbot.graph.prompts.summarizer import SUMMARIZER_PROMPT
|
||||
|
||||
def summarizer_node(state: AgentState) -> dict:
|
||||
"""Summarize the code execution results into a final answer."""
|
||||
question = state["question"]
|
||||
plan = state.get("plan", "")
|
||||
code_output = state.get("code_output", "")
|
||||
history = state.get("messages", [])[-6:]
|
||||
summary = state.get("summary", "")
|
||||
|
||||
settings = Settings()
|
||||
logger = get_logger("summarizer")
|
||||
|
||||
logger.info("Generating final summary...")
|
||||
|
||||
llm = get_llm_model(
|
||||
settings.summarizer_llm,
|
||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||
)
|
||||
|
||||
messages = SUMMARIZER_PROMPT.format_messages(
|
||||
question=question,
|
||||
plan=plan,
|
||||
code_output=code_output,
|
||||
history=history,
|
||||
summary=summary
|
||||
)
|
||||
|
||||
try:
|
||||
response = llm.invoke(messages)
|
||||
logger.info("[bold green]Summary generated.[/bold green]")
|
||||
|
||||
# Return the final message to be added to the state
|
||||
return {
|
||||
"messages": [response]
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate summary: {str(e)}")
|
||||
raise e
|
||||
56
backend/src/ea_chatbot/graph/nodes/synthesizer.py
Normal file
56
backend/src/ea_chatbot/graph/nodes/synthesizer.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||
from ea_chatbot.graph.prompts.synthesizer import SYNTHESIZER_PROMPT
|
||||
|
||||
def synthesizer_node(state: AgentState) -> dict:
|
||||
"""Synthesize the results from multiple workers into a final answer."""
|
||||
question = state["question"]
|
||||
history = state.get("messages", [])
|
||||
|
||||
# We look for the 'summary' from the last worker which might have cumulative info
|
||||
# Or we can look at all messages in history bubbled up from workers.
|
||||
# For now, let's assume the history contains all the worker summaries.
|
||||
|
||||
settings = Settings()
|
||||
logger = get_logger("orchestrator:synthesizer")
|
||||
|
||||
logger.info("Synthesizing final answer from worker results...")
|
||||
|
||||
llm = get_llm_model(
|
||||
settings.synthesizer_llm,
|
||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||
)
|
||||
|
||||
# Artifact summary
|
||||
plots = state.get("plots", [])
|
||||
vfs = state.get("vfs", {})
|
||||
artifacts_summary = ""
|
||||
if plots:
|
||||
artifacts_summary += f"- {len(plots)} generated plot(s) are attached to this response.\n"
|
||||
if vfs:
|
||||
artifacts_summary += "- Data files available in VFS: " + ", ".join(vfs.keys()) + "\n"
|
||||
if not artifacts_summary:
|
||||
artifacts_summary = "No additional artifacts generated."
|
||||
|
||||
# We provide the full history and the original question
|
||||
messages = SYNTHESIZER_PROMPT.format_messages(
|
||||
question=question,
|
||||
history=history,
|
||||
worker_results="Review the worker summaries provided in the message history.",
|
||||
artifacts_summary=artifacts_summary
|
||||
)
|
||||
|
||||
try:
|
||||
response = llm.invoke(messages)
|
||||
logger.info("[bold green]Final synthesis complete.[/bold green]")
|
||||
|
||||
# Return the final message to be added to the state
|
||||
return {
|
||||
"messages": [response],
|
||||
"next_action": "end"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to synthesize final answer: {str(e)}")
|
||||
raise e
|
||||
@@ -9,6 +9,14 @@ The user will provide a task and a plan.
|
||||
- Do NOT assume a dataframe `df` is already loaded unless explicitly stated. You usually need to query it first.
|
||||
- The database schema is described in the prompt. Use it to construct valid SQL queries.
|
||||
|
||||
**Virtual File System (VFS):**
|
||||
- An in-memory file system is available as `vfs`. Use it to persist intermediate data or large artifacts.
|
||||
- `vfs.write(filename, content, metadata=None)`: Save a file (content can be any serializable object).
|
||||
- `vfs.read(filename) -> (content, metadata)`: Read a file.
|
||||
- `vfs.list() -> list[str]`: List all files.
|
||||
- `vfs.delete(filename)`: Delete a file.
|
||||
- Prefer using VFS for intermediate DataFrames or complex data structures instead of printing everything.
|
||||
|
||||
**Plotting:**
|
||||
- If you need to plot any data, use the `plots` list to store the figures.
|
||||
- Example: `plots.append(fig)` or `plots.append(plt.gcf())`.
|
||||
@@ -18,7 +26,8 @@ The user will provide a task and a plan.
|
||||
- Produce FULL, COMPLETE CODE that includes all steps and solves the task!
|
||||
- Always include the import statements at the top of the code (e.g., `import pandas as pd`, `import matplotlib.pyplot as plt`).
|
||||
- Always include print statements to output the results of your code.
|
||||
- Use `db.query_df("SELECT ...")` to get data."""
|
||||
- Use `db.query_df("SELECT ...")` to get data.
|
||||
"""
|
||||
|
||||
CODE_GENERATOR_USER = """TASK:
|
||||
{question}
|
||||
@@ -43,6 +52,7 @@ Return a complete, corrected python code that incorporates the fixes for the err
|
||||
- You have access to a database client via the variable `db`.
|
||||
- Use `db.query_df(sql)` to run queries.
|
||||
- Use `plots.append(fig)` for plots.
|
||||
- You have access to `vfs` for persistent in-memory storage.
|
||||
- Always include imports and print statements."""
|
||||
|
||||
ERROR_CORRECTOR_USER = """FAILED CODE:
|
||||
|
||||
@@ -1,41 +1,29 @@
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
|
||||
PLANNER_SYSTEM = """You are a Research Specialist and your job is to find answers and educate the user.
|
||||
Provide factual information responding directly to the user's question. Include key details and context to ensure your response comprehensively answers their query.
|
||||
PLANNER_SYSTEM = """You are a Lead Orchestrator for an Election Analytics Chatbot.
|
||||
Your job is to decompose complex user queries into a high-level checklist of tasks.
|
||||
|
||||
**Specialized Workers:**
|
||||
1. `data_analyst`: Handles SQL queries, Python data analysis, and plotting. Use this when the user needs numbers, trends, or charts from the internal database.
|
||||
2. `researcher`: Performs web searches for current news, facts, or external data not in the primary database.
|
||||
|
||||
**Orchestration Strategy:**
|
||||
- Analyze the user's question and the available data summary.
|
||||
- Create a logical sequence of tasks (checklist) for these workers.
|
||||
- Be specific in the task description for the worker (e.g., "Find the total votes in Florida 2020").
|
||||
- If the query is ambiguous, the Orchestrator loop will later handle clarification, but for now, make the best plan possible.
|
||||
|
||||
Today's Date is: {date}"""
|
||||
|
||||
PLANNER_USER = """Conversation Summary: {summary}
|
||||
|
||||
TASK:
|
||||
USER QUESTION:
|
||||
{question}
|
||||
|
||||
AVAILABLE DATA SUMMARY (Use only if relevant to the task):
|
||||
AVAILABLE DATABASE SUMMARY:
|
||||
{database_description}
|
||||
|
||||
First: Evaluate whether you have all necessary and requested information to provide a solution.
|
||||
Use the dataset description above to determine what data and in what format you have available to you.
|
||||
You are able to search internet if the user asks for it, or you require any information that you can not derive from the given dataset or the instruction.
|
||||
|
||||
Second: Incorporate any additional relevant context, reasoning, or details from previous interactions or internal chain-of-thought that may impact the solution.
|
||||
Ensure that all such information is fully included in your response rather than referring to previous answers indirectly.
|
||||
|
||||
Third: Reflect on the problem and briefly describe it, while addressing the problem goal, inputs, outputs,
|
||||
rules, constraints, and other relevant details that appear in the problem description.
|
||||
|
||||
Fourth: Based on the preceding steps, formulate your response as an algorithm, breaking the solution in up to eight simple concise yet descriptive, clear English steps.
|
||||
You MUST Include all values or instructions as described in the above task, or retrieved using internet search!
|
||||
If fewer steps suffice, that's acceptable. If more are needed, please include them.
|
||||
Remember to explain steps rather than write code.
|
||||
|
||||
This algorithm will be later converted to Python code.
|
||||
If a dataframe is required, assume it is named 'df' and is already defined/populated based on the data summary.
|
||||
|
||||
There is a list variable called `plots` that you need to use to store any plots you generate. Do not use `plt.show()` as it will render the plot and cause an error.
|
||||
|
||||
Output the algorithm as a YAML string. Always enclose the YAML string within ```yaml tags.
|
||||
|
||||
**Note: Ensure that any necessary context from prior interactions is fully embedded in the plan. Do not use phrases like "refer to previous answer"; instead, provide complete details inline.**
|
||||
Decompose the question into a strategic checklist. For each task, specify which worker should handle it.
|
||||
|
||||
{example_plan}"""
|
||||
|
||||
|
||||
31
backend/src/ea_chatbot/graph/prompts/synthesizer.py
Normal file
31
backend/src/ea_chatbot/graph/prompts/synthesizer.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
|
||||
SYNTHESIZER_SYSTEM = """You are a Lead Orchestrator for an Election Analytics Chatbot.
|
||||
You have coordinated several specialized workers (Data Analysts, Researchers) to answer a user's complex query.
|
||||
|
||||
Your goal is to synthesize their individual findings into a single, cohesive, and comprehensive final response for the user.
|
||||
|
||||
**Guidelines:**
|
||||
- Do NOT mention the internal 'workers' or 'checklist' names.
|
||||
- Combine the data insights (from Data Analysts) and factual research (from Researchers) into a natural narrative.
|
||||
- Ensure all numbers, dates, and names from the worker reports are included accurately.
|
||||
- **Artifacts & Plots:** If plots or charts were generated, refer to them naturally (e.g., "The chart below shows...").
|
||||
- If any part of the plan failed, explain the status honestly but professionally.
|
||||
- Present data in clear formats (tables, bullet points) where appropriate."""
|
||||
|
||||
SYNTHESIZER_USER = """USER QUESTION:
|
||||
{question}
|
||||
|
||||
EXECUTION SUMMARY (Results from specialized workers):
|
||||
{worker_results}
|
||||
|
||||
AVAILABLE ARTIFACTS:
|
||||
{artifacts_summary}
|
||||
|
||||
Provide the final integrated response:"""
|
||||
|
||||
SYNTHESIZER_PROMPT = ChatPromptTemplate.from_messages([
|
||||
("system", SYNTHESIZER_SYSTEM),
|
||||
MessagesPlaceholder(variable_name="history"),
|
||||
("human", SYNTHESIZER_USER),
|
||||
])
|
||||
@@ -12,12 +12,12 @@ class AgentState(AS):
|
||||
|
||||
# Query Analysis (Decomposition results)
|
||||
analysis: Optional[Dict[str, Any]]
|
||||
# Expected keys: "requires_dataset", "expert", "data", "unknown", "condition"
|
||||
# Expected keys: "data_required", "unknowns", "ambiguities", "conditions"
|
||||
|
||||
# Step-by-step reasoning
|
||||
# Step-by-step reasoning (Legacy, use checklist for new flows)
|
||||
plan: Optional[str]
|
||||
|
||||
# Code execution context
|
||||
# Code execution context (Legacy, use workers for new flows)
|
||||
code: Optional[str]
|
||||
code_output: Optional[str]
|
||||
error: Optional[str]
|
||||
@@ -26,11 +26,21 @@ class AgentState(AS):
|
||||
plots: Annotated[List[Any], operator.add] # Matplotlib figures
|
||||
dfs: Dict[str, Any] # Pandas DataFrames
|
||||
|
||||
# Conversation summary
|
||||
# Conversation summary / Latest worker result
|
||||
summary: Optional[str]
|
||||
|
||||
# Routing hint: "clarify", "plan", "research", "end"
|
||||
# Routing hint: "clarify", "plan", "end"
|
||||
next_action: str
|
||||
|
||||
# Number of execution attempts
|
||||
iterations: int
|
||||
|
||||
# --- DeepAgents Extensions ---
|
||||
# High-level plan/checklist: List of Task objects (dicts)
|
||||
checklist: List[Dict[str, Any]]
|
||||
|
||||
# Current active step in the checklist
|
||||
current_step: int
|
||||
|
||||
# Virtual File System (VFS): Map of filenames to content/metadata
|
||||
vfs: Dict[str, Any]
|
||||
|
||||
0
backend/src/ea_chatbot/graph/workers/__init__.py
Normal file
0
backend/src/ea_chatbot/graph/workers/__init__.py
Normal file
42
backend/src/ea_chatbot/graph/workers/data_analyst/mapping.py
Normal file
42
backend/src/ea_chatbot/graph/workers/data_analyst/mapping.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from typing import Dict, Any, List
|
||||
import copy
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.graph.workers.data_analyst.state import WorkerState
|
||||
from ea_chatbot.utils.vfs import safe_vfs_copy
|
||||
|
||||
def prepare_worker_input(state: AgentState) -> Dict[str, Any]:
|
||||
"""Prepare the initial state for the Data Analyst worker."""
|
||||
checklist = state.get("checklist", [])
|
||||
current_step = state.get("current_step", 0)
|
||||
|
||||
# Get the current task description
|
||||
task_desc = "Analyze data" # Default
|
||||
if 0 <= current_step < len(checklist):
|
||||
task_desc = checklist[current_step].get("task", task_desc)
|
||||
|
||||
return {
|
||||
"task": task_desc,
|
||||
"messages": [HumanMessage(content=task_desc)], # Start worker loop with the task
|
||||
"vfs_state": safe_vfs_copy(state.get("vfs", {})),
|
||||
"iterations": 0,
|
||||
"plots": [],
|
||||
"code": None,
|
||||
"output": None,
|
||||
"error": None,
|
||||
"result": None
|
||||
}
|
||||
|
||||
def merge_worker_output(worker_state: WorkerState) -> Dict[str, Any]:
|
||||
"""Map worker results back to the global AgentState."""
|
||||
result = worker_state.get("result") or "No result produced by data analyst worker."
|
||||
|
||||
# We bubble up the summary as an AI message and update the global summary field
|
||||
updates = {
|
||||
"messages": [AIMessage(content=result)],
|
||||
"summary": result,
|
||||
"vfs": worker_state.get("vfs_state", {}),
|
||||
"plots": worker_state.get("plots", [])
|
||||
}
|
||||
|
||||
return updates
|
||||
@@ -0,0 +1,64 @@
|
||||
from typing import Dict, Any, List, Optional
|
||||
from ea_chatbot.graph.workers.data_analyst.state import WorkerState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from ea_chatbot.utils import database_inspection
|
||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||
from ea_chatbot.graph.prompts.coder import CODE_GENERATOR_PROMPT
|
||||
from ea_chatbot.schemas import CodeGenerationResponse
|
||||
|
||||
def coder_node(state: WorkerState) -> dict:
|
||||
"""Generate Python code based on the sub-task assigned to the worker."""
|
||||
task = state["task"]
|
||||
output = state.get("output", "None")
|
||||
error = state.get("error", "None")
|
||||
vfs_state = state.get("vfs_state", {})
|
||||
|
||||
settings = Settings()
|
||||
logger = get_logger("data_analyst_worker:coder")
|
||||
|
||||
logger.info(f"Generating Python code for task: {task[:50]}...")
|
||||
|
||||
# We can use the configured 'coder_llm' for this node
|
||||
llm = get_llm_model(
|
||||
settings.coder_llm,
|
||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||
)
|
||||
structured_llm = llm.with_structured_output(CodeGenerationResponse)
|
||||
|
||||
# Data summary for context
|
||||
database_description = database_inspection.get_data_summary(data_dir=settings.data_dir) or "No data available."
|
||||
|
||||
# VFS Summary: Let the LLM know which files are available in-memory
|
||||
vfs_summary = "Available in-memory files (VFS):\n"
|
||||
if vfs_state:
|
||||
for filename, data in vfs_state.items():
|
||||
if isinstance(data, dict):
|
||||
meta = data.get("metadata", {})
|
||||
vfs_summary += f"- {filename} ({meta.get('type', 'unknown')})\n"
|
||||
else:
|
||||
vfs_summary += f"- {filename} (raw data)\n"
|
||||
else:
|
||||
vfs_summary += "- None"
|
||||
|
||||
# Reuse the global prompt but adapt 'question' and 'plan' labels
|
||||
# For a sub-task worker, 'task' effectively replaces the high-level 'plan'
|
||||
messages = CODE_GENERATOR_PROMPT.format_messages(
|
||||
question=task,
|
||||
plan="Focus on the specific task below.",
|
||||
database_description=database_description,
|
||||
code_exec_results=f"Output: {output}\nError: {error}\n\n{vfs_summary}",
|
||||
example_code=""
|
||||
)
|
||||
|
||||
try:
|
||||
response = structured_llm.invoke(messages)
|
||||
logger.info("[bold green]Code generated.[/bold green]")
|
||||
return {
|
||||
"code": response.parsed_code,
|
||||
"error": None, # Clear previous errors
|
||||
"iterations": state.get("iterations", 0) + 1
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate code: {str(e)}")
|
||||
raise e
|
||||
@@ -1,23 +1,25 @@
|
||||
import io
|
||||
import sys
|
||||
import traceback
|
||||
import copy
|
||||
from contextlib import redirect_stdout
|
||||
from typing import TYPE_CHECKING
|
||||
import pandas as pd
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.graph.workers.data_analyst.state import WorkerState
|
||||
from ea_chatbot.utils.db_client import DBClient
|
||||
from ea_chatbot.utils.vfs import VFSHelper, safe_vfs_copy
|
||||
from ea_chatbot.utils.logging import get_logger
|
||||
from ea_chatbot.config import Settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ea_chatbot.types import DBSettings
|
||||
|
||||
def executor_node(state: AgentState) -> dict:
|
||||
"""Execute the Python code and capture output, plots, and dataframes."""
|
||||
def executor_node(state: WorkerState) -> dict:
|
||||
"""Execute the Python code in the context of the Data Analyst worker."""
|
||||
code = state.get("code")
|
||||
logger = get_logger("executor")
|
||||
logger = get_logger("data_analyst_worker:executor")
|
||||
|
||||
if not code:
|
||||
logger.error("No code provided to executor.")
|
||||
@@ -37,45 +39,40 @@ def executor_node(state: AgentState) -> dict:
|
||||
|
||||
db_client = DBClient(settings=db_settings)
|
||||
|
||||
# Initialize the Virtual File System (VFS) helper with the snapshot from state
|
||||
vfs_state = safe_vfs_copy(state.get("vfs_state", {}))
|
||||
vfs_helper = VFSHelper(vfs_state)
|
||||
|
||||
# Initialize local variables for execution
|
||||
# 'db' is the DBClient instance, 'plots' is for matplotlib figures
|
||||
local_vars = {
|
||||
'db': db_client,
|
||||
'plots': [],
|
||||
'pd': pd
|
||||
'pd': pd,
|
||||
'vfs': vfs_helper
|
||||
}
|
||||
|
||||
stdout_buffer = io.StringIO()
|
||||
error = None
|
||||
code_output = ""
|
||||
output = ""
|
||||
plots = []
|
||||
dfs = {}
|
||||
|
||||
try:
|
||||
with redirect_stdout(stdout_buffer):
|
||||
# Execute the code in the context of local_vars
|
||||
exec(code, {}, local_vars)
|
||||
|
||||
code_output = stdout_buffer.getvalue()
|
||||
output = stdout_buffer.getvalue()
|
||||
|
||||
# Limit the output length if it's too long
|
||||
if code_output.count('\n') > 32:
|
||||
code_output = '\n'.join(code_output.split('\n')[:32]) + '\n...'
|
||||
if output.count('\n') > 32:
|
||||
output = '\n'.join(output.split('\n')[:32]) + '\n...'
|
||||
|
||||
# Extract plots
|
||||
raw_plots = local_vars.get('plots', [])
|
||||
if isinstance(raw_plots, list):
|
||||
plots = [p for p in raw_plots if isinstance(p, Figure)]
|
||||
|
||||
# Extract DataFrames that were likely intended for display
|
||||
# We look for DataFrames in local_vars that were mentioned in the code
|
||||
for key, value in local_vars.items():
|
||||
if isinstance(value, pd.DataFrame):
|
||||
# Heuristic: if the variable name is in the code, it might be a result DF
|
||||
if key in code:
|
||||
dfs[key] = value
|
||||
|
||||
logger.info(f"[bold green]Execution complete.[/bold green] Captured {len(plots)} plots and {len(dfs)} dataframes.")
|
||||
logger.info(f"[bold green]Execution complete.[/bold green] Captured {len(plots)} plots.")
|
||||
|
||||
except Exception as e:
|
||||
# Capture the traceback
|
||||
@@ -90,13 +87,11 @@ def executor_node(state: AgentState) -> dict:
|
||||
error += f"{exc_type.__name__ if exc_type else 'Exception'}: {exc_value}"
|
||||
|
||||
logger.error(f"Execution failed: {str(e)}")
|
||||
|
||||
# If we have an error, we still might want to see partial stdout
|
||||
code_output = stdout_buffer.getvalue()
|
||||
output = stdout_buffer.getvalue()
|
||||
|
||||
return {
|
||||
"code_output": code_output,
|
||||
"output": output,
|
||||
"error": error,
|
||||
"plots": plots,
|
||||
"dfs": dfs
|
||||
"vfs_state": vfs_state
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
from typing import Dict, Any, List, Optional
|
||||
from ea_chatbot.graph.workers.data_analyst.state import WorkerState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||
|
||||
def summarizer_node(state: WorkerState) -> dict:
|
||||
"""Summarize the data analysis results for the Orchestrator."""
|
||||
task = state["task"]
|
||||
output = state.get("output", "")
|
||||
error = state.get("error")
|
||||
plots = state.get("plots", [])
|
||||
vfs_state = state.get("vfs_state", {})
|
||||
|
||||
settings = Settings()
|
||||
logger = get_logger("data_analyst_worker:summarizer")
|
||||
|
||||
logger.info("Summarizing analysis results for the Orchestrator...")
|
||||
|
||||
# Artifact summary
|
||||
artifact_info = ""
|
||||
if plots:
|
||||
artifact_info += f"- Generated {len(plots)} plot(s).\n"
|
||||
if vfs_state:
|
||||
artifact_info += "- VFS Artifacts: " + ", ".join(vfs_state.keys()) + "\n"
|
||||
|
||||
# We can use a smaller/faster model for this summary if needed
|
||||
llm = get_llm_model(
|
||||
settings.planner_llm, # Using planner model for summary logic
|
||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||
)
|
||||
|
||||
prompt = f"""You are a data analyst sub-agent. You have completed a sub-task.
|
||||
Task: {task}
|
||||
Execution Results: {output}
|
||||
Error Log (if any): {error}
|
||||
{artifact_info}
|
||||
|
||||
Provide a concise summary of the findings or status for the top-level Orchestrator.
|
||||
If plots or data files were generated, mention them.
|
||||
If the execution failed after multiple retries, explain why concisely.
|
||||
Do NOT include the raw Python code, just the results of the analysis."""
|
||||
|
||||
try:
|
||||
response = llm.invoke(prompt)
|
||||
result = response.content if hasattr(response, "content") else str(response)
|
||||
logger.info("[bold green]Analysis results summarized.[/bold green]")
|
||||
return {
|
||||
"result": result
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to summarize results: {str(e)}")
|
||||
raise e
|
||||
28
backend/src/ea_chatbot/graph/workers/data_analyst/state.py
Normal file
28
backend/src/ea_chatbot/graph/workers/data_analyst/state.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from typing import TypedDict, List, Dict, Any, Optional
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
class WorkerState(TypedDict):
|
||||
"""Internal state for the Data Analyst worker subgraph."""
|
||||
|
||||
# Internal worker conversation (not bubbled up to global unless summary)
|
||||
messages: List[BaseMessage]
|
||||
|
||||
# The specific sub-task assigned by the Orchestrator
|
||||
task: str
|
||||
|
||||
# Generated code and execution context
|
||||
code: Optional[str]
|
||||
output: Optional[str]
|
||||
error: Optional[str]
|
||||
|
||||
# Number of internal retry attempts
|
||||
iterations: int
|
||||
|
||||
# Temporary storage for analysis results
|
||||
plots: List[Any]
|
||||
|
||||
# Isolated/Snapshot view of the VFS
|
||||
vfs_state: Dict[str, Any]
|
||||
|
||||
# The final summary or result to return to the Orchestrator
|
||||
result: Optional[str]
|
||||
@@ -0,0 +1,51 @@
|
||||
from langgraph.graph import StateGraph, END
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
from ea_chatbot.graph.workers.data_analyst.state import WorkerState
|
||||
from ea_chatbot.graph.workers.data_analyst.nodes.coder import coder_node
|
||||
from ea_chatbot.graph.workers.data_analyst.nodes.executor import executor_node
|
||||
from ea_chatbot.graph.workers.data_analyst.nodes.summarizer import summarizer_node
|
||||
|
||||
def router(state: WorkerState) -> str:
|
||||
"""Routes the subgraph between coding, execution, and summarization."""
|
||||
error = state.get("error")
|
||||
iterations = state.get("iterations", 0)
|
||||
|
||||
if error and iterations < 3:
|
||||
# Retry with error correction
|
||||
return "coder"
|
||||
|
||||
# Either success or max retries reached
|
||||
return "summarizer"
|
||||
|
||||
def create_data_analyst_worker(
|
||||
coder=coder_node,
|
||||
executor=executor_node,
|
||||
summarizer=summarizer_node
|
||||
) -> CompiledStateGraph:
|
||||
"""Create the Data Analyst worker subgraph."""
|
||||
workflow = StateGraph(WorkerState)
|
||||
|
||||
# Add Nodes
|
||||
workflow.add_node("coder", coder)
|
||||
workflow.add_node("executor", executor)
|
||||
workflow.add_node("summarizer", summarizer)
|
||||
|
||||
# Set entry point
|
||||
workflow.set_entry_point("coder")
|
||||
|
||||
# Add Edges
|
||||
workflow.add_edge("coder", "executor")
|
||||
|
||||
# Add Conditional Edges
|
||||
workflow.add_conditional_edges(
|
||||
"executor",
|
||||
router,
|
||||
{
|
||||
"coder": "coder",
|
||||
"summarizer": "summarizer"
|
||||
}
|
||||
)
|
||||
|
||||
workflow.add_edge("summarizer", END)
|
||||
|
||||
return workflow.compile()
|
||||
32
backend/src/ea_chatbot/graph/workers/researcher/mapping.py
Normal file
32
backend/src/ea_chatbot/graph/workers/researcher/mapping.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from typing import Dict, Any, List
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.graph.workers.researcher.state import WorkerState
|
||||
|
||||
def prepare_researcher_input(state: AgentState) -> Dict[str, Any]:
|
||||
"""Prepare the initial state for the Researcher worker."""
|
||||
checklist = state.get("checklist", [])
|
||||
current_step = state.get("current_step", 0)
|
||||
|
||||
task_desc = "Perform research"
|
||||
if 0 <= current_step < len(checklist):
|
||||
task_desc = checklist[current_step].get("task", task_desc)
|
||||
|
||||
return {
|
||||
"task": task_desc,
|
||||
"messages": [HumanMessage(content=task_desc)],
|
||||
"queries": [],
|
||||
"raw_results": [],
|
||||
"iterations": 0,
|
||||
"result": None
|
||||
}
|
||||
|
||||
def merge_researcher_output(worker_state: WorkerState) -> Dict[str, Any]:
|
||||
"""Map researcher results back to the global AgentState."""
|
||||
result = worker_state.get("result") or "No result produced by researcher worker."
|
||||
|
||||
return {
|
||||
"messages": [AIMessage(content=result)],
|
||||
"summary": result,
|
||||
# Researcher doesn't usually update VFS or Plots, but we keep the structure
|
||||
}
|
||||
@@ -1,24 +1,20 @@
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.graph.workers.researcher.state import WorkerState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from ea_chatbot.utils import helpers
|
||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||
from ea_chatbot.graph.prompts.researcher import RESEARCHER_PROMPT
|
||||
|
||||
def researcher_node(state: AgentState) -> dict:
|
||||
"""Handle general research queries or web searches."""
|
||||
question = state["question"]
|
||||
history = state.get("messages", [])[-6:]
|
||||
summary = state.get("summary", "")
|
||||
def searcher_node(state: WorkerState) -> dict:
|
||||
"""Execute web research for the specific task."""
|
||||
task = state["task"]
|
||||
logger = get_logger("researcher_worker:searcher")
|
||||
|
||||
logger.info(f"Researching task: {task[:50]}...")
|
||||
|
||||
settings = Settings()
|
||||
logger = get_logger("researcher")
|
||||
|
||||
logger.info(f"Researching question: [italic]\"{question}\"[/italic]")
|
||||
|
||||
# Use researcher_llm from settings
|
||||
llm = get_llm_model(
|
||||
settings.researcher_llm,
|
||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||
@@ -26,34 +22,39 @@ def researcher_node(state: AgentState) -> dict:
|
||||
|
||||
date_str = helpers.get_readable_date()
|
||||
|
||||
# Adapt the global researcher prompt for the sub-task
|
||||
messages = RESEARCHER_PROMPT.format_messages(
|
||||
date=date_str,
|
||||
question=question,
|
||||
history=history,
|
||||
summary=summary
|
||||
question=task,
|
||||
history=[], # Worker has fresh context or task-specific history
|
||||
summary=""
|
||||
)
|
||||
|
||||
# Provider-aware tool binding
|
||||
# Tool binding
|
||||
try:
|
||||
if isinstance(llm, ChatGoogleGenerativeAI):
|
||||
# Native Google Search for Gemini
|
||||
llm_with_tools = llm.bind_tools([{"google_search": {}}])
|
||||
elif isinstance(llm, ChatOpenAI):
|
||||
# Native Web Search for OpenAI (built-in tool)
|
||||
llm_with_tools = llm.bind_tools([{"type": "web_search"}])
|
||||
else:
|
||||
# Fallback for other providers that might not support these specific search tools
|
||||
llm_with_tools = llm
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to bind search tools: {str(e)}. Falling back to base LLM.")
|
||||
logger.warning(f"Failed to bind search tools: {str(e)}")
|
||||
llm_with_tools = llm
|
||||
|
||||
try:
|
||||
response = llm_with_tools.invoke(messages)
|
||||
logger.info("[bold green]Research complete.[/bold green]")
|
||||
logger.info("[bold green]Search complete.[/bold green]")
|
||||
|
||||
# In a real tool-use scenario, we'd extract the tool outputs here.
|
||||
# For now, we'll store the response content as a 'raw_result'.
|
||||
content = response.content if hasattr(response, "content") else str(response)
|
||||
|
||||
return {
|
||||
"messages": [response]
|
||||
"messages": [response],
|
||||
"raw_results": [content],
|
||||
"iterations": state.get("iterations", 0) + 1
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Research failed: {str(e)}")
|
||||
raise e
|
||||
logger.error(f"Search failed: {str(e)}")
|
||||
raise e
|
||||
@@ -0,0 +1,49 @@
|
||||
from ea_chatbot.graph.workers.researcher.state import WorkerState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||
|
||||
def summarizer_node(state: WorkerState) -> dict:
|
||||
"""Summarize research results for the Orchestrator."""
|
||||
task = state["task"]
|
||||
raw_results = state.get("raw_results", [])
|
||||
|
||||
logger = get_logger("researcher_worker:summarizer")
|
||||
logger.info("Summarizing research results...")
|
||||
|
||||
settings = Settings()
|
||||
llm = get_llm_model(
|
||||
settings.planner_llm,
|
||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||
)
|
||||
|
||||
# Ensure all results are strings (Gemini/OpenAI might return complex content)
|
||||
processed_results = []
|
||||
for res in raw_results:
|
||||
if isinstance(res, list):
|
||||
processed_results.append(str(res))
|
||||
else:
|
||||
processed_results.append(str(res))
|
||||
|
||||
results_str = "\n---\n".join(processed_results)
|
||||
|
||||
prompt = f"""You are a Research Specialist sub-agent. You have completed a research sub-task.
|
||||
Task: {task}
|
||||
|
||||
Raw Research Findings:
|
||||
{results_str}
|
||||
|
||||
Provide a concise, factual summary of the findings for the top-level Orchestrator.
|
||||
Ensure all key facts, dates, and sources (if provided) are preserved.
|
||||
Do NOT include internal reasoning, just the factual summary."""
|
||||
|
||||
try:
|
||||
response = llm.invoke(prompt)
|
||||
result = response.content if hasattr(response, "content") else str(response)
|
||||
logger.info("[bold green]Research summary complete.[/bold green]")
|
||||
return {
|
||||
"result": result
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to summarize research: {str(e)}")
|
||||
raise e
|
||||
21
backend/src/ea_chatbot/graph/workers/researcher/state.py
Normal file
21
backend/src/ea_chatbot/graph/workers/researcher/state.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from typing import TypedDict, List, Dict, Any, Optional
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
class WorkerState(TypedDict):
|
||||
"""Internal state for the Researcher worker subgraph."""
|
||||
|
||||
# Internal worker conversation
|
||||
messages: List[BaseMessage]
|
||||
|
||||
# The specific sub-task assigned by the Orchestrator
|
||||
task: str
|
||||
|
||||
# Search context
|
||||
queries: List[str]
|
||||
raw_results: List[str]
|
||||
|
||||
# Number of internal retry/refinement attempts
|
||||
iterations: int
|
||||
|
||||
# The final research summary to return to the Orchestrator
|
||||
result: Optional[str]
|
||||
25
backend/src/ea_chatbot/graph/workers/researcher/workflow.py
Normal file
25
backend/src/ea_chatbot/graph/workers/researcher/workflow.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from langgraph.graph import StateGraph, END
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
from ea_chatbot.graph.workers.researcher.state import WorkerState
|
||||
from ea_chatbot.graph.workers.researcher.nodes.searcher import searcher_node
|
||||
from ea_chatbot.graph.workers.researcher.nodes.summarizer import summarizer_node
|
||||
|
||||
def create_researcher_worker(
|
||||
searcher=searcher_node,
|
||||
summarizer=summarizer_node
|
||||
) -> CompiledStateGraph:
|
||||
"""Create the Researcher worker subgraph."""
|
||||
workflow = StateGraph(WorkerState)
|
||||
|
||||
# Add Nodes
|
||||
workflow.add_node("searcher", searcher)
|
||||
workflow.add_node("summarizer", summarizer)
|
||||
|
||||
# Set entry point
|
||||
workflow.set_entry_point("searcher")
|
||||
|
||||
# Add Edges
|
||||
workflow.add_edge("searcher", "summarizer")
|
||||
workflow.add_edge("summarizer", END)
|
||||
|
||||
return workflow.compile()
|
||||
@@ -2,86 +2,101 @@ from langgraph.graph import StateGraph, END
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node
|
||||
from ea_chatbot.graph.nodes.planner import planner_node
|
||||
from ea_chatbot.graph.nodes.coder import coder_node
|
||||
from ea_chatbot.graph.nodes.error_corrector import error_corrector_node
|
||||
from ea_chatbot.graph.nodes.executor import executor_node
|
||||
from ea_chatbot.graph.nodes.summarizer import summarizer_node
|
||||
from ea_chatbot.graph.nodes.researcher import researcher_node
|
||||
from ea_chatbot.graph.nodes.delegate import delegate_node
|
||||
from ea_chatbot.graph.nodes.reflector import reflector_node
|
||||
from ea_chatbot.graph.nodes.synthesizer import synthesizer_node
|
||||
from ea_chatbot.graph.workers.data_analyst.workflow import create_data_analyst_worker
|
||||
from ea_chatbot.graph.workers.data_analyst.mapping import prepare_worker_input, merge_worker_output
|
||||
from ea_chatbot.graph.workers.researcher.workflow import create_researcher_worker
|
||||
from ea_chatbot.graph.workers.researcher.mapping import prepare_researcher_input, merge_researcher_output
|
||||
from ea_chatbot.graph.nodes.clarification import clarification_node
|
||||
from ea_chatbot.graph.nodes.summarize_conversation import summarize_conversation_node
|
||||
|
||||
MAX_ITERATIONS = 3
|
||||
# Global cache for compiled subgraphs
|
||||
_DATA_ANALYST_WORKER = create_data_analyst_worker()
|
||||
_RESEARCHER_WORKER = create_researcher_worker()
|
||||
|
||||
def router(state: AgentState) -> str:
|
||||
"""Route to the next node based on the analysis."""
|
||||
# Define worker nodes as piped runnables to enable subgraph event propagation
|
||||
data_analyst_worker_runnable = prepare_worker_input | _DATA_ANALYST_WORKER | merge_worker_output
|
||||
researcher_worker_runnable = prepare_researcher_input | _RESEARCHER_WORKER | merge_researcher_output
|
||||
|
||||
def main_router(state: AgentState) -> str:
|
||||
"""Route from query analyzer based on initial assessment."""
|
||||
next_action = state.get("next_action")
|
||||
if next_action == "plan":
|
||||
return "planner"
|
||||
elif next_action == "research":
|
||||
return "researcher"
|
||||
elif next_action == "clarify":
|
||||
if next_action == "clarify":
|
||||
return "clarification"
|
||||
else:
|
||||
return END
|
||||
# Even if QA suggests 'research', we now go through 'planner' for orchestration
|
||||
# aligning with the new hierarchical architecture.
|
||||
return "planner"
|
||||
|
||||
def create_workflow():
|
||||
"""Create the LangGraph workflow."""
|
||||
def delegation_router(state: AgentState) -> str:
|
||||
"""Route from delegate node to specific workers or synthesis."""
|
||||
next_action = state.get("next_action")
|
||||
if next_action == "data_analyst":
|
||||
return "data_analyst_worker"
|
||||
elif next_action == "researcher":
|
||||
return "researcher_worker"
|
||||
elif next_action == "summarize":
|
||||
return "synthesizer"
|
||||
return "synthesizer"
|
||||
|
||||
def create_workflow(
|
||||
query_analyzer=query_analyzer_node,
|
||||
planner=planner_node,
|
||||
delegate=delegate_node,
|
||||
data_analyst_worker=data_analyst_worker_runnable,
|
||||
researcher_worker=researcher_worker_runnable,
|
||||
reflector=reflector_node,
|
||||
synthesizer=synthesizer_node,
|
||||
clarification=clarification_node,
|
||||
summarize_conversation=summarize_conversation_node
|
||||
):
|
||||
"""Create the high-level Orchestrator workflow."""
|
||||
workflow = StateGraph(AgentState)
|
||||
|
||||
# Add nodes
|
||||
workflow.add_node("query_analyzer", query_analyzer_node)
|
||||
workflow.add_node("planner", planner_node)
|
||||
workflow.add_node("coder", coder_node)
|
||||
workflow.add_node("error_corrector", error_corrector_node)
|
||||
workflow.add_node("researcher", researcher_node)
|
||||
workflow.add_node("clarification", clarification_node)
|
||||
workflow.add_node("executor", executor_node)
|
||||
workflow.add_node("summarizer", summarizer_node)
|
||||
workflow.add_node("summarize_conversation", summarize_conversation_node)
|
||||
# Add Nodes
|
||||
workflow.add_node("query_analyzer", query_analyzer)
|
||||
workflow.add_node("planner", planner)
|
||||
workflow.add_node("delegate", delegate)
|
||||
workflow.add_node("data_analyst_worker", data_analyst_worker)
|
||||
workflow.add_node("researcher_worker", researcher_worker)
|
||||
workflow.add_node("reflector", reflector)
|
||||
workflow.add_node("synthesizer", synthesizer)
|
||||
workflow.add_node("clarification", clarification)
|
||||
workflow.add_node("summarize_conversation", summarize_conversation)
|
||||
|
||||
# Set entry point
|
||||
workflow.set_entry_point("query_analyzer")
|
||||
|
||||
# Add conditional edges from query_analyzer
|
||||
# Edges
|
||||
workflow.add_conditional_edges(
|
||||
"query_analyzer",
|
||||
router,
|
||||
main_router,
|
||||
{
|
||||
"planner": "planner",
|
||||
"researcher": "researcher",
|
||||
"clarification": "clarification",
|
||||
END: END
|
||||
"planner": "planner"
|
||||
}
|
||||
)
|
||||
|
||||
# Linear flow for planning and coding
|
||||
workflow.add_edge("planner", "coder")
|
||||
workflow.add_edge("coder", "executor")
|
||||
workflow.add_edge("planner", "delegate")
|
||||
|
||||
# Executor routing
|
||||
def executor_router(state: AgentState) -> str:
|
||||
if state.get("error"):
|
||||
# Check for iteration limit to prevent infinite loops
|
||||
if state.get("iterations", 0) >= MAX_ITERATIONS:
|
||||
return "summarizer"
|
||||
return "error_corrector"
|
||||
return "summarizer"
|
||||
|
||||
workflow.add_conditional_edges(
|
||||
"executor",
|
||||
executor_router,
|
||||
"delegate",
|
||||
delegation_router,
|
||||
{
|
||||
"error_corrector": "error_corrector",
|
||||
"summarizer": "summarizer"
|
||||
"data_analyst_worker": "data_analyst_worker",
|
||||
"researcher_worker": "researcher_worker",
|
||||
"synthesizer": "synthesizer"
|
||||
}
|
||||
)
|
||||
|
||||
workflow.add_edge("error_corrector", "executor")
|
||||
workflow.add_edge("data_analyst_worker", "reflector")
|
||||
workflow.add_edge("researcher_worker", "reflector")
|
||||
workflow.add_edge("reflector", "delegate")
|
||||
|
||||
workflow.add_edge("researcher", "summarize_conversation")
|
||||
workflow.add_edge("clarification", END)
|
||||
workflow.add_edge("summarizer", "summarize_conversation")
|
||||
workflow.add_edge("synthesizer", "summarize_conversation")
|
||||
workflow.add_edge("summarize_conversation", END)
|
||||
workflow.add_edge("clarification", END)
|
||||
|
||||
# Compile the graph
|
||||
app = workflow.compile()
|
||||
|
||||
@@ -1,7 +1,15 @@
|
||||
from pydantic import BaseModel, Field, computed_field
|
||||
from typing import Sequence, Optional
|
||||
from typing import Sequence, Optional, List, Dict, Any, Literal
|
||||
import re
|
||||
|
||||
class QueryAnalysis(BaseModel):
|
||||
"""Analysis of the user's query."""
|
||||
data_required: List[str] = Field(description="List of data points or entities mentioned (e.g., ['2024 results', 'Florida']).")
|
||||
unknowns: List[str] = Field(description="List of target information the user wants to know or needed for final answer (e.g., 'who won', 'total votes').")
|
||||
ambiguities: List[str] = Field(description="List of CRITICAL missing details that prevent ANY analysis. Do NOT include database names or plot types if defaults can be used.")
|
||||
conditions: List[str] = Field(description="List of any filters or constraints (e.g., ['year=2024', 'state=Florida']). Include context resolved from history.")
|
||||
next_action: Literal["plan", "clarify", "research"] = Field(description="The next action to take. 'plan' for data analysis (even with defaults), 'research' for general knowledge, or 'clarify' ONLY for critical ambiguities.")
|
||||
|
||||
class TaskPlanContext(BaseModel):
|
||||
'''Background context relevant to the task plan'''
|
||||
initial_context: str = Field(
|
||||
@@ -33,6 +41,22 @@ class TaskPlanResponse(BaseModel):
|
||||
description="Ordered list of steps to execute that follow the 'Step <number>: <detail>' pattern.",
|
||||
)
|
||||
|
||||
class ChecklistTask(BaseModel):
|
||||
'''A specific sub-task in the high-level orchestrator plan'''
|
||||
task: str = Field(description="Description of the sub-task")
|
||||
worker: str = Field(description="The worker to delegate to (data_analyst or researcher)")
|
||||
|
||||
class ChecklistResponse(BaseModel):
|
||||
'''Orchestrator's decomposed plan/checklist'''
|
||||
goal: str = Field(description="Overall objective")
|
||||
reflection: str = Field(description="Strategic reasoning")
|
||||
checklist: List[ChecklistTask] = Field(description="Ordered list of tasks for specialized workers")
|
||||
|
||||
class ReflectorResponse(BaseModel):
|
||||
'''Orchestrator's evaluation of worker output'''
|
||||
satisfied: bool = Field(description="Whether the worker's output satisfies the sub-task requirements")
|
||||
reasoning: str = Field(description="Brief explanation of the evaluation")
|
||||
|
||||
_IM_SEP_TOKEN_PATTERN = re.compile(re.escape("<|im_sep|>"))
|
||||
_CODE_BLOCK_PATTERN = re.compile(r"```(?:python\s*)?(.*?)\s*```", re.DOTALL)
|
||||
_FORBIDDEN_MODULES = (
|
||||
|
||||
69
backend/src/ea_chatbot/utils/vfs.py
Normal file
69
backend/src/ea_chatbot/utils/vfs.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import copy
|
||||
from typing import Dict, Any, Optional, Tuple, List
|
||||
from ea_chatbot.utils.logging import get_logger
|
||||
|
||||
logger = get_logger("utils:vfs")
|
||||
|
||||
def safe_vfs_copy(vfs_state: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Perform a safe deep copy of the VFS state.
|
||||
|
||||
If an entry cannot be deep-copied (e.g., it contains a non-copyable object like a DB handle),
|
||||
logs an error and replaces the entry with a descriptive error marker.
|
||||
This prevents crashing the graph/persistence while making the failure explicit.
|
||||
"""
|
||||
new_vfs = {}
|
||||
for filename, data in vfs_state.items():
|
||||
try:
|
||||
# Attempt a standard deepcopy for isolation
|
||||
new_vfs[filename] = copy.deepcopy(data)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"CRITICAL: VFS artifact '{filename}' is NOT copyable/serializable: {str(e)}. "
|
||||
"Replacing with error placeholder to prevent graph crash."
|
||||
)
|
||||
# Replace with a standardized error artifact
|
||||
new_vfs[filename] = {
|
||||
"content": f"<ERROR: This artifact could not be persisted or copied: {str(e)}>",
|
||||
"metadata": {
|
||||
"type": "error",
|
||||
"error": str(e),
|
||||
"original_filename": filename
|
||||
}
|
||||
}
|
||||
return new_vfs
|
||||
|
||||
class VFSHelper:
|
||||
"""Helper class for managing in-memory Virtual File System (VFS) artifacts."""
|
||||
|
||||
def __init__(self, vfs_state: Dict[str, Any]):
|
||||
"""Initialize with a reference to the VFS state from AgentState."""
|
||||
self._vfs = vfs_state
|
||||
|
||||
def write(self, filename: str, content: Any, metadata: Optional[Dict[str, Any]] = None) -> None:
|
||||
"""Write a file to the VFS."""
|
||||
self._vfs[filename] = {
|
||||
"content": content,
|
||||
"metadata": metadata or {}
|
||||
}
|
||||
|
||||
def read(self, filename: str) -> Tuple[Optional[Any], Optional[Dict[str, Any]]]:
|
||||
"""Read a file and its metadata from the VFS. Returns (None, None) if not found."""
|
||||
file_data = self._vfs.get(filename)
|
||||
if file_data is not None:
|
||||
# Handle raw values (backwards compatibility or inconsistent schema)
|
||||
if not isinstance(file_data, dict) or "content" not in file_data:
|
||||
return file_data, {}
|
||||
return file_data["content"], file_data.get("metadata", {})
|
||||
return None, None
|
||||
|
||||
def list(self) -> List[str]:
|
||||
"""List all filenames in the VFS."""
|
||||
return list(self._vfs.keys())
|
||||
|
||||
def delete(self, filename: str) -> bool:
|
||||
"""Delete a file from the VFS. Returns True if deleted, False if not found."""
|
||||
if filename in self._vfs:
|
||||
del self._vfs[filename]
|
||||
return True
|
||||
return False
|
||||
@@ -36,24 +36,25 @@ async def test_stream_agent_events_all_features():
|
||||
# Stream chunk
|
||||
{
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"langgraph_node": "summarizer"},
|
||||
"metadata": {"langgraph_node": "synthesizer"},
|
||||
"data": {"chunk": AIMessage(content="Hello ")}
|
||||
},
|
||||
{
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"langgraph_node": "summarizer"},
|
||||
"metadata": {"langgraph_node": "synthesizer"},
|
||||
"data": {"chunk": AIMessage(content="world")}
|
||||
},
|
||||
# Plot event
|
||||
# Plot event - with nested subgraph it might bubble up or come directly from data_analyst_worker
|
||||
# Let's mock it coming from the data_analyst_worker on_chain_end
|
||||
{
|
||||
"event": "on_chain_end",
|
||||
"name": "executor",
|
||||
"name": "data_analyst_worker",
|
||||
"data": {"output": {"plots": [fig]}}
|
||||
},
|
||||
# Final response
|
||||
{
|
||||
"event": "on_chain_end",
|
||||
"name": "summarizer",
|
||||
"name": "synthesizer",
|
||||
"data": {"output": {"messages": [AIMessage(content="Hello world final")]}}
|
||||
},
|
||||
# Summary update
|
||||
@@ -91,7 +92,7 @@ async def test_stream_agent_events_all_features():
|
||||
assert any(r.get("type") == "on_chat_model_stream" for r in results)
|
||||
|
||||
# Verify plot was encoded
|
||||
plot_event = next(r for r in results if r.get("name") == "executor")
|
||||
plot_event = next(r for r in results if r.get("name") == "data_analyst_worker")
|
||||
assert "encoded_plots" in plot_event["data"]
|
||||
assert len(plot_event["data"]["encoded_plots"]) == 1
|
||||
|
||||
|
||||
@@ -123,3 +123,20 @@ def test_get_me_success(client):
|
||||
assert response.status_code == 200
|
||||
assert response.json()["email"] == "test@example.com"
|
||||
assert response.json()["id"] == "123"
|
||||
|
||||
def test_get_me_rejects_refresh_token(client):
|
||||
"""Test that /auth/me rejects refresh tokens for authentication."""
|
||||
from ea_chatbot.api.utils import create_refresh_token
|
||||
token = create_refresh_token(data={"sub": "123"})
|
||||
|
||||
with patch("ea_chatbot.api.dependencies.history_manager") as mock_hm:
|
||||
# Even if the user exists, the dependency should reject the token type
|
||||
mock_hm.get_user_by_id.return_value = User(id="123", username="test@example.com")
|
||||
|
||||
response = client.get(
|
||||
"/api/v1/auth/me",
|
||||
headers={"Authorization": f"Bearer {token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
assert "Cannot use refresh token" in response.json()["detail"]
|
||||
|
||||
@@ -19,14 +19,13 @@ def auth_header(mock_user):
|
||||
yield {"Authorization": f"Bearer {token}"}
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
def test_persistence_integration_success(auth_header, mock_user):
|
||||
"""Test that messages and plots are persisted correctly during streaming."""
|
||||
mock_events = [
|
||||
{"event": "on_chat_model_stream", "name": "summarizer", "data": {"chunk": "Final answer"}},
|
||||
{"event": "on_chain_end", "name": "summarizer", "data": {"output": {"messages": [{"content": "Final answer"}]}}},
|
||||
{"event": "on_chain_end", "name": "summarize_conversation", "data": {"output": {"summary": "New summary"}}}
|
||||
]
|
||||
|
||||
def test_persistence_integration_success(auth_header, mock_user):
|
||||
"""Test that messages and plots are persisted correctly during streaming."""
|
||||
mock_events = [
|
||||
{"event": "on_chat_model_stream", "metadata": {"langgraph_node": "synthesizer"}, "data": {"chunk": "Final answer"}},
|
||||
{"event": "on_chain_end", "name": "synthesizer", "data": {"output": {"messages": [{"content": "Final answer"}]}}},
|
||||
{"event": "on_chain_end", "name": "summarize_conversation", "data": {"output": {"summary": "New summary"}}}
|
||||
]
|
||||
async def mock_astream_events(*args, **kwargs):
|
||||
for event in mock_events:
|
||||
yield event
|
||||
|
||||
@@ -1,62 +0,0 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.graph.nodes.coder import coder_node
|
||||
from ea_chatbot.graph.nodes.error_corrector import error_corrector_node
|
||||
|
||||
@pytest.fixture
|
||||
def mock_state():
|
||||
return {
|
||||
"messages": [],
|
||||
"question": "Show me results for New Jersey",
|
||||
"plan": "Step 1: Load data\nStep 2: Filter by NJ",
|
||||
"code": None,
|
||||
"error": None,
|
||||
"plots": [],
|
||||
"dfs": {},
|
||||
"next_action": "plan"
|
||||
}
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.coder.get_llm_model")
|
||||
@patch("ea_chatbot.utils.database_inspection.get_data_summary")
|
||||
def test_coder_node(mock_get_summary, mock_get_llm, mock_state):
|
||||
"""Test coder node generates code from plan."""
|
||||
mock_get_summary.return_value = "Column: Name, Type: text"
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
from ea_chatbot.schemas import CodeGenerationResponse
|
||||
mock_response = CodeGenerationResponse(
|
||||
code="import pandas as pd\nprint('Hello')",
|
||||
explanation="Generated code"
|
||||
)
|
||||
mock_llm.with_structured_output.return_value.invoke.return_value = mock_response
|
||||
|
||||
result = coder_node(mock_state)
|
||||
|
||||
assert "code" in result
|
||||
assert "import pandas as pd" in result["code"]
|
||||
assert "error" in result
|
||||
assert result["error"] is None
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.error_corrector.get_llm_model")
|
||||
def test_error_corrector_node(mock_get_llm, mock_state):
|
||||
"""Test error corrector node fixes code."""
|
||||
mock_state["code"] = "import pandas as pd\nprint(undefined_var)"
|
||||
mock_state["error"] = "NameError: name 'undefined_var' is not defined"
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
from ea_chatbot.schemas import CodeGenerationResponse
|
||||
mock_response = CodeGenerationResponse(
|
||||
code="import pandas as pd\nprint('Defined')",
|
||||
explanation="Fixed variable"
|
||||
)
|
||||
mock_llm.with_structured_output.return_value.invoke.return_value = mock_response
|
||||
|
||||
result = error_corrector_node(mock_state)
|
||||
|
||||
assert "code" in result
|
||||
assert "print('Defined')" in result["code"]
|
||||
assert result["error"] is None
|
||||
53
backend/tests/test_data_analyst_mapping.py
Normal file
53
backend/tests/test_data_analyst_mapping.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.graph.workers.data_analyst.mapping import (
|
||||
prepare_worker_input,
|
||||
merge_worker_output
|
||||
)
|
||||
from ea_chatbot.graph.workers.data_analyst.state import WorkerState
|
||||
|
||||
def test_prepare_worker_input():
|
||||
"""Verify that we correctly map global state to worker input."""
|
||||
global_state = AgentState(
|
||||
messages=[HumanMessage(content="global message")],
|
||||
question="original question",
|
||||
checklist=[{"task": "Worker Task", "status": "pending"}],
|
||||
current_step=0,
|
||||
vfs={"old.txt": "old data"},
|
||||
plots=[],
|
||||
dfs={},
|
||||
next_action="test",
|
||||
iterations=0
|
||||
)
|
||||
|
||||
worker_input = prepare_worker_input(global_state)
|
||||
|
||||
assert worker_input["task"] == "Worker Task"
|
||||
assert "old.txt" in worker_input["vfs_state"]
|
||||
# Internal worker messages should start fresh or with the task
|
||||
assert len(worker_input["messages"]) == 1
|
||||
assert worker_input["messages"][0].content == "Worker Task"
|
||||
|
||||
def test_merge_worker_output():
|
||||
"""Verify that we correctly merge worker results back to global state."""
|
||||
worker_state = WorkerState(
|
||||
messages=[HumanMessage(content="internal"), AIMessage(content="summary")],
|
||||
task="Worker Task",
|
||||
result="Finished analysis",
|
||||
plots=["plot1"],
|
||||
vfs_state={"new.txt": "new data"},
|
||||
iterations=2
|
||||
)
|
||||
|
||||
updates = merge_worker_output(worker_state)
|
||||
|
||||
# We expect the 'result' to be added as an AI message to global history
|
||||
assert len(updates["messages"]) == 1
|
||||
assert updates["messages"][0].content == "Finished analysis"
|
||||
|
||||
# VFS should be updated
|
||||
assert "new.txt" in updates["vfs"]
|
||||
|
||||
# Plots should be bubbled up
|
||||
assert len(updates["plots"]) == 1
|
||||
assert updates["plots"][0] == "plot1"
|
||||
17
backend/tests/test_data_analyst_state.py
Normal file
17
backend/tests/test_data_analyst_state.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from typing import get_type_hints, List
|
||||
from langchain_core.messages import BaseMessage
|
||||
from ea_chatbot.graph.workers.data_analyst.state import WorkerState
|
||||
|
||||
def test_data_analyst_worker_state():
|
||||
"""Verify that DataAnalyst WorkerState has the required fields."""
|
||||
hints = get_type_hints(WorkerState)
|
||||
|
||||
assert "messages" in hints
|
||||
assert "task" in hints
|
||||
assert hints["task"] == str
|
||||
|
||||
assert "code" in hints
|
||||
assert "output" in hints
|
||||
assert "error" in hints
|
||||
assert "iterations" in hints
|
||||
assert hints["iterations"] == int
|
||||
81
backend/tests/test_data_analyst_worker.py
Normal file
81
backend/tests/test_data_analyst_worker.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
from ea_chatbot.graph.workers.data_analyst.workflow import create_data_analyst_worker, WorkerState
|
||||
|
||||
def test_data_analyst_worker_one_shot():
|
||||
"""Verify a successful one-shot execution of the worker subgraph."""
|
||||
mock_coder = MagicMock()
|
||||
mock_executor = MagicMock()
|
||||
mock_summarizer = MagicMock()
|
||||
|
||||
# Scenario: Coder -> Executor (Success) -> Summarizer -> END
|
||||
mock_coder.return_value = {"code": "print(1)", "error": None, "iterations": 1}
|
||||
mock_executor.return_value = {"output": "1\n", "error": None, "plots": []}
|
||||
mock_summarizer.return_value = {"result": "Result is 1"}
|
||||
|
||||
graph = create_data_analyst_worker(
|
||||
coder=mock_coder,
|
||||
executor=mock_executor,
|
||||
summarizer=mock_summarizer
|
||||
)
|
||||
|
||||
initial_state = WorkerState(
|
||||
messages=[],
|
||||
task="Calculate 1+1",
|
||||
code=None,
|
||||
output=None,
|
||||
error=None,
|
||||
iterations=0,
|
||||
plots=[],
|
||||
vfs_state={},
|
||||
result=None
|
||||
)
|
||||
|
||||
final_state = graph.invoke(initial_state)
|
||||
|
||||
assert final_state["result"] == "Result is 1"
|
||||
assert mock_coder.call_count == 1
|
||||
assert mock_executor.call_count == 1
|
||||
assert mock_summarizer.call_count == 1
|
||||
|
||||
def test_data_analyst_worker_retry():
|
||||
"""Verify that the worker retries on error."""
|
||||
mock_coder = MagicMock()
|
||||
mock_executor = MagicMock()
|
||||
mock_summarizer = MagicMock()
|
||||
|
||||
# Scenario: Coder (1) -> Executor (Error) -> Router (coder) -> Coder (2) -> Executor (Success) -> Summarizer -> END
|
||||
mock_coder.side_effect = [
|
||||
{"code": "error_code", "error": None, "iterations": 1},
|
||||
{"code": "fixed_code", "error": None, "iterations": 2}
|
||||
]
|
||||
mock_executor.side_effect = [
|
||||
{"output": "", "error": "NameError", "plots": []},
|
||||
{"output": "Success", "error": None, "plots": []}
|
||||
]
|
||||
mock_summarizer.return_value = {"result": "Fixed Result"}
|
||||
|
||||
graph = create_data_analyst_worker(
|
||||
coder=mock_coder,
|
||||
executor=mock_executor,
|
||||
summarizer=mock_summarizer
|
||||
)
|
||||
|
||||
initial_state = WorkerState(
|
||||
messages=[],
|
||||
task="Retry Task",
|
||||
code=None,
|
||||
output=None,
|
||||
error=None,
|
||||
iterations=0,
|
||||
plots=[],
|
||||
vfs_state={},
|
||||
result=None
|
||||
)
|
||||
|
||||
final_state = graph.invoke(initial_state)
|
||||
|
||||
assert final_state["result"] == "Fixed Result"
|
||||
assert mock_coder.call_count == 2
|
||||
assert mock_executor.call_count == 2
|
||||
assert mock_summarizer.call_count == 1
|
||||
79
backend/tests/test_deepagents_e2e.py
Normal file
79
backend/tests/test_deepagents_e2e.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
from ea_chatbot.graph.workflow import create_workflow
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
def test_deepagents_multi_worker_sequential_flow():
|
||||
"""Verify that the Orchestrator can handle a sequence of different workers."""
|
||||
|
||||
mock_analyzer = MagicMock()
|
||||
mock_planner = MagicMock()
|
||||
mock_delegate = MagicMock()
|
||||
mock_analyst = MagicMock()
|
||||
mock_researcher = MagicMock()
|
||||
mock_reflector = MagicMock()
|
||||
mock_synthesizer = MagicMock()
|
||||
|
||||
# 1. Analyzer: Plan
|
||||
mock_analyzer.return_value = {"next_action": "plan"}
|
||||
|
||||
# 2. Planner: Two tasks
|
||||
mock_planner.return_value = {
|
||||
"checklist": [
|
||||
{"task": "Get Numbers", "worker": "data_analyst"},
|
||||
{"task": "Get Facts", "worker": "researcher"}
|
||||
],
|
||||
"current_step": 0
|
||||
}
|
||||
|
||||
# 3. Delegate & Reflector: Loop through tasks
|
||||
mock_delegate.side_effect = [
|
||||
{"next_action": "data_analyst"}, # Step 0
|
||||
{"next_action": "researcher"}, # Step 1
|
||||
{"next_action": "summarize"} # Step 2 (done)
|
||||
]
|
||||
|
||||
mock_analyst.return_value = {"messages": [AIMessage(content="Analyst summary")], "vfs": {"data.csv": "..."}}
|
||||
mock_researcher.return_value = {"messages": [AIMessage(content="Researcher summary")]}
|
||||
|
||||
mock_reflector.side_effect = [
|
||||
{"current_step": 1, "next_action": "delegate"}, # Done with Analyst
|
||||
{"current_step": 2, "next_action": "delegate"} # Done with Researcher
|
||||
]
|
||||
|
||||
mock_synthesizer.return_value = {
|
||||
"messages": [AIMessage(content="Final multi-agent response")],
|
||||
"next_action": "end"
|
||||
}
|
||||
|
||||
app = create_workflow(
|
||||
query_analyzer=mock_analyzer,
|
||||
planner=mock_planner,
|
||||
delegate=mock_delegate,
|
||||
data_analyst_worker=mock_analyst,
|
||||
researcher_worker=mock_researcher,
|
||||
reflector=mock_reflector,
|
||||
synthesizer=mock_synthesizer
|
||||
)
|
||||
|
||||
initial_state = AgentState(
|
||||
messages=[HumanMessage(content="Numbers and facts please")],
|
||||
question="Numbers and facts please",
|
||||
analysis={},
|
||||
next_action="",
|
||||
iterations=0,
|
||||
checklist=[],
|
||||
current_step=0,
|
||||
vfs={},
|
||||
plots=[],
|
||||
dfs={}
|
||||
)
|
||||
|
||||
final_state = app.invoke(initial_state, config={"recursion_limit": 30})
|
||||
|
||||
assert mock_analyst.called
|
||||
assert mock_researcher.called
|
||||
assert mock_reflector.call_count == 2
|
||||
assert "Final multi-agent response" in [m.content for m in final_state["messages"]]
|
||||
assert final_state["current_step"] == 2
|
||||
@@ -1,122 +0,0 @@
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from unittest.mock import MagicMock, patch
|
||||
from matplotlib.figure import Figure
|
||||
from ea_chatbot.graph.nodes.executor import executor_node
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings():
|
||||
with patch("ea_chatbot.graph.nodes.executor.Settings") as MockSettings:
|
||||
mock_settings_instance = MagicMock()
|
||||
mock_settings_instance.db_host = "localhost"
|
||||
mock_settings_instance.db_port = 5432
|
||||
mock_settings_instance.db_user = "user"
|
||||
mock_settings_instance.db_pswd = "pass"
|
||||
mock_settings_instance.db_name = "test_db"
|
||||
mock_settings_instance.db_table = "test_table"
|
||||
MockSettings.return_value = mock_settings_instance
|
||||
yield mock_settings_instance
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_client():
|
||||
with patch("ea_chatbot.graph.nodes.executor.DBClient") as MockDBClient:
|
||||
mock_client_instance = MagicMock()
|
||||
MockDBClient.return_value = mock_client_instance
|
||||
yield mock_client_instance
|
||||
|
||||
def test_executor_node_success_simple_print(mock_settings, mock_db_client):
|
||||
"""Test executing simple code that prints to stdout."""
|
||||
state = {
|
||||
"code": "print('Hello, World!')",
|
||||
"question": "test",
|
||||
"messages": []
|
||||
}
|
||||
|
||||
result = executor_node(state)
|
||||
|
||||
assert "code_output" in result
|
||||
assert "Hello, World!" in result["code_output"]
|
||||
assert result["error"] is None
|
||||
assert result["plots"] == []
|
||||
assert result["dfs"] == {}
|
||||
|
||||
def test_executor_node_success_dataframe(mock_settings, mock_db_client):
|
||||
"""Test executing code that creates a DataFrame."""
|
||||
code = """
|
||||
import pandas as pd
|
||||
df = pd.DataFrame({'a': [1, 2], 'b': [3, 4]})
|
||||
print(df)
|
||||
"""
|
||||
state = {
|
||||
"code": code,
|
||||
"question": "test",
|
||||
"messages": []
|
||||
}
|
||||
|
||||
result = executor_node(state)
|
||||
|
||||
assert "code_output" in result
|
||||
assert "a b" in result["code_output"] # Check part of DF string representation
|
||||
assert "dfs" in result
|
||||
assert "df" in result["dfs"]
|
||||
assert isinstance(result["dfs"]["df"], pd.DataFrame)
|
||||
|
||||
def test_executor_node_success_plot(mock_settings, mock_db_client):
|
||||
"""Test executing code that generates a plot."""
|
||||
code = """
|
||||
import matplotlib.pyplot as plt
|
||||
fig = plt.figure()
|
||||
plots.append(fig)
|
||||
print('Plot generated')
|
||||
"""
|
||||
state = {
|
||||
"code": code,
|
||||
"question": "test",
|
||||
"messages": []
|
||||
}
|
||||
|
||||
result = executor_node(state)
|
||||
|
||||
assert "Plot generated" in result["code_output"]
|
||||
assert "plots" in result
|
||||
assert len(result["plots"]) == 1
|
||||
assert isinstance(result["plots"][0], Figure)
|
||||
|
||||
def test_executor_node_error_syntax(mock_settings, mock_db_client):
|
||||
"""Test executing code with a syntax error."""
|
||||
state = {
|
||||
"code": "print('Hello World", # Missing closing quote
|
||||
"question": "test",
|
||||
"messages": []
|
||||
}
|
||||
|
||||
result = executor_node(state)
|
||||
|
||||
assert result["error"] is not None
|
||||
assert "SyntaxError" in result["error"]
|
||||
|
||||
def test_executor_node_error_runtime(mock_settings, mock_db_client):
|
||||
"""Test executing code with a runtime error."""
|
||||
state = {
|
||||
"code": "print(1 / 0)",
|
||||
"question": "test",
|
||||
"messages": []
|
||||
}
|
||||
|
||||
result = executor_node(state)
|
||||
|
||||
assert result["error"] is not None
|
||||
assert "ZeroDivisionError" in result["error"]
|
||||
|
||||
def test_executor_node_no_code(mock_settings, mock_db_client):
|
||||
"""Test handling when no code is provided."""
|
||||
state = {
|
||||
"code": None,
|
||||
"question": "test",
|
||||
"messages": []
|
||||
}
|
||||
|
||||
result = executor_node(state)
|
||||
|
||||
assert "error" in result
|
||||
assert "No code provided" in result["error"]
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
import pytest
|
||||
import logging
|
||||
from unittest.mock import patch
|
||||
from ea_chatbot.graph.workflow import app
|
||||
from ea_chatbot.graph.workflow import create_workflow
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.utils.logging import get_logger
|
||||
from langchain_community.chat_models import FakeListChatModel
|
||||
@@ -43,10 +43,10 @@ def test_logging_e2e_json_output(tmp_path):
|
||||
"question": "Who won in 2024?",
|
||||
"analysis": None,
|
||||
"next_action": "",
|
||||
"plan": None,
|
||||
"code": None,
|
||||
"code_output": None,
|
||||
"error": None,
|
||||
"iterations": 0,
|
||||
"checklist": [],
|
||||
"current_step": 0,
|
||||
"vfs": {},
|
||||
"plots": [],
|
||||
"dfs": {}
|
||||
}
|
||||
@@ -57,6 +57,20 @@ def test_logging_e2e_json_output(tmp_path):
|
||||
|
||||
fake_clarify = FakeListChatModel(responses=["Please specify."])
|
||||
|
||||
# Create a test app without interrupts
|
||||
# We need to manually compile it here to avoid the global 'app' which has interrupts
|
||||
from langgraph.graph import StateGraph, END
|
||||
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node
|
||||
from ea_chatbot.graph.nodes.clarification import clarification_node
|
||||
|
||||
workflow = StateGraph(AgentState)
|
||||
workflow.add_node("query_analyzer", query_analyzer_node)
|
||||
workflow.add_node("clarification", clarification_node)
|
||||
workflow.set_entry_point("query_analyzer")
|
||||
workflow.add_edge("query_analyzer", "clarification")
|
||||
workflow.add_edge("clarification", END)
|
||||
test_app = workflow.compile()
|
||||
|
||||
with patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model") as mock_llm_factory:
|
||||
mock_llm_factory.return_value = fake_analyzer
|
||||
|
||||
@@ -64,7 +78,7 @@ def test_logging_e2e_json_output(tmp_path):
|
||||
mock_clarify_llm_factory.return_value = fake_clarify
|
||||
|
||||
# Run the graph
|
||||
list(app.stream(initial_state))
|
||||
test_app.invoke(initial_state)
|
||||
|
||||
# Verify file content
|
||||
assert log_file.exists()
|
||||
|
||||
@@ -1,24 +1,27 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from ea_chatbot.graph.nodes.planner import planner_node
|
||||
from ea_chatbot.graph.nodes.researcher import researcher_node
|
||||
from ea_chatbot.graph.nodes.summarizer import summarizer_node
|
||||
from ea_chatbot.schemas import TaskPlanResponse
|
||||
from ea_chatbot.schemas import ChecklistResponse, ChecklistTask
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
|
||||
@pytest.fixture
|
||||
def mock_state_with_history():
|
||||
return {
|
||||
"messages": [
|
||||
HumanMessage(content="Show me the 2024 results for Florida"),
|
||||
AIMessage(content="Here are the results for Florida in 2024...")
|
||||
HumanMessage(content="What about NJ?"),
|
||||
AIMessage(content="NJ has 9 million voters.")
|
||||
],
|
||||
"question": "What about in New Jersey?",
|
||||
"analysis": {"data_required": ["2024 results", "New Jersey"], "unknowns": [], "ambiguities": [], "conditions": []},
|
||||
"question": "Show me the breakdown by county for 2024",
|
||||
"analysis": {
|
||||
"data_required": ["2024 results", "New Jersey"],
|
||||
"unknowns": [],
|
||||
"ambiguities": [],
|
||||
"conditions": []
|
||||
},
|
||||
"next_action": "plan",
|
||||
"summary": "The user is asking about 2024 election results.",
|
||||
"plan": "Plan steps...",
|
||||
"code_output": "Code output..."
|
||||
"summary": "The user is asking about NJ 2024 results.",
|
||||
"checklist": [],
|
||||
"current_step": 0
|
||||
}
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.planner.get_llm_model")
|
||||
@@ -31,49 +34,17 @@ def test_planner_uses_history_and_summary(mock_prompt, mock_get_summary, mock_ge
|
||||
mock_structured_llm = MagicMock()
|
||||
mock_llm_instance.with_structured_output.return_value = mock_structured_llm
|
||||
|
||||
mock_structured_llm.invoke.return_value = TaskPlanResponse(
|
||||
mock_structured_llm.invoke.return_value = ChecklistResponse(
|
||||
goal="goal",
|
||||
reflection="reflection",
|
||||
context={
|
||||
"initial_context": "context",
|
||||
"assumptions": [],
|
||||
"constraints": []
|
||||
},
|
||||
steps=["Step 1: test"]
|
||||
checklist=[ChecklistTask(task="Step 1: test", worker="data_analyst")]
|
||||
)
|
||||
|
||||
planner_node(mock_state_with_history)
|
||||
|
||||
mock_prompt.format_messages.assert_called_once()
|
||||
kwargs = mock_prompt.format_messages.call_args[1]
|
||||
assert kwargs["question"] == "What about in New Jersey?"
|
||||
assert kwargs["summary"] == mock_state_with_history["summary"]
|
||||
assert len(kwargs["history"]) == 2
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
|
||||
@patch("ea_chatbot.graph.nodes.researcher.RESEARCHER_PROMPT")
|
||||
def test_researcher_uses_history_and_summary(mock_prompt, mock_get_llm, mock_state_with_history):
|
||||
mock_llm_instance = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm_instance
|
||||
|
||||
researcher_node(mock_state_with_history)
|
||||
|
||||
mock_prompt.format_messages.assert_called_once()
|
||||
kwargs = mock_prompt.format_messages.call_args[1]
|
||||
assert kwargs["question"] == "What about in New Jersey?"
|
||||
assert kwargs["summary"] == mock_state_with_history["summary"]
|
||||
assert len(kwargs["history"]) == 2
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.summarizer.get_llm_model")
|
||||
@patch("ea_chatbot.graph.nodes.summarizer.SUMMARIZER_PROMPT")
|
||||
def test_summarizer_uses_history_and_summary(mock_prompt, mock_get_llm, mock_state_with_history):
|
||||
mock_llm_instance = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm_instance
|
||||
|
||||
summarizer_node(mock_state_with_history)
|
||||
|
||||
mock_prompt.format_messages.assert_called_once()
|
||||
kwargs = mock_prompt.format_messages.call_args[1]
|
||||
assert kwargs["question"] == "What about in New Jersey?"
|
||||
assert kwargs["summary"] == mock_state_with_history["summary"]
|
||||
assert len(kwargs["history"]) == 2
|
||||
# Verify history and summary were passed to prompt format
|
||||
# We check the arguments passed to the mock_prompt.format_messages
|
||||
call_args = mock_prompt.format_messages.call_args[1]
|
||||
assert call_args["summary"] == "The user is asking about NJ 2024 results."
|
||||
assert len(call_args["history"]) == 2
|
||||
assert "breakdown by county" in call_args["question"]
|
||||
|
||||
126
backend/tests/test_node_model_selection.py
Normal file
126
backend/tests/test_node_model_selection.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node
|
||||
from ea_chatbot.graph.nodes.planner import planner_node
|
||||
from ea_chatbot.graph.nodes.reflector import reflector_node
|
||||
from ea_chatbot.graph.nodes.synthesizer import synthesizer_node
|
||||
from ea_chatbot.graph.workers.data_analyst.nodes.coder import coder_node as analyst_coder
|
||||
from ea_chatbot.graph.workers.researcher.nodes.searcher import searcher_node as researcher_searcher
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.graph.workers.data_analyst.state import WorkerState as AnalystState
|
||||
from ea_chatbot.graph.workers.researcher.state import WorkerState as ResearcherState
|
||||
from ea_chatbot.config import Settings
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings():
|
||||
settings = Settings()
|
||||
settings.query_analyzer_llm.model = "model-qa"
|
||||
settings.planner_llm.model = "model-planner"
|
||||
settings.reflector_llm.model = "model-reflector"
|
||||
settings.synthesizer_llm.model = "model-synthesizer"
|
||||
settings.coder_llm.model = "model-coder"
|
||||
settings.researcher_llm.model = "model-researcher"
|
||||
return settings
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
|
||||
@patch("ea_chatbot.graph.nodes.query_analyzer.Settings")
|
||||
def test_query_analyzer_model(mock_settings_cls, mock_get_llm, mock_settings):
|
||||
mock_settings_cls.return_value = mock_settings
|
||||
mock_llm = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
state = AgentState(question="test", messages=[], checklist=[], current_step=0, iterations=0, vfs={}, plots=[], dfs={}, next_action="", analysis={})
|
||||
|
||||
try:
|
||||
query_analyzer_node(state)
|
||||
except:
|
||||
pass # We only care about the call
|
||||
|
||||
# Verify it was called with the QA config
|
||||
args = mock_get_llm.call_args[0][0]
|
||||
assert args.model == "model-qa"
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.planner.get_llm_model")
|
||||
@patch("ea_chatbot.graph.nodes.planner.Settings")
|
||||
def test_planner_model(mock_settings_cls, mock_get_llm, mock_settings):
|
||||
mock_settings_cls.return_value = mock_settings
|
||||
mock_llm = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
state = AgentState(question="test", messages=[], checklist=[], current_step=0, iterations=0, vfs={}, plots=[], dfs={}, next_action="", analysis={})
|
||||
|
||||
try:
|
||||
planner_node(state)
|
||||
except:
|
||||
pass
|
||||
|
||||
args = mock_get_llm.call_args[0][0]
|
||||
assert args.model == "model-planner"
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.reflector.get_llm_model")
|
||||
@patch("ea_chatbot.graph.nodes.reflector.Settings")
|
||||
def test_reflector_model(mock_settings_cls, mock_get_llm, mock_settings):
|
||||
mock_settings_cls.return_value = mock_settings
|
||||
mock_llm = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
state = AgentState(question="test", messages=[], checklist=[{"task": "T1"}], current_step=0, iterations=0, vfs={}, plots=[], dfs={}, next_action="", analysis={}, summary="Worker did X")
|
||||
|
||||
try:
|
||||
reflector_node(state)
|
||||
except:
|
||||
pass
|
||||
|
||||
args = mock_get_llm.call_args[0][0]
|
||||
assert args.model == "model-reflector"
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.synthesizer.get_llm_model")
|
||||
@patch("ea_chatbot.graph.nodes.synthesizer.Settings")
|
||||
def test_synthesizer_model(mock_settings_cls, mock_get_llm, mock_settings):
|
||||
mock_settings_cls.return_value = mock_settings
|
||||
mock_llm = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
state = AgentState(question="test", messages=[], checklist=[], current_step=0, iterations=0, vfs={}, plots=[], dfs={}, next_action="", analysis={})
|
||||
|
||||
try:
|
||||
synthesizer_node(state)
|
||||
except:
|
||||
pass
|
||||
|
||||
args = mock_get_llm.call_args[0][0]
|
||||
assert args.model == "model-synthesizer"
|
||||
|
||||
@patch("ea_chatbot.graph.workers.data_analyst.nodes.coder.get_llm_model")
|
||||
@patch("ea_chatbot.graph.workers.data_analyst.nodes.coder.Settings")
|
||||
def test_analyst_coder_model(mock_settings_cls, mock_get_llm, mock_settings):
|
||||
mock_settings_cls.return_value = mock_settings
|
||||
mock_llm = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
state = AnalystState(task="test", messages=[], iterations=0, vfs_state={}, plots=[], result=None, code=None, output=None, error=None)
|
||||
|
||||
try:
|
||||
analyst_coder(state)
|
||||
except:
|
||||
pass
|
||||
|
||||
args = mock_get_llm.call_args[0][0]
|
||||
assert args.model == "model-coder"
|
||||
|
||||
@patch("ea_chatbot.graph.workers.researcher.nodes.searcher.get_llm_model")
|
||||
@patch("ea_chatbot.graph.workers.researcher.nodes.searcher.Settings")
|
||||
def test_researcher_searcher_model(mock_settings_cls, mock_get_llm, mock_settings):
|
||||
mock_settings_cls.return_value = mock_settings
|
||||
mock_llm = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
state = ResearcherState(task="test", messages=[], iterations=0, result=None, queries=[], raw_results=[])
|
||||
|
||||
try:
|
||||
researcher_searcher(state)
|
||||
except:
|
||||
pass
|
||||
|
||||
args = mock_get_llm.call_args[0][0]
|
||||
assert args.model == "model-researcher"
|
||||
56
backend/tests/test_orchestrator_delegate.py
Normal file
56
backend/tests/test_orchestrator_delegate.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from ea_chatbot.graph.nodes.delegate import delegate_node
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
|
||||
def test_delegate_node_data_analyst():
|
||||
"""Verify that the delegate node routes to data_analyst."""
|
||||
state = AgentState(
|
||||
checklist=[{"task": "Analyze data", "worker": "data_analyst"}],
|
||||
current_step=0,
|
||||
messages=[],
|
||||
question="test",
|
||||
analysis={},
|
||||
next_action="",
|
||||
iterations=0,
|
||||
vfs={},
|
||||
plots=[],
|
||||
dfs={}
|
||||
)
|
||||
|
||||
result = delegate_node(state)
|
||||
assert result["next_action"] == "data_analyst"
|
||||
|
||||
def test_delegate_node_researcher():
|
||||
"""Verify that the delegate node routes to researcher."""
|
||||
state = AgentState(
|
||||
checklist=[{"task": "Search web", "worker": "researcher"}],
|
||||
current_step=0,
|
||||
messages=[],
|
||||
question="test",
|
||||
analysis={},
|
||||
next_action="",
|
||||
iterations=0,
|
||||
vfs={},
|
||||
plots=[],
|
||||
dfs={}
|
||||
)
|
||||
|
||||
result = delegate_node(state)
|
||||
assert result["next_action"] == "researcher"
|
||||
|
||||
def test_delegate_node_finished():
|
||||
"""Verify that the delegate node routes to summarize if checklist is complete."""
|
||||
state = AgentState(
|
||||
checklist=[{"task": "Task 1", "worker": "data_analyst"}],
|
||||
current_step=1, # Already finished the only task
|
||||
messages=[],
|
||||
question="test",
|
||||
analysis={},
|
||||
next_action="",
|
||||
iterations=0,
|
||||
vfs={},
|
||||
plots=[],
|
||||
dfs={}
|
||||
)
|
||||
|
||||
result = delegate_node(state)
|
||||
assert result["next_action"] == "summarize"
|
||||
136
backend/tests/test_orchestrator_loop.py
Normal file
136
backend/tests/test_orchestrator_loop.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
from ea_chatbot.graph.workflow import create_workflow
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
def test_orchestrator_full_flow():
|
||||
"""Verify the full Orchestrator-Workers flow via direct node injection."""
|
||||
|
||||
mock_analyzer = MagicMock()
|
||||
mock_planner = MagicMock()
|
||||
mock_delegate = MagicMock()
|
||||
mock_worker = MagicMock()
|
||||
mock_reflector = MagicMock()
|
||||
mock_synthesizer = MagicMock()
|
||||
mock_summarize_conv = MagicMock()
|
||||
|
||||
# 1. Analyzer: Proceed to planning
|
||||
mock_analyzer.return_value = {"next_action": "plan"}
|
||||
|
||||
# 2. Planner: Generate checklist
|
||||
mock_planner.return_value = {
|
||||
"checklist": [{"task": "T1", "worker": "data_analyst"}],
|
||||
"current_step": 0
|
||||
}
|
||||
|
||||
# 3. Delegate: Route to data_analyst
|
||||
mock_delegate.side_effect = [
|
||||
{"next_action": "data_analyst"}, # First call
|
||||
{"next_action": "summarize"} # Second call (after reflector)
|
||||
]
|
||||
|
||||
# 4. Worker: Success
|
||||
mock_worker.return_value = {
|
||||
"messages": [AIMessage(content="Worker result")],
|
||||
"vfs": {"res.txt": "data"}
|
||||
}
|
||||
|
||||
# 5. Reflector: Advance
|
||||
mock_reflector.return_value = {
|
||||
"current_step": 1,
|
||||
"next_action": "delegate"
|
||||
}
|
||||
|
||||
# 6. Synthesizer: Final answer
|
||||
mock_synthesizer.return_value = {
|
||||
"messages": [AIMessage(content="Final synthesized answer")],
|
||||
"next_action": "end"
|
||||
}
|
||||
|
||||
# 7. Summarize Conv: End
|
||||
mock_summarize_conv.return_value = {"summary": "Done"}
|
||||
|
||||
# Create workflow with injected mocks
|
||||
app = create_workflow(
|
||||
query_analyzer=mock_analyzer,
|
||||
planner=mock_planner,
|
||||
delegate=mock_delegate,
|
||||
data_analyst_worker=mock_worker,
|
||||
reflector=mock_reflector,
|
||||
synthesizer=mock_synthesizer,
|
||||
summarize_conversation=mock_summarize_conv
|
||||
)
|
||||
|
||||
initial_state = AgentState(
|
||||
messages=[HumanMessage(content="Explain results")],
|
||||
question="Explain results",
|
||||
analysis={},
|
||||
next_action="",
|
||||
iterations=0,
|
||||
checklist=[],
|
||||
current_step=0,
|
||||
vfs={},
|
||||
plots=[],
|
||||
dfs={}
|
||||
)
|
||||
|
||||
final_state = app.invoke(initial_state)
|
||||
|
||||
assert mock_analyzer.called
|
||||
assert mock_planner.called
|
||||
assert mock_delegate.call_count == 2
|
||||
assert mock_worker.called
|
||||
assert mock_reflector.called
|
||||
assert mock_synthesizer.called
|
||||
assert "Final synthesized answer" in [m.content for m in final_state["messages"]]
|
||||
|
||||
def test_orchestrator_researcher_flow():
|
||||
"""Verify that the Orchestrator can route to the researcher worker."""
|
||||
|
||||
mock_analyzer = MagicMock()
|
||||
mock_planner = MagicMock()
|
||||
mock_delegate = MagicMock()
|
||||
mock_researcher = MagicMock()
|
||||
mock_reflector = MagicMock()
|
||||
mock_synthesizer = MagicMock()
|
||||
|
||||
mock_analyzer.return_value = {"next_action": "plan"}
|
||||
mock_planner.return_value = {
|
||||
"checklist": [{"task": "Search news", "worker": "researcher"}],
|
||||
"current_step": 0
|
||||
}
|
||||
mock_delegate.side_effect = [
|
||||
{"next_action": "researcher"},
|
||||
{"next_action": "summarize"}
|
||||
]
|
||||
mock_researcher.return_value = {"messages": [AIMessage(content="News found")]}
|
||||
mock_reflector.return_value = {"current_step": 1, "next_action": "delegate"}
|
||||
mock_synthesizer.return_value = {"messages": [AIMessage(content="Final News Summary")], "next_action": "end"}
|
||||
|
||||
app = create_workflow(
|
||||
query_analyzer=mock_analyzer,
|
||||
planner=mock_planner,
|
||||
delegate=mock_delegate,
|
||||
researcher_worker=mock_researcher,
|
||||
reflector=mock_reflector,
|
||||
synthesizer=mock_synthesizer
|
||||
)
|
||||
|
||||
initial_state = AgentState(
|
||||
messages=[HumanMessage(content="What's the news?")],
|
||||
question="What's the news?",
|
||||
analysis={},
|
||||
next_action="",
|
||||
iterations=0,
|
||||
checklist=[],
|
||||
current_step=0,
|
||||
vfs={},
|
||||
plots=[],
|
||||
dfs={}
|
||||
)
|
||||
|
||||
final_state = app.invoke(initial_state)
|
||||
|
||||
assert mock_researcher.called
|
||||
assert "Final News Summary" in [m.content for m in final_state["messages"]]
|
||||
34
backend/tests/test_orchestrator_planner.py
Normal file
34
backend/tests/test_orchestrator_planner.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from typing import get_type_hints, List
|
||||
from ea_chatbot.graph.nodes.planner import planner_node
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
|
||||
def test_planner_node_checklist():
|
||||
"""Verify that the planner node generates a checklist."""
|
||||
state = AgentState(
|
||||
messages=[],
|
||||
question="How many voters are in Florida and what is the current news?",
|
||||
analysis={"requires_dataset": True},
|
||||
next_action="plan",
|
||||
iterations=0,
|
||||
checklist=[],
|
||||
current_step=0,
|
||||
vfs={},
|
||||
plots=[],
|
||||
dfs={}
|
||||
)
|
||||
|
||||
# Mocking the LLM would be ideal, but for now we'll check the returned keys
|
||||
# and assume the implementation provides them.
|
||||
# In a real TDD, we'd mock the LLM to return a specific structure.
|
||||
|
||||
# For now, let's assume the task is to update 'planner_node' to return these keys.
|
||||
result = planner_node(state)
|
||||
|
||||
assert "checklist" in result
|
||||
assert isinstance(result["checklist"], list)
|
||||
assert len(result["checklist"]) > 0
|
||||
assert "task" in result["checklist"][0]
|
||||
assert "worker" in result["checklist"][0] # 'data_analyst' or 'researcher'
|
||||
|
||||
assert "current_step" in result
|
||||
assert result["current_step"] == 0
|
||||
30
backend/tests/test_orchestrator_reflector.py
Normal file
30
backend/tests/test_orchestrator_reflector.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.graph.nodes.reflector import reflector_node
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
|
||||
def test_reflector_node_satisfied():
|
||||
"""Verify that the reflector node increments current_step when satisfied."""
|
||||
state = AgentState(
|
||||
checklist=[{"task": "Analyze votes", "worker": "data_analyst"}],
|
||||
current_step=0,
|
||||
messages=[],
|
||||
question="test",
|
||||
analysis={},
|
||||
next_action="",
|
||||
iterations=0,
|
||||
vfs={},
|
||||
plots=[],
|
||||
dfs={},
|
||||
summary="Worker successfully analyzed votes and found 5 million."
|
||||
)
|
||||
|
||||
# Mocking the LLM to return 'satisfied=True'
|
||||
with patch("ea_chatbot.graph.nodes.reflector.get_llm_model") as mock_get_llm:
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.with_structured_output.return_value.invoke.return_value = MagicMock(satisfied=True, reasoning="Good.")
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
result = reflector_node(state)
|
||||
|
||||
assert result["current_step"] == 1
|
||||
assert result["next_action"] == "delegate" # Route back to delegate for next task
|
||||
31
backend/tests/test_orchestrator_synthesizer.py
Normal file
31
backend/tests/test_orchestrator_synthesizer.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.graph.nodes.synthesizer import synthesizer_node
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
def test_synthesizer_node_success():
|
||||
"""Verify that the synthesizer node produces a final response."""
|
||||
state = AgentState(
|
||||
messages=[AIMessage(content="Worker 1 found data."), AIMessage(content="Worker 2 searched web.")],
|
||||
question="What are the results?",
|
||||
checklist=[],
|
||||
current_step=0,
|
||||
iterations=0,
|
||||
vfs={},
|
||||
plots=[],
|
||||
dfs={},
|
||||
next_action="",
|
||||
analysis={}
|
||||
)
|
||||
|
||||
# Mocking the LLM
|
||||
with patch("ea_chatbot.graph.nodes.synthesizer.get_llm_model") as mock_get_llm:
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.invoke.return_value = AIMessage(content="Final synthesized answer.")
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
result = synthesizer_node(state)
|
||||
|
||||
assert len(result["messages"]) == 1
|
||||
assert result["messages"][0].content == "Final synthesized answer."
|
||||
assert result["next_action"] == "end"
|
||||
@@ -1,6 +1,7 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.graph.nodes.planner import planner_node
|
||||
from ea_chatbot.schemas import ChecklistResponse, ChecklistTask
|
||||
|
||||
@pytest.fixture
|
||||
def mock_state():
|
||||
@@ -8,39 +9,40 @@ def mock_state():
|
||||
"messages": [],
|
||||
"question": "Show me results for New Jersey",
|
||||
"analysis": {
|
||||
# "requires_dataset" removed as it's no longer used
|
||||
"expert": "Data Analyst",
|
||||
"data": "NJ data",
|
||||
"unknown": "results",
|
||||
"condition": "state=NJ"
|
||||
},
|
||||
"next_action": "plan",
|
||||
"plan": None
|
||||
"checklist": [],
|
||||
"current_step": 0
|
||||
}
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.planner.get_llm_model")
|
||||
@patch("ea_chatbot.utils.database_inspection.get_data_summary")
|
||||
def test_planner_node(mock_get_summary, mock_get_llm, mock_state):
|
||||
"""Test planner node with unified prompt."""
|
||||
"""Test planner node with checklist generation."""
|
||||
mock_get_summary.return_value = "Column: Name, Type: text"
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
from ea_chatbot.schemas import TaskPlanResponse, TaskPlanContext
|
||||
mock_plan = TaskPlanResponse(
|
||||
mock_response = ChecklistResponse(
|
||||
goal="Get NJ results",
|
||||
reflection="The user wants NJ results",
|
||||
context=TaskPlanContext(initial_context="NJ data", assumptions=[], constraints=[]),
|
||||
steps=["Step 1: Load data", "Step 2: Filter by NJ"]
|
||||
checklist=[
|
||||
ChecklistTask(task="Query NJ data", worker="data_analyst")
|
||||
]
|
||||
)
|
||||
mock_llm.with_structured_output.return_value.invoke.return_value = mock_plan
|
||||
mock_llm.with_structured_output.return_value.invoke.return_value = mock_response
|
||||
|
||||
result = planner_node(mock_state)
|
||||
|
||||
assert "plan" in result
|
||||
assert "Step 1: Load data" in result["plan"]
|
||||
assert "Step 2: Filter by NJ" in result["plan"]
|
||||
assert "checklist" in result
|
||||
assert result["checklist"][0]["task"] == "Query NJ data"
|
||||
assert result["current_step"] == 0
|
||||
assert result["summary"] == "The user wants NJ results"
|
||||
|
||||
# Verify helper was called
|
||||
mock_get_summary.assert_called_once()
|
||||
mock_get_summary.assert_called_once()
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_openai import ChatOpenAI
|
||||
from ea_chatbot.graph.nodes.researcher import researcher_node
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm():
|
||||
with patch("ea_chatbot.graph.nodes.researcher.get_llm_model") as mock_get_llm:
|
||||
mock_llm_instance = MagicMock(spec=ChatOpenAI)
|
||||
mock_get_llm.return_value = mock_llm_instance
|
||||
yield mock_llm_instance
|
||||
|
||||
def test_researcher_node_success(mock_llm):
|
||||
"""Test that researcher_node invokes LLM with web_search tool and returns messages."""
|
||||
state = {
|
||||
"question": "What is the capital of France?",
|
||||
"messages": []
|
||||
}
|
||||
|
||||
mock_llm_with_tools = MagicMock()
|
||||
mock_llm.bind_tools.return_value = mock_llm_with_tools
|
||||
mock_llm_with_tools.invoke.return_value = AIMessage(content="The capital of France is Paris.")
|
||||
|
||||
result = researcher_node(state)
|
||||
|
||||
assert mock_llm.bind_tools.called
|
||||
# Check that it was called with web_search
|
||||
args, kwargs = mock_llm.bind_tools.call_args
|
||||
assert {"type": "web_search"} in args[0]
|
||||
|
||||
assert mock_llm_with_tools.invoke.called
|
||||
assert "messages" in result
|
||||
assert result["messages"][0].content == "The capital of France is Paris."
|
||||
@@ -1,62 +0,0 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from ea_chatbot.graph.nodes.researcher import researcher_node
|
||||
|
||||
@pytest.fixture
|
||||
def base_state():
|
||||
return {
|
||||
"question": "Who won the 2024 election?",
|
||||
"messages": [],
|
||||
"summary": ""
|
||||
}
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
|
||||
def test_researcher_binds_openai_search(mock_get_llm, base_state):
|
||||
"""Test that OpenAI LLM binds 'web_search' tool."""
|
||||
mock_llm = MagicMock(spec=ChatOpenAI)
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
mock_llm_with_tools = MagicMock()
|
||||
mock_llm.bind_tools.return_value = mock_llm_with_tools
|
||||
mock_llm_with_tools.invoke.return_value = AIMessage(content="OpenAI Search Result")
|
||||
|
||||
result = researcher_node(base_state)
|
||||
|
||||
# Verify bind_tools called with correct OpenAI tool
|
||||
mock_llm.bind_tools.assert_called_once_with([{"type": "web_search"}])
|
||||
assert result["messages"][0].content == "OpenAI Search Result"
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
|
||||
def test_researcher_binds_google_search(mock_get_llm, base_state):
|
||||
"""Test that Google LLM binds 'google_search' tool."""
|
||||
mock_llm = MagicMock(spec=ChatGoogleGenerativeAI)
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
mock_llm_with_tools = MagicMock()
|
||||
mock_llm.bind_tools.return_value = mock_llm_with_tools
|
||||
mock_llm_with_tools.invoke.return_value = AIMessage(content="Google Search Result")
|
||||
|
||||
result = researcher_node(base_state)
|
||||
|
||||
# Verify bind_tools called with correct Google tool
|
||||
mock_llm.bind_tools.assert_called_once_with([{"google_search": {}}])
|
||||
assert result["messages"][0].content == "Google Search Result"
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
|
||||
def test_researcher_fallback_on_bind_error(mock_get_llm, base_state):
|
||||
"""Test that researcher falls back to basic LLM if bind_tools fails."""
|
||||
mock_llm = MagicMock(spec=ChatOpenAI)
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
# Simulate bind_tools failing (e.g. model doesn't support it)
|
||||
mock_llm.bind_tools.side_effect = Exception("Not supported")
|
||||
mock_llm.invoke.return_value = AIMessage(content="Basic Result")
|
||||
|
||||
result = researcher_node(base_state)
|
||||
|
||||
# Should still succeed using the base LLM
|
||||
assert result["messages"][0].content == "Basic Result"
|
||||
mock_llm.invoke.assert_called_once()
|
||||
13
backend/tests/test_researcher_state.py
Normal file
13
backend/tests/test_researcher_state.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from typing import get_type_hints, List
|
||||
from ea_chatbot.graph.workers.researcher.state import WorkerState
|
||||
|
||||
def test_researcher_worker_state():
|
||||
"""Verify that Researcher WorkerState has the required fields."""
|
||||
hints = get_type_hints(WorkerState)
|
||||
|
||||
assert "messages" in hints
|
||||
assert "task" in hints
|
||||
assert "queries" in hints
|
||||
assert "raw_results" in hints
|
||||
assert "iterations" in hints
|
||||
assert "result" in hints
|
||||
67
backend/tests/test_researcher_worker.py
Normal file
67
backend/tests/test_researcher_worker.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
from ea_chatbot.graph.workers.researcher.workflow import create_researcher_worker, WorkerState
|
||||
from ea_chatbot.graph.workers.researcher.mapping import prepare_researcher_input, merge_researcher_output
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
def test_researcher_worker_flow():
|
||||
"""Verify that the Researcher worker flow works as expected."""
|
||||
mock_searcher = MagicMock()
|
||||
mock_summarizer = MagicMock()
|
||||
|
||||
mock_searcher.return_value = {
|
||||
"raw_results": ["Result A"],
|
||||
"messages": [AIMessage(content="Search result")]
|
||||
}
|
||||
mock_summarizer.return_value = {"result": "Consolidated Summary"}
|
||||
|
||||
graph = create_researcher_worker(
|
||||
searcher=mock_searcher,
|
||||
summarizer=mock_summarizer
|
||||
)
|
||||
|
||||
initial_state = WorkerState(
|
||||
messages=[],
|
||||
task="Find governor",
|
||||
queries=[],
|
||||
raw_results=[],
|
||||
iterations=0,
|
||||
result=None
|
||||
)
|
||||
|
||||
final_state = graph.invoke(initial_state)
|
||||
|
||||
assert final_state["result"] == "Consolidated Summary"
|
||||
assert mock_searcher.called
|
||||
assert mock_summarizer.called
|
||||
|
||||
def test_researcher_mapping():
|
||||
"""Verify that we correctly map states for the researcher."""
|
||||
global_state = AgentState(
|
||||
checklist=[{"task": "Search X", "worker": "researcher"}],
|
||||
current_step=0,
|
||||
messages=[],
|
||||
question="test",
|
||||
analysis={},
|
||||
next_action="",
|
||||
iterations=0,
|
||||
vfs={},
|
||||
plots=[],
|
||||
dfs={}
|
||||
)
|
||||
|
||||
worker_input = prepare_researcher_input(global_state)
|
||||
assert worker_input["task"] == "Search X"
|
||||
|
||||
worker_output = WorkerState(
|
||||
messages=[],
|
||||
task="Search X",
|
||||
queries=[],
|
||||
raw_results=[],
|
||||
iterations=1,
|
||||
result="Found X"
|
||||
)
|
||||
|
||||
updates = merge_researcher_output(worker_output)
|
||||
assert updates["messages"][0].content == "Found X"
|
||||
55
backend/tests/test_review_fix_clarification.py
Normal file
55
backend/tests/test_review_fix_clarification.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.graph.workflow import create_workflow
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
|
||||
def test_clarification_flow_immediate_execution():
|
||||
"""Verify that an ambiguous query immediately executes the clarification node without interruption."""
|
||||
|
||||
mock_analyzer = MagicMock()
|
||||
mock_clarification = MagicMock()
|
||||
|
||||
# 1. Analyzer returns 'clarify'
|
||||
mock_analyzer.return_value = {"next_action": "clarify"}
|
||||
|
||||
# 2. Clarification node returns a question
|
||||
mock_clarification.return_value = {"messages": [AIMessage(content="What year?")]}
|
||||
|
||||
# Create workflow without other nodes since they won't be reached
|
||||
# We still need to provide mock planners etc. to create_workflow
|
||||
app = create_workflow(
|
||||
query_analyzer=mock_analyzer,
|
||||
clarification=mock_clarification,
|
||||
planner=MagicMock(),
|
||||
delegate=MagicMock(),
|
||||
data_analyst_worker=MagicMock(),
|
||||
researcher_worker=MagicMock(),
|
||||
reflector=MagicMock(),
|
||||
synthesizer=MagicMock(),
|
||||
summarize_conversation=MagicMock()
|
||||
)
|
||||
|
||||
initial_state = AgentState(
|
||||
messages=[HumanMessage(content="Who won?")],
|
||||
question="Who won?",
|
||||
analysis={},
|
||||
next_action="",
|
||||
iterations=0,
|
||||
checklist=[],
|
||||
current_step=0,
|
||||
vfs={},
|
||||
plots=[],
|
||||
dfs={}
|
||||
)
|
||||
|
||||
# Run the graph
|
||||
final_state = app.invoke(initial_state)
|
||||
|
||||
# Assertions
|
||||
assert mock_analyzer.called
|
||||
assert mock_clarification.called
|
||||
|
||||
# Verify the state contains the clarification message
|
||||
assert len(final_state["messages"]) > 0
|
||||
assert "What year?" in [m.content for m in final_state["messages"]]
|
||||
48
backend/tests/test_review_fix_integration.py
Normal file
48
backend/tests/test_review_fix_integration.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import pytest
|
||||
from ea_chatbot.graph.workflow import create_workflow, data_analyst_worker_runnable
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from unittest.mock import MagicMock, patch
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
|
||||
def test_worker_merge_sets_summary_for_reflector(monkeypatch):
|
||||
"""Verify that worker node (runnable) sets the 'summary' field for the Reflector."""
|
||||
|
||||
state = AgentState(
|
||||
messages=[HumanMessage(content="test")],
|
||||
question="test",
|
||||
checklist=[{"task": "Analyze data", "worker": "data_analyst"}],
|
||||
current_step=0,
|
||||
iterations=0,
|
||||
vfs={},
|
||||
plots=[],
|
||||
dfs={},
|
||||
next_action="",
|
||||
analysis={},
|
||||
summary="Initial Planner Summary" # Stale summary
|
||||
)
|
||||
|
||||
# Create a mock for the invoke method
|
||||
mock_invoke = MagicMock()
|
||||
mock_invoke.return_value = {
|
||||
"summary": "Actual Worker Findings",
|
||||
"messages": [AIMessage(content="Actual Worker Findings")],
|
||||
"vfs": {},
|
||||
"plots": []
|
||||
}
|
||||
|
||||
# Manually replace the runnable with a mock object that has an invoke method
|
||||
mock_runnable = MagicMock()
|
||||
mock_runnable.invoke = mock_invoke
|
||||
monkeypatch.setattr("ea_chatbot.graph.workflow.data_analyst_worker_runnable", mock_runnable)
|
||||
|
||||
# Execute via the module reference (which is now mocked)
|
||||
from ea_chatbot.graph.workflow import data_analyst_worker_runnable
|
||||
updates = data_analyst_worker_runnable.invoke(state)
|
||||
|
||||
# Verify that 'summary' is in updates and has the worker result
|
||||
assert "summary" in updates
|
||||
assert updates["summary"] == "Actual Worker Findings"
|
||||
|
||||
# When applied to state, it should overwrite the stale summary
|
||||
state.update(updates)
|
||||
assert state["summary"] == "Actual Worker Findings"
|
||||
35
backend/tests/test_review_fix_reflector.py
Normal file
35
backend/tests/test_review_fix_reflector.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.graph.nodes.reflector import reflector_node
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
|
||||
def test_reflector_does_not_advance_on_failure():
|
||||
"""Verify that reflector does not increment current_step if not satisfied."""
|
||||
state = AgentState(
|
||||
checklist=[{"task": "Task 1", "worker": "data_analyst"}],
|
||||
current_step=0,
|
||||
messages=[],
|
||||
question="test",
|
||||
analysis={},
|
||||
next_action="",
|
||||
iterations=0,
|
||||
vfs={},
|
||||
plots=[],
|
||||
dfs={},
|
||||
summary="Failed output"
|
||||
)
|
||||
|
||||
with patch("ea_chatbot.graph.nodes.reflector.get_llm_model") as mock_get_llm:
|
||||
mock_llm = MagicMock()
|
||||
# Mark as NOT satisfied
|
||||
mock_llm.with_structured_output.return_value.invoke.return_value = MagicMock(satisfied=False, reasoning="Incomplete.")
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
result = reflector_node(state)
|
||||
|
||||
# Should NOT increment (therefore not in the updates dict)
|
||||
assert "current_step" not in result
|
||||
# Should increment iterations
|
||||
assert result["iterations"] == 1
|
||||
# Should probably route to planner or retry
|
||||
assert result["next_action"] == "delegate" # Or 'planner' if we want re-planning
|
||||
51
backend/tests/test_review_fix_summary.py
Normal file
51
backend/tests/test_review_fix_summary.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import pytest
|
||||
from ea_chatbot.graph.workers.data_analyst.mapping import merge_worker_output as merge_analyst
|
||||
from ea_chatbot.graph.workers.researcher.mapping import merge_researcher_output as merge_researcher
|
||||
from ea_chatbot.graph.workers.data_analyst.state import WorkerState as AnalystState
|
||||
from ea_chatbot.graph.workers.researcher.state import WorkerState as ResearcherState
|
||||
|
||||
def test_analyst_merge_updates_summary():
|
||||
"""Verify that analyst merge updates the global summary."""
|
||||
worker_state = AnalystState(
|
||||
result="Actual Worker Result",
|
||||
messages=[],
|
||||
task="test",
|
||||
iterations=1,
|
||||
vfs_state={},
|
||||
plots=[]
|
||||
)
|
||||
|
||||
updates = merge_analyst(worker_state)
|
||||
assert "summary" in updates
|
||||
assert updates["summary"] == "Actual Worker Result"
|
||||
|
||||
def test_researcher_merge_updates_summary():
|
||||
"""Verify that researcher merge updates the global summary."""
|
||||
worker_state = ResearcherState(
|
||||
result="Actual Research Result",
|
||||
messages=[],
|
||||
task="test",
|
||||
iterations=1,
|
||||
queries=[],
|
||||
raw_results=[]
|
||||
)
|
||||
|
||||
updates = merge_researcher(worker_state)
|
||||
assert "summary" in updates
|
||||
assert updates["summary"] == "Actual Research Result"
|
||||
|
||||
def test_merge_handles_none_result():
|
||||
"""Verify that merge functions handle None results gracefully."""
|
||||
worker_state = AnalystState(
|
||||
result=None,
|
||||
messages=[],
|
||||
task="test",
|
||||
iterations=1,
|
||||
vfs_state={},
|
||||
plots=[]
|
||||
)
|
||||
|
||||
updates = merge_analyst(worker_state)
|
||||
assert updates["summary"] is not None
|
||||
assert isinstance(updates["messages"][0].content, str)
|
||||
assert len(updates["messages"][0].content) > 0
|
||||
85
backend/tests/test_review_fix_unbounded_loop.py
Normal file
85
backend/tests/test_review_fix_unbounded_loop.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.graph.workflow import create_workflow
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
def test_orchestrator_loop_retry_budget():
|
||||
"""Verify that the orchestrator loop is bounded and terminates after max retries."""
|
||||
|
||||
mock_analyzer = MagicMock()
|
||||
mock_planner = MagicMock()
|
||||
mock_delegate = MagicMock()
|
||||
mock_worker = MagicMock()
|
||||
mock_reflector = MagicMock()
|
||||
mock_synthesizer = MagicMock()
|
||||
|
||||
# 1. Analyzer: Proceed to planning
|
||||
mock_analyzer.return_value = {"next_action": "plan"}
|
||||
|
||||
# 2. Planner: One task
|
||||
mock_planner.return_value = {
|
||||
"checklist": [{"task": "Unsolvable Task", "worker": "data_analyst"}],
|
||||
"current_step": 0,
|
||||
"iterations": 0
|
||||
}
|
||||
|
||||
# We'll use the REAL delegate and reflector logic to verify the fix
|
||||
# But we mock the LLM calls inside them if necessary.
|
||||
# Actually, it's easier to just mock the node return values but follow the logic.
|
||||
|
||||
from ea_chatbot.graph.nodes.delegate import delegate_node
|
||||
from ea_chatbot.graph.nodes.reflector import reflector_node
|
||||
|
||||
# Mocking the LLM inside reflector to always be unsatisfied
|
||||
with patch("ea_chatbot.graph.nodes.reflector.get_llm_model") as mock_get_llm:
|
||||
mock_llm = MagicMock()
|
||||
# Mark as NOT satisfied
|
||||
mock_llm.with_structured_output.return_value.invoke.return_value = MagicMock(satisfied=False, reasoning="Still bad.")
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
app = create_workflow(
|
||||
query_analyzer=mock_analyzer,
|
||||
planner=mock_planner,
|
||||
# delegate=delegate_node, # Use real
|
||||
data_analyst_worker=mock_worker,
|
||||
# reflector=reflector_node, # Use real
|
||||
synthesizer=mock_synthesizer
|
||||
)
|
||||
|
||||
# Mock worker to return something
|
||||
mock_worker.return_value = {"result": "Bad Output", "messages": [AIMessage(content="Bad")]}
|
||||
mock_synthesizer.return_value = {"messages": [AIMessage(content="Failure Summary")], "next_action": "end"}
|
||||
|
||||
initial_state = AgentState(
|
||||
messages=[HumanMessage(content="test")],
|
||||
question="test",
|
||||
analysis={},
|
||||
next_action="",
|
||||
iterations=0,
|
||||
checklist=[],
|
||||
current_step=0,
|
||||
vfs={},
|
||||
plots=[],
|
||||
dfs={}
|
||||
)
|
||||
|
||||
# Run the graph. If fix works, it should hit iterations=3 and route to synthesizer.
|
||||
# We use a recursion_limit higher than our retry budget but low enough to fail fast if unbounded.
|
||||
final_state = app.invoke(initial_state, config={"recursion_limit": 20})
|
||||
|
||||
# Assertions
|
||||
# 1. We tried 3 times (iterations 0, 1, 2) and failed on 3rd.
|
||||
# Wait, delegate routes to summarize when iterations >= 3.
|
||||
# Reflector increments iterations.
|
||||
# Loop:
|
||||
# Start: it=0
|
||||
# Delegate (it=0) -> Worker -> Reflector (fail, it=1) -> Delegate (it=1)
|
||||
# Delegate (it=1) -> Worker -> Reflector (fail, it=2) -> Delegate (it=2)
|
||||
# Delegate (it=2) -> Worker -> Reflector (fail, it=3) -> Delegate (it=3)
|
||||
# Delegate (it=3) -> Summarize (it=0)
|
||||
|
||||
assert final_state["iterations"] == 0 # Reset in delegate or handled in synthesizer
|
||||
# Check if we hit the failure summary
|
||||
assert "Failed to complete task" in final_state["summary"]
|
||||
assert mock_worker.call_count == 3
|
||||
50
backend/tests/test_review_fix_vfs.py
Normal file
50
backend/tests/test_review_fix_vfs.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import pytest
|
||||
from ea_chatbot.utils.vfs import VFSHelper
|
||||
from ea_chatbot.graph.workers.data_analyst.mapping import prepare_worker_input
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
|
||||
def test_vfs_isolation_deep_copy():
|
||||
"""Verify that worker VFS is deep-copied from global state."""
|
||||
global_vfs = {
|
||||
"data.txt": {
|
||||
"content": "original",
|
||||
"metadata": {"tags": ["a"]}
|
||||
}
|
||||
}
|
||||
state = AgentState(
|
||||
checklist=[{"task": "test", "worker": "data_analyst"}],
|
||||
current_step=0,
|
||||
messages=[],
|
||||
question="test",
|
||||
analysis={},
|
||||
next_action="",
|
||||
iterations=0,
|
||||
vfs=global_vfs,
|
||||
plots=[],
|
||||
dfs={}
|
||||
)
|
||||
|
||||
worker_input = prepare_worker_input(state)
|
||||
worker_vfs = worker_input["vfs_state"]
|
||||
|
||||
# Mutate worker VFS nested object
|
||||
worker_vfs["data.txt"]["metadata"]["tags"].append("b")
|
||||
|
||||
# Global VFS should remain unchanged
|
||||
assert global_vfs["data.txt"]["metadata"]["tags"] == ["a"]
|
||||
|
||||
def test_vfs_schema_normalization():
|
||||
"""Verify that VFSHelper handles inconsistent VFS entries."""
|
||||
vfs = {
|
||||
"raw.txt": "just a string", # Inconsistent with standard Dict[str, Any] entry
|
||||
"valid.txt": {"content": "data", "metadata": {}}
|
||||
}
|
||||
helper = VFSHelper(vfs)
|
||||
|
||||
# Should not crash during read
|
||||
content, metadata = helper.read("raw.txt")
|
||||
assert content == "just a string"
|
||||
assert metadata == {}
|
||||
|
||||
# Should not crash during list/other ops if they assume dict
|
||||
assert "raw.txt" in helper.list()
|
||||
15
backend/tests/test_state_extensions.py
Normal file
15
backend/tests/test_state_extensions.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from typing import get_type_hints, List, Dict, Any
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
|
||||
def test_agent_state_extensions():
|
||||
"""Verify that AgentState has the new DeepAgents fields."""
|
||||
hints = get_type_hints(AgentState)
|
||||
|
||||
assert "checklist" in hints
|
||||
# checklist: List[Dict[str, Any]] (or similar for tasks)
|
||||
|
||||
assert "current_step" in hints
|
||||
assert hints["current_step"] == int
|
||||
|
||||
assert "vfs" in hints
|
||||
# vfs: Dict[str, Any]
|
||||
@@ -1,47 +0,0 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from langchain_core.messages import AIMessage
|
||||
from ea_chatbot.graph.nodes.summarizer import summarizer_node
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm():
|
||||
with patch("ea_chatbot.graph.nodes.summarizer.get_llm_model") as mock_get_llm:
|
||||
mock_llm_instance = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm_instance
|
||||
yield mock_llm_instance
|
||||
|
||||
def test_summarizer_node_success(mock_llm):
|
||||
"""Test that summarizer_node invokes LLM with correct inputs and returns messages."""
|
||||
state = {
|
||||
"question": "What is the total count?",
|
||||
"plan": "1. Run query\n2. Sum results",
|
||||
"code_output": "The total is 100",
|
||||
"messages": []
|
||||
}
|
||||
|
||||
mock_llm.invoke.return_value = AIMessage(content="The final answer is 100.")
|
||||
|
||||
result = summarizer_node(state)
|
||||
|
||||
# Verify LLM was called
|
||||
assert mock_llm.invoke.called
|
||||
|
||||
# Verify result structure
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 1
|
||||
assert isinstance(result["messages"][0], AIMessage)
|
||||
assert result["messages"][0].content == "The final answer is 100."
|
||||
|
||||
def test_summarizer_node_empty_state(mock_llm):
|
||||
"""Test handling of empty or minimal state."""
|
||||
state = {
|
||||
"question": "Empty?",
|
||||
"messages": []
|
||||
}
|
||||
|
||||
mock_llm.invoke.return_value = AIMessage(content="No data provided.")
|
||||
|
||||
result = summarizer_node(state)
|
||||
|
||||
assert "messages" in result
|
||||
assert result["messages"][0].content == "No data provided."
|
||||
49
backend/tests/test_vfs.py
Normal file
49
backend/tests/test_vfs.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import pytest
|
||||
from ea_chatbot.utils.vfs import VFSHelper
|
||||
|
||||
def test_vfs_read_write():
|
||||
"""Verify that we can write and read from the VFS."""
|
||||
vfs = {}
|
||||
helper = VFSHelper(vfs)
|
||||
|
||||
helper.write("test.txt", "Hello World", metadata={"type": "text"})
|
||||
assert "test.txt" in vfs
|
||||
assert vfs["test.txt"]["content"] == "Hello World"
|
||||
assert vfs["test.txt"]["metadata"]["type"] == "text"
|
||||
|
||||
content, metadata = helper.read("test.txt")
|
||||
assert content == "Hello World"
|
||||
assert metadata["type"] == "text"
|
||||
|
||||
def test_vfs_list():
|
||||
"""Verify that we can list files in the VFS."""
|
||||
vfs = {}
|
||||
helper = VFSHelper(vfs)
|
||||
|
||||
helper.write("a.txt", "content a")
|
||||
helper.write("b.txt", "content b")
|
||||
|
||||
files = helper.list()
|
||||
assert "a.txt" in files
|
||||
assert "b.txt" in files
|
||||
assert len(files) == 2
|
||||
|
||||
def test_vfs_delete():
|
||||
"""Verify that we can delete files from the VFS."""
|
||||
vfs = {}
|
||||
helper = VFSHelper(vfs)
|
||||
|
||||
helper.write("test.txt", "Hello")
|
||||
helper.delete("test.txt")
|
||||
|
||||
assert "test.txt" not in vfs
|
||||
assert len(helper.list()) == 0
|
||||
|
||||
def test_vfs_not_found():
|
||||
"""Verify that reading a non-existent file returns None."""
|
||||
vfs = {}
|
||||
helper = VFSHelper(vfs)
|
||||
|
||||
content, metadata = helper.read("missing.txt")
|
||||
assert content is None
|
||||
assert metadata is None
|
||||
48
backend/tests/test_vfs_robustness.py
Normal file
48
backend/tests/test_vfs_robustness.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import pytest
|
||||
import threading
|
||||
from ea_chatbot.utils.vfs import safe_vfs_copy
|
||||
|
||||
def test_safe_vfs_copy_success():
|
||||
"""Test standard success case."""
|
||||
vfs = {
|
||||
"test.csv": {"content": "data", "metadata": {"type": "csv"}},
|
||||
"num": 42
|
||||
}
|
||||
copied = safe_vfs_copy(vfs)
|
||||
assert copied == vfs
|
||||
assert copied is not vfs
|
||||
assert copied["test.csv"] is not vfs["test.csv"]
|
||||
|
||||
def test_safe_vfs_copy_handles_non_copyable():
|
||||
"""Test replacing uncopyable objects with error placeholders."""
|
||||
# A threading.Lock is famously uncopyable
|
||||
lock = threading.Lock()
|
||||
|
||||
vfs = {
|
||||
"safe_file": "important data",
|
||||
"unsafe_lock": lock
|
||||
}
|
||||
|
||||
copied = safe_vfs_copy(vfs)
|
||||
|
||||
# Safe one remains
|
||||
assert copied["safe_file"] == "important data"
|
||||
|
||||
# Unsafe one is REPLACED with an error dict
|
||||
assert isinstance(copied["unsafe_lock"], dict)
|
||||
assert "content" in copied["unsafe_lock"]
|
||||
assert "ERROR" in copied["unsafe_lock"]["content"]
|
||||
assert copied["unsafe_lock"]["metadata"]["type"] == "error"
|
||||
assert "lock" in str(copied["unsafe_lock"]["metadata"]["error"]).lower()
|
||||
|
||||
# Original is unchanged (it was a lock)
|
||||
assert vfs["unsafe_lock"] is lock
|
||||
|
||||
def test_safe_vfs_copy_preserves_nested_copyable():
|
||||
"""Test deepcopy still works for complex but copyable objects."""
|
||||
vfs = {
|
||||
"data": {"a": [1, 2, 3], "b": {"c": True}}
|
||||
}
|
||||
copied = safe_vfs_copy(vfs)
|
||||
assert copied["data"]["a"] == [1, 2, 3]
|
||||
assert copied["data"]["a"] is not vfs["data"]["a"]
|
||||
@@ -1,92 +1,80 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.graph.workflow import app
|
||||
from ea_chatbot.graph.nodes.query_analyzer import QueryAnalysis
|
||||
from ea_chatbot.schemas import TaskPlanResponse, TaskPlanContext, CodeGenerationResponse
|
||||
from ea_chatbot.graph.workflow import create_workflow
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
|
||||
@patch("ea_chatbot.graph.nodes.planner.get_llm_model")
|
||||
@patch("ea_chatbot.graph.nodes.coder.get_llm_model")
|
||||
@patch("ea_chatbot.graph.nodes.summarizer.get_llm_model")
|
||||
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
|
||||
@patch("ea_chatbot.utils.database_inspection.get_data_summary")
|
||||
@patch("ea_chatbot.graph.nodes.executor.Settings")
|
||||
@patch("ea_chatbot.graph.nodes.executor.DBClient")
|
||||
def test_workflow_full_flow(mock_db_client, mock_settings, mock_get_summary, mock_researcher_llm, mock_summarizer_llm, mock_coder_llm, mock_planner_llm, mock_qa_llm):
|
||||
"""Test the flow from query_analyzer through planner to coder."""
|
||||
def test_workflow_full_flow():
|
||||
"""Test the full Orchestrator-Workers flow using node injection."""
|
||||
|
||||
# Mock Settings for Executor
|
||||
mock_settings_instance = MagicMock()
|
||||
mock_settings_instance.db_host = "localhost"
|
||||
mock_settings_instance.db_port = 5432
|
||||
mock_settings_instance.db_user = "user"
|
||||
mock_settings_instance.db_pswd = "pass"
|
||||
mock_settings_instance.db_name = "test_db"
|
||||
mock_settings_instance.db_table = "test_table"
|
||||
mock_settings.return_value = mock_settings_instance
|
||||
mock_analyzer = MagicMock()
|
||||
mock_planner = MagicMock()
|
||||
mock_delegate = MagicMock()
|
||||
mock_worker = MagicMock()
|
||||
mock_reflector = MagicMock()
|
||||
mock_synthesizer = MagicMock()
|
||||
mock_summarize_conv = MagicMock()
|
||||
|
||||
# Mock DBClient
|
||||
mock_client_instance = MagicMock()
|
||||
mock_db_client.return_value = mock_client_instance
|
||||
# 1. Analyzer: Proceed to planning
|
||||
mock_analyzer.return_value = {"next_action": "plan"}
|
||||
|
||||
# 1. Mock Query Analyzer
|
||||
mock_qa_instance = MagicMock()
|
||||
mock_qa_llm.return_value = mock_qa_instance
|
||||
mock_qa_instance.with_structured_output.return_value.invoke.return_value = QueryAnalysis(
|
||||
data_required=["2024 results"],
|
||||
unknowns=[],
|
||||
ambiguities=[],
|
||||
conditions=[],
|
||||
next_action="plan"
|
||||
# 2. Planner: Generate checklist
|
||||
mock_planner.return_value = {
|
||||
"checklist": [{"task": "Step 1", "worker": "data_analyst"}],
|
||||
"current_step": 0
|
||||
}
|
||||
|
||||
# 3. Delegate: Route to data_analyst
|
||||
mock_delegate.side_effect = [
|
||||
{"next_action": "data_analyst"},
|
||||
{"next_action": "summarize"}
|
||||
]
|
||||
|
||||
# 4. Worker: Success
|
||||
mock_worker.return_value = {
|
||||
"messages": [AIMessage(content="Worker Summary")],
|
||||
"vfs": {}
|
||||
}
|
||||
|
||||
# 5. Reflector: Advance
|
||||
mock_reflector.return_value = {
|
||||
"current_step": 1,
|
||||
"next_action": "delegate"
|
||||
}
|
||||
|
||||
# 6. Synthesizer: Final answer
|
||||
mock_synthesizer.return_value = {
|
||||
"messages": [AIMessage(content="Final Summary")],
|
||||
"next_action": "end"
|
||||
}
|
||||
|
||||
# 7. Summarize Conv: End
|
||||
mock_summarize_conv.return_value = {"summary": "Done"}
|
||||
|
||||
app = create_workflow(
|
||||
query_analyzer=mock_analyzer,
|
||||
planner=mock_planner,
|
||||
delegate=mock_delegate,
|
||||
data_analyst_worker=mock_worker,
|
||||
reflector=mock_reflector,
|
||||
synthesizer=mock_synthesizer,
|
||||
summarize_conversation=mock_summarize_conv
|
||||
)
|
||||
|
||||
# 2. Mock Planner
|
||||
mock_planner_instance = MagicMock()
|
||||
mock_planner_llm.return_value = mock_planner_instance
|
||||
mock_get_summary.return_value = "Data summary"
|
||||
mock_planner_instance.with_structured_output.return_value.invoke.return_value = TaskPlanResponse(
|
||||
goal="Task Goal",
|
||||
reflection="Reflection",
|
||||
context=TaskPlanContext(initial_context="Ctx", assumptions=[], constraints=[]),
|
||||
steps=["Step 1"]
|
||||
)
|
||||
|
||||
# 3. Mock Coder
|
||||
mock_coder_instance = MagicMock()
|
||||
mock_coder_llm.return_value = mock_coder_instance
|
||||
mock_coder_instance.with_structured_output.return_value.invoke.return_value = CodeGenerationResponse(
|
||||
code="print('Hello')",
|
||||
explanation="Explanation"
|
||||
)
|
||||
|
||||
# 4. Mock Summarizer
|
||||
mock_summarizer_instance = MagicMock()
|
||||
mock_summarizer_llm.return_value = mock_summarizer_instance
|
||||
mock_summarizer_instance.invoke.return_value = AIMessage(content="Summary")
|
||||
|
||||
# 5. Mock Researcher (just in case)
|
||||
mock_researcher_instance = MagicMock()
|
||||
mock_researcher_llm.return_value = mock_researcher_instance
|
||||
|
||||
# Initial state
|
||||
initial_state = {
|
||||
"messages": [],
|
||||
"question": "Show me the 2024 results",
|
||||
"messages": [HumanMessage(content="Show me results")],
|
||||
"question": "Show me results",
|
||||
"analysis": None,
|
||||
"next_action": "",
|
||||
"plan": None,
|
||||
"code": None,
|
||||
"error": None,
|
||||
"iterations": 0,
|
||||
"checklist": [],
|
||||
"current_step": 0,
|
||||
"vfs": {},
|
||||
"plots": [],
|
||||
"dfs": {}
|
||||
}
|
||||
|
||||
# Run the graph
|
||||
# We use recursion_limit to avoid infinite loops in placeholders if any
|
||||
result = app.invoke(initial_state, config={"recursion_limit": 10})
|
||||
result = app.invoke(initial_state, config={"recursion_limit": 20})
|
||||
|
||||
assert result["next_action"] == "plan"
|
||||
assert "plan" in result and result["plan"] is not None
|
||||
assert "code" in result and "print('Hello')" in result["code"]
|
||||
assert "analysis" in result
|
||||
assert "Final Summary" in [m.content for m in result["messages"]]
|
||||
assert result["current_step"] == 1
|
||||
|
||||
@@ -1,60 +1,55 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from langchain_core.messages import AIMessage
|
||||
from ea_chatbot.graph.workflow import app
|
||||
from ea_chatbot.graph.nodes.query_analyzer import QueryAnalysis
|
||||
from ea_chatbot.schemas import TaskPlanResponse, TaskPlanContext, CodeGenerationResponse
|
||||
from ea_chatbot.schemas import QueryAnalysis, ChecklistResponse, ChecklistTask, CodeGenerationResponse
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llms():
|
||||
with patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model") as mock_qa_llm, \
|
||||
patch("ea_chatbot.graph.nodes.planner.get_llm_model") as mock_planner_llm, \
|
||||
patch("ea_chatbot.graph.nodes.coder.get_llm_model") as mock_coder_llm, \
|
||||
patch("ea_chatbot.graph.nodes.summarizer.get_llm_model") as mock_summarizer_llm, \
|
||||
patch("ea_chatbot.graph.nodes.researcher.get_llm_model") as mock_researcher_llm, \
|
||||
patch("ea_chatbot.graph.nodes.summarize_conversation.get_llm_model") as mock_summary_llm, \
|
||||
patch("ea_chatbot.utils.database_inspection.get_data_summary") as mock_get_summary:
|
||||
mock_get_summary.return_value = "Data summary"
|
||||
|
||||
# Mock summary LLM to return a simple response
|
||||
mock_summary_instance = MagicMock()
|
||||
mock_summary_llm.return_value = mock_summary_instance
|
||||
mock_summary_instance.invoke.return_value = AIMessage(content="Turn summary")
|
||||
|
||||
with patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model") as mock_qa, \
|
||||
patch("ea_chatbot.graph.nodes.planner.get_llm_model") as mock_planner, \
|
||||
patch("ea_chatbot.graph.workers.data_analyst.nodes.coder.get_llm_model") as mock_coder, \
|
||||
patch("ea_chatbot.graph.workers.data_analyst.nodes.summarizer.get_llm_model") as mock_worker_summarizer, \
|
||||
patch("ea_chatbot.graph.nodes.synthesizer.get_llm_model") as mock_synthesizer, \
|
||||
patch("ea_chatbot.graph.workers.researcher.nodes.searcher.get_llm_model") as mock_researcher, \
|
||||
patch("ea_chatbot.graph.workers.researcher.nodes.summarizer.get_llm_model") as mock_res_summarizer, \
|
||||
patch("ea_chatbot.graph.nodes.reflector.get_llm_model") as mock_reflector:
|
||||
yield {
|
||||
"qa": mock_qa_llm,
|
||||
"planner": mock_planner_llm,
|
||||
"coder": mock_coder_llm,
|
||||
"summarizer": mock_summarizer_llm,
|
||||
"researcher": mock_researcher_llm,
|
||||
"summary": mock_summary_llm
|
||||
"qa": mock_qa,
|
||||
"planner": mock_planner,
|
||||
"coder": mock_coder,
|
||||
"worker_summarizer": mock_worker_summarizer,
|
||||
"synthesizer": mock_synthesizer,
|
||||
"researcher": mock_researcher,
|
||||
"res_summarizer": mock_res_summarizer,
|
||||
"reflector": mock_reflector
|
||||
}
|
||||
|
||||
def test_workflow_data_analysis_flow(mock_llms):
|
||||
"""Test full flow: QueryAnalyzer -> Planner -> Coder -> Executor -> Summarizer."""
|
||||
"""Test full flow: QueryAnalyzer -> Planner -> Delegate -> DataAnalyst -> Reflector -> Synthesizer."""
|
||||
|
||||
# 1. Mock Query Analyzer (routes to plan)
|
||||
# 1. Mock Query Analyzer
|
||||
mock_qa_instance = MagicMock()
|
||||
mock_llms["qa"].return_value = mock_qa_instance
|
||||
mock_qa_instance.with_structured_output.return_value.invoke.return_value = QueryAnalysis(
|
||||
data_required=["2024 results"],
|
||||
unknowns=[],
|
||||
data_required=["2024 results"],
|
||||
unknowns=[],
|
||||
ambiguities=[],
|
||||
conditions=[],
|
||||
conditions=[],
|
||||
next_action="plan"
|
||||
)
|
||||
|
||||
# 2. Mock Planner
|
||||
mock_planner_instance = MagicMock()
|
||||
mock_llms["planner"].return_value = mock_planner_instance
|
||||
mock_planner_instance.with_structured_output.return_value.invoke.return_value = TaskPlanResponse(
|
||||
mock_planner_instance.with_structured_output.return_value.invoke.return_value = ChecklistResponse(
|
||||
goal="Get results",
|
||||
reflection="Reflect",
|
||||
context=TaskPlanContext(initial_context="Ctx", assumptions=[], constraints=[]),
|
||||
steps=["Step 1"]
|
||||
checklist=[ChecklistTask(task="Query Data", worker="data_analyst")]
|
||||
)
|
||||
|
||||
# 3. Mock Coder
|
||||
# 3. Mock Coder (Worker)
|
||||
mock_coder_instance = MagicMock()
|
||||
mock_llms["coder"].return_value = mock_coder_instance
|
||||
mock_coder_instance.with_structured_output.return_value.invoke.return_value = CodeGenerationResponse(
|
||||
@@ -62,10 +57,20 @@ def test_workflow_data_analysis_flow(mock_llms):
|
||||
explanation="Explain"
|
||||
)
|
||||
|
||||
# 4. Mock Summarizer
|
||||
mock_summarizer_instance = MagicMock()
|
||||
mock_llms["summarizer"].return_value = mock_summarizer_instance
|
||||
mock_summarizer_instance.invoke.return_value = AIMessage(content="Final Summary: Success")
|
||||
# 4. Mock Worker Summarizer
|
||||
mock_ws_instance = MagicMock()
|
||||
mock_llms["worker_summarizer"].return_value = mock_ws_instance
|
||||
mock_ws_instance.invoke.return_value = AIMessage(content="Worker Summary")
|
||||
|
||||
# 5. Mock Reflector
|
||||
mock_reflector_instance = MagicMock()
|
||||
mock_llms["reflector"].return_value = mock_reflector_instance
|
||||
mock_reflector_instance.with_structured_output.return_value.invoke.return_value = MagicMock(satisfied=True, reasoning="Good.")
|
||||
|
||||
# 6. Mock Synthesizer
|
||||
mock_syn_instance = MagicMock()
|
||||
mock_llms["synthesizer"].return_value = mock_syn_instance
|
||||
mock_syn_instance.invoke.return_value = AIMessage(content="Final Summary: Success")
|
||||
|
||||
# Initial state
|
||||
initial_state = {
|
||||
@@ -73,66 +78,77 @@ def test_workflow_data_analysis_flow(mock_llms):
|
||||
"question": "Show me 2024 results",
|
||||
"analysis": None,
|
||||
"next_action": "",
|
||||
"plan": None,
|
||||
"code": None,
|
||||
"error": None,
|
||||
"iterations": 0,
|
||||
"checklist": [],
|
||||
"current_step": 0,
|
||||
"vfs": {},
|
||||
"plots": [],
|
||||
"dfs": {}
|
||||
}
|
||||
|
||||
# Run the graph
|
||||
result = app.invoke(initial_state, config={"recursion_limit": 15})
|
||||
result = app.invoke(initial_state, config={"recursion_limit": 20})
|
||||
|
||||
assert result["next_action"] == "plan"
|
||||
assert "Execution Success" in result["code_output"]
|
||||
assert "Final Summary: Success" in result["messages"][-1].content
|
||||
assert "Final Summary: Success" in [m.content for m in result["messages"]]
|
||||
assert result["current_step"] == 1
|
||||
|
||||
def test_workflow_research_flow(mock_llms):
|
||||
"""Test flow: QueryAnalyzer -> Researcher -> Summarizer."""
|
||||
"""Test flow with research task."""
|
||||
|
||||
# 1. Mock Query Analyzer (routes to research)
|
||||
# 1. Mock Query Analyzer
|
||||
mock_qa_instance = MagicMock()
|
||||
mock_llms["qa"].return_value = mock_qa_instance
|
||||
mock_qa_instance.with_structured_output.return_value.invoke.return_value = QueryAnalysis(
|
||||
data_required=[],
|
||||
unknowns=[],
|
||||
data_required=[],
|
||||
unknowns=[],
|
||||
ambiguities=[],
|
||||
conditions=[],
|
||||
conditions=[],
|
||||
next_action="research"
|
||||
)
|
||||
|
||||
# 2. Mock Researcher
|
||||
mock_researcher_instance = MagicMock()
|
||||
mock_llms["researcher"].return_value = mock_researcher_instance
|
||||
# Researcher node uses bind_tools if it's ChatOpenAI/ChatGoogleGenerativeAI
|
||||
# Since it's a MagicMock, it will fallback to using the base instance
|
||||
mock_researcher_instance.invoke.return_value = AIMessage(content="Research Results")
|
||||
# 2. Mock Planner
|
||||
mock_planner_instance = MagicMock()
|
||||
mock_llms["planner"].return_value = mock_planner_instance
|
||||
mock_planner_instance.with_structured_output.return_value.invoke.return_value = ChecklistResponse(
|
||||
goal="Search",
|
||||
reflection="Reflect",
|
||||
checklist=[ChecklistTask(task="Search Web", worker="researcher")]
|
||||
)
|
||||
|
||||
# Also mock bind_tools just in case we ever use spec
|
||||
mock_llm_with_tools = MagicMock()
|
||||
mock_researcher_instance.bind_tools.return_value = mock_llm_with_tools
|
||||
mock_llm_with_tools.invoke.return_value = AIMessage(content="Research Results")
|
||||
# 3. Mock Researcher Searcher
|
||||
mock_res_instance = MagicMock()
|
||||
mock_llms["researcher"].return_value = mock_res_instance
|
||||
mock_res_instance.invoke.return_value = AIMessage(content="Research Result")
|
||||
|
||||
# 3. Mock Summarizer (not used in this flow, but kept for completeness)
|
||||
mock_summarizer_instance = MagicMock()
|
||||
mock_llms["summarizer"].return_value = mock_summarizer_instance
|
||||
mock_summarizer_instance.invoke.return_value = AIMessage(content="Final Summary: Research Success")
|
||||
# 4. Mock Researcher Summarizer
|
||||
mock_rs_instance = MagicMock()
|
||||
mock_llms["res_summarizer"].return_value = mock_rs_instance
|
||||
mock_rs_instance.invoke.return_value = AIMessage(content="Researcher Summary")
|
||||
|
||||
# 5. Mock Reflector
|
||||
mock_reflector_instance = MagicMock()
|
||||
mock_llms["reflector"].return_value = mock_reflector_instance
|
||||
mock_reflector_instance.with_structured_output.return_value.invoke.return_value = MagicMock(satisfied=True, reasoning="Good.")
|
||||
|
||||
# 6. Mock Synthesizer
|
||||
mock_syn_instance = MagicMock()
|
||||
mock_llms["synthesizer"].return_value = mock_syn_instance
|
||||
mock_syn_instance.invoke.return_value = AIMessage(content="Final Research Summary")
|
||||
|
||||
# Initial state
|
||||
initial_state = {
|
||||
"messages": [],
|
||||
"question": "Who is the governor of Florida?",
|
||||
"question": "Who is the governor?",
|
||||
"analysis": None,
|
||||
"next_action": "",
|
||||
"plan": None,
|
||||
"code": None,
|
||||
"error": None,
|
||||
"iterations": 0,
|
||||
"checklist": [],
|
||||
"current_step": 0,
|
||||
"vfs": {},
|
||||
"plots": [],
|
||||
"dfs": {}
|
||||
}
|
||||
|
||||
# Run the graph
|
||||
result = app.invoke(initial_state, config={"recursion_limit": 10})
|
||||
result = app.invoke(initial_state, config={"recursion_limit": 20})
|
||||
|
||||
assert result["next_action"] == "research"
|
||||
assert "Research Results" in result["messages"][-1].content
|
||||
assert "Final Research Summary" in [m.content for m in result["messages"]]
|
||||
assert result["current_step"] == 1
|
||||
|
||||
4
frontend/package-lock.json
generated
4
frontend/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "frontend",
|
||||
"version": "0.0.0",
|
||||
"version": "0.1.0",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "frontend",
|
||||
"version": "0.0.0",
|
||||
"version": "0.1.0",
|
||||
"dependencies": {
|
||||
"@hookform/resolvers": "^5.2.2",
|
||||
"@tailwindcss/vite": "^4.1.18",
|
||||
|
||||
@@ -9,9 +9,9 @@ interface ExecutionStatusProps {
|
||||
|
||||
const PHASE_CONFIG = [
|
||||
{ label: "Analyzing query...", match: "Query analysis complete." },
|
||||
{ label: "Generating strategic plan...", match: "Strategic plan generated." },
|
||||
{ label: "Writing analysis code...", match: "Analysis code generated." },
|
||||
{ label: "Performing data analysis...", match: "Data analysis and visualization complete." }
|
||||
{ label: "Generating high-level plan...", match: "Checklist generated." },
|
||||
{ label: "Delegating to specialists...", match: "Task assigned." },
|
||||
{ label: "Synthesizing final answer...", match: "Final synthesis complete." }
|
||||
]
|
||||
|
||||
export function ExecutionStatus({ steps, isComplete, className }: ExecutionStatusProps) {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { cva } from "class-variance-authority"
|
||||
|
||||
export const buttonVariants = cva(
|
||||
"inline-flex items-center justify-center gap-2 whitespace-nowrap rounded-md text-sm font-medium transition-all disabled:pointer-events-none disabled:opacity-50 [&_svg]:pointer-events-none [&_svg:not([class*='size-'])]:size-4 shrink-0 [&_svg]:shrink-0 outline-none focus-visible:border-ring focus-visible:ring-ring/50 focus-visible:ring-[3px] aria-invalid:ring-destructive/20 dark:aria-invalid:ring-destructive/40 aria-invalid:border-destructive",
|
||||
"inline-flex items-center justify-center gap-2 whitespace-nowrap rounded-md text-sm font-medium transition-all cursor-pointer disabled:pointer-events-none disabled:opacity-50 disabled:cursor-default [&_svg]:pointer-events-none [&_svg:not([class*='size-'])]:size-4 shrink-0 [&_svg]:shrink-0 outline-none focus-visible:border-ring focus-visible:ring-ring/50 focus-visible:ring-[3px] aria-invalid:ring-destructive/20 dark:aria-invalid:ring-destructive/40 aria-invalid:border-destructive",
|
||||
{
|
||||
variants: {
|
||||
variant: {
|
||||
|
||||
29
frontend/src/components/ui/button.test.tsx
Normal file
29
frontend/src/components/ui/button.test.tsx
Normal file
@@ -0,0 +1,29 @@
|
||||
import { render, screen } from "@testing-library/react"
|
||||
import { describe, expect, it } from "vitest"
|
||||
import { Button } from "./button"
|
||||
|
||||
describe("Button component", () => {
|
||||
it("renders with a pointer cursor", () => {
|
||||
render(<Button>Click me</Button>)
|
||||
const button = screen.getByRole("button", { name: /click me/i })
|
||||
expect(button).toHaveClass("cursor-pointer")
|
||||
})
|
||||
|
||||
it("renders different variants with a pointer cursor", () => {
|
||||
const { rerender } = render(<Button variant="outline">Outline</Button>)
|
||||
let button = screen.getByRole("button", { name: /outline/i })
|
||||
expect(button).toHaveClass("cursor-pointer")
|
||||
|
||||
rerender(<Button variant="destructive">Destructive</Button>)
|
||||
button = screen.getByRole("button", { name: /destructive/i })
|
||||
expect(button).toHaveClass("cursor-pointer")
|
||||
})
|
||||
|
||||
it("renders with a default cursor when disabled", () => {
|
||||
render(<Button disabled>Disabled Button</Button>)
|
||||
const button = screen.getByRole("button", { name: /disabled button/i })
|
||||
expect(button).toBeDisabled()
|
||||
expect(button).toHaveClass("disabled:pointer-events-none")
|
||||
expect(button).toHaveClass("disabled:cursor-default")
|
||||
})
|
||||
})
|
||||
74
frontend/src/components/ui/dropdown-menu.test.tsx
Normal file
74
frontend/src/components/ui/dropdown-menu.test.tsx
Normal file
@@ -0,0 +1,74 @@
|
||||
import { render, screen } from "@testing-library/react"
|
||||
import { describe, expect, it } from "vitest"
|
||||
import {
|
||||
DropdownMenu,
|
||||
DropdownMenuContent,
|
||||
DropdownMenuItem,
|
||||
DropdownMenuSub,
|
||||
DropdownMenuSubTrigger,
|
||||
DropdownMenuSubContent,
|
||||
DropdownMenuPortal
|
||||
} from "./dropdown-menu"
|
||||
|
||||
describe("DropdownMenu components", () => {
|
||||
it("DropdownMenuItem should have cursor-pointer when enabled", () => {
|
||||
render(
|
||||
<DropdownMenu open={true}>
|
||||
<DropdownMenuContent>
|
||||
<DropdownMenuItem>Enabled Item</DropdownMenuItem>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
)
|
||||
const item = screen.getByText("Enabled Item")
|
||||
expect(item).toHaveClass("cursor-pointer")
|
||||
})
|
||||
|
||||
it("DropdownMenuSubTrigger should have cursor-pointer when enabled", () => {
|
||||
render(
|
||||
<DropdownMenu open={true}>
|
||||
<DropdownMenuContent>
|
||||
<DropdownMenuSub>
|
||||
<DropdownMenuSubTrigger>Sub Trigger</DropdownMenuSubTrigger>
|
||||
<DropdownMenuPortal>
|
||||
<DropdownMenuSubContent />
|
||||
</DropdownMenuPortal>
|
||||
</DropdownMenuSub>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
)
|
||||
const trigger = screen.getByText("Sub Trigger")
|
||||
expect(trigger).toHaveClass("cursor-pointer")
|
||||
})
|
||||
|
||||
it("DropdownMenuItem should show default cursor when disabled", () => {
|
||||
render(
|
||||
<DropdownMenu open={true}>
|
||||
<DropdownMenuContent>
|
||||
<DropdownMenuItem disabled>Disabled Item</DropdownMenuItem>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
)
|
||||
const item = screen.getByText("Disabled Item")
|
||||
// Radix sets data-disabled attribute
|
||||
expect(item).toHaveAttribute("data-disabled")
|
||||
expect(item).toHaveClass("data-[disabled]:cursor-default")
|
||||
})
|
||||
|
||||
it("DropdownMenuSubTrigger should show default cursor when disabled", () => {
|
||||
render(
|
||||
<DropdownMenu open={true}>
|
||||
<DropdownMenuContent>
|
||||
<DropdownMenuSub>
|
||||
<DropdownMenuSubTrigger disabled>Disabled Sub Trigger</DropdownMenuSubTrigger>
|
||||
<DropdownMenuPortal>
|
||||
<DropdownMenuSubContent />
|
||||
</DropdownMenuPortal>
|
||||
</DropdownMenuSub>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
)
|
||||
const trigger = screen.getByText("Disabled Sub Trigger")
|
||||
expect(trigger).toHaveAttribute("data-disabled")
|
||||
expect(trigger).toHaveClass("data-[disabled]:cursor-default")
|
||||
})
|
||||
})
|
||||
@@ -27,7 +27,7 @@ const DropdownMenuSubTrigger = React.forwardRef<
|
||||
<DropdownMenuPrimitive.SubTrigger
|
||||
ref={ref}
|
||||
className={cn(
|
||||
"focus:bg-accent data-[state=open]:bg-accent flex cursor-default items-center rounded-sm px-2 py-1.5 text-sm outline-hidden select-none",
|
||||
"focus:bg-accent data-[state=open]:bg-accent flex cursor-pointer items-center rounded-sm px-2 py-1.5 text-sm outline-hidden select-none data-[disabled]:cursor-default data-[disabled]:pointer-events-none data-[disabled]:opacity-50",
|
||||
inset && "pl-8",
|
||||
className
|
||||
)}
|
||||
@@ -81,7 +81,7 @@ const DropdownMenuItem = React.forwardRef<
|
||||
<DropdownMenuPrimitive.Item
|
||||
ref={ref}
|
||||
className={cn(
|
||||
"focus:bg-accent focus:text-accent-foreground relative flex cursor-default items-center rounded-sm px-2 py-1.5 text-sm outline-hidden transition-colors data-[disabled]:pointer-events-none data-[disabled]:opacity-50 select-none",
|
||||
"focus:bg-accent focus:text-accent-foreground relative flex cursor-pointer items-center rounded-sm px-2 py-1.5 text-sm outline-hidden transition-colors data-[disabled]:pointer-events-none data-[disabled]:opacity-50 data-[disabled]:cursor-default select-none",
|
||||
inset && "pl-8",
|
||||
className
|
||||
)}
|
||||
@@ -99,7 +99,7 @@ const DropdownMenuCheckboxItem = React.forwardRef<
|
||||
<DropdownMenuPrimitive.CheckboxItem
|
||||
ref={ref}
|
||||
className={cn(
|
||||
"focus:bg-accent focus:text-accent-foreground relative flex cursor-default items-center rounded-sm py-1.5 pr-2 pl-8 text-sm outline-hidden transition-colors data-[disabled]:pointer-events-none data-[disabled]:opacity-50 select-none",
|
||||
"focus:bg-accent focus:text-accent-foreground relative flex cursor-pointer items-center rounded-sm py-1.5 pr-2 pl-8 text-sm outline-hidden transition-colors data-[disabled]:pointer-events-none data-[disabled]:opacity-50 data-[disabled]:cursor-default select-none",
|
||||
className
|
||||
)}
|
||||
checked={checked}
|
||||
@@ -122,7 +122,7 @@ const DropdownMenuRadioItem = React.forwardRef<
|
||||
<DropdownMenuPrimitive.RadioItem
|
||||
ref={ref}
|
||||
className={cn(
|
||||
"focus:bg-accent focus:text-accent-foreground relative flex cursor-default items-center rounded-sm py-1.5 pr-2 pl-8 text-sm outline-hidden transition-colors data-[disabled]:pointer-events-none data-[disabled]:opacity-50 select-none",
|
||||
"focus:bg-accent focus:text-accent-foreground relative flex cursor-pointer items-center rounded-sm py-1.5 pr-2 pl-8 text-sm outline-hidden transition-colors data-[disabled]:pointer-events-none data-[disabled]:opacity-50 data-[disabled]:cursor-default select-none",
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
|
||||
83
frontend/src/components/ui/sidebar.test.tsx
Normal file
83
frontend/src/components/ui/sidebar.test.tsx
Normal file
@@ -0,0 +1,83 @@
|
||||
import { render, screen } from "@testing-library/react"
|
||||
import { describe, expect, it } from "vitest"
|
||||
import {
|
||||
SidebarProvider,
|
||||
SidebarMenuButton,
|
||||
SidebarMenuAction,
|
||||
SidebarGroupAction,
|
||||
SidebarMenu,
|
||||
SidebarMenuItem,
|
||||
SidebarGroup
|
||||
} from "./sidebar"
|
||||
|
||||
describe("Sidebar components", () => {
|
||||
it("SidebarMenuButton should have cursor-pointer", () => {
|
||||
render(
|
||||
<SidebarProvider>
|
||||
<SidebarMenu>
|
||||
<SidebarMenuItem>
|
||||
<SidebarMenuButton>Menu Button</SidebarMenuButton>
|
||||
</SidebarMenuItem>
|
||||
</SidebarMenu>
|
||||
</SidebarProvider>
|
||||
)
|
||||
const button = screen.getByRole("button", { name: /menu button/i })
|
||||
expect(button).toHaveClass("cursor-pointer")
|
||||
})
|
||||
|
||||
it("SidebarMenuAction should have cursor-pointer", () => {
|
||||
render(
|
||||
<SidebarProvider>
|
||||
<SidebarMenu>
|
||||
<SidebarMenuItem>
|
||||
<SidebarMenuButton>Button</SidebarMenuButton>
|
||||
<SidebarMenuAction>Action</SidebarMenuAction>
|
||||
</SidebarMenuItem>
|
||||
</SidebarMenu>
|
||||
</SidebarProvider>
|
||||
)
|
||||
const action = screen.getByRole("button", { name: /action/i })
|
||||
expect(action).toHaveClass("cursor-pointer")
|
||||
})
|
||||
|
||||
it("SidebarGroupAction should have cursor-pointer", () => {
|
||||
render(
|
||||
<SidebarProvider>
|
||||
<SidebarGroup>
|
||||
<SidebarGroupAction>Group Action</SidebarGroupAction>
|
||||
</SidebarGroup>
|
||||
</SidebarProvider>
|
||||
)
|
||||
const action = screen.getByRole("button", { name: /group action/i })
|
||||
expect(action).toHaveClass("cursor-pointer")
|
||||
})
|
||||
|
||||
it("SidebarGroupAction should have cursor-default when disabled", () => {
|
||||
render(
|
||||
<SidebarProvider>
|
||||
<SidebarGroup>
|
||||
<SidebarGroupAction disabled>Disabled Group Action</SidebarGroupAction>
|
||||
</SidebarGroup>
|
||||
</SidebarProvider>
|
||||
)
|
||||
const action = screen.getByRole("button", { name: /disabled group action/i })
|
||||
expect(action).toHaveClass("disabled:cursor-default")
|
||||
expect(action).toBeDisabled()
|
||||
})
|
||||
|
||||
it("SidebarMenuAction should not show pointer when disabled", () => {
|
||||
render(
|
||||
<SidebarProvider>
|
||||
<SidebarMenu>
|
||||
<SidebarMenuItem>
|
||||
<SidebarMenuButton>Button</SidebarMenuButton>
|
||||
<SidebarMenuAction disabled>Disabled Action</SidebarMenuAction>
|
||||
</SidebarMenuItem>
|
||||
</SidebarMenu>
|
||||
</SidebarProvider>
|
||||
)
|
||||
const action = screen.getByRole("button", { name: /disabled action/i })
|
||||
expect(action).toHaveClass("disabled:cursor-default")
|
||||
expect(action).toBeDisabled()
|
||||
})
|
||||
})
|
||||
@@ -405,7 +405,7 @@ function SidebarGroupAction({
|
||||
data-slot="sidebar-group-action"
|
||||
data-sidebar="group-action"
|
||||
className={cn(
|
||||
"text-sidebar-foreground ring-sidebar-ring hover:bg-sidebar-accent hover:text-sidebar-accent-foreground absolute top-3.5 right-3 flex aspect-square w-5 items-center justify-center rounded-md p-0 outline-hidden transition-transform focus-visible:ring-2 [&>svg]:size-4 [&>svg]:shrink-0",
|
||||
"text-sidebar-foreground ring-sidebar-ring hover:bg-sidebar-accent hover:text-sidebar-accent-foreground absolute top-3.5 right-3 flex aspect-square w-5 items-center justify-center rounded-md p-0 outline-hidden transition-transform focus-visible:ring-2 [&>svg]:size-4 [&>svg]:shrink-0 cursor-pointer disabled:cursor-default disabled:pointer-events-none disabled:opacity-50",
|
||||
// Increases the hit area of the button on mobile.
|
||||
"after:absolute after:-inset-2 md:after:hidden",
|
||||
"group-data-[collapsible=icon]:hidden",
|
||||
@@ -477,7 +477,7 @@ function SidebarMenuButton({
|
||||
data-size={size}
|
||||
data-active={isActive}
|
||||
className={cn(
|
||||
"peer/menu-button flex w-full items-center gap-2 overflow-hidden rounded-md p-2 text-left text-sm outline-hidden ring-sidebar-ring transition-[width,height,padding] hover:bg-sidebar-accent hover:text-sidebar-accent-foreground focus-visible:ring-2 active:bg-sidebar-accent active:text-sidebar-accent-foreground disabled:pointer-events-none disabled:opacity-50 group-has-data-[sidebar=menu-action]/menu-item:pr-8 aria-disabled:pointer-events-none aria-disabled:opacity-50 data-[active=true]:bg-sidebar-accent data-[active=true]:font-medium data-[active=true]:text-sidebar-accent-foreground data-[state=open]:hover:bg-sidebar-accent data-[state=open]:hover:text-sidebar-accent-foreground group-data-[collapsible=icon]:size-8! group-data-[collapsible=icon]:p-2! [&>span:last-child]:truncate [&>svg]:size-4 [&>svg]:shrink-0",
|
||||
"peer/menu-button flex w-full items-center gap-2 overflow-hidden rounded-md p-2 text-left text-sm outline-hidden ring-sidebar-ring transition-[width,height,padding] hover:bg-sidebar-accent hover:text-sidebar-accent-foreground focus-visible:ring-2 active:bg-sidebar-accent active:text-sidebar-accent-foreground disabled:pointer-events-none disabled:opacity-50 group-has-data-[sidebar=menu-action]/menu-item:pr-8 aria-disabled:pointer-events-none aria-disabled:opacity-50 data-[active=true]:bg-sidebar-accent data-[active=true]:font-medium data-[active=true]:text-sidebar-accent-foreground data-[state=open]:hover:bg-sidebar-accent data-[state=open]:hover:text-sidebar-accent-foreground group-data-[collapsible=icon]:size-8! group-data-[collapsible=icon]:p-2! [&>span:last-child]:truncate [&>svg]:size-4 [&>svg]:shrink-0 cursor-pointer disabled:cursor-default",
|
||||
variant === "default" && "hover:bg-sidebar-accent hover:text-sidebar-accent-foreground",
|
||||
variant === "outline" && "bg-background shadow-[0_0_0_1px_hsl(var(--sidebar-border))] hover:bg-sidebar-accent hover:text-sidebar-accent-foreground hover:shadow-[0_0_0_1px_hsl(var(--sidebar-accent))]",
|
||||
size === "default" && "h-8 text-sm",
|
||||
@@ -528,7 +528,7 @@ function SidebarMenuAction({
|
||||
data-slot="sidebar-menu-action"
|
||||
data-sidebar="menu-action"
|
||||
className={cn(
|
||||
"text-sidebar-foreground ring-sidebar-ring hover:bg-sidebar-accent hover:text-sidebar-accent-foreground peer-hover/menu-button:text-sidebar-accent-foreground absolute top-1.5 right-1 flex aspect-square w-5 items-center justify-center rounded-md p-0 outline-hidden transition-transform focus-visible:ring-2 [&>svg]:size-4 [&>svg]:shrink-0",
|
||||
"text-sidebar-foreground ring-sidebar-ring hover:bg-sidebar-accent hover:text-sidebar-accent-foreground peer-hover/menu-button:text-sidebar-accent-foreground absolute top-1.5 right-1 flex aspect-square w-5 items-center justify-center rounded-md p-0 outline-hidden transition-transform focus-visible:ring-2 [&>svg]:size-4 [&>svg]:shrink-0 cursor-pointer disabled:cursor-default disabled:pointer-events-none disabled:opacity-50",
|
||||
// Increases the hit area of the button on mobile.
|
||||
"after:absolute after:-inset-2 md:after:hidden",
|
||||
"peer-data-[size=sm]/menu-button:top-1",
|
||||
|
||||
@@ -3,21 +3,21 @@ import { ChatService, type ChatEvent, type MessageResponse } from "./chat"
|
||||
|
||||
describe("ChatService SSE Parsing", () => {
|
||||
it("should correctly parse a text stream chunk", () => {
|
||||
const rawChunk = `data: {"type": "on_chat_model_stream", "name": "summarizer", "data": {"chunk": "Hello"}}\n\n`
|
||||
const rawChunk = `data: {"type": "on_chat_model_stream", "name": "synthesizer", "data": {"chunk": "Hello"}}\n\n`
|
||||
const events = ChatService.parseSSEChunk(rawChunk)
|
||||
|
||||
expect(events).toHaveLength(1)
|
||||
expect(events[0]).toEqual({
|
||||
type: "on_chat_model_stream",
|
||||
name: "summarizer",
|
||||
name: "synthesizer",
|
||||
data: { chunk: "Hello" }
|
||||
})
|
||||
})
|
||||
|
||||
it("should handle multiple events in one chunk", () => {
|
||||
const rawChunk =
|
||||
`data: {"type": "on_chat_model_stream", "name": "summarizer", "data": {"chunk": "Hello"}}\n\n` +
|
||||
`data: {"type": "on_chat_model_stream", "name": "summarizer", "data": {"chunk": " World"}}\n\n`
|
||||
`data: {"type": "on_chat_model_stream", "name": "synthesizer", "data": {"chunk": "Hello"}}\n\n` +
|
||||
`data: {"type": "on_chat_model_stream", "name": "synthesizer", "data": {"chunk": " World"}}\n\n`
|
||||
|
||||
const events = ChatService.parseSSEChunk(rawChunk)
|
||||
|
||||
@@ -25,8 +25,8 @@ describe("ChatService SSE Parsing", () => {
|
||||
expect(events[1].data!.chunk).toBe(" World")
|
||||
})
|
||||
|
||||
it("should parse encoded plots from executor node", () => {
|
||||
const rawChunk = `data: {"type": "on_chain_end", "name": "executor", "data": {"encoded_plots": ["base64data"]}}\n\n`
|
||||
it("should parse encoded plots from data_analyst_worker node", () => {
|
||||
const rawChunk = `data: {"type": "on_chain_end", "name": "data_analyst_worker", "data": {"encoded_plots": ["base64data"]}}\n\n`
|
||||
const events = ChatService.parseSSEChunk(rawChunk)
|
||||
|
||||
expect(events[0].data!.encoded_plots).toEqual(["base64data"])
|
||||
@@ -45,7 +45,7 @@ describe("ChatService Message State Management", () => {
|
||||
const messages: MessageResponse[] = [{ id: "1", role: "assistant", content: "Initial", created_at: new Date().toISOString() }]
|
||||
const event: ChatEvent = {
|
||||
type: "on_chat_model_stream",
|
||||
node: "summarizer",
|
||||
node: "synthesizer",
|
||||
data: { chunk: { content: " text" } }
|
||||
}
|
||||
|
||||
@@ -57,7 +57,7 @@ describe("ChatService Message State Management", () => {
|
||||
const messages: MessageResponse[] = [{ id: "1", role: "assistant", content: "Analysis", created_at: new Date().toISOString(), plots: [] }]
|
||||
const event: ChatEvent = {
|
||||
type: "on_chain_end",
|
||||
name: "executor",
|
||||
name: "data_analyst_worker",
|
||||
data: { encoded_plots: ["plot1"] }
|
||||
}
|
||||
|
||||
|
||||
@@ -86,7 +86,8 @@ export const ChatService = {
|
||||
const { type, name, node, data } = event
|
||||
|
||||
// 1. Handle incremental LLM chunks for terminal nodes
|
||||
if (type === "on_chat_model_stream" && (node === "summarizer" || node === "researcher" || node === "clarification")) {
|
||||
// Now using 'synthesizer' for the final user response
|
||||
if (type === "on_chat_model_stream" && (node === "synthesizer" || node === "clarification")) {
|
||||
const chunk = data?.chunk?.content || ""
|
||||
if (!chunk) return messages
|
||||
|
||||
@@ -110,7 +111,7 @@ export const ChatService = {
|
||||
if (!lastMsg || lastMsg.role !== "assistant") return messages
|
||||
|
||||
// Terminal nodes final text
|
||||
if (name === "summarizer" || name === "researcher" || name === "clarification") {
|
||||
if (name === "synthesizer" || name === "clarification") {
|
||||
const messages_list = data?.output?.messages
|
||||
const msg = messages_list ? messages_list[messages_list.length - 1]?.content : null
|
||||
|
||||
@@ -121,8 +122,8 @@ export const ChatService = {
|
||||
}
|
||||
}
|
||||
|
||||
// Plots from executor
|
||||
if (name === "executor" && data?.encoded_plots) {
|
||||
// Plots from data analyst worker
|
||||
if (name === "data_analyst_worker" && data?.encoded_plots) {
|
||||
lastMsg.plots = [...(lastMsg.plots || []), ...data.encoded_plots]
|
||||
// Filter out the 'active' step and replace with 'complete'
|
||||
const filteredSteps = (lastMsg.steps || []).filter(s => s !== "Performing data analysis...");
|
||||
@@ -134,15 +135,25 @@ export const ChatService = {
|
||||
// Status for intermediate nodes (completion)
|
||||
const statusMap: Record<string, string> = {
|
||||
"query_analyzer": "Query analysis complete.",
|
||||
"planner": "Strategic plan generated.",
|
||||
"coder": "Analysis code generated."
|
||||
"planner": "Checklist generated.",
|
||||
"delegate": "Task assigned.",
|
||||
"reflector": "Result verified.",
|
||||
"coder": "Analysis code generated.",
|
||||
"executor": "Code execution complete.",
|
||||
"searcher": "Web search complete.",
|
||||
"summarizer": "Task summary generated."
|
||||
}
|
||||
|
||||
if (name && statusMap[name]) {
|
||||
// Find and replace the active status if it exists
|
||||
const activeStatus = name === "query_analyzer" ? "Analyzing query..." :
|
||||
name === "planner" ? "Generating strategic plan..." :
|
||||
name === "coder" ? "Writing analysis code..." : null;
|
||||
name === "planner" ? "Generating high-level plan..." :
|
||||
name === "delegate" ? "Routing task..." :
|
||||
name === "reflector" ? "Evaluating results..." :
|
||||
name === "coder" ? "Writing analysis code..." :
|
||||
name === "executor" ? "Executing code..." :
|
||||
name === "searcher" ? "Searching web..." :
|
||||
name === "summarizer" ? "Summarizing results..." : null;
|
||||
|
||||
let filteredSteps = lastMsg.steps || [];
|
||||
if (activeStatus) {
|
||||
@@ -159,9 +170,13 @@ export const ChatService = {
|
||||
if (type === "on_chain_start") {
|
||||
const startStatusMap: Record<string, string> = {
|
||||
"query_analyzer": "Analyzing query...",
|
||||
"planner": "Generating strategic plan...",
|
||||
"planner": "Generating high-level plan...",
|
||||
"delegate": "Routing task...",
|
||||
"reflector": "Evaluating results...",
|
||||
"coder": "Writing analysis code...",
|
||||
"executor": "Performing data analysis..."
|
||||
"executor": "Executing code...",
|
||||
"searcher": "Searching web...",
|
||||
"summarizer": "Summarizing results..."
|
||||
}
|
||||
|
||||
if (name && startStatusMap[name]) {
|
||||
@@ -170,9 +185,11 @@ export const ChatService = {
|
||||
const lastMsg = { ...newMessages[lastMsgIndex] }
|
||||
|
||||
if (lastMsg && lastMsg.role === "assistant") {
|
||||
// Avoid duplicate start messages
|
||||
if (!(lastMsg.steps || []).includes(startStatusMap[name])) {
|
||||
lastMsg.steps = [...(lastMsg.steps || []), startStatusMap[name]]
|
||||
const currentSteps = lastMsg.steps || []
|
||||
// Allow duplicate steps if it's a retry (cycle),
|
||||
// but avoid spamming the same active step multiple times in a row
|
||||
if (currentSteps[currentSteps.length - 1] !== startStatusMap[name]) {
|
||||
lastMsg.steps = [...currentSteps, startStatusMap[name]]
|
||||
newMessages[lastMsgIndex] = lastMsg
|
||||
return newMessages
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user