38 Commits

Author SHA1 Message Date
Yunxiao Xu
f62a70f7c3 chore(graph): Remove obsolete linear nodes and legacy tests 2026-02-23 19:32:45 -08:00
Yunxiao Xu
b9084fcaef feat(frontend): Allow duplicate progress steps to support cyclical worker loops and retries 2026-02-23 19:32:45 -08:00
Yunxiao Xu
8fcfc4ee88 feat(frontend): Update chat service and UI to support Orchestrator architecture and native subgraph streaming 2026-02-23 19:32:45 -08:00
Yunxiao Xu
b8d8651924 refactor(graph): Use piped Runnables for worker nodes to enable subgraph event streaming 2026-02-23 19:32:45 -08:00
Yunxiao Xu
02d93120e0 feat(api): Update Chat Stream Protocol for Orchestrator Architecture 2026-02-23 19:32:45 -08:00
Yunxiao Xu
92c30d217e feat(orchestrator): Harden VFS and enhance artifact awareness across workers 2026-02-23 19:32:45 -08:00
Yunxiao Xu
88a27f5a8d docs: Update project documentation to reflect Orchestrator-Workers architecture 2026-02-23 19:32:45 -08:00
Yunxiao Xu
4d92c9aedb fix(orchestrator): Remove clarification interrupt to allow single-pass generation
Also fixes a test assertion in the reflector test to align with LangGraph state updates.
2026-02-23 19:32:45 -08:00
Yunxiao Xu
557b553c59 fix(orchestrator): Enforce retry budget to prevent unbounded loops 2026-02-23 19:32:45 -08:00
Yunxiao Xu
2cfbc5d1d0 fix(orchestrator): Apply refinements from code review 2026-02-23 19:32:45 -08:00
Yunxiao Xu
c5cf4b38a1 docs(config): Update .env.example with new Orchestrator node configuration keys 2026-02-23 19:32:45 -08:00
Yunxiao Xu
46129c6f1e test(config): Add tests to verify node-specific LLM configuration usage 2026-02-23 19:32:45 -08:00
Yunxiao Xu
8eea464be4 test(orchestrator): Add E2E multi-worker sequential flow tests 2026-02-23 19:32:45 -08:00
Yunxiao Xu
11c14fb8a8 feat(config): Implement asymmetric model configuration for Orchestrator and Workers 2026-02-23 19:32:45 -08:00
Yunxiao Xu
9b97140fff fix(researcher): Handle non-string search results in summarizer node 2026-02-23 19:32:45 -08:00
Yunxiao Xu
9e90f2c9ad feat(orchestrator): Integrate Researcher worker subgraph into the Orchestrator loop 2026-02-23 19:32:45 -08:00
Yunxiao Xu
5cc5bd91ae feat(workers): Implement Researcher worker subgraph for web research tasks 2026-02-23 19:32:45 -08:00
Yunxiao Xu
120b6fd11a feat(workers): Define WorkerState for the Researcher subgraph 2026-02-23 19:32:45 -08:00
Yunxiao Xu
f4d09c07c4 chore(graph): Relocate QueryAnalysis schema and update existing tests for Orchestrator architecture 2026-02-23 19:32:45 -08:00
Yunxiao Xu
ad7845cc6a test(orchestrator): Add integration tests for the Orchestrator-Workers loop 2026-02-23 19:32:45 -08:00
Yunxiao Xu
18e4e8db7d feat(orchestrator): Integrate Orchestrator-Workers loop and human-in-the-loop interrupts 2026-02-23 19:32:45 -08:00
Yunxiao Xu
9fef4888b5 feat(orchestrator): Implement Synthesizer node for final worker results integration 2026-02-23 19:32:45 -08:00
Yunxiao Xu
37c353a249 feat(orchestrator): Implement Reflector node for task evaluation and plan advancement 2026-02-23 19:32:45 -08:00
Yunxiao Xu
ff9b443bfe feat(orchestrator): Implement Delegate node for task routing 2026-02-23 19:32:45 -08:00
Yunxiao Xu
575e1a2e53 feat(orchestrator): Implement high-level task decomposition in Planner node 2026-02-23 19:32:45 -08:00
Yunxiao Xu
013208b929 feat(workers): Implement input/output mapping for Data Analyst subgraph 2026-02-23 19:32:45 -08:00
Yunxiao Xu
cb045504d1 feat(workers): Extract Coder and Executor nodes into Data Analyst worker subgraph 2026-02-23 19:32:45 -08:00
Yunxiao Xu
5324cbe851 feat(workers): Define WorkerState for the Data Analyst subgraph 2026-02-23 19:32:45 -08:00
Yunxiao Xu
eeb2be409b feat(executor): Integrate VFS helper for in-memory artifact tracking 2026-02-23 19:32:45 -08:00
Yunxiao Xu
92d9288f38 feat(utils): Implement VFSHelper for in-memory artifact management 2026-02-23 19:32:45 -08:00
Yunxiao Xu
8957e93f3d feat(graph): Extend AgentState with checklist, current_step, and vfs 2026-02-23 19:32:45 -08:00
Yunxiao Xu
45fe122580 chore(frontend): update package-lock.json 2026-02-23 06:10:56 -08:00
Yunxiao Xu
969165f4a7 fix(ui): ensure dropdown checkbox and radio items have default cursor when disabled 2026-02-23 06:02:23 -08:00
Yunxiao Xu
e7be0dbeca test(ui): strengthen disabled cursor assertions and fix linting 2026-02-23 05:46:44 -08:00
Yunxiao Xu
322ae1e7c8 fix(ui): resolve disabled state cursor regression and add dropdown tests 2026-02-23 05:39:32 -08:00
Yunxiao Xu
46b57d2a73 fix(ui): add cursor-pointer to dropdown menu items 2026-02-23 05:20:38 -08:00
Yunxiao Xu
99ded43483 fix(ui): add cursor-pointer to sidebar components 2026-02-23 05:16:39 -08:00
Yunxiao Xu
1394c0496a fix(ui): add cursor-pointer to base button variants 2026-02-23 05:03:05 -08:00
81 changed files with 2483 additions and 909 deletions

View File

@@ -4,6 +4,7 @@ A stateful, graph-based chatbot for election data analysis, built with LangGraph
## 🚀 Features ## 🚀 Features
- **Multi-Agent Orchestration**: Decomposes complex queries and delegates them to specialized sub-agents (Data Analyst, Researcher) using a robust feedback loop.
- **Intelligent Query Analysis**: Automatically determines if a query needs data analysis, web research, or clarification. - **Intelligent Query Analysis**: Automatically determines if a query needs data analysis, web research, or clarification.
- **Automated Data Analysis**: Generates and executes Python code to analyze election datasets and produce visualizations. - **Automated Data Analysis**: Generates and executes Python code to analyze election datasets and produce visualizations.
- **Web Research**: Integrates web search capabilities for general election-related questions. - **Web Research**: Integrates web search capabilities for general election-related questions.

View File

@@ -57,3 +57,11 @@ OIDC_REDIRECT_URI=http://localhost:8000/api/v1/auth/oidc/callback
# Researcher # Researcher
# RESEARCHER_LLM__PROVIDER=google # RESEARCHER_LLM__PROVIDER=google
# RESEARCHER_LLM__MODEL=gemini-2.0-flash # RESEARCHER_LLM__MODEL=gemini-2.0-flash
# Reflector
# REFLECTOR_LLM__PROVIDER=openai
# REFLECTOR_LLM__MODEL=gpt-5-mini
# Synthesizer
# SYNTHESIZER_LLM__PROVIDER=openai
# SYNTHESIZER_LLM__MODEL=gpt-5-mini

View File

@@ -1,12 +1,13 @@
# Election Analytics Chatbot - Backend Guide # Election Analytics Chatbot - Backend Guide
## Overview ## Overview
The backend is a Python-based FastAPI application that leverages **LangGraph** to provide a stateful, agentic workflow for election data analysis. It handles complex queries by decomposing them into tasks such as data analysis, web research, or user clarification. The backend is a Python-based FastAPI application that leverages **LangGraph** to provide a stateful, hierarchical multi-agent workflow for election data analysis. It handles complex queries using an Orchestrator-Workers pattern, decomposing tasks and delegating them to specialized subgraphs (Data Analyst, Researcher) with built-in reflection and error recovery.
## 1. Architecture Overview ## 1. Architecture Overview
- **Framework**: LangGraph for workflow orchestration and state management. - **Framework**: LangGraph for hierarchical workflow orchestration and state management.
- **API**: FastAPI for providing REST and streaming (SSE) endpoints. - **API**: FastAPI for providing REST and streaming (SSE) endpoints.
- **State Management**: Persistent state using LangGraph's `StateGraph` with a PostgreSQL checkpointer. - **State Management**: Persistent state using LangGraph's `StateGraph` with a PostgreSQL checkpointer. Maintains global state (`AgentState`) and isolated worker states (`WorkerState`).
- **Virtual File System (VFS)**: An in-memory abstraction passed between nodes to manage intermediate artifacts (scripts, CSVs, charts) without bloating the context window.
- **Database**: PostgreSQL. - **Database**: PostgreSQL.
- Application data: Uses `users` table for local and OIDC users (String IDs). - Application data: Uses `users` table for local and OIDC users (String IDs).
- History: Persists chat history and artifacts. - History: Persists chat history and artifacts.
@@ -14,23 +15,28 @@ The backend is a Python-based FastAPI application that leverages **LangGraph** t
## 2. Core Components ## 2. Core Components
### 2.1. The Graph State (`src/ea_chatbot/graph/state.py`) ### 2.1. State Management (`src/ea_chatbot/graph/state.py` & `workers/*/state.py`)
The state tracks the conversation context, plan, generated code, execution results, and artifacts. - **Global State**: Tracks the conversation context, the high-level task `checklist`, execution progress (`current_step`), and the VFS.
- **Worker State**: Isolated snapshot for specialized subgraphs, tracking internal retry loops (`iterations`), worker-specific prompts, and raw results.
### 2.2. Nodes (The Actors) ### 2.2. The Orchestrator
Located in `src/ea_chatbot/graph/nodes/`: Located in `src/ea_chatbot/graph/nodes/`:
- **`query_analyzer`**: Analyzes the user query to determine the intent and required data. - **`query_analyzer`**: Analyzes the user query to determine the intent and required data. If ambiguous, routes to `clarification`.
- **`planner`**: Creates a step-by-step plan for data analysis. - **`planner`**: Decomposes the user request into a strategic `checklist` of sub-tasks assigned to specific workers.
- **`coder`**: Generates Python code based on the plan and dataset metadata. - **`delegate`**: The traffic controller. Routes the current task to the appropriate worker and enforces a strict retry budget to prevent infinite loops.
- **`executor`**: Safely executes the generated code and captures outputs (dataframes, plots). - **`reflector`**: The quality control node. Evaluates a worker's summary against the sub-task requirements. Can trigger a retry if unsatisfied.
- **`error_corrector`**: Fixes code if execution fails. - **`synthesizer`**: Aggregates all worker results into a final, cohesive response for the user.
- **`researcher`**: Performs web searches for general election information. - **`clarification`**: Asks the user for more information if the query is critically ambiguous.
- **`summarizer`**: Generates a natural language response based on the analysis results.
- **`clarification`**: Asks the user for more information if the query is ambiguous.
### 2.3. The Workflow (Graph) ### 2.3. Specialized Workers (Sub-Graphs)
The graph connects these nodes with conditional edges, allowing for iterative refinement and error correction. Located in `src/ea_chatbot/graph/workers/`:
- **`data_analyst`**: Generates Python/SQL code, executes it securely, and captures dataframes/plots. Contains an internal retry loop (`coder` -> `executor` -> error check -> `coder`).
- **`researcher`**: Performs web searches for general election information and synthesizes factual findings.
### 2.4. The Workflow
The global graph connects the Orchestrator nodes, wrapping the Worker subgraphs as self-contained nodes with mapped inputs and outputs.
## 3. Key Modules ## 3. Key Modules

View File

@@ -56,7 +56,8 @@ async def stream_agent_events(
initial_state, initial_state,
config, config,
version="v2", version="v2",
checkpointer=checkpointer checkpointer=checkpointer,
subgraphs=True
): ):
kind = event.get("event") kind = event.get("event")
name = event.get("name") name = event.get("name")
@@ -71,8 +72,8 @@ async def stream_agent_events(
"data": data "data": data
} }
# Buffer assistant chunks (summarizer and researcher might stream) # Buffer assistant chunks (synthesizer and clarification might stream)
if kind == "on_chat_model_stream" and node_name in ["summarizer", "researcher", "clarification"]: if kind == "on_chat_model_stream" and node_name in ["synthesizer", "clarification"]:
chunk = data.get("chunk", "") chunk = data.get("chunk", "")
# Use utility to safely extract text content from the chunk # Use utility to safely extract text content from the chunk
chunk_data = convert_to_json_compatible(chunk) chunk_data = convert_to_json_compatible(chunk)
@@ -83,7 +84,7 @@ async def stream_agent_events(
assistant_chunks.append(str(chunk_data)) assistant_chunks.append(str(chunk_data))
# Buffer and encode plots # Buffer and encode plots
if kind == "on_chain_end" and name == "executor": if kind == "on_chain_end" and name == "data_analyst_worker":
output = data.get("output", {}) output = data.get("output", {})
if isinstance(output, dict) and "plots" in output: if isinstance(output, dict) and "plots" in output:
plots = output["plots"] plots = output["plots"]
@@ -95,7 +96,7 @@ async def stream_agent_events(
output_event["data"]["encoded_plots"] = encoded_plots output_event["data"]["encoded_plots"] = encoded_plots
# Collect final response from terminal nodes # Collect final response from terminal nodes
if kind == "on_chain_end" and name in ["summarizer", "researcher", "clarification"]: if kind == "on_chain_end" and name in ["synthesizer", "clarification"]:
output = data.get("output", {}) output = data.get("output", {})
if isinstance(output, dict) and "messages" in output: if isinstance(output, dict) and "messages" in output:
last_msg = output["messages"][-1] last_msg = output["messages"][-1]

View File

@@ -52,6 +52,8 @@ class Settings(BaseSettings):
coder_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0)) coder_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0))
summarizer_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0)) summarizer_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0))
researcher_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0)) researcher_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0))
reflector_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0))
synthesizer_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0))
# Allow nested env vars like QUERY_ANALYZER_LLM__MODEL # Allow nested env vars like QUERY_ANALYZER_LLM__MODEL
model_config = SettingsConfigDict(env_nested_delimiter='__', env_prefix='') model_config = SettingsConfigDict(env_nested_delimiter='__', env_prefix='')

View File

@@ -36,8 +36,7 @@ Please ask the user for the necessary details."""
response = llm.invoke(messages) response = llm.invoke(messages)
logger.info("[bold green]Clarification generated.[/bold green]") logger.info("[bold green]Clarification generated.[/bold green]")
return { return {
"messages": [response], "messages": [response]
"next_action": "end" # To indicate we are done for now
} }
except Exception as e: except Exception as e:
logger.error(f"Failed to generate clarification: {str(e)}") logger.error(f"Failed to generate clarification: {str(e)}")

View File

@@ -1,47 +0,0 @@
from ea_chatbot.graph.state import AgentState
from ea_chatbot.config import Settings
from ea_chatbot.utils.llm_factory import get_llm_model
from ea_chatbot.utils import database_inspection
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
from ea_chatbot.graph.prompts.coder import CODE_GENERATOR_PROMPT
from ea_chatbot.schemas import CodeGenerationResponse
def coder_node(state: AgentState) -> dict:
"""Generate Python code based on the plan and data summary."""
question = state["question"]
plan = state.get("plan", "")
code_output = state.get("code_output", "None")
settings = Settings()
logger = get_logger("coder")
logger.info("Generating Python code...")
llm = get_llm_model(
settings.coder_llm,
callbacks=[LangChainLoggingHandler(logger=logger)]
)
structured_llm = llm.with_structured_output(CodeGenerationResponse)
# Always provide data summary
database_description = database_inspection.get_data_summary(data_dir=settings.data_dir) or "No data available."
example_code = "" # Placeholder
messages = CODE_GENERATOR_PROMPT.format_messages(
question=question,
plan=plan,
database_description=database_description,
code_exec_results=code_output,
example_code=example_code
)
try:
response = structured_llm.invoke(messages)
logger.info("[bold green]Code generated.[/bold green]")
return {
"code": response.parsed_code,
"error": None # Clear previous errors on new code generation
}
except Exception as e:
logger.error(f"Failed to generate code: {str(e)}")
raise e

View File

@@ -0,0 +1,28 @@
from ea_chatbot.graph.state import AgentState
from ea_chatbot.utils.logging import get_logger
def delegate_node(state: AgentState) -> dict:
"""Determine which worker subgraph to call next based on the checklist."""
checklist = state.get("checklist", [])
current_step = state.get("current_step", 0)
iterations = state.get("iterations", 0)
logger = get_logger("orchestrator:delegate")
if not checklist or current_step >= len(checklist):
logger.info("Checklist complete or empty. Routing to summarizer.")
return {"next_action": "summarize"}
# Enforce retry budget
if iterations >= 3:
logger.error(f"Max retries reached for task {current_step}. Routing to summary with failure.")
return {
"next_action": "summarize",
"iterations": 0, # Reset for next turn
"summary": f"Failed to complete task {current_step} after {iterations} attempts."
}
task_info = checklist[current_step]
worker = task_info.get("worker", "data_analyst")
logger.info(f"Delegating next task to worker: {worker} (Attempt {iterations + 1})")
return {"next_action": worker}

View File

@@ -1,44 +0,0 @@
from ea_chatbot.graph.state import AgentState
from ea_chatbot.config import Settings
from ea_chatbot.utils.llm_factory import get_llm_model
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
from ea_chatbot.graph.prompts.coder import ERROR_CORRECTOR_PROMPT
from ea_chatbot.schemas import CodeGenerationResponse
def error_corrector_node(state: AgentState) -> dict:
"""Fix the code based on the execution error."""
code = state.get("code", "")
error = state.get("error", "Unknown error")
settings = Settings()
logger = get_logger("error_corrector")
logger.warning(f"[bold red]Execution error detected:[/bold red] {error[:100]}...")
logger.info("Attempting to correct the code...")
# Reuse coder LLM config or add a new one. Using coder_llm for now.
llm = get_llm_model(
settings.coder_llm,
callbacks=[LangChainLoggingHandler(logger=logger)]
)
structured_llm = llm.with_structured_output(CodeGenerationResponse)
messages = ERROR_CORRECTOR_PROMPT.format_messages(
code=code,
error=error
)
try:
response = structured_llm.invoke(messages)
logger.info("[bold green]Correction generated.[/bold green]")
current_iterations = state.get("iterations", 0)
return {
"code": response.parsed_code,
"error": None, # Clear error after fix attempt
"iterations": current_iterations + 1
}
except Exception as e:
logger.error(f"Failed to correct code: {str(e)}")
raise e

View File

@@ -1,34 +1,32 @@
import yaml
from ea_chatbot.graph.state import AgentState from ea_chatbot.graph.state import AgentState
from ea_chatbot.config import Settings from ea_chatbot.config import Settings
from ea_chatbot.utils.llm_factory import get_llm_model from ea_chatbot.utils.llm_factory import get_llm_model
from ea_chatbot.utils import helpers, database_inspection from ea_chatbot.utils import helpers, database_inspection
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
from ea_chatbot.graph.prompts.planner import PLANNER_PROMPT from ea_chatbot.graph.prompts.planner import PLANNER_PROMPT
from ea_chatbot.schemas import TaskPlanResponse from ea_chatbot.schemas import ChecklistResponse
def planner_node(state: AgentState) -> dict: def planner_node(state: AgentState) -> dict:
"""Generate a structured plan based on the query analysis.""" """Generate a high-level task checklist for the Orchestrator."""
question = state["question"] question = state["question"]
history = state.get("messages", [])[-6:] history = state.get("messages", [])[-6:]
summary = state.get("summary", "") summary = state.get("summary", "")
settings = Settings() settings = Settings()
logger = get_logger("planner") logger = get_logger("orchestrator:planner")
logger.info("Generating task plan...") logger.info("Generating high-level task checklist...")
llm = get_llm_model( llm = get_llm_model(
settings.planner_llm, settings.planner_llm,
callbacks=[LangChainLoggingHandler(logger=logger)] callbacks=[LangChainLoggingHandler(logger=logger)]
) )
structured_llm = llm.with_structured_output(TaskPlanResponse) structured_llm = llm.with_structured_output(ChecklistResponse)
date_str = helpers.get_readable_date() date_str = helpers.get_readable_date()
# Always provide data summary; LLM decides relevance. # Data summary for context
database_description = database_inspection.get_data_summary(data_dir=settings.data_dir) or "No data available." database_description = database_inspection.get_data_summary(data_dir=settings.data_dir) or "No data available."
example_plan = ""
messages = PLANNER_PROMPT.format_messages( messages = PLANNER_PROMPT.format_messages(
date=date_str, date=date_str,
@@ -36,16 +34,20 @@ def planner_node(state: AgentState) -> dict:
history=history, history=history,
summary=summary, summary=summary,
database_description=database_description, database_description=database_description,
example_plan=example_plan example_plan="Decompose into data_analyst and researcher tasks."
) )
# Generate the structured plan
try: try:
response = structured_llm.invoke(messages) response = ChecklistResponse.model_validate(structured_llm.invoke(messages))
# Convert the structured response back to YAML string for the state # Convert ChecklistTask objects to dicts for state
plan_yaml = yaml.dump(response.model_dump(), sort_keys=False) checklist = [task.model_dump() for task in response.checklist]
logger.info("[bold green]Plan generated successfully.[/bold green]") logger.info(f"[bold green]Checklist generated with {len(checklist)} tasks.[/bold green]")
return {"plan": plan_yaml} return {
"checklist": checklist,
"current_step": 0,
"iterations": 0, # Reset iteration counter for the new plan
"summary": response.reflection # Use reflection as initial summary
}
except Exception as e: except Exception as e:
logger.error(f"Failed to generate plan: {str(e)}") logger.error(f"Failed to generate checklist: {str(e)}")
raise e raise e

View File

@@ -1,18 +1,9 @@
from typing import List, Literal
from pydantic import BaseModel, Field
from ea_chatbot.graph.state import AgentState from ea_chatbot.graph.state import AgentState
from ea_chatbot.config import Settings from ea_chatbot.config import Settings
from ea_chatbot.utils.llm_factory import get_llm_model from ea_chatbot.utils.llm_factory import get_llm_model
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
from ea_chatbot.graph.prompts.query_analyzer import QUERY_ANALYZER_PROMPT from ea_chatbot.graph.prompts.query_analyzer import QUERY_ANALYZER_PROMPT
from ea_chatbot.schemas import QueryAnalysis
class QueryAnalysis(BaseModel):
"""Analysis of the user's query."""
data_required: List[str] = Field(description="List of data points or entities mentioned (e.g., ['2024 results', 'Florida']).")
unknowns: List[str] = Field(description="List of target information the user wants to know or needed for final answer (e.g., 'who won', 'total votes').")
ambiguities: List[str] = Field(description="List of CRITICAL missing details that prevent ANY analysis. Do NOT include database names or plot types if defaults can be used.")
conditions: List[str] = Field(description="List of any filters or constraints (e.g., ['year=2024', 'state=Florida']). Include context resolved from history.")
next_action: Literal["plan", "clarify", "research"] = Field(description="The next action to take. 'plan' for data analysis (even with defaults), 'research' for general knowledge, or 'clarify' ONLY for critical ambiguities.")
def query_analyzer_node(state: AgentState) -> dict: def query_analyzer_node(state: AgentState) -> dict:
"""Analyze the user's question and determine the next course of action.""" """Analyze the user's question and determine the next course of action."""

