Compare commits
38 Commits
2c44df3a5c
...
develop
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f62a70f7c3 | ||
|
|
b9084fcaef | ||
|
|
8fcfc4ee88 | ||
|
|
b8d8651924 | ||
|
|
02d93120e0 | ||
|
|
92c30d217e | ||
|
|
88a27f5a8d | ||
|
|
4d92c9aedb | ||
|
|
557b553c59 | ||
|
|
2cfbc5d1d0 | ||
|
|
c5cf4b38a1 | ||
|
|
46129c6f1e | ||
|
|
8eea464be4 | ||
|
|
11c14fb8a8 | ||
|
|
9b97140fff | ||
|
|
9e90f2c9ad | ||
|
|
5cc5bd91ae | ||
|
|
120b6fd11a | ||
|
|
f4d09c07c4 | ||
|
|
ad7845cc6a | ||
|
|
18e4e8db7d | ||
|
|
9fef4888b5 | ||
|
|
37c353a249 | ||
|
|
ff9b443bfe | ||
|
|
575e1a2e53 | ||
|
|
013208b929 | ||
|
|
cb045504d1 | ||
|
|
5324cbe851 | ||
|
|
eeb2be409b | ||
|
|
92d9288f38 | ||
|
|
8957e93f3d | ||
|
|
45fe122580 | ||
|
|
969165f4a7 | ||
|
|
e7be0dbeca | ||
|
|
322ae1e7c8 | ||
|
|
46b57d2a73 | ||
|
|
99ded43483 | ||
|
|
1394c0496a |
@@ -4,6 +4,7 @@ A stateful, graph-based chatbot for election data analysis, built with LangGraph
|
|||||||
|
|
||||||
## 🚀 Features
|
## 🚀 Features
|
||||||
|
|
||||||
|
- **Multi-Agent Orchestration**: Decomposes complex queries and delegates them to specialized sub-agents (Data Analyst, Researcher) using a robust feedback loop.
|
||||||
- **Intelligent Query Analysis**: Automatically determines if a query needs data analysis, web research, or clarification.
|
- **Intelligent Query Analysis**: Automatically determines if a query needs data analysis, web research, or clarification.
|
||||||
- **Automated Data Analysis**: Generates and executes Python code to analyze election datasets and produce visualizations.
|
- **Automated Data Analysis**: Generates and executes Python code to analyze election datasets and produce visualizations.
|
||||||
- **Web Research**: Integrates web search capabilities for general election-related questions.
|
- **Web Research**: Integrates web search capabilities for general election-related questions.
|
||||||
|
|||||||
@@ -57,3 +57,11 @@ OIDC_REDIRECT_URI=http://localhost:8000/api/v1/auth/oidc/callback
|
|||||||
# Researcher
|
# Researcher
|
||||||
# RESEARCHER_LLM__PROVIDER=google
|
# RESEARCHER_LLM__PROVIDER=google
|
||||||
# RESEARCHER_LLM__MODEL=gemini-2.0-flash
|
# RESEARCHER_LLM__MODEL=gemini-2.0-flash
|
||||||
|
|
||||||
|
# Reflector
|
||||||
|
# REFLECTOR_LLM__PROVIDER=openai
|
||||||
|
# REFLECTOR_LLM__MODEL=gpt-5-mini
|
||||||
|
|
||||||
|
# Synthesizer
|
||||||
|
# SYNTHESIZER_LLM__PROVIDER=openai
|
||||||
|
# SYNTHESIZER_LLM__MODEL=gpt-5-mini
|
||||||
|
|||||||
@@ -1,12 +1,13 @@
|
|||||||
# Election Analytics Chatbot - Backend Guide
|
# Election Analytics Chatbot - Backend Guide
|
||||||
|
|
||||||
## Overview
|
## Overview
|
||||||
The backend is a Python-based FastAPI application that leverages **LangGraph** to provide a stateful, agentic workflow for election data analysis. It handles complex queries by decomposing them into tasks such as data analysis, web research, or user clarification.
|
The backend is a Python-based FastAPI application that leverages **LangGraph** to provide a stateful, hierarchical multi-agent workflow for election data analysis. It handles complex queries using an Orchestrator-Workers pattern, decomposing tasks and delegating them to specialized subgraphs (Data Analyst, Researcher) with built-in reflection and error recovery.
|
||||||
|
|
||||||
## 1. Architecture Overview
|
## 1. Architecture Overview
|
||||||
- **Framework**: LangGraph for workflow orchestration and state management.
|
- **Framework**: LangGraph for hierarchical workflow orchestration and state management.
|
||||||
- **API**: FastAPI for providing REST and streaming (SSE) endpoints.
|
- **API**: FastAPI for providing REST and streaming (SSE) endpoints.
|
||||||
- **State Management**: Persistent state using LangGraph's `StateGraph` with a PostgreSQL checkpointer.
|
- **State Management**: Persistent state using LangGraph's `StateGraph` with a PostgreSQL checkpointer. Maintains global state (`AgentState`) and isolated worker states (`WorkerState`).
|
||||||
|
- **Virtual File System (VFS)**: An in-memory abstraction passed between nodes to manage intermediate artifacts (scripts, CSVs, charts) without bloating the context window.
|
||||||
- **Database**: PostgreSQL.
|
- **Database**: PostgreSQL.
|
||||||
- Application data: Uses `users` table for local and OIDC users (String IDs).
|
- Application data: Uses `users` table for local and OIDC users (String IDs).
|
||||||
- History: Persists chat history and artifacts.
|
- History: Persists chat history and artifacts.
|
||||||
@@ -14,23 +15,28 @@ The backend is a Python-based FastAPI application that leverages **LangGraph** t
|
|||||||
|
|
||||||
## 2. Core Components
|
## 2. Core Components
|
||||||
|
|
||||||
### 2.1. The Graph State (`src/ea_chatbot/graph/state.py`)
|
### 2.1. State Management (`src/ea_chatbot/graph/state.py` & `workers/*/state.py`)
|
||||||
The state tracks the conversation context, plan, generated code, execution results, and artifacts.
|
- **Global State**: Tracks the conversation context, the high-level task `checklist`, execution progress (`current_step`), and the VFS.
|
||||||
|
- **Worker State**: Isolated snapshot for specialized subgraphs, tracking internal retry loops (`iterations`), worker-specific prompts, and raw results.
|
||||||
|
|
||||||
### 2.2. Nodes (The Actors)
|
### 2.2. The Orchestrator
|
||||||
Located in `src/ea_chatbot/graph/nodes/`:
|
Located in `src/ea_chatbot/graph/nodes/`:
|
||||||
|
|
||||||
- **`query_analyzer`**: Analyzes the user query to determine the intent and required data.
|
- **`query_analyzer`**: Analyzes the user query to determine the intent and required data. If ambiguous, routes to `clarification`.
|
||||||
- **`planner`**: Creates a step-by-step plan for data analysis.
|
- **`planner`**: Decomposes the user request into a strategic `checklist` of sub-tasks assigned to specific workers.
|
||||||
- **`coder`**: Generates Python code based on the plan and dataset metadata.
|
- **`delegate`**: The traffic controller. Routes the current task to the appropriate worker and enforces a strict retry budget to prevent infinite loops.
|
||||||
- **`executor`**: Safely executes the generated code and captures outputs (dataframes, plots).
|
- **`reflector`**: The quality control node. Evaluates a worker's summary against the sub-task requirements. Can trigger a retry if unsatisfied.
|
||||||
- **`error_corrector`**: Fixes code if execution fails.
|
- **`synthesizer`**: Aggregates all worker results into a final, cohesive response for the user.
|
||||||
- **`researcher`**: Performs web searches for general election information.
|
- **`clarification`**: Asks the user for more information if the query is critically ambiguous.
|
||||||
- **`summarizer`**: Generates a natural language response based on the analysis results.
|
|
||||||
- **`clarification`**: Asks the user for more information if the query is ambiguous.
|
|
||||||
|
|
||||||
### 2.3. The Workflow (Graph)
|
### 2.3. Specialized Workers (Sub-Graphs)
|
||||||
The graph connects these nodes with conditional edges, allowing for iterative refinement and error correction.
|
Located in `src/ea_chatbot/graph/workers/`:
|
||||||
|
|
||||||
|
- **`data_analyst`**: Generates Python/SQL code, executes it securely, and captures dataframes/plots. Contains an internal retry loop (`coder` -> `executor` -> error check -> `coder`).
|
||||||
|
- **`researcher`**: Performs web searches for general election information and synthesizes factual findings.
|
||||||
|
|
||||||
|
### 2.4. The Workflow
|
||||||
|
The global graph connects the Orchestrator nodes, wrapping the Worker subgraphs as self-contained nodes with mapped inputs and outputs.
|
||||||
|
|
||||||
## 3. Key Modules
|
## 3. Key Modules
|
||||||
|
|
||||||
|
|||||||
@@ -56,7 +56,8 @@ async def stream_agent_events(
|
|||||||
initial_state,
|
initial_state,
|
||||||
config,
|
config,
|
||||||
version="v2",
|
version="v2",
|
||||||
checkpointer=checkpointer
|
checkpointer=checkpointer,
|
||||||
|
subgraphs=True
|
||||||
):
|
):
|
||||||
kind = event.get("event")
|
kind = event.get("event")
|
||||||
name = event.get("name")
|
name = event.get("name")
|
||||||
@@ -71,8 +72,8 @@ async def stream_agent_events(
|
|||||||
"data": data
|
"data": data
|
||||||
}
|
}
|
||||||
|
|
||||||
# Buffer assistant chunks (summarizer and researcher might stream)
|
# Buffer assistant chunks (synthesizer and clarification might stream)
|
||||||
if kind == "on_chat_model_stream" and node_name in ["summarizer", "researcher", "clarification"]:
|
if kind == "on_chat_model_stream" and node_name in ["synthesizer", "clarification"]:
|
||||||
chunk = data.get("chunk", "")
|
chunk = data.get("chunk", "")
|
||||||
# Use utility to safely extract text content from the chunk
|
# Use utility to safely extract text content from the chunk
|
||||||
chunk_data = convert_to_json_compatible(chunk)
|
chunk_data = convert_to_json_compatible(chunk)
|
||||||
@@ -83,7 +84,7 @@ async def stream_agent_events(
|
|||||||
assistant_chunks.append(str(chunk_data))
|
assistant_chunks.append(str(chunk_data))
|
||||||
|
|
||||||
# Buffer and encode plots
|
# Buffer and encode plots
|
||||||
if kind == "on_chain_end" and name == "executor":
|
if kind == "on_chain_end" and name == "data_analyst_worker":
|
||||||
output = data.get("output", {})
|
output = data.get("output", {})
|
||||||
if isinstance(output, dict) and "plots" in output:
|
if isinstance(output, dict) and "plots" in output:
|
||||||
plots = output["plots"]
|
plots = output["plots"]
|
||||||
@@ -95,7 +96,7 @@ async def stream_agent_events(
|
|||||||
output_event["data"]["encoded_plots"] = encoded_plots
|
output_event["data"]["encoded_plots"] = encoded_plots
|
||||||
|
|
||||||
# Collect final response from terminal nodes
|
# Collect final response from terminal nodes
|
||||||
if kind == "on_chain_end" and name in ["summarizer", "researcher", "clarification"]:
|
if kind == "on_chain_end" and name in ["synthesizer", "clarification"]:
|
||||||
output = data.get("output", {})
|
output = data.get("output", {})
|
||||||
if isinstance(output, dict) and "messages" in output:
|
if isinstance(output, dict) and "messages" in output:
|
||||||
last_msg = output["messages"][-1]
|
last_msg = output["messages"][-1]
|
||||||
|
|||||||
@@ -52,6 +52,8 @@ class Settings(BaseSettings):
|
|||||||
coder_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0))
|
coder_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0))
|
||||||
summarizer_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0))
|
summarizer_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0))
|
||||||
researcher_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0))
|
researcher_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0))
|
||||||
|
reflector_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0))
|
||||||
|
synthesizer_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0))
|
||||||
|
|
||||||
# Allow nested env vars like QUERY_ANALYZER_LLM__MODEL
|
# Allow nested env vars like QUERY_ANALYZER_LLM__MODEL
|
||||||
model_config = SettingsConfigDict(env_nested_delimiter='__', env_prefix='')
|
model_config = SettingsConfigDict(env_nested_delimiter='__', env_prefix='')
|
||||||
|
|||||||
@@ -36,8 +36,7 @@ Please ask the user for the necessary details."""
|
|||||||
response = llm.invoke(messages)
|
response = llm.invoke(messages)
|
||||||
logger.info("[bold green]Clarification generated.[/bold green]")
|
logger.info("[bold green]Clarification generated.[/bold green]")
|
||||||
return {
|
return {
|
||||||
"messages": [response],
|
"messages": [response]
|
||||||
"next_action": "end" # To indicate we are done for now
|
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to generate clarification: {str(e)}")
|
logger.error(f"Failed to generate clarification: {str(e)}")
|
||||||
|
|||||||
@@ -1,47 +0,0 @@
|
|||||||
from ea_chatbot.graph.state import AgentState
|
|
||||||
from ea_chatbot.config import Settings
|
|
||||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
|
||||||
from ea_chatbot.utils import database_inspection
|
|
||||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
|
||||||
from ea_chatbot.graph.prompts.coder import CODE_GENERATOR_PROMPT
|
|
||||||
from ea_chatbot.schemas import CodeGenerationResponse
|
|
||||||
|
|
||||||
def coder_node(state: AgentState) -> dict:
|
|
||||||
"""Generate Python code based on the plan and data summary."""
|
|
||||||
question = state["question"]
|
|
||||||
plan = state.get("plan", "")
|
|
||||||
code_output = state.get("code_output", "None")
|
|
||||||
|
|
||||||
settings = Settings()
|
|
||||||
logger = get_logger("coder")
|
|
||||||
|
|
||||||
logger.info("Generating Python code...")
|
|
||||||
|
|
||||||
llm = get_llm_model(
|
|
||||||
settings.coder_llm,
|
|
||||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
|
||||||
)
|
|
||||||
structured_llm = llm.with_structured_output(CodeGenerationResponse)
|
|
||||||
|
|
||||||
# Always provide data summary
|
|
||||||
database_description = database_inspection.get_data_summary(data_dir=settings.data_dir) or "No data available."
|
|
||||||
example_code = "" # Placeholder
|
|
||||||
|
|
||||||
messages = CODE_GENERATOR_PROMPT.format_messages(
|
|
||||||
question=question,
|
|
||||||
plan=plan,
|
|
||||||
database_description=database_description,
|
|
||||||
code_exec_results=code_output,
|
|
||||||
example_code=example_code
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = structured_llm.invoke(messages)
|
|
||||||
logger.info("[bold green]Code generated.[/bold green]")
|
|
||||||
return {
|
|
||||||
"code": response.parsed_code,
|
|
||||||
"error": None # Clear previous errors on new code generation
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to generate code: {str(e)}")
|
|
||||||
raise e
|
|
||||||
28
backend/src/ea_chatbot/graph/nodes/delegate.py
Normal file
28
backend/src/ea_chatbot/graph/nodes/delegate.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
from ea_chatbot.graph.state import AgentState
|
||||||
|
from ea_chatbot.utils.logging import get_logger
|
||||||
|
|
||||||
|
def delegate_node(state: AgentState) -> dict:
|
||||||
|
"""Determine which worker subgraph to call next based on the checklist."""
|
||||||
|
checklist = state.get("checklist", [])
|
||||||
|
current_step = state.get("current_step", 0)
|
||||||
|
iterations = state.get("iterations", 0)
|
||||||
|
logger = get_logger("orchestrator:delegate")
|
||||||
|
|
||||||
|
if not checklist or current_step >= len(checklist):
|
||||||
|
logger.info("Checklist complete or empty. Routing to summarizer.")
|
||||||
|
return {"next_action": "summarize"}
|
||||||
|
|
||||||
|
# Enforce retry budget
|
||||||
|
if iterations >= 3:
|
||||||
|
logger.error(f"Max retries reached for task {current_step}. Routing to summary with failure.")
|
||||||
|
return {
|
||||||
|
"next_action": "summarize",
|
||||||
|
"iterations": 0, # Reset for next turn
|
||||||
|
"summary": f"Failed to complete task {current_step} after {iterations} attempts."
|
||||||
|
}
|
||||||
|
|
||||||
|
task_info = checklist[current_step]
|
||||||
|
worker = task_info.get("worker", "data_analyst")
|
||||||
|
|
||||||
|
logger.info(f"Delegating next task to worker: {worker} (Attempt {iterations + 1})")
|
||||||
|
return {"next_action": worker}
|
||||||
@@ -1,44 +0,0 @@
|
|||||||
from ea_chatbot.graph.state import AgentState
|
|
||||||
from ea_chatbot.config import Settings
|
|
||||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
|
||||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
|
||||||
from ea_chatbot.graph.prompts.coder import ERROR_CORRECTOR_PROMPT
|
|
||||||
from ea_chatbot.schemas import CodeGenerationResponse
|
|
||||||
|
|
||||||
def error_corrector_node(state: AgentState) -> dict:
|
|
||||||
"""Fix the code based on the execution error."""
|
|
||||||
code = state.get("code", "")
|
|
||||||
error = state.get("error", "Unknown error")
|
|
||||||
|
|
||||||
settings = Settings()
|
|
||||||
logger = get_logger("error_corrector")
|
|
||||||
|
|
||||||
logger.warning(f"[bold red]Execution error detected:[/bold red] {error[:100]}...")
|
|
||||||
logger.info("Attempting to correct the code...")
|
|
||||||
|
|
||||||
# Reuse coder LLM config or add a new one. Using coder_llm for now.
|
|
||||||
llm = get_llm_model(
|
|
||||||
settings.coder_llm,
|
|
||||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
|
||||||
)
|
|
||||||
structured_llm = llm.with_structured_output(CodeGenerationResponse)
|
|
||||||
|
|
||||||
messages = ERROR_CORRECTOR_PROMPT.format_messages(
|
|
||||||
code=code,
|
|
||||||
error=error
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = structured_llm.invoke(messages)
|
|
||||||
logger.info("[bold green]Correction generated.[/bold green]")
|
|
||||||
|
|
||||||
current_iterations = state.get("iterations", 0)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"code": response.parsed_code,
|
|
||||||
"error": None, # Clear error after fix attempt
|
|
||||||
"iterations": current_iterations + 1
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to correct code: {str(e)}")
|
|
||||||
raise e
|
|
||||||
@@ -1,34 +1,32 @@
|
|||||||
import yaml
|
|
||||||
from ea_chatbot.graph.state import AgentState
|
from ea_chatbot.graph.state import AgentState
|
||||||
from ea_chatbot.config import Settings
|
from ea_chatbot.config import Settings
|
||||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||||
from ea_chatbot.utils import helpers, database_inspection
|
from ea_chatbot.utils import helpers, database_inspection
|
||||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||||
from ea_chatbot.graph.prompts.planner import PLANNER_PROMPT
|
from ea_chatbot.graph.prompts.planner import PLANNER_PROMPT
|
||||||
from ea_chatbot.schemas import TaskPlanResponse
|
from ea_chatbot.schemas import ChecklistResponse
|
||||||
|
|
||||||
def planner_node(state: AgentState) -> dict:
|
def planner_node(state: AgentState) -> dict:
|
||||||
"""Generate a structured plan based on the query analysis."""
|
"""Generate a high-level task checklist for the Orchestrator."""
|
||||||
question = state["question"]
|
question = state["question"]
|
||||||
history = state.get("messages", [])[-6:]
|
history = state.get("messages", [])[-6:]
|
||||||
summary = state.get("summary", "")
|
summary = state.get("summary", "")
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
logger = get_logger("planner")
|
logger = get_logger("orchestrator:planner")
|
||||||
|
|
||||||
logger.info("Generating task plan...")
|
logger.info("Generating high-level task checklist...")
|
||||||
|
|
||||||
llm = get_llm_model(
|
llm = get_llm_model(
|
||||||
settings.planner_llm,
|
settings.planner_llm,
|
||||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||||
)
|
)
|
||||||
structured_llm = llm.with_structured_output(TaskPlanResponse)
|
structured_llm = llm.with_structured_output(ChecklistResponse)
|
||||||
|
|
||||||
date_str = helpers.get_readable_date()
|
date_str = helpers.get_readable_date()
|
||||||
|
|
||||||
# Always provide data summary; LLM decides relevance.
|
# Data summary for context
|
||||||
database_description = database_inspection.get_data_summary(data_dir=settings.data_dir) or "No data available."
|
database_description = database_inspection.get_data_summary(data_dir=settings.data_dir) or "No data available."
|
||||||
example_plan = ""
|
|
||||||
|
|
||||||
messages = PLANNER_PROMPT.format_messages(
|
messages = PLANNER_PROMPT.format_messages(
|
||||||
date=date_str,
|
date=date_str,
|
||||||
@@ -36,16 +34,20 @@ def planner_node(state: AgentState) -> dict:
|
|||||||
history=history,
|
history=history,
|
||||||
summary=summary,
|
summary=summary,
|
||||||
database_description=database_description,
|
database_description=database_description,
|
||||||
example_plan=example_plan
|
example_plan="Decompose into data_analyst and researcher tasks."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Generate the structured plan
|
|
||||||
try:
|
try:
|
||||||
response = structured_llm.invoke(messages)
|
response = ChecklistResponse.model_validate(structured_llm.invoke(messages))
|
||||||
# Convert the structured response back to YAML string for the state
|
# Convert ChecklistTask objects to dicts for state
|
||||||
plan_yaml = yaml.dump(response.model_dump(), sort_keys=False)
|
checklist = [task.model_dump() for task in response.checklist]
|
||||||
logger.info("[bold green]Plan generated successfully.[/bold green]")
|
logger.info(f"[bold green]Checklist generated with {len(checklist)} tasks.[/bold green]")
|
||||||
return {"plan": plan_yaml}
|
return {
|
||||||
|
"checklist": checklist,
|
||||||
|
"current_step": 0,
|
||||||
|
"iterations": 0, # Reset iteration counter for the new plan
|
||||||
|
"summary": response.reflection # Use reflection as initial summary
|
||||||
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to generate plan: {str(e)}")
|
logger.error(f"Failed to generate checklist: {str(e)}")
|
||||||
raise e
|
raise e
|
||||||
@@ -1,18 +1,9 @@
|
|||||||
from typing import List, Literal
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from ea_chatbot.graph.state import AgentState
|
from ea_chatbot.graph.state import AgentState
|
||||||
from ea_chatbot.config import Settings
|
from ea_chatbot.config import Settings
|
||||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||||
from ea_chatbot.graph.prompts.query_analyzer import QUERY_ANALYZER_PROMPT
|
from ea_chatbot.graph.prompts.query_analyzer import QUERY_ANALYZER_PROMPT
|
||||||
|
from ea_chatbot.schemas import QueryAnalysis
|
||||||
class QueryAnalysis(BaseModel):
|
|
||||||
"""Analysis of the user's query."""
|
|
||||||
data_required: List[str] = Field(description="List of data points or entities mentioned (e.g., ['2024 results', 'Florida']).")
|
|
||||||
unknowns: List[str] = Field(description="List of target information the user wants to know or needed for final answer (e.g., 'who won', 'total votes').")
|
|
||||||
ambiguities: List[str] = Field(description="List of CRITICAL missing details that prevent ANY analysis. Do NOT include database names or plot types if defaults can be used.")
|
|
||||||
conditions: List[str] = Field(description="List of any filters or constraints (e.g., ['year=2024', 'state=Florida']). Include context resolved from history.")
|
|
||||||
next_action: Literal["plan", "clarify", "research"] = Field(description="The next action to take. 'plan' for data analysis (even with defaults), 'research' for general knowledge, or 'clarify' ONLY for critical ambiguities.")
|
|
||||||
|
|
||||||
def query_analyzer_node(state: AgentState) -> dict:
|
def query_analyzer_node(state: AgentState) -> dict:
|
||||||
"""Analyze the user's question and determine the next course of action."""
|
"""Analyze the user's question and determine the next course of action."""
|
||||||
|
|||||||
63
backend/src/ea_chatbot/graph/nodes/reflector.py
Normal file
63
backend/src/ea_chatbot/graph/nodes/reflector.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
from ea_chatbot.graph.state import AgentState
|
||||||
|
from ea_chatbot.config import Settings
|
||||||
|
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||||
|
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||||
|
from ea_chatbot.schemas import ReflectorResponse
|
||||||
|
|
||||||
|
def reflector_node(state: AgentState) -> dict:
|
||||||
|
"""Evaluate if the worker's output satisfies the current sub-task."""
|
||||||
|
checklist = state.get("checklist", [])
|
||||||
|
current_step = state.get("current_step", 0)
|
||||||
|
summary = state.get("summary", "") # This contains the worker's summary
|
||||||
|
|
||||||
|
if not checklist or current_step >= len(checklist):
|
||||||
|
return {"next_action": "summarize"}
|
||||||
|
|
||||||
|
task_info = checklist[current_step]
|
||||||
|
task_desc = task_info.get("task", "")
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
logger = get_logger("orchestrator:reflector")
|
||||||
|
|
||||||
|
logger.info(f"Evaluating worker output for task: {task_desc[:50]}...")
|
||||||
|
|
||||||
|
llm = get_llm_model(
|
||||||
|
settings.reflector_llm,
|
||||||
|
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||||
|
)
|
||||||
|
structured_llm = llm.with_structured_output(ReflectorResponse)
|
||||||
|
|
||||||
|
prompt = f"""You are a Lead Orchestrator evaluating the work of a specialized sub-agent.
|
||||||
|
|
||||||
|
**Sub-Task assigned:**
|
||||||
|
{task_desc}
|
||||||
|
|
||||||
|
**Worker's Result Summary:**
|
||||||
|
{summary}
|
||||||
|
|
||||||
|
Evaluate if the result is satisfactory and complete for this specific sub-task.
|
||||||
|
If there were major errors or the output is missing critical data requested in the sub-task, mark satisfied as False."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = structured_llm.invoke(prompt)
|
||||||
|
if response.satisfied:
|
||||||
|
logger.info("[bold green]Sub-task satisfied.[/bold green] Advancing plan.")
|
||||||
|
return {
|
||||||
|
"current_step": current_step + 1,
|
||||||
|
"iterations": 0, # Reset for next task
|
||||||
|
"next_action": "delegate"
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
logger.warning(f"[bold yellow]Sub-task NOT satisfied.[/bold yellow] Reason: {response.reasoning}")
|
||||||
|
# Do NOT advance the step. Increment iterations to track retries.
|
||||||
|
return {
|
||||||
|
"iterations": state.get("iterations", 0) + 1,
|
||||||
|
"next_action": "delegate"
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to reflect: {str(e)}")
|
||||||
|
# On error, increment iterations to avoid infinite loop if LLM is stuck
|
||||||
|
return {
|
||||||
|
"iterations": state.get("iterations", 0) + 1,
|
||||||
|
"next_action": "delegate"
|
||||||
|
}
|
||||||
@@ -1,43 +0,0 @@
|
|||||||
from ea_chatbot.graph.state import AgentState
|
|
||||||
from ea_chatbot.config import Settings
|
|
||||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
|
||||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
|
||||||
from ea_chatbot.graph.prompts.summarizer import SUMMARIZER_PROMPT
|
|
||||||
|
|
||||||
def summarizer_node(state: AgentState) -> dict:
|
|
||||||
"""Summarize the code execution results into a final answer."""
|
|
||||||
question = state["question"]
|
|
||||||
plan = state.get("plan", "")
|
|
||||||
code_output = state.get("code_output", "")
|
|
||||||
history = state.get("messages", [])[-6:]
|
|
||||||
summary = state.get("summary", "")
|
|
||||||
|
|
||||||
settings = Settings()
|
|
||||||
logger = get_logger("summarizer")
|
|
||||||
|
|
||||||
logger.info("Generating final summary...")
|
|
||||||
|
|
||||||
llm = get_llm_model(
|
|
||||||
settings.summarizer_llm,
|
|
||||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
|
||||||
)
|
|
||||||
|
|
||||||
messages = SUMMARIZER_PROMPT.format_messages(
|
|
||||||
question=question,
|
|
||||||
plan=plan,
|
|
||||||
code_output=code_output,
|
|
||||||
history=history,
|
|
||||||
summary=summary
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = llm.invoke(messages)
|
|
||||||
logger.info("[bold green]Summary generated.[/bold green]")
|
|
||||||
|
|
||||||
# Return the final message to be added to the state
|
|
||||||
return {
|
|
||||||
"messages": [response]
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to generate summary: {str(e)}")
|
|
||||||
raise e
|
|
||||||
56
backend/src/ea_chatbot/graph/nodes/synthesizer.py
Normal file
56
backend/src/ea_chatbot/graph/nodes/synthesizer.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
from ea_chatbot.graph.state import AgentState
|
||||||
|
from ea_chatbot.config import Settings
|
||||||
|
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||||
|
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||||
|
from ea_chatbot.graph.prompts.synthesizer import SYNTHESIZER_PROMPT
|
||||||
|
|
||||||
|
def synthesizer_node(state: AgentState) -> dict:
|
||||||
|
"""Synthesize the results from multiple workers into a final answer."""
|
||||||
|
question = state["question"]
|
||||||
|
history = state.get("messages", [])
|
||||||
|
|
||||||
|
# We look for the 'summary' from the last worker which might have cumulative info
|
||||||
|
# Or we can look at all messages in history bubbled up from workers.
|
||||||
|
# For now, let's assume the history contains all the worker summaries.
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
logger = get_logger("orchestrator:synthesizer")
|
||||||
|
|
||||||
|
logger.info("Synthesizing final answer from worker results...")
|
||||||
|
|
||||||
|
llm = get_llm_model(
|
||||||
|
settings.synthesizer_llm,
|
||||||
|
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Artifact summary
|
||||||
|
plots = state.get("plots", [])
|
||||||
|
vfs = state.get("vfs", {})
|
||||||
|
artifacts_summary = ""
|
||||||
|
if plots:
|
||||||
|
artifacts_summary += f"- {len(plots)} generated plot(s) are attached to this response.\n"
|
||||||
|
if vfs:
|
||||||
|
artifacts_summary += "- Data files available in VFS: " + ", ".join(vfs.keys()) + "\n"
|
||||||
|
if not artifacts_summary:
|
||||||
|
artifacts_summary = "No additional artifacts generated."
|
||||||
|
|
||||||
|
# We provide the full history and the original question
|
||||||
|
messages = SYNTHESIZER_PROMPT.format_messages(
|
||||||
|
question=question,
|
||||||
|
history=history,
|
||||||
|
worker_results="Review the worker summaries provided in the message history.",
|
||||||
|
artifacts_summary=artifacts_summary
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = llm.invoke(messages)
|
||||||
|
logger.info("[bold green]Final synthesis complete.[/bold green]")
|
||||||
|
|
||||||
|
# Return the final message to be added to the state
|
||||||
|
return {
|
||||||
|
"messages": [response],
|
||||||
|
"next_action": "end"
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to synthesize final answer: {str(e)}")
|
||||||
|
raise e
|
||||||
@@ -9,6 +9,14 @@ The user will provide a task and a plan.
|
|||||||
- Do NOT assume a dataframe `df` is already loaded unless explicitly stated. You usually need to query it first.
|
- Do NOT assume a dataframe `df` is already loaded unless explicitly stated. You usually need to query it first.
|
||||||
- The database schema is described in the prompt. Use it to construct valid SQL queries.
|
- The database schema is described in the prompt. Use it to construct valid SQL queries.
|
||||||
|
|
||||||
|
**Virtual File System (VFS):**
|
||||||
|
- An in-memory file system is available as `vfs`. Use it to persist intermediate data or large artifacts.
|
||||||
|
- `vfs.write(filename, content, metadata=None)`: Save a file (content can be any serializable object).
|
||||||
|
- `vfs.read(filename) -> (content, metadata)`: Read a file.
|
||||||
|
- `vfs.list() -> list[str]`: List all files.
|
||||||
|
- `vfs.delete(filename)`: Delete a file.
|
||||||
|
- Prefer using VFS for intermediate DataFrames or complex data structures instead of printing everything.
|
||||||
|
|
||||||
**Plotting:**
|
**Plotting:**
|
||||||
- If you need to plot any data, use the `plots` list to store the figures.
|
- If you need to plot any data, use the `plots` list to store the figures.
|
||||||
- Example: `plots.append(fig)` or `plots.append(plt.gcf())`.
|
- Example: `plots.append(fig)` or `plots.append(plt.gcf())`.
|
||||||
@@ -18,7 +26,8 @@ The user will provide a task and a plan.
|
|||||||
- Produce FULL, COMPLETE CODE that includes all steps and solves the task!
|
- Produce FULL, COMPLETE CODE that includes all steps and solves the task!
|
||||||
- Always include the import statements at the top of the code (e.g., `import pandas as pd`, `import matplotlib.pyplot as plt`).
|
- Always include the import statements at the top of the code (e.g., `import pandas as pd`, `import matplotlib.pyplot as plt`).
|
||||||
- Always include print statements to output the results of your code.
|
- Always include print statements to output the results of your code.
|
||||||
- Use `db.query_df("SELECT ...")` to get data."""
|
- Use `db.query_df("SELECT ...")` to get data.
|
||||||
|
"""
|
||||||
|
|
||||||
CODE_GENERATOR_USER = """TASK:
|
CODE_GENERATOR_USER = """TASK:
|
||||||
{question}
|
{question}
|
||||||
@@ -43,6 +52,7 @@ Return a complete, corrected python code that incorporates the fixes for the err
|
|||||||
- You have access to a database client via the variable `db`.
|
- You have access to a database client via the variable `db`.
|
||||||
- Use `db.query_df(sql)` to run queries.
|
- Use `db.query_df(sql)` to run queries.
|
||||||
- Use `plots.append(fig)` for plots.
|
- Use `plots.append(fig)` for plots.
|
||||||
|
- You have access to `vfs` for persistent in-memory storage.
|
||||||
- Always include imports and print statements."""
|
- Always include imports and print statements."""
|
||||||
|
|
||||||
ERROR_CORRECTOR_USER = """FAILED CODE:
|
ERROR_CORRECTOR_USER = """FAILED CODE:
|
||||||
|
|||||||
@@ -1,41 +1,29 @@
|
|||||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||||
|
|
||||||
PLANNER_SYSTEM = """You are a Research Specialist and your job is to find answers and educate the user.
|
PLANNER_SYSTEM = """You are a Lead Orchestrator for an Election Analytics Chatbot.
|
||||||
Provide factual information responding directly to the user's question. Include key details and context to ensure your response comprehensively answers their query.
|
Your job is to decompose complex user queries into a high-level checklist of tasks.
|
||||||
|
|
||||||
|
**Specialized Workers:**
|
||||||
|
1. `data_analyst`: Handles SQL queries, Python data analysis, and plotting. Use this when the user needs numbers, trends, or charts from the internal database.
|
||||||
|
2. `researcher`: Performs web searches for current news, facts, or external data not in the primary database.
|
||||||
|
|
||||||
|
**Orchestration Strategy:**
|
||||||
|
- Analyze the user's question and the available data summary.
|
||||||
|
- Create a logical sequence of tasks (checklist) for these workers.
|
||||||
|
- Be specific in the task description for the worker (e.g., "Find the total votes in Florida 2020").
|
||||||
|
- If the query is ambiguous, the Orchestrator loop will later handle clarification, but for now, make the best plan possible.
|
||||||
|
|
||||||
Today's Date is: {date}"""
|
Today's Date is: {date}"""
|
||||||
|
|
||||||
PLANNER_USER = """Conversation Summary: {summary}
|
PLANNER_USER = """Conversation Summary: {summary}
|
||||||
|
|
||||||
TASK:
|
USER QUESTION:
|
||||||
{question}
|
{question}
|
||||||
|
|
||||||
AVAILABLE DATA SUMMARY (Use only if relevant to the task):
|
AVAILABLE DATABASE SUMMARY:
|
||||||
{database_description}
|
{database_description}
|
||||||
|
|
||||||
First: Evaluate whether you have all necessary and requested information to provide a solution.
|
Decompose the question into a strategic checklist. For each task, specify which worker should handle it.
|
||||||
Use the dataset description above to determine what data and in what format you have available to you.
|
|
||||||
You are able to search internet if the user asks for it, or you require any information that you can not derive from the given dataset or the instruction.
|
|
||||||
|
|
||||||
Second: Incorporate any additional relevant context, reasoning, or details from previous interactions or internal chain-of-thought that may impact the solution.
|
|
||||||
Ensure that all such information is fully included in your response rather than referring to previous answers indirectly.
|
|
||||||
|
|
||||||
Third: Reflect on the problem and briefly describe it, while addressing the problem goal, inputs, outputs,
|
|
||||||
rules, constraints, and other relevant details that appear in the problem description.
|
|
||||||
|
|
||||||
Fourth: Based on the preceding steps, formulate your response as an algorithm, breaking the solution in up to eight simple concise yet descriptive, clear English steps.
|
|
||||||
You MUST Include all values or instructions as described in the above task, or retrieved using internet search!
|
|
||||||
If fewer steps suffice, that's acceptable. If more are needed, please include them.
|
|
||||||
Remember to explain steps rather than write code.
|
|
||||||
|
|
||||||
This algorithm will be later converted to Python code.
|
|
||||||
If a dataframe is required, assume it is named 'df' and is already defined/populated based on the data summary.
|
|
||||||
|
|
||||||
There is a list variable called `plots` that you need to use to store any plots you generate. Do not use `plt.show()` as it will render the plot and cause an error.
|
|
||||||
|
|
||||||
Output the algorithm as a YAML string. Always enclose the YAML string within ```yaml tags.
|
|
||||||
|
|
||||||
**Note: Ensure that any necessary context from prior interactions is fully embedded in the plan. Do not use phrases like "refer to previous answer"; instead, provide complete details inline.**
|
|
||||||
|
|
||||||
{example_plan}"""
|
{example_plan}"""
|
||||||
|
|
||||||
|
|||||||
31
backend/src/ea_chatbot/graph/prompts/synthesizer.py
Normal file
31
backend/src/ea_chatbot/graph/prompts/synthesizer.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||||
|
|
||||||
|
SYNTHESIZER_SYSTEM = """You are a Lead Orchestrator for an Election Analytics Chatbot.
|
||||||
|
You have coordinated several specialized workers (Data Analysts, Researchers) to answer a user's complex query.
|
||||||
|
|
||||||
|
Your goal is to synthesize their individual findings into a single, cohesive, and comprehensive final response for the user.
|
||||||
|
|
||||||
|
**Guidelines:**
|
||||||
|
- Do NOT mention the internal 'workers' or 'checklist' names.
|
||||||
|
- Combine the data insights (from Data Analysts) and factual research (from Researchers) into a natural narrative.
|
||||||
|
- Ensure all numbers, dates, and names from the worker reports are included accurately.
|
||||||
|
- **Artifacts & Plots:** If plots or charts were generated, refer to them naturally (e.g., "The chart below shows...").
|
||||||
|
- If any part of the plan failed, explain the status honestly but professionally.
|
||||||
|
- Present data in clear formats (tables, bullet points) where appropriate."""
|
||||||
|
|
||||||
|
SYNTHESIZER_USER = """USER QUESTION:
|
||||||
|
{question}
|
||||||
|
|
||||||
|
EXECUTION SUMMARY (Results from specialized workers):
|
||||||
|
{worker_results}
|
||||||
|
|
||||||
|
AVAILABLE ARTIFACTS:
|
||||||
|
{artifacts_summary}
|
||||||
|
|
||||||
|
Provide the final integrated response:"""
|
||||||
|
|
||||||
|
SYNTHESIZER_PROMPT = ChatPromptTemplate.from_messages([
|
||||||
|
("system", SYNTHESIZER_SYSTEM),
|
||||||
|
MessagesPlaceholder(variable_name="history"),
|
||||||
|
("human", SYNTHESIZER_USER),
|
||||||
|
])
|
||||||
@@ -12,12 +12,12 @@ class AgentState(AS):
|
|||||||
|
|
||||||
# Query Analysis (Decomposition results)
|
# Query Analysis (Decomposition results)
|
||||||
analysis: Optional[Dict[str, Any]]
|
analysis: Optional[Dict[str, Any]]
|
||||||
# Expected keys: "requires_dataset", "expert", "data", "unknown", "condition"
|
# Expected keys: "data_required", "unknowns", "ambiguities", "conditions"
|
||||||
|
|
||||||
# Step-by-step reasoning
|
# Step-by-step reasoning (Legacy, use checklist for new flows)
|
||||||
plan: Optional[str]
|
plan: Optional[str]
|
||||||
|
|
||||||
# Code execution context
|
# Code execution context (Legacy, use workers for new flows)
|
||||||
code: Optional[str]
|
code: Optional[str]
|
||||||
code_output: Optional[str]
|
code_output: Optional[str]
|
||||||
error: Optional[str]
|
error: Optional[str]
|
||||||
@@ -26,11 +26,21 @@ class AgentState(AS):
|
|||||||
plots: Annotated[List[Any], operator.add] # Matplotlib figures
|
plots: Annotated[List[Any], operator.add] # Matplotlib figures
|
||||||
dfs: Dict[str, Any] # Pandas DataFrames
|
dfs: Dict[str, Any] # Pandas DataFrames
|
||||||
|
|
||||||
# Conversation summary
|
# Conversation summary / Latest worker result
|
||||||
summary: Optional[str]
|
summary: Optional[str]
|
||||||
|
|
||||||
# Routing hint: "clarify", "plan", "research", "end"
|
# Routing hint: "clarify", "plan", "end"
|
||||||
next_action: str
|
next_action: str
|
||||||
|
|
||||||
# Number of execution attempts
|
# Number of execution attempts
|
||||||
iterations: int
|
iterations: int
|
||||||
|
|
||||||
|
# --- DeepAgents Extensions ---
|
||||||
|
# High-level plan/checklist: List of Task objects (dicts)
|
||||||
|
checklist: List[Dict[str, Any]]
|
||||||
|
|
||||||
|
# Current active step in the checklist
|
||||||
|
current_step: int
|
||||||
|
|
||||||
|
# Virtual File System (VFS): Map of filenames to content/metadata
|
||||||
|
vfs: Dict[str, Any]
|
||||||
|
|||||||
0
backend/src/ea_chatbot/graph/workers/__init__.py
Normal file
0
backend/src/ea_chatbot/graph/workers/__init__.py
Normal file
42
backend/src/ea_chatbot/graph/workers/data_analyst/mapping.py
Normal file
42
backend/src/ea_chatbot/graph/workers/data_analyst/mapping.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
from typing import Dict, Any, List
|
||||||
|
import copy
|
||||||
|
from langchain_core.messages import HumanMessage, AIMessage
|
||||||
|
from ea_chatbot.graph.state import AgentState
|
||||||
|
from ea_chatbot.graph.workers.data_analyst.state import WorkerState
|
||||||
|
from ea_chatbot.utils.vfs import safe_vfs_copy
|
||||||
|
|
||||||
|
def prepare_worker_input(state: AgentState) -> Dict[str, Any]:
|
||||||
|
"""Prepare the initial state for the Data Analyst worker."""
|
||||||
|
checklist = state.get("checklist", [])
|
||||||
|
current_step = state.get("current_step", 0)
|
||||||
|
|
||||||
|
# Get the current task description
|
||||||
|
task_desc = "Analyze data" # Default
|
||||||
|
if 0 <= current_step < len(checklist):
|
||||||
|
task_desc = checklist[current_step].get("task", task_desc)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"task": task_desc,
|
||||||
|
"messages": [HumanMessage(content=task_desc)], # Start worker loop with the task
|
||||||
|
"vfs_state": safe_vfs_copy(state.get("vfs", {})),
|
||||||
|
"iterations": 0,
|
||||||
|
"plots": [],
|
||||||
|
"code": None,
|
||||||
|
"output": None,
|
||||||
|
"error": None,
|
||||||
|
"result": None
|
||||||
|
}
|
||||||
|
|
||||||
|
def merge_worker_output(worker_state: WorkerState) -> Dict[str, Any]:
|
||||||
|
"""Map worker results back to the global AgentState."""
|
||||||
|
result = worker_state.get("result") or "No result produced by data analyst worker."
|
||||||
|
|
||||||
|
# We bubble up the summary as an AI message and update the global summary field
|
||||||
|
updates = {
|
||||||
|
"messages": [AIMessage(content=result)],
|
||||||
|
"summary": result,
|
||||||
|
"vfs": worker_state.get("vfs_state", {}),
|
||||||
|
"plots": worker_state.get("plots", [])
|
||||||
|
}
|
||||||
|
|
||||||
|
return updates
|
||||||
@@ -0,0 +1,64 @@
|
|||||||
|
from typing import Dict, Any, List, Optional
|
||||||
|
from ea_chatbot.graph.workers.data_analyst.state import WorkerState
|
||||||
|
from ea_chatbot.config import Settings
|
||||||
|
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||||
|
from ea_chatbot.utils import database_inspection
|
||||||
|
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||||
|
from ea_chatbot.graph.prompts.coder import CODE_GENERATOR_PROMPT
|
||||||
|
from ea_chatbot.schemas import CodeGenerationResponse
|
||||||
|
|
||||||
|
def coder_node(state: WorkerState) -> dict:
|
||||||
|
"""Generate Python code based on the sub-task assigned to the worker."""
|
||||||
|
task = state["task"]
|
||||||
|
output = state.get("output", "None")
|
||||||
|
error = state.get("error", "None")
|
||||||
|
vfs_state = state.get("vfs_state", {})
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
logger = get_logger("data_analyst_worker:coder")
|
||||||
|
|
||||||
|
logger.info(f"Generating Python code for task: {task[:50]}...")
|
||||||
|
|
||||||
|
# We can use the configured 'coder_llm' for this node
|
||||||
|
llm = get_llm_model(
|
||||||
|
settings.coder_llm,
|
||||||
|
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||||
|
)
|
||||||
|
structured_llm = llm.with_structured_output(CodeGenerationResponse)
|
||||||
|
|
||||||
|
# Data summary for context
|
||||||
|
database_description = database_inspection.get_data_summary(data_dir=settings.data_dir) or "No data available."
|
||||||
|
|
||||||
|
# VFS Summary: Let the LLM know which files are available in-memory
|
||||||
|
vfs_summary = "Available in-memory files (VFS):\n"
|
||||||
|
if vfs_state:
|
||||||
|
for filename, data in vfs_state.items():
|
||||||
|
if isinstance(data, dict):
|
||||||
|
meta = data.get("metadata", {})
|
||||||
|
vfs_summary += f"- {filename} ({meta.get('type', 'unknown')})\n"
|
||||||
|
else:
|
||||||
|
vfs_summary += f"- {filename} (raw data)\n"
|
||||||
|
else:
|
||||||
|
vfs_summary += "- None"
|
||||||
|
|
||||||
|
# Reuse the global prompt but adapt 'question' and 'plan' labels
|
||||||
|
# For a sub-task worker, 'task' effectively replaces the high-level 'plan'
|
||||||
|
messages = CODE_GENERATOR_PROMPT.format_messages(
|
||||||
|
question=task,
|
||||||
|
plan="Focus on the specific task below.",
|
||||||
|
database_description=database_description,
|
||||||
|
code_exec_results=f"Output: {output}\nError: {error}\n\n{vfs_summary}",
|
||||||
|
example_code=""
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = structured_llm.invoke(messages)
|
||||||
|
logger.info("[bold green]Code generated.[/bold green]")
|
||||||
|
return {
|
||||||
|
"code": response.parsed_code,
|
||||||
|
"error": None, # Clear previous errors
|
||||||
|
"iterations": state.get("iterations", 0) + 1
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to generate code: {str(e)}")
|
||||||
|
raise e
|
||||||
@@ -1,23 +1,25 @@
|
|||||||
import io
|
import io
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
import copy
|
||||||
from contextlib import redirect_stdout
|
from contextlib import redirect_stdout
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from matplotlib.figure import Figure
|
from matplotlib.figure import Figure
|
||||||
|
|
||||||
from ea_chatbot.graph.state import AgentState
|
from ea_chatbot.graph.workers.data_analyst.state import WorkerState
|
||||||
from ea_chatbot.utils.db_client import DBClient
|
from ea_chatbot.utils.db_client import DBClient
|
||||||
|
from ea_chatbot.utils.vfs import VFSHelper, safe_vfs_copy
|
||||||
from ea_chatbot.utils.logging import get_logger
|
from ea_chatbot.utils.logging import get_logger
|
||||||
from ea_chatbot.config import Settings
|
from ea_chatbot.config import Settings
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ea_chatbot.types import DBSettings
|
from ea_chatbot.types import DBSettings
|
||||||
|
|
||||||
def executor_node(state: AgentState) -> dict:
|
def executor_node(state: WorkerState) -> dict:
|
||||||
"""Execute the Python code and capture output, plots, and dataframes."""
|
"""Execute the Python code in the context of the Data Analyst worker."""
|
||||||
code = state.get("code")
|
code = state.get("code")
|
||||||
logger = get_logger("executor")
|
logger = get_logger("data_analyst_worker:executor")
|
||||||
|
|
||||||
if not code:
|
if not code:
|
||||||
logger.error("No code provided to executor.")
|
logger.error("No code provided to executor.")
|
||||||
@@ -37,45 +39,40 @@ def executor_node(state: AgentState) -> dict:
|
|||||||
|
|
||||||
db_client = DBClient(settings=db_settings)
|
db_client = DBClient(settings=db_settings)
|
||||||
|
|
||||||
|
# Initialize the Virtual File System (VFS) helper with the snapshot from state
|
||||||
|
vfs_state = safe_vfs_copy(state.get("vfs_state", {}))
|
||||||
|
vfs_helper = VFSHelper(vfs_state)
|
||||||
|
|
||||||
# Initialize local variables for execution
|
# Initialize local variables for execution
|
||||||
# 'db' is the DBClient instance, 'plots' is for matplotlib figures
|
|
||||||
local_vars = {
|
local_vars = {
|
||||||
'db': db_client,
|
'db': db_client,
|
||||||
'plots': [],
|
'plots': [],
|
||||||
'pd': pd
|
'pd': pd,
|
||||||
|
'vfs': vfs_helper
|
||||||
}
|
}
|
||||||
|
|
||||||
stdout_buffer = io.StringIO()
|
stdout_buffer = io.StringIO()
|
||||||
error = None
|
error = None
|
||||||
code_output = ""
|
output = ""
|
||||||
plots = []
|
plots = []
|
||||||
dfs = {}
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with redirect_stdout(stdout_buffer):
|
with redirect_stdout(stdout_buffer):
|
||||||
# Execute the code in the context of local_vars
|
# Execute the code in the context of local_vars
|
||||||
exec(code, {}, local_vars)
|
exec(code, {}, local_vars)
|
||||||
|
|
||||||
code_output = stdout_buffer.getvalue()
|
output = stdout_buffer.getvalue()
|
||||||
|
|
||||||
# Limit the output length if it's too long
|
# Limit the output length if it's too long
|
||||||
if code_output.count('\n') > 32:
|
if output.count('\n') > 32:
|
||||||
code_output = '\n'.join(code_output.split('\n')[:32]) + '\n...'
|
output = '\n'.join(output.split('\n')[:32]) + '\n...'
|
||||||
|
|
||||||
# Extract plots
|
# Extract plots
|
||||||
raw_plots = local_vars.get('plots', [])
|
raw_plots = local_vars.get('plots', [])
|
||||||
if isinstance(raw_plots, list):
|
if isinstance(raw_plots, list):
|
||||||
plots = [p for p in raw_plots if isinstance(p, Figure)]
|
plots = [p for p in raw_plots if isinstance(p, Figure)]
|
||||||
|
|
||||||
# Extract DataFrames that were likely intended for display
|
logger.info(f"[bold green]Execution complete.[/bold green] Captured {len(plots)} plots.")
|
||||||
# We look for DataFrames in local_vars that were mentioned in the code
|
|
||||||
for key, value in local_vars.items():
|
|
||||||
if isinstance(value, pd.DataFrame):
|
|
||||||
# Heuristic: if the variable name is in the code, it might be a result DF
|
|
||||||
if key in code:
|
|
||||||
dfs[key] = value
|
|
||||||
|
|
||||||
logger.info(f"[bold green]Execution complete.[/bold green] Captured {len(plots)} plots and {len(dfs)} dataframes.")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Capture the traceback
|
# Capture the traceback
|
||||||
@@ -90,13 +87,11 @@ def executor_node(state: AgentState) -> dict:
|
|||||||
error += f"{exc_type.__name__ if exc_type else 'Exception'}: {exc_value}"
|
error += f"{exc_type.__name__ if exc_type else 'Exception'}: {exc_value}"
|
||||||
|
|
||||||
logger.error(f"Execution failed: {str(e)}")
|
logger.error(f"Execution failed: {str(e)}")
|
||||||
|
output = stdout_buffer.getvalue()
|
||||||
# If we have an error, we still might want to see partial stdout
|
|
||||||
code_output = stdout_buffer.getvalue()
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"code_output": code_output,
|
"output": output,
|
||||||
"error": error,
|
"error": error,
|
||||||
"plots": plots,
|
"plots": plots,
|
||||||
"dfs": dfs
|
"vfs_state": vfs_state
|
||||||
}
|
}
|
||||||
@@ -0,0 +1,53 @@
|
|||||||
|
from typing import Dict, Any, List, Optional
|
||||||
|
from ea_chatbot.graph.workers.data_analyst.state import WorkerState
|
||||||
|
from ea_chatbot.config import Settings
|
||||||
|
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||||
|
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||||
|
|
||||||
|
def summarizer_node(state: WorkerState) -> dict:
|
||||||
|
"""Summarize the data analysis results for the Orchestrator."""
|
||||||
|
task = state["task"]
|
||||||
|
output = state.get("output", "")
|
||||||
|
error = state.get("error")
|
||||||
|
plots = state.get("plots", [])
|
||||||
|
vfs_state = state.get("vfs_state", {})
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
logger = get_logger("data_analyst_worker:summarizer")
|
||||||
|
|
||||||
|
logger.info("Summarizing analysis results for the Orchestrator...")
|
||||||
|
|
||||||
|
# Artifact summary
|
||||||
|
artifact_info = ""
|
||||||
|
if plots:
|
||||||
|
artifact_info += f"- Generated {len(plots)} plot(s).\n"
|
||||||
|
if vfs_state:
|
||||||
|
artifact_info += "- VFS Artifacts: " + ", ".join(vfs_state.keys()) + "\n"
|
||||||
|
|
||||||
|
# We can use a smaller/faster model for this summary if needed
|
||||||
|
llm = get_llm_model(
|
||||||
|
settings.planner_llm, # Using planner model for summary logic
|
||||||
|
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = f"""You are a data analyst sub-agent. You have completed a sub-task.
|
||||||
|
Task: {task}
|
||||||
|
Execution Results: {output}
|
||||||
|
Error Log (if any): {error}
|
||||||
|
{artifact_info}
|
||||||
|
|
||||||
|
Provide a concise summary of the findings or status for the top-level Orchestrator.
|
||||||
|
If plots or data files were generated, mention them.
|
||||||
|
If the execution failed after multiple retries, explain why concisely.
|
||||||
|
Do NOT include the raw Python code, just the results of the analysis."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = llm.invoke(prompt)
|
||||||
|
result = response.content if hasattr(response, "content") else str(response)
|
||||||
|
logger.info("[bold green]Analysis results summarized.[/bold green]")
|
||||||
|
return {
|
||||||
|
"result": result
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to summarize results: {str(e)}")
|
||||||
|
raise e
|
||||||
28
backend/src/ea_chatbot/graph/workers/data_analyst/state.py
Normal file
28
backend/src/ea_chatbot/graph/workers/data_analyst/state.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
from typing import TypedDict, List, Dict, Any, Optional
|
||||||
|
from langchain_core.messages import BaseMessage
|
||||||
|
|
||||||
|
class WorkerState(TypedDict):
|
||||||
|
"""Internal state for the Data Analyst worker subgraph."""
|
||||||
|
|
||||||
|
# Internal worker conversation (not bubbled up to global unless summary)
|
||||||
|
messages: List[BaseMessage]
|
||||||
|
|
||||||
|
# The specific sub-task assigned by the Orchestrator
|
||||||
|
task: str
|
||||||
|
|
||||||
|
# Generated code and execution context
|
||||||
|
code: Optional[str]
|
||||||
|
output: Optional[str]
|
||||||
|
error: Optional[str]
|
||||||
|
|
||||||
|
# Number of internal retry attempts
|
||||||
|
iterations: int
|
||||||
|
|
||||||
|
# Temporary storage for analysis results
|
||||||
|
plots: List[Any]
|
||||||
|
|
||||||
|
# Isolated/Snapshot view of the VFS
|
||||||
|
vfs_state: Dict[str, Any]
|
||||||
|
|
||||||
|
# The final summary or result to return to the Orchestrator
|
||||||
|
result: Optional[str]
|
||||||
@@ -0,0 +1,51 @@
|
|||||||
|
from langgraph.graph import StateGraph, END
|
||||||
|
from langgraph.graph.state import CompiledStateGraph
|
||||||
|
from ea_chatbot.graph.workers.data_analyst.state import WorkerState
|
||||||
|
from ea_chatbot.graph.workers.data_analyst.nodes.coder import coder_node
|
||||||
|
from ea_chatbot.graph.workers.data_analyst.nodes.executor import executor_node
|
||||||
|
from ea_chatbot.graph.workers.data_analyst.nodes.summarizer import summarizer_node
|
||||||
|
|
||||||
|
def router(state: WorkerState) -> str:
|
||||||
|
"""Routes the subgraph between coding, execution, and summarization."""
|
||||||
|
error = state.get("error")
|
||||||
|
iterations = state.get("iterations", 0)
|
||||||
|
|
||||||
|
if error and iterations < 3:
|
||||||
|
# Retry with error correction
|
||||||
|
return "coder"
|
||||||
|
|
||||||
|
# Either success or max retries reached
|
||||||
|
return "summarizer"
|
||||||
|
|
||||||
|
def create_data_analyst_worker(
|
||||||
|
coder=coder_node,
|
||||||
|
executor=executor_node,
|
||||||
|
summarizer=summarizer_node
|
||||||
|
) -> CompiledStateGraph:
|
||||||
|
"""Create the Data Analyst worker subgraph."""
|
||||||
|
workflow = StateGraph(WorkerState)
|
||||||
|
|
||||||
|
# Add Nodes
|
||||||
|
workflow.add_node("coder", coder)
|
||||||
|
workflow.add_node("executor", executor)
|
||||||
|
workflow.add_node("summarizer", summarizer)
|
||||||
|
|
||||||
|
# Set entry point
|
||||||
|
workflow.set_entry_point("coder")
|
||||||
|
|
||||||
|
# Add Edges
|
||||||
|
workflow.add_edge("coder", "executor")
|
||||||
|
|
||||||
|
# Add Conditional Edges
|
||||||
|
workflow.add_conditional_edges(
|
||||||
|
"executor",
|
||||||
|
router,
|
||||||
|
{
|
||||||
|
"coder": "coder",
|
||||||
|
"summarizer": "summarizer"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
workflow.add_edge("summarizer", END)
|
||||||
|
|
||||||
|
return workflow.compile()
|
||||||
32
backend/src/ea_chatbot/graph/workers/researcher/mapping.py
Normal file
32
backend/src/ea_chatbot/graph/workers/researcher/mapping.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
from typing import Dict, Any, List
|
||||||
|
from langchain_core.messages import HumanMessage, AIMessage
|
||||||
|
from ea_chatbot.graph.state import AgentState
|
||||||
|
from ea_chatbot.graph.workers.researcher.state import WorkerState
|
||||||
|
|
||||||
|
def prepare_researcher_input(state: AgentState) -> Dict[str, Any]:
|
||||||
|
"""Prepare the initial state for the Researcher worker."""
|
||||||
|
checklist = state.get("checklist", [])
|
||||||
|
current_step = state.get("current_step", 0)
|
||||||
|
|
||||||
|
task_desc = "Perform research"
|
||||||
|
if 0 <= current_step < len(checklist):
|
||||||
|
task_desc = checklist[current_step].get("task", task_desc)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"task": task_desc,
|
||||||
|
"messages": [HumanMessage(content=task_desc)],
|
||||||
|
"queries": [],
|
||||||
|
"raw_results": [],
|
||||||
|
"iterations": 0,
|
||||||
|
"result": None
|
||||||
|
}
|
||||||
|
|
||||||
|
def merge_researcher_output(worker_state: WorkerState) -> Dict[str, Any]:
|
||||||
|
"""Map researcher results back to the global AgentState."""
|
||||||
|
result = worker_state.get("result") or "No result produced by researcher worker."
|
||||||
|
|
||||||
|
return {
|
||||||
|
"messages": [AIMessage(content=result)],
|
||||||
|
"summary": result,
|
||||||
|
# Researcher doesn't usually update VFS or Plots, but we keep the structure
|
||||||
|
}
|
||||||
@@ -1,24 +1,20 @@
|
|||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||||
from ea_chatbot.graph.state import AgentState
|
from ea_chatbot.graph.workers.researcher.state import WorkerState
|
||||||
from ea_chatbot.config import Settings
|
from ea_chatbot.config import Settings
|
||||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||||
from ea_chatbot.utils import helpers
|
from ea_chatbot.utils import helpers
|
||||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||||
from ea_chatbot.graph.prompts.researcher import RESEARCHER_PROMPT
|
from ea_chatbot.graph.prompts.researcher import RESEARCHER_PROMPT
|
||||||
|
|
||||||
def researcher_node(state: AgentState) -> dict:
|
def searcher_node(state: WorkerState) -> dict:
|
||||||
"""Handle general research queries or web searches."""
|
"""Execute web research for the specific task."""
|
||||||
question = state["question"]
|
task = state["task"]
|
||||||
history = state.get("messages", [])[-6:]
|
logger = get_logger("researcher_worker:searcher")
|
||||||
summary = state.get("summary", "")
|
|
||||||
|
logger.info(f"Researching task: {task[:50]}...")
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
logger = get_logger("researcher")
|
|
||||||
|
|
||||||
logger.info(f"Researching question: [italic]\"{question}\"[/italic]")
|
|
||||||
|
|
||||||
# Use researcher_llm from settings
|
|
||||||
llm = get_llm_model(
|
llm = get_llm_model(
|
||||||
settings.researcher_llm,
|
settings.researcher_llm,
|
||||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||||
@@ -26,34 +22,39 @@ def researcher_node(state: AgentState) -> dict:
|
|||||||
|
|
||||||
date_str = helpers.get_readable_date()
|
date_str = helpers.get_readable_date()
|
||||||
|
|
||||||
|
# Adapt the global researcher prompt for the sub-task
|
||||||
messages = RESEARCHER_PROMPT.format_messages(
|
messages = RESEARCHER_PROMPT.format_messages(
|
||||||
date=date_str,
|
date=date_str,
|
||||||
question=question,
|
question=task,
|
||||||
history=history,
|
history=[], # Worker has fresh context or task-specific history
|
||||||
summary=summary
|
summary=""
|
||||||
)
|
)
|
||||||
|
|
||||||
# Provider-aware tool binding
|
# Tool binding
|
||||||
try:
|
try:
|
||||||
if isinstance(llm, ChatGoogleGenerativeAI):
|
if isinstance(llm, ChatGoogleGenerativeAI):
|
||||||
# Native Google Search for Gemini
|
|
||||||
llm_with_tools = llm.bind_tools([{"google_search": {}}])
|
llm_with_tools = llm.bind_tools([{"google_search": {}}])
|
||||||
elif isinstance(llm, ChatOpenAI):
|
elif isinstance(llm, ChatOpenAI):
|
||||||
# Native Web Search for OpenAI (built-in tool)
|
|
||||||
llm_with_tools = llm.bind_tools([{"type": "web_search"}])
|
llm_with_tools = llm.bind_tools([{"type": "web_search"}])
|
||||||
else:
|
else:
|
||||||
# Fallback for other providers that might not support these specific search tools
|
|
||||||
llm_with_tools = llm
|
llm_with_tools = llm
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to bind search tools: {str(e)}. Falling back to base LLM.")
|
logger.warning(f"Failed to bind search tools: {str(e)}")
|
||||||
llm_with_tools = llm
|
llm_with_tools = llm
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = llm_with_tools.invoke(messages)
|
response = llm_with_tools.invoke(messages)
|
||||||
logger.info("[bold green]Research complete.[/bold green]")
|
logger.info("[bold green]Search complete.[/bold green]")
|
||||||
|
|
||||||
|
# In a real tool-use scenario, we'd extract the tool outputs here.
|
||||||
|
# For now, we'll store the response content as a 'raw_result'.
|
||||||
|
content = response.content if hasattr(response, "content") else str(response)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"messages": [response]
|
"messages": [response],
|
||||||
|
"raw_results": [content],
|
||||||
|
"iterations": state.get("iterations", 0) + 1
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Research failed: {str(e)}")
|
logger.error(f"Search failed: {str(e)}")
|
||||||
raise e
|
raise e
|
||||||
@@ -0,0 +1,49 @@
|
|||||||
|
from ea_chatbot.graph.workers.researcher.state import WorkerState
|
||||||
|
from ea_chatbot.config import Settings
|
||||||
|
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||||
|
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||||
|
|
||||||
|
def summarizer_node(state: WorkerState) -> dict:
|
||||||
|
"""Summarize research results for the Orchestrator."""
|
||||||
|
task = state["task"]
|
||||||
|
raw_results = state.get("raw_results", [])
|
||||||
|
|
||||||
|
logger = get_logger("researcher_worker:summarizer")
|
||||||
|
logger.info("Summarizing research results...")
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
llm = get_llm_model(
|
||||||
|
settings.planner_llm,
|
||||||
|
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure all results are strings (Gemini/OpenAI might return complex content)
|
||||||
|
processed_results = []
|
||||||
|
for res in raw_results:
|
||||||
|
if isinstance(res, list):
|
||||||
|
processed_results.append(str(res))
|
||||||
|
else:
|
||||||
|
processed_results.append(str(res))
|
||||||
|
|
||||||
|
results_str = "\n---\n".join(processed_results)
|
||||||
|
|
||||||
|
prompt = f"""You are a Research Specialist sub-agent. You have completed a research sub-task.
|
||||||
|
Task: {task}
|
||||||
|
|
||||||
|
Raw Research Findings:
|
||||||
|
{results_str}
|
||||||
|
|
||||||
|
Provide a concise, factual summary of the findings for the top-level Orchestrator.
|
||||||
|
Ensure all key facts, dates, and sources (if provided) are preserved.
|
||||||
|
Do NOT include internal reasoning, just the factual summary."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = llm.invoke(prompt)
|
||||||
|
result = response.content if hasattr(response, "content") else str(response)
|
||||||
|
logger.info("[bold green]Research summary complete.[/bold green]")
|
||||||
|
return {
|
||||||
|
"result": result
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to summarize research: {str(e)}")
|
||||||
|
raise e
|
||||||
21
backend/src/ea_chatbot/graph/workers/researcher/state.py
Normal file
21
backend/src/ea_chatbot/graph/workers/researcher/state.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
from typing import TypedDict, List, Dict, Any, Optional
|
||||||
|
from langchain_core.messages import BaseMessage
|
||||||
|
|
||||||
|
class WorkerState(TypedDict):
|
||||||
|
"""Internal state for the Researcher worker subgraph."""
|
||||||
|
|
||||||
|
# Internal worker conversation
|
||||||
|
messages: List[BaseMessage]
|
||||||
|
|
||||||
|
# The specific sub-task assigned by the Orchestrator
|
||||||
|
task: str
|
||||||
|
|
||||||
|
# Search context
|
||||||
|
queries: List[str]
|
||||||
|
raw_results: List[str]
|
||||||
|
|
||||||
|
# Number of internal retry/refinement attempts
|
||||||
|
iterations: int
|
||||||
|
|
||||||
|
# The final research summary to return to the Orchestrator
|
||||||
|
result: Optional[str]
|
||||||
25
backend/src/ea_chatbot/graph/workers/researcher/workflow.py
Normal file
25
backend/src/ea_chatbot/graph/workers/researcher/workflow.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
from langgraph.graph import StateGraph, END
|
||||||
|
from langgraph.graph.state import CompiledStateGraph
|
||||||
|
from ea_chatbot.graph.workers.researcher.state import WorkerState
|
||||||
|
from ea_chatbot.graph.workers.researcher.nodes.searcher import searcher_node
|
||||||
|
from ea_chatbot.graph.workers.researcher.nodes.summarizer import summarizer_node
|
||||||
|
|
||||||
|
def create_researcher_worker(
|
||||||
|
searcher=searcher_node,
|
||||||
|
summarizer=summarizer_node
|
||||||
|
) -> CompiledStateGraph:
|
||||||
|
"""Create the Researcher worker subgraph."""
|
||||||
|
workflow = StateGraph(WorkerState)
|
||||||
|
|
||||||
|
# Add Nodes
|
||||||
|
workflow.add_node("searcher", searcher)
|
||||||
|
workflow.add_node("summarizer", summarizer)
|
||||||
|
|
||||||
|
# Set entry point
|
||||||
|
workflow.set_entry_point("searcher")
|
||||||
|
|
||||||
|
# Add Edges
|
||||||
|
workflow.add_edge("searcher", "summarizer")
|
||||||
|
workflow.add_edge("summarizer", END)
|
||||||
|
|
||||||
|
return workflow.compile()
|
||||||
@@ -2,86 +2,101 @@ from langgraph.graph import StateGraph, END
|
|||||||
from ea_chatbot.graph.state import AgentState
|
from ea_chatbot.graph.state import AgentState
|
||||||
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node
|
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node
|
||||||
from ea_chatbot.graph.nodes.planner import planner_node
|
from ea_chatbot.graph.nodes.planner import planner_node
|
||||||
from ea_chatbot.graph.nodes.coder import coder_node
|
from ea_chatbot.graph.nodes.delegate import delegate_node
|
||||||
from ea_chatbot.graph.nodes.error_corrector import error_corrector_node
|
from ea_chatbot.graph.nodes.reflector import reflector_node
|
||||||
from ea_chatbot.graph.nodes.executor import executor_node
|
from ea_chatbot.graph.nodes.synthesizer import synthesizer_node
|
||||||
from ea_chatbot.graph.nodes.summarizer import summarizer_node
|
from ea_chatbot.graph.workers.data_analyst.workflow import create_data_analyst_worker
|
||||||
from ea_chatbot.graph.nodes.researcher import researcher_node
|
from ea_chatbot.graph.workers.data_analyst.mapping import prepare_worker_input, merge_worker_output
|
||||||
|
from ea_chatbot.graph.workers.researcher.workflow import create_researcher_worker
|
||||||
|
from ea_chatbot.graph.workers.researcher.mapping import prepare_researcher_input, merge_researcher_output
|
||||||
from ea_chatbot.graph.nodes.clarification import clarification_node
|
from ea_chatbot.graph.nodes.clarification import clarification_node
|
||||||
from ea_chatbot.graph.nodes.summarize_conversation import summarize_conversation_node
|
from ea_chatbot.graph.nodes.summarize_conversation import summarize_conversation_node
|
||||||
|
|
||||||
MAX_ITERATIONS = 3
|
# Global cache for compiled subgraphs
|
||||||
|
_DATA_ANALYST_WORKER = create_data_analyst_worker()
|
||||||
|
_RESEARCHER_WORKER = create_researcher_worker()
|
||||||
|
|
||||||
def router(state: AgentState) -> str:
|
# Define worker nodes as piped runnables to enable subgraph event propagation
|
||||||
"""Route to the next node based on the analysis."""
|
data_analyst_worker_runnable = prepare_worker_input | _DATA_ANALYST_WORKER | merge_worker_output
|
||||||
|
researcher_worker_runnable = prepare_researcher_input | _RESEARCHER_WORKER | merge_researcher_output
|
||||||
|
|
||||||
|
def main_router(state: AgentState) -> str:
|
||||||
|
"""Route from query analyzer based on initial assessment."""
|
||||||
next_action = state.get("next_action")
|
next_action = state.get("next_action")
|
||||||
if next_action == "plan":
|
if next_action == "clarify":
|
||||||
return "planner"
|
|
||||||
elif next_action == "research":
|
|
||||||
return "researcher"
|
|
||||||
elif next_action == "clarify":
|
|
||||||
return "clarification"
|
return "clarification"
|
||||||
else:
|
# Even if QA suggests 'research', we now go through 'planner' for orchestration
|
||||||
return END
|
# aligning with the new hierarchical architecture.
|
||||||
|
return "planner"
|
||||||
|
|
||||||
def create_workflow():
|
def delegation_router(state: AgentState) -> str:
|
||||||
"""Create the LangGraph workflow."""
|
"""Route from delegate node to specific workers or synthesis."""
|
||||||
|
next_action = state.get("next_action")
|
||||||
|
if next_action == "data_analyst":
|
||||||
|
return "data_analyst_worker"
|
||||||
|
elif next_action == "researcher":
|
||||||
|
return "researcher_worker"
|
||||||
|
elif next_action == "summarize":
|
||||||
|
return "synthesizer"
|
||||||
|
return "synthesizer"
|
||||||
|
|
||||||
|
def create_workflow(
|
||||||
|
query_analyzer=query_analyzer_node,
|
||||||
|
planner=planner_node,
|
||||||
|
delegate=delegate_node,
|
||||||
|
data_analyst_worker=data_analyst_worker_runnable,
|
||||||
|
researcher_worker=researcher_worker_runnable,
|
||||||
|
reflector=reflector_node,
|
||||||
|
synthesizer=synthesizer_node,
|
||||||
|
clarification=clarification_node,
|
||||||
|
summarize_conversation=summarize_conversation_node
|
||||||
|
):
|
||||||
|
"""Create the high-level Orchestrator workflow."""
|
||||||
workflow = StateGraph(AgentState)
|
workflow = StateGraph(AgentState)
|
||||||
|
|
||||||
# Add nodes
|
# Add Nodes
|
||||||
workflow.add_node("query_analyzer", query_analyzer_node)
|
workflow.add_node("query_analyzer", query_analyzer)
|
||||||
workflow.add_node("planner", planner_node)
|
workflow.add_node("planner", planner)
|
||||||
workflow.add_node("coder", coder_node)
|
workflow.add_node("delegate", delegate)
|
||||||
workflow.add_node("error_corrector", error_corrector_node)
|
workflow.add_node("data_analyst_worker", data_analyst_worker)
|
||||||
workflow.add_node("researcher", researcher_node)
|
workflow.add_node("researcher_worker", researcher_worker)
|
||||||
workflow.add_node("clarification", clarification_node)
|
workflow.add_node("reflector", reflector)
|
||||||
workflow.add_node("executor", executor_node)
|
workflow.add_node("synthesizer", synthesizer)
|
||||||
workflow.add_node("summarizer", summarizer_node)
|
workflow.add_node("clarification", clarification)
|
||||||
workflow.add_node("summarize_conversation", summarize_conversation_node)
|
workflow.add_node("summarize_conversation", summarize_conversation)
|
||||||
|
|
||||||
# Set entry point
|
# Set entry point
|
||||||
workflow.set_entry_point("query_analyzer")
|
workflow.set_entry_point("query_analyzer")
|
||||||
|
|
||||||
# Add conditional edges from query_analyzer
|
# Edges
|
||||||
workflow.add_conditional_edges(
|
workflow.add_conditional_edges(
|
||||||
"query_analyzer",
|
"query_analyzer",
|
||||||
router,
|
main_router,
|
||||||
{
|
{
|
||||||
"planner": "planner",
|
|
||||||
"researcher": "researcher",
|
|
||||||
"clarification": "clarification",
|
"clarification": "clarification",
|
||||||
END: END
|
"planner": "planner"
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Linear flow for planning and coding
|
workflow.add_edge("planner", "delegate")
|
||||||
workflow.add_edge("planner", "coder")
|
|
||||||
workflow.add_edge("coder", "executor")
|
|
||||||
|
|
||||||
# Executor routing
|
|
||||||
def executor_router(state: AgentState) -> str:
|
|
||||||
if state.get("error"):
|
|
||||||
# Check for iteration limit to prevent infinite loops
|
|
||||||
if state.get("iterations", 0) >= MAX_ITERATIONS:
|
|
||||||
return "summarizer"
|
|
||||||
return "error_corrector"
|
|
||||||
return "summarizer"
|
|
||||||
|
|
||||||
workflow.add_conditional_edges(
|
workflow.add_conditional_edges(
|
||||||
"executor",
|
"delegate",
|
||||||
executor_router,
|
delegation_router,
|
||||||
{
|
{
|
||||||
"error_corrector": "error_corrector",
|
"data_analyst_worker": "data_analyst_worker",
|
||||||
"summarizer": "summarizer"
|
"researcher_worker": "researcher_worker",
|
||||||
|
"synthesizer": "synthesizer"
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
workflow.add_edge("error_corrector", "executor")
|
workflow.add_edge("data_analyst_worker", "reflector")
|
||||||
|
workflow.add_edge("researcher_worker", "reflector")
|
||||||
|
workflow.add_edge("reflector", "delegate")
|
||||||
|
|
||||||
workflow.add_edge("researcher", "summarize_conversation")
|
workflow.add_edge("synthesizer", "summarize_conversation")
|
||||||
workflow.add_edge("clarification", END)
|
|
||||||
workflow.add_edge("summarizer", "summarize_conversation")
|
|
||||||
workflow.add_edge("summarize_conversation", END)
|
workflow.add_edge("summarize_conversation", END)
|
||||||
|
workflow.add_edge("clarification", END)
|
||||||
|
|
||||||
# Compile the graph
|
# Compile the graph
|
||||||
app = workflow.compile()
|
app = workflow.compile()
|
||||||
|
|||||||
@@ -1,7 +1,15 @@
|
|||||||
from pydantic import BaseModel, Field, computed_field
|
from pydantic import BaseModel, Field, computed_field
|
||||||
from typing import Sequence, Optional
|
from typing import Sequence, Optional, List, Dict, Any, Literal
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
class QueryAnalysis(BaseModel):
|
||||||
|
"""Analysis of the user's query."""
|
||||||
|
data_required: List[str] = Field(description="List of data points or entities mentioned (e.g., ['2024 results', 'Florida']).")
|
||||||
|
unknowns: List[str] = Field(description="List of target information the user wants to know or needed for final answer (e.g., 'who won', 'total votes').")
|
||||||
|
ambiguities: List[str] = Field(description="List of CRITICAL missing details that prevent ANY analysis. Do NOT include database names or plot types if defaults can be used.")
|
||||||
|
conditions: List[str] = Field(description="List of any filters or constraints (e.g., ['year=2024', 'state=Florida']). Include context resolved from history.")
|
||||||
|
next_action: Literal["plan", "clarify", "research"] = Field(description="The next action to take. 'plan' for data analysis (even with defaults), 'research' for general knowledge, or 'clarify' ONLY for critical ambiguities.")
|
||||||
|
|
||||||
class TaskPlanContext(BaseModel):
|
class TaskPlanContext(BaseModel):
|
||||||
'''Background context relevant to the task plan'''
|
'''Background context relevant to the task plan'''
|
||||||
initial_context: str = Field(
|
initial_context: str = Field(
|
||||||
@@ -33,6 +41,22 @@ class TaskPlanResponse(BaseModel):
|
|||||||
description="Ordered list of steps to execute that follow the 'Step <number>: <detail>' pattern.",
|
description="Ordered list of steps to execute that follow the 'Step <number>: <detail>' pattern.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
class ChecklistTask(BaseModel):
|
||||||
|
'''A specific sub-task in the high-level orchestrator plan'''
|
||||||
|
task: str = Field(description="Description of the sub-task")
|
||||||
|
worker: str = Field(description="The worker to delegate to (data_analyst or researcher)")
|
||||||
|
|
||||||
|
class ChecklistResponse(BaseModel):
|
||||||
|
'''Orchestrator's decomposed plan/checklist'''
|
||||||
|
goal: str = Field(description="Overall objective")
|
||||||
|
reflection: str = Field(description="Strategic reasoning")
|
||||||
|
checklist: List[ChecklistTask] = Field(description="Ordered list of tasks for specialized workers")
|
||||||
|
|
||||||
|
class ReflectorResponse(BaseModel):
|
||||||
|
'''Orchestrator's evaluation of worker output'''
|
||||||
|
satisfied: bool = Field(description="Whether the worker's output satisfies the sub-task requirements")
|
||||||
|
reasoning: str = Field(description="Brief explanation of the evaluation")
|
||||||
|
|
||||||
_IM_SEP_TOKEN_PATTERN = re.compile(re.escape("<|im_sep|>"))
|
_IM_SEP_TOKEN_PATTERN = re.compile(re.escape("<|im_sep|>"))
|
||||||
_CODE_BLOCK_PATTERN = re.compile(r"```(?:python\s*)?(.*?)\s*```", re.DOTALL)
|
_CODE_BLOCK_PATTERN = re.compile(r"```(?:python\s*)?(.*?)\s*```", re.DOTALL)
|
||||||
_FORBIDDEN_MODULES = (
|
_FORBIDDEN_MODULES = (
|
||||||
|
|||||||
69
backend/src/ea_chatbot/utils/vfs.py
Normal file
69
backend/src/ea_chatbot/utils/vfs.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
import copy
|
||||||
|
from typing import Dict, Any, Optional, Tuple, List
|
||||||
|
from ea_chatbot.utils.logging import get_logger
|
||||||
|
|
||||||
|
logger = get_logger("utils:vfs")
|
||||||
|
|
||||||
|
def safe_vfs_copy(vfs_state: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Perform a safe deep copy of the VFS state.
|
||||||
|
|
||||||
|
If an entry cannot be deep-copied (e.g., it contains a non-copyable object like a DB handle),
|
||||||
|
logs an error and replaces the entry with a descriptive error marker.
|
||||||
|
This prevents crashing the graph/persistence while making the failure explicit.
|
||||||
|
"""
|
||||||
|
new_vfs = {}
|
||||||
|
for filename, data in vfs_state.items():
|
||||||
|
try:
|
||||||
|
# Attempt a standard deepcopy for isolation
|
||||||
|
new_vfs[filename] = copy.deepcopy(data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"CRITICAL: VFS artifact '{filename}' is NOT copyable/serializable: {str(e)}. "
|
||||||
|
"Replacing with error placeholder to prevent graph crash."
|
||||||
|
)
|
||||||
|
# Replace with a standardized error artifact
|
||||||
|
new_vfs[filename] = {
|
||||||
|
"content": f"<ERROR: This artifact could not be persisted or copied: {str(e)}>",
|
||||||
|
"metadata": {
|
||||||
|
"type": "error",
|
||||||
|
"error": str(e),
|
||||||
|
"original_filename": filename
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return new_vfs
|
||||||
|
|
||||||
|
class VFSHelper:
|
||||||
|
"""Helper class for managing in-memory Virtual File System (VFS) artifacts."""
|
||||||
|
|
||||||
|
def __init__(self, vfs_state: Dict[str, Any]):
|
||||||
|
"""Initialize with a reference to the VFS state from AgentState."""
|
||||||
|
self._vfs = vfs_state
|
||||||
|
|
||||||
|
def write(self, filename: str, content: Any, metadata: Optional[Dict[str, Any]] = None) -> None:
|
||||||
|
"""Write a file to the VFS."""
|
||||||
|
self._vfs[filename] = {
|
||||||
|
"content": content,
|
||||||
|
"metadata": metadata or {}
|
||||||
|
}
|
||||||
|
|
||||||
|
def read(self, filename: str) -> Tuple[Optional[Any], Optional[Dict[str, Any]]]:
|
||||||
|
"""Read a file and its metadata from the VFS. Returns (None, None) if not found."""
|
||||||
|
file_data = self._vfs.get(filename)
|
||||||
|
if file_data is not None:
|
||||||
|
# Handle raw values (backwards compatibility or inconsistent schema)
|
||||||
|
if not isinstance(file_data, dict) or "content" not in file_data:
|
||||||
|
return file_data, {}
|
||||||
|
return file_data["content"], file_data.get("metadata", {})
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
def list(self) -> List[str]:
|
||||||
|
"""List all filenames in the VFS."""
|
||||||
|
return list(self._vfs.keys())
|
||||||
|
|
||||||
|
def delete(self, filename: str) -> bool:
|
||||||
|
"""Delete a file from the VFS. Returns True if deleted, False if not found."""
|
||||||
|
if filename in self._vfs:
|
||||||
|
del self._vfs[filename]
|
||||||
|
return True
|
||||||
|
return False
|
||||||
@@ -36,24 +36,25 @@ async def test_stream_agent_events_all_features():
|
|||||||
# Stream chunk
|
# Stream chunk
|
||||||
{
|
{
|
||||||
"event": "on_chat_model_stream",
|
"event": "on_chat_model_stream",
|
||||||
"metadata": {"langgraph_node": "summarizer"},
|
"metadata": {"langgraph_node": "synthesizer"},
|
||||||
"data": {"chunk": AIMessage(content="Hello ")}
|
"data": {"chunk": AIMessage(content="Hello ")}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"event": "on_chat_model_stream",
|
"event": "on_chat_model_stream",
|
||||||
"metadata": {"langgraph_node": "summarizer"},
|
"metadata": {"langgraph_node": "synthesizer"},
|
||||||
"data": {"chunk": AIMessage(content="world")}
|
"data": {"chunk": AIMessage(content="world")}
|
||||||
},
|
},
|
||||||
# Plot event
|
# Plot event - with nested subgraph it might bubble up or come directly from data_analyst_worker
|
||||||
|
# Let's mock it coming from the data_analyst_worker on_chain_end
|
||||||
{
|
{
|
||||||
"event": "on_chain_end",
|
"event": "on_chain_end",
|
||||||
"name": "executor",
|
"name": "data_analyst_worker",
|
||||||
"data": {"output": {"plots": [fig]}}
|
"data": {"output": {"plots": [fig]}}
|
||||||
},
|
},
|
||||||
# Final response
|
# Final response
|
||||||
{
|
{
|
||||||
"event": "on_chain_end",
|
"event": "on_chain_end",
|
||||||
"name": "summarizer",
|
"name": "synthesizer",
|
||||||
"data": {"output": {"messages": [AIMessage(content="Hello world final")]}}
|
"data": {"output": {"messages": [AIMessage(content="Hello world final")]}}
|
||||||
},
|
},
|
||||||
# Summary update
|
# Summary update
|
||||||
@@ -91,7 +92,7 @@ async def test_stream_agent_events_all_features():
|
|||||||
assert any(r.get("type") == "on_chat_model_stream" for r in results)
|
assert any(r.get("type") == "on_chat_model_stream" for r in results)
|
||||||
|
|
||||||
# Verify plot was encoded
|
# Verify plot was encoded
|
||||||
plot_event = next(r for r in results if r.get("name") == "executor")
|
plot_event = next(r for r in results if r.get("name") == "data_analyst_worker")
|
||||||
assert "encoded_plots" in plot_event["data"]
|
assert "encoded_plots" in plot_event["data"]
|
||||||
assert len(plot_event["data"]["encoded_plots"]) == 1
|
assert len(plot_event["data"]["encoded_plots"]) == 1
|
||||||
|
|
||||||
|
|||||||
@@ -123,3 +123,20 @@ def test_get_me_success(client):
|
|||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json()["email"] == "test@example.com"
|
assert response.json()["email"] == "test@example.com"
|
||||||
assert response.json()["id"] == "123"
|
assert response.json()["id"] == "123"
|
||||||
|
|
||||||
|
def test_get_me_rejects_refresh_token(client):
|
||||||
|
"""Test that /auth/me rejects refresh tokens for authentication."""
|
||||||
|
from ea_chatbot.api.utils import create_refresh_token
|
||||||
|
token = create_refresh_token(data={"sub": "123"})
|
||||||
|
|
||||||
|
with patch("ea_chatbot.api.dependencies.history_manager") as mock_hm:
|
||||||
|
# Even if the user exists, the dependency should reject the token type
|
||||||
|
mock_hm.get_user_by_id.return_value = User(id="123", username="test@example.com")
|
||||||
|
|
||||||
|
response = client.get(
|
||||||
|
"/api/v1/auth/me",
|
||||||
|
headers={"Authorization": f"Bearer {token}"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert "Cannot use refresh token" in response.json()["detail"]
|
||||||
|
|||||||
@@ -19,14 +19,13 @@ def auth_header(mock_user):
|
|||||||
yield {"Authorization": f"Bearer {token}"}
|
yield {"Authorization": f"Bearer {token}"}
|
||||||
app.dependency_overrides.clear()
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
def test_persistence_integration_success(auth_header, mock_user):
|
def test_persistence_integration_success(auth_header, mock_user):
|
||||||
"""Test that messages and plots are persisted correctly during streaming."""
|
"""Test that messages and plots are persisted correctly during streaming."""
|
||||||
mock_events = [
|
mock_events = [
|
||||||
{"event": "on_chat_model_stream", "name": "summarizer", "data": {"chunk": "Final answer"}},
|
{"event": "on_chat_model_stream", "metadata": {"langgraph_node": "synthesizer"}, "data": {"chunk": "Final answer"}},
|
||||||
{"event": "on_chain_end", "name": "summarizer", "data": {"output": {"messages": [{"content": "Final answer"}]}}},
|
{"event": "on_chain_end", "name": "synthesizer", "data": {"output": {"messages": [{"content": "Final answer"}]}}},
|
||||||
{"event": "on_chain_end", "name": "summarize_conversation", "data": {"output": {"summary": "New summary"}}}
|
{"event": "on_chain_end", "name": "summarize_conversation", "data": {"output": {"summary": "New summary"}}}
|
||||||
]
|
]
|
||||||
|
|
||||||
async def mock_astream_events(*args, **kwargs):
|
async def mock_astream_events(*args, **kwargs):
|
||||||
for event in mock_events:
|
for event in mock_events:
|
||||||
yield event
|
yield event
|
||||||
|
|||||||
@@ -1,62 +0,0 @@
|
|||||||
import pytest
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
from ea_chatbot.graph.nodes.coder import coder_node
|
|
||||||
from ea_chatbot.graph.nodes.error_corrector import error_corrector_node
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_state():
|
|
||||||
return {
|
|
||||||
"messages": [],
|
|
||||||
"question": "Show me results for New Jersey",
|
|
||||||
"plan": "Step 1: Load data\nStep 2: Filter by NJ",
|
|
||||||
"code": None,
|
|
||||||
"error": None,
|
|
||||||
"plots": [],
|
|
||||||
"dfs": {},
|
|
||||||
"next_action": "plan"
|
|
||||||
}
|
|
||||||
|
|
||||||
@patch("ea_chatbot.graph.nodes.coder.get_llm_model")
|
|
||||||
@patch("ea_chatbot.utils.database_inspection.get_data_summary")
|
|
||||||
def test_coder_node(mock_get_summary, mock_get_llm, mock_state):
|
|
||||||
"""Test coder node generates code from plan."""
|
|
||||||
mock_get_summary.return_value = "Column: Name, Type: text"
|
|
||||||
|
|
||||||
mock_llm = MagicMock()
|
|
||||||
mock_get_llm.return_value = mock_llm
|
|
||||||
|
|
||||||
from ea_chatbot.schemas import CodeGenerationResponse
|
|
||||||
mock_response = CodeGenerationResponse(
|
|
||||||
code="import pandas as pd\nprint('Hello')",
|
|
||||||
explanation="Generated code"
|
|
||||||
)
|
|
||||||
mock_llm.with_structured_output.return_value.invoke.return_value = mock_response
|
|
||||||
|
|
||||||
result = coder_node(mock_state)
|
|
||||||
|
|
||||||
assert "code" in result
|
|
||||||
assert "import pandas as pd" in result["code"]
|
|
||||||
assert "error" in result
|
|
||||||
assert result["error"] is None
|
|
||||||
|
|
||||||
@patch("ea_chatbot.graph.nodes.error_corrector.get_llm_model")
|
|
||||||
def test_error_corrector_node(mock_get_llm, mock_state):
|
|
||||||
"""Test error corrector node fixes code."""
|
|
||||||
mock_state["code"] = "import pandas as pd\nprint(undefined_var)"
|
|
||||||
mock_state["error"] = "NameError: name 'undefined_var' is not defined"
|
|
||||||
|
|
||||||
mock_llm = MagicMock()
|
|
||||||
mock_get_llm.return_value = mock_llm
|
|
||||||
|
|
||||||
from ea_chatbot.schemas import CodeGenerationResponse
|
|
||||||
mock_response = CodeGenerationResponse(
|
|
||||||
code="import pandas as pd\nprint('Defined')",
|
|
||||||
explanation="Fixed variable"
|
|
||||||
)
|
|
||||||
mock_llm.with_structured_output.return_value.invoke.return_value = mock_response
|
|
||||||
|
|
||||||
result = error_corrector_node(mock_state)
|
|
||||||
|
|
||||||
assert "code" in result
|
|
||||||
assert "print('Defined')" in result["code"]
|
|
||||||
assert result["error"] is None
|
|
||||||
53
backend/tests/test_data_analyst_mapping.py
Normal file
53
backend/tests/test_data_analyst_mapping.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
from langchain_core.messages import HumanMessage, AIMessage
|
||||||
|
from ea_chatbot.graph.state import AgentState
|
||||||
|
from ea_chatbot.graph.workers.data_analyst.mapping import (
|
||||||
|
prepare_worker_input,
|
||||||
|
merge_worker_output
|
||||||
|
)
|
||||||
|
from ea_chatbot.graph.workers.data_analyst.state import WorkerState
|
||||||
|
|
||||||
|
def test_prepare_worker_input():
|
||||||
|
"""Verify that we correctly map global state to worker input."""
|
||||||
|
global_state = AgentState(
|
||||||
|
messages=[HumanMessage(content="global message")],
|
||||||
|
question="original question",
|
||||||
|
checklist=[{"task": "Worker Task", "status": "pending"}],
|
||||||
|
current_step=0,
|
||||||
|
vfs={"old.txt": "old data"},
|
||||||
|
plots=[],
|
||||||
|
dfs={},
|
||||||
|
next_action="test",
|
||||||
|
iterations=0
|
||||||
|
)
|
||||||
|
|
||||||
|
worker_input = prepare_worker_input(global_state)
|
||||||
|
|
||||||
|
assert worker_input["task"] == "Worker Task"
|
||||||
|
assert "old.txt" in worker_input["vfs_state"]
|
||||||
|
# Internal worker messages should start fresh or with the task
|
||||||
|
assert len(worker_input["messages"]) == 1
|
||||||
|
assert worker_input["messages"][0].content == "Worker Task"
|
||||||
|
|
||||||
|
def test_merge_worker_output():
|
||||||
|
"""Verify that we correctly merge worker results back to global state."""
|
||||||
|
worker_state = WorkerState(
|
||||||
|
messages=[HumanMessage(content="internal"), AIMessage(content="summary")],
|
||||||
|
task="Worker Task",
|
||||||
|
result="Finished analysis",
|
||||||
|
plots=["plot1"],
|
||||||
|
vfs_state={"new.txt": "new data"},
|
||||||
|
iterations=2
|
||||||
|
)
|
||||||
|
|
||||||
|
updates = merge_worker_output(worker_state)
|
||||||
|
|
||||||
|
# We expect the 'result' to be added as an AI message to global history
|
||||||
|
assert len(updates["messages"]) == 1
|
||||||
|
assert updates["messages"][0].content == "Finished analysis"
|
||||||
|
|
||||||
|
# VFS should be updated
|
||||||
|
assert "new.txt" in updates["vfs"]
|
||||||
|
|
||||||
|
# Plots should be bubbled up
|
||||||
|
assert len(updates["plots"]) == 1
|
||||||
|
assert updates["plots"][0] == "plot1"
|
||||||
17
backend/tests/test_data_analyst_state.py
Normal file
17
backend/tests/test_data_analyst_state.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
from typing import get_type_hints, List
|
||||||
|
from langchain_core.messages import BaseMessage
|
||||||
|
from ea_chatbot.graph.workers.data_analyst.state import WorkerState
|
||||||
|
|
||||||
|
def test_data_analyst_worker_state():
|
||||||
|
"""Verify that DataAnalyst WorkerState has the required fields."""
|
||||||
|
hints = get_type_hints(WorkerState)
|
||||||
|
|
||||||
|
assert "messages" in hints
|
||||||
|
assert "task" in hints
|
||||||
|
assert hints["task"] == str
|
||||||
|
|
||||||
|
assert "code" in hints
|
||||||
|
assert "output" in hints
|
||||||
|
assert "error" in hints
|
||||||
|
assert "iterations" in hints
|
||||||
|
assert hints["iterations"] == int
|
||||||
81
backend/tests/test_data_analyst_worker.py
Normal file
81
backend/tests/test_data_analyst_worker.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
from ea_chatbot.graph.workers.data_analyst.workflow import create_data_analyst_worker, WorkerState
|
||||||
|
|
||||||
|
def test_data_analyst_worker_one_shot():
|
||||||
|
"""Verify a successful one-shot execution of the worker subgraph."""
|
||||||
|
mock_coder = MagicMock()
|
||||||
|
mock_executor = MagicMock()
|
||||||
|
mock_summarizer = MagicMock()
|
||||||
|
|
||||||
|
# Scenario: Coder -> Executor (Success) -> Summarizer -> END
|
||||||
|
mock_coder.return_value = {"code": "print(1)", "error": None, "iterations": 1}
|
||||||
|
mock_executor.return_value = {"output": "1\n", "error": None, "plots": []}
|
||||||
|
mock_summarizer.return_value = {"result": "Result is 1"}
|
||||||
|
|
||||||
|
graph = create_data_analyst_worker(
|
||||||
|
coder=mock_coder,
|
||||||
|
executor=mock_executor,
|
||||||
|
summarizer=mock_summarizer
|
||||||
|
)
|
||||||
|
|
||||||
|
initial_state = WorkerState(
|
||||||
|
messages=[],
|
||||||
|
task="Calculate 1+1",
|
||||||
|
code=None,
|
||||||
|
output=None,
|
||||||
|
error=None,
|
||||||
|
iterations=0,
|
||||||
|
plots=[],
|
||||||
|
vfs_state={},
|
||||||
|
result=None
|
||||||
|
)
|
||||||
|
|
||||||
|
final_state = graph.invoke(initial_state)
|
||||||
|
|
||||||
|
assert final_state["result"] == "Result is 1"
|
||||||
|
assert mock_coder.call_count == 1
|
||||||
|
assert mock_executor.call_count == 1
|
||||||
|
assert mock_summarizer.call_count == 1
|
||||||
|
|
||||||
|
def test_data_analyst_worker_retry():
|
||||||
|
"""Verify that the worker retries on error."""
|
||||||
|
mock_coder = MagicMock()
|
||||||
|
mock_executor = MagicMock()
|
||||||
|
mock_summarizer = MagicMock()
|
||||||
|
|
||||||
|
# Scenario: Coder (1) -> Executor (Error) -> Router (coder) -> Coder (2) -> Executor (Success) -> Summarizer -> END
|
||||||
|
mock_coder.side_effect = [
|
||||||
|
{"code": "error_code", "error": None, "iterations": 1},
|
||||||
|
{"code": "fixed_code", "error": None, "iterations": 2}
|
||||||
|
]
|
||||||
|
mock_executor.side_effect = [
|
||||||
|
{"output": "", "error": "NameError", "plots": []},
|
||||||
|
{"output": "Success", "error": None, "plots": []}
|
||||||
|
]
|
||||||
|
mock_summarizer.return_value = {"result": "Fixed Result"}
|
||||||
|
|
||||||
|
graph = create_data_analyst_worker(
|
||||||
|
coder=mock_coder,
|
||||||
|
executor=mock_executor,
|
||||||
|
summarizer=mock_summarizer
|
||||||
|
)
|
||||||
|
|
||||||
|
initial_state = WorkerState(
|
||||||
|
messages=[],
|
||||||
|
task="Retry Task",
|
||||||
|
code=None,
|
||||||
|
output=None,
|
||||||
|
error=None,
|
||||||
|
iterations=0,
|
||||||
|
plots=[],
|
||||||
|
vfs_state={},
|
||||||
|
result=None
|
||||||
|
)
|
||||||
|
|
||||||
|
final_state = graph.invoke(initial_state)
|
||||||
|
|
||||||
|
assert final_state["result"] == "Fixed Result"
|
||||||
|
assert mock_coder.call_count == 2
|
||||||
|
assert mock_executor.call_count == 2
|
||||||
|
assert mock_summarizer.call_count == 1
|
||||||
79
backend/tests/test_deepagents_e2e.py
Normal file
79
backend/tests/test_deepagents_e2e.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
from ea_chatbot.graph.workflow import create_workflow
|
||||||
|
from ea_chatbot.graph.state import AgentState
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage
|
||||||
|
|
||||||
|
def test_deepagents_multi_worker_sequential_flow():
|
||||||
|
"""Verify that the Orchestrator can handle a sequence of different workers."""
|
||||||
|
|
||||||
|
mock_analyzer = MagicMock()
|
||||||
|
mock_planner = MagicMock()
|
||||||
|
mock_delegate = MagicMock()
|
||||||
|
mock_analyst = MagicMock()
|
||||||
|
mock_researcher = MagicMock()
|
||||||
|
mock_reflector = MagicMock()
|
||||||
|
mock_synthesizer = MagicMock()
|
||||||
|
|
||||||
|
# 1. Analyzer: Plan
|
||||||
|
mock_analyzer.return_value = {"next_action": "plan"}
|
||||||
|
|
||||||
|
# 2. Planner: Two tasks
|
||||||
|
mock_planner.return_value = {
|
||||||
|
"checklist": [
|
||||||
|
{"task": "Get Numbers", "worker": "data_analyst"},
|
||||||
|
{"task": "Get Facts", "worker": "researcher"}
|
||||||
|
],
|
||||||
|
"current_step": 0
|
||||||
|
}
|
||||||
|
|
||||||
|
# 3. Delegate & Reflector: Loop through tasks
|
||||||
|
mock_delegate.side_effect = [
|
||||||
|
{"next_action": "data_analyst"}, # Step 0
|
||||||
|
{"next_action": "researcher"}, # Step 1
|
||||||
|
{"next_action": "summarize"} # Step 2 (done)
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_analyst.return_value = {"messages": [AIMessage(content="Analyst summary")], "vfs": {"data.csv": "..."}}
|
||||||
|
mock_researcher.return_value = {"messages": [AIMessage(content="Researcher summary")]}
|
||||||
|
|
||||||
|
mock_reflector.side_effect = [
|
||||||
|
{"current_step": 1, "next_action": "delegate"}, # Done with Analyst
|
||||||
|
{"current_step": 2, "next_action": "delegate"} # Done with Researcher
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_synthesizer.return_value = {
|
||||||
|
"messages": [AIMessage(content="Final multi-agent response")],
|
||||||
|
"next_action": "end"
|
||||||
|
}
|
||||||
|
|
||||||
|
app = create_workflow(
|
||||||
|
query_analyzer=mock_analyzer,
|
||||||
|
planner=mock_planner,
|
||||||
|
delegate=mock_delegate,
|
||||||
|
data_analyst_worker=mock_analyst,
|
||||||
|
researcher_worker=mock_researcher,
|
||||||
|
reflector=mock_reflector,
|
||||||
|
synthesizer=mock_synthesizer
|
||||||
|
)
|
||||||
|
|
||||||
|
initial_state = AgentState(
|
||||||
|
messages=[HumanMessage(content="Numbers and facts please")],
|
||||||
|
question="Numbers and facts please",
|
||||||
|
analysis={},
|
||||||
|
next_action="",
|
||||||
|
iterations=0,
|
||||||
|
checklist=[],
|
||||||
|
current_step=0,
|
||||||
|
vfs={},
|
||||||
|
plots=[],
|
||||||
|
dfs={}
|
||||||
|
)
|
||||||
|
|
||||||
|
final_state = app.invoke(initial_state, config={"recursion_limit": 30})
|
||||||
|
|
||||||
|
assert mock_analyst.called
|
||||||
|
assert mock_researcher.called
|
||||||
|
assert mock_reflector.call_count == 2
|
||||||
|
assert "Final multi-agent response" in [m.content for m in final_state["messages"]]
|
||||||
|
assert final_state["current_step"] == 2
|
||||||
@@ -1,122 +0,0 @@
|
|||||||
import pytest
|
|
||||||
import pandas as pd
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
from matplotlib.figure import Figure
|
|
||||||
from ea_chatbot.graph.nodes.executor import executor_node
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_settings():
|
|
||||||
with patch("ea_chatbot.graph.nodes.executor.Settings") as MockSettings:
|
|
||||||
mock_settings_instance = MagicMock()
|
|
||||||
mock_settings_instance.db_host = "localhost"
|
|
||||||
mock_settings_instance.db_port = 5432
|
|
||||||
mock_settings_instance.db_user = "user"
|
|
||||||
mock_settings_instance.db_pswd = "pass"
|
|
||||||
mock_settings_instance.db_name = "test_db"
|
|
||||||
mock_settings_instance.db_table = "test_table"
|
|
||||||
MockSettings.return_value = mock_settings_instance
|
|
||||||
yield mock_settings_instance
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_db_client():
|
|
||||||
with patch("ea_chatbot.graph.nodes.executor.DBClient") as MockDBClient:
|
|
||||||
mock_client_instance = MagicMock()
|
|
||||||
MockDBClient.return_value = mock_client_instance
|
|
||||||
yield mock_client_instance
|
|
||||||
|
|
||||||
def test_executor_node_success_simple_print(mock_settings, mock_db_client):
|
|
||||||
"""Test executing simple code that prints to stdout."""
|
|
||||||
state = {
|
|
||||||
"code": "print('Hello, World!')",
|
|
||||||
"question": "test",
|
|
||||||
"messages": []
|
|
||||||
}
|
|
||||||
|
|
||||||
result = executor_node(state)
|
|
||||||
|
|
||||||
assert "code_output" in result
|
|
||||||
assert "Hello, World!" in result["code_output"]
|
|
||||||
assert result["error"] is None
|
|
||||||
assert result["plots"] == []
|
|
||||||
assert result["dfs"] == {}
|
|
||||||
|
|
||||||
def test_executor_node_success_dataframe(mock_settings, mock_db_client):
|
|
||||||
"""Test executing code that creates a DataFrame."""
|
|
||||||
code = """
|
|
||||||
import pandas as pd
|
|
||||||
df = pd.DataFrame({'a': [1, 2], 'b': [3, 4]})
|
|
||||||
print(df)
|
|
||||||
"""
|
|
||||||
state = {
|
|
||||||
"code": code,
|
|
||||||
"question": "test",
|
|
||||||
"messages": []
|
|
||||||
}
|
|
||||||
|
|
||||||
result = executor_node(state)
|
|
||||||
|
|
||||||
assert "code_output" in result
|
|
||||||
assert "a b" in result["code_output"] # Check part of DF string representation
|
|
||||||
assert "dfs" in result
|
|
||||||
assert "df" in result["dfs"]
|
|
||||||
assert isinstance(result["dfs"]["df"], pd.DataFrame)
|
|
||||||
|
|
||||||
def test_executor_node_success_plot(mock_settings, mock_db_client):
|
|
||||||
"""Test executing code that generates a plot."""
|
|
||||||
code = """
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
fig = plt.figure()
|
|
||||||
plots.append(fig)
|
|
||||||
print('Plot generated')
|
|
||||||
"""
|
|
||||||
state = {
|
|
||||||
"code": code,
|
|
||||||
"question": "test",
|
|
||||||
"messages": []
|
|
||||||
}
|
|
||||||
|
|
||||||
result = executor_node(state)
|
|
||||||
|
|
||||||
assert "Plot generated" in result["code_output"]
|
|
||||||
assert "plots" in result
|
|
||||||
assert len(result["plots"]) == 1
|
|
||||||
assert isinstance(result["plots"][0], Figure)
|
|
||||||
|
|
||||||
def test_executor_node_error_syntax(mock_settings, mock_db_client):
|
|
||||||
"""Test executing code with a syntax error."""
|
|
||||||
state = {
|
|
||||||
"code": "print('Hello World", # Missing closing quote
|
|
||||||
"question": "test",
|
|
||||||
"messages": []
|
|
||||||
}
|
|
||||||
|
|
||||||
result = executor_node(state)
|
|
||||||
|
|
||||||
assert result["error"] is not None
|
|
||||||
assert "SyntaxError" in result["error"]
|
|
||||||
|
|
||||||
def test_executor_node_error_runtime(mock_settings, mock_db_client):
|
|
||||||
"""Test executing code with a runtime error."""
|
|
||||||
state = {
|
|
||||||
"code": "print(1 / 0)",
|
|
||||||
"question": "test",
|
|
||||||
"messages": []
|
|
||||||
}
|
|
||||||
|
|
||||||
result = executor_node(state)
|
|
||||||
|
|
||||||
assert result["error"] is not None
|
|
||||||
assert "ZeroDivisionError" in result["error"]
|
|
||||||
|
|
||||||
def test_executor_node_no_code(mock_settings, mock_db_client):
|
|
||||||
"""Test handling when no code is provided."""
|
|
||||||
state = {
|
|
||||||
"code": None,
|
|
||||||
"question": "test",
|
|
||||||
"messages": []
|
|
||||||
}
|
|
||||||
|
|
||||||
result = executor_node(state)
|
|
||||||
|
|
||||||
assert "error" in result
|
|
||||||
assert "No code provided" in result["error"]
|
|
||||||
@@ -2,7 +2,7 @@ import json
|
|||||||
import pytest
|
import pytest
|
||||||
import logging
|
import logging
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
from ea_chatbot.graph.workflow import app
|
from ea_chatbot.graph.workflow import create_workflow
|
||||||
from ea_chatbot.graph.state import AgentState
|
from ea_chatbot.graph.state import AgentState
|
||||||
from ea_chatbot.utils.logging import get_logger
|
from ea_chatbot.utils.logging import get_logger
|
||||||
from langchain_community.chat_models import FakeListChatModel
|
from langchain_community.chat_models import FakeListChatModel
|
||||||
@@ -43,10 +43,10 @@ def test_logging_e2e_json_output(tmp_path):
|
|||||||
"question": "Who won in 2024?",
|
"question": "Who won in 2024?",
|
||||||
"analysis": None,
|
"analysis": None,
|
||||||
"next_action": "",
|
"next_action": "",
|
||||||
"plan": None,
|
"iterations": 0,
|
||||||
"code": None,
|
"checklist": [],
|
||||||
"code_output": None,
|
"current_step": 0,
|
||||||
"error": None,
|
"vfs": {},
|
||||||
"plots": [],
|
"plots": [],
|
||||||
"dfs": {}
|
"dfs": {}
|
||||||
}
|
}
|
||||||
@@ -57,6 +57,20 @@ def test_logging_e2e_json_output(tmp_path):
|
|||||||
|
|
||||||
fake_clarify = FakeListChatModel(responses=["Please specify."])
|
fake_clarify = FakeListChatModel(responses=["Please specify."])
|
||||||
|
|
||||||
|
# Create a test app without interrupts
|
||||||
|
# We need to manually compile it here to avoid the global 'app' which has interrupts
|
||||||
|
from langgraph.graph import StateGraph, END
|
||||||
|
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node
|
||||||
|
from ea_chatbot.graph.nodes.clarification import clarification_node
|
||||||
|
|
||||||
|
workflow = StateGraph(AgentState)
|
||||||
|
workflow.add_node("query_analyzer", query_analyzer_node)
|
||||||
|
workflow.add_node("clarification", clarification_node)
|
||||||
|
workflow.set_entry_point("query_analyzer")
|
||||||
|
workflow.add_edge("query_analyzer", "clarification")
|
||||||
|
workflow.add_edge("clarification", END)
|
||||||
|
test_app = workflow.compile()
|
||||||
|
|
||||||
with patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model") as mock_llm_factory:
|
with patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model") as mock_llm_factory:
|
||||||
mock_llm_factory.return_value = fake_analyzer
|
mock_llm_factory.return_value = fake_analyzer
|
||||||
|
|
||||||
@@ -64,7 +78,7 @@ def test_logging_e2e_json_output(tmp_path):
|
|||||||
mock_clarify_llm_factory.return_value = fake_clarify
|
mock_clarify_llm_factory.return_value = fake_clarify
|
||||||
|
|
||||||
# Run the graph
|
# Run the graph
|
||||||
list(app.stream(initial_state))
|
test_app.invoke(initial_state)
|
||||||
|
|
||||||
# Verify file content
|
# Verify file content
|
||||||
assert log_file.exists()
|
assert log_file.exists()
|
||||||
|
|||||||
@@ -1,24 +1,27 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
from langchain_core.messages import HumanMessage, AIMessage
|
|
||||||
from ea_chatbot.graph.nodes.planner import planner_node
|
from ea_chatbot.graph.nodes.planner import planner_node
|
||||||
from ea_chatbot.graph.nodes.researcher import researcher_node
|
from ea_chatbot.schemas import ChecklistResponse, ChecklistTask
|
||||||
from ea_chatbot.graph.nodes.summarizer import summarizer_node
|
from langchain_core.messages import HumanMessage, AIMessage
|
||||||
from ea_chatbot.schemas import TaskPlanResponse
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_state_with_history():
|
def mock_state_with_history():
|
||||||
return {
|
return {
|
||||||
"messages": [
|
"messages": [
|
||||||
HumanMessage(content="Show me the 2024 results for Florida"),
|
HumanMessage(content="What about NJ?"),
|
||||||
AIMessage(content="Here are the results for Florida in 2024...")
|
AIMessage(content="NJ has 9 million voters.")
|
||||||
],
|
],
|
||||||
"question": "What about in New Jersey?",
|
"question": "Show me the breakdown by county for 2024",
|
||||||
"analysis": {"data_required": ["2024 results", "New Jersey"], "unknowns": [], "ambiguities": [], "conditions": []},
|
"analysis": {
|
||||||
|
"data_required": ["2024 results", "New Jersey"],
|
||||||
|
"unknowns": [],
|
||||||
|
"ambiguities": [],
|
||||||
|
"conditions": []
|
||||||
|
},
|
||||||
"next_action": "plan",
|
"next_action": "plan",
|
||||||
"summary": "The user is asking about 2024 election results.",
|
"summary": "The user is asking about NJ 2024 results.",
|
||||||
"plan": "Plan steps...",
|
"checklist": [],
|
||||||
"code_output": "Code output..."
|
"current_step": 0
|
||||||
}
|
}
|
||||||
|
|
||||||
@patch("ea_chatbot.graph.nodes.planner.get_llm_model")
|
@patch("ea_chatbot.graph.nodes.planner.get_llm_model")
|
||||||
@@ -31,49 +34,17 @@ def test_planner_uses_history_and_summary(mock_prompt, mock_get_summary, mock_ge
|
|||||||
mock_structured_llm = MagicMock()
|
mock_structured_llm = MagicMock()
|
||||||
mock_llm_instance.with_structured_output.return_value = mock_structured_llm
|
mock_llm_instance.with_structured_output.return_value = mock_structured_llm
|
||||||
|
|
||||||
mock_structured_llm.invoke.return_value = TaskPlanResponse(
|
mock_structured_llm.invoke.return_value = ChecklistResponse(
|
||||||
goal="goal",
|
goal="goal",
|
||||||
reflection="reflection",
|
reflection="reflection",
|
||||||
context={
|
checklist=[ChecklistTask(task="Step 1: test", worker="data_analyst")]
|
||||||
"initial_context": "context",
|
|
||||||
"assumptions": [],
|
|
||||||
"constraints": []
|
|
||||||
},
|
|
||||||
steps=["Step 1: test"]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
planner_node(mock_state_with_history)
|
planner_node(mock_state_with_history)
|
||||||
|
|
||||||
mock_prompt.format_messages.assert_called_once()
|
# Verify history and summary were passed to prompt format
|
||||||
kwargs = mock_prompt.format_messages.call_args[1]
|
# We check the arguments passed to the mock_prompt.format_messages
|
||||||
assert kwargs["question"] == "What about in New Jersey?"
|
call_args = mock_prompt.format_messages.call_args[1]
|
||||||
assert kwargs["summary"] == mock_state_with_history["summary"]
|
assert call_args["summary"] == "The user is asking about NJ 2024 results."
|
||||||
assert len(kwargs["history"]) == 2
|
assert len(call_args["history"]) == 2
|
||||||
|
assert "breakdown by county" in call_args["question"]
|
||||||
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
|
|
||||||
@patch("ea_chatbot.graph.nodes.researcher.RESEARCHER_PROMPT")
|
|
||||||
def test_researcher_uses_history_and_summary(mock_prompt, mock_get_llm, mock_state_with_history):
|
|
||||||
mock_llm_instance = MagicMock()
|
|
||||||
mock_get_llm.return_value = mock_llm_instance
|
|
||||||
|
|
||||||
researcher_node(mock_state_with_history)
|
|
||||||
|
|
||||||
mock_prompt.format_messages.assert_called_once()
|
|
||||||
kwargs = mock_prompt.format_messages.call_args[1]
|
|
||||||
assert kwargs["question"] == "What about in New Jersey?"
|
|
||||||
assert kwargs["summary"] == mock_state_with_history["summary"]
|
|
||||||
assert len(kwargs["history"]) == 2
|
|
||||||
|
|
||||||
@patch("ea_chatbot.graph.nodes.summarizer.get_llm_model")
|
|
||||||
@patch("ea_chatbot.graph.nodes.summarizer.SUMMARIZER_PROMPT")
|
|
||||||
def test_summarizer_uses_history_and_summary(mock_prompt, mock_get_llm, mock_state_with_history):
|
|
||||||
mock_llm_instance = MagicMock()
|
|
||||||
mock_get_llm.return_value = mock_llm_instance
|
|
||||||
|
|
||||||
summarizer_node(mock_state_with_history)
|
|
||||||
|
|
||||||
mock_prompt.format_messages.assert_called_once()
|
|
||||||
kwargs = mock_prompt.format_messages.call_args[1]
|
|
||||||
assert kwargs["question"] == "What about in New Jersey?"
|
|
||||||
assert kwargs["summary"] == mock_state_with_history["summary"]
|
|
||||||
assert len(kwargs["history"]) == 2
|
|
||||||
|
|||||||
126
backend/tests/test_node_model_selection.py
Normal file
126
backend/tests/test_node_model_selection.py
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node
|
||||||
|
from ea_chatbot.graph.nodes.planner import planner_node
|
||||||
|
from ea_chatbot.graph.nodes.reflector import reflector_node
|
||||||
|
from ea_chatbot.graph.nodes.synthesizer import synthesizer_node
|
||||||
|
from ea_chatbot.graph.workers.data_analyst.nodes.coder import coder_node as analyst_coder
|
||||||
|
from ea_chatbot.graph.workers.researcher.nodes.searcher import searcher_node as researcher_searcher
|
||||||
|
from ea_chatbot.graph.state import AgentState
|
||||||
|
from ea_chatbot.graph.workers.data_analyst.state import WorkerState as AnalystState
|
||||||
|
from ea_chatbot.graph.workers.researcher.state import WorkerState as ResearcherState
|
||||||
|
from ea_chatbot.config import Settings
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_settings():
|
||||||
|
settings = Settings()
|
||||||
|
settings.query_analyzer_llm.model = "model-qa"
|
||||||
|
settings.planner_llm.model = "model-planner"
|
||||||
|
settings.reflector_llm.model = "model-reflector"
|
||||||
|
settings.synthesizer_llm.model = "model-synthesizer"
|
||||||
|
settings.coder_llm.model = "model-coder"
|
||||||
|
settings.researcher_llm.model = "model-researcher"
|
||||||
|
return settings
|
||||||
|
|
||||||
|
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
|
||||||
|
@patch("ea_chatbot.graph.nodes.query_analyzer.Settings")
|
||||||
|
def test_query_analyzer_model(mock_settings_cls, mock_get_llm, mock_settings):
|
||||||
|
mock_settings_cls.return_value = mock_settings
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_get_llm.return_value = mock_llm
|
||||||
|
|
||||||
|
state = AgentState(question="test", messages=[], checklist=[], current_step=0, iterations=0, vfs={}, plots=[], dfs={}, next_action="", analysis={})
|
||||||
|
|
||||||
|
try:
|
||||||
|
query_analyzer_node(state)
|
||||||
|
except:
|
||||||
|
pass # We only care about the call
|
||||||
|
|
||||||
|
# Verify it was called with the QA config
|
||||||
|
args = mock_get_llm.call_args[0][0]
|
||||||
|
assert args.model == "model-qa"
|
||||||
|
|
||||||
|
@patch("ea_chatbot.graph.nodes.planner.get_llm_model")
|
||||||
|
@patch("ea_chatbot.graph.nodes.planner.Settings")
|
||||||
|
def test_planner_model(mock_settings_cls, mock_get_llm, mock_settings):
|
||||||
|
mock_settings_cls.return_value = mock_settings
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_get_llm.return_value = mock_llm
|
||||||
|
|
||||||
|
state = AgentState(question="test", messages=[], checklist=[], current_step=0, iterations=0, vfs={}, plots=[], dfs={}, next_action="", analysis={})
|
||||||
|
|
||||||
|
try:
|
||||||
|
planner_node(state)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
args = mock_get_llm.call_args[0][0]
|
||||||
|
assert args.model == "model-planner"
|
||||||
|
|
||||||
|
@patch("ea_chatbot.graph.nodes.reflector.get_llm_model")
|
||||||
|
@patch("ea_chatbot.graph.nodes.reflector.Settings")
|
||||||
|
def test_reflector_model(mock_settings_cls, mock_get_llm, mock_settings):
|
||||||
|
mock_settings_cls.return_value = mock_settings
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_get_llm.return_value = mock_llm
|
||||||
|
|
||||||
|
state = AgentState(question="test", messages=[], checklist=[{"task": "T1"}], current_step=0, iterations=0, vfs={}, plots=[], dfs={}, next_action="", analysis={}, summary="Worker did X")
|
||||||
|
|
||||||
|
try:
|
||||||
|
reflector_node(state)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
args = mock_get_llm.call_args[0][0]
|
||||||
|
assert args.model == "model-reflector"
|
||||||
|
|
||||||
|
@patch("ea_chatbot.graph.nodes.synthesizer.get_llm_model")
|
||||||
|
@patch("ea_chatbot.graph.nodes.synthesizer.Settings")
|
||||||
|
def test_synthesizer_model(mock_settings_cls, mock_get_llm, mock_settings):
|
||||||
|
mock_settings_cls.return_value = mock_settings
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_get_llm.return_value = mock_llm
|
||||||
|
|
||||||
|
state = AgentState(question="test", messages=[], checklist=[], current_step=0, iterations=0, vfs={}, plots=[], dfs={}, next_action="", analysis={})
|
||||||
|
|
||||||
|
try:
|
||||||
|
synthesizer_node(state)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
args = mock_get_llm.call_args[0][0]
|
||||||
|
assert args.model == "model-synthesizer"
|
||||||
|
|
||||||
|
@patch("ea_chatbot.graph.workers.data_analyst.nodes.coder.get_llm_model")
|
||||||
|
@patch("ea_chatbot.graph.workers.data_analyst.nodes.coder.Settings")
|
||||||
|
def test_analyst_coder_model(mock_settings_cls, mock_get_llm, mock_settings):
|
||||||
|
mock_settings_cls.return_value = mock_settings
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_get_llm.return_value = mock_llm
|
||||||
|
|
||||||
|
state = AnalystState(task="test", messages=[], iterations=0, vfs_state={}, plots=[], result=None, code=None, output=None, error=None)
|
||||||
|
|
||||||
|
try:
|
||||||
|
analyst_coder(state)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
args = mock_get_llm.call_args[0][0]
|
||||||
|
assert args.model == "model-coder"
|
||||||
|
|
||||||
|
@patch("ea_chatbot.graph.workers.researcher.nodes.searcher.get_llm_model")
|
||||||
|
@patch("ea_chatbot.graph.workers.researcher.nodes.searcher.Settings")
|
||||||
|
def test_researcher_searcher_model(mock_settings_cls, mock_get_llm, mock_settings):
|
||||||
|
mock_settings_cls.return_value = mock_settings
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_get_llm.return_value = mock_llm
|
||||||
|
|
||||||
|
state = ResearcherState(task="test", messages=[], iterations=0, result=None, queries=[], raw_results=[])
|
||||||
|
|
||||||
|
try:
|
||||||
|
researcher_searcher(state)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
args = mock_get_llm.call_args[0][0]
|
||||||
|
assert args.model == "model-researcher"
|
||||||
56
backend/tests/test_orchestrator_delegate.py
Normal file
56
backend/tests/test_orchestrator_delegate.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
from ea_chatbot.graph.nodes.delegate import delegate_node
|
||||||
|
from ea_chatbot.graph.state import AgentState
|
||||||
|
|
||||||
|
def test_delegate_node_data_analyst():
|
||||||
|
"""Verify that the delegate node routes to data_analyst."""
|
||||||
|
state = AgentState(
|
||||||
|
checklist=[{"task": "Analyze data", "worker": "data_analyst"}],
|
||||||
|
current_step=0,
|
||||||
|
messages=[],
|
||||||
|
question="test",
|
||||||
|
analysis={},
|
||||||
|
next_action="",
|
||||||
|
iterations=0,
|
||||||
|
vfs={},
|
||||||
|
plots=[],
|
||||||
|
dfs={}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = delegate_node(state)
|
||||||
|
assert result["next_action"] == "data_analyst"
|
||||||
|
|
||||||
|
def test_delegate_node_researcher():
|
||||||
|
"""Verify that the delegate node routes to researcher."""
|
||||||
|
state = AgentState(
|
||||||
|
checklist=[{"task": "Search web", "worker": "researcher"}],
|
||||||
|
current_step=0,
|
||||||
|
messages=[],
|
||||||
|
question="test",
|
||||||
|
analysis={},
|
||||||
|
next_action="",
|
||||||
|
iterations=0,
|
||||||
|
vfs={},
|
||||||
|
plots=[],
|
||||||
|
dfs={}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = delegate_node(state)
|
||||||
|
assert result["next_action"] == "researcher"
|
||||||
|
|
||||||
|
def test_delegate_node_finished():
|
||||||
|
"""Verify that the delegate node routes to summarize if checklist is complete."""
|
||||||
|
state = AgentState(
|
||||||
|
checklist=[{"task": "Task 1", "worker": "data_analyst"}],
|
||||||
|
current_step=1, # Already finished the only task
|
||||||
|
messages=[],
|
||||||
|
question="test",
|
||||||
|
analysis={},
|
||||||
|
next_action="",
|
||||||
|
iterations=0,
|
||||||
|
vfs={},
|
||||||
|
plots=[],
|
||||||
|
dfs={}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = delegate_node(state)
|
||||||
|
assert result["next_action"] == "summarize"
|
||||||
136
backend/tests/test_orchestrator_loop.py
Normal file
136
backend/tests/test_orchestrator_loop.py
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
from ea_chatbot.graph.workflow import create_workflow
|
||||||
|
from ea_chatbot.graph.state import AgentState
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage
|
||||||
|
|
||||||
|
def test_orchestrator_full_flow():
|
||||||
|
"""Verify the full Orchestrator-Workers flow via direct node injection."""
|
||||||
|
|
||||||
|
mock_analyzer = MagicMock()
|
||||||
|
mock_planner = MagicMock()
|
||||||
|
mock_delegate = MagicMock()
|
||||||
|
mock_worker = MagicMock()
|
||||||
|
mock_reflector = MagicMock()
|
||||||
|
mock_synthesizer = MagicMock()
|
||||||
|
mock_summarize_conv = MagicMock()
|
||||||
|
|
||||||
|
# 1. Analyzer: Proceed to planning
|
||||||
|
mock_analyzer.return_value = {"next_action": "plan"}
|
||||||
|
|
||||||
|
# 2. Planner: Generate checklist
|
||||||
|
mock_planner.return_value = {
|
||||||
|
"checklist": [{"task": "T1", "worker": "data_analyst"}],
|
||||||
|
"current_step": 0
|
||||||
|
}
|
||||||
|
|
||||||
|
# 3. Delegate: Route to data_analyst
|
||||||
|
mock_delegate.side_effect = [
|
||||||
|
{"next_action": "data_analyst"}, # First call
|
||||||
|
{"next_action": "summarize"} # Second call (after reflector)
|
||||||
|
]
|
||||||
|
|
||||||
|
# 4. Worker: Success
|
||||||
|
mock_worker.return_value = {
|
||||||
|
"messages": [AIMessage(content="Worker result")],
|
||||||
|
"vfs": {"res.txt": "data"}
|
||||||
|
}
|
||||||
|
|
||||||
|
# 5. Reflector: Advance
|
||||||
|
mock_reflector.return_value = {
|
||||||
|
"current_step": 1,
|
||||||
|
"next_action": "delegate"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 6. Synthesizer: Final answer
|
||||||
|
mock_synthesizer.return_value = {
|
||||||
|
"messages": [AIMessage(content="Final synthesized answer")],
|
||||||
|
"next_action": "end"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 7. Summarize Conv: End
|
||||||
|
mock_summarize_conv.return_value = {"summary": "Done"}
|
||||||
|
|
||||||
|
# Create workflow with injected mocks
|
||||||
|
app = create_workflow(
|
||||||
|
query_analyzer=mock_analyzer,
|
||||||
|
planner=mock_planner,
|
||||||
|
delegate=mock_delegate,
|
||||||
|
data_analyst_worker=mock_worker,
|
||||||
|
reflector=mock_reflector,
|
||||||
|
synthesizer=mock_synthesizer,
|
||||||
|
summarize_conversation=mock_summarize_conv
|
||||||
|
)
|
||||||
|
|
||||||
|
initial_state = AgentState(
|
||||||
|
messages=[HumanMessage(content="Explain results")],
|
||||||
|
question="Explain results",
|
||||||
|
analysis={},
|
||||||
|
next_action="",
|
||||||
|
iterations=0,
|
||||||
|
checklist=[],
|
||||||
|
current_step=0,
|
||||||
|
vfs={},
|
||||||
|
plots=[],
|
||||||
|
dfs={}
|
||||||
|
)
|
||||||
|
|
||||||
|
final_state = app.invoke(initial_state)
|
||||||
|
|
||||||
|
assert mock_analyzer.called
|
||||||
|
assert mock_planner.called
|
||||||
|
assert mock_delegate.call_count == 2
|
||||||
|
assert mock_worker.called
|
||||||
|
assert mock_reflector.called
|
||||||
|
assert mock_synthesizer.called
|
||||||
|
assert "Final synthesized answer" in [m.content for m in final_state["messages"]]
|
||||||
|
|
||||||
|
def test_orchestrator_researcher_flow():
|
||||||
|
"""Verify that the Orchestrator can route to the researcher worker."""
|
||||||
|
|
||||||
|
mock_analyzer = MagicMock()
|
||||||
|
mock_planner = MagicMock()
|
||||||
|
mock_delegate = MagicMock()
|
||||||
|
mock_researcher = MagicMock()
|
||||||
|
mock_reflector = MagicMock()
|
||||||
|
mock_synthesizer = MagicMock()
|
||||||
|
|
||||||
|
mock_analyzer.return_value = {"next_action": "plan"}
|
||||||
|
mock_planner.return_value = {
|
||||||
|
"checklist": [{"task": "Search news", "worker": "researcher"}],
|
||||||
|
"current_step": 0
|
||||||
|
}
|
||||||
|
mock_delegate.side_effect = [
|
||||||
|
{"next_action": "researcher"},
|
||||||
|
{"next_action": "summarize"}
|
||||||
|
]
|
||||||
|
mock_researcher.return_value = {"messages": [AIMessage(content="News found")]}
|
||||||
|
mock_reflector.return_value = {"current_step": 1, "next_action": "delegate"}
|
||||||
|
mock_synthesizer.return_value = {"messages": [AIMessage(content="Final News Summary")], "next_action": "end"}
|
||||||
|
|
||||||
|
app = create_workflow(
|
||||||
|
query_analyzer=mock_analyzer,
|
||||||
|
planner=mock_planner,
|
||||||
|
delegate=mock_delegate,
|
||||||
|
researcher_worker=mock_researcher,
|
||||||
|
reflector=mock_reflector,
|
||||||
|
synthesizer=mock_synthesizer
|
||||||
|
)
|
||||||
|
|
||||||
|
initial_state = AgentState(
|
||||||
|
messages=[HumanMessage(content="What's the news?")],
|
||||||
|
question="What's the news?",
|
||||||
|
analysis={},
|
||||||
|
next_action="",
|
||||||
|
iterations=0,
|
||||||
|
checklist=[],
|
||||||
|
current_step=0,
|
||||||
|
vfs={},
|
||||||
|
plots=[],
|
||||||
|
dfs={}
|
||||||
|
)
|
||||||
|
|
||||||
|
final_state = app.invoke(initial_state)
|
||||||
|
|
||||||
|
assert mock_researcher.called
|
||||||
|
assert "Final News Summary" in [m.content for m in final_state["messages"]]
|
||||||
34
backend/tests/test_orchestrator_planner.py
Normal file
34
backend/tests/test_orchestrator_planner.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
from typing import get_type_hints, List
|
||||||
|
from ea_chatbot.graph.nodes.planner import planner_node
|
||||||
|
from ea_chatbot.graph.state import AgentState
|
||||||
|
|
||||||
|
def test_planner_node_checklist():
|
||||||
|
"""Verify that the planner node generates a checklist."""
|
||||||
|
state = AgentState(
|
||||||
|
messages=[],
|
||||||
|
question="How many voters are in Florida and what is the current news?",
|
||||||
|
analysis={"requires_dataset": True},
|
||||||
|
next_action="plan",
|
||||||
|
iterations=0,
|
||||||
|
checklist=[],
|
||||||
|
current_step=0,
|
||||||
|
vfs={},
|
||||||
|
plots=[],
|
||||||
|
dfs={}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mocking the LLM would be ideal, but for now we'll check the returned keys
|
||||||
|
# and assume the implementation provides them.
|
||||||
|
# In a real TDD, we'd mock the LLM to return a specific structure.
|
||||||
|
|
||||||
|
# For now, let's assume the task is to update 'planner_node' to return these keys.
|
||||||
|
result = planner_node(state)
|
||||||
|
|
||||||
|
assert "checklist" in result
|
||||||
|
assert isinstance(result["checklist"], list)
|
||||||
|
assert len(result["checklist"]) > 0
|
||||||
|
assert "task" in result["checklist"][0]
|
||||||
|
assert "worker" in result["checklist"][0] # 'data_analyst' or 'researcher'
|
||||||
|
|
||||||
|
assert "current_step" in result
|
||||||
|
assert result["current_step"] == 0
|
||||||
30
backend/tests/test_orchestrator_reflector.py
Normal file
30
backend/tests/test_orchestrator_reflector.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
from ea_chatbot.graph.nodes.reflector import reflector_node
|
||||||
|
from ea_chatbot.graph.state import AgentState
|
||||||
|
|
||||||
|
def test_reflector_node_satisfied():
|
||||||
|
"""Verify that the reflector node increments current_step when satisfied."""
|
||||||
|
state = AgentState(
|
||||||
|
checklist=[{"task": "Analyze votes", "worker": "data_analyst"}],
|
||||||
|
current_step=0,
|
||||||
|
messages=[],
|
||||||
|
question="test",
|
||||||
|
analysis={},
|
||||||
|
next_action="",
|
||||||
|
iterations=0,
|
||||||
|
vfs={},
|
||||||
|
plots=[],
|
||||||
|
dfs={},
|
||||||
|
summary="Worker successfully analyzed votes and found 5 million."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mocking the LLM to return 'satisfied=True'
|
||||||
|
with patch("ea_chatbot.graph.nodes.reflector.get_llm_model") as mock_get_llm:
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.with_structured_output.return_value.invoke.return_value = MagicMock(satisfied=True, reasoning="Good.")
|
||||||
|
mock_get_llm.return_value = mock_llm
|
||||||
|
|
||||||
|
result = reflector_node(state)
|
||||||
|
|
||||||
|
assert result["current_step"] == 1
|
||||||
|
assert result["next_action"] == "delegate" # Route back to delegate for next task
|
||||||
31
backend/tests/test_orchestrator_synthesizer.py
Normal file
31
backend/tests/test_orchestrator_synthesizer.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
from ea_chatbot.graph.nodes.synthesizer import synthesizer_node
|
||||||
|
from ea_chatbot.graph.state import AgentState
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
|
||||||
|
def test_synthesizer_node_success():
|
||||||
|
"""Verify that the synthesizer node produces a final response."""
|
||||||
|
state = AgentState(
|
||||||
|
messages=[AIMessage(content="Worker 1 found data."), AIMessage(content="Worker 2 searched web.")],
|
||||||
|
question="What are the results?",
|
||||||
|
checklist=[],
|
||||||
|
current_step=0,
|
||||||
|
iterations=0,
|
||||||
|
vfs={},
|
||||||
|
plots=[],
|
||||||
|
dfs={},
|
||||||
|
next_action="",
|
||||||
|
analysis={}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mocking the LLM
|
||||||
|
with patch("ea_chatbot.graph.nodes.synthesizer.get_llm_model") as mock_get_llm:
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.invoke.return_value = AIMessage(content="Final synthesized answer.")
|
||||||
|
mock_get_llm.return_value = mock_llm
|
||||||
|
|
||||||
|
result = synthesizer_node(state)
|
||||||
|
|
||||||
|
assert len(result["messages"]) == 1
|
||||||
|
assert result["messages"][0].content == "Final synthesized answer."
|
||||||
|
assert result["next_action"] == "end"
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
from ea_chatbot.graph.nodes.planner import planner_node
|
from ea_chatbot.graph.nodes.planner import planner_node
|
||||||
|
from ea_chatbot.schemas import ChecklistResponse, ChecklistTask
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_state():
|
def mock_state():
|
||||||
@@ -8,39 +9,40 @@ def mock_state():
|
|||||||
"messages": [],
|
"messages": [],
|
||||||
"question": "Show me results for New Jersey",
|
"question": "Show me results for New Jersey",
|
||||||
"analysis": {
|
"analysis": {
|
||||||
# "requires_dataset" removed as it's no longer used
|
|
||||||
"expert": "Data Analyst",
|
"expert": "Data Analyst",
|
||||||
"data": "NJ data",
|
"data": "NJ data",
|
||||||
"unknown": "results",
|
"unknown": "results",
|
||||||
"condition": "state=NJ"
|
"condition": "state=NJ"
|
||||||
},
|
},
|
||||||
"next_action": "plan",
|
"next_action": "plan",
|
||||||
"plan": None
|
"checklist": [],
|
||||||
|
"current_step": 0
|
||||||
}
|
}
|
||||||
|
|
||||||
@patch("ea_chatbot.graph.nodes.planner.get_llm_model")
|
@patch("ea_chatbot.graph.nodes.planner.get_llm_model")
|
||||||
@patch("ea_chatbot.utils.database_inspection.get_data_summary")
|
@patch("ea_chatbot.utils.database_inspection.get_data_summary")
|
||||||
def test_planner_node(mock_get_summary, mock_get_llm, mock_state):
|
def test_planner_node(mock_get_summary, mock_get_llm, mock_state):
|
||||||
"""Test planner node with unified prompt."""
|
"""Test planner node with checklist generation."""
|
||||||
mock_get_summary.return_value = "Column: Name, Type: text"
|
mock_get_summary.return_value = "Column: Name, Type: text"
|
||||||
|
|
||||||
mock_llm = MagicMock()
|
mock_llm = MagicMock()
|
||||||
mock_get_llm.return_value = mock_llm
|
mock_get_llm.return_value = mock_llm
|
||||||
|
|
||||||
from ea_chatbot.schemas import TaskPlanResponse, TaskPlanContext
|
mock_response = ChecklistResponse(
|
||||||
mock_plan = TaskPlanResponse(
|
|
||||||
goal="Get NJ results",
|
goal="Get NJ results",
|
||||||
reflection="The user wants NJ results",
|
reflection="The user wants NJ results",
|
||||||
context=TaskPlanContext(initial_context="NJ data", assumptions=[], constraints=[]),
|
checklist=[
|
||||||
steps=["Step 1: Load data", "Step 2: Filter by NJ"]
|
ChecklistTask(task="Query NJ data", worker="data_analyst")
|
||||||
|
]
|
||||||
)
|
)
|
||||||
mock_llm.with_structured_output.return_value.invoke.return_value = mock_plan
|
mock_llm.with_structured_output.return_value.invoke.return_value = mock_response
|
||||||
|
|
||||||
result = planner_node(mock_state)
|
result = planner_node(mock_state)
|
||||||
|
|
||||||
assert "plan" in result
|
assert "checklist" in result
|
||||||
assert "Step 1: Load data" in result["plan"]
|
assert result["checklist"][0]["task"] == "Query NJ data"
|
||||||
assert "Step 2: Filter by NJ" in result["plan"]
|
assert result["current_step"] == 0
|
||||||
|
assert result["summary"] == "The user wants NJ results"
|
||||||
|
|
||||||
# Verify helper was called
|
# Verify helper was called
|
||||||
mock_get_summary.assert_called_once()
|
mock_get_summary.assert_called_once()
|
||||||
|
|||||||
@@ -1,34 +0,0 @@
|
|||||||
import pytest
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
from langchain_core.messages import AIMessage
|
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
from ea_chatbot.graph.nodes.researcher import researcher_node
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_llm():
|
|
||||||
with patch("ea_chatbot.graph.nodes.researcher.get_llm_model") as mock_get_llm:
|
|
||||||
mock_llm_instance = MagicMock(spec=ChatOpenAI)
|
|
||||||
mock_get_llm.return_value = mock_llm_instance
|
|
||||||
yield mock_llm_instance
|
|
||||||
|
|
||||||
def test_researcher_node_success(mock_llm):
|
|
||||||
"""Test that researcher_node invokes LLM with web_search tool and returns messages."""
|
|
||||||
state = {
|
|
||||||
"question": "What is the capital of France?",
|
|
||||||
"messages": []
|
|
||||||
}
|
|
||||||
|
|
||||||
mock_llm_with_tools = MagicMock()
|
|
||||||
mock_llm.bind_tools.return_value = mock_llm_with_tools
|
|
||||||
mock_llm_with_tools.invoke.return_value = AIMessage(content="The capital of France is Paris.")
|
|
||||||
|
|
||||||
result = researcher_node(state)
|
|
||||||
|
|
||||||
assert mock_llm.bind_tools.called
|
|
||||||
# Check that it was called with web_search
|
|
||||||
args, kwargs = mock_llm.bind_tools.call_args
|
|
||||||
assert {"type": "web_search"} in args[0]
|
|
||||||
|
|
||||||
assert mock_llm_with_tools.invoke.called
|
|
||||||
assert "messages" in result
|
|
||||||
assert result["messages"][0].content == "The capital of France is Paris."
|
|
||||||
@@ -1,62 +0,0 @@
|
|||||||
import pytest
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
from langchain_core.messages import AIMessage
|
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
||||||
from ea_chatbot.graph.nodes.researcher import researcher_node
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def base_state():
|
|
||||||
return {
|
|
||||||
"question": "Who won the 2024 election?",
|
|
||||||
"messages": [],
|
|
||||||
"summary": ""
|
|
||||||
}
|
|
||||||
|
|
||||||
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
|
|
||||||
def test_researcher_binds_openai_search(mock_get_llm, base_state):
|
|
||||||
"""Test that OpenAI LLM binds 'web_search' tool."""
|
|
||||||
mock_llm = MagicMock(spec=ChatOpenAI)
|
|
||||||
mock_get_llm.return_value = mock_llm
|
|
||||||
|
|
||||||
mock_llm_with_tools = MagicMock()
|
|
||||||
mock_llm.bind_tools.return_value = mock_llm_with_tools
|
|
||||||
mock_llm_with_tools.invoke.return_value = AIMessage(content="OpenAI Search Result")
|
|
||||||
|
|
||||||
result = researcher_node(base_state)
|
|
||||||
|
|
||||||
# Verify bind_tools called with correct OpenAI tool
|
|
||||||
mock_llm.bind_tools.assert_called_once_with([{"type": "web_search"}])
|
|
||||||
assert result["messages"][0].content == "OpenAI Search Result"
|
|
||||||
|
|
||||||
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
|
|
||||||
def test_researcher_binds_google_search(mock_get_llm, base_state):
|
|
||||||
"""Test that Google LLM binds 'google_search' tool."""
|
|
||||||
mock_llm = MagicMock(spec=ChatGoogleGenerativeAI)
|
|
||||||
mock_get_llm.return_value = mock_llm
|
|
||||||
|
|
||||||
mock_llm_with_tools = MagicMock()
|
|
||||||
mock_llm.bind_tools.return_value = mock_llm_with_tools
|
|
||||||
mock_llm_with_tools.invoke.return_value = AIMessage(content="Google Search Result")
|
|
||||||
|
|
||||||
result = researcher_node(base_state)
|
|
||||||
|
|
||||||
# Verify bind_tools called with correct Google tool
|
|
||||||
mock_llm.bind_tools.assert_called_once_with([{"google_search": {}}])
|
|
||||||
assert result["messages"][0].content == "Google Search Result"
|
|
||||||
|
|
||||||
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
|
|
||||||
def test_researcher_fallback_on_bind_error(mock_get_llm, base_state):
|
|
||||||
"""Test that researcher falls back to basic LLM if bind_tools fails."""
|
|
||||||
mock_llm = MagicMock(spec=ChatOpenAI)
|
|
||||||
mock_get_llm.return_value = mock_llm
|
|
||||||
|
|
||||||
# Simulate bind_tools failing (e.g. model doesn't support it)
|
|
||||||
mock_llm.bind_tools.side_effect = Exception("Not supported")
|
|
||||||
mock_llm.invoke.return_value = AIMessage(content="Basic Result")
|
|
||||||
|
|
||||||
result = researcher_node(base_state)
|
|
||||||
|
|
||||||
# Should still succeed using the base LLM
|
|
||||||
assert result["messages"][0].content == "Basic Result"
|
|
||||||
mock_llm.invoke.assert_called_once()
|
|
||||||
13
backend/tests/test_researcher_state.py
Normal file
13
backend/tests/test_researcher_state.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
from typing import get_type_hints, List
|
||||||
|
from ea_chatbot.graph.workers.researcher.state import WorkerState
|
||||||
|
|
||||||
|
def test_researcher_worker_state():
|
||||||
|
"""Verify that Researcher WorkerState has the required fields."""
|
||||||
|
hints = get_type_hints(WorkerState)
|
||||||
|
|
||||||
|
assert "messages" in hints
|
||||||
|
assert "task" in hints
|
||||||
|
assert "queries" in hints
|
||||||
|
assert "raw_results" in hints
|
||||||
|
assert "iterations" in hints
|
||||||
|
assert "result" in hints
|
||||||
67
backend/tests/test_researcher_worker.py
Normal file
67
backend/tests/test_researcher_worker.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
from ea_chatbot.graph.workers.researcher.workflow import create_researcher_worker, WorkerState
|
||||||
|
from ea_chatbot.graph.workers.researcher.mapping import prepare_researcher_input, merge_researcher_output
|
||||||
|
from ea_chatbot.graph.state import AgentState
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
|
||||||
|
def test_researcher_worker_flow():
|
||||||
|
"""Verify that the Researcher worker flow works as expected."""
|
||||||
|
mock_searcher = MagicMock()
|
||||||
|
mock_summarizer = MagicMock()
|
||||||
|
|
||||||
|
mock_searcher.return_value = {
|
||||||
|
"raw_results": ["Result A"],
|
||||||
|
"messages": [AIMessage(content="Search result")]
|
||||||
|
}
|
||||||
|
mock_summarizer.return_value = {"result": "Consolidated Summary"}
|
||||||
|
|
||||||
|
graph = create_researcher_worker(
|
||||||
|
searcher=mock_searcher,
|
||||||
|
summarizer=mock_summarizer
|
||||||
|
)
|
||||||
|
|
||||||
|
initial_state = WorkerState(
|
||||||
|
messages=[],
|
||||||
|
task="Find governor",
|
||||||
|
queries=[],
|
||||||
|
raw_results=[],
|
||||||
|
iterations=0,
|
||||||
|
result=None
|
||||||
|
)
|
||||||
|
|
||||||
|
final_state = graph.invoke(initial_state)
|
||||||
|
|
||||||
|
assert final_state["result"] == "Consolidated Summary"
|
||||||
|
assert mock_searcher.called
|
||||||
|
assert mock_summarizer.called
|
||||||
|
|
||||||
|
def test_researcher_mapping():
|
||||||
|
"""Verify that we correctly map states for the researcher."""
|
||||||
|
global_state = AgentState(
|
||||||
|
checklist=[{"task": "Search X", "worker": "researcher"}],
|
||||||
|
current_step=0,
|
||||||
|
messages=[],
|
||||||
|
question="test",
|
||||||
|
analysis={},
|
||||||
|
next_action="",
|
||||||
|
iterations=0,
|
||||||
|
vfs={},
|
||||||
|
plots=[],
|
||||||
|
dfs={}
|
||||||
|
)
|
||||||
|
|
||||||
|
worker_input = prepare_researcher_input(global_state)
|
||||||
|
assert worker_input["task"] == "Search X"
|
||||||
|
|
||||||
|
worker_output = WorkerState(
|
||||||
|
messages=[],
|
||||||
|
task="Search X",
|
||||||
|
queries=[],
|
||||||
|
raw_results=[],
|
||||||
|
iterations=1,
|
||||||
|
result="Found X"
|
||||||
|
)
|
||||||
|
|
||||||
|
updates = merge_researcher_output(worker_output)
|
||||||
|
assert updates["messages"][0].content == "Found X"
|
||||||
55
backend/tests/test_review_fix_clarification.py
Normal file
55
backend/tests/test_review_fix_clarification.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
from ea_chatbot.graph.workflow import create_workflow
|
||||||
|
from ea_chatbot.graph.state import AgentState
|
||||||
|
from langchain_core.messages import HumanMessage, AIMessage
|
||||||
|
|
||||||
|
def test_clarification_flow_immediate_execution():
|
||||||
|
"""Verify that an ambiguous query immediately executes the clarification node without interruption."""
|
||||||
|
|
||||||
|
mock_analyzer = MagicMock()
|
||||||
|
mock_clarification = MagicMock()
|
||||||
|
|
||||||
|
# 1. Analyzer returns 'clarify'
|
||||||
|
mock_analyzer.return_value = {"next_action": "clarify"}
|
||||||
|
|
||||||
|
# 2. Clarification node returns a question
|
||||||
|
mock_clarification.return_value = {"messages": [AIMessage(content="What year?")]}
|
||||||
|
|
||||||
|
# Create workflow without other nodes since they won't be reached
|
||||||
|
# We still need to provide mock planners etc. to create_workflow
|
||||||
|
app = create_workflow(
|
||||||
|
query_analyzer=mock_analyzer,
|
||||||
|
clarification=mock_clarification,
|
||||||
|
planner=MagicMock(),
|
||||||
|
delegate=MagicMock(),
|
||||||
|
data_analyst_worker=MagicMock(),
|
||||||
|
researcher_worker=MagicMock(),
|
||||||
|
reflector=MagicMock(),
|
||||||
|
synthesizer=MagicMock(),
|
||||||
|
summarize_conversation=MagicMock()
|
||||||
|
)
|
||||||
|
|
||||||
|
initial_state = AgentState(
|
||||||
|
messages=[HumanMessage(content="Who won?")],
|
||||||
|
question="Who won?",
|
||||||
|
analysis={},
|
||||||
|
next_action="",
|
||||||
|
iterations=0,
|
||||||
|
checklist=[],
|
||||||
|
current_step=0,
|
||||||
|
vfs={},
|
||||||
|
plots=[],
|
||||||
|
dfs={}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run the graph
|
||||||
|
final_state = app.invoke(initial_state)
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
assert mock_analyzer.called
|
||||||
|
assert mock_clarification.called
|
||||||
|
|
||||||
|
# Verify the state contains the clarification message
|
||||||
|
assert len(final_state["messages"]) > 0
|
||||||
|
assert "What year?" in [m.content for m in final_state["messages"]]
|
||||||
48
backend/tests/test_review_fix_integration.py
Normal file
48
backend/tests/test_review_fix_integration.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
import pytest
|
||||||
|
from ea_chatbot.graph.workflow import create_workflow, data_analyst_worker_runnable
|
||||||
|
from ea_chatbot.graph.state import AgentState
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
from langchain_core.messages import HumanMessage, AIMessage
|
||||||
|
|
||||||
|
def test_worker_merge_sets_summary_for_reflector(monkeypatch):
|
||||||
|
"""Verify that worker node (runnable) sets the 'summary' field for the Reflector."""
|
||||||
|
|
||||||
|
state = AgentState(
|
||||||
|
messages=[HumanMessage(content="test")],
|
||||||
|
question="test",
|
||||||
|
checklist=[{"task": "Analyze data", "worker": "data_analyst"}],
|
||||||
|
current_step=0,
|
||||||
|
iterations=0,
|
||||||
|
vfs={},
|
||||||
|
plots=[],
|
||||||
|
dfs={},
|
||||||
|
next_action="",
|
||||||
|
analysis={},
|
||||||
|
summary="Initial Planner Summary" # Stale summary
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a mock for the invoke method
|
||||||
|
mock_invoke = MagicMock()
|
||||||
|
mock_invoke.return_value = {
|
||||||
|
"summary": "Actual Worker Findings",
|
||||||
|
"messages": [AIMessage(content="Actual Worker Findings")],
|
||||||
|
"vfs": {},
|
||||||
|
"plots": []
|
||||||
|
}
|
||||||
|
|
||||||
|
# Manually replace the runnable with a mock object that has an invoke method
|
||||||
|
mock_runnable = MagicMock()
|
||||||
|
mock_runnable.invoke = mock_invoke
|
||||||
|
monkeypatch.setattr("ea_chatbot.graph.workflow.data_analyst_worker_runnable", mock_runnable)
|
||||||
|
|
||||||
|
# Execute via the module reference (which is now mocked)
|
||||||
|
from ea_chatbot.graph.workflow import data_analyst_worker_runnable
|
||||||
|
updates = data_analyst_worker_runnable.invoke(state)
|
||||||
|
|
||||||
|
# Verify that 'summary' is in updates and has the worker result
|
||||||
|
assert "summary" in updates
|
||||||
|
assert updates["summary"] == "Actual Worker Findings"
|
||||||
|
|
||||||
|
# When applied to state, it should overwrite the stale summary
|
||||||
|
state.update(updates)
|
||||||
|
assert state["summary"] == "Actual Worker Findings"
|
||||||
35
backend/tests/test_review_fix_reflector.py
Normal file
35
backend/tests/test_review_fix_reflector.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
from ea_chatbot.graph.nodes.reflector import reflector_node
|
||||||
|
from ea_chatbot.graph.state import AgentState
|
||||||
|
|
||||||
|
def test_reflector_does_not_advance_on_failure():
|
||||||
|
"""Verify that reflector does not increment current_step if not satisfied."""
|
||||||
|
state = AgentState(
|
||||||
|
checklist=[{"task": "Task 1", "worker": "data_analyst"}],
|
||||||
|
current_step=0,
|
||||||
|
messages=[],
|
||||||
|
question="test",
|
||||||
|
analysis={},
|
||||||
|
next_action="",
|
||||||
|
iterations=0,
|
||||||
|
vfs={},
|
||||||
|
plots=[],
|
||||||
|
dfs={},
|
||||||
|
summary="Failed output"
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("ea_chatbot.graph.nodes.reflector.get_llm_model") as mock_get_llm:
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
# Mark as NOT satisfied
|
||||||
|
mock_llm.with_structured_output.return_value.invoke.return_value = MagicMock(satisfied=False, reasoning="Incomplete.")
|
||||||
|
mock_get_llm.return_value = mock_llm
|
||||||
|
|
||||||
|
result = reflector_node(state)
|
||||||
|
|
||||||
|
# Should NOT increment (therefore not in the updates dict)
|
||||||
|
assert "current_step" not in result
|
||||||
|
# Should increment iterations
|
||||||
|
assert result["iterations"] == 1
|
||||||
|
# Should probably route to planner or retry
|
||||||
|
assert result["next_action"] == "delegate" # Or 'planner' if we want re-planning
|
||||||
51
backend/tests/test_review_fix_summary.py
Normal file
51
backend/tests/test_review_fix_summary.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
import pytest
|
||||||
|
from ea_chatbot.graph.workers.data_analyst.mapping import merge_worker_output as merge_analyst
|
||||||
|
from ea_chatbot.graph.workers.researcher.mapping import merge_researcher_output as merge_researcher
|
||||||
|
from ea_chatbot.graph.workers.data_analyst.state import WorkerState as AnalystState
|
||||||
|
from ea_chatbot.graph.workers.researcher.state import WorkerState as ResearcherState
|
||||||
|
|
||||||
|
def test_analyst_merge_updates_summary():
|
||||||
|
"""Verify that analyst merge updates the global summary."""
|
||||||
|
worker_state = AnalystState(
|
||||||
|
result="Actual Worker Result",
|
||||||
|
messages=[],
|
||||||
|
task="test",
|
||||||
|
iterations=1,
|
||||||
|
vfs_state={},
|
||||||
|
plots=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
updates = merge_analyst(worker_state)
|
||||||
|
assert "summary" in updates
|
||||||
|
assert updates["summary"] == "Actual Worker Result"
|
||||||
|
|
||||||
|
def test_researcher_merge_updates_summary():
|
||||||
|
"""Verify that researcher merge updates the global summary."""
|
||||||
|
worker_state = ResearcherState(
|
||||||
|
result="Actual Research Result",
|
||||||
|
messages=[],
|
||||||
|
task="test",
|
||||||
|
iterations=1,
|
||||||
|
queries=[],
|
||||||
|
raw_results=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
updates = merge_researcher(worker_state)
|
||||||
|
assert "summary" in updates
|
||||||
|
assert updates["summary"] == "Actual Research Result"
|
||||||
|
|
||||||
|
def test_merge_handles_none_result():
|
||||||
|
"""Verify that merge functions handle None results gracefully."""
|
||||||
|
worker_state = AnalystState(
|
||||||
|
result=None,
|
||||||
|
messages=[],
|
||||||
|
task="test",
|
||||||
|
iterations=1,
|
||||||
|
vfs_state={},
|
||||||
|
plots=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
updates = merge_analyst(worker_state)
|
||||||
|
assert updates["summary"] is not None
|
||||||
|
assert isinstance(updates["messages"][0].content, str)
|
||||||
|
assert len(updates["messages"][0].content) > 0
|
||||||
85
backend/tests/test_review_fix_unbounded_loop.py
Normal file
85
backend/tests/test_review_fix_unbounded_loop.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
from ea_chatbot.graph.workflow import create_workflow
|
||||||
|
from ea_chatbot.graph.state import AgentState
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage
|
||||||
|
|
||||||
|
def test_orchestrator_loop_retry_budget():
|
||||||
|
"""Verify that the orchestrator loop is bounded and terminates after max retries."""
|
||||||
|
|
||||||
|
mock_analyzer = MagicMock()
|
||||||
|
mock_planner = MagicMock()
|
||||||
|
mock_delegate = MagicMock()
|
||||||
|
mock_worker = MagicMock()
|
||||||
|
mock_reflector = MagicMock()
|
||||||
|
mock_synthesizer = MagicMock()
|
||||||
|
|
||||||
|
# 1. Analyzer: Proceed to planning
|
||||||
|
mock_analyzer.return_value = {"next_action": "plan"}
|
||||||
|
|
||||||
|
# 2. Planner: One task
|
||||||
|
mock_planner.return_value = {
|
||||||
|
"checklist": [{"task": "Unsolvable Task", "worker": "data_analyst"}],
|
||||||
|
"current_step": 0,
|
||||||
|
"iterations": 0
|
||||||
|
}
|
||||||
|
|
||||||
|
# We'll use the REAL delegate and reflector logic to verify the fix
|
||||||
|
# But we mock the LLM calls inside them if necessary.
|
||||||
|
# Actually, it's easier to just mock the node return values but follow the logic.
|
||||||
|
|
||||||
|
from ea_chatbot.graph.nodes.delegate import delegate_node
|
||||||
|
from ea_chatbot.graph.nodes.reflector import reflector_node
|
||||||
|
|
||||||
|
# Mocking the LLM inside reflector to always be unsatisfied
|
||||||
|
with patch("ea_chatbot.graph.nodes.reflector.get_llm_model") as mock_get_llm:
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
# Mark as NOT satisfied
|
||||||
|
mock_llm.with_structured_output.return_value.invoke.return_value = MagicMock(satisfied=False, reasoning="Still bad.")
|
||||||
|
mock_get_llm.return_value = mock_llm
|
||||||
|
|
||||||
|
app = create_workflow(
|
||||||
|
query_analyzer=mock_analyzer,
|
||||||
|
planner=mock_planner,
|
||||||
|
# delegate=delegate_node, # Use real
|
||||||
|
data_analyst_worker=mock_worker,
|
||||||
|
# reflector=reflector_node, # Use real
|
||||||
|
synthesizer=mock_synthesizer
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock worker to return something
|
||||||
|
mock_worker.return_value = {"result": "Bad Output", "messages": [AIMessage(content="Bad")]}
|
||||||
|
mock_synthesizer.return_value = {"messages": [AIMessage(content="Failure Summary")], "next_action": "end"}
|
||||||
|
|
||||||
|
initial_state = AgentState(
|
||||||
|
messages=[HumanMessage(content="test")],
|
||||||
|
question="test",
|
||||||
|
analysis={},
|
||||||
|
next_action="",
|
||||||
|
iterations=0,
|
||||||
|
checklist=[],
|
||||||
|
current_step=0,
|
||||||
|
vfs={},
|
||||||
|
plots=[],
|
||||||
|
dfs={}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run the graph. If fix works, it should hit iterations=3 and route to synthesizer.
|
||||||
|
# We use a recursion_limit higher than our retry budget but low enough to fail fast if unbounded.
|
||||||
|
final_state = app.invoke(initial_state, config={"recursion_limit": 20})
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
# 1. We tried 3 times (iterations 0, 1, 2) and failed on 3rd.
|
||||||
|
# Wait, delegate routes to summarize when iterations >= 3.
|
||||||
|
# Reflector increments iterations.
|
||||||
|
# Loop:
|
||||||
|
# Start: it=0
|
||||||
|
# Delegate (it=0) -> Worker -> Reflector (fail, it=1) -> Delegate (it=1)
|
||||||
|
# Delegate (it=1) -> Worker -> Reflector (fail, it=2) -> Delegate (it=2)
|
||||||
|
# Delegate (it=2) -> Worker -> Reflector (fail, it=3) -> Delegate (it=3)
|
||||||
|
# Delegate (it=3) -> Summarize (it=0)
|
||||||
|
|
||||||
|
assert final_state["iterations"] == 0 # Reset in delegate or handled in synthesizer
|
||||||
|
# Check if we hit the failure summary
|
||||||
|
assert "Failed to complete task" in final_state["summary"]
|
||||||
|
assert mock_worker.call_count == 3
|
||||||
50
backend/tests/test_review_fix_vfs.py
Normal file
50
backend/tests/test_review_fix_vfs.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
import pytest
|
||||||
|
from ea_chatbot.utils.vfs import VFSHelper
|
||||||
|
from ea_chatbot.graph.workers.data_analyst.mapping import prepare_worker_input
|
||||||
|
from ea_chatbot.graph.state import AgentState
|
||||||
|
|
||||||
|
def test_vfs_isolation_deep_copy():
|
||||||
|
"""Verify that worker VFS is deep-copied from global state."""
|
||||||
|
global_vfs = {
|
||||||
|
"data.txt": {
|
||||||
|
"content": "original",
|
||||||
|
"metadata": {"tags": ["a"]}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
state = AgentState(
|
||||||
|
checklist=[{"task": "test", "worker": "data_analyst"}],
|
||||||
|
current_step=0,
|
||||||
|
messages=[],
|
||||||
|
question="test",
|
||||||
|
analysis={},
|
||||||
|
next_action="",
|
||||||
|
iterations=0,
|
||||||
|
vfs=global_vfs,
|
||||||
|
plots=[],
|
||||||
|
dfs={}
|
||||||
|
)
|
||||||
|
|
||||||
|
worker_input = prepare_worker_input(state)
|
||||||
|
worker_vfs = worker_input["vfs_state"]
|
||||||
|
|
||||||
|
# Mutate worker VFS nested object
|
||||||
|
worker_vfs["data.txt"]["metadata"]["tags"].append("b")
|
||||||
|
|
||||||
|
# Global VFS should remain unchanged
|
||||||
|
assert global_vfs["data.txt"]["metadata"]["tags"] == ["a"]
|
||||||
|
|
||||||
|
def test_vfs_schema_normalization():
|
||||||
|
"""Verify that VFSHelper handles inconsistent VFS entries."""
|
||||||
|
vfs = {
|
||||||
|
"raw.txt": "just a string", # Inconsistent with standard Dict[str, Any] entry
|
||||||
|
"valid.txt": {"content": "data", "metadata": {}}
|
||||||
|
}
|
||||||
|
helper = VFSHelper(vfs)
|
||||||
|
|
||||||
|
# Should not crash during read
|
||||||
|
content, metadata = helper.read("raw.txt")
|
||||||
|
assert content == "just a string"
|
||||||
|
assert metadata == {}
|
||||||
|
|
||||||
|
# Should not crash during list/other ops if they assume dict
|
||||||
|
assert "raw.txt" in helper.list()
|
||||||
15
backend/tests/test_state_extensions.py
Normal file
15
backend/tests/test_state_extensions.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
from typing import get_type_hints, List, Dict, Any
|
||||||
|
from ea_chatbot.graph.state import AgentState
|
||||||
|
|
||||||
|
def test_agent_state_extensions():
|
||||||
|
"""Verify that AgentState has the new DeepAgents fields."""
|
||||||
|
hints = get_type_hints(AgentState)
|
||||||
|
|
||||||
|
assert "checklist" in hints
|
||||||
|
# checklist: List[Dict[str, Any]] (or similar for tasks)
|
||||||
|
|
||||||
|
assert "current_step" in hints
|
||||||
|
assert hints["current_step"] == int
|
||||||
|
|
||||||
|
assert "vfs" in hints
|
||||||
|
# vfs: Dict[str, Any]
|
||||||
@@ -1,47 +0,0 @@
|
|||||||
import pytest
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
from langchain_core.messages import AIMessage
|
|
||||||
from ea_chatbot.graph.nodes.summarizer import summarizer_node
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_llm():
|
|
||||||
with patch("ea_chatbot.graph.nodes.summarizer.get_llm_model") as mock_get_llm:
|
|
||||||
mock_llm_instance = MagicMock()
|
|
||||||
mock_get_llm.return_value = mock_llm_instance
|
|
||||||
yield mock_llm_instance
|
|
||||||
|
|
||||||
def test_summarizer_node_success(mock_llm):
|
|
||||||
"""Test that summarizer_node invokes LLM with correct inputs and returns messages."""
|
|
||||||
state = {
|
|
||||||
"question": "What is the total count?",
|
|
||||||
"plan": "1. Run query\n2. Sum results",
|
|
||||||
"code_output": "The total is 100",
|
|
||||||
"messages": []
|
|
||||||
}
|
|
||||||
|
|
||||||
mock_llm.invoke.return_value = AIMessage(content="The final answer is 100.")
|
|
||||||
|
|
||||||
result = summarizer_node(state)
|
|
||||||
|
|
||||||
# Verify LLM was called
|
|
||||||
assert mock_llm.invoke.called
|
|
||||||
|
|
||||||
# Verify result structure
|
|
||||||
assert "messages" in result
|
|
||||||
assert len(result["messages"]) == 1
|
|
||||||
assert isinstance(result["messages"][0], AIMessage)
|
|
||||||
assert result["messages"][0].content == "The final answer is 100."
|
|
||||||
|
|
||||||
def test_summarizer_node_empty_state(mock_llm):
|
|
||||||
"""Test handling of empty or minimal state."""
|
|
||||||
state = {
|
|
||||||
"question": "Empty?",
|
|
||||||
"messages": []
|
|
||||||
}
|
|
||||||
|
|
||||||
mock_llm.invoke.return_value = AIMessage(content="No data provided.")
|
|
||||||
|
|
||||||
result = summarizer_node(state)
|
|
||||||
|
|
||||||
assert "messages" in result
|
|
||||||
assert result["messages"][0].content == "No data provided."
|
|
||||||
49
backend/tests/test_vfs.py
Normal file
49
backend/tests/test_vfs.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
import pytest
|
||||||
|
from ea_chatbot.utils.vfs import VFSHelper
|
||||||
|
|
||||||
|
def test_vfs_read_write():
|
||||||
|
"""Verify that we can write and read from the VFS."""
|
||||||
|
vfs = {}
|
||||||
|
helper = VFSHelper(vfs)
|
||||||
|
|
||||||
|
helper.write("test.txt", "Hello World", metadata={"type": "text"})
|
||||||
|
assert "test.txt" in vfs
|
||||||
|
assert vfs["test.txt"]["content"] == "Hello World"
|
||||||
|
assert vfs["test.txt"]["metadata"]["type"] == "text"
|
||||||
|
|
||||||
|
content, metadata = helper.read("test.txt")
|
||||||
|
assert content == "Hello World"
|
||||||
|
assert metadata["type"] == "text"
|
||||||
|
|
||||||
|
def test_vfs_list():
|
||||||
|
"""Verify that we can list files in the VFS."""
|
||||||
|
vfs = {}
|
||||||
|
helper = VFSHelper(vfs)
|
||||||
|
|
||||||
|
helper.write("a.txt", "content a")
|
||||||
|
helper.write("b.txt", "content b")
|
||||||
|
|
||||||
|
files = helper.list()
|
||||||
|
assert "a.txt" in files
|
||||||
|
assert "b.txt" in files
|
||||||
|
assert len(files) == 2
|
||||||
|
|
||||||
|
def test_vfs_delete():
|
||||||
|
"""Verify that we can delete files from the VFS."""
|
||||||
|
vfs = {}
|
||||||
|
helper = VFSHelper(vfs)
|
||||||
|
|
||||||
|
helper.write("test.txt", "Hello")
|
||||||
|
helper.delete("test.txt")
|
||||||
|
|
||||||
|
assert "test.txt" not in vfs
|
||||||
|
assert len(helper.list()) == 0
|
||||||
|
|
||||||
|
def test_vfs_not_found():
|
||||||
|
"""Verify that reading a non-existent file returns None."""
|
||||||
|
vfs = {}
|
||||||
|
helper = VFSHelper(vfs)
|
||||||
|
|
||||||
|
content, metadata = helper.read("missing.txt")
|
||||||
|
assert content is None
|
||||||
|
assert metadata is None
|
||||||
48
backend/tests/test_vfs_robustness.py
Normal file
48
backend/tests/test_vfs_robustness.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
import pytest
|
||||||
|
import threading
|
||||||
|
from ea_chatbot.utils.vfs import safe_vfs_copy
|
||||||
|
|
||||||
|
def test_safe_vfs_copy_success():
|
||||||
|
"""Test standard success case."""
|
||||||
|
vfs = {
|
||||||
|
"test.csv": {"content": "data", "metadata": {"type": "csv"}},
|
||||||
|
"num": 42
|
||||||
|
}
|
||||||
|
copied = safe_vfs_copy(vfs)
|
||||||
|
assert copied == vfs
|
||||||
|
assert copied is not vfs
|
||||||
|
assert copied["test.csv"] is not vfs["test.csv"]
|
||||||
|
|
||||||
|
def test_safe_vfs_copy_handles_non_copyable():
|
||||||
|
"""Test replacing uncopyable objects with error placeholders."""
|
||||||
|
# A threading.Lock is famously uncopyable
|
||||||
|
lock = threading.Lock()
|
||||||
|
|
||||||
|
vfs = {
|
||||||
|
"safe_file": "important data",
|
||||||
|
"unsafe_lock": lock
|
||||||
|
}
|
||||||
|
|
||||||
|
copied = safe_vfs_copy(vfs)
|
||||||
|
|
||||||
|
# Safe one remains
|
||||||
|
assert copied["safe_file"] == "important data"
|
||||||
|
|
||||||
|
# Unsafe one is REPLACED with an error dict
|
||||||
|
assert isinstance(copied["unsafe_lock"], dict)
|
||||||
|
assert "content" in copied["unsafe_lock"]
|
||||||
|
assert "ERROR" in copied["unsafe_lock"]["content"]
|
||||||
|
assert copied["unsafe_lock"]["metadata"]["type"] == "error"
|
||||||
|
assert "lock" in str(copied["unsafe_lock"]["metadata"]["error"]).lower()
|
||||||
|
|
||||||
|
# Original is unchanged (it was a lock)
|
||||||
|
assert vfs["unsafe_lock"] is lock
|
||||||
|
|
||||||
|
def test_safe_vfs_copy_preserves_nested_copyable():
|
||||||
|
"""Test deepcopy still works for complex but copyable objects."""
|
||||||
|
vfs = {
|
||||||
|
"data": {"a": [1, 2, 3], "b": {"c": True}}
|
||||||
|
}
|
||||||
|
copied = safe_vfs_copy(vfs)
|
||||||
|
assert copied["data"]["a"] == [1, 2, 3]
|
||||||
|
assert copied["data"]["a"] is not vfs["data"]["a"]
|
||||||
@@ -1,92 +1,80 @@
|
|||||||
|
import pytest
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
from ea_chatbot.graph.workflow import app
|
from ea_chatbot.graph.workflow import create_workflow
|
||||||
from ea_chatbot.graph.nodes.query_analyzer import QueryAnalysis
|
from ea_chatbot.graph.state import AgentState
|
||||||
from ea_chatbot.schemas import TaskPlanResponse, TaskPlanContext, CodeGenerationResponse
|
from langchain_core.messages import AIMessage, HumanMessage
|
||||||
|
|
||||||
from langchain_core.messages import AIMessage
|
def test_workflow_full_flow():
|
||||||
|
"""Test the full Orchestrator-Workers flow using node injection."""
|
||||||
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
|
|
||||||
@patch("ea_chatbot.graph.nodes.planner.get_llm_model")
|
|
||||||
@patch("ea_chatbot.graph.nodes.coder.get_llm_model")
|
|
||||||
@patch("ea_chatbot.graph.nodes.summarizer.get_llm_model")
|
|
||||||
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
|
|
||||||
@patch("ea_chatbot.utils.database_inspection.get_data_summary")
|
|
||||||
@patch("ea_chatbot.graph.nodes.executor.Settings")
|
|
||||||
@patch("ea_chatbot.graph.nodes.executor.DBClient")
|
|
||||||
def test_workflow_full_flow(mock_db_client, mock_settings, mock_get_summary, mock_researcher_llm, mock_summarizer_llm, mock_coder_llm, mock_planner_llm, mock_qa_llm):
|
|
||||||
"""Test the flow from query_analyzer through planner to coder."""
|
|
||||||
|
|
||||||
# Mock Settings for Executor
|
mock_analyzer = MagicMock()
|
||||||
mock_settings_instance = MagicMock()
|
mock_planner = MagicMock()
|
||||||
mock_settings_instance.db_host = "localhost"
|
mock_delegate = MagicMock()
|
||||||
mock_settings_instance.db_port = 5432
|
mock_worker = MagicMock()
|
||||||
mock_settings_instance.db_user = "user"
|
mock_reflector = MagicMock()
|
||||||
mock_settings_instance.db_pswd = "pass"
|
mock_synthesizer = MagicMock()
|
||||||
mock_settings_instance.db_name = "test_db"
|
mock_summarize_conv = MagicMock()
|
||||||
mock_settings_instance.db_table = "test_table"
|
|
||||||
mock_settings.return_value = mock_settings_instance
|
|
||||||
|
|
||||||
# Mock DBClient
|
# 1. Analyzer: Proceed to planning
|
||||||
mock_client_instance = MagicMock()
|
mock_analyzer.return_value = {"next_action": "plan"}
|
||||||
mock_db_client.return_value = mock_client_instance
|
|
||||||
|
|
||||||
# 1. Mock Query Analyzer
|
# 2. Planner: Generate checklist
|
||||||
mock_qa_instance = MagicMock()
|
mock_planner.return_value = {
|
||||||
mock_qa_llm.return_value = mock_qa_instance
|
"checklist": [{"task": "Step 1", "worker": "data_analyst"}],
|
||||||
mock_qa_instance.with_structured_output.return_value.invoke.return_value = QueryAnalysis(
|
"current_step": 0
|
||||||
data_required=["2024 results"],
|
}
|
||||||
unknowns=[],
|
|
||||||
ambiguities=[],
|
# 3. Delegate: Route to data_analyst
|
||||||
conditions=[],
|
mock_delegate.side_effect = [
|
||||||
next_action="plan"
|
{"next_action": "data_analyst"},
|
||||||
|
{"next_action": "summarize"}
|
||||||
|
]
|
||||||
|
|
||||||
|
# 4. Worker: Success
|
||||||
|
mock_worker.return_value = {
|
||||||
|
"messages": [AIMessage(content="Worker Summary")],
|
||||||
|
"vfs": {}
|
||||||
|
}
|
||||||
|
|
||||||
|
# 5. Reflector: Advance
|
||||||
|
mock_reflector.return_value = {
|
||||||
|
"current_step": 1,
|
||||||
|
"next_action": "delegate"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 6. Synthesizer: Final answer
|
||||||
|
mock_synthesizer.return_value = {
|
||||||
|
"messages": [AIMessage(content="Final Summary")],
|
||||||
|
"next_action": "end"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 7. Summarize Conv: End
|
||||||
|
mock_summarize_conv.return_value = {"summary": "Done"}
|
||||||
|
|
||||||
|
app = create_workflow(
|
||||||
|
query_analyzer=mock_analyzer,
|
||||||
|
planner=mock_planner,
|
||||||
|
delegate=mock_delegate,
|
||||||
|
data_analyst_worker=mock_worker,
|
||||||
|
reflector=mock_reflector,
|
||||||
|
synthesizer=mock_synthesizer,
|
||||||
|
summarize_conversation=mock_summarize_conv
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. Mock Planner
|
|
||||||
mock_planner_instance = MagicMock()
|
|
||||||
mock_planner_llm.return_value = mock_planner_instance
|
|
||||||
mock_get_summary.return_value = "Data summary"
|
|
||||||
mock_planner_instance.with_structured_output.return_value.invoke.return_value = TaskPlanResponse(
|
|
||||||
goal="Task Goal",
|
|
||||||
reflection="Reflection",
|
|
||||||
context=TaskPlanContext(initial_context="Ctx", assumptions=[], constraints=[]),
|
|
||||||
steps=["Step 1"]
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. Mock Coder
|
|
||||||
mock_coder_instance = MagicMock()
|
|
||||||
mock_coder_llm.return_value = mock_coder_instance
|
|
||||||
mock_coder_instance.with_structured_output.return_value.invoke.return_value = CodeGenerationResponse(
|
|
||||||
code="print('Hello')",
|
|
||||||
explanation="Explanation"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 4. Mock Summarizer
|
|
||||||
mock_summarizer_instance = MagicMock()
|
|
||||||
mock_summarizer_llm.return_value = mock_summarizer_instance
|
|
||||||
mock_summarizer_instance.invoke.return_value = AIMessage(content="Summary")
|
|
||||||
|
|
||||||
# 5. Mock Researcher (just in case)
|
|
||||||
mock_researcher_instance = MagicMock()
|
|
||||||
mock_researcher_llm.return_value = mock_researcher_instance
|
|
||||||
|
|
||||||
# Initial state
|
|
||||||
initial_state = {
|
initial_state = {
|
||||||
"messages": [],
|
"messages": [HumanMessage(content="Show me results")],
|
||||||
"question": "Show me the 2024 results",
|
"question": "Show me results",
|
||||||
"analysis": None,
|
"analysis": None,
|
||||||
"next_action": "",
|
"next_action": "",
|
||||||
"plan": None,
|
"iterations": 0,
|
||||||
"code": None,
|
"checklist": [],
|
||||||
"error": None,
|
"current_step": 0,
|
||||||
|
"vfs": {},
|
||||||
"plots": [],
|
"plots": [],
|
||||||
"dfs": {}
|
"dfs": {}
|
||||||
}
|
}
|
||||||
|
|
||||||
# Run the graph
|
result = app.invoke(initial_state, config={"recursion_limit": 20})
|
||||||
# We use recursion_limit to avoid infinite loops in placeholders if any
|
|
||||||
result = app.invoke(initial_state, config={"recursion_limit": 10})
|
|
||||||
|
|
||||||
assert result["next_action"] == "plan"
|
assert "Final Summary" in [m.content for m in result["messages"]]
|
||||||
assert "plan" in result and result["plan"] is not None
|
assert result["current_step"] == 1
|
||||||
assert "code" in result and "print('Hello')" in result["code"]
|
|
||||||
assert "analysis" in result
|
|
||||||
|
|||||||
@@ -1,60 +1,55 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
from langchain_core.messages import AIMessage
|
|
||||||
from ea_chatbot.graph.workflow import app
|
from ea_chatbot.graph.workflow import app
|
||||||
from ea_chatbot.graph.nodes.query_analyzer import QueryAnalysis
|
from ea_chatbot.schemas import QueryAnalysis, ChecklistResponse, ChecklistTask, CodeGenerationResponse
|
||||||
from ea_chatbot.schemas import TaskPlanResponse, TaskPlanContext, CodeGenerationResponse
|
from ea_chatbot.graph.state import AgentState
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_llms():
|
def mock_llms():
|
||||||
with patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model") as mock_qa_llm, \
|
with patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model") as mock_qa, \
|
||||||
patch("ea_chatbot.graph.nodes.planner.get_llm_model") as mock_planner_llm, \
|
patch("ea_chatbot.graph.nodes.planner.get_llm_model") as mock_planner, \
|
||||||
patch("ea_chatbot.graph.nodes.coder.get_llm_model") as mock_coder_llm, \
|
patch("ea_chatbot.graph.workers.data_analyst.nodes.coder.get_llm_model") as mock_coder, \
|
||||||
patch("ea_chatbot.graph.nodes.summarizer.get_llm_model") as mock_summarizer_llm, \
|
patch("ea_chatbot.graph.workers.data_analyst.nodes.summarizer.get_llm_model") as mock_worker_summarizer, \
|
||||||
patch("ea_chatbot.graph.nodes.researcher.get_llm_model") as mock_researcher_llm, \
|
patch("ea_chatbot.graph.nodes.synthesizer.get_llm_model") as mock_synthesizer, \
|
||||||
patch("ea_chatbot.graph.nodes.summarize_conversation.get_llm_model") as mock_summary_llm, \
|
patch("ea_chatbot.graph.workers.researcher.nodes.searcher.get_llm_model") as mock_researcher, \
|
||||||
patch("ea_chatbot.utils.database_inspection.get_data_summary") as mock_get_summary:
|
patch("ea_chatbot.graph.workers.researcher.nodes.summarizer.get_llm_model") as mock_res_summarizer, \
|
||||||
mock_get_summary.return_value = "Data summary"
|
patch("ea_chatbot.graph.nodes.reflector.get_llm_model") as mock_reflector:
|
||||||
|
|
||||||
# Mock summary LLM to return a simple response
|
|
||||||
mock_summary_instance = MagicMock()
|
|
||||||
mock_summary_llm.return_value = mock_summary_instance
|
|
||||||
mock_summary_instance.invoke.return_value = AIMessage(content="Turn summary")
|
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
"qa": mock_qa_llm,
|
"qa": mock_qa,
|
||||||
"planner": mock_planner_llm,
|
"planner": mock_planner,
|
||||||
"coder": mock_coder_llm,
|
"coder": mock_coder,
|
||||||
"summarizer": mock_summarizer_llm,
|
"worker_summarizer": mock_worker_summarizer,
|
||||||
"researcher": mock_researcher_llm,
|
"synthesizer": mock_synthesizer,
|
||||||
"summary": mock_summary_llm
|
"researcher": mock_researcher,
|
||||||
|
"res_summarizer": mock_res_summarizer,
|
||||||
|
"reflector": mock_reflector
|
||||||
}
|
}
|
||||||
|
|
||||||
def test_workflow_data_analysis_flow(mock_llms):
|
def test_workflow_data_analysis_flow(mock_llms):
|
||||||
"""Test full flow: QueryAnalyzer -> Planner -> Coder -> Executor -> Summarizer."""
|
"""Test full flow: QueryAnalyzer -> Planner -> Delegate -> DataAnalyst -> Reflector -> Synthesizer."""
|
||||||
|
|
||||||
# 1. Mock Query Analyzer (routes to plan)
|
# 1. Mock Query Analyzer
|
||||||
mock_qa_instance = MagicMock()
|
mock_qa_instance = MagicMock()
|
||||||
mock_llms["qa"].return_value = mock_qa_instance
|
mock_llms["qa"].return_value = mock_qa_instance
|
||||||
mock_qa_instance.with_structured_output.return_value.invoke.return_value = QueryAnalysis(
|
mock_qa_instance.with_structured_output.return_value.invoke.return_value = QueryAnalysis(
|
||||||
data_required=["2024 results"],
|
data_required=["2024 results"],
|
||||||
unknowns=[],
|
unknowns=[],
|
||||||
ambiguities=[],
|
ambiguities=[],
|
||||||
conditions=[],
|
conditions=[],
|
||||||
next_action="plan"
|
next_action="plan"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. Mock Planner
|
# 2. Mock Planner
|
||||||
mock_planner_instance = MagicMock()
|
mock_planner_instance = MagicMock()
|
||||||
mock_llms["planner"].return_value = mock_planner_instance
|
mock_llms["planner"].return_value = mock_planner_instance
|
||||||
mock_planner_instance.with_structured_output.return_value.invoke.return_value = TaskPlanResponse(
|
mock_planner_instance.with_structured_output.return_value.invoke.return_value = ChecklistResponse(
|
||||||
goal="Get results",
|
goal="Get results",
|
||||||
reflection="Reflect",
|
reflection="Reflect",
|
||||||
context=TaskPlanContext(initial_context="Ctx", assumptions=[], constraints=[]),
|
checklist=[ChecklistTask(task="Query Data", worker="data_analyst")]
|
||||||
steps=["Step 1"]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3. Mock Coder
|
# 3. Mock Coder (Worker)
|
||||||
mock_coder_instance = MagicMock()
|
mock_coder_instance = MagicMock()
|
||||||
mock_llms["coder"].return_value = mock_coder_instance
|
mock_llms["coder"].return_value = mock_coder_instance
|
||||||
mock_coder_instance.with_structured_output.return_value.invoke.return_value = CodeGenerationResponse(
|
mock_coder_instance.with_structured_output.return_value.invoke.return_value = CodeGenerationResponse(
|
||||||
@@ -62,10 +57,20 @@ def test_workflow_data_analysis_flow(mock_llms):
|
|||||||
explanation="Explain"
|
explanation="Explain"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 4. Mock Summarizer
|
# 4. Mock Worker Summarizer
|
||||||
mock_summarizer_instance = MagicMock()
|
mock_ws_instance = MagicMock()
|
||||||
mock_llms["summarizer"].return_value = mock_summarizer_instance
|
mock_llms["worker_summarizer"].return_value = mock_ws_instance
|
||||||
mock_summarizer_instance.invoke.return_value = AIMessage(content="Final Summary: Success")
|
mock_ws_instance.invoke.return_value = AIMessage(content="Worker Summary")
|
||||||
|
|
||||||
|
# 5. Mock Reflector
|
||||||
|
mock_reflector_instance = MagicMock()
|
||||||
|
mock_llms["reflector"].return_value = mock_reflector_instance
|
||||||
|
mock_reflector_instance.with_structured_output.return_value.invoke.return_value = MagicMock(satisfied=True, reasoning="Good.")
|
||||||
|
|
||||||
|
# 6. Mock Synthesizer
|
||||||
|
mock_syn_instance = MagicMock()
|
||||||
|
mock_llms["synthesizer"].return_value = mock_syn_instance
|
||||||
|
mock_syn_instance.invoke.return_value = AIMessage(content="Final Summary: Success")
|
||||||
|
|
||||||
# Initial state
|
# Initial state
|
||||||
initial_state = {
|
initial_state = {
|
||||||
@@ -73,66 +78,77 @@ def test_workflow_data_analysis_flow(mock_llms):
|
|||||||
"question": "Show me 2024 results",
|
"question": "Show me 2024 results",
|
||||||
"analysis": None,
|
"analysis": None,
|
||||||
"next_action": "",
|
"next_action": "",
|
||||||
"plan": None,
|
"iterations": 0,
|
||||||
"code": None,
|
"checklist": [],
|
||||||
"error": None,
|
"current_step": 0,
|
||||||
|
"vfs": {},
|
||||||
"plots": [],
|
"plots": [],
|
||||||
"dfs": {}
|
"dfs": {}
|
||||||
}
|
}
|
||||||
|
|
||||||
# Run the graph
|
# Run the graph
|
||||||
result = app.invoke(initial_state, config={"recursion_limit": 15})
|
result = app.invoke(initial_state, config={"recursion_limit": 20})
|
||||||
|
|
||||||
assert result["next_action"] == "plan"
|
assert "Final Summary: Success" in [m.content for m in result["messages"]]
|
||||||
assert "Execution Success" in result["code_output"]
|
assert result["current_step"] == 1
|
||||||
assert "Final Summary: Success" in result["messages"][-1].content
|
|
||||||
|
|
||||||
def test_workflow_research_flow(mock_llms):
|
def test_workflow_research_flow(mock_llms):
|
||||||
"""Test flow: QueryAnalyzer -> Researcher -> Summarizer."""
|
"""Test flow with research task."""
|
||||||
|
|
||||||
# 1. Mock Query Analyzer (routes to research)
|
# 1. Mock Query Analyzer
|
||||||
mock_qa_instance = MagicMock()
|
mock_qa_instance = MagicMock()
|
||||||
mock_llms["qa"].return_value = mock_qa_instance
|
mock_llms["qa"].return_value = mock_qa_instance
|
||||||
mock_qa_instance.with_structured_output.return_value.invoke.return_value = QueryAnalysis(
|
mock_qa_instance.with_structured_output.return_value.invoke.return_value = QueryAnalysis(
|
||||||
data_required=[],
|
data_required=[],
|
||||||
unknowns=[],
|
unknowns=[],
|
||||||
ambiguities=[],
|
ambiguities=[],
|
||||||
conditions=[],
|
conditions=[],
|
||||||
next_action="research"
|
next_action="research"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. Mock Researcher
|
# 2. Mock Planner
|
||||||
mock_researcher_instance = MagicMock()
|
mock_planner_instance = MagicMock()
|
||||||
mock_llms["researcher"].return_value = mock_researcher_instance
|
mock_llms["planner"].return_value = mock_planner_instance
|
||||||
# Researcher node uses bind_tools if it's ChatOpenAI/ChatGoogleGenerativeAI
|
mock_planner_instance.with_structured_output.return_value.invoke.return_value = ChecklistResponse(
|
||||||
# Since it's a MagicMock, it will fallback to using the base instance
|
goal="Search",
|
||||||
mock_researcher_instance.invoke.return_value = AIMessage(content="Research Results")
|
reflection="Reflect",
|
||||||
|
checklist=[ChecklistTask(task="Search Web", worker="researcher")]
|
||||||
|
)
|
||||||
|
|
||||||
# Also mock bind_tools just in case we ever use spec
|
# 3. Mock Researcher Searcher
|
||||||
mock_llm_with_tools = MagicMock()
|
mock_res_instance = MagicMock()
|
||||||
mock_researcher_instance.bind_tools.return_value = mock_llm_with_tools
|
mock_llms["researcher"].return_value = mock_res_instance
|
||||||
mock_llm_with_tools.invoke.return_value = AIMessage(content="Research Results")
|
mock_res_instance.invoke.return_value = AIMessage(content="Research Result")
|
||||||
|
|
||||||
# 3. Mock Summarizer (not used in this flow, but kept for completeness)
|
# 4. Mock Researcher Summarizer
|
||||||
mock_summarizer_instance = MagicMock()
|
mock_rs_instance = MagicMock()
|
||||||
mock_llms["summarizer"].return_value = mock_summarizer_instance
|
mock_llms["res_summarizer"].return_value = mock_rs_instance
|
||||||
mock_summarizer_instance.invoke.return_value = AIMessage(content="Final Summary: Research Success")
|
mock_rs_instance.invoke.return_value = AIMessage(content="Researcher Summary")
|
||||||
|
|
||||||
|
# 5. Mock Reflector
|
||||||
|
mock_reflector_instance = MagicMock()
|
||||||
|
mock_llms["reflector"].return_value = mock_reflector_instance
|
||||||
|
mock_reflector_instance.with_structured_output.return_value.invoke.return_value = MagicMock(satisfied=True, reasoning="Good.")
|
||||||
|
|
||||||
|
# 6. Mock Synthesizer
|
||||||
|
mock_syn_instance = MagicMock()
|
||||||
|
mock_llms["synthesizer"].return_value = mock_syn_instance
|
||||||
|
mock_syn_instance.invoke.return_value = AIMessage(content="Final Research Summary")
|
||||||
|
|
||||||
# Initial state
|
|
||||||
initial_state = {
|
initial_state = {
|
||||||
"messages": [],
|
"messages": [],
|
||||||
"question": "Who is the governor of Florida?",
|
"question": "Who is the governor?",
|
||||||
"analysis": None,
|
"analysis": None,
|
||||||
"next_action": "",
|
"next_action": "",
|
||||||
"plan": None,
|
"iterations": 0,
|
||||||
"code": None,
|
"checklist": [],
|
||||||
"error": None,
|
"current_step": 0,
|
||||||
|
"vfs": {},
|
||||||
"plots": [],
|
"plots": [],
|
||||||
"dfs": {}
|
"dfs": {}
|
||||||
}
|
}
|
||||||
|
|
||||||
# Run the graph
|
result = app.invoke(initial_state, config={"recursion_limit": 20})
|
||||||
result = app.invoke(initial_state, config={"recursion_limit": 10})
|
|
||||||
|
|
||||||
assert result["next_action"] == "research"
|
assert "Final Research Summary" in [m.content for m in result["messages"]]
|
||||||
assert "Research Results" in result["messages"][-1].content
|
assert result["current_step"] == 1
|
||||||
|
|||||||
4
frontend/package-lock.json
generated
4
frontend/package-lock.json
generated
@@ -1,12 +1,12 @@
|
|||||||
{
|
{
|
||||||
"name": "frontend",
|
"name": "frontend",
|
||||||
"version": "0.0.0",
|
"version": "0.1.0",
|
||||||
"lockfileVersion": 3,
|
"lockfileVersion": 3,
|
||||||
"requires": true,
|
"requires": true,
|
||||||
"packages": {
|
"packages": {
|
||||||
"": {
|
"": {
|
||||||
"name": "frontend",
|
"name": "frontend",
|
||||||
"version": "0.0.0",
|
"version": "0.1.0",
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@hookform/resolvers": "^5.2.2",
|
"@hookform/resolvers": "^5.2.2",
|
||||||
"@tailwindcss/vite": "^4.1.18",
|
"@tailwindcss/vite": "^4.1.18",
|
||||||
|
|||||||
@@ -9,9 +9,9 @@ interface ExecutionStatusProps {
|
|||||||
|
|
||||||
const PHASE_CONFIG = [
|
const PHASE_CONFIG = [
|
||||||
{ label: "Analyzing query...", match: "Query analysis complete." },
|
{ label: "Analyzing query...", match: "Query analysis complete." },
|
||||||
{ label: "Generating strategic plan...", match: "Strategic plan generated." },
|
{ label: "Generating high-level plan...", match: "Checklist generated." },
|
||||||
{ label: "Writing analysis code...", match: "Analysis code generated." },
|
{ label: "Delegating to specialists...", match: "Task assigned." },
|
||||||
{ label: "Performing data analysis...", match: "Data analysis and visualization complete." }
|
{ label: "Synthesizing final answer...", match: "Final synthesis complete." }
|
||||||
]
|
]
|
||||||
|
|
||||||
export function ExecutionStatus({ steps, isComplete, className }: ExecutionStatusProps) {
|
export function ExecutionStatus({ steps, isComplete, className }: ExecutionStatusProps) {
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import { cva } from "class-variance-authority"
|
import { cva } from "class-variance-authority"
|
||||||
|
|
||||||
export const buttonVariants = cva(
|
export const buttonVariants = cva(
|
||||||
"inline-flex items-center justify-center gap-2 whitespace-nowrap rounded-md text-sm font-medium transition-all disabled:pointer-events-none disabled:opacity-50 [&_svg]:pointer-events-none [&_svg:not([class*='size-'])]:size-4 shrink-0 [&_svg]:shrink-0 outline-none focus-visible:border-ring focus-visible:ring-ring/50 focus-visible:ring-[3px] aria-invalid:ring-destructive/20 dark:aria-invalid:ring-destructive/40 aria-invalid:border-destructive",
|
"inline-flex items-center justify-center gap-2 whitespace-nowrap rounded-md text-sm font-medium transition-all cursor-pointer disabled:pointer-events-none disabled:opacity-50 disabled:cursor-default [&_svg]:pointer-events-none [&_svg:not([class*='size-'])]:size-4 shrink-0 [&_svg]:shrink-0 outline-none focus-visible:border-ring focus-visible:ring-ring/50 focus-visible:ring-[3px] aria-invalid:ring-destructive/20 dark:aria-invalid:ring-destructive/40 aria-invalid:border-destructive",
|
||||||
{
|
{
|
||||||
variants: {
|
variants: {
|
||||||
variant: {
|
variant: {
|
||||||
|
|||||||
29
frontend/src/components/ui/button.test.tsx
Normal file
29
frontend/src/components/ui/button.test.tsx
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
import { render, screen } from "@testing-library/react"
|
||||||
|
import { describe, expect, it } from "vitest"
|
||||||
|
import { Button } from "./button"
|
||||||
|
|
||||||
|
describe("Button component", () => {
|
||||||
|
it("renders with a pointer cursor", () => {
|
||||||
|
render(<Button>Click me</Button>)
|
||||||
|
const button = screen.getByRole("button", { name: /click me/i })
|
||||||
|
expect(button).toHaveClass("cursor-pointer")
|
||||||
|
})
|
||||||
|
|
||||||
|
it("renders different variants with a pointer cursor", () => {
|
||||||
|
const { rerender } = render(<Button variant="outline">Outline</Button>)
|
||||||
|
let button = screen.getByRole("button", { name: /outline/i })
|
||||||
|
expect(button).toHaveClass("cursor-pointer")
|
||||||
|
|
||||||
|
rerender(<Button variant="destructive">Destructive</Button>)
|
||||||
|
button = screen.getByRole("button", { name: /destructive/i })
|
||||||
|
expect(button).toHaveClass("cursor-pointer")
|
||||||
|
})
|
||||||
|
|
||||||
|
it("renders with a default cursor when disabled", () => {
|
||||||
|
render(<Button disabled>Disabled Button</Button>)
|
||||||
|
const button = screen.getByRole("button", { name: /disabled button/i })
|
||||||
|
expect(button).toBeDisabled()
|
||||||
|
expect(button).toHaveClass("disabled:pointer-events-none")
|
||||||
|
expect(button).toHaveClass("disabled:cursor-default")
|
||||||
|
})
|
||||||
|
})
|
||||||
74
frontend/src/components/ui/dropdown-menu.test.tsx
Normal file
74
frontend/src/components/ui/dropdown-menu.test.tsx
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
import { render, screen } from "@testing-library/react"
|
||||||
|
import { describe, expect, it } from "vitest"
|
||||||
|
import {
|
||||||
|
DropdownMenu,
|
||||||
|
DropdownMenuContent,
|
||||||
|
DropdownMenuItem,
|
||||||
|
DropdownMenuSub,
|
||||||
|
DropdownMenuSubTrigger,
|
||||||
|
DropdownMenuSubContent,
|
||||||
|
DropdownMenuPortal
|
||||||
|
} from "./dropdown-menu"
|
||||||
|
|
||||||
|
describe("DropdownMenu components", () => {
|
||||||
|
it("DropdownMenuItem should have cursor-pointer when enabled", () => {
|
||||||
|
render(
|
||||||
|
<DropdownMenu open={true}>
|
||||||
|
<DropdownMenuContent>
|
||||||
|
<DropdownMenuItem>Enabled Item</DropdownMenuItem>
|
||||||
|
</DropdownMenuContent>
|
||||||
|
</DropdownMenu>
|
||||||
|
)
|
||||||
|
const item = screen.getByText("Enabled Item")
|
||||||
|
expect(item).toHaveClass("cursor-pointer")
|
||||||
|
})
|
||||||
|
|
||||||
|
it("DropdownMenuSubTrigger should have cursor-pointer when enabled", () => {
|
||||||
|
render(
|
||||||
|
<DropdownMenu open={true}>
|
||||||
|
<DropdownMenuContent>
|
||||||
|
<DropdownMenuSub>
|
||||||
|
<DropdownMenuSubTrigger>Sub Trigger</DropdownMenuSubTrigger>
|
||||||
|
<DropdownMenuPortal>
|
||||||
|
<DropdownMenuSubContent />
|
||||||
|
</DropdownMenuPortal>
|
||||||
|
</DropdownMenuSub>
|
||||||
|
</DropdownMenuContent>
|
||||||
|
</DropdownMenu>
|
||||||
|
)
|
||||||
|
const trigger = screen.getByText("Sub Trigger")
|
||||||
|
expect(trigger).toHaveClass("cursor-pointer")
|
||||||
|
})
|
||||||
|
|
||||||
|
it("DropdownMenuItem should show default cursor when disabled", () => {
|
||||||
|
render(
|
||||||
|
<DropdownMenu open={true}>
|
||||||
|
<DropdownMenuContent>
|
||||||
|
<DropdownMenuItem disabled>Disabled Item</DropdownMenuItem>
|
||||||
|
</DropdownMenuContent>
|
||||||
|
</DropdownMenu>
|
||||||
|
)
|
||||||
|
const item = screen.getByText("Disabled Item")
|
||||||
|
// Radix sets data-disabled attribute
|
||||||
|
expect(item).toHaveAttribute("data-disabled")
|
||||||
|
expect(item).toHaveClass("data-[disabled]:cursor-default")
|
||||||
|
})
|
||||||
|
|
||||||
|
it("DropdownMenuSubTrigger should show default cursor when disabled", () => {
|
||||||
|
render(
|
||||||
|
<DropdownMenu open={true}>
|
||||||
|
<DropdownMenuContent>
|
||||||
|
<DropdownMenuSub>
|
||||||
|
<DropdownMenuSubTrigger disabled>Disabled Sub Trigger</DropdownMenuSubTrigger>
|
||||||
|
<DropdownMenuPortal>
|
||||||
|
<DropdownMenuSubContent />
|
||||||
|
</DropdownMenuPortal>
|
||||||
|
</DropdownMenuSub>
|
||||||
|
</DropdownMenuContent>
|
||||||
|
</DropdownMenu>
|
||||||
|
)
|
||||||
|
const trigger = screen.getByText("Disabled Sub Trigger")
|
||||||
|
expect(trigger).toHaveAttribute("data-disabled")
|
||||||
|
expect(trigger).toHaveClass("data-[disabled]:cursor-default")
|
||||||
|
})
|
||||||
|
})
|
||||||
@@ -27,7 +27,7 @@ const DropdownMenuSubTrigger = React.forwardRef<
|
|||||||
<DropdownMenuPrimitive.SubTrigger
|
<DropdownMenuPrimitive.SubTrigger
|
||||||
ref={ref}
|
ref={ref}
|
||||||
className={cn(
|
className={cn(
|
||||||
"focus:bg-accent data-[state=open]:bg-accent flex cursor-default items-center rounded-sm px-2 py-1.5 text-sm outline-hidden select-none",
|
"focus:bg-accent data-[state=open]:bg-accent flex cursor-pointer items-center rounded-sm px-2 py-1.5 text-sm outline-hidden select-none data-[disabled]:cursor-default data-[disabled]:pointer-events-none data-[disabled]:opacity-50",
|
||||||
inset && "pl-8",
|
inset && "pl-8",
|
||||||
className
|
className
|
||||||
)}
|
)}
|
||||||
@@ -81,7 +81,7 @@ const DropdownMenuItem = React.forwardRef<
|
|||||||
<DropdownMenuPrimitive.Item
|
<DropdownMenuPrimitive.Item
|
||||||
ref={ref}
|
ref={ref}
|
||||||
className={cn(
|
className={cn(
|
||||||
"focus:bg-accent focus:text-accent-foreground relative flex cursor-default items-center rounded-sm px-2 py-1.5 text-sm outline-hidden transition-colors data-[disabled]:pointer-events-none data-[disabled]:opacity-50 select-none",
|
"focus:bg-accent focus:text-accent-foreground relative flex cursor-pointer items-center rounded-sm px-2 py-1.5 text-sm outline-hidden transition-colors data-[disabled]:pointer-events-none data-[disabled]:opacity-50 data-[disabled]:cursor-default select-none",
|
||||||
inset && "pl-8",
|
inset && "pl-8",
|
||||||
className
|
className
|
||||||
)}
|
)}
|
||||||
@@ -99,7 +99,7 @@ const DropdownMenuCheckboxItem = React.forwardRef<
|
|||||||
<DropdownMenuPrimitive.CheckboxItem
|
<DropdownMenuPrimitive.CheckboxItem
|
||||||
ref={ref}
|
ref={ref}
|
||||||
className={cn(
|
className={cn(
|
||||||
"focus:bg-accent focus:text-accent-foreground relative flex cursor-default items-center rounded-sm py-1.5 pr-2 pl-8 text-sm outline-hidden transition-colors data-[disabled]:pointer-events-none data-[disabled]:opacity-50 select-none",
|
"focus:bg-accent focus:text-accent-foreground relative flex cursor-pointer items-center rounded-sm py-1.5 pr-2 pl-8 text-sm outline-hidden transition-colors data-[disabled]:pointer-events-none data-[disabled]:opacity-50 data-[disabled]:cursor-default select-none",
|
||||||
className
|
className
|
||||||
)}
|
)}
|
||||||
checked={checked}
|
checked={checked}
|
||||||
@@ -122,7 +122,7 @@ const DropdownMenuRadioItem = React.forwardRef<
|
|||||||
<DropdownMenuPrimitive.RadioItem
|
<DropdownMenuPrimitive.RadioItem
|
||||||
ref={ref}
|
ref={ref}
|
||||||
className={cn(
|
className={cn(
|
||||||
"focus:bg-accent focus:text-accent-foreground relative flex cursor-default items-center rounded-sm py-1.5 pr-2 pl-8 text-sm outline-hidden transition-colors data-[disabled]:pointer-events-none data-[disabled]:opacity-50 select-none",
|
"focus:bg-accent focus:text-accent-foreground relative flex cursor-pointer items-center rounded-sm py-1.5 pr-2 pl-8 text-sm outline-hidden transition-colors data-[disabled]:pointer-events-none data-[disabled]:opacity-50 data-[disabled]:cursor-default select-none",
|
||||||
className
|
className
|
||||||
)}
|
)}
|
||||||
{...props}
|
{...props}
|
||||||
|
|||||||
83
frontend/src/components/ui/sidebar.test.tsx
Normal file
83
frontend/src/components/ui/sidebar.test.tsx
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
import { render, screen } from "@testing-library/react"
|
||||||
|
import { describe, expect, it } from "vitest"
|
||||||
|
import {
|
||||||
|
SidebarProvider,
|
||||||
|
SidebarMenuButton,
|
||||||
|
SidebarMenuAction,
|
||||||
|
SidebarGroupAction,
|
||||||
|
SidebarMenu,
|
||||||
|
SidebarMenuItem,
|
||||||
|
SidebarGroup
|
||||||
|
} from "./sidebar"
|
||||||
|
|
||||||
|
describe("Sidebar components", () => {
|
||||||
|
it("SidebarMenuButton should have cursor-pointer", () => {
|
||||||
|
render(
|
||||||
|
<SidebarProvider>
|
||||||
|
<SidebarMenu>
|
||||||
|
<SidebarMenuItem>
|
||||||
|
<SidebarMenuButton>Menu Button</SidebarMenuButton>
|
||||||
|
</SidebarMenuItem>
|
||||||
|
</SidebarMenu>
|
||||||
|
</SidebarProvider>
|
||||||
|
)
|
||||||
|
const button = screen.getByRole("button", { name: /menu button/i })
|
||||||
|
expect(button).toHaveClass("cursor-pointer")
|
||||||
|
})
|
||||||
|
|
||||||
|
it("SidebarMenuAction should have cursor-pointer", () => {
|
||||||
|
render(
|
||||||
|
<SidebarProvider>
|
||||||
|
<SidebarMenu>
|
||||||
|
<SidebarMenuItem>
|
||||||
|
<SidebarMenuButton>Button</SidebarMenuButton>
|
||||||
|
<SidebarMenuAction>Action</SidebarMenuAction>
|
||||||
|
</SidebarMenuItem>
|
||||||
|
</SidebarMenu>
|
||||||
|
</SidebarProvider>
|
||||||
|
)
|
||||||
|
const action = screen.getByRole("button", { name: /action/i })
|
||||||
|
expect(action).toHaveClass("cursor-pointer")
|
||||||
|
})
|
||||||
|
|
||||||
|
it("SidebarGroupAction should have cursor-pointer", () => {
|
||||||
|
render(
|
||||||
|
<SidebarProvider>
|
||||||
|
<SidebarGroup>
|
||||||
|
<SidebarGroupAction>Group Action</SidebarGroupAction>
|
||||||
|
</SidebarGroup>
|
||||||
|
</SidebarProvider>
|
||||||
|
)
|
||||||
|
const action = screen.getByRole("button", { name: /group action/i })
|
||||||
|
expect(action).toHaveClass("cursor-pointer")
|
||||||
|
})
|
||||||
|
|
||||||
|
it("SidebarGroupAction should have cursor-default when disabled", () => {
|
||||||
|
render(
|
||||||
|
<SidebarProvider>
|
||||||
|
<SidebarGroup>
|
||||||
|
<SidebarGroupAction disabled>Disabled Group Action</SidebarGroupAction>
|
||||||
|
</SidebarGroup>
|
||||||
|
</SidebarProvider>
|
||||||
|
)
|
||||||
|
const action = screen.getByRole("button", { name: /disabled group action/i })
|
||||||
|
expect(action).toHaveClass("disabled:cursor-default")
|
||||||
|
expect(action).toBeDisabled()
|
||||||
|
})
|
||||||
|
|
||||||
|
it("SidebarMenuAction should not show pointer when disabled", () => {
|
||||||
|
render(
|
||||||
|
<SidebarProvider>
|
||||||
|
<SidebarMenu>
|
||||||
|
<SidebarMenuItem>
|
||||||
|
<SidebarMenuButton>Button</SidebarMenuButton>
|
||||||
|
<SidebarMenuAction disabled>Disabled Action</SidebarMenuAction>
|
||||||
|
</SidebarMenuItem>
|
||||||
|
</SidebarMenu>
|
||||||
|
</SidebarProvider>
|
||||||
|
)
|
||||||
|
const action = screen.getByRole("button", { name: /disabled action/i })
|
||||||
|
expect(action).toHaveClass("disabled:cursor-default")
|
||||||
|
expect(action).toBeDisabled()
|
||||||
|
})
|
||||||
|
})
|
||||||
@@ -405,7 +405,7 @@ function SidebarGroupAction({
|
|||||||
data-slot="sidebar-group-action"
|
data-slot="sidebar-group-action"
|
||||||
data-sidebar="group-action"
|
data-sidebar="group-action"
|
||||||
className={cn(
|
className={cn(
|
||||||
"text-sidebar-foreground ring-sidebar-ring hover:bg-sidebar-accent hover:text-sidebar-accent-foreground absolute top-3.5 right-3 flex aspect-square w-5 items-center justify-center rounded-md p-0 outline-hidden transition-transform focus-visible:ring-2 [&>svg]:size-4 [&>svg]:shrink-0",
|
"text-sidebar-foreground ring-sidebar-ring hover:bg-sidebar-accent hover:text-sidebar-accent-foreground absolute top-3.5 right-3 flex aspect-square w-5 items-center justify-center rounded-md p-0 outline-hidden transition-transform focus-visible:ring-2 [&>svg]:size-4 [&>svg]:shrink-0 cursor-pointer disabled:cursor-default disabled:pointer-events-none disabled:opacity-50",
|
||||||
// Increases the hit area of the button on mobile.
|
// Increases the hit area of the button on mobile.
|
||||||
"after:absolute after:-inset-2 md:after:hidden",
|
"after:absolute after:-inset-2 md:after:hidden",
|
||||||
"group-data-[collapsible=icon]:hidden",
|
"group-data-[collapsible=icon]:hidden",
|
||||||
@@ -477,7 +477,7 @@ function SidebarMenuButton({
|
|||||||
data-size={size}
|
data-size={size}
|
||||||
data-active={isActive}
|
data-active={isActive}
|
||||||
className={cn(
|
className={cn(
|
||||||
"peer/menu-button flex w-full items-center gap-2 overflow-hidden rounded-md p-2 text-left text-sm outline-hidden ring-sidebar-ring transition-[width,height,padding] hover:bg-sidebar-accent hover:text-sidebar-accent-foreground focus-visible:ring-2 active:bg-sidebar-accent active:text-sidebar-accent-foreground disabled:pointer-events-none disabled:opacity-50 group-has-data-[sidebar=menu-action]/menu-item:pr-8 aria-disabled:pointer-events-none aria-disabled:opacity-50 data-[active=true]:bg-sidebar-accent data-[active=true]:font-medium data-[active=true]:text-sidebar-accent-foreground data-[state=open]:hover:bg-sidebar-accent data-[state=open]:hover:text-sidebar-accent-foreground group-data-[collapsible=icon]:size-8! group-data-[collapsible=icon]:p-2! [&>span:last-child]:truncate [&>svg]:size-4 [&>svg]:shrink-0",
|
"peer/menu-button flex w-full items-center gap-2 overflow-hidden rounded-md p-2 text-left text-sm outline-hidden ring-sidebar-ring transition-[width,height,padding] hover:bg-sidebar-accent hover:text-sidebar-accent-foreground focus-visible:ring-2 active:bg-sidebar-accent active:text-sidebar-accent-foreground disabled:pointer-events-none disabled:opacity-50 group-has-data-[sidebar=menu-action]/menu-item:pr-8 aria-disabled:pointer-events-none aria-disabled:opacity-50 data-[active=true]:bg-sidebar-accent data-[active=true]:font-medium data-[active=true]:text-sidebar-accent-foreground data-[state=open]:hover:bg-sidebar-accent data-[state=open]:hover:text-sidebar-accent-foreground group-data-[collapsible=icon]:size-8! group-data-[collapsible=icon]:p-2! [&>span:last-child]:truncate [&>svg]:size-4 [&>svg]:shrink-0 cursor-pointer disabled:cursor-default",
|
||||||
variant === "default" && "hover:bg-sidebar-accent hover:text-sidebar-accent-foreground",
|
variant === "default" && "hover:bg-sidebar-accent hover:text-sidebar-accent-foreground",
|
||||||
variant === "outline" && "bg-background shadow-[0_0_0_1px_hsl(var(--sidebar-border))] hover:bg-sidebar-accent hover:text-sidebar-accent-foreground hover:shadow-[0_0_0_1px_hsl(var(--sidebar-accent))]",
|
variant === "outline" && "bg-background shadow-[0_0_0_1px_hsl(var(--sidebar-border))] hover:bg-sidebar-accent hover:text-sidebar-accent-foreground hover:shadow-[0_0_0_1px_hsl(var(--sidebar-accent))]",
|
||||||
size === "default" && "h-8 text-sm",
|
size === "default" && "h-8 text-sm",
|
||||||
@@ -528,7 +528,7 @@ function SidebarMenuAction({
|
|||||||
data-slot="sidebar-menu-action"
|
data-slot="sidebar-menu-action"
|
||||||
data-sidebar="menu-action"
|
data-sidebar="menu-action"
|
||||||
className={cn(
|
className={cn(
|
||||||
"text-sidebar-foreground ring-sidebar-ring hover:bg-sidebar-accent hover:text-sidebar-accent-foreground peer-hover/menu-button:text-sidebar-accent-foreground absolute top-1.5 right-1 flex aspect-square w-5 items-center justify-center rounded-md p-0 outline-hidden transition-transform focus-visible:ring-2 [&>svg]:size-4 [&>svg]:shrink-0",
|
"text-sidebar-foreground ring-sidebar-ring hover:bg-sidebar-accent hover:text-sidebar-accent-foreground peer-hover/menu-button:text-sidebar-accent-foreground absolute top-1.5 right-1 flex aspect-square w-5 items-center justify-center rounded-md p-0 outline-hidden transition-transform focus-visible:ring-2 [&>svg]:size-4 [&>svg]:shrink-0 cursor-pointer disabled:cursor-default disabled:pointer-events-none disabled:opacity-50",
|
||||||
// Increases the hit area of the button on mobile.
|
// Increases the hit area of the button on mobile.
|
||||||
"after:absolute after:-inset-2 md:after:hidden",
|
"after:absolute after:-inset-2 md:after:hidden",
|
||||||
"peer-data-[size=sm]/menu-button:top-1",
|
"peer-data-[size=sm]/menu-button:top-1",
|
||||||
|
|||||||
@@ -3,21 +3,21 @@ import { ChatService, type ChatEvent, type MessageResponse } from "./chat"
|
|||||||
|
|
||||||
describe("ChatService SSE Parsing", () => {
|
describe("ChatService SSE Parsing", () => {
|
||||||
it("should correctly parse a text stream chunk", () => {
|
it("should correctly parse a text stream chunk", () => {
|
||||||
const rawChunk = `data: {"type": "on_chat_model_stream", "name": "summarizer", "data": {"chunk": "Hello"}}\n\n`
|
const rawChunk = `data: {"type": "on_chat_model_stream", "name": "synthesizer", "data": {"chunk": "Hello"}}\n\n`
|
||||||
const events = ChatService.parseSSEChunk(rawChunk)
|
const events = ChatService.parseSSEChunk(rawChunk)
|
||||||
|
|
||||||
expect(events).toHaveLength(1)
|
expect(events).toHaveLength(1)
|
||||||
expect(events[0]).toEqual({
|
expect(events[0]).toEqual({
|
||||||
type: "on_chat_model_stream",
|
type: "on_chat_model_stream",
|
||||||
name: "summarizer",
|
name: "synthesizer",
|
||||||
data: { chunk: "Hello" }
|
data: { chunk: "Hello" }
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
it("should handle multiple events in one chunk", () => {
|
it("should handle multiple events in one chunk", () => {
|
||||||
const rawChunk =
|
const rawChunk =
|
||||||
`data: {"type": "on_chat_model_stream", "name": "summarizer", "data": {"chunk": "Hello"}}\n\n` +
|
`data: {"type": "on_chat_model_stream", "name": "synthesizer", "data": {"chunk": "Hello"}}\n\n` +
|
||||||
`data: {"type": "on_chat_model_stream", "name": "summarizer", "data": {"chunk": " World"}}\n\n`
|
`data: {"type": "on_chat_model_stream", "name": "synthesizer", "data": {"chunk": " World"}}\n\n`
|
||||||
|
|
||||||
const events = ChatService.parseSSEChunk(rawChunk)
|
const events = ChatService.parseSSEChunk(rawChunk)
|
||||||
|
|
||||||
@@ -25,8 +25,8 @@ describe("ChatService SSE Parsing", () => {
|
|||||||
expect(events[1].data!.chunk).toBe(" World")
|
expect(events[1].data!.chunk).toBe(" World")
|
||||||
})
|
})
|
||||||
|
|
||||||
it("should parse encoded plots from executor node", () => {
|
it("should parse encoded plots from data_analyst_worker node", () => {
|
||||||
const rawChunk = `data: {"type": "on_chain_end", "name": "executor", "data": {"encoded_plots": ["base64data"]}}\n\n`
|
const rawChunk = `data: {"type": "on_chain_end", "name": "data_analyst_worker", "data": {"encoded_plots": ["base64data"]}}\n\n`
|
||||||
const events = ChatService.parseSSEChunk(rawChunk)
|
const events = ChatService.parseSSEChunk(rawChunk)
|
||||||
|
|
||||||
expect(events[0].data!.encoded_plots).toEqual(["base64data"])
|
expect(events[0].data!.encoded_plots).toEqual(["base64data"])
|
||||||
@@ -45,7 +45,7 @@ describe("ChatService Message State Management", () => {
|
|||||||
const messages: MessageResponse[] = [{ id: "1", role: "assistant", content: "Initial", created_at: new Date().toISOString() }]
|
const messages: MessageResponse[] = [{ id: "1", role: "assistant", content: "Initial", created_at: new Date().toISOString() }]
|
||||||
const event: ChatEvent = {
|
const event: ChatEvent = {
|
||||||
type: "on_chat_model_stream",
|
type: "on_chat_model_stream",
|
||||||
node: "summarizer",
|
node: "synthesizer",
|
||||||
data: { chunk: { content: " text" } }
|
data: { chunk: { content: " text" } }
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -57,7 +57,7 @@ describe("ChatService Message State Management", () => {
|
|||||||
const messages: MessageResponse[] = [{ id: "1", role: "assistant", content: "Analysis", created_at: new Date().toISOString(), plots: [] }]
|
const messages: MessageResponse[] = [{ id: "1", role: "assistant", content: "Analysis", created_at: new Date().toISOString(), plots: [] }]
|
||||||
const event: ChatEvent = {
|
const event: ChatEvent = {
|
||||||
type: "on_chain_end",
|
type: "on_chain_end",
|
||||||
name: "executor",
|
name: "data_analyst_worker",
|
||||||
data: { encoded_plots: ["plot1"] }
|
data: { encoded_plots: ["plot1"] }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -86,7 +86,8 @@ export const ChatService = {
|
|||||||
const { type, name, node, data } = event
|
const { type, name, node, data } = event
|
||||||
|
|
||||||
// 1. Handle incremental LLM chunks for terminal nodes
|
// 1. Handle incremental LLM chunks for terminal nodes
|
||||||
if (type === "on_chat_model_stream" && (node === "summarizer" || node === "researcher" || node === "clarification")) {
|
// Now using 'synthesizer' for the final user response
|
||||||
|
if (type === "on_chat_model_stream" && (node === "synthesizer" || node === "clarification")) {
|
||||||
const chunk = data?.chunk?.content || ""
|
const chunk = data?.chunk?.content || ""
|
||||||
if (!chunk) return messages
|
if (!chunk) return messages
|
||||||
|
|
||||||
@@ -110,7 +111,7 @@ export const ChatService = {
|
|||||||
if (!lastMsg || lastMsg.role !== "assistant") return messages
|
if (!lastMsg || lastMsg.role !== "assistant") return messages
|
||||||
|
|
||||||
// Terminal nodes final text
|
// Terminal nodes final text
|
||||||
if (name === "summarizer" || name === "researcher" || name === "clarification") {
|
if (name === "synthesizer" || name === "clarification") {
|
||||||
const messages_list = data?.output?.messages
|
const messages_list = data?.output?.messages
|
||||||
const msg = messages_list ? messages_list[messages_list.length - 1]?.content : null
|
const msg = messages_list ? messages_list[messages_list.length - 1]?.content : null
|
||||||
|
|
||||||
@@ -121,8 +122,8 @@ export const ChatService = {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Plots from executor
|
// Plots from data analyst worker
|
||||||
if (name === "executor" && data?.encoded_plots) {
|
if (name === "data_analyst_worker" && data?.encoded_plots) {
|
||||||
lastMsg.plots = [...(lastMsg.plots || []), ...data.encoded_plots]
|
lastMsg.plots = [...(lastMsg.plots || []), ...data.encoded_plots]
|
||||||
// Filter out the 'active' step and replace with 'complete'
|
// Filter out the 'active' step and replace with 'complete'
|
||||||
const filteredSteps = (lastMsg.steps || []).filter(s => s !== "Performing data analysis...");
|
const filteredSteps = (lastMsg.steps || []).filter(s => s !== "Performing data analysis...");
|
||||||
@@ -134,15 +135,25 @@ export const ChatService = {
|
|||||||
// Status for intermediate nodes (completion)
|
// Status for intermediate nodes (completion)
|
||||||
const statusMap: Record<string, string> = {
|
const statusMap: Record<string, string> = {
|
||||||
"query_analyzer": "Query analysis complete.",
|
"query_analyzer": "Query analysis complete.",
|
||||||
"planner": "Strategic plan generated.",
|
"planner": "Checklist generated.",
|
||||||
"coder": "Analysis code generated."
|
"delegate": "Task assigned.",
|
||||||
|
"reflector": "Result verified.",
|
||||||
|
"coder": "Analysis code generated.",
|
||||||
|
"executor": "Code execution complete.",
|
||||||
|
"searcher": "Web search complete.",
|
||||||
|
"summarizer": "Task summary generated."
|
||||||
}
|
}
|
||||||
|
|
||||||
if (name && statusMap[name]) {
|
if (name && statusMap[name]) {
|
||||||
// Find and replace the active status if it exists
|
// Find and replace the active status if it exists
|
||||||
const activeStatus = name === "query_analyzer" ? "Analyzing query..." :
|
const activeStatus = name === "query_analyzer" ? "Analyzing query..." :
|
||||||
name === "planner" ? "Generating strategic plan..." :
|
name === "planner" ? "Generating high-level plan..." :
|
||||||
name === "coder" ? "Writing analysis code..." : null;
|
name === "delegate" ? "Routing task..." :
|
||||||
|
name === "reflector" ? "Evaluating results..." :
|
||||||
|
name === "coder" ? "Writing analysis code..." :
|
||||||
|
name === "executor" ? "Executing code..." :
|
||||||
|
name === "searcher" ? "Searching web..." :
|
||||||
|
name === "summarizer" ? "Summarizing results..." : null;
|
||||||
|
|
||||||
let filteredSteps = lastMsg.steps || [];
|
let filteredSteps = lastMsg.steps || [];
|
||||||
if (activeStatus) {
|
if (activeStatus) {
|
||||||
@@ -159,9 +170,13 @@ export const ChatService = {
|
|||||||
if (type === "on_chain_start") {
|
if (type === "on_chain_start") {
|
||||||
const startStatusMap: Record<string, string> = {
|
const startStatusMap: Record<string, string> = {
|
||||||
"query_analyzer": "Analyzing query...",
|
"query_analyzer": "Analyzing query...",
|
||||||
"planner": "Generating strategic plan...",
|
"planner": "Generating high-level plan...",
|
||||||
|
"delegate": "Routing task...",
|
||||||
|
"reflector": "Evaluating results...",
|
||||||
"coder": "Writing analysis code...",
|
"coder": "Writing analysis code...",
|
||||||
"executor": "Performing data analysis..."
|
"executor": "Executing code...",
|
||||||
|
"searcher": "Searching web...",
|
||||||
|
"summarizer": "Summarizing results..."
|
||||||
}
|
}
|
||||||
|
|
||||||
if (name && startStatusMap[name]) {
|
if (name && startStatusMap[name]) {
|
||||||
@@ -170,9 +185,11 @@ export const ChatService = {
|
|||||||
const lastMsg = { ...newMessages[lastMsgIndex] }
|
const lastMsg = { ...newMessages[lastMsgIndex] }
|
||||||
|
|
||||||
if (lastMsg && lastMsg.role === "assistant") {
|
if (lastMsg && lastMsg.role === "assistant") {
|
||||||
// Avoid duplicate start messages
|
const currentSteps = lastMsg.steps || []
|
||||||
if (!(lastMsg.steps || []).includes(startStatusMap[name])) {
|
// Allow duplicate steps if it's a retry (cycle),
|
||||||
lastMsg.steps = [...(lastMsg.steps || []), startStatusMap[name]]
|
// but avoid spamming the same active step multiple times in a row
|
||||||
|
if (currentSteps[currentSteps.length - 1] !== startStatusMap[name]) {
|
||||||
|
lastMsg.steps = [...currentSteps, startStatusMap[name]]
|
||||||
newMessages[lastMsgIndex] = lastMsg
|
newMessages[lastMsgIndex] = lastMsg
|
||||||
return newMessages
|
return newMessages
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user