View File

@@ -0,0 +1,63 @@
from ea_chatbot.graph.state import AgentState
from ea_chatbot.config import Settings
from ea_chatbot.utils.llm_factory import get_llm_model
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
from ea_chatbot.schemas import ReflectorResponse
def reflector_node(state: AgentState) -> dict:
"""Evaluate if the worker's output satisfies the current sub-task."""
checklist = state.get("checklist", [])
current_step = state.get("current_step", 0)
summary = state.get("summary", "") # This contains the worker's summary
if not checklist or current_step >= len(checklist):
return {"next_action": "summarize"}
task_info = checklist[current_step]
task_desc = task_info.get("task", "")
settings = Settings()
logger = get_logger("orchestrator:reflector")
logger.info(f"Evaluating worker output for task: {task_desc[:50]}...")
llm = get_llm_model(
settings.reflector_llm,
callbacks=[LangChainLoggingHandler(logger=logger)]
)
structured_llm = llm.with_structured_output(ReflectorResponse)
prompt = f"""You are a Lead Orchestrator evaluating the work of a specialized sub-agent.
**Sub-Task assigned:**
{task_desc}
**Worker's Result Summary:**
{summary}
Evaluate if the result is satisfactory and complete for this specific sub-task.
If there were major errors or the output is missing critical data requested in the sub-task, mark satisfied as False."""
try:
response = structured_llm.invoke(prompt)
if response.satisfied:
logger.info("[bold green]Sub-task satisfied.[/bold green] Advancing plan.")
return {
"current_step": current_step + 1,
"iterations": 0, # Reset for next task
"next_action": "delegate"
}
else:
logger.warning(f"[bold yellow]Sub-task NOT satisfied.[/bold yellow] Reason: {response.reasoning}")
# Do NOT advance the step. Increment iterations to track retries.
return {
"iterations": state.get("iterations", 0) + 1,
"next_action": "delegate"
}
except Exception as e:
logger.error(f"Failed to reflect: {str(e)}")
# On error, increment iterations to avoid infinite loop if LLM is stuck
return {
"iterations": state.get("iterations", 0) + 1,
"next_action": "delegate"
}

View File

@@ -1,43 +0,0 @@
from ea_chatbot.graph.state import AgentState
from ea_chatbot.config import Settings
from ea_chatbot.utils.llm_factory import get_llm_model
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
from ea_chatbot.graph.prompts.summarizer import SUMMARIZER_PROMPT
def summarizer_node(state: AgentState) -> dict:
"""Summarize the code execution results into a final answer."""
question = state["question"]
plan = state.get("plan", "")
code_output = state.get("code_output", "")
history = state.get("messages", [])[-6:]
summary = state.get("summary", "")
settings = Settings()
logger = get_logger("summarizer")
logger.info("Generating final summary...")
llm = get_llm_model(
settings.summarizer_llm,
callbacks=[LangChainLoggingHandler(logger=logger)]
)
messages = SUMMARIZER_PROMPT.format_messages(
question=question,
plan=plan,
code_output=code_output,
history=history,
summary=summary
)
try:
response = llm.invoke(messages)
logger.info("[bold green]Summary generated.[/bold green]")
# Return the final message to be added to the state
return {
"messages": [response]
}
except Exception as e:
logger.error(f"Failed to generate summary: {str(e)}")
raise e

View File

@@ -0,0 +1,56 @@
from ea_chatbot.graph.state import AgentState
from ea_chatbot.config import Settings
from ea_chatbot.utils.llm_factory import get_llm_model
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
from ea_chatbot.graph.prompts.synthesizer import SYNTHESIZER_PROMPT
def synthesizer_node(state: AgentState) -> dict:
"""Synthesize the results from multiple workers into a final answer."""
question = state["question"]
history = state.get("messages", [])
# We look for the 'summary' from the last worker which might have cumulative info
# Or we can look at all messages in history bubbled up from workers.
# For now, let's assume the history contains all the worker summaries.
settings = Settings()
logger = get_logger("orchestrator:synthesizer")
logger.info("Synthesizing final answer from worker results...")
llm = get_llm_model(
settings.synthesizer_llm,
callbacks=[LangChainLoggingHandler(logger=logger)]
)
# Artifact summary
plots = state.get("plots", [])
vfs = state.get("vfs", {})
artifacts_summary = ""
if plots:
artifacts_summary += f"- {len(plots)} generated plot(s) are attached to this response.\n"
if vfs:
artifacts_summary += "- Data files available in VFS: " + ", ".join(vfs.keys()) + "\n"
if not artifacts_summary:
artifacts_summary = "No additional artifacts generated."
# We provide the full history and the original question
messages = SYNTHESIZER_PROMPT.format_messages(
question=question,
history=history,
worker_results="Review the worker summaries provided in the message history.",
artifacts_summary=artifacts_summary
)
try:
response = llm.invoke(messages)
logger.info("[bold green]Final synthesis complete.[/bold green]")
# Return the final message to be added to the state
return {
"messages": [response],
"next_action": "end"
}
except Exception as e:
logger.error(f"Failed to synthesize final answer: {str(e)}")
raise e

View File

@@ -9,6 +9,14 @@ The user will provide a task and a plan.
- Do NOT assume a dataframe `df` is already loaded unless explicitly stated. You usually need to query it first. - Do NOT assume a dataframe `df` is already loaded unless explicitly stated. You usually need to query it first.
- The database schema is described in the prompt. Use it to construct valid SQL queries. - The database schema is described in the prompt. Use it to construct valid SQL queries.
**Virtual File System (VFS):**
- An in-memory file system is available as `vfs`. Use it to persist intermediate data or large artifacts.
- `vfs.write(filename, content, metadata=None)`: Save a file (content can be any serializable object).
- `vfs.read(filename) -> (content, metadata)`: Read a file.
- `vfs.list() -> list[str]`: List all files.
- `vfs.delete(filename)`: Delete a file.
- Prefer using VFS for intermediate DataFrames or complex data structures instead of printing everything.
**Plotting:** **Plotting:**
- If you need to plot any data, use the `plots` list to store the figures. - If you need to plot any data, use the `plots` list to store the figures.
- Example: `plots.append(fig)` or `plots.append(plt.gcf())`. - Example: `plots.append(fig)` or `plots.append(plt.gcf())`.
@@ -18,7 +26,8 @@ The user will provide a task and a plan.
- Produce FULL, COMPLETE CODE that includes all steps and solves the task! - Produce FULL, COMPLETE CODE that includes all steps and solves the task!
- Always include the import statements at the top of the code (e.g., `import pandas as pd`, `import matplotlib.pyplot as plt`). - Always include the import statements at the top of the code (e.g., `import pandas as pd`, `import matplotlib.pyplot as plt`).
- Always include print statements to output the results of your code. - Always include print statements to output the results of your code.
- Use `db.query_df("SELECT ...")` to get data.""" - Use `db.query_df("SELECT ...")` to get data.
"""
CODE_GENERATOR_USER = """TASK: CODE_GENERATOR_USER = """TASK:
{question} {question}
@@ -43,6 +52,7 @@ Return a complete, corrected python code that incorporates the fixes for the err
- You have access to a database client via the variable `db`. - You have access to a database client via the variable `db`.
- Use `db.query_df(sql)` to run queries. - Use `db.query_df(sql)` to run queries.
- Use `plots.append(fig)` for plots. - Use `plots.append(fig)` for plots.
- You have access to `vfs` for persistent in-memory storage.
- Always include imports and print statements.""" - Always include imports and print statements."""
ERROR_CORRECTOR_USER = """FAILED CODE: ERROR_CORRECTOR_USER = """FAILED CODE:

View File

@@ -1,41 +1,29 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
PLANNER_SYSTEM = """You are a Research Specialist and your job is to find answers and educate the user. PLANNER_SYSTEM = """You are a Lead Orchestrator for an Election Analytics Chatbot.
Provide factual information responding directly to the user's question. Include key details and context to ensure your response comprehensively answers their query. Your job is to decompose complex user queries into a high-level checklist of tasks.
**Specialized Workers:**
1. `data_analyst`: Handles SQL queries, Python data analysis, and plotting. Use this when the user needs numbers, trends, or charts from the internal database.
2. `researcher`: Performs web searches for current news, facts, or external data not in the primary database.
**Orchestration Strategy:**
- Analyze the user's question and the available data summary.
- Create a logical sequence of tasks (checklist) for these workers.
- Be specific in the task description for the worker (e.g., "Find the total votes in Florida 2020").
- If the query is ambiguous, the Orchestrator loop will later handle clarification, but for now, make the best plan possible.
Today's Date is: {date}""" Today's Date is: {date}"""
PLANNER_USER = """Conversation Summary: {summary} PLANNER_USER = """Conversation Summary: {summary}
TASK: USER QUESTION:
{question} {question}
AVAILABLE DATA SUMMARY (Use only if relevant to the task): AVAILABLE DATABASE SUMMARY:
{database_description} {database_description}
First: Evaluate whether you have all necessary and requested information to provide a solution. Decompose the question into a strategic checklist. For each task, specify which worker should handle it.
Use the dataset description above to determine what data and in what format you have available to you.
You are able to search internet if the user asks for it, or you require any information that you can not derive from the given dataset or the instruction.
Second: Incorporate any additional relevant context, reasoning, or details from previous interactions or internal chain-of-thought that may impact the solution.
Ensure that all such information is fully included in your response rather than referring to previous answers indirectly.
Third: Reflect on the problem and briefly describe it, while addressing the problem goal, inputs, outputs,
rules, constraints, and other relevant details that appear in the problem description.
Fourth: Based on the preceding steps, formulate your response as an algorithm, breaking the solution in up to eight simple concise yet descriptive, clear English steps.
You MUST Include all values or instructions as described in the above task, or retrieved using internet search!
If fewer steps suffice, that's acceptable. If more are needed, please include them.
Remember to explain steps rather than write code.
This algorithm will be later converted to Python code.
If a dataframe is required, assume it is named 'df' and is already defined/populated based on the data summary.
There is a list variable called `plots` that you need to use to store any plots you generate. Do not use `plt.show()` as it will render the plot and cause an error.
Output the algorithm as a YAML string. Always enclose the YAML string within ```yaml tags.
**Note: Ensure that any necessary context from prior interactions is fully embedded in the plan. Do not use phrases like "refer to previous answer"; instead, provide complete details inline.**
{example_plan}""" {example_plan}"""

View File

@@ -0,0 +1,31 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
SYNTHESIZER_SYSTEM = """You are a Lead Orchestrator for an Election Analytics Chatbot.
You have coordinated several specialized workers (Data Analysts, Researchers) to answer a user's complex query.
Your goal is to synthesize their individual findings into a single, cohesive, and comprehensive final response for the user.
**Guidelines:**
- Do NOT mention the internal 'workers' or 'checklist' names.
- Combine the data insights (from Data Analysts) and factual research (from Researchers) into a natural narrative.
- Ensure all numbers, dates, and names from the worker reports are included accurately.
- **Artifacts & Plots:** If plots or charts were generated, refer to them naturally (e.g., "The chart below shows...").
- If any part of the plan failed, explain the status honestly but professionally.
- Present data in clear formats (tables, bullet points) where appropriate."""
SYNTHESIZER_USER = """USER QUESTION:
{question}
EXECUTION SUMMARY (Results from specialized workers):
{worker_results}
AVAILABLE ARTIFACTS:
{artifacts_summary}
Provide the final integrated response:"""
SYNTHESIZER_PROMPT = ChatPromptTemplate.from_messages([
("system", SYNTHESIZER_SYSTEM),
MessagesPlaceholder(variable_name="history"),
("human", SYNTHESIZER_USER),
])

View File

@@ -12,12 +12,12 @@ class AgentState(AS):
# Query Analysis (Decomposition results) # Query Analysis (Decomposition results)
analysis: Optional[Dict[str, Any]] analysis: Optional[Dict[str, Any]]
# Expected keys: "requires_dataset", "expert", "data", "unknown", "condition" # Expected keys: "data_required", "unknowns", "ambiguities", "conditions"
# Step-by-step reasoning # Step-by-step reasoning (Legacy, use checklist for new flows)
plan: Optional[str] plan: Optional[str]
# Code execution context # Code execution context (Legacy, use workers for new flows)
code: Optional[str] code: Optional[str]
code_output: Optional[str] code_output: Optional[str]
error: Optional[str] error: Optional[str]
@@ -26,11 +26,21 @@ class AgentState(AS):
plots: Annotated[List[Any], operator.add] # Matplotlib figures plots: Annotated[List[Any], operator.add] # Matplotlib figures
dfs: Dict[str, Any] # Pandas DataFrames dfs: Dict[str, Any] # Pandas DataFrames
# Conversation summary # Conversation summary / Latest worker result
summary: Optional[str] summary: Optional[str]
# Routing hint: "clarify", "plan", "research", "end" # Routing hint: "clarify", "plan", "end"
next_action: str next_action: str
# Number of execution attempts # Number of execution attempts
iterations: int iterations: int
# --- DeepAgents Extensions ---
# High-level plan/checklist: List of Task objects (dicts)
checklist: List[Dict[str, Any]]
# Current active step in the checklist
current_step: int
# Virtual File System (VFS): Map of filenames to content/metadata
vfs: Dict[str, Any]

View File

@@ -0,0 +1,42 @@
from typing import Dict, Any, List
import copy
from langchain_core.messages import HumanMessage, AIMessage
from ea_chatbot.graph.state import AgentState
from ea_chatbot.graph.workers.data_analyst.state import WorkerState
from ea_chatbot.utils.vfs import safe_vfs_copy
def prepare_worker_input(state: AgentState) -> Dict[str, Any]:
"""Prepare the initial state for the Data Analyst worker."""
checklist = state.get("checklist", [])
current_step = state.get("current_step", 0)
# Get the current task description
task_desc = "Analyze data" # Default
if 0 <= current_step < len(checklist):
task_desc = checklist[current_step].get("task", task_desc)
return {
"task": task_desc,
"messages": [HumanMessage(content=task_desc)], # Start worker loop with the task
"vfs_state": safe_vfs_copy(state.get("vfs", {})),
"iterations": 0,
"plots": [],
"code": None,
"output": None,
"error": None,
"result": None
}
def merge_worker_output(worker_state: WorkerState) -> Dict[str, Any]:
"""Map worker results back to the global AgentState."""
result = worker_state.get("result") or "No result produced by data analyst worker."
# We bubble up the summary as an AI message and update the global summary field
updates = {
"messages": [AIMessage(content=result)],
"summary": result,
"vfs": worker_state.get("vfs_state", {}),
"plots": worker_state.get("plots", [])
}
return updates

View File

@@ -0,0 +1,64 @@
from typing import Dict, Any, List, Optional
from ea_chatbot.graph.workers.data_analyst.state import WorkerState
from ea_chatbot.config import Settings
from ea_chatbot.utils.llm_factory import get_llm_model
from ea_chatbot.utils import database_inspection
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
from ea_chatbot.graph.prompts.coder import CODE_GENERATOR_PROMPT
from ea_chatbot.schemas import CodeGenerationResponse
def coder_node(state: WorkerState) -> dict:
"""Generate Python code based on the sub-task assigned to the worker."""
task = state["task"]
output = state.get("output", "None")
error = state.get("error", "None")
vfs_state = state.get("vfs_state", {})
settings = Settings()
logger = get_logger("data_analyst_worker:coder")
logger.info(f"Generating Python code for task: {task[:50]}...")
# We can use the configured 'coder_llm' for this node
llm = get_llm_model(
settings.coder_llm,
callbacks=[LangChainLoggingHandler(logger=logger)]
)
structured_llm = llm.with_structured_output(CodeGenerationResponse)
# Data summary for context
database_description = database_inspection.get_data_summary(data_dir=settings.data_dir) or "No data available."
# VFS Summary: Let the LLM know which files are available in-memory
vfs_summary = "Available in-memory files (VFS):\n"
if vfs_state:
for filename, data in vfs_state.items():
if isinstance(data, dict):
meta = data.get("metadata", {})
vfs_summary += f"- {filename} ({meta.get('type', 'unknown')})\n"
else:
vfs_summary += f"- {filename} (raw data)\n"
else:
vfs_summary += "- None"
# Reuse the global prompt but adapt 'question' and 'plan' labels
# For a sub-task worker, 'task' effectively replaces the high-level 'plan'
messages = CODE_GENERATOR_PROMPT.format_messages(
question=task,
plan="Focus on the specific task below.",
database_description=database_description,
code_exec_results=f"Output: {output}\nError: {error}\n\n{vfs_summary}",
example_code=""
)
try:
response = structured_llm.invoke(messages)
logger.info("[bold green]Code generated.[/bold green]")
return {
"code": response.parsed_code,
"error": None, # Clear previous errors
"iterations": state.get("iterations", 0) + 1
}
except Exception as e:
logger.error(f"Failed to generate code: {str(e)}")
raise e

View File

@@ -1,23 +1,25 @@
import io import io
import sys import sys
import traceback import traceback
import copy
from contextlib import redirect_stdout from contextlib import redirect_stdout
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import pandas as pd import pandas as pd
from matplotlib.figure import Figure from matplotlib.figure import Figure
from ea_chatbot.graph.state import AgentState from ea_chatbot.graph.workers.data_analyst.state import WorkerState
from ea_chatbot.utils.db_client import DBClient from ea_chatbot.utils.db_client import DBClient
from ea_chatbot.utils.vfs import VFSHelper, safe_vfs_copy
from ea_chatbot.utils.logging import get_logger from ea_chatbot.utils.logging import get_logger
from ea_chatbot.config import Settings from ea_chatbot.config import Settings
if TYPE_CHECKING: if TYPE_CHECKING:
from ea_chatbot.types import DBSettings from ea_chatbot.types import DBSettings
def executor_node(state: AgentState) -> dict: def executor_node(state: WorkerState) -> dict:
"""Execute the Python code and capture output, plots, and dataframes.""" """Execute the Python code in the context of the Data Analyst worker."""
code = state.get("code") code = state.get("code")
logger = get_logger("executor") logger = get_logger("data_analyst_worker:executor")
if not code: if not code:
logger.error("No code provided to executor.") logger.error("No code provided to executor.")
@@ -37,45 +39,40 @@ def executor_node(state: AgentState) -> dict:
db_client = DBClient(settings=db_settings) db_client = DBClient(settings=db_settings)
# Initialize the Virtual File System (VFS) helper with the snapshot from state
vfs_state = safe_vfs_copy(state.get("vfs_state", {}))
vfs_helper = VFSHelper(vfs_state)
# Initialize local variables for execution # Initialize local variables for execution
# 'db' is the DBClient instance, 'plots' is for matplotlib figures
local_vars = { local_vars = {
'db': db_client, 'db': db_client,
'plots': [], 'plots': [],
'pd': pd 'pd': pd,
'vfs': vfs_helper
} }
stdout_buffer = io.StringIO() stdout_buffer = io.StringIO()
error = None error = None
code_output = "" output = ""
plots = [] plots = []
dfs = {}
try: try:
with redirect_stdout(stdout_buffer): with redirect_stdout(stdout_buffer):
# Execute the code in the context of local_vars # Execute the code in the context of local_vars
exec(code, {}, local_vars) exec(code, {}, local_vars)
code_output = stdout_buffer.getvalue() output = stdout_buffer.getvalue()
# Limit the output length if it's too long # Limit the output length if it's too long
if code_output.count('\n') > 32: if output.count('\n') > 32:
code_output = '\n'.join(code_output.split('\n')[:32]) + '\n...' output = '\n'.join(output.split('\n')[:32]) + '\n...'
# Extract plots # Extract plots
raw_plots = local_vars.get('plots', []) raw_plots = local_vars.get('plots', [])
if isinstance(raw_plots, list): if isinstance(raw_plots, list):
plots = [p for p in raw_plots if isinstance(p, Figure)] plots = [p for p in raw_plots if isinstance(p, Figure)]
# Extract DataFrames that were likely intended for display logger.info(f"[bold green]Execution complete.[/bold green] Captured {len(plots)} plots.")
# We look for DataFrames in local_vars that were mentioned in the code
for key, value in local_vars.items():
if isinstance(value, pd.DataFrame):
# Heuristic: if the variable name is in the code, it might be a result DF
if key in code:
dfs[key] = value
logger.info(f"[bold green]Execution complete.[/bold green] Captured {len(plots)} plots and {len(dfs)} dataframes.")
except Exception as e: except Exception as e:
# Capture the traceback # Capture the traceback
@@ -90,13 +87,11 @@ def executor_node(state: AgentState) -> dict:
error += f"{exc_type.__name__ if exc_type else 'Exception'}: {exc_value}" error += f"{exc_type.__name__ if exc_type else 'Exception'}: {exc_value}"
logger.error(f"Execution failed: {str(e)}") logger.error(f"Execution failed: {str(e)}")
output = stdout_buffer.getvalue()
# If we have an error, we still might want to see partial stdout
code_output = stdout_buffer.getvalue()
return { return {
"code_output": code_output, "output": output,
"error": error, "error": error,
"plots": plots, "plots": plots,
"dfs": dfs "vfs_state": vfs_state
} }

View File

@@ -0,0 +1,53 @@
from typing import Dict, Any, List, Optional
from ea_chatbot.graph.workers.data_analyst.state import WorkerState
from ea_chatbot.config import Settings
from ea_chatbot.utils.llm_factory import get_llm_model
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
def summarizer_node(state: WorkerState) -> dict:
"""Summarize the data analysis results for the Orchestrator."""
task = state["task"]
output = state.get("output", "")
error = state.get("error")
plots = state.get("plots", [])
vfs_state = state.get("vfs_state", {})
settings = Settings()
logger = get_logger("data_analyst_worker:summarizer")
logger.info("Summarizing analysis results for the Orchestrator...")
# Artifact summary
artifact_info = ""
if plots:
artifact_info += f"- Generated {len(plots)} plot(s).\n"
if vfs_state:
artifact_info += "- VFS Artifacts: " + ", ".join(vfs_state.keys()) + "\n"
# We can use a smaller/faster model for this summary if needed
llm = get_llm_model(
settings.planner_llm, # Using planner model for summary logic
callbacks=[LangChainLoggingHandler(logger=logger)]
)
prompt = f"""You are a data analyst sub-agent. You have completed a sub-task.
Task: {task}
Execution Results: {output}
Error Log (if any): {error}
{artifact_info}
Provide a concise summary of the findings or status for the top-level Orchestrator.
If plots or data files were generated, mention them.
If the execution failed after multiple retries, explain why concisely.
Do NOT include the raw Python code, just the results of the analysis."""
try:
response = llm.invoke(prompt)
result = response.content if hasattr(response, "content") else str(response)
logger.info("[bold green]Analysis results summarized.[/bold green]")
return {
"result": result
}
except Exception as e:
logger.error(f"Failed to summarize results: {str(e)}")
raise e

View File

@@ -0,0 +1,28 @@
from typing import TypedDict, List, Dict, Any, Optional
from langchain_core.messages import BaseMessage
class WorkerState(TypedDict):
"""Internal state for the Data Analyst worker subgraph."""
# Internal worker conversation (not bubbled up to global unless summary)
messages: List[BaseMessage]
# The specific sub-task assigned by the Orchestrator
task: str
# Generated code and execution context
code: Optional[str]
output: Optional[str]
error: Optional[str]
# Number of internal retry attempts
iterations: int
# Temporary storage for analysis results
plots: List[Any]
# Isolated/Snapshot view of the VFS
vfs_state: Dict[str, Any]
# The final summary or result to return to the Orchestrator
result: Optional[str]

View File

@@ -0,0 +1,51 @@
from langgraph.graph import StateGraph, END
from langgraph.graph.state import CompiledStateGraph
from ea_chatbot.graph.workers.data_analyst.state import WorkerState
from ea_chatbot.graph.workers.data_analyst.nodes.coder import coder_node
from ea_chatbot.graph.workers.data_analyst.nodes.executor import executor_node
from ea_chatbot.graph.workers.data_analyst.nodes.summarizer import summarizer_node
def router(state: WorkerState) -> str:
"""Routes the subgraph between coding, execution, and summarization."""
error = state.get("error")
iterations = state.get("iterations", 0)
if error and iterations < 3:
# Retry with error correction
return "coder"
# Either success or max retries reached
return "summarizer"
def create_data_analyst_worker(
coder=coder_node,
executor=executor_node,
summarizer=summarizer_node
) -> CompiledStateGraph:
"""Create the Data Analyst worker subgraph."""
workflow = StateGraph(WorkerState)
# Add Nodes
workflow.add_node("coder", coder)
workflow.add_node("executor", executor)
workflow.add_node("summarizer", summarizer)
# Set entry point
workflow.set_entry_point("coder")
# Add Edges
workflow.add_edge("coder", "executor")
# Add Conditional Edges
workflow.add_conditional_edges(
"executor",
router,
{
"coder": "coder",
"summarizer": "summarizer"
}
)
workflow.add_edge("summarizer", END)
return workflow.compile()

View File

@@ -0,0 +1,32 @@
from typing import Dict, Any, List
from langchain_core.messages import HumanMessage, AIMessage
from ea_chatbot.graph.state import AgentState
from ea_chatbot.graph.workers.researcher.state import WorkerState
def prepare_researcher_input(state: AgentState) -> Dict[str, Any]:
"""Prepare the initial state for the Researcher worker."""
checklist = state.get("checklist", [])
current_step = state.get("current_step", 0)
task_desc = "Perform research"
if 0 <= current_step < len(checklist):
task_desc = checklist[current_step].get("task", task_desc)
return {
"task": task_desc,
"messages": [HumanMessage(content=task_desc)],
"queries": [],
"raw_results": [],
"iterations": 0,
"result": None
}
def merge_researcher_output(worker_state: WorkerState) -> Dict[str, Any]:
"""Map researcher results back to the global AgentState."""
result = worker_state.get("result") or "No result produced by researcher worker."
return {
"messages": [AIMessage(content=result)],
"summary": result,
# Researcher doesn't usually update VFS or Plots, but we keep the structure
}

View File

@@ -1,24 +1,20 @@
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI from langchain_google_genai import ChatGoogleGenerativeAI
from ea_chatbot.graph.state import AgentState from ea_chatbot.graph.workers.researcher.state import WorkerState
from ea_chatbot.config import Settings from ea_chatbot.config import Settings
from ea_chatbot.utils.llm_factory import get_llm_model from ea_chatbot.utils.llm_factory import get_llm_model
from ea_chatbot.utils import helpers from ea_chatbot.utils import helpers
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
from ea_chatbot.graph.prompts.researcher import RESEARCHER_PROMPT from ea_chatbot.graph.prompts.researcher import RESEARCHER_PROMPT
def researcher_node(state: AgentState) -> dict: def searcher_node(state: WorkerState) -> dict:
"""Handle general research queries or web searches.""" """Execute web research for the specific task."""
question = state["question"] task = state["task"]
history = state.get("messages", [])[-6:] logger = get_logger("researcher_worker:searcher")
summary = state.get("summary", "")
logger.info(f"Researching task: {task[:50]}...")
settings = Settings() settings = Settings()
logger = get_logger("researcher")
logger.info(f"Researching question: [italic]\"{question}\"[/italic]")
# Use researcher_llm from settings
llm = get_llm_model( llm = get_llm_model(
settings.researcher_llm, settings.researcher_llm,
callbacks=[LangChainLoggingHandler(logger=logger)] callbacks=[LangChainLoggingHandler(logger=logger)]
@@ -26,34 +22,39 @@ def researcher_node(state: AgentState) -> dict:
date_str = helpers.get_readable_date() date_str = helpers.get_readable_date()
# Adapt the global researcher prompt for the sub-task
messages = RESEARCHER_PROMPT.format_messages( messages = RESEARCHER_PROMPT.format_messages(
date=date_str, date=date_str,
question=question, question=task,
history=history, history=[], # Worker has fresh context or task-specific history
summary=summary summary=""
) )
# Provider-aware tool binding # Tool binding
try: try:
if isinstance(llm, ChatGoogleGenerativeAI): if isinstance(llm, ChatGoogleGenerativeAI):
# Native Google Search for Gemini
llm_with_tools = llm.bind_tools([{"google_search": {}}]) llm_with_tools = llm.bind_tools([{"google_search": {}}])
elif isinstance(llm, ChatOpenAI): elif isinstance(llm, ChatOpenAI):
# Native Web Search for OpenAI (built-in tool)
llm_with_tools = llm.bind_tools([{"type": "web_search"}]) llm_with_tools = llm.bind_tools([{"type": "web_search"}])
else: else:
# Fallback for other providers that might not support these specific search tools
llm_with_tools = llm llm_with_tools = llm
except Exception as e: except Exception as e:
logger.warning(f"Failed to bind search tools: {str(e)}. Falling back to base LLM.") logger.warning(f"Failed to bind search tools: {str(e)}")
llm_with_tools = llm llm_with_tools = llm
try: try:
response = llm_with_tools.invoke(messages) response = llm_with_tools.invoke(messages)
logger.info("[bold green]Research complete.[/bold green]") logger.info("[bold green]Search complete.[/bold green]")
# In a real tool-use scenario, we'd extract the tool outputs here.
# For now, we'll store the response content as a 'raw_result'.
content = response.content if hasattr(response, "content") else str(response)
return { return {
"messages": [response] "messages": [response],
"raw_results": [content],
"iterations": state.get("iterations", 0) + 1
} }
except Exception as e: except Exception as e:
logger.error(f"Research failed: {str(e)}") logger.error(f"Search failed: {str(e)}")
raise e raise e

View File

@@ -0,0 +1,49 @@
from ea_chatbot.graph.workers.researcher.state import WorkerState
from ea_chatbot.config import Settings
from ea_chatbot.utils.llm_factory import get_llm_model
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
def summarizer_node(state: WorkerState) -> dict:
"""Summarize research results for the Orchestrator."""
task = state["task"]
raw_results = state.get("raw_results", [])
logger = get_logger("researcher_worker:summarizer")
logger.info("Summarizing research results...")
settings = Settings()
llm = get_llm_model(
settings.planner_llm,
callbacks=[LangChainLoggingHandler(logger=logger)]
)
# Ensure all results are strings (Gemini/OpenAI might return complex content)
processed_results = []
for res in raw_results:
if isinstance(res, list):
processed_results.append(str(res))
else:
processed_results.append(str(res))
results_str = "\n---\n".join(processed_results)
prompt = f"""You are a Research Specialist sub-agent. You have completed a research sub-task.
Task: {task}
Raw Research Findings:
{results_str}
Provide a concise, factual summary of the findings for the top-level Orchestrator.
Ensure all key facts, dates, and sources (if provided) are preserved.
Do NOT include internal reasoning, just the factual summary."""
try:
response = llm.invoke(prompt)
result = response.content if hasattr(response, "content") else str(response)
logger.info("[bold green]Research summary complete.[/bold green]")
return {
"result": result
}
except Exception as e:
logger.error(f"Failed to summarize research: {str(e)}")
raise e

View File

@@ -0,0 +1,21 @@
from typing import TypedDict, List, Dict, Any, Optional
from langchain_core.messages import BaseMessage
class WorkerState(TypedDict):
"""Internal state for the Researcher worker subgraph."""
# Internal worker conversation
messages: List[BaseMessage]
# The specific sub-task assigned by the Orchestrator
task: str
# Search context
queries: List[str]
raw_results: List[str]
# Number of internal retry/refinement attempts
iterations: int
# The final research summary to return to the Orchestrator
result: Optional[str]

View File

@@ -0,0 +1,25 @@
from langgraph.graph import StateGraph, END
from langgraph.graph.state import CompiledStateGraph
from ea_chatbot.graph.workers.researcher.state import WorkerState
from ea_chatbot.graph.workers.researcher.nodes.searcher import searcher_node
from ea_chatbot.graph.workers.researcher.nodes.summarizer import summarizer_node
def create_researcher_worker(
searcher=searcher_node,
summarizer=summarizer_node
) -> CompiledStateGraph:
"""Create the Researcher worker subgraph."""
workflow = StateGraph(WorkerState)
# Add Nodes
workflow.add_node("searcher", searcher)
workflow.add_node("summarizer", summarizer)
# Set entry point
workflow.set_entry_point("searcher")
# Add Edges
workflow.add_edge("searcher", "summarizer")
workflow.add_edge("summarizer", END)
return workflow.compile()

View File

@@ -2,86 +2,101 @@ from langgraph.graph import StateGraph, END
from ea_chatbot.graph.state import AgentState from ea_chatbot.graph.state import AgentState
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node
from ea_chatbot.graph.nodes.planner import planner_node from ea_chatbot.graph.nodes.planner import planner_node
from ea_chatbot.graph.nodes.coder import coder_node from ea_chatbot.graph.nodes.delegate import delegate_node
from ea_chatbot.graph.nodes.error_corrector import error_corrector_node from ea_chatbot.graph.nodes.reflector import reflector_node
from ea_chatbot.graph.nodes.executor import executor_node from ea_chatbot.graph.nodes.synthesizer import synthesizer_node
from ea_chatbot.graph.nodes.summarizer import summarizer_node from ea_chatbot.graph.workers.data_analyst.workflow import create_data_analyst_worker
from ea_chatbot.graph.nodes.researcher import researcher_node from ea_chatbot.graph.workers.data_analyst.mapping import prepare_worker_input, merge_worker_output
from ea_chatbot.graph.workers.researcher.workflow import create_researcher_worker
from ea_chatbot.graph.workers.researcher.mapping import prepare_researcher_input, merge_researcher_output
from ea_chatbot.graph.nodes.clarification import clarification_node from ea_chatbot.graph.nodes.clarification import clarification_node
from ea_chatbot.graph.nodes.summarize_conversation import summarize_conversation_node from ea_chatbot.graph.nodes.summarize_conversation import summarize_conversation_node
MAX_ITERATIONS = 3 # Global cache for compiled subgraphs
_DATA_ANALYST_WORKER = create_data_analyst_worker()
_RESEARCHER_WORKER = create_researcher_worker()
def router(state: AgentState) -> str: # Define worker nodes as piped runnables to enable subgraph event propagation
"""Route to the next node based on the analysis.""" data_analyst_worker_runnable = prepare_worker_input | _DATA_ANALYST_WORKER | merge_worker_output
researcher_worker_runnable = prepare_researcher_input | _RESEARCHER_WORKER | merge_researcher_output
def main_router(state: AgentState) -> str:
"""Route from query analyzer based on initial assessment."""
next_action = state.get("next_action") next_action = state.get("next_action")
if next_action == "plan": if next_action == "clarify":
return "planner"
elif next_action == "research":
return "researcher"
elif next_action == "clarify":
return "clarification" return "clarification"
else: # Even if QA suggests 'research', we now go through 'planner' for orchestration
return END # aligning with the new hierarchical architecture.
return "planner"
def create_workflow(): def delegation_router(state: AgentState) -> str:
"""Create the LangGraph workflow.""" """Route from delegate node to specific workers or synthesis."""
next_action = state.get("next_action")
if next_action == "data_analyst":
return "data_analyst_worker"
elif next_action == "researcher":
return "researcher_worker"
elif next_action == "summarize":
return "synthesizer"
return "synthesizer"
def create_workflow(
query_analyzer=query_analyzer_node,
planner=planner_node,
delegate=delegate_node,
data_analyst_worker=data_analyst_worker_runnable,
researcher_worker=researcher_worker_runnable,
reflector=reflector_node,
synthesizer=synthesizer_node,
clarification=clarification_node,
summarize_conversation=summarize_conversation_node
):
"""Create the high-level Orchestrator workflow."""
workflow = StateGraph(AgentState) workflow = StateGraph(AgentState)
# Add nodes # Add Nodes
workflow.add_node("query_analyzer", query_analyzer_node) workflow.add_node("query_analyzer", query_analyzer)
workflow.add_node("planner", planner_node) workflow.add_node("planner", planner)
workflow.add_node("coder", coder_node) workflow.add_node("delegate", delegate)
workflow.add_node("error_corrector", error_corrector_node) workflow.add_node("data_analyst_worker", data_analyst_worker)
workflow.add_node("researcher", researcher_node) workflow.add_node("researcher_worker", researcher_worker)
workflow.add_node("clarification", clarification_node) workflow.add_node("reflector", reflector)
workflow.add_node("executor", executor_node) workflow.add_node("synthesizer", synthesizer)
workflow.add_node("summarizer", summarizer_node) workflow.add_node("clarification", clarification)
workflow.add_node("summarize_conversation", summarize_conversation_node) workflow.add_node("summarize_conversation", summarize_conversation)
# Set entry point # Set entry point
workflow.set_entry_point("query_analyzer") workflow.set_entry_point("query_analyzer")
# Add conditional edges from query_analyzer # Edges
workflow.add_conditional_edges( workflow.add_conditional_edges(
"query_analyzer", "query_analyzer",
router, main_router,
{ {
"planner": "planner",
"researcher": "researcher",
"clarification": "clarification", "clarification": "clarification",
END: END "planner": "planner"
} }
) )
# Linear flow for planning and coding workflow.add_edge("planner", "delegate")
workflow.add_edge("planner", "coder")
workflow.add_edge("coder", "executor")
# Executor routing
def executor_router(state: AgentState) -> str:
if state.get("error"):
# Check for iteration limit to prevent infinite loops
if state.get("iterations", 0) >= MAX_ITERATIONS:
return "summarizer"
return "error_corrector"
return "summarizer"
workflow.add_conditional_edges( workflow.add_conditional_edges(
"executor", "delegate",
executor_router, delegation_router,
{ {
"error_corrector": "error_corrector", "data_analyst_worker": "data_analyst_worker",
"summarizer": "summarizer" "researcher_worker": "researcher_worker",
"synthesizer": "synthesizer"
} }
) )
workflow.add_edge("error_corrector", "executor") workflow.add_edge("data_analyst_worker", "reflector")
workflow.add_edge("researcher_worker", "reflector")
workflow.add_edge("reflector", "delegate")
workflow.add_edge("researcher", "summarize_conversation") workflow.add_edge("synthesizer", "summarize_conversation")
workflow.add_edge("clarification", END)
workflow.add_edge("summarizer", "summarize_conversation")
workflow.add_edge("summarize_conversation", END) workflow.add_edge("summarize_conversation", END)
workflow.add_edge("clarification", END)
# Compile the graph # Compile the graph
app = workflow.compile() app = workflow.compile()

View File

@@ -1,7 +1,15 @@
from pydantic import BaseModel, Field, computed_field from pydantic import BaseModel, Field, computed_field
from typing import Sequence, Optional from typing import Sequence, Optional, List, Dict, Any, Literal
import re import re
class QueryAnalysis(BaseModel):
"""Analysis of the user's query."""
data_required: List[str] = Field(description="List of data points or entities mentioned (e.g., ['2024 results', 'Florida']).")
unknowns: List[str] = Field(description="List of target information the user wants to know or needed for final answer (e.g., 'who won', 'total votes').")
ambiguities: List[str] = Field(description="List of CRITICAL missing details that prevent ANY analysis. Do NOT include database names or plot types if defaults can be used.")
conditions: List[str] = Field(description="List of any filters or constraints (e.g., ['year=2024', 'state=Florida']). Include context resolved from history.")
next_action: Literal["plan", "clarify", "research"] = Field(description="The next action to take. 'plan' for data analysis (even with defaults), 'research' for general knowledge, or 'clarify' ONLY for critical ambiguities.")
class TaskPlanContext(BaseModel): class TaskPlanContext(BaseModel):
'''Background context relevant to the task plan''' '''Background context relevant to the task plan'''
initial_context: str = Field( initial_context: str = Field(
@@ -33,6 +41,22 @@ class TaskPlanResponse(BaseModel):
description="Ordered list of steps to execute that follow the 'Step <number>: <detail>' pattern.", description="Ordered list of steps to execute that follow the 'Step <number>: <detail>' pattern.",
) )
class ChecklistTask(BaseModel):
'''A specific sub-task in the high-level orchestrator plan'''
task: str = Field(description="Description of the sub-task")
worker: str = Field(description="The worker to delegate to (data_analyst or researcher)")
class ChecklistResponse(BaseModel):
'''Orchestrator's decomposed plan/checklist'''
goal: str = Field(description="Overall objective")
reflection: str = Field(description="Strategic reasoning")
checklist: List[ChecklistTask] = Field(description="Ordered list of tasks for specialized workers")
class ReflectorResponse(BaseModel):
'''Orchestrator's evaluation of worker output'''
satisfied: bool = Field(description="Whether the worker's output satisfies the sub-task requirements")
reasoning: str = Field(description="Brief explanation of the evaluation")
_IM_SEP_TOKEN_PATTERN = re.compile(re.escape("<|im_sep|>")) _IM_SEP_TOKEN_PATTERN = re.compile(re.escape("<|im_sep|>"))
_CODE_BLOCK_PATTERN = re.compile(r"```(?:python\s*)?(.*?)\s*```", re.DOTALL) _CODE_BLOCK_PATTERN = re.compile(r"```(?:python\s*)?(.*?)\s*```", re.DOTALL)
_FORBIDDEN_MODULES = ( _FORBIDDEN_MODULES = (

View File

@@ -0,0 +1,69 @@
import copy
from typing import Dict, Any, Optional, Tuple, List
from ea_chatbot.utils.logging import get_logger
logger = get_logger("utils:vfs")
def safe_vfs_copy(vfs_state: Dict[str, Any]) -> Dict[str, Any]:
"""
Perform a safe deep copy of the VFS state.
If an entry cannot be deep-copied (e.g., it contains a non-copyable object like a DB handle),
logs an error and replaces the entry with a descriptive error marker.
This prevents crashing the graph/persistence while making the failure explicit.
"""
new_vfs = {}
for filename, data in vfs_state.items():
try:
# Attempt a standard deepcopy for isolation
new_vfs[filename] = copy.deepcopy(data)
except Exception as e:
logger.error(
f"CRITICAL: VFS artifact '{filename}' is NOT copyable/serializable: {str(e)}. "
"Replacing with error placeholder to prevent graph crash."
)
# Replace with a standardized error artifact
new_vfs[filename] = {
"content": f"<ERROR: This artifact could not be persisted or copied: {str(e)}>",
"metadata": {
"type": "error",
"error": str(e),
"original_filename": filename
}
}
return new_vfs
class VFSHelper:
"""Helper class for managing in-memory Virtual File System (VFS) artifacts."""
def __init__(self, vfs_state: Dict[str, Any]):
"""Initialize with a reference to the VFS state from AgentState."""
self._vfs = vfs_state
def write(self, filename: str, content: Any, metadata: Optional[Dict[str, Any]] = None) -> None:
"""Write a file to the VFS."""
self._vfs[filename] = {
"content": content,
"metadata": metadata or {}
}
def read(self, filename: str) -> Tuple[Optional[Any], Optional[Dict[str, Any]]]:
"""Read a file and its metadata from the VFS. Returns (None, None) if not found."""
file_data = self._vfs.get(filename)
if file_data is not None:
# Handle raw values (backwards compatibility or inconsistent schema)
if not isinstance(file_data, dict) or "content" not in file_data:
return file_data, {}
return file_data["content"], file_data.get("metadata", {})
return None, None
def list(self) -> List[str]:
"""List all filenames in the VFS."""
return list(self._vfs.keys())
def delete(self, filename: str) -> bool:
"""Delete a file from the VFS. Returns True if deleted, False if not found."""
if filename in self._vfs:
del self._vfs[filename]
return True
return False

View File

@@ -36,24 +36,25 @@ async def test_stream_agent_events_all_features():
# Stream chunk # Stream chunk
{ {
"event": "on_chat_model_stream", "event": "on_chat_model_stream",
"metadata": {"langgraph_node": "summarizer"}, "metadata": {"langgraph_node": "synthesizer"},
"data": {"chunk": AIMessage(content="Hello ")} "data": {"chunk": AIMessage(content="Hello ")}
}, },
{ {
"event": "on_chat_model_stream", "event": "on_chat_model_stream",
"metadata": {"langgraph_node": "summarizer"}, "metadata": {"langgraph_node": "synthesizer"},
"data": {"chunk": AIMessage(content="world")} "data": {"chunk": AIMessage(content="world")}
}, },
# Plot event # Plot event - with nested subgraph it might bubble up or come directly from data_analyst_worker
# Let's mock it coming from the data_analyst_worker on_chain_end
{ {
"event": "on_chain_end", "event": "on_chain_end",
"name": "executor", "name": "data_analyst_worker",
"data": {"output": {"plots": [fig]}} "data": {"output": {"plots": [fig]}}
}, },
# Final response # Final response
{ {
"event": "on_chain_end", "event": "on_chain_end",
"name": "summarizer", "name": "synthesizer",
"data": {"output": {"messages": [AIMessage(content="Hello world final")]}} "data": {"output": {"messages": [AIMessage(content="Hello world final")]}}
}, },
# Summary update # Summary update
@@ -91,7 +92,7 @@ async def test_stream_agent_events_all_features():
assert any(r.get("type") == "on_chat_model_stream" for r in results) assert any(r.get("type") == "on_chat_model_stream" for r in results)
# Verify plot was encoded # Verify plot was encoded
plot_event = next(r for r in results if r.get("name") == "executor") plot_event = next(r for r in results if r.get("name") == "data_analyst_worker")
assert "encoded_plots" in plot_event["data"] assert "encoded_plots" in plot_event["data"]
assert len(plot_event["data"]["encoded_plots"]) == 1 assert len(plot_event["data"]["encoded_plots"]) == 1

View File

@@ -123,3 +123,20 @@ def test_get_me_success(client):
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["email"] == "test@example.com" assert response.json()["email"] == "test@example.com"
assert response.json()["id"] == "123" assert response.json()["id"] == "123"
def test_get_me_rejects_refresh_token(client):
"""Test that /auth/me rejects refresh tokens for authentication."""
from ea_chatbot.api.utils import create_refresh_token
token = create_refresh_token(data={"sub": "123"})
with patch("ea_chatbot.api.dependencies.history_manager") as mock_hm:
# Even if the user exists, the dependency should reject the token type
mock_hm.get_user_by_id.return_value = User(id="123", username="test@example.com")
response = client.get(
"/api/v1/auth/me",
headers={"Authorization": f"Bearer {token}"}
)
assert response.status_code == 401
assert "Cannot use refresh token" in response.json()["detail"]

View File

@@ -19,14 +19,13 @@ def auth_header(mock_user):
yield {"Authorization": f"Bearer {token}"} yield {"Authorization": f"Bearer {token}"}
app.dependency_overrides.clear() app.dependency_overrides.clear()
def test_persistence_integration_success(auth_header, mock_user): def test_persistence_integration_success(auth_header, mock_user):
"""Test that messages and plots are persisted correctly during streaming.""" """Test that messages and plots are persisted correctly during streaming."""
mock_events = [ mock_events = [
{"event": "on_chat_model_stream", "name": "summarizer", "data": {"chunk": "Final answer"}}, {"event": "on_chat_model_stream", "metadata": {"langgraph_node": "synthesizer"}, "data": {"chunk": "Final answer"}},
{"event": "on_chain_end", "name": "summarizer", "data": {"output": {"messages": [{"content": "Final answer"}]}}}, {"event": "on_chain_end", "name": "synthesizer", "data": {"output": {"messages": [{"content": "Final answer"}]}}},
{"event": "on_chain_end", "name": "summarize_conversation", "data": {"output": {"summary": "New summary"}}} {"event": "on_chain_end", "name": "summarize_conversation", "data": {"output": {"summary": "New summary"}}}
] ]
async def mock_astream_events(*args, **kwargs): async def mock_astream_events(*args, **kwargs):
for event in mock_events: for event in mock_events:
yield event yield event

View File

@@ -1,62 +0,0 @@
import pytest
from unittest.mock import MagicMock, patch
from ea_chatbot.graph.nodes.coder import coder_node
from ea_chatbot.graph.nodes.error_corrector import error_corrector_node
@pytest.fixture
def mock_state():
return {
"messages": [],
"question": "Show me results for New Jersey",
"plan": "Step 1: Load data\nStep 2: Filter by NJ",
"code": None,
"error": None,
"plots": [],
"dfs": {},
"next_action": "plan"
}
@patch("ea_chatbot.graph.nodes.coder.get_llm_model")
@patch("ea_chatbot.utils.database_inspection.get_data_summary")
def test_coder_node(mock_get_summary, mock_get_llm, mock_state):
"""Test coder node generates code from plan."""
mock_get_summary.return_value = "Column: Name, Type: text"
mock_llm = MagicMock()
mock_get_llm.return_value = mock_llm
from ea_chatbot.schemas import CodeGenerationResponse
mock_response = CodeGenerationResponse(
code="import pandas as pd\nprint('Hello')",
explanation="Generated code"
)
mock_llm.with_structured_output.return_value.invoke.return_value = mock_response
result = coder_node(mock_state)
assert "code" in result
assert "import pandas as pd" in result["code"]
assert "error" in result
assert result["error"] is None
@patch("ea_chatbot.graph.nodes.error_corrector.get_llm_model")
def test_error_corrector_node(mock_get_llm, mock_state):
"""Test error corrector node fixes code."""
mock_state["code"] = "import pandas as pd\nprint(undefined_var)"
mock_state["error"] = "NameError: name 'undefined_var' is not defined"
mock_llm = MagicMock()
mock_get_llm.return_value = mock_llm
from ea_chatbot.schemas import CodeGenerationResponse
mock_response = CodeGenerationResponse(
code="import pandas as pd\nprint('Defined')",
explanation="Fixed variable"
)
mock_llm.with_structured_output.return_value.invoke.return_value = mock_response
result = error_corrector_node(mock_state)
assert "code" in result
assert "print('Defined')" in result["code"]
assert result["error"] is None

View File

@@ -0,0 +1,53 @@
from langchain_core.messages import HumanMessage, AIMessage
from ea_chatbot.graph.state import AgentState
from ea_chatbot.graph.workers.data_analyst.mapping import (
prepare_worker_input,
merge_worker_output
)
from ea_chatbot.graph.workers.data_analyst.state import WorkerState
def test_prepare_worker_input():
"""Verify that we correctly map global state to worker input."""
global_state = AgentState(
messages=[HumanMessage(content="global message")],
question="original question",
checklist=[{"task": "Worker Task", "status": "pending"}],
current_step=0,
vfs={"old.txt": "old data"},
plots=[],
dfs={},
next_action="test",
iterations=0
)
worker_input = prepare_worker_input(global_state)
assert worker_input["task"] == "Worker Task"
assert "old.txt" in worker_input["vfs_state"]
# Internal worker messages should start fresh or with the task
assert len(worker_input["messages"]) == 1
assert worker_input["messages"][0].content == "Worker Task"
def test_merge_worker_output():
"""Verify that we correctly merge worker results back to global state."""
worker_state = WorkerState(
messages=[HumanMessage(content="internal"), AIMessage(content="summary")],
task="Worker Task",
result="Finished analysis",
plots=["plot1"],
vfs_state={"new.txt": "new data"},
iterations=2
)
updates = merge_worker_output(worker_state)
# We expect the 'result' to be added as an AI message to global history
assert len(updates["messages"]) == 1
assert updates["messages"][0].content == "Finished analysis"
# VFS should be updated
assert "new.txt" in updates["vfs"]
# Plots should be bubbled up
assert len(updates["plots"]) == 1
assert updates["plots"][0] == "plot1"

View File

@@ -0,0 +1,17 @@
from typing import get_type_hints, List
from langchain_core.messages import BaseMessage
from ea_chatbot.graph.workers.data_analyst.state import WorkerState
def test_data_analyst_worker_state():
"""Verify that DataAnalyst WorkerState has the required fields."""
hints = get_type_hints(WorkerState)
assert "messages" in hints
assert "task" in hints
assert hints["task"] == str
assert "code" in hints
assert "output" in hints
assert "error" in hints
assert "iterations" in hints
assert hints["iterations"] == int

View File

@@ -0,0 +1,81 @@
import pytest
from unittest.mock import MagicMock
from ea_chatbot.graph.workers.data_analyst.workflow import create_data_analyst_worker, WorkerState
def test_data_analyst_worker_one_shot():
"""Verify a successful one-shot execution of the worker subgraph."""
mock_coder = MagicMock()
mock_executor = MagicMock()
mock_summarizer = MagicMock()
# Scenario: Coder -> Executor (Success) -> Summarizer -> END
mock_coder.return_value = {"code": "print(1)", "error": None, "iterations": 1}
mock_executor.return_value = {"output": "1\n", "error": None, "plots": []}
mock_summarizer.return_value = {"result": "Result is 1"}
graph = create_data_analyst_worker(
coder=mock_coder,
executor=mock_executor,
summarizer=mock_summarizer
)
initial_state = WorkerState(
messages=[],
task="Calculate 1+1",
code=None,
output=None,
error=None,
iterations=0,
plots=[],
vfs_state={},
result=None
)
final_state = graph.invoke(initial_state)
assert final_state["result"] == "Result is 1"
assert mock_coder.call_count == 1
assert mock_executor.call_count == 1
assert mock_summarizer.call_count == 1
def test_data_analyst_worker_retry():
"""Verify that the worker retries on error."""
mock_coder = MagicMock()
mock_executor = MagicMock()
mock_summarizer = MagicMock()
# Scenario: Coder (1) -> Executor (Error) -> Router (coder) -> Coder (2) -> Executor (Success) -> Summarizer -> END
mock_coder.side_effect = [
{"code": "error_code", "error": None, "iterations": 1},
{"code": "fixed_code", "error": None, "iterations": 2}
]
mock_executor.side_effect = [
{"output": "", "error": "NameError", "plots": []},
{"output": "Success", "error": None, "plots": []}
]
mock_summarizer.return_value = {"result": "Fixed Result"}
graph = create_data_analyst_worker(
coder=mock_coder,
executor=mock_executor,
summarizer=mock_summarizer
)
initial_state = WorkerState(
messages=[],
task="Retry Task",
code=None,
output=None,
error=None,
iterations=0,
plots=[],
vfs_state={},
result=None
)
final_state = graph.invoke(initial_state)
assert final_state["result"] == "Fixed Result"
assert mock_coder.call_count == 2
assert mock_executor.call_count == 2
assert mock_summarizer.call_count == 1

View File

@@ -0,0 +1,79 @@
import pytest
from unittest.mock import MagicMock
from ea_chatbot.graph.workflow import create_workflow
from ea_chatbot.graph.state import AgentState
from langchain_core.messages import AIMessage, HumanMessage
def test_deepagents_multi_worker_sequential_flow():
"""Verify that the Orchestrator can handle a sequence of different workers."""
mock_analyzer = MagicMock()
mock_planner = MagicMock()
mock_delegate = MagicMock()
mock_analyst = MagicMock()
mock_researcher = MagicMock()
mock_reflector = MagicMock()
mock_synthesizer = MagicMock()
# 1. Analyzer: Plan
mock_analyzer.return_value = {"next_action": "plan"}
# 2. Planner: Two tasks
mock_planner.return_value = {
"checklist": [
{"task": "Get Numbers", "worker": "data_analyst"},
{"task": "Get Facts", "worker": "researcher"}
],
"current_step": 0
}
# 3. Delegate & Reflector: Loop through tasks
mock_delegate.side_effect = [
{"next_action": "data_analyst"}, # Step 0
{"next_action": "researcher"}, # Step 1
{"next_action": "summarize"} # Step 2 (done)
]
mock_analyst.return_value = {"messages": [AIMessage(content="Analyst summary")], "vfs": {"data.csv": "..."}}
mock_researcher.return_value = {"messages": [AIMessage(content="Researcher summary")]}
mock_reflector.side_effect = [
{"current_step": 1, "next_action": "delegate"}, # Done with Analyst
{"current_step": 2, "next_action": "delegate"} # Done with Researcher
]
mock_synthesizer.return_value = {
"messages": [AIMessage(content="Final multi-agent response")],
"next_action": "end"
}
app = create_workflow(
query_analyzer=mock_analyzer,
planner=mock_planner,
delegate=mock_delegate,
data_analyst_worker=mock_analyst,
researcher_worker=mock_researcher,
reflector=mock_reflector,
synthesizer=mock_synthesizer
)
initial_state = AgentState(
messages=[HumanMessage(content="Numbers and facts please")],
question="Numbers and facts please",
analysis={},
next_action="",
iterations=0,
checklist=[],
current_step=0,
vfs={},
plots=[],
dfs={}
)
final_state = app.invoke(initial_state, config={"recursion_limit": 30})
assert mock_analyst.called
assert mock_researcher.called
assert mock_reflector.call_count == 2
assert "Final multi-agent response" in [m.content for m in final_state["messages"]]
assert final_state["current_step"] == 2

View File

@@ -1,122 +0,0 @@
import pytest
import pandas as pd
from unittest.mock import MagicMock, patch
from matplotlib.figure import Figure
from ea_chatbot.graph.nodes.executor import executor_node
@pytest.fixture
def mock_settings():
with patch("ea_chatbot.graph.nodes.executor.Settings") as MockSettings:
mock_settings_instance = MagicMock()
mock_settings_instance.db_host = "localhost"
mock_settings_instance.db_port = 5432
mock_settings_instance.db_user = "user"
mock_settings_instance.db_pswd = "pass"
mock_settings_instance.db_name = "test_db"
mock_settings_instance.db_table = "test_table"
MockSettings.return_value = mock_settings_instance
yield mock_settings_instance
@pytest.fixture
def mock_db_client():
with patch("ea_chatbot.graph.nodes.executor.DBClient") as MockDBClient:
mock_client_instance = MagicMock()
MockDBClient.return_value = mock_client_instance
yield mock_client_instance
def test_executor_node_success_simple_print(mock_settings, mock_db_client):
"""Test executing simple code that prints to stdout."""
state = {
"code": "print('Hello, World!')",
"question": "test",
"messages": []
}
result = executor_node(state)
assert "code_output" in result
assert "Hello, World!" in result["code_output"]
assert result["error"] is None
assert result["plots"] == []
assert result["dfs"] == {}
def test_executor_node_success_dataframe(mock_settings, mock_db_client):
"""Test executing code that creates a DataFrame."""
code = """
import pandas as pd
df = pd.DataFrame({'a': [1, 2], 'b': [3, 4]})
print(df)
"""
state = {
"code": code,
"question": "test",
"messages": []
}
result = executor_node(state)
assert "code_output" in result
assert "a b" in result["code_output"] # Check part of DF string representation
assert "dfs" in result
assert "df" in result["dfs"]
assert isinstance(result["dfs"]["df"], pd.DataFrame)
def test_executor_node_success_plot(mock_settings, mock_db_client):
"""Test executing code that generates a plot."""
code = """
import matplotlib.pyplot as plt
fig = plt.figure()
plots.append(fig)
print('Plot generated')
"""
state = {
"code": code,
"question": "test",
"messages": []
}
result = executor_node(state)
assert "Plot generated" in result["code_output"]
assert "plots" in result
assert len(result["plots"]) == 1
assert isinstance(result["plots"][0], Figure)
def test_executor_node_error_syntax(mock_settings, mock_db_client):
"""Test executing code with a syntax error."""
state = {
"code": "print('Hello World", # Missing closing quote
"question": "test",
"messages": []
}
result = executor_node(state)
assert result["error"] is not None
assert "SyntaxError" in result["error"]
def test_executor_node_error_runtime(mock_settings, mock_db_client):
"""Test executing code with a runtime error."""
state = {
"code": "print(1 / 0)",
"question": "test",
"messages": []
}
result = executor_node(state)
assert result["error"] is not None
assert "ZeroDivisionError" in result["error"]
def test_executor_node_no_code(mock_settings, mock_db_client):
"""Test handling when no code is provided."""
state = {
"code": None,
"question": "test",
"messages": []
}
result = executor_node(state)
assert "error" in result
assert "No code provided" in result["error"]

View File

@@ -2,7 +2,7 @@ import json
import pytest import pytest
import logging import logging
from unittest.mock import patch from unittest.mock import patch
from ea_chatbot.graph.workflow import app from ea_chatbot.graph.workflow import create_workflow
from ea_chatbot.graph.state import AgentState from ea_chatbot.graph.state import AgentState
from ea_chatbot.utils.logging import get_logger from ea_chatbot.utils.logging import get_logger
from langchain_community.chat_models import FakeListChatModel from langchain_community.chat_models import FakeListChatModel
@@ -43,10 +43,10 @@ def test_logging_e2e_json_output(tmp_path):
"question": "Who won in 2024?", "question": "Who won in 2024?",
"analysis": None, "analysis": None,
"next_action": "", "next_action": "",
"plan": None, "iterations": 0,
"code": None, "checklist": [],
"code_output": None, "current_step": 0,
"error": None, "vfs": {},
"plots": [], "plots": [],
"dfs": {} "dfs": {}
} }
@@ -57,6 +57,20 @@ def test_logging_e2e_json_output(tmp_path):
fake_clarify = FakeListChatModel(responses=["Please specify."]) fake_clarify = FakeListChatModel(responses=["Please specify."])
# Create a test app without interrupts
# We need to manually compile it here to avoid the global 'app' which has interrupts
from langgraph.graph import StateGraph, END
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node
from ea_chatbot.graph.nodes.clarification import clarification_node
workflow = StateGraph(AgentState)
workflow.add_node("query_analyzer", query_analyzer_node)
workflow.add_node("clarification", clarification_node)
workflow.set_entry_point("query_analyzer")
workflow.add_edge("query_analyzer", "clarification")
workflow.add_edge("clarification", END)
test_app = workflow.compile()
with patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model") as mock_llm_factory: with patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model") as mock_llm_factory:
mock_llm_factory.return_value = fake_analyzer mock_llm_factory.return_value = fake_analyzer
@@ -64,7 +78,7 @@ def test_logging_e2e_json_output(tmp_path):
mock_clarify_llm_factory.return_value = fake_clarify mock_clarify_llm_factory.return_value = fake_clarify
# Run the graph # Run the graph
list(app.stream(initial_state)) test_app.invoke(initial_state)
# Verify file content # Verify file content
assert log_file.exists() assert log_file.exists()

View File

@@ -1,24 +1,27 @@
import pytest import pytest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from langchain_core.messages import HumanMessage, AIMessage
from ea_chatbot.graph.nodes.planner import planner_node from ea_chatbot.graph.nodes.planner import planner_node
from ea_chatbot.graph.nodes.researcher import researcher_node from ea_chatbot.schemas import ChecklistResponse, ChecklistTask
from ea_chatbot.graph.nodes.summarizer import summarizer_node from langchain_core.messages import HumanMessage, AIMessage
from ea_chatbot.schemas import TaskPlanResponse
@pytest.fixture @pytest.fixture
def mock_state_with_history(): def mock_state_with_history():
return { return {
"messages": [ "messages": [
HumanMessage(content="Show me the 2024 results for Florida"), HumanMessage(content="What about NJ?"),
AIMessage(content="Here are the results for Florida in 2024...") AIMessage(content="NJ has 9 million voters.")
], ],
"question": "What about in New Jersey?", "question": "Show me the breakdown by county for 2024",
"analysis": {"data_required": ["2024 results", "New Jersey"], "unknowns": [], "ambiguities": [], "conditions": []}, "analysis": {
"data_required": ["2024 results", "New Jersey"],
"unknowns": [],
"ambiguities": [],
"conditions": []
},
"next_action": "plan", "next_action": "plan",
"summary": "The user is asking about 2024 election results.", "summary": "The user is asking about NJ 2024 results.",
"plan": "Plan steps...", "checklist": [],
"code_output": "Code output..." "current_step": 0
} }
@patch("ea_chatbot.graph.nodes.planner.get_llm_model") @patch("ea_chatbot.graph.nodes.planner.get_llm_model")
@@ -31,49 +34,17 @@ def test_planner_uses_history_and_summary(mock_prompt, mock_get_summary, mock_ge
mock_structured_llm = MagicMock() mock_structured_llm = MagicMock()
mock_llm_instance.with_structured_output.return_value = mock_structured_llm mock_llm_instance.with_structured_output.return_value = mock_structured_llm
mock_structured_llm.invoke.return_value = TaskPlanResponse( mock_structured_llm.invoke.return_value = ChecklistResponse(
goal="goal", goal="goal",
reflection="reflection", reflection="reflection",
context={ checklist=[ChecklistTask(task="Step 1: test", worker="data_analyst")]
"initial_context": "context",
"assumptions": [],
"constraints": []
},
steps=["Step 1: test"]
) )
planner_node(mock_state_with_history) planner_node(mock_state_with_history)
mock_prompt.format_messages.assert_called_once() # Verify history and summary were passed to prompt format
kwargs = mock_prompt.format_messages.call_args[1] # We check the arguments passed to the mock_prompt.format_messages
assert kwargs["question"] == "What about in New Jersey?" call_args = mock_prompt.format_messages.call_args[1]
assert kwargs["summary"] == mock_state_with_history["summary"] assert call_args["summary"] == "The user is asking about NJ 2024 results."
assert len(kwargs["history"]) == 2 assert len(call_args["history"]) == 2
assert "breakdown by county" in call_args["question"]
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
@patch("ea_chatbot.graph.nodes.researcher.RESEARCHER_PROMPT")
def test_researcher_uses_history_and_summary(mock_prompt, mock_get_llm, mock_state_with_history):
mock_llm_instance = MagicMock()
mock_get_llm.return_value = mock_llm_instance
researcher_node(mock_state_with_history)
mock_prompt.format_messages.assert_called_once()
kwargs = mock_prompt.format_messages.call_args[1]
assert kwargs["question"] == "What about in New Jersey?"
assert kwargs["summary"] == mock_state_with_history["summary"]
assert len(kwargs["history"]) == 2
@patch("ea_chatbot.graph.nodes.summarizer.get_llm_model")
@patch("ea_chatbot.graph.nodes.summarizer.SUMMARIZER_PROMPT")
def test_summarizer_uses_history_and_summary(mock_prompt, mock_get_llm, mock_state_with_history):
mock_llm_instance = MagicMock()
mock_get_llm.return_value = mock_llm_instance
summarizer_node(mock_state_with_history)
mock_prompt.format_messages.assert_called_once()
kwargs = mock_prompt.format_messages.call_args[1]
assert kwargs["question"] == "What about in New Jersey?"
assert kwargs["summary"] == mock_state_with_history["summary"]
assert len(kwargs["history"]) == 2

View File

@@ -0,0 +1,126 @@
import pytest
from unittest.mock import MagicMock, patch
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node
from ea_chatbot.graph.nodes.planner import planner_node
from ea_chatbot.graph.nodes.reflector import reflector_node
from ea_chatbot.graph.nodes.synthesizer import synthesizer_node
from ea_chatbot.graph.workers.data_analyst.nodes.coder import coder_node as analyst_coder
from ea_chatbot.graph.workers.researcher.nodes.searcher import searcher_node as researcher_searcher
from ea_chatbot.graph.state import AgentState
from ea_chatbot.graph.workers.data_analyst.state import WorkerState as AnalystState
from ea_chatbot.graph.workers.researcher.state import WorkerState as ResearcherState
from ea_chatbot.config import Settings
@pytest.fixture
def mock_settings():
settings = Settings()
settings.query_analyzer_llm.model = "model-qa"
settings.planner_llm.model = "model-planner"
settings.reflector_llm.model = "model-reflector"
settings.synthesizer_llm.model = "model-synthesizer"
settings.coder_llm.model = "model-coder"
settings.researcher_llm.model = "model-researcher"
return settings
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
@patch("ea_chatbot.graph.nodes.query_analyzer.Settings")
def test_query_analyzer_model(mock_settings_cls, mock_get_llm, mock_settings):
mock_settings_cls.return_value = mock_settings
mock_llm = MagicMock()
mock_get_llm.return_value = mock_llm
state = AgentState(question="test", messages=[], checklist=[], current_step=0, iterations=0, vfs={}, plots=[], dfs={}, next_action="", analysis={})
try:
query_analyzer_node(state)
except:
pass # We only care about the call
# Verify it was called with the QA config
args = mock_get_llm.call_args[0][0]
assert args.model == "model-qa"
@patch("ea_chatbot.graph.nodes.planner.get_llm_model")
@patch("ea_chatbot.graph.nodes.planner.Settings")
def test_planner_model(mock_settings_cls, mock_get_llm, mock_settings):
mock_settings_cls.return_value = mock_settings
mock_llm = MagicMock()
mock_get_llm.return_value = mock_llm
state = AgentState(question="test", messages=[], checklist=[], current_step=0, iterations=0, vfs={}, plots=[], dfs={}, next_action="", analysis={})
try:
planner_node(state)
except:
pass
args = mock_get_llm.call_args[0][0]
assert args.model == "model-planner"
@patch("ea_chatbot.graph.nodes.reflector.get_llm_model")
@patch("ea_chatbot.graph.nodes.reflector.Settings")
def test_reflector_model(mock_settings_cls, mock_get_llm, mock_settings):
mock_settings_cls.return_value = mock_settings
mock_llm = MagicMock()
mock_get_llm.return_value = mock_llm
state = AgentState(question="test", messages=[], checklist=[{"task": "T1"}], current_step=0, iterations=0, vfs={}, plots=[], dfs={}, next_action="", analysis={}, summary="Worker did X")
try:
reflector_node(state)
except:
pass
args = mock_get_llm.call_args[0][0]
assert args.model == "model-reflector"
@patch("ea_chatbot.graph.nodes.synthesizer.get_llm_model")
@patch("ea_chatbot.graph.nodes.synthesizer.Settings")
def test_synthesizer_model(mock_settings_cls, mock_get_llm, mock_settings):
mock_settings_cls.return_value = mock_settings
mock_llm = MagicMock()
mock_get_llm.return_value = mock_llm
state = AgentState(question="test", messages=[], checklist=[], current_step=0, iterations=0, vfs={}, plots=[], dfs={}, next_action="", analysis={})
try:
synthesizer_node(state)
except:
pass
args = mock_get_llm.call_args[0][0]
assert args.model == "model-synthesizer"
@patch("ea_chatbot.graph.workers.data_analyst.nodes.coder.get_llm_model")
@patch("ea_chatbot.graph.workers.data_analyst.nodes.coder.Settings")
def test_analyst_coder_model(mock_settings_cls, mock_get_llm, mock_settings):
mock_settings_cls.return_value = mock_settings
mock_llm = MagicMock()
mock_get_llm.return_value = mock_llm
state = AnalystState(task="test", messages=[], iterations=0, vfs_state={}, plots=[], result=None, code=None, output=None, error=None)
try:
analyst_coder(state)
except:
pass
args = mock_get_llm.call_args[0][0]
assert args.model == "model-coder"
@patch("ea_chatbot.graph.workers.researcher.nodes.searcher.get_llm_model")
@patch("ea_chatbot.graph.workers.researcher.nodes.searcher.Settings")
def test_researcher_searcher_model(mock_settings_cls, mock_get_llm, mock_settings):
mock_settings_cls.return_value = mock_settings
mock_llm = MagicMock()
mock_get_llm.return_value = mock_llm
state = ResearcherState(task="test", messages=[], iterations=0, result=None, queries=[], raw_results=[])
try:
researcher_searcher(state)
except:
pass
args = mock_get_llm.call_args[0][0]
assert args.model == "model-researcher"

View File

@@ -0,0 +1,56 @@
from ea_chatbot.graph.nodes.delegate import delegate_node
from ea_chatbot.graph.state import AgentState
def test_delegate_node_data_analyst():
"""Verify that the delegate node routes to data_analyst."""
state = AgentState(
checklist=[{"task": "Analyze data", "worker": "data_analyst"}],
current_step=0,
messages=[],
question="test",
analysis={},
next_action="",
iterations=0,
vfs={},
plots=[],
dfs={}
)
result = delegate_node(state)
assert result["next_action"] == "data_analyst"
def test_delegate_node_researcher():
"""Verify that the delegate node routes to researcher."""
state = AgentState(
checklist=[{"task": "Search web", "worker": "researcher"}],
current_step=0,
messages=[],
question="test",
analysis={},
next_action="",
iterations=0,
vfs={},
plots=[],
dfs={}
)
result = delegate_node(state)
assert result["next_action"] == "researcher"
def test_delegate_node_finished():
"""Verify that the delegate node routes to summarize if checklist is complete."""
state = AgentState(
checklist=[{"task": "Task 1", "worker": "data_analyst"}],
current_step=1, # Already finished the only task
messages=[],
question="test",
analysis={},
next_action="",
iterations=0,
vfs={},
plots=[],
dfs={}
)
result = delegate_node(state)
assert result["next_action"] == "summarize"

View File

@@ -0,0 +1,136 @@
import pytest
from unittest.mock import MagicMock
from ea_chatbot.graph.workflow import create_workflow
from ea_chatbot.graph.state import AgentState
from langchain_core.messages import AIMessage, HumanMessage
def test_orchestrator_full_flow():
"""Verify the full Orchestrator-Workers flow via direct node injection."""
mock_analyzer = MagicMock()
mock_planner = MagicMock()
mock_delegate = MagicMock()
mock_worker = MagicMock()
mock_reflector = MagicMock()
mock_synthesizer = MagicMock()
mock_summarize_conv = MagicMock()
# 1. Analyzer: Proceed to planning
mock_analyzer.return_value = {"next_action": "plan"}
# 2. Planner: Generate checklist
mock_planner.return_value = {
"checklist": [{"task": "T1", "worker": "data_analyst"}],
"current_step": 0
}
# 3. Delegate: Route to data_analyst
mock_delegate.side_effect = [
{"next_action": "data_analyst"}, # First call
{"next_action": "summarize"} # Second call (after reflector)
]
# 4. Worker: Success
mock_worker.return_value = {
"messages": [AIMessage(content="Worker result")],
"vfs": {"res.txt": "data"}
}
# 5. Reflector: Advance
mock_reflector.return_value = {
"current_step": 1,
"next_action": "delegate"
}
# 6. Synthesizer: Final answer
mock_synthesizer.return_value = {
"messages": [AIMessage(content="Final synthesized answer")],
"next_action": "end"
}
# 7. Summarize Conv: End
mock_summarize_conv.return_value = {"summary": "Done"}
# Create workflow with injected mocks
app = create_workflow(
query_analyzer=mock_analyzer,
planner=mock_planner,
delegate=mock_delegate,
data_analyst_worker=mock_worker,
reflector=mock_reflector,
synthesizer=mock_synthesizer,
summarize_conversation=mock_summarize_conv
)
initial_state = AgentState(
messages=[HumanMessage(content="Explain results")],
question="Explain results",
analysis={},
next_action="",
iterations=0,
checklist=[],
current_step=0,
vfs={},
plots=[],
dfs={}
)
final_state = app.invoke(initial_state)
assert mock_analyzer.called
assert mock_planner.called
assert mock_delegate.call_count == 2
assert mock_worker.called
assert mock_reflector.called
assert mock_synthesizer.called
assert "Final synthesized answer" in [m.content for m in final_state["messages"]]
def test_orchestrator_researcher_flow():
"""Verify that the Orchestrator can route to the researcher worker."""
mock_analyzer = MagicMock()
mock_planner = MagicMock()
mock_delegate = MagicMock()
mock_researcher = MagicMock()
mock_reflector = MagicMock()
mock_synthesizer = MagicMock()
mock_analyzer.return_value = {"next_action": "plan"}
mock_planner.return_value = {
"checklist": [{"task": "Search news", "worker": "researcher"}],
"current_step": 0
}
mock_delegate.side_effect = [
{"next_action": "researcher"},
{"next_action": "summarize"}
]
mock_researcher.return_value = {"messages": [AIMessage(content="News found")]}
mock_reflector.return_value = {"current_step": 1, "next_action": "delegate"}
mock_synthesizer.return_value = {"messages": [AIMessage(content="Final News Summary")], "next_action": "end"}
app = create_workflow(
query_analyzer=mock_analyzer,
planner=mock_planner,
delegate=mock_delegate,
researcher_worker=mock_researcher,
reflector=mock_reflector,
synthesizer=mock_synthesizer
)
initial_state = AgentState(
messages=[HumanMessage(content="What's the news?")],
question="What's the news?",
analysis={},
next_action="",
iterations=0,
checklist=[],
current_step=0,
vfs={},
plots=[],
dfs={}
)
final_state = app.invoke(initial_state)
assert mock_researcher.called
assert "Final News Summary" in [m.content for m in final_state["messages"]]

View File

@@ -0,0 +1,34 @@
from typing import get_type_hints, List
from ea_chatbot.graph.nodes.planner import planner_node
from ea_chatbot.graph.state import AgentState
def test_planner_node_checklist():
"""Verify that the planner node generates a checklist."""
state = AgentState(
messages=[],
question="How many voters are in Florida and what is the current news?",
analysis={"requires_dataset": True},
next_action="plan",
iterations=0,
checklist=[],
current_step=0,
vfs={},
plots=[],
dfs={}
)
# Mocking the LLM would be ideal, but for now we'll check the returned keys
# and assume the implementation provides them.
# In a real TDD, we'd mock the LLM to return a specific structure.
# For now, let's assume the task is to update 'planner_node' to return these keys.
result = planner_node(state)
assert "checklist" in result
assert isinstance(result["checklist"], list)
assert len(result["checklist"]) > 0
assert "task" in result["checklist"][0]
assert "worker" in result["checklist"][0] # 'data_analyst' or 'researcher'
assert "current_step" in result
assert result["current_step"] == 0

View File

@@ -0,0 +1,30 @@
from unittest.mock import MagicMock, patch
from ea_chatbot.graph.nodes.reflector import reflector_node
from ea_chatbot.graph.state import AgentState
def test_reflector_node_satisfied():
"""Verify that the reflector node increments current_step when satisfied."""
state = AgentState(
checklist=[{"task": "Analyze votes", "worker": "data_analyst"}],
current_step=0,
messages=[],
question="test",
analysis={},
next_action="",
iterations=0,
vfs={},
plots=[],
dfs={},
summary="Worker successfully analyzed votes and found 5 million."
)
# Mocking the LLM to return 'satisfied=True'
with patch("ea_chatbot.graph.nodes.reflector.get_llm_model") as mock_get_llm:
mock_llm = MagicMock()
mock_llm.with_structured_output.return_value.invoke.return_value = MagicMock(satisfied=True, reasoning="Good.")
mock_get_llm.return_value = mock_llm
result = reflector_node(state)
assert result["current_step"] == 1
assert result["next_action"] == "delegate" # Route back to delegate for next task

View File

@@ -0,0 +1,31 @@
from unittest.mock import MagicMock, patch
from ea_chatbot.graph.nodes.synthesizer import synthesizer_node
from ea_chatbot.graph.state import AgentState
from langchain_core.messages import AIMessage
def test_synthesizer_node_success():
"""Verify that the synthesizer node produces a final response."""
state = AgentState(
messages=[AIMessage(content="Worker 1 found data."), AIMessage(content="Worker 2 searched web.")],
question="What are the results?",
checklist=[],
current_step=0,
iterations=0,
vfs={},
plots=[],
dfs={},
next_action="",
analysis={}
)
# Mocking the LLM
with patch("ea_chatbot.graph.nodes.synthesizer.get_llm_model") as mock_get_llm:
mock_llm = MagicMock()
mock_llm.invoke.return_value = AIMessage(content="Final synthesized answer.")
mock_get_llm.return_value = mock_llm
result = synthesizer_node(state)
assert len(result["messages"]) == 1
assert result["messages"][0].content == "Final synthesized answer."
assert result["next_action"] == "end"

View File

@@ -1,6 +1,7 @@
import pytest import pytest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from ea_chatbot.graph.nodes.planner import planner_node from ea_chatbot.graph.nodes.planner import planner_node
from ea_chatbot.schemas import ChecklistResponse, ChecklistTask
@pytest.fixture @pytest.fixture
def mock_state(): def mock_state():
@@ -8,39 +9,40 @@ def mock_state():
"messages": [], "messages": [],
"question": "Show me results for New Jersey", "question": "Show me results for New Jersey",
"analysis": { "analysis": {
# "requires_dataset" removed as it's no longer used
"expert": "Data Analyst", "expert": "Data Analyst",
"data": "NJ data", "data": "NJ data",
"unknown": "results", "unknown": "results",
"condition": "state=NJ" "condition": "state=NJ"
}, },
"next_action": "plan", "next_action": "plan",
"plan": None "checklist": [],
"current_step": 0
} }
@patch("ea_chatbot.graph.nodes.planner.get_llm_model") @patch("ea_chatbot.graph.nodes.planner.get_llm_model")
@patch("ea_chatbot.utils.database_inspection.get_data_summary") @patch("ea_chatbot.utils.database_inspection.get_data_summary")
def test_planner_node(mock_get_summary, mock_get_llm, mock_state): def test_planner_node(mock_get_summary, mock_get_llm, mock_state):
"""Test planner node with unified prompt.""" """Test planner node with checklist generation."""
mock_get_summary.return_value = "Column: Name, Type: text" mock_get_summary.return_value = "Column: Name, Type: text"
mock_llm = MagicMock() mock_llm = MagicMock()
mock_get_llm.return_value = mock_llm mock_get_llm.return_value = mock_llm
from ea_chatbot.schemas import TaskPlanResponse, TaskPlanContext mock_response = ChecklistResponse(
mock_plan = TaskPlanResponse(
goal="Get NJ results", goal="Get NJ results",
reflection="The user wants NJ results", reflection="The user wants NJ results",
context=TaskPlanContext(initial_context="NJ data", assumptions=[], constraints=[]), checklist=[
steps=["Step 1: Load data", "Step 2: Filter by NJ"] ChecklistTask(task="Query NJ data", worker="data_analyst")
]
) )
mock_llm.with_structured_output.return_value.invoke.return_value = mock_plan mock_llm.with_structured_output.return_value.invoke.return_value = mock_response
result = planner_node(mock_state) result = planner_node(mock_state)
assert "plan" in result assert "checklist" in result
assert "Step 1: Load data" in result["plan"] assert result["checklist"][0]["task"] == "Query NJ data"
assert "Step 2: Filter by NJ" in result["plan"] assert result["current_step"] == 0
assert result["summary"] == "The user wants NJ results"
# Verify helper was called # Verify helper was called
mock_get_summary.assert_called_once() mock_get_summary.assert_called_once()

View File

@@ -1,34 +0,0 @@
import pytest
from unittest.mock import MagicMock, patch
from langchain_core.messages import AIMessage
from langchain_openai import ChatOpenAI
from ea_chatbot.graph.nodes.researcher import researcher_node
@pytest.fixture
def mock_llm():
with patch("ea_chatbot.graph.nodes.researcher.get_llm_model") as mock_get_llm:
mock_llm_instance = MagicMock(spec=ChatOpenAI)
mock_get_llm.return_value = mock_llm_instance
yield mock_llm_instance
def test_researcher_node_success(mock_llm):
"""Test that researcher_node invokes LLM with web_search tool and returns messages."""
state = {
"question": "What is the capital of France?",
"messages": []
}
mock_llm_with_tools = MagicMock()
mock_llm.bind_tools.return_value = mock_llm_with_tools
mock_llm_with_tools.invoke.return_value = AIMessage(content="The capital of France is Paris.")
result = researcher_node(state)
assert mock_llm.bind_tools.called
# Check that it was called with web_search
args, kwargs = mock_llm.bind_tools.call_args
assert {"type": "web_search"} in args[0]
assert mock_llm_with_tools.invoke.called
assert "messages" in result
assert result["messages"][0].content == "The capital of France is Paris."

View File

@@ -1,62 +0,0 @@
import pytest
from unittest.mock import MagicMock, patch
from langchain_core.messages import AIMessage
from langchain_openai import ChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI
from ea_chatbot.graph.nodes.researcher import researcher_node
@pytest.fixture
def base_state():
return {
"question": "Who won the 2024 election?",
"messages": [],
"summary": ""
}
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
def test_researcher_binds_openai_search(mock_get_llm, base_state):
"""Test that OpenAI LLM binds 'web_search' tool."""
mock_llm = MagicMock(spec=ChatOpenAI)
mock_get_llm.return_value = mock_llm
mock_llm_with_tools = MagicMock()
mock_llm.bind_tools.return_value = mock_llm_with_tools
mock_llm_with_tools.invoke.return_value = AIMessage(content="OpenAI Search Result")
result = researcher_node(base_state)
# Verify bind_tools called with correct OpenAI tool
mock_llm.bind_tools.assert_called_once_with([{"type": "web_search"}])
assert result["messages"][0].content == "OpenAI Search Result"
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
def test_researcher_binds_google_search(mock_get_llm, base_state):
"""Test that Google LLM binds 'google_search' tool."""
mock_llm = MagicMock(spec=ChatGoogleGenerativeAI)
mock_get_llm.return_value = mock_llm
mock_llm_with_tools = MagicMock()
mock_llm.bind_tools.return_value = mock_llm_with_tools
mock_llm_with_tools.invoke.return_value = AIMessage(content="Google Search Result")
result = researcher_node(base_state)
# Verify bind_tools called with correct Google tool
mock_llm.bind_tools.assert_called_once_with([{"google_search": {}}])
assert result["messages"][0].content == "Google Search Result"
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
def test_researcher_fallback_on_bind_error(mock_get_llm, base_state):
"""Test that researcher falls back to basic LLM if bind_tools fails."""
mock_llm = MagicMock(spec=ChatOpenAI)
mock_get_llm.return_value = mock_llm
# Simulate bind_tools failing (e.g. model doesn't support it)
mock_llm.bind_tools.side_effect = Exception("Not supported")
mock_llm.invoke.return_value = AIMessage(content="Basic Result")
result = researcher_node(base_state)
# Should still succeed using the base LLM
assert result["messages"][0].content == "Basic Result"
mock_llm.invoke.assert_called_once()

View File

@@ -0,0 +1,13 @@
from typing import get_type_hints, List
from ea_chatbot.graph.workers.researcher.state import WorkerState
def test_researcher_worker_state():
"""Verify that Researcher WorkerState has the required fields."""
hints = get_type_hints(WorkerState)
assert "messages" in hints
assert "task" in hints
assert "queries" in hints
assert "raw_results" in hints
assert "iterations" in hints
assert "result" in hints

View File

@@ -0,0 +1,67 @@
import pytest
from unittest.mock import MagicMock
from ea_chatbot.graph.workers.researcher.workflow import create_researcher_worker, WorkerState
from ea_chatbot.graph.workers.researcher.mapping import prepare_researcher_input, merge_researcher_output
from ea_chatbot.graph.state import AgentState
from langchain_core.messages import AIMessage
def test_researcher_worker_flow():
"""Verify that the Researcher worker flow works as expected."""
mock_searcher = MagicMock()
mock_summarizer = MagicMock()
mock_searcher.return_value = {
"raw_results": ["Result A"],
"messages": [AIMessage(content="Search result")]
}
mock_summarizer.return_value = {"result": "Consolidated Summary"}
graph = create_researcher_worker(
searcher=mock_searcher,
summarizer=mock_summarizer
)
initial_state = WorkerState(
messages=[],
task="Find governor",
queries=[],
raw_results=[],
iterations=0,
result=None
)
final_state = graph.invoke(initial_state)
assert final_state["result"] == "Consolidated Summary"
assert mock_searcher.called
assert mock_summarizer.called
def test_researcher_mapping():
"""Verify that we correctly map states for the researcher."""
global_state = AgentState(
checklist=[{"task": "Search X", "worker": "researcher"}],
current_step=0,
messages=[],
question="test",
analysis={},
next_action="",
iterations=0,
vfs={},
plots=[],
dfs={}
)
worker_input = prepare_researcher_input(global_state)
assert worker_input["task"] == "Search X"
worker_output = WorkerState(
messages=[],
task="Search X",
queries=[],
raw_results=[],
iterations=1,
result="Found X"
)
updates = merge_researcher_output(worker_output)
assert updates["messages"][0].content == "Found X"

View File

@@ -0,0 +1,55 @@
import pytest
from unittest.mock import MagicMock, patch
from ea_chatbot.graph.workflow import create_workflow
from ea_chatbot.graph.state import AgentState
from langchain_core.messages import HumanMessage, AIMessage
def test_clarification_flow_immediate_execution():
"""Verify that an ambiguous query immediately executes the clarification node without interruption."""
mock_analyzer = MagicMock()
mock_clarification = MagicMock()
# 1. Analyzer returns 'clarify'
mock_analyzer.return_value = {"next_action": "clarify"}
# 2. Clarification node returns a question
mock_clarification.return_value = {"messages": [AIMessage(content="What year?")]}
# Create workflow without other nodes since they won't be reached
# We still need to provide mock planners etc. to create_workflow
app = create_workflow(
query_analyzer=mock_analyzer,
clarification=mock_clarification,
planner=MagicMock(),
delegate=MagicMock(),
data_analyst_worker=MagicMock(),
researcher_worker=MagicMock(),
reflector=MagicMock(),
synthesizer=MagicMock(),
summarize_conversation=MagicMock()
)
initial_state = AgentState(
messages=[HumanMessage(content="Who won?")],
question="Who won?",
analysis={},
next_action="",
iterations=0,
checklist=[],
current_step=0,
vfs={},
plots=[],
dfs={}
)
# Run the graph
final_state = app.invoke(initial_state)
# Assertions
assert mock_analyzer.called
assert mock_clarification.called
# Verify the state contains the clarification message
assert len(final_state["messages"]) > 0
assert "What year?" in [m.content for m in final_state["messages"]]

View File

@@ -0,0 +1,48 @@
import pytest
from ea_chatbot.graph.workflow import create_workflow, data_analyst_worker_runnable
from ea_chatbot.graph.state import AgentState
from unittest.mock import MagicMock, patch
from langchain_core.messages import HumanMessage, AIMessage
def test_worker_merge_sets_summary_for_reflector(monkeypatch):
"""Verify that worker node (runnable) sets the 'summary' field for the Reflector."""
state = AgentState(
messages=[HumanMessage(content="test")],
question="test",
checklist=[{"task": "Analyze data", "worker": "data_analyst"}],
current_step=0,
iterations=0,
vfs={},
plots=[],
dfs={},
next_action="",
analysis={},
summary="Initial Planner Summary" # Stale summary
)
# Create a mock for the invoke method
mock_invoke = MagicMock()
mock_invoke.return_value = {
"summary": "Actual Worker Findings",
"messages": [AIMessage(content="Actual Worker Findings")],
"vfs": {},
"plots": []
}
# Manually replace the runnable with a mock object that has an invoke method
mock_runnable = MagicMock()
mock_runnable.invoke = mock_invoke
monkeypatch.setattr("ea_chatbot.graph.workflow.data_analyst_worker_runnable", mock_runnable)
# Execute via the module reference (which is now mocked)
from ea_chatbot.graph.workflow import data_analyst_worker_runnable
updates = data_analyst_worker_runnable.invoke(state)
# Verify that 'summary' is in updates and has the worker result
assert "summary" in updates
assert updates["summary"] == "Actual Worker Findings"
# When applied to state, it should overwrite the stale summary
state.update(updates)
assert state["summary"] == "Actual Worker Findings"

View File

@@ -0,0 +1,35 @@
import pytest
from unittest.mock import MagicMock, patch
from ea_chatbot.graph.nodes.reflector import reflector_node
from ea_chatbot.graph.state import AgentState
def test_reflector_does_not_advance_on_failure():
"""Verify that reflector does not increment current_step if not satisfied."""
state = AgentState(
checklist=[{"task": "Task 1", "worker": "data_analyst"}],
current_step=0,
messages=[],
question="test",
analysis={},
next_action="",
iterations=0,
vfs={},
plots=[],
dfs={},
summary="Failed output"
)
with patch("ea_chatbot.graph.nodes.reflector.get_llm_model") as mock_get_llm:
mock_llm = MagicMock()
# Mark as NOT satisfied
mock_llm.with_structured_output.return_value.invoke.return_value = MagicMock(satisfied=False, reasoning="Incomplete.")
mock_get_llm.return_value = mock_llm
result = reflector_node(state)
# Should NOT increment (therefore not in the updates dict)
assert "current_step" not in result
# Should increment iterations
assert result["iterations"] == 1
# Should probably route to planner or retry
assert result["next_action"] == "delegate" # Or 'planner' if we want re-planning

View File

@@ -0,0 +1,51 @@
import pytest
from ea_chatbot.graph.workers.data_analyst.mapping import merge_worker_output as merge_analyst
from ea_chatbot.graph.workers.researcher.mapping import merge_researcher_output as merge_researcher
from ea_chatbot.graph.workers.data_analyst.state import WorkerState as AnalystState
from ea_chatbot.graph.workers.researcher.state import WorkerState as ResearcherState
def test_analyst_merge_updates_summary():
"""Verify that analyst merge updates the global summary."""
worker_state = AnalystState(
result="Actual Worker Result",
messages=[],
task="test",
iterations=1,
vfs_state={},
plots=[]
)
updates = merge_analyst(worker_state)
assert "summary" in updates
assert updates["summary"] == "Actual Worker Result"
def test_researcher_merge_updates_summary():
"""Verify that researcher merge updates the global summary."""
worker_state = ResearcherState(
result="Actual Research Result",
messages=[],
task="test",
iterations=1,
queries=[],
raw_results=[]
)
updates = merge_researcher(worker_state)
assert "summary" in updates
assert updates["summary"] == "Actual Research Result"
def test_merge_handles_none_result():
"""Verify that merge functions handle None results gracefully."""
worker_state = AnalystState(
result=None,
messages=[],
task="test",
iterations=1,
vfs_state={},
plots=[]
)
updates = merge_analyst(worker_state)
assert updates["summary"] is not None
assert isinstance(updates["messages"][0].content, str)
assert len(updates["messages"][0].content) > 0

View File

@@ -0,0 +1,85 @@
import pytest
from unittest.mock import MagicMock, patch
from ea_chatbot.graph.workflow import create_workflow
from ea_chatbot.graph.state import AgentState
from langchain_core.messages import AIMessage, HumanMessage
def test_orchestrator_loop_retry_budget():
"""Verify that the orchestrator loop is bounded and terminates after max retries."""
mock_analyzer = MagicMock()
mock_planner = MagicMock()
mock_delegate = MagicMock()
mock_worker = MagicMock()
mock_reflector = MagicMock()
mock_synthesizer = MagicMock()
# 1. Analyzer: Proceed to planning
mock_analyzer.return_value = {"next_action": "plan"}
# 2. Planner: One task
mock_planner.return_value = {
"checklist": [{"task": "Unsolvable Task", "worker": "data_analyst"}],
"current_step": 0,
"iterations": 0
}
# We'll use the REAL delegate and reflector logic to verify the fix
# But we mock the LLM calls inside them if necessary.
# Actually, it's easier to just mock the node return values but follow the logic.
from ea_chatbot.graph.nodes.delegate import delegate_node
from ea_chatbot.graph.nodes.reflector import reflector_node
# Mocking the LLM inside reflector to always be unsatisfied
with patch("ea_chatbot.graph.nodes.reflector.get_llm_model") as mock_get_llm:
mock_llm = MagicMock()
# Mark as NOT satisfied
mock_llm.with_structured_output.return_value.invoke.return_value = MagicMock(satisfied=False, reasoning="Still bad.")
mock_get_llm.return_value = mock_llm
app = create_workflow(
query_analyzer=mock_analyzer,
planner=mock_planner,
# delegate=delegate_node, # Use real
data_analyst_worker=mock_worker,
# reflector=reflector_node, # Use real
synthesizer=mock_synthesizer
)
# Mock worker to return something
mock_worker.return_value = {"result": "Bad Output", "messages": [AIMessage(content="Bad")]}
mock_synthesizer.return_value = {"messages": [AIMessage(content="Failure Summary")], "next_action": "end"}
initial_state = AgentState(
messages=[HumanMessage(content="test")],
question="test",
analysis={},
next_action="",
iterations=0,
checklist=[],
current_step=0,
vfs={},
plots=[],
dfs={}
)
# Run the graph. If fix works, it should hit iterations=3 and route to synthesizer.
# We use a recursion_limit higher than our retry budget but low enough to fail fast if unbounded.
final_state = app.invoke(initial_state, config={"recursion_limit": 20})
# Assertions
# 1. We tried 3 times (iterations 0, 1, 2) and failed on 3rd.
# Wait, delegate routes to summarize when iterations >= 3.
# Reflector increments iterations.
# Loop:
# Start: it=0
# Delegate (it=0) -> Worker -> Reflector (fail, it=1) -> Delegate (it=1)
# Delegate (it=1) -> Worker -> Reflector (fail, it=2) -> Delegate (it=2)
# Delegate (it=2) -> Worker -> Reflector (fail, it=3) -> Delegate (it=3)
# Delegate (it=3) -> Summarize (it=0)
assert final_state["iterations"] == 0 # Reset in delegate or handled in synthesizer
# Check if we hit the failure summary
assert "Failed to complete task" in final_state["summary"]
assert mock_worker.call_count == 3

View File

@@ -0,0 +1,50 @@
import pytest
from ea_chatbot.utils.vfs import VFSHelper
from ea_chatbot.graph.workers.data_analyst.mapping import prepare_worker_input
from ea_chatbot.graph.state import AgentState
def test_vfs_isolation_deep_copy():
"""Verify that worker VFS is deep-copied from global state."""
global_vfs = {
"data.txt": {
"content": "original",
"metadata": {"tags": ["a"]}
}
}
state = AgentState(
checklist=[{"task": "test", "worker": "data_analyst"}],
current_step=0,
messages=[],
question="test",
analysis={},
next_action="",
iterations=0,
vfs=global_vfs,
plots=[],
dfs={}
)
worker_input = prepare_worker_input(state)
worker_vfs = worker_input["vfs_state"]
# Mutate worker VFS nested object
worker_vfs["data.txt"]["metadata"]["tags"].append("b")
# Global VFS should remain unchanged
assert global_vfs["data.txt"]["metadata"]["tags"] == ["a"]
def test_vfs_schema_normalization():
"""Verify that VFSHelper handles inconsistent VFS entries."""
vfs = {
"raw.txt": "just a string", # Inconsistent with standard Dict[str, Any] entry
"valid.txt": {"content": "data", "metadata": {}}
}
helper = VFSHelper(vfs)
# Should not crash during read
content, metadata = helper.read("raw.txt")
assert content == "just a string"
assert metadata == {}
# Should not crash during list/other ops if they assume dict
assert "raw.txt" in helper.list()

View File

@@ -0,0 +1,15 @@
from typing import get_type_hints, List, Dict, Any
from ea_chatbot.graph.state import AgentState
def test_agent_state_extensions():
"""Verify that AgentState has the new DeepAgents fields."""
hints = get_type_hints(AgentState)
assert "checklist" in hints
# checklist: List[Dict[str, Any]] (or similar for tasks)
assert "current_step" in hints
assert hints["current_step"] == int
assert "vfs" in hints
# vfs: Dict[str, Any]

View File

@@ -1,47 +0,0 @@
import pytest
from unittest.mock import MagicMock, patch
from langchain_core.messages import AIMessage
from ea_chatbot.graph.nodes.summarizer import summarizer_node
@pytest.fixture
def mock_llm():
with patch("ea_chatbot.graph.nodes.summarizer.get_llm_model") as mock_get_llm:
mock_llm_instance = MagicMock()
mock_get_llm.return_value = mock_llm_instance
yield mock_llm_instance
def test_summarizer_node_success(mock_llm):
"""Test that summarizer_node invokes LLM with correct inputs and returns messages."""
state = {
"question": "What is the total count?",
"plan": "1. Run query\n2. Sum results",
"code_output": "The total is 100",
"messages": []
}
mock_llm.invoke.return_value = AIMessage(content="The final answer is 100.")
result = summarizer_node(state)
# Verify LLM was called
assert mock_llm.invoke.called
# Verify result structure
assert "messages" in result
assert len(result["messages"]) == 1
assert isinstance(result["messages"][0], AIMessage)
assert result["messages"][0].content == "The final answer is 100."
def test_summarizer_node_empty_state(mock_llm):
"""Test handling of empty or minimal state."""
state = {
"question": "Empty?",
"messages": []
}
mock_llm.invoke.return_value = AIMessage(content="No data provided.")
result = summarizer_node(state)
assert "messages" in result
assert result["messages"][0].content == "No data provided."

49
backend/tests/test_vfs.py Normal file
View File

@@ -0,0 +1,49 @@
import pytest
from ea_chatbot.utils.vfs import VFSHelper
def test_vfs_read_write():
"""Verify that we can write and read from the VFS."""
vfs = {}
helper = VFSHelper(vfs)
helper.write("test.txt", "Hello World", metadata={"type": "text"})
assert "test.txt" in vfs
assert vfs["test.txt"]["content"] == "Hello World"
assert vfs["test.txt"]["metadata"]["type"] == "text"
content, metadata = helper.read("test.txt")
assert content == "Hello World"
assert metadata["type"] == "text"
def test_vfs_list():
"""Verify that we can list files in the VFS."""
vfs = {}
helper = VFSHelper(vfs)
helper.write("a.txt", "content a")
helper.write("b.txt", "content b")
files = helper.list()
assert "a.txt" in files
assert "b.txt" in files
assert len(files) == 2
def test_vfs_delete():
"""Verify that we can delete files from the VFS."""
vfs = {}
helper = VFSHelper(vfs)
helper.write("test.txt", "Hello")
helper.delete("test.txt")
assert "test.txt" not in vfs
assert len(helper.list()) == 0
def test_vfs_not_found():
"""Verify that reading a non-existent file returns None."""
vfs = {}
helper = VFSHelper(vfs)
content, metadata = helper.read("missing.txt")
assert content is None
assert metadata is None

View File

@@ -0,0 +1,48 @@
import pytest
import threading
from ea_chatbot.utils.vfs import safe_vfs_copy
def test_safe_vfs_copy_success():
"""Test standard success case."""
vfs = {
"test.csv": {"content": "data", "metadata": {"type": "csv"}},
"num": 42
}
copied = safe_vfs_copy(vfs)
assert copied == vfs
assert copied is not vfs
assert copied["test.csv"] is not vfs["test.csv"]
def test_safe_vfs_copy_handles_non_copyable():
"""Test replacing uncopyable objects with error placeholders."""
# A threading.Lock is famously uncopyable
lock = threading.Lock()
vfs = {
"safe_file": "important data",
"unsafe_lock": lock
}
copied = safe_vfs_copy(vfs)
# Safe one remains
assert copied["safe_file"] == "important data"
# Unsafe one is REPLACED with an error dict
assert isinstance(copied["unsafe_lock"], dict)
assert "content" in copied["unsafe_lock"]
assert "ERROR" in copied["unsafe_lock"]["content"]
assert copied["unsafe_lock"]["metadata"]["type"] == "error"
assert "lock" in str(copied["unsafe_lock"]["metadata"]["error"]).lower()
# Original is unchanged (it was a lock)
assert vfs["unsafe_lock"] is lock
def test_safe_vfs_copy_preserves_nested_copyable():
"""Test deepcopy still works for complex but copyable objects."""
vfs = {
"data": {"a": [1, 2, 3], "b": {"c": True}}
}
copied = safe_vfs_copy(vfs)
assert copied["data"]["a"] == [1, 2, 3]
assert copied["data"]["a"] is not vfs["data"]["a"]

View File

@@ -1,92 +1,80 @@
import pytest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from ea_chatbot.graph.workflow import app from ea_chatbot.graph.workflow import create_workflow
from ea_chatbot.graph.nodes.query_analyzer import QueryAnalysis from ea_chatbot.graph.state import AgentState
from ea_chatbot.schemas import TaskPlanResponse, TaskPlanContext, CodeGenerationResponse from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.messages import AIMessage def test_workflow_full_flow():
"""Test the full Orchestrator-Workers flow using node injection."""
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
@patch("ea_chatbot.graph.nodes.planner.get_llm_model")
@patch("ea_chatbot.graph.nodes.coder.get_llm_model")
@patch("ea_chatbot.graph.nodes.summarizer.get_llm_model")
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
@patch("ea_chatbot.utils.database_inspection.get_data_summary")
@patch("ea_chatbot.graph.nodes.executor.Settings")
@patch("ea_chatbot.graph.nodes.executor.DBClient")
def test_workflow_full_flow(mock_db_client, mock_settings, mock_get_summary, mock_researcher_llm, mock_summarizer_llm, mock_coder_llm, mock_planner_llm, mock_qa_llm):
"""Test the flow from query_analyzer through planner to coder."""
# Mock Settings for Executor mock_analyzer = MagicMock()
mock_settings_instance = MagicMock() mock_planner = MagicMock()
mock_settings_instance.db_host = "localhost" mock_delegate = MagicMock()
mock_settings_instance.db_port = 5432 mock_worker = MagicMock()
mock_settings_instance.db_user = "user" mock_reflector = MagicMock()
mock_settings_instance.db_pswd = "pass" mock_synthesizer = MagicMock()
mock_settings_instance.db_name = "test_db" mock_summarize_conv = MagicMock()
mock_settings_instance.db_table = "test_table"
mock_settings.return_value = mock_settings_instance
# Mock DBClient # 1. Analyzer: Proceed to planning
mock_client_instance = MagicMock() mock_analyzer.return_value = {"next_action": "plan"}
mock_db_client.return_value = mock_client_instance
# 1. Mock Query Analyzer # 2. Planner: Generate checklist
mock_qa_instance = MagicMock() mock_planner.return_value = {
mock_qa_llm.return_value = mock_qa_instance "checklist": [{"task": "Step 1", "worker": "data_analyst"}],
mock_qa_instance.with_structured_output.return_value.invoke.return_value = QueryAnalysis( "current_step": 0
data_required=["2024 results"], }
unknowns=[],
ambiguities=[], # 3. Delegate: Route to data_analyst
conditions=[], mock_delegate.side_effect = [
next_action="plan" {"next_action": "data_analyst"},
{"next_action": "summarize"}
]
# 4. Worker: Success
mock_worker.return_value = {
"messages": [AIMessage(content="Worker Summary")],
"vfs": {}
}
# 5. Reflector: Advance
mock_reflector.return_value = {
"current_step": 1,
"next_action": "delegate"
}
# 6. Synthesizer: Final answer
mock_synthesizer.return_value = {
"messages": [AIMessage(content="Final Summary")],
"next_action": "end"
}
# 7. Summarize Conv: End
mock_summarize_conv.return_value = {"summary": "Done"}
app = create_workflow(
query_analyzer=mock_analyzer,
planner=mock_planner,
delegate=mock_delegate,
data_analyst_worker=mock_worker,
reflector=mock_reflector,
synthesizer=mock_synthesizer,
summarize_conversation=mock_summarize_conv
) )
# 2. Mock Planner
mock_planner_instance = MagicMock()
mock_planner_llm.return_value = mock_planner_instance
mock_get_summary.return_value = "Data summary"
mock_planner_instance.with_structured_output.return_value.invoke.return_value = TaskPlanResponse(
goal="Task Goal",
reflection="Reflection",
context=TaskPlanContext(initial_context="Ctx", assumptions=[], constraints=[]),
steps=["Step 1"]
)
# 3. Mock Coder
mock_coder_instance = MagicMock()
mock_coder_llm.return_value = mock_coder_instance
mock_coder_instance.with_structured_output.return_value.invoke.return_value = CodeGenerationResponse(
code="print('Hello')",
explanation="Explanation"
)
# 4. Mock Summarizer
mock_summarizer_instance = MagicMock()
mock_summarizer_llm.return_value = mock_summarizer_instance
mock_summarizer_instance.invoke.return_value = AIMessage(content="Summary")
# 5. Mock Researcher (just in case)
mock_researcher_instance = MagicMock()
mock_researcher_llm.return_value = mock_researcher_instance
# Initial state
initial_state = { initial_state = {
"messages": [], "messages": [HumanMessage(content="Show me results")],
"question": "Show me the 2024 results", "question": "Show me results",
"analysis": None, "analysis": None,
"next_action": "", "next_action": "",
"plan": None, "iterations": 0,
"code": None, "checklist": [],
"error": None, "current_step": 0,
"vfs": {},
"plots": [], "plots": [],
"dfs": {} "dfs": {}
} }
# Run the graph result = app.invoke(initial_state, config={"recursion_limit": 20})
# We use recursion_limit to avoid infinite loops in placeholders if any
result = app.invoke(initial_state, config={"recursion_limit": 10})
assert result["next_action"] == "plan" assert "Final Summary" in [m.content for m in result["messages"]]
assert "plan" in result and result["plan"] is not None assert result["current_step"] == 1
assert "code" in result and "print('Hello')" in result["code"]
assert "analysis" in result

View File

@@ -1,60 +1,55 @@
import pytest import pytest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from langchain_core.messages import AIMessage
from ea_chatbot.graph.workflow import app from ea_chatbot.graph.workflow import app
from ea_chatbot.graph.nodes.query_analyzer import QueryAnalysis from ea_chatbot.schemas import QueryAnalysis, ChecklistResponse, ChecklistTask, CodeGenerationResponse
from ea_chatbot.schemas import TaskPlanResponse, TaskPlanContext, CodeGenerationResponse from ea_chatbot.graph.state import AgentState
from langchain_core.messages import AIMessage
@pytest.fixture @pytest.fixture
def mock_llms(): def mock_llms():
with patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model") as mock_qa_llm, \ with patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model") as mock_qa, \
patch("ea_chatbot.graph.nodes.planner.get_llm_model") as mock_planner_llm, \ patch("ea_chatbot.graph.nodes.planner.get_llm_model") as mock_planner, \
patch("ea_chatbot.graph.nodes.coder.get_llm_model") as mock_coder_llm, \ patch("ea_chatbot.graph.workers.data_analyst.nodes.coder.get_llm_model") as mock_coder, \
patch("ea_chatbot.graph.nodes.summarizer.get_llm_model") as mock_summarizer_llm, \ patch("ea_chatbot.graph.workers.data_analyst.nodes.summarizer.get_llm_model") as mock_worker_summarizer, \
patch("ea_chatbot.graph.nodes.researcher.get_llm_model") as mock_researcher_llm, \ patch("ea_chatbot.graph.nodes.synthesizer.get_llm_model") as mock_synthesizer, \
patch("ea_chatbot.graph.nodes.summarize_conversation.get_llm_model") as mock_summary_llm, \ patch("ea_chatbot.graph.workers.researcher.nodes.searcher.get_llm_model") as mock_researcher, \
patch("ea_chatbot.utils.database_inspection.get_data_summary") as mock_get_summary: patch("ea_chatbot.graph.workers.researcher.nodes.summarizer.get_llm_model") as mock_res_summarizer, \
mock_get_summary.return_value = "Data summary" patch("ea_chatbot.graph.nodes.reflector.get_llm_model") as mock_reflector:
# Mock summary LLM to return a simple response
mock_summary_instance = MagicMock()
mock_summary_llm.return_value = mock_summary_instance
mock_summary_instance.invoke.return_value = AIMessage(content="Turn summary")
yield { yield {
"qa": mock_qa_llm, "qa": mock_qa,
"planner": mock_planner_llm, "planner": mock_planner,
"coder": mock_coder_llm, "coder": mock_coder,
"summarizer": mock_summarizer_llm, "worker_summarizer": mock_worker_summarizer,
"researcher": mock_researcher_llm, "synthesizer": mock_synthesizer,
"summary": mock_summary_llm "researcher": mock_researcher,
"res_summarizer": mock_res_summarizer,
"reflector": mock_reflector
} }
def test_workflow_data_analysis_flow(mock_llms): def test_workflow_data_analysis_flow(mock_llms):
"""Test full flow: QueryAnalyzer -> Planner -> Coder -> Executor -> Summarizer.""" """Test full flow: QueryAnalyzer -> Planner -> Delegate -> DataAnalyst -> Reflector -> Synthesizer."""
# 1. Mock Query Analyzer (routes to plan) # 1. Mock Query Analyzer
mock_qa_instance = MagicMock() mock_qa_instance = MagicMock()
mock_llms["qa"].return_value = mock_qa_instance mock_llms["qa"].return_value = mock_qa_instance
mock_qa_instance.with_structured_output.return_value.invoke.return_value = QueryAnalysis( mock_qa_instance.with_structured_output.return_value.invoke.return_value = QueryAnalysis(
data_required=["2024 results"], data_required=["2024 results"],
unknowns=[], unknowns=[],
ambiguities=[], ambiguities=[],
conditions=[], conditions=[],
next_action="plan" next_action="plan"
) )
# 2. Mock Planner # 2. Mock Planner
mock_planner_instance = MagicMock() mock_planner_instance = MagicMock()
mock_llms["planner"].return_value = mock_planner_instance mock_llms["planner"].return_value = mock_planner_instance
mock_planner_instance.with_structured_output.return_value.invoke.return_value = TaskPlanResponse( mock_planner_instance.with_structured_output.return_value.invoke.return_value = ChecklistResponse(
goal="Get results", goal="Get results",
reflection="Reflect", reflection="Reflect",
context=TaskPlanContext(initial_context="Ctx", assumptions=[], constraints=[]), checklist=[ChecklistTask(task="Query Data", worker="data_analyst")]
steps=["Step 1"]
) )
# 3. Mock Coder # 3. Mock Coder (Worker)
mock_coder_instance = MagicMock() mock_coder_instance = MagicMock()
mock_llms["coder"].return_value = mock_coder_instance mock_llms["coder"].return_value = mock_coder_instance
mock_coder_instance.with_structured_output.return_value.invoke.return_value = CodeGenerationResponse( mock_coder_instance.with_structured_output.return_value.invoke.return_value = CodeGenerationResponse(
@@ -62,10 +57,20 @@ def test_workflow_data_analysis_flow(mock_llms):
explanation="Explain" explanation="Explain"
) )
# 4. Mock Summarizer # 4. Mock Worker Summarizer
mock_summarizer_instance = MagicMock() mock_ws_instance = MagicMock()
mock_llms["summarizer"].return_value = mock_summarizer_instance mock_llms["worker_summarizer"].return_value = mock_ws_instance
mock_summarizer_instance.invoke.return_value = AIMessage(content="Final Summary: Success") mock_ws_instance.invoke.return_value = AIMessage(content="Worker Summary")
# 5. Mock Reflector
mock_reflector_instance = MagicMock()
mock_llms["reflector"].return_value = mock_reflector_instance
mock_reflector_instance.with_structured_output.return_value.invoke.return_value = MagicMock(satisfied=True, reasoning="Good.")
# 6. Mock Synthesizer
mock_syn_instance = MagicMock()
mock_llms["synthesizer"].return_value = mock_syn_instance
mock_syn_instance.invoke.return_value = AIMessage(content="Final Summary: Success")
# Initial state # Initial state
initial_state = { initial_state = {
@@ -73,66 +78,77 @@ def test_workflow_data_analysis_flow(mock_llms):
"question": "Show me 2024 results", "question": "Show me 2024 results",
"analysis": None, "analysis": None,
"next_action": "", "next_action": "",
"plan": None, "iterations": 0,
"code": None, "checklist": [],
"error": None, "current_step": 0,
"vfs": {},
"plots": [], "plots": [],
"dfs": {} "dfs": {}
} }
# Run the graph # Run the graph
result = app.invoke(initial_state, config={"recursion_limit": 15}) result = app.invoke(initial_state, config={"recursion_limit": 20})
assert result["next_action"] == "plan" assert "Final Summary: Success" in [m.content for m in result["messages"]]
assert "Execution Success" in result["code_output"] assert result["current_step"] == 1
assert "Final Summary: Success" in result["messages"][-1].content
def test_workflow_research_flow(mock_llms): def test_workflow_research_flow(mock_llms):
"""Test flow: QueryAnalyzer -> Researcher -> Summarizer.""" """Test flow with research task."""
# 1. Mock Query Analyzer (routes to research) # 1. Mock Query Analyzer
mock_qa_instance = MagicMock() mock_qa_instance = MagicMock()
mock_llms["qa"].return_value = mock_qa_instance mock_llms["qa"].return_value = mock_qa_instance
mock_qa_instance.with_structured_output.return_value.invoke.return_value = QueryAnalysis( mock_qa_instance.with_structured_output.return_value.invoke.return_value = QueryAnalysis(
data_required=[], data_required=[],
unknowns=[], unknowns=[],
ambiguities=[], ambiguities=[],
conditions=[], conditions=[],
next_action="research" next_action="research"
) )
# 2. Mock Researcher # 2. Mock Planner
mock_researcher_instance = MagicMock() mock_planner_instance = MagicMock()
mock_llms["researcher"].return_value = mock_researcher_instance mock_llms["planner"].return_value = mock_planner_instance
# Researcher node uses bind_tools if it's ChatOpenAI/ChatGoogleGenerativeAI mock_planner_instance.with_structured_output.return_value.invoke.return_value = ChecklistResponse(
# Since it's a MagicMock, it will fallback to using the base instance goal="Search",
mock_researcher_instance.invoke.return_value = AIMessage(content="Research Results") reflection="Reflect",
checklist=[ChecklistTask(task="Search Web", worker="researcher")]
)
# Also mock bind_tools just in case we ever use spec # 3. Mock Researcher Searcher
mock_llm_with_tools = MagicMock() mock_res_instance = MagicMock()
mock_researcher_instance.bind_tools.return_value = mock_llm_with_tools mock_llms["researcher"].return_value = mock_res_instance
mock_llm_with_tools.invoke.return_value = AIMessage(content="Research Results") mock_res_instance.invoke.return_value = AIMessage(content="Research Result")
# 3. Mock Summarizer (not used in this flow, but kept for completeness) # 4. Mock Researcher Summarizer
mock_summarizer_instance = MagicMock() mock_rs_instance = MagicMock()
mock_llms["summarizer"].return_value = mock_summarizer_instance mock_llms["res_summarizer"].return_value = mock_rs_instance
mock_summarizer_instance.invoke.return_value = AIMessage(content="Final Summary: Research Success") mock_rs_instance.invoke.return_value = AIMessage(content="Researcher Summary")
# 5. Mock Reflector
mock_reflector_instance = MagicMock()
mock_llms["reflector"].return_value = mock_reflector_instance
mock_reflector_instance.with_structured_output.return_value.invoke.return_value = MagicMock(satisfied=True, reasoning="Good.")
# 6. Mock Synthesizer
mock_syn_instance = MagicMock()
mock_llms["synthesizer"].return_value = mock_syn_instance
mock_syn_instance.invoke.return_value = AIMessage(content="Final Research Summary")
# Initial state
initial_state = { initial_state = {
"messages": [], "messages": [],
"question": "Who is the governor of Florida?", "question": "Who is the governor?",
"analysis": None, "analysis": None,
"next_action": "", "next_action": "",
"plan": None, "iterations": 0,
"code": None, "checklist": [],
"error": None, "current_step": 0,
"vfs": {},
"plots": [], "plots": [],
"dfs": {} "dfs": {}
} }
# Run the graph result = app.invoke(initial_state, config={"recursion_limit": 20})
result = app.invoke(initial_state, config={"recursion_limit": 10})
assert result["next_action"] == "research" assert "Final Research Summary" in [m.content for m in result["messages"]]
assert "Research Results" in result["messages"][-1].content assert result["current_step"] == 1

View File

@@ -1,12 +1,12 @@
{ {
"name": "frontend", "name": "frontend",
"version": "0.0.0", "version": "0.1.0",
"lockfileVersion": 3, "lockfileVersion": 3,
"requires": true, "requires": true,
"packages": { "packages": {
"": { "": {
"name": "frontend", "name": "frontend",
"version": "0.0.0", "version": "0.1.0",
"dependencies": { "dependencies": {
"@hookform/resolvers": "^5.2.2", "@hookform/resolvers": "^5.2.2",
"@tailwindcss/vite": "^4.1.18", "@tailwindcss/vite": "^4.1.18",

View File

@@ -9,9 +9,9 @@ interface ExecutionStatusProps {
const PHASE_CONFIG = [ const PHASE_CONFIG = [
{ label: "Analyzing query...", match: "Query analysis complete." }, { label: "Analyzing query...", match: "Query analysis complete." },
{ label: "Generating strategic plan...", match: "Strategic plan generated." }, { label: "Generating high-level plan...", match: "Checklist generated." },
{ label: "Writing analysis code...", match: "Analysis code generated." }, { label: "Delegating to specialists...", match: "Task assigned." },
{ label: "Performing data analysis...", match: "Data analysis and visualization complete." } { label: "Synthesizing final answer...", match: "Final synthesis complete." }
] ]
export function ExecutionStatus({ steps, isComplete, className }: ExecutionStatusProps) { export function ExecutionStatus({ steps, isComplete, className }: ExecutionStatusProps) {

View File

@@ -1,7 +1,7 @@
import { cva } from "class-variance-authority" import { cva } from "class-variance-authority"
export const buttonVariants = cva( export const buttonVariants = cva(
"inline-flex items-center justify-center gap-2 whitespace-nowrap rounded-md text-sm font-medium transition-all disabled:pointer-events-none disabled:opacity-50 [&_svg]:pointer-events-none [&_svg:not([class*='size-'])]:size-4 shrink-0 [&_svg]:shrink-0 outline-none focus-visible:border-ring focus-visible:ring-ring/50 focus-visible:ring-[3px] aria-invalid:ring-destructive/20 dark:aria-invalid:ring-destructive/40 aria-invalid:border-destructive", "inline-flex items-center justify-center gap-2 whitespace-nowrap rounded-md text-sm font-medium transition-all cursor-pointer disabled:pointer-events-none disabled:opacity-50 disabled:cursor-default [&_svg]:pointer-events-none [&_svg:not([class*='size-'])]:size-4 shrink-0 [&_svg]:shrink-0 outline-none focus-visible:border-ring focus-visible:ring-ring/50 focus-visible:ring-[3px] aria-invalid:ring-destructive/20 dark:aria-invalid:ring-destructive/40 aria-invalid:border-destructive",
{ {
variants: { variants: {
variant: { variant: {

View File

@@ -0,0 +1,29 @@
import { render, screen } from "@testing-library/react"
import { describe, expect, it } from "vitest"
import { Button } from "./button"
describe("Button component", () => {
it("renders with a pointer cursor", () => {
render(<Button>Click me</Button>)
const button = screen.getByRole("button", { name: /click me/i })
expect(button).toHaveClass("cursor-pointer")
})
it("renders different variants with a pointer cursor", () => {
const { rerender } = render(<Button variant="outline">Outline</Button>)
let button = screen.getByRole("button", { name: /outline/i })
expect(button).toHaveClass("cursor-pointer")
rerender(<Button variant="destructive">Destructive</Button>)
button = screen.getByRole("button", { name: /destructive/i })
expect(button).toHaveClass("cursor-pointer")
})
it("renders with a default cursor when disabled", () => {
render(<Button disabled>Disabled Button</Button>)
const button = screen.getByRole("button", { name: /disabled button/i })
expect(button).toBeDisabled()
expect(button).toHaveClass("disabled:pointer-events-none")
expect(button).toHaveClass("disabled:cursor-default")
})
})

View File

@@ -0,0 +1,74 @@
import { render, screen } from "@testing-library/react"
import { describe, expect, it } from "vitest"
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuSub,
DropdownMenuSubTrigger,
DropdownMenuSubContent,
DropdownMenuPortal
} from "./dropdown-menu"
describe("DropdownMenu components", () => {
it("DropdownMenuItem should have cursor-pointer when enabled", () => {
render(
<DropdownMenu open={true}>
<DropdownMenuContent>
<DropdownMenuItem>Enabled Item</DropdownMenuItem>
</DropdownMenuContent>
</DropdownMenu>
)
const item = screen.getByText("Enabled Item")
expect(item).toHaveClass("cursor-pointer")
})
it("DropdownMenuSubTrigger should have cursor-pointer when enabled", () => {
render(
<DropdownMenu open={true}>
<DropdownMenuContent>
<DropdownMenuSub>
<DropdownMenuSubTrigger>Sub Trigger</DropdownMenuSubTrigger>
<DropdownMenuPortal>
<DropdownMenuSubContent />
</DropdownMenuPortal>
</DropdownMenuSub>
</DropdownMenuContent>
</DropdownMenu>
)
const trigger = screen.getByText("Sub Trigger")
expect(trigger).toHaveClass("cursor-pointer")
})
it("DropdownMenuItem should show default cursor when disabled", () => {
render(
<DropdownMenu open={true}>
<DropdownMenuContent>
<DropdownMenuItem disabled>Disabled Item</DropdownMenuItem>
</DropdownMenuContent>
</DropdownMenu>
)
const item = screen.getByText("Disabled Item")
// Radix sets data-disabled attribute
expect(item).toHaveAttribute("data-disabled")
expect(item).toHaveClass("data-[disabled]:cursor-default")
})
it("DropdownMenuSubTrigger should show default cursor when disabled", () => {
render(
<DropdownMenu open={true}>
<DropdownMenuContent>
<DropdownMenuSub>
<DropdownMenuSubTrigger disabled>Disabled Sub Trigger</DropdownMenuSubTrigger>
<DropdownMenuPortal>
<DropdownMenuSubContent />
</DropdownMenuPortal>
</DropdownMenuSub>
</DropdownMenuContent>
</DropdownMenu>
)
const trigger = screen.getByText("Disabled Sub Trigger")
expect(trigger).toHaveAttribute("data-disabled")
expect(trigger).toHaveClass("data-[disabled]:cursor-default")
})
})

View File

@@ -27,7 +27,7 @@ const DropdownMenuSubTrigger = React.forwardRef<
<DropdownMenuPrimitive.SubTrigger <DropdownMenuPrimitive.SubTrigger
ref={ref} ref={ref}
className={cn( className={cn(
"focus:bg-accent data-[state=open]:bg-accent flex cursor-default items-center rounded-sm px-2 py-1.5 text-sm outline-hidden select-none", "focus:bg-accent data-[state=open]:bg-accent flex cursor-pointer items-center rounded-sm px-2 py-1.5 text-sm outline-hidden select-none data-[disabled]:cursor-default data-[disabled]:pointer-events-none data-[disabled]:opacity-50",
inset && "pl-8", inset && "pl-8",
className className
)} )}
@@ -81,7 +81,7 @@ const DropdownMenuItem = React.forwardRef<
<DropdownMenuPrimitive.Item <DropdownMenuPrimitive.Item
ref={ref} ref={ref}
className={cn( className={cn(
"focus:bg-accent focus:text-accent-foreground relative flex cursor-default items-center rounded-sm px-2 py-1.5 text-sm outline-hidden transition-colors data-[disabled]:pointer-events-none data-[disabled]:opacity-50 select-none", "focus:bg-accent focus:text-accent-foreground relative flex cursor-pointer items-center rounded-sm px-2 py-1.5 text-sm outline-hidden transition-colors data-[disabled]:pointer-events-none data-[disabled]:opacity-50 data-[disabled]:cursor-default select-none",
inset && "pl-8", inset && "pl-8",
className className
)} )}
@@ -99,7 +99,7 @@ const DropdownMenuCheckboxItem = React.forwardRef<
<DropdownMenuPrimitive.CheckboxItem <DropdownMenuPrimitive.CheckboxItem
ref={ref} ref={ref}
className={cn( className={cn(
"focus:bg-accent focus:text-accent-foreground relative flex cursor-default items-center rounded-sm py-1.5 pr-2 pl-8 text-sm outline-hidden transition-colors data-[disabled]:pointer-events-none data-[disabled]:opacity-50 select-none", "focus:bg-accent focus:text-accent-foreground relative flex cursor-pointer items-center rounded-sm py-1.5 pr-2 pl-8 text-sm outline-hidden transition-colors data-[disabled]:pointer-events-none data-[disabled]:opacity-50 data-[disabled]:cursor-default select-none",
className className
)} )}
checked={checked} checked={checked}
@@ -122,7 +122,7 @@ const DropdownMenuRadioItem = React.forwardRef<
<DropdownMenuPrimitive.RadioItem <DropdownMenuPrimitive.RadioItem
ref={ref} ref={ref}
className={cn( className={cn(
"focus:bg-accent focus:text-accent-foreground relative flex cursor-default items-center rounded-sm py-1.5 pr-2 pl-8 text-sm outline-hidden transition-colors data-[disabled]:pointer-events-none data-[disabled]:opacity-50 select-none", "focus:bg-accent focus:text-accent-foreground relative flex cursor-pointer items-center rounded-sm py-1.5 pr-2 pl-8 text-sm outline-hidden transition-colors data-[disabled]:pointer-events-none data-[disabled]:opacity-50 data-[disabled]:cursor-default select-none",
className className
)} )}
{...props} {...props}

View File

@@ -0,0 +1,83 @@
import { render, screen } from "@testing-library/react"
import { describe, expect, it } from "vitest"
import {
SidebarProvider,
SidebarMenuButton,
SidebarMenuAction,
SidebarGroupAction,
SidebarMenu,
SidebarMenuItem,
SidebarGroup
} from "./sidebar"
describe("Sidebar components", () => {
it("SidebarMenuButton should have cursor-pointer", () => {
render(
<SidebarProvider>
<SidebarMenu>
<SidebarMenuItem>
<SidebarMenuButton>Menu Button</SidebarMenuButton>
</SidebarMenuItem>
</SidebarMenu>
</SidebarProvider>
)
const button = screen.getByRole("button", { name: /menu button/i })
expect(button).toHaveClass("cursor-pointer")
})
it("SidebarMenuAction should have cursor-pointer", () => {
render(
<SidebarProvider>
<SidebarMenu>
<SidebarMenuItem>
<SidebarMenuButton>Button</SidebarMenuButton>
<SidebarMenuAction>Action</SidebarMenuAction>
</SidebarMenuItem>
</SidebarMenu>
</SidebarProvider>
)
const action = screen.getByRole("button", { name: /action/i })
expect(action).toHaveClass("cursor-pointer")
})
it("SidebarGroupAction should have cursor-pointer", () => {
render(
<SidebarProvider>
<SidebarGroup>
<SidebarGroupAction>Group Action</SidebarGroupAction>
</SidebarGroup>
</SidebarProvider>
)
const action = screen.getByRole("button", { name: /group action/i })
expect(action).toHaveClass("cursor-pointer")
})
it("SidebarGroupAction should have cursor-default when disabled", () => {
render(
<SidebarProvider>
<SidebarGroup>
<SidebarGroupAction disabled>Disabled Group Action</SidebarGroupAction>
</SidebarGroup>
</SidebarProvider>
)
const action = screen.getByRole("button", { name: /disabled group action/i })
expect(action).toHaveClass("disabled:cursor-default")
expect(action).toBeDisabled()
})
it("SidebarMenuAction should not show pointer when disabled", () => {
render(
<SidebarProvider>
<SidebarMenu>
<SidebarMenuItem>
<SidebarMenuButton>Button</SidebarMenuButton>
<SidebarMenuAction disabled>Disabled Action</SidebarMenuAction>
</SidebarMenuItem>
</SidebarMenu>
</SidebarProvider>
)
const action = screen.getByRole("button", { name: /disabled action/i })
expect(action).toHaveClass("disabled:cursor-default")
expect(action).toBeDisabled()
})
})

View File

@@ -405,7 +405,7 @@ function SidebarGroupAction({
data-slot="sidebar-group-action" data-slot="sidebar-group-action"
data-sidebar="group-action" data-sidebar="group-action"
className={cn( className={cn(
"text-sidebar-foreground ring-sidebar-ring hover:bg-sidebar-accent hover:text-sidebar-accent-foreground absolute top-3.5 right-3 flex aspect-square w-5 items-center justify-center rounded-md p-0 outline-hidden transition-transform focus-visible:ring-2 [&>svg]:size-4 [&>svg]:shrink-0", "text-sidebar-foreground ring-sidebar-ring hover:bg-sidebar-accent hover:text-sidebar-accent-foreground absolute top-3.5 right-3 flex aspect-square w-5 items-center justify-center rounded-md p-0 outline-hidden transition-transform focus-visible:ring-2 [&>svg]:size-4 [&>svg]:shrink-0 cursor-pointer disabled:cursor-default disabled:pointer-events-none disabled:opacity-50",
// Increases the hit area of the button on mobile. // Increases the hit area of the button on mobile.
"after:absolute after:-inset-2 md:after:hidden", "after:absolute after:-inset-2 md:after:hidden",
"group-data-[collapsible=icon]:hidden", "group-data-[collapsible=icon]:hidden",
@@ -477,7 +477,7 @@ function SidebarMenuButton({
data-size={size} data-size={size}
data-active={isActive} data-active={isActive}
className={cn( className={cn(
"peer/menu-button flex w-full items-center gap-2 overflow-hidden rounded-md p-2 text-left text-sm outline-hidden ring-sidebar-ring transition-[width,height,padding] hover:bg-sidebar-accent hover:text-sidebar-accent-foreground focus-visible:ring-2 active:bg-sidebar-accent active:text-sidebar-accent-foreground disabled:pointer-events-none disabled:opacity-50 group-has-data-[sidebar=menu-action]/menu-item:pr-8 aria-disabled:pointer-events-none aria-disabled:opacity-50 data-[active=true]:bg-sidebar-accent data-[active=true]:font-medium data-[active=true]:text-sidebar-accent-foreground data-[state=open]:hover:bg-sidebar-accent data-[state=open]:hover:text-sidebar-accent-foreground group-data-[collapsible=icon]:size-8! group-data-[collapsible=icon]:p-2! [&>span:last-child]:truncate [&>svg]:size-4 [&>svg]:shrink-0", "peer/menu-button flex w-full items-center gap-2 overflow-hidden rounded-md p-2 text-left text-sm outline-hidden ring-sidebar-ring transition-[width,height,padding] hover:bg-sidebar-accent hover:text-sidebar-accent-foreground focus-visible:ring-2 active:bg-sidebar-accent active:text-sidebar-accent-foreground disabled:pointer-events-none disabled:opacity-50 group-has-data-[sidebar=menu-action]/menu-item:pr-8 aria-disabled:pointer-events-none aria-disabled:opacity-50 data-[active=true]:bg-sidebar-accent data-[active=true]:font-medium data-[active=true]:text-sidebar-accent-foreground data-[state=open]:hover:bg-sidebar-accent data-[state=open]:hover:text-sidebar-accent-foreground group-data-[collapsible=icon]:size-8! group-data-[collapsible=icon]:p-2! [&>span:last-child]:truncate [&>svg]:size-4 [&>svg]:shrink-0 cursor-pointer disabled:cursor-default",
variant === "default" && "hover:bg-sidebar-accent hover:text-sidebar-accent-foreground", variant === "default" && "hover:bg-sidebar-accent hover:text-sidebar-accent-foreground",
variant === "outline" && "bg-background shadow-[0_0_0_1px_hsl(var(--sidebar-border))] hover:bg-sidebar-accent hover:text-sidebar-accent-foreground hover:shadow-[0_0_0_1px_hsl(var(--sidebar-accent))]", variant === "outline" && "bg-background shadow-[0_0_0_1px_hsl(var(--sidebar-border))] hover:bg-sidebar-accent hover:text-sidebar-accent-foreground hover:shadow-[0_0_0_1px_hsl(var(--sidebar-accent))]",
size === "default" && "h-8 text-sm", size === "default" && "h-8 text-sm",
@@ -528,7 +528,7 @@ function SidebarMenuAction({
data-slot="sidebar-menu-action" data-slot="sidebar-menu-action"
data-sidebar="menu-action" data-sidebar="menu-action"
className={cn( className={cn(
"text-sidebar-foreground ring-sidebar-ring hover:bg-sidebar-accent hover:text-sidebar-accent-foreground peer-hover/menu-button:text-sidebar-accent-foreground absolute top-1.5 right-1 flex aspect-square w-5 items-center justify-center rounded-md p-0 outline-hidden transition-transform focus-visible:ring-2 [&>svg]:size-4 [&>svg]:shrink-0", "text-sidebar-foreground ring-sidebar-ring hover:bg-sidebar-accent hover:text-sidebar-accent-foreground peer-hover/menu-button:text-sidebar-accent-foreground absolute top-1.5 right-1 flex aspect-square w-5 items-center justify-center rounded-md p-0 outline-hidden transition-transform focus-visible:ring-2 [&>svg]:size-4 [&>svg]:shrink-0 cursor-pointer disabled:cursor-default disabled:pointer-events-none disabled:opacity-50",
// Increases the hit area of the button on mobile. // Increases the hit area of the button on mobile.
"after:absolute after:-inset-2 md:after:hidden", "after:absolute after:-inset-2 md:after:hidden",
"peer-data-[size=sm]/menu-button:top-1", "peer-data-[size=sm]/menu-button:top-1",

View File

@@ -3,21 +3,21 @@ import { ChatService, type ChatEvent, type MessageResponse } from "./chat"
describe("ChatService SSE Parsing", () => { describe("ChatService SSE Parsing", () => {
it("should correctly parse a text stream chunk", () => { it("should correctly parse a text stream chunk", () => {
const rawChunk = `data: {"type": "on_chat_model_stream", "name": "summarizer", "data": {"chunk": "Hello"}}\n\n` const rawChunk = `data: {"type": "on_chat_model_stream", "name": "synthesizer", "data": {"chunk": "Hello"}}\n\n`
const events = ChatService.parseSSEChunk(rawChunk) const events = ChatService.parseSSEChunk(rawChunk)
expect(events).toHaveLength(1) expect(events).toHaveLength(1)
expect(events[0]).toEqual({ expect(events[0]).toEqual({
type: "on_chat_model_stream", type: "on_chat_model_stream",
name: "summarizer", name: "synthesizer",
data: { chunk: "Hello" } data: { chunk: "Hello" }
}) })
}) })
it("should handle multiple events in one chunk", () => { it("should handle multiple events in one chunk", () => {
const rawChunk = const rawChunk =
`data: {"type": "on_chat_model_stream", "name": "summarizer", "data": {"chunk": "Hello"}}\n\n` + `data: {"type": "on_chat_model_stream", "name": "synthesizer", "data": {"chunk": "Hello"}}\n\n` +
`data: {"type": "on_chat_model_stream", "name": "summarizer", "data": {"chunk": " World"}}\n\n` `data: {"type": "on_chat_model_stream", "name": "synthesizer", "data": {"chunk": " World"}}\n\n`
const events = ChatService.parseSSEChunk(rawChunk) const events = ChatService.parseSSEChunk(rawChunk)
@@ -25,8 +25,8 @@ describe("ChatService SSE Parsing", () => {
expect(events[1].data!.chunk).toBe(" World") expect(events[1].data!.chunk).toBe(" World")
}) })
it("should parse encoded plots from executor node", () => { it("should parse encoded plots from data_analyst_worker node", () => {
const rawChunk = `data: {"type": "on_chain_end", "name": "executor", "data": {"encoded_plots": ["base64data"]}}\n\n` const rawChunk = `data: {"type": "on_chain_end", "name": "data_analyst_worker", "data": {"encoded_plots": ["base64data"]}}\n\n`
const events = ChatService.parseSSEChunk(rawChunk) const events = ChatService.parseSSEChunk(rawChunk)
expect(events[0].data!.encoded_plots).toEqual(["base64data"]) expect(events[0].data!.encoded_plots).toEqual(["base64data"])
@@ -45,7 +45,7 @@ describe("ChatService Message State Management", () => {
const messages: MessageResponse[] = [{ id: "1", role: "assistant", content: "Initial", created_at: new Date().toISOString() }] const messages: MessageResponse[] = [{ id: "1", role: "assistant", content: "Initial", created_at: new Date().toISOString() }]
const event: ChatEvent = { const event: ChatEvent = {
type: "on_chat_model_stream", type: "on_chat_model_stream",
node: "summarizer", node: "synthesizer",
data: { chunk: { content: " text" } } data: { chunk: { content: " text" } }
} }
@@ -57,7 +57,7 @@ describe("ChatService Message State Management", () => {
const messages: MessageResponse[] = [{ id: "1", role: "assistant", content: "Analysis", created_at: new Date().toISOString(), plots: [] }] const messages: MessageResponse[] = [{ id: "1", role: "assistant", content: "Analysis", created_at: new Date().toISOString(), plots: [] }]
const event: ChatEvent = { const event: ChatEvent = {
type: "on_chain_end", type: "on_chain_end",
name: "executor", name: "data_analyst_worker",
data: { encoded_plots: ["plot1"] } data: { encoded_plots: ["plot1"] }
} }

View File

@@ -86,7 +86,8 @@ export const ChatService = {
const { type, name, node, data } = event const { type, name, node, data } = event
// 1. Handle incremental LLM chunks for terminal nodes // 1. Handle incremental LLM chunks for terminal nodes
if (type === "on_chat_model_stream" && (node === "summarizer" || node === "researcher" || node === "clarification")) { // Now using 'synthesizer' for the final user response
if (type === "on_chat_model_stream" && (node === "synthesizer" || node === "clarification")) {
const chunk = data?.chunk?.content || "" const chunk = data?.chunk?.content || ""
if (!chunk) return messages if (!chunk) return messages
@@ -110,7 +111,7 @@ export const ChatService = {
if (!lastMsg || lastMsg.role !== "assistant") return messages if (!lastMsg || lastMsg.role !== "assistant") return messages
// Terminal nodes final text // Terminal nodes final text
if (name === "summarizer" || name === "researcher" || name === "clarification") { if (name === "synthesizer" || name === "clarification") {
const messages_list = data?.output?.messages const messages_list = data?.output?.messages
const msg = messages_list ? messages_list[messages_list.length - 1]?.content : null const msg = messages_list ? messages_list[messages_list.length - 1]?.content : null
@@ -121,8 +122,8 @@ export const ChatService = {
} }
} }
// Plots from executor // Plots from data analyst worker
if (name === "executor" && data?.encoded_plots) { if (name === "data_analyst_worker" && data?.encoded_plots) {
lastMsg.plots = [...(lastMsg.plots || []), ...data.encoded_plots] lastMsg.plots = [...(lastMsg.plots || []), ...data.encoded_plots]
// Filter out the 'active' step and replace with 'complete' // Filter out the 'active' step and replace with 'complete'
const filteredSteps = (lastMsg.steps || []).filter(s => s !== "Performing data analysis..."); const filteredSteps = (lastMsg.steps || []).filter(s => s !== "Performing data analysis...");
@@ -134,15 +135,25 @@ export const ChatService = {
// Status for intermediate nodes (completion) // Status for intermediate nodes (completion)
const statusMap: Record<string, string> = { const statusMap: Record<string, string> = {
"query_analyzer": "Query analysis complete.", "query_analyzer": "Query analysis complete.",
"planner": "Strategic plan generated.", "planner": "Checklist generated.",
"coder": "Analysis code generated." "delegate": "Task assigned.",
"reflector": "Result verified.",
"coder": "Analysis code generated.",
"executor": "Code execution complete.",
"searcher": "Web search complete.",
"summarizer": "Task summary generated."
} }
if (name && statusMap[name]) { if (name && statusMap[name]) {
// Find and replace the active status if it exists // Find and replace the active status if it exists
const activeStatus = name === "query_analyzer" ? "Analyzing query..." : const activeStatus = name === "query_analyzer" ? "Analyzing query..." :
name === "planner" ? "Generating strategic plan..." : name === "planner" ? "Generating high-level plan..." :
name === "coder" ? "Writing analysis code..." : null; name === "delegate" ? "Routing task..." :
name === "reflector" ? "Evaluating results..." :
name === "coder" ? "Writing analysis code..." :
name === "executor" ? "Executing code..." :
name === "searcher" ? "Searching web..." :
name === "summarizer" ? "Summarizing results..." : null;
let filteredSteps = lastMsg.steps || []; let filteredSteps = lastMsg.steps || [];
if (activeStatus) { if (activeStatus) {
@@ -159,9 +170,13 @@ export const ChatService = {
if (type === "on_chain_start") { if (type === "on_chain_start") {
const startStatusMap: Record<string, string> = { const startStatusMap: Record<string, string> = {
"query_analyzer": "Analyzing query...", "query_analyzer": "Analyzing query...",
"planner": "Generating strategic plan...", "planner": "Generating high-level plan...",
"delegate": "Routing task...",
"reflector": "Evaluating results...",
"coder": "Writing analysis code...", "coder": "Writing analysis code...",
"executor": "Performing data analysis..." "executor": "Executing code...",
"searcher": "Searching web...",
"summarizer": "Summarizing results..."
} }
if (name && startStatusMap[name]) { if (name && startStatusMap[name]) {
@@ -170,9 +185,11 @@ export const ChatService = {
const lastMsg = { ...newMessages[lastMsgIndex] } const lastMsg = { ...newMessages[lastMsgIndex] }
if (lastMsg && lastMsg.role === "assistant") { if (lastMsg && lastMsg.role === "assistant") {
// Avoid duplicate start messages const currentSteps = lastMsg.steps || []
if (!(lastMsg.steps || []).includes(startStatusMap[name])) { // Allow duplicate steps if it's a retry (cycle),
lastMsg.steps = [...(lastMsg.steps || []), startStatusMap[name]] // but avoid spamming the same active step multiple times in a row
if (currentSteps[currentSteps.length - 1] !== startStatusMap[name]) {
lastMsg.steps = [...currentSteps, startStatusMap[name]]
newMessages[lastMsgIndex] = lastMsg newMessages[lastMsgIndex] = lastMsg
return newMessages return newMessages
} }