feat: implement mvp with email-first login flow and langgraph architecture

This commit is contained in:
Yunxiao Xu
2026-02-09 23:22:30 -08:00
parent af227d40e6
commit 5a943b902a
79 changed files with 8200 additions and 1 deletions

51
.env.example Normal file
View File

@@ -0,0 +1,51 @@
# API Keys
OPENAI_API_KEY=your_openai_api_key_here
GOOGLE_API_KEY=your_google_api_key_here
# App Configuration
DATA_DIR=data
DATA_STATE=new_jersey
LOG_LEVEL=INFO
DEV_MODE=false
# Voter Database Configuration
DB_HOST=localhost
DB_PORT=5432
DB_NAME=blockdata
DB_USER=user
DB_PSWD=password
DB_TABLE=rd_gc_voters_nj
# Application/History Database Configuration
HISTORY_DB_URL=postgresql://user:password@localhost:5433/ea_history
# OIDC Configuration (Authentik/SSO)
OIDC_CLIENT_ID=your_client_id
OIDC_CLIENT_SECRET=your_client_secret
OIDC_SERVER_METADATA_URL=https://your-authentik.example.com/application/o/ea-chatbot/.well-known/openid-configuration
OIDC_REDIRECT_URI=http://localhost:8501
# Node Configuration Overrides (Optional)
# Format: <NODE_NAME>_LLM__<PARAMETER>
# Possible parameters: PROVIDER, MODEL, TEMPERATURE, MAX_TOKENS
# Query Analyzer
# QUERY_ANALYZER_LLM__PROVIDER=openai
# QUERY_ANALYZER_LLM__MODEL=gpt-5-mini
# QUERY_ANALYZER_LLM__TEMPERATURE=0.0
# Planner
# PLANNER_LLM__PROVIDER=openai
# PLANNER_LLM__MODEL=gpt-5-mini
# Coder
# CODER_LLM__PROVIDER=openai
# CODER_LLM__MODEL=gpt-5-mini
# Summarizer
# SUMMARIZER_LLM__PROVIDER=openai
# SUMMARIZER_LLM__MODEL=gpt-5-mini
# Researcher
# RESEARCHER_LLM__PROVIDER=google
# RESEARCHER_LLM__MODEL=gemini-2.0-flash

11
.gitignore vendored
View File

@@ -168,4 +168,13 @@ cython_debug/
#.idea/
# PyPI configuration file
.pypirc
.pypirc
conductor/
data/
# Logs
logs/
postgres-data/
langchain-docs/

169
GEMINI.md Normal file
View File

@@ -0,0 +1,169 @@
# Election Analytics Chatbot - Project Guide
## Overview
This document serves as a guide for rewriting the current "BambooAI" based chatbot system into a modern, stateful, and graph-based architecture using **LangGraph**. The goal is to improve maintainability, observability, and flexibility of the agentic workflows.
## 1. Migration Goals
- **Framework Switch**: Move from the custom linear `ChatBot` class (in `src/ea_chatbot/bambooai/core/chatbot.py`) to `LangGraph`.
- **State Management**: explicit state management using LangGraph's `StateGraph`.
- **Modularity**: Break down monolithic methods (`pd_agent_converse`, `execute_code`) into distinct Nodes.
- **Observability**: Easier debugging of the decision process (Routing -> Planning -> Coding -> Executing).
## 2. Architecture Proposal
### 2.1. The Graph State
The state will track the conversation and execution context.
```python
from typing import TypedDict, Annotated, List, Dict, Any, Optional
from langchain_core.messages import BaseMessage
import operator
class AgentState(TypedDict):
# Conversation history
messages: Annotated[List[BaseMessage], operator.add]
# Task context
question: str
# Query Analysis (Decomposition results)
analysis: Optional[Dict[str, Any]]
# Expected keys: "requires_dataset", "expert", "data", "unknown", "condition"
# Step-by-step reasoning
plan: Optional[str]
# Code execution context
code: Optional[str]
code_output: Optional[str]
error: Optional[str]
# Artifacts (for UI display)
plots: List[Figure] # Matplotlib figures
dfs: Dict[str, DataFrame] # Pandas DataFrames
# Control flow
iterations: int
next_action: str # Routing hint: "clarify", "plan", "research", "end"
```
### 2.2. Nodes (The Actors)
We will map existing logic to these nodes:
1. **`query_analyzer_node`** (Router & Refiner):
* **Logic**: Replaces `Expert Selector` and `Analyst Selector`.
* **Function**:
1. Decomposes the user's query into key elements (Data, Unknowns, Conditions).
2. Determines if the query is ambiguous or missing critical information.
* **Output**: Updates `messages`. Returns routing decision:
* `clarification_node` (if ambiguous).
* `planner_node` (if clear data task).
* `researcher_node` (if general/web task).
2. **`clarification_node`** (Human-in-the-loop):
* **Logic**: Replaces `Theorist-Clarification`.
* **Function**: Formulates a specific question to ask the user for missing details.
* **Output**: Returns a message to the user and **interrupts** the graph execution to await user input.
3. **`researcher_node`** (Theorist):
* **Logic**: Handles general queries or web searches.
* **Function**: Uses `GoogleSearch` tool if necessary.
* **Output**: Final answer.
4. **`planner_node`**:
* **Logic**: Replaces `Planner`.
* **Function**: Generates a step-by-step plan based on the decomposed query elements and dataframe ontology.
* **Output**: Updates `plan`.
5. **`coder_node`**:
* **Logic**: Replaces `Code Generator` & `Error Corrector`.
* **Function**: Generates Python code. If `error` exists in state, it attempts to fix it.
* **Output**: Updates `code`.
6. **`executor_node`**:
* **Logic**: Replaces `Code Executor`.
* **Function**: Executes the Python code in a safe(r) environment. It needs access to the `DBClient`.
* **Output**: Updates `code_output`, `plots`, `dfs`. If exception, updates `error`.
7. **`summarizer_node`**:
* **Logic**: Replaces `Solution Summarizer`.
* **Function**: Interprets the code output and generates a natural language response.
* **Output**: Final response message.
### 2.3. The Workflow (Graph)
```mermaid
graph TD
Start --> QueryAnalyzer
QueryAnalyzer -->|Ambiguous| Clarification
Clarification -->|User Input| QueryAnalyzer
QueryAnalyzer -->|General/Web| Researcher
QueryAnalyzer -->|Data Analysis| Planner
Planner --> Coder
Coder --> Executor
Executor -->|Success| Summarizer
Executor -->|Error| Coder
Researcher --> End
Summarizer --> End
```
## 3. Implementation Steps
### Step 1: Dependencies
Add the following packages to `pyproject.toml`:
* `langgraph`
* `langchain`
* `langchain-openai`
* `langchain-google-genai`
* `langchain-community`
### Step 2: Directory Structure
Create a new package for the graph logic to keep it separate from the old one during migration.
```
src/ea_chatbot/
├── graph/
│ ├── __init__.py
│ ├── state.py # State definition
│ ├── nodes/ # Individual node implementations
│ │ ├── __init__.py
│ │ ├── router.py
│ │ ├── planner.py
│ │ ├── coder.py
│ │ ├── executor.py
│ │ └── ...
│ ├── workflow.py # Graph construction
│ └── tools/ # DB and Search tools wrapped for LangChain
└── ...
```
### Step 3: Tool Wrapping
Wrap the existing `DBClient` (from `src/ea_chatbot/bambooai/utils/db_client.py`) into a structure accessible by the `executor_node`. The `executor_node` will likely keep the existing `exec()` based approach initially for compatibility with the generated code, but structured as a graph node.
### Step 4: Prompt Migration
Port the prompts from `data/PROMPT_TEMPLATES.json` or `src/ea_chatbot/bambooai/prompts/strings.py` into the respective nodes. Use LangChain's `ChatPromptTemplate` for better management.
### Step 5: Streamlit Integration
Update `src/ea_chatbot/app.py` to use the new `workflow.compile()` runnable.
* Instead of `chatbot.pd_agent_converse(...)`, use `app.stream(...)` (LangGraph app).
* Handle the streaming output to update the UI progressively.
## 4. Key Considerations for Refactoring
* **Database Connection**: Ensure `DBClient` is initialized once and passed to the `Executor` node efficiently (e.g., via `configurable` parameters or closure).
* **Prompt Templating**: The current system uses simple `format` strings. Switching to LangChain templates allows for easier model switching and partial formatting.
* **Token Management**: LangGraph provides built-in tracing (if LangSmith is enabled), but we should ensure the `OutputManager` logic (printing costs/tokens) is preserved or adapted if still needed for the CLI/Logs.
* **Vector DB**: The current system has `PineconeWrapper` for RAG. This should be integrated into the `Planner` or `Coder` node to fetch few-shot examples or context.
## 5. Next Actions
1. **Initialize**: Create the folder structure.
2. **Define State**: Create `src/ea_chatbot/graph/state.py`.
3. **Implement Router**: Create the first node to replicate `Expert Selector` logic.
4. **Implement Executor**: Port the `exec()` logic to a node.
## 6. Git Operations
- Branches should be used for specific features or bug fixes.
- New branches should be created from the `main` branch and `conductor` branch.
- The conductor should always use the `conductor` branch and derived branches.
- When a feature or fix is complete, use rebase to keep the commit history clean before merging.
- The conductor related changes should never be merged into the `main` branch.

149
alembic.ini Normal file
View File

@@ -0,0 +1,149 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts.
# this is typically a path given in POSIX (e.g. forward slashes)
# format, relative to the token %(here)s which refers to the location of this
# ini file
script_location = %(here)s/alembic
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
# for all available tokens
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# Or organize into date-based subdirectories (requires recursive_version_locations = true)
# file_template = %%(year)d/%%(month).2d/%%(day).2d_%%(hour).2d%%(minute).2d_%%(second).2d_%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory. for multiple paths, the path separator
# is defined by "path_separator" below.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the tzdata library which can be installed by adding
# `alembic[tz]` to the pip requirements.
# string value is passed to ZoneInfo()
# leave blank for localtime
# timezone =
# max length of characters to apply to the "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to <script_location>/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "path_separator"
# below.
# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions
# path_separator; This indicates what character is used to split lists of file
# paths, including version_locations and prepend_sys_path within configparser
# files such as alembic.ini.
# The default rendered in new alembic.ini files is "os", which uses os.pathsep
# to provide os-dependent path splitting.
#
# Note that in order to support legacy alembic.ini files, this default does NOT
# take place if path_separator is not present in alembic.ini. If this
# option is omitted entirely, fallback logic is as follows:
#
# 1. Parsing of the version_locations option falls back to using the legacy
# "version_path_separator" key, which if absent then falls back to the legacy
# behavior of splitting on spaces and/or commas.
# 2. Parsing of the prepend_sys_path option falls back to the legacy
# behavior of splitting on spaces, commas, or colons.
#
# Valid values for path_separator are:
#
# path_separator = :
# path_separator = ;
# path_separator = space
# path_separator = newline
#
# Use os.pathsep. Default configuration used for new projects.
path_separator = os
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
# database URL. This is consumed by the user-maintained env.py script only.
# other means of configuring database URLs may be customized within the env.py
# file.
sqlalchemy.url = driver://user:pass@localhost/dbname
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module
# hooks = ruff
# ruff.type = module
# ruff.module = ruff
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
# Alternatively, use the exec runner to execute a binary found on your PATH
# hooks = ruff
# ruff.type = exec
# ruff.executable = ruff
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
# Logging configuration. This is also consumed by the user-maintained
# env.py script only.
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARNING
handlers = console
qualname =
[logger_sqlalchemy]
level = WARNING
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

1
alembic/README Normal file
View File

@@ -0,0 +1 @@
Generic single-database configuration.

83
alembic/env.py Normal file
View File

@@ -0,0 +1,83 @@
from logging.config import fileConfig
import os
import sys
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from alembic import context
# Add src to path to ensure we can import the app
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../src'))
from ea_chatbot.config import Settings
from ea_chatbot.history.models import Base
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# add your model's MetaData object here
# for 'autogenerate' support
target_metadata = Base.metadata
# Override sqlalchemy.url from settings
settings = Settings()
config.set_main_option("sqlalchemy.url", settings.history_db_url)
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(
connection=connection, target_metadata=target_metadata
)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

28
alembic/script.py.mako Normal file
View File

@@ -0,0 +1,28 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
"""Upgrade schema."""
${upgrades if upgrades else "pass"}
def downgrade() -> None:
"""Downgrade schema."""
${downgrades if downgrades else "pass"}

View File

@@ -0,0 +1,32 @@
"""Add summary to conversation
Revision ID: 63886baa1255
Revises: a15fde9a62df
Create Date: 2026-02-07 22:34:47.254569
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '63886baa1255'
down_revision: Union[str, Sequence[str], None] = 'a15fde9a62df'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('conversations', sa.Column('summary', sa.Text(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('conversations', 'summary')
# ### end Alembic commands ###

View File

@@ -0,0 +1,68 @@
"""Initial history tables
Revision ID: a15fde9a62df
Revises:
Create Date: 2026-02-07 21:18:57.524534
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'a15fde9a62df'
down_revision: Union[str, Sequence[str], None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('users',
sa.Column('id', sa.String(), nullable=False),
sa.Column('username', sa.String(), nullable=False),
sa.Column('password_hash', sa.String(), nullable=True),
sa.Column('display_name', sa.String(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_users_username'), 'users', ['username'], unique=True)
op.create_table('conversations',
sa.Column('id', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=False),
sa.Column('data_state', sa.String(), nullable=False),
sa.Column('name', sa.String(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_table('messages',
sa.Column('id', sa.String(), nullable=False),
sa.Column('conversation_id', sa.String(), nullable=False),
sa.Column('role', sa.String(), nullable=False),
sa.Column('content', sa.Text(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(['conversation_id'], ['conversations.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_table('plots',
sa.Column('id', sa.String(), nullable=False),
sa.Column('message_id', sa.String(), nullable=False),
sa.Column('image_data', sa.LargeBinary(), nullable=False),
sa.ForeignKeyConstraint(['message_id'], ['messages.id'], ),
sa.PrimaryKeyConstraint('id')
)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('plots')
op.drop_table('messages')
op.drop_table('conversations')
op.drop_index(op.f('ix_users_username'), table_name='users')
op.drop_table('users')
# ### end Alembic commands ###

18
docker-compose.yml Normal file
View File

@@ -0,0 +1,18 @@
services:
history-db:
image: postgres:16
container_name: ea-history-db
restart: unless-stopped
environment:
POSTGRES_DB: ${HISTORY_DB_NAME:-ea_history}
POSTGRES_USER: ${HISTORY_DB_USER:-user}
POSTGRES_PASSWORD: ${HISTORY_DB_PSWD:-password}
ports:
- "5433:5432"
volumes:
- ./postgres-data:/var/lib/postgresql/data
healthcheck:
test: ["CMD-SHELL", "pg_isready -U $${POSTGRES_USER} -d $${POSTGRES_DB}"]
interval: 10s
timeout: 5s
retries: 5

39
pyproject.toml Normal file
View File

@@ -0,0 +1,39 @@
[project]
name = "ea-chatbot"
version = "0.1.0"
description = "An election analytics chatbot using LangGraph."
requires-python = ">=3.13"
dependencies = [
"langgraph",
"langchain",
"langchain-openai",
"langchain-google-genai",
"langchain-community",
"pandas",
"matplotlib",
"psycopg2-binary",
"termcolor>=3.3.0",
"sqlalchemy>=2.0.46",
"pyyaml>=6.0.3",
"streamlit>=1.54.0",
"rich>=14.3.2",
"alembic>=1.13.0",
"authlib>=1.3.0",
"httpx>=0.27.0",
"argon2-cffi>=23.1.0",
]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["src/ea_chatbot"]
[dependency-groups]
dev = [
"pytest>=9.0.2",
"pytest-cov>=7.0.0",
"ruff>=0.9.3",
"mypy>=1.14.1",
]

View File

451
src/ea_chatbot/app.py Normal file
View File

@@ -0,0 +1,451 @@
import streamlit as st
import asyncio
import os
import io
from dotenv import load_dotenv
from ea_chatbot.graph.workflow import app
from ea_chatbot.graph.state import AgentState
from ea_chatbot.utils.logging import get_logger
from ea_chatbot.utils.helpers import merge_agent_state
from ea_chatbot.config import Settings
from ea_chatbot.history.manager import HistoryManager
from ea_chatbot.auth import OIDCClient, AuthType, get_user_auth_type
# Load environment variables
load_dotenv()
# Initialize Config and Manager
settings = Settings()
history_manager = HistoryManager(settings.history_db_url)
# Initialize OIDC Client if configured
oidc_client = None
if settings.oidc_client_id and settings.oidc_client_secret and settings.oidc_server_metadata_url:
oidc_client = OIDCClient(
client_id=settings.oidc_client_id,
client_secret=settings.oidc_client_secret,
server_metadata_url=settings.oidc_server_metadata_url,
# Redirect back to the same page
redirect_uri=os.getenv("OIDC_REDIRECT_URI", "http://localhost:8501")
)
# Initialize Logger
logger = get_logger(level=settings.log_level, log_file="logs/app.jsonl")
# --- Authentication Helpers ---
def login_user(user):
st.session_state.user = user
st.session_state.messages = []
st.session_state.summary = ""
st.session_state.current_conversation_id = None
st.rerun()
def logout_user():
for key in list(st.session_state.keys()):
del st.session_state[key]
st.rerun()
def load_conversation(conv_id):
messages = history_manager.get_messages(conv_id)
formatted_messages = []
for m in messages:
# Convert DB models to session state dicts
msg_dict = {
"role": m.role,
"content": m.content,
"plots": [p.image_data for p in m.plots]
}
formatted_messages.append(msg_dict)
st.session_state.messages = formatted_messages
st.session_state.current_conversation_id = conv_id
# Fetch summary from DB
with history_manager.get_session() as session:
from ea_chatbot.history.models import Conversation
conv = session.get(Conversation, conv_id)
st.session_state.summary = conv.summary if conv else ""
st.rerun()
def main():
st.set_page_config(
page_title="Election Analytics Chatbot",
page_icon="🗳️",
layout="wide"
)
# Check for OIDC Callback
if "code" in st.query_params and oidc_client:
code = st.query_params["code"]
try:
token = oidc_client.exchange_code_for_token(code)
user_info = oidc_client.get_user_info(token)
email = user_info.get("email")
name = user_info.get("name") or user_info.get("preferred_username")
if email:
user = history_manager.sync_user_from_oidc(email=email, display_name=name)
# Clear query params
st.query_params.clear()
login_user(user)
except Exception as e:
st.error(f"OIDC Login failed: {str(e)}")
# Display Login Screen if not authenticated
if "user" not in st.session_state:
st.title("🗳️ Election Analytics Chatbot")
# Initialize Login State
if "login_step" not in st.session_state:
st.session_state.login_step = "email"
if "login_email" not in st.session_state:
st.session_state.login_email = ""
col1, col2 = st.columns([1, 1])
with col1:
st.header("Login")
# Step 1: Identification
if st.session_state.login_step == "email":
st.write("Please enter your email to begin:")
with st.form("email_form"):
email_input = st.text_input("Email", value=st.session_state.login_email)
submitted = st.form_submit_button("Next")
if submitted:
if not email_input.strip():
st.error("Email cannot be empty.")
else:
st.session_state.login_email = email_input.strip()
auth_type = get_user_auth_type(st.session_state.login_email, history_manager)
if auth_type == AuthType.LOCAL:
st.session_state.login_step = "login_password"
elif auth_type == AuthType.OIDC:
st.session_state.login_step = "oidc_login"
else: # AuthType.NEW
st.session_state.login_step = "register_details"
st.rerun()
# Step 2a: Local Login
elif st.session_state.login_step == "login_password":
st.info(f"Welcome back, **{st.session_state.login_email}**!")
with st.form("password_form"):
password = st.text_input("Password", type="password")
col_login, col_back = st.columns([1, 1])
submitted = col_login.form_submit_button("Login")
back = col_back.form_submit_button("Back")
if back:
st.session_state.login_step = "email"
st.rerun()
if submitted:
user = history_manager.authenticate_user(st.session_state.login_email, password)
if user:
login_user(user)
else:
st.error("Invalid email or password")
# Step 2b: Registration
elif st.session_state.login_step == "register_details":
st.info(f"Create an account for **{st.session_state.login_email}**")
with st.form("register_form"):
reg_name = st.text_input("Display Name")
reg_password = st.text_input("Password", type="password")
col_reg, col_back = st.columns([1, 1])
submitted = col_reg.form_submit_button("Register & Login")
back = col_back.form_submit_button("Back")
if back:
st.session_state.login_step = "email"
st.rerun()
if submitted:
if not reg_password:
st.error("Password is required for registration.")
else:
user = history_manager.create_user(st.session_state.login_email, reg_password, reg_name)
st.success("Registered! Logging in...")
login_user(user)
# Step 2c: OIDC Redirection
elif st.session_state.login_step == "oidc_login":
st.info(f"**{st.session_state.login_email}** is configured for Single Sign-On (SSO).")
col_sso, col_back = st.columns([1, 1])
with col_sso:
if oidc_client:
login_url = oidc_client.get_login_url()
st.link_button("Login with SSO", login_url, type="primary", use_container_width=True)
else:
st.error("OIDC is not configured.")
with col_back:
if st.button("Back", use_container_width=True):
st.session_state.login_step = "email"
st.rerun()
with col2:
if oidc_client:
st.header("Single Sign-On")
st.write("Login with your organizational account.")
if st.button("Login with SSO"):
login_url = oidc_client.get_login_url()
st.link_button("Go to **YXXU**", login_url, type="primary")
else:
st.info("SSO is not configured.")
st.stop()
# --- Main App (Authenticated) ---
user = st.session_state.user
# Sidebar configuration
with st.sidebar:
st.title(f"Hi, {user.display_name or user.username}!")
if st.button("Logout"):
logout_user()
st.divider()
st.header("History")
if st.button(" New Chat", use_container_width=True):
st.session_state.messages = []
st.session_state.summary = ""
st.session_state.current_conversation_id = None
st.rerun()
# List conversations for the current user and data state
conversations = history_manager.get_conversations(user.id, settings.data_state)
for conv in conversations:
col_c, col_r, col_d = st.columns([0.7, 0.15, 0.15])
is_current = st.session_state.get("current_conversation_id") == conv.id
label = f"💬 {conv.name}" if not is_current else f"👉 {conv.name}"
if col_c.button(label, key=f"conv_{conv.id}", use_container_width=True):
load_conversation(conv.id)
if col_r.button("✏️", key=f"ren_{conv.id}"):
st.session_state.renaming_id = conv.id
if col_d.button("🗑️", key=f"del_{conv.id}"):
if history_manager.delete_conversation(conv.id):
if is_current:
st.session_state.current_conversation_id = None
st.session_state.messages = []
st.rerun()
# Rename dialog
if st.session_state.get("renaming_id"):
rid = st.session_state.renaming_id
with st.form("rename_form"):
new_name = st.text_input("New Name")
if st.form_submit_button("Save"):
history_manager.rename_conversation(rid, new_name)
del st.session_state.renaming_id
st.rerun()
if st.form_submit_button("Cancel"):
del st.session_state.renaming_id
st.rerun()
st.divider()
st.header("Settings")
# Check for DEV_MODE env var (defaults to False)
default_dev_mode = os.getenv("DEV_MODE", "false").lower() == "true"
dev_mode = st.checkbox("Dev Mode", value=default_dev_mode, help="Enable to see code generation and raw reasoning steps.")
st.title("🗳️ Election Analytics Chatbot")
# Initialize chat history state
if "messages" not in st.session_state:
st.session_state.messages = []
if "summary" not in st.session_state:
st.session_state.summary = ""
# Display chat messages from history on app rerun
for message in st.session_state.messages:
with st.chat_message(message["role"]):
if message.get("plan") and dev_mode:
with st.expander("Reasoning Plan"):
st.code(message["plan"], language="yaml")
if message.get("code") and dev_mode:
with st.expander("Generated Code"):
st.code(message["code"], language="python")
st.markdown(message["content"])
if message.get("plots"):
for plot_data in message["plots"]:
# If plot_data is bytes, convert to image
if isinstance(plot_data, bytes):
st.image(plot_data)
else:
# Fallback for old session state or non-binary
st.pyplot(plot_data)
if message.get("dfs"):
for df_name, df in message["dfs"].items():
st.subheader(f"Data: {df_name}")
st.dataframe(df)
# Accept user input
if prompt := st.chat_input("Ask a question about election data..."):
# Ensure we have a conversation ID
if not st.session_state.get("current_conversation_id"):
# Auto-create conversation
conv_name = (prompt[:30] + '...') if len(prompt) > 30 else prompt
conv = history_manager.create_conversation(user.id, settings.data_state, conv_name)
st.session_state.current_conversation_id = conv.id
conv_id = st.session_state.current_conversation_id
# Save user message to DB
history_manager.add_message(conv_id, "user", prompt)
# Add user message to session state
st.session_state.messages.append({"role": "user", "content": prompt})
# Display user message in chat message container
with st.chat_message("user"):
st.markdown(prompt)
# Prepare graph input
initial_state: AgentState = {
"messages": st.session_state.messages[:-1], # Pass history (excluding the current prompt)
"question": prompt,
"summary": st.session_state.summary,
"analysis": None,
"next_action": "",
"plan": None,
"code": None,
"code_output": None,
"error": None,
"plots": [],
"dfs": {}
}
# Placeholder for graph output
with st.chat_message("assistant"):
final_state = initial_state
# Real-time node updates
with st.status("Thinking...", expanded=True) as status:
try:
# Use app.stream to capture node transitions
for event in app.stream(initial_state):
for node_name, state_update in event.items():
prev_error = final_state.get("error")
# Use helper to merge state correctly (appending messages/plots, updating dfs)
final_state = merge_agent_state(final_state, state_update)
if node_name == "query_analyzer":
analysis = state_update.get("analysis", {})
next_action = state_update.get("next_action", "unknown")
status.write(f"🔍 **Analyzed Query:**")
for k,v in analysis.items():
status.write(f"- {k:<8}: {v}")
status.markdown(f"Next Step: {next_action.capitalize()}")
elif node_name == "planner":
status.write("📋 **Plan Generated**")
# Render artifacts
if state_update.get("plan") and dev_mode:
with st.expander("Reasoning Plan", expanded=True):
st.code(state_update["plan"], language="yaml")
elif node_name == "researcher":
status.write("🌐 **Research Complete**")
if state_update.get("messages") and dev_mode:
for msg in state_update["messages"]:
# Extract content from BaseMessage or show raw string
content = getattr(msg, "text", msg.content)
status.markdown(content)
elif node_name == "coder":
status.write("💻 **Code Generated**")
if state_update.get("code") and dev_mode:
with st.expander("Generated Code"):
st.code(state_update["code"], language="python")
elif node_name == "error_corrector":
status.write("🛠️ **Fixing Execution Error...**")
if prev_error:
truncated_error = prev_error.strip()
if len(truncated_error) > 180:
truncated_error = truncated_error[:180] + "..."
status.write(f"Previous error: {truncated_error}")
if state_update.get("code") and dev_mode:
with st.expander("Corrected Code"):
st.code(state_update["code"], language="python")
elif node_name == "executor":
if state_update.get("error"):
if dev_mode:
status.write(f"❌ **Execution Error:** {state_update.get('error')}...")
else:
status.write(f"❌ **Execution Error:** {state_update.get('error')[:100]}...")
else:
status.write("✅ **Execution Successful**")
if state_update.get("plots"):
status.write(f"📊 Generated {len(state_update['plots'])} plot(s)")
elif node_name == "summarizer":
status.write("📝 **Summarizing Results...**")
status.update(label="Complete!", state="complete", expanded=False)
except Exception as e:
status.update(label="Error!", state="error")
st.error(f"Error during graph execution: {str(e)}")
# Extract results
response_text: str = ""
if final_state.get("messages"):
# The last message is the Assistant's response
last_msg = final_state["messages"][-1]
response_text = getattr(last_msg, "text", str(last_msg.content))
st.markdown(response_text)
# Collect plot bytes for saving to DB
plot_bytes_list = []
if final_state.get("plots"):
for fig in final_state["plots"]:
st.pyplot(fig)
# Convert fig to bytes
buf = io.BytesIO()
fig.savefig(buf, format="png")
plot_bytes_list.append(buf.getvalue())
if final_state.get("dfs"):
for df_name, df in final_state["dfs"].items():
st.subheader(f"Data: {df_name}")
st.dataframe(df)
# Save assistant message to DB
history_manager.add_message(conv_id, "assistant", response_text, plots=plot_bytes_list)
# Update summary in DB
new_summary = final_state.get("summary", "")
if new_summary:
history_manager.update_conversation_summary(conv_id, new_summary)
# Store assistant response in session history
st.session_state.messages.append({
"role": "assistant",
"content": response_text,
"plan": final_state.get("plan"),
"code": final_state.get("code"),
"plots": plot_bytes_list,
"dfs": final_state.get("dfs")
})
st.session_state.summary = new_summary
if __name__ == "__main__":
main()

97
src/ea_chatbot/auth.py Normal file
View File

@@ -0,0 +1,97 @@
import requests
from enum import Enum
from typing import Dict, Any, Optional
from authlib.integrations.requests_client import OAuth2Session
class AuthType(Enum):
LOCAL = "local"
OIDC = "oidc"
NEW = "new"
def get_user_auth_type(email: str, history_manager: Any) -> AuthType:
"""
Determine the authentication type for a given email.
Args:
email: The user's email address.
history_manager: Instance of HistoryManager to check the DB.
Returns:
AuthType: LOCAL if password exists, OIDC if user exists but no password, NEW otherwise.
"""
user = history_manager.get_user(email)
if not user:
return AuthType.NEW
if user.password_hash:
return AuthType.LOCAL
return AuthType.OIDC
class OIDCClient:
"""
Client for OIDC Authentication using Authlib.
Designed to work within a Streamlit environment.
"""
def __init__(
self,
client_id: str,
client_secret: str,
server_metadata_url: str,
redirect_uri: str = "http://localhost:8501"
):
self.client_id = client_id
self.client_secret = client_secret
self.server_metadata_url = server_metadata_url
self.redirect_uri = redirect_uri
self.oauth_session = OAuth2Session(
client_id=client_id,
client_secret=client_secret,
redirect_uri=redirect_uri,
scope="openid email profile"
)
self.metadata: Dict[str, Any] = {}
def fetch_metadata(self) -> Dict[str, Any]:
"""Fetch OIDC provider metadata if not already fetched."""
if not self.metadata:
self.metadata = requests.get(self.server_metadata_url).json()
return self.metadata
def get_login_url(self) -> str:
"""Generate the authorization URL."""
metadata = self.fetch_metadata()
authorization_endpoint = metadata.get("authorization_endpoint")
if not authorization_endpoint:
raise ValueError("authorization_endpoint not found in OIDC metadata")
uri, state = self.oauth_session.create_authorization_url(authorization_endpoint)
return uri
def exchange_code_for_token(self, code: str) -> Dict[str, Any]:
"""Exchange the authorization code for an access token."""
metadata = self.fetch_metadata()
token_endpoint = metadata.get("token_endpoint")
if not token_endpoint:
raise ValueError("token_endpoint not found in OIDC metadata")
token = self.oauth_session.fetch_token(
token_endpoint,
code=code,
client_secret=self.client_secret
)
return token
def get_user_info(self, token: Dict[str, Any]) -> Dict[str, Any]:
"""Fetch user information using the access token."""
metadata = self.fetch_metadata()
userinfo_endpoint = metadata.get("userinfo_endpoint")
if not userinfo_endpoint:
raise ValueError("userinfo_endpoint not found in OIDC metadata")
# Set the token on the session so it's used in the request
self.oauth_session.token = token
resp = self.oauth_session.get(userinfo_endpoint)
resp.raise_for_status()
return resp.json()

44
src/ea_chatbot/config.py Normal file
View File

@@ -0,0 +1,44 @@
from typing import Dict, Any, Optional
from pydantic import BaseModel, Field
from pydantic_settings import BaseSettings, SettingsConfigDict
class LLMConfig(BaseModel):
"""Configuration for a specific LLM node."""
provider: str = "openai"
model: str = "gpt-5-mini"
temperature: float = 0.0
max_tokens: Optional[int] = None
provider_specific: Dict[str, Any] = Field(default_factory=dict)
class Settings(BaseSettings):
"""Global application settings."""
data_dir: str = "data"
data_state: str = "new_jersey"
log_level: str = Field(default="INFO", alias="LOG_LEVEL")
# Voter Database configuration
db_host: str = Field(default="localhost", alias="DB_HOST")
db_port: int = Field(default=5432, alias="DB_PORT")
db_name: str = Field(default="blockdata", alias="DB_NAME")
db_user: str = Field(default="user", alias="DB_USER")
db_pswd: str = Field(default="password", alias="DB_PSWD")
db_table: str = Field(default="rd_gc_voters_nj", alias="DB_TABLE")
# Application/History Database
history_db_url: str = Field(default="postgresql://user:password@localhost:5433/ea_history", alias="HISTORY_DB_URL")
# OIDC Configuration
oidc_client_id: Optional[str] = Field(default=None, alias="OIDC_CLIENT_ID")
oidc_client_secret: Optional[str] = Field(default=None, alias="OIDC_CLIENT_SECRET")
oidc_server_metadata_url: Optional[str] = Field(default=None, alias="OIDC_SERVER_METADATA_URL")
# Default configurations for each node
query_analyzer_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0))
planner_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0))
coder_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0))
summarizer_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0))
researcher_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0))
# Allow nested env vars like QUERY_ANALYZER_LLM__MODEL
model_config = SettingsConfigDict(env_nested_delimiter='__', env_prefix='')

View File

View File

View File

@@ -0,0 +1,45 @@
from langchain_core.messages import AIMessage
from ea_chatbot.graph.state import AgentState
from ea_chatbot.config import Settings
from ea_chatbot.utils.llm_factory import get_llm_model
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
def clarification_node(state: AgentState) -> dict:
"""Ask the user for missing information or clarifications."""
question = state["question"]
analysis = state.get("analysis", {})
ambiguities = analysis.get("ambiguities", [])
settings = Settings()
logger = get_logger("clarification")
logger.info(f"Generating clarification for {len(ambiguities)} ambiguities.")
llm = get_llm_model(
settings.query_analyzer_llm,
callbacks=[LangChainLoggingHandler(logger=logger)]
)
system_prompt = """You are a Clarification Specialist. Your role is to identify what information is missing from a user's request to perform a data analysis or research task.
Based on the analysis of the user's question, formulate a polite and concise request for the missing information."""
prompt = f"""Original Question: {question}
Missing/Ambiguous Information: {', '.join(ambiguities) if ambiguities else 'Unknown ambiguities'}
Please ask the user for the necessary details."""
messages = [
("system", system_prompt),
("user", prompt)
]
try:
response = llm.invoke(messages)
logger.info("[bold green]Clarification generated.[/bold green]")
return {
"messages": [response],
"next_action": "end" # To indicate we are done for now
}
except Exception as e:
logger.error(f"Failed to generate clarification: {str(e)}")
raise e

View File

@@ -0,0 +1,47 @@
from ea_chatbot.graph.state import AgentState
from ea_chatbot.config import Settings
from ea_chatbot.utils.llm_factory import get_llm_model
from ea_chatbot.utils import helpers, database_inspection
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
from ea_chatbot.graph.prompts.coder import CODE_GENERATOR_PROMPT
from ea_chatbot.schemas import CodeGenerationResponse
def coder_node(state: AgentState) -> dict:
"""Generate Python code based on the plan and data summary."""
question = state["question"]
plan = state.get("plan", "")
code_output = state.get("code_output", "None")
settings = Settings()
logger = get_logger("coder")
logger.info("Generating Python code...")
llm = get_llm_model(
settings.coder_llm,
callbacks=[LangChainLoggingHandler(logger=logger)]
)
structured_llm = llm.with_structured_output(CodeGenerationResponse)
# Always provide data summary
database_description = database_inspection.get_data_summary(data_dir=settings.data_dir) or "No data available."
example_code = "" # Placeholder
messages = CODE_GENERATOR_PROMPT.format_messages(
question=question,
plan=plan,
database_description=database_description,
code_exec_results=code_output,
example_code=example_code
)
try:
response = structured_llm.invoke(messages)
logger.info("[bold green]Code generated.[/bold green]")
return {
"code": response.parsed_code,
"error": None # Clear previous errors on new code generation
}
except Exception as e:
logger.error(f"Failed to generate code: {str(e)}")
raise e

View File

@@ -0,0 +1,40 @@
from ea_chatbot.graph.state import AgentState
from ea_chatbot.config import Settings
from ea_chatbot.utils.llm_factory import get_llm_model
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
from ea_chatbot.graph.prompts.coder import ERROR_CORRECTOR_PROMPT
from ea_chatbot.schemas import CodeGenerationResponse
def error_corrector_node(state: AgentState) -> dict:
"""Fix the code based on the execution error."""
code = state.get("code", "")
error = state.get("error", "Unknown error")
settings = Settings()
logger = get_logger("error_corrector")
logger.warning(f"[bold red]Execution error detected:[/bold red] {error[:100]}...")
logger.info("Attempting to correct the code...")
# Reuse coder LLM config or add a new one. Using coder_llm for now.
llm = get_llm_model(
settings.coder_llm,
callbacks=[LangChainLoggingHandler(logger=logger)]
)
structured_llm = llm.with_structured_output(CodeGenerationResponse)
messages = ERROR_CORRECTOR_PROMPT.format_messages(
code=code,
error=error
)
try:
response = structured_llm.invoke(messages)
logger.info("[bold green]Correction generated.[/bold green]")
return {
"code": response.parsed_code,
"error": None # Clear error after fix attempt
}
except Exception as e:
logger.error(f"Failed to correct code: {str(e)}")
raise e

View File

@@ -0,0 +1,102 @@
import io
import sys
import traceback
from contextlib import redirect_stdout
from typing import Any, Dict, List, TYPE_CHECKING
import pandas as pd
from matplotlib.figure import Figure
from ea_chatbot.graph.state import AgentState
from ea_chatbot.utils.db_client import DBClient
from ea_chatbot.utils.logging import get_logger
from ea_chatbot.config import Settings
if TYPE_CHECKING:
from ea_chatbot.types import DBSettings
def executor_node(state: AgentState) -> dict:
"""Execute the Python code and capture output, plots, and dataframes."""
code = state.get("code")
logger = get_logger("executor")
if not code:
logger.error("No code provided to executor.")
return {"error": "No code provided to executor."}
logger.info("Executing Python code...")
settings = Settings()
db_settings: "DBSettings" = {
"host": settings.db_host,
"port": settings.db_port,
"user": settings.db_user,
"pswd": settings.db_pswd,
"db": settings.db_name,
"table": settings.db_table
}
db_client = DBClient(settings=db_settings)
# Initialize local variables for execution
# 'db' is the DBClient instance, 'plots' is for matplotlib figures
local_vars = {
'db': db_client,
'plots': [],
'pd': pd
}
stdout_buffer = io.StringIO()
error = None
code_output = ""
plots = []
dfs = {}
try:
with redirect_stdout(stdout_buffer):
# Execute the code in the context of local_vars
exec(code, {}, local_vars)
code_output = stdout_buffer.getvalue()
# Limit the output length if it's too long
if code_output.count('\n') > 32:
code_output = '\n'.join(code_output.split('\n')[:32]) + '\n...'
# Extract plots
raw_plots = local_vars.get('plots', [])
if isinstance(raw_plots, list):
plots = [p for p in raw_plots if isinstance(p, Figure)]
# Extract DataFrames that were likely intended for display
# We look for DataFrames in local_vars that were mentioned in the code
for key, value in local_vars.items():
if isinstance(value, pd.DataFrame):
# Heuristic: if the variable name is in the code, it might be a result DF
if key in code:
dfs[key] = value
logger.info(f"[bold green]Execution complete.[/bold green] Captured {len(plots)} plots and {len(dfs)} dataframes.")
except Exception as e:
# Capture the traceback
exc_type, exc_value, tb = sys.exc_info()
full_traceback = traceback.format_exc()
# Filter traceback to show only the relevant part (the executed string)
filtered_tb_lines = [line for line in full_traceback.split('\n') if '<string>' in line]
error = '\n'.join(filtered_tb_lines)
if error:
error += '\n'
error += f"{exc_type.__name__ if exc_type else 'Exception'}: {exc_value}"
logger.error(f"Execution failed: {str(e)}")
# If we have an error, we still might want to see partial stdout
code_output = stdout_buffer.getvalue()
return {
"code_output": code_output,
"error": error,
"plots": plots,
"dfs": dfs
}

View File

@@ -0,0 +1,51 @@
import yaml
from ea_chatbot.graph.state import AgentState
from ea_chatbot.config import Settings
from ea_chatbot.utils.llm_factory import get_llm_model
from ea_chatbot.utils import helpers, database_inspection
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
from ea_chatbot.graph.prompts.planner import PLANNER_PROMPT
from ea_chatbot.schemas import TaskPlanResponse
def planner_node(state: AgentState) -> dict:
"""Generate a structured plan based on the query analysis."""
question = state["question"]
history = state.get("messages", [])[-6:]
summary = state.get("summary", "")
settings = Settings()
logger = get_logger("planner")
logger.info("Generating task plan...")
llm = get_llm_model(
settings.planner_llm,
callbacks=[LangChainLoggingHandler(logger=logger)]
)
structured_llm = llm.with_structured_output(TaskPlanResponse)
date_str = helpers.get_readable_date()
# Always provide data summary; LLM decides relevance.
database_description = database_inspection.get_data_summary(data_dir=settings.data_dir) or "No data available."
example_plan = ""
messages = PLANNER_PROMPT.format_messages(
date=date_str,
question=question,
history=history,
summary=summary,
database_description=database_description,
example_plan=example_plan
)
# Generate the structured plan
try:
response = structured_llm.invoke(messages)
# Convert the structured response back to YAML string for the state
plan_yaml = yaml.dump(response.model_dump(), sort_keys=False)
logger.info("[bold green]Plan generated successfully.[/bold green]")
return {"plan": plan_yaml}
except Exception as e:
logger.error(f"Failed to generate plan: {str(e)}")
raise e

View File

@@ -0,0 +1,72 @@
from typing import List, Literal
from pydantic import BaseModel, Field
from ea_chatbot.graph.state import AgentState
from ea_chatbot.config import Settings
from ea_chatbot.utils.llm_factory import get_llm_model
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
from ea_chatbot.graph.prompts.query_analyzer import QUERY_ANALYZER_PROMPT
class QueryAnalysis(BaseModel):
"""Analysis of the user's query."""
data_required: List[str] = Field(description="List of data points or entities mentioned (e.g., ['2024 results', 'Florida']).")
unknowns: List[str] = Field(description="List of target information the user wants to know or needed for final answer (e.g., 'who won', 'total votes').")
ambiguities: List[str] = Field(description="List of CRITICAL missing details that prevent ANY analysis. Do NOT include database names or plot types if defaults can be used.")
conditions: List[str] = Field(description="List of any filters or constraints (e.g., ['year=2024', 'state=Florida']). Include context resolved from history.")
next_action: Literal["plan", "clarify", "research"] = Field(description="The next action to take. 'plan' for data analysis (even with defaults), 'research' for general knowledge, or 'clarify' ONLY for critical ambiguities.")
def query_analyzer_node(state: AgentState) -> dict:
"""Analyze the user's question and determine the next course of action."""
question = state["question"]
history = state.get("messages", [])
summary = state.get("summary", "")
# Keep last 3 turns (6 messages)
history = history[-6:]
settings = Settings()
logger = get_logger("query_analyzer")
logger.info(f"Analyzing question: [italic]\"{question}\"[/italic]")
# Initialize the LLM with structured output using the factory
# Pass logging callback to track LLM usage
llm = get_llm_model(
settings.query_analyzer_llm,
callbacks=[LangChainLoggingHandler(logger=logger)]
)
structured_llm = llm.with_structured_output(QueryAnalysis)
# Prepare messages using the prompt template
messages = QUERY_ANALYZER_PROMPT.format_messages(
question=question,
history=history,
summary=summary
)
try:
# Invoke the structured LLM directly with the list of messages
analysis_result = structured_llm.invoke(messages)
analysis_result = QueryAnalysis.model_validate(analysis_result)
analysis_dict = analysis_result.model_dump()
analysis_dict.pop("next_action")
next_action = analysis_result.next_action
logger.info(f"Analysis complete. Next action: [bold magenta]{next_action}[/bold magenta]")
except Exception as e:
logger.error(f"Error during query analysis: {str(e)}")
analysis_dict = {
"data_required": [],
"unknowns": [],
"ambiguities": [f"Error during analysis: {str(e)}"],
"conditions": []
}
next_action = "clarify"
return {
"analysis": analysis_dict,
"next_action": next_action
}

View File

@@ -0,0 +1,60 @@
from langchain_core.messages import AIMessage
from langchain_openai import ChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI
from ea_chatbot.graph.state import AgentState
from ea_chatbot.config import Settings
from ea_chatbot.utils.llm_factory import get_llm_model
from ea_chatbot.utils import helpers
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
from ea_chatbot.graph.prompts.researcher import RESEARCHER_PROMPT
def researcher_node(state: AgentState) -> dict:
"""Handle general research queries or web searches."""
question = state["question"]
history = state.get("messages", [])[-6:]
summary = state.get("summary", "")
settings = Settings()
logger = get_logger("researcher")
logger.info(f"Researching question: [italic]\"{question}\"[/italic]")
# Use researcher_llm from settings
llm = get_llm_model(
settings.researcher_llm,
callbacks=[LangChainLoggingHandler(logger=logger)]
)
date_str = helpers.get_readable_date()
messages = RESEARCHER_PROMPT.format_messages(
date=date_str,
question=question,
history=history,
summary=summary
)
# Provider-aware tool binding
try:
if isinstance(llm, ChatGoogleGenerativeAI):
# Native Google Search for Gemini
llm_with_tools = llm.bind_tools([{"google_search": {}}])
elif isinstance(llm, ChatOpenAI):
# Native Web Search for OpenAI (built-in tool)
llm_with_tools = llm.bind_tools([{"type": "web_search"}])
else:
# Fallback for other providers that might not support these specific search tools
llm_with_tools = llm
except Exception as e:
logger.warning(f"Failed to bind search tools: {str(e)}. Falling back to base LLM.")
llm_with_tools = llm
try:
response = llm_with_tools.invoke(messages)
logger.info("[bold green]Research complete.[/bold green]")
return {
"messages": [response]
}
except Exception as e:
logger.error(f"Research failed: {str(e)}")
raise e

View File

@@ -0,0 +1,52 @@
from langchain_core.messages import SystemMessage, HumanMessage
from ea_chatbot.graph.state import AgentState
from ea_chatbot.config import Settings
from ea_chatbot.utils.llm_factory import get_llm_model
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
def summarize_conversation_node(state: AgentState) -> dict:
"""Update the conversation summary based on the latest interaction."""
summary = state.get("summary", "")
messages = state.get("messages", [])
# We only summarize if there are messages
if not messages:
return {}
# Get the last turn (User + Assistant)
last_turn = messages[-2:]
settings = Settings()
logger = get_logger("summarize_conversation")
logger.info("Updating conversation summary...")
# Use summarizer_llm for this task as well
llm = get_llm_model(
settings.summarizer_llm,
callbacks=[LangChainLoggingHandler(logger=logger)]
)
if summary:
prompt = (
f"This is a summary of the conversation so far: {summary}\n\n"
"Extend the summary by taking into account the new messages above."
)
else:
prompt = "Create a summary of the conversation above."
# Construct the messages for the summarization LLM
summarization_messages = [
SystemMessage(content=f"Current summary: {summary}" if summary else "You are a helpful assistant that summarizes conversations."),
HumanMessage(content=f"Recent messages:\n{last_turn}\n\n{prompt}\n\nKeep the summary concise and focused on the key topics and data points discussed.")
]
try:
response = llm.invoke(summarization_messages)
new_summary = response.content
logger.info("[bold green]Conversation summary updated.[/bold green]")
return {"summary": new_summary}
except Exception as e:
logger.error(f"Failed to update summary: {str(e)}")
# If summarization fails, we keep the old one
return {"summary": summary}

View File

@@ -0,0 +1,44 @@
from langchain_core.messages import AIMessage
from ea_chatbot.graph.state import AgentState
from ea_chatbot.config import Settings
from ea_chatbot.utils.llm_factory import get_llm_model
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
from ea_chatbot.graph.prompts.summarizer import SUMMARIZER_PROMPT
def summarizer_node(state: AgentState) -> dict:
"""Summarize the code execution results into a final answer."""
question = state["question"]
plan = state.get("plan", "")
code_output = state.get("code_output", "")
history = state.get("messages", [])[-6:]
summary = state.get("summary", "")
settings = Settings()
logger = get_logger("summarizer")
logger.info("Generating final summary...")
llm = get_llm_model(
settings.summarizer_llm,
callbacks=[LangChainLoggingHandler(logger=logger)]
)
messages = SUMMARIZER_PROMPT.format_messages(
question=question,
plan=plan,
code_output=code_output,
history=history,
summary=summary
)
try:
response = llm.invoke(messages)
logger.info("[bold green]Summary generated.[/bold green]")
# Return the final message to be added to the state
return {
"messages": [response]
}
except Exception as e:
logger.error(f"Failed to generate summary: {str(e)}")
raise e

View File

@@ -0,0 +1,10 @@
from .query_analyzer import QUERY_ANALYZER_PROMPT
from .planner import PLANNER_PROMPT
from .coder import CODE_GENERATOR_PROMPT, ERROR_CORRECTOR_PROMPT
__all__ = [
"QUERY_ANALYZER_PROMPT",
"PLANNER_PROMPT",
"CODE_GENERATOR_PROMPT",
"ERROR_CORRECTOR_PROMPT",
]

View File

@@ -0,0 +1,64 @@
from langchain_core.prompts import ChatPromptTemplate
CODE_GENERATOR_SYSTEM = """You are an AI data analyst and your job is to assist users with data analysis and coding tasks.
The user will provide a task and a plan.
**Data Access:**
- A database client is available as a variable named `db`.
- You MUST use `db.query_df(sql_query)` to execute SQL queries and retrieve data as a Pandas DataFrame.
- Do NOT assume a dataframe `df` is already loaded unless explicitly stated. You usually need to query it first.
- The database schema is described in the prompt. Use it to construct valid SQL queries.
**Plotting:**
- If you need to plot any data, use the `plots` list to store the figures.
- Example: `plots.append(fig)` or `plots.append(plt.gcf())`.
- Do not use `plt.show()` as it will render the plot and cause an error.
**Code Requirements:**
- Produce FULL, COMPLETE CODE that includes all steps and solves the task!
- Always include the import statements at the top of the code (e.g., `import pandas as pd`, `import matplotlib.pyplot as plt`).
- Always include print statements to output the results of your code.
- Use `db.query_df("SELECT ...")` to get data."""
CODE_GENERATOR_USER = """TASK:
{question}
PLAN:
```yaml
{plan}
```
AVAILABLE DATA SUMMARY (Database Schema):
{database_description}
CODE EXECUTION OF THE PREVIOUS TASK RESULTED IN:
{code_exec_results}
{example_code}"""
ERROR_CORRECTOR_SYSTEM = """The execution of the code resulted in an error.
Return a complete, corrected python code that incorporates the fixes for the error.
**Reminders:**
- You have access to a database client via the variable `db`.
- Use `db.query_df(sql)` to run queries.
- Use `plots.append(fig)` for plots.
- Always include imports and print statements."""
ERROR_CORRECTOR_USER = """FAILED CODE:
```python
{code}
```
ERROR:
{error}"""
CODE_GENERATOR_PROMPT = ChatPromptTemplate.from_messages([
("system", CODE_GENERATOR_SYSTEM),
("human", CODE_GENERATOR_USER),
])
ERROR_CORRECTOR_PROMPT = ChatPromptTemplate.from_messages([
("system", ERROR_CORRECTOR_SYSTEM),
("human", ERROR_CORRECTOR_USER),
])

View File

@@ -0,0 +1,46 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
PLANNER_SYSTEM = """You are a Research Specialist and your job is to find answers and educate the user.
Provide factual information responding directly to the user's question. Include key details and context to ensure your response comprehensively answers their query.
Today's Date is: {date}"""
PLANNER_USER = """Conversation Summary: {summary}
TASK:
{question}
AVAILABLE DATA SUMMARY (Use only if relevant to the task):
{database_description}
First: Evaluate whether you have all necessary and requested information to provide a solution.
Use the dataset description above to determine what data and in what format you have available to you.
You are able to search internet if the user asks for it, or you require any information that you can not derive from the given dataset or the instruction.
Second: Incorporate any additional relevant context, reasoning, or details from previous interactions or internal chain-of-thought that may impact the solution.
Ensure that all such information is fully included in your response rather than referring to previous answers indirectly.
Third: Reflect on the problem and briefly describe it, while addressing the problem goal, inputs, outputs,
rules, constraints, and other relevant details that appear in the problem description.
Fourth: Based on the preceding steps, formulate your response as an algorithm, breaking the solution in up to eight simple concise yet descriptive, clear English steps.
You MUST Include all values or instructions as described in the above task, or retrieved using internet search!
If fewer steps suffice, that's acceptable. If more are needed, please include them.
Remember to explain steps rather than write code.
This algorithm will be later converted to Python code.
If a dataframe is required, assume it is named 'df' and is already defined/populated based on the data summary.
There is a list variable called `plots` that you need to use to store any plots you generate. Do not use `plt.show()` as it will render the plot and cause an error.
Output the algorithm as a YAML string. Always enclose the YAML string within ```yaml tags.
**Note: Ensure that any necessary context from prior interactions is fully embedded in the plan. Do not use phrases like "refer to previous answer"; instead, provide complete details inline.**
{example_plan}"""
PLANNER_PROMPT = ChatPromptTemplate.from_messages([
("system", PLANNER_SYSTEM),
MessagesPlaceholder(variable_name="history"),
("human", PLANNER_USER),
])

View File

@@ -0,0 +1,33 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
SYSTEM_PROMPT = """You are an expert election data analyst. Decompose the user's question into key elements to determine the next action.
### Context & Defaults
- **History:** Use the conversation history and summary to resolve coreferences (e.g., "those results", "that state"). Assume the current question inherits missing context (Year, State, County) from history.
- **Data Access:** You have access to voter and election databases. Proceed to planning without asking for database or table names.
- **Downstream Capabilities:** Visualizations are generated as Matplotlib figures. Proceed to planning for "graphs" or "plots" without asking for file formats or plot types.
- **Trends:** For trend requests without a specified interval, allow the Planner to use a sensible default (e.g., by election cycle).
### Instructions:
1. **Analyze:** Identify if the request is for data analysis, general facts (web research), or is critically ambiguous.
2. **Extract Entities & Conditions:**
- **Data Required:** e.g., "vote count", "demographics".
- **Conditions:** e.g., "Year=2024". Include context from history.
3. **Identify Target & Critical Ambiguities:**
- **Unknowns:** The core target question.
- **Critical Ambiguities:** ONLY list issues that PREVENT any analysis.
- Examples: No timeframe/geography in query OR history; "track the same voter" without an identity definition.
4. **Determine Action:**
- `plan`: For data analysis where defaults or history provide sufficient context.
- `research`: For general knowledge.
- `clarify`: ONLY for CRITICAL ambiguities."""
USER_PROMPT_TEMPLATE = """Conversation Summary: {summary}
Analyze the following question: {question}"""
QUERY_ANALYZER_PROMPT = ChatPromptTemplate.from_messages([
("system", SYSTEM_PROMPT),
MessagesPlaceholder(variable_name="history"),
("human", USER_PROMPT_TEMPLATE),
])

View File

@@ -0,0 +1,12 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
RESEARCHER_PROMPT = ChatPromptTemplate.from_messages([
("system", """You are a Research Specialist and your job is to find answers and educate the user.
Provide factual information responding directly to the user's question. Include key details and context to ensure your response comprehensively answers their query.
Today's Date is: {date}"""),
MessagesPlaceholder(variable_name="history"),
("user", """Conversation Summary: {summary}
{question}""")
])

View File

@@ -0,0 +1,27 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
SUMMARIZER_PROMPT = ChatPromptTemplate.from_messages([
("system", """You are an expert election data analyst providing a final answer to the user.
Use the provided conversation history and summary to ensure your response is contextually relevant and flows naturally from previous turns.
Conversation Summary: {summary}"""),
MessagesPlaceholder(variable_name="history"),
("user", """The user presented you with the following question.
Question: {question}
To address this, you have designed an algorithm.
Algorithm: {plan}.
You have crafted a Python code based on this algorithm, and the output generated by the code's execution is as follows.
Output: {code_output}.
Please produce a comprehensive, easy-to-understand answer that:
1. Summarizes the main insights or conclusions achieved through your method's implementation. Include execution results if necessary.
2. Includes relevant findings from the code execution in a clear format (e.g., text explanation, tables, lists, bullet points).
- Avoid referencing the code or output as 'the above results' or saying 'it's in the code output.'
- Instead, present the actual key data or statistics within your explanation.
3. If the user requested specific information that does not appear in the code's output but you can provide it, include that information directly in your summary.
4. Present any data or tables that might have been generated by the code in full, since the user cannot directly see the execution output.
Your goal is to give a final answer that stands on its own without requiring the user to see the code or raw output directly.""")
])

View File

@@ -0,0 +1,33 @@
from typing import TypedDict, Annotated, List, Dict, Any, Optional
from langchain_core.messages import BaseMessage
from langchain.agents import AgentState as AS
import operator
class AgentState(AS):
# Conversation history
messages: Annotated[List[BaseMessage], operator.add]
# Task context
question: str
# Query Analysis (Decomposition results)
analysis: Optional[Dict[str, Any]]
# Expected keys: "requires_dataset", "expert", "data", "unknown", "condition"
# Step-by-step reasoning
plan: Optional[str]
# Code execution context
code: Optional[str]
code_output: Optional[str]
error: Optional[str]
# Artifacts (for UI display)
plots: Annotated[List[Any], operator.add] # Matplotlib figures
dfs: Dict[str, Any] # Pandas DataFrames
# Conversation summary
summary: Optional[str]
# Routing hint: "clarify", "plan", "research", "end"
next_action: str

View File

@@ -0,0 +1,87 @@
from langgraph.graph import StateGraph, END
from ea_chatbot.graph.state import AgentState
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node
from ea_chatbot.graph.nodes.planner import planner_node
from ea_chatbot.graph.nodes.coder import coder_node
from ea_chatbot.graph.nodes.error_corrector import error_corrector_node
from ea_chatbot.graph.nodes.executor import executor_node
from ea_chatbot.graph.nodes.summarizer import summarizer_node
from ea_chatbot.graph.nodes.researcher import researcher_node
from ea_chatbot.graph.nodes.clarification import clarification_node
from ea_chatbot.graph.nodes.summarize_conversation import summarize_conversation_node
def router(state: AgentState) -> str:
"""Route to the next node based on the analysis."""
next_action = state.get("next_action")
if next_action == "plan":
return "planner"
elif next_action == "research":
return "researcher"
elif next_action == "clarify":
return "clarification"
else:
return END
def create_workflow():
"""Create the LangGraph workflow."""
workflow = StateGraph(AgentState)
# Add nodes
workflow.add_node("query_analyzer", query_analyzer_node)
workflow.add_node("planner", planner_node)
workflow.add_node("coder", coder_node)
workflow.add_node("error_corrector", error_corrector_node)
workflow.add_node("researcher", researcher_node)
workflow.add_node("clarification", clarification_node)
workflow.add_node("executor", executor_node)
workflow.add_node("summarizer", summarizer_node)
workflow.add_node("summarize_conversation", summarize_conversation_node)
# Set entry point
workflow.set_entry_point("query_analyzer")
# Add conditional edges from query_analyzer
workflow.add_conditional_edges(
"query_analyzer",
router,
{
"planner": "planner",
"researcher": "researcher",
"clarification": "clarification",
END: END
}
)
# Linear flow for planning and coding
workflow.add_edge("planner", "coder")
workflow.add_edge("coder", "executor")
# Executor routing
def executor_router(state: AgentState) -> str:
if state.get("error"):
return "error_corrector"
return "summarizer"
workflow.add_conditional_edges(
"executor",
executor_router,
{
"error_corrector": "error_corrector",
"summarizer": "summarizer"
}
)
workflow.add_edge("error_corrector", "executor")
workflow.add_edge("researcher", "summarize_conversation")
workflow.add_edge("clarification", END)
workflow.add_edge("summarizer", "summarize_conversation")
workflow.add_edge("summarize_conversation", END)
# Compile the graph
app = workflow.compile()
return app
# Initialize the app
app = create_workflow()

View File

View File

@@ -0,0 +1,183 @@
from contextlib import contextmanager
from typing import Optional, List
from sqlalchemy import create_engine, select, delete
from sqlalchemy.orm import sessionmaker, Session
from argon2 import PasswordHasher
from argon2.exceptions import VerifyMismatchError
from ea_chatbot.history.models import User, Conversation, Message, Plot
# Argon2 Password Hasher
ph = PasswordHasher()
class HistoryManager:
"""Manages database sessions and operations for history and user data."""
def __init__(self, db_url: str):
self.engine = create_engine(db_url)
# expire_on_commit=False is important so we can use objects after session closes
self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self.engine, expire_on_commit=False)
@contextmanager
def get_session(self):
"""Context manager for database sessions."""
session = self.SessionLocal()
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
# --- User Management ---
def get_user(self, email: str) -> Optional[User]:
"""Fetch a user by their email (username)."""
with self.get_session() as session:
result = session.execute(select(User).where(User.username == email))
return result.scalar_one_or_none()
def create_user(self, email: str, password: Optional[str] = None, display_name: Optional[str] = None) -> User:
"""Create a new local user."""
hashed_password = ph.hash(password) if password else None
user = User(
username=email,
password_hash=hashed_password,
display_name=display_name or email.split("@")[0]
)
with self.get_session() as session:
session.add(user)
session.commit()
session.refresh(user)
return user
def authenticate_user(self, email: str, password: str) -> Optional[User]:
"""Authenticate a user by email and password."""
user = self.get_user(email)
if not user or not user.password_hash:
return None
try:
ph.verify(user.password_hash, password)
return user
except VerifyMismatchError:
return None
def sync_user_from_oidc(self, email: str, display_name: Optional[str] = None) -> User:
"""
Synchronize a user from an OIDC provider.
If a user with the same email exists, update their display name.
Otherwise, create a new user.
"""
user = self.get_user(email)
if user:
# Update existing user if needed
if display_name and user.display_name != display_name:
with self.get_session() as session:
db_user = session.get(User, user.id)
db_user.display_name = display_name
session.commit()
session.refresh(db_user)
return db_user
return user
else:
# Create new user (no password for OIDC users initially)
return self.create_user(email=email, display_name=display_name)
# --- Conversation Management ---
def create_conversation(self, user_id: str, data_state: str, name: str, summary: Optional[str] = None) -> Conversation:
"""Create a new conversation for a user."""
conv = Conversation(
user_id=user_id,
data_state=data_state,
name=name,
summary=summary
)
with self.get_session() as session:
session.add(conv)
session.commit()
session.refresh(conv)
return conv
def get_conversations(self, user_id: str, data_state: str) -> List[Conversation]:
"""Get all conversations for a user and data state, ordered by creation time."""
with self.get_session() as session:
stmt = (
select(Conversation)
.where(Conversation.user_id == user_id, Conversation.data_state == data_state)
.order_by(Conversation.created_at.desc())
)
result = session.execute(stmt)
return list(result.scalars().all())
def rename_conversation(self, conversation_id: str, new_name: str) -> Optional[Conversation]:
"""Rename an existing conversation."""
with self.get_session() as session:
conv = session.get(Conversation, conversation_id)
if conv:
conv.name = new_name
session.commit()
session.refresh(conv)
return conv
def delete_conversation(self, conversation_id: str) -> bool:
"""Delete a conversation and its associated messages/plots (via cascade)."""
with self.get_session() as session:
conv = session.get(Conversation, conversation_id)
if conv:
session.delete(conv)
session.commit()
return True
return False
def update_conversation_summary(self, conversation_id: str, summary: str) -> Optional[Conversation]:
"""Update the summary of a conversation."""
with self.get_session() as session:
conv = session.get(Conversation, conversation_id)
if conv:
conv.summary = summary
session.commit()
session.refresh(conv)
return conv
# --- Message & Plot Management ---
def add_message(self, conversation_id: str, role: str, content: str, plots: Optional[List[bytes]] = None) -> Message:
"""Add a message to a conversation, optionally with plots."""
msg = Message(
conversation_id=conversation_id,
role=role,
content=content
)
with self.get_session() as session:
session.add(msg)
session.flush() # Populate msg.id for plots
if plots:
for plot_data in plots:
plot = Plot(message_id=msg.id, image_data=plot_data)
session.add(plot)
session.commit()
session.refresh(msg)
# Ensure plots are loaded before session closes if we need them
_ = msg.plots
return msg
def get_messages(self, conversation_id: str) -> List[Message]:
"""Get all messages for a conversation, ordered by creation time."""
with self.get_session() as session:
stmt = (
select(Message)
.where(Message.conversation_id == conversation_id)
.order_by(Message.created_at.asc())
)
result = session.execute(stmt)
messages = list(result.scalars().all())
# Pre-load plots for each message
for m in messages:
_ = m.plots
return messages

View File

@@ -0,0 +1,52 @@
from typing import List, Optional
from datetime import datetime, timezone
import uuid
from sqlalchemy import String, ForeignKey, DateTime, LargeBinary, Text
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
class Base(DeclarativeBase):
pass
class User(Base):
__tablename__ = "users"
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
username: Mapped[str] = mapped_column(String, unique=True, index=True)
password_hash: Mapped[Optional[str]] = mapped_column(String, nullable=True)
display_name: Mapped[Optional[str]] = mapped_column(String, nullable=True)
conversations: Mapped[List["Conversation"]] = relationship(back_populates="user", cascade="all, delete-orphan")
class Conversation(Base):
__tablename__ = "conversations"
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
user_id: Mapped[str] = mapped_column(ForeignKey("users.id"))
data_state: Mapped[str] = mapped_column(String)
name: Mapped[str] = mapped_column(String)
summary: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, default=lambda: datetime.now(timezone.utc))
user: Mapped["User"] = relationship(back_populates="conversations")
messages: Mapped[List["Message"]] = relationship(back_populates="conversation", cascade="all, delete-orphan")
class Message(Base):
__tablename__ = "messages"
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
conversation_id: Mapped[str] = mapped_column(ForeignKey("conversations.id"))
role: Mapped[str] = mapped_column(String)
content: Mapped[str] = mapped_column(Text)
created_at: Mapped[datetime] = mapped_column(DateTime, default=lambda: datetime.now(timezone.utc))
conversation: Mapped["Conversation"] = relationship(back_populates="messages")
plots: Mapped[List["Plot"]] = relationship(back_populates="message", cascade="all, delete-orphan")
class Plot(Base):
__tablename__ = "plots"
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
message_id: Mapped[str] = mapped_column(ForeignKey("messages.id"))
image_data: Mapped[bytes] = mapped_column(LargeBinary)
message: Mapped["Message"] = relationship(back_populates="plots")

83
src/ea_chatbot/schemas.py Normal file
View File

@@ -0,0 +1,83 @@
from pydantic import BaseModel, Field, computed_field
from typing import Sequence, Optional
import re
class TaskPlanContext(BaseModel):
'''Background context relevant to the task plan'''
initial_context: str = Field(
min_length=1,
description="Background information about the database/tables and previous conversations relevant to the task.",
)
assumptions: Sequence[str] = Field(
description="Assumptions made while working on the task.",
)
constraints: Optional[Sequence[str]] = Field(
description="Constraints that apply to the task.",
)
class TaskPlanResponse(BaseModel):
'''Structured plan to achieve the task objective'''
goal: str = Field(
min_length=1,
description="Single-sentence objective the plan must achieve.",
)
reflection: str = Field(
min_length=1,
description="High-level natural-language reasoning describing the user's request and the intended solution approach.",
)
context: TaskPlanContext = Field(
description="Background context relevant to the task plan.",
)
steps: Sequence[str] = Field(
min_length=1,
description="Ordered list of steps to execute that follow the 'Step <number>: <detail>' pattern.",
)
_IM_SEP_TOKEN_PATTERN = re.compile(re.escape("<|im_sep|>"))
_CODE_BLOCK_PATTERN = re.compile(r"```(?:python\s*)?(.*?)\s*```", re.DOTALL)
_FORBIDDEN_MODULES = (
"subprocess",
"sys",
"eval",
"exec",
"socket",
"urllib",
"shutil",
"pickle",
"ctypes",
"multiprocessing",
"tempfile",
"glob",
"pty",
"commands",
"cgi",
"cgitb",
"xml.etree.ElementTree",
"builtins",
)
_FORBIDDEN_MODULE_PATTERN = re.compile(
r"^((?:[^#].*)?\b(" + "|".join(map(re.escape, _FORBIDDEN_MODULES)) + r")\b.*)$",
flags=re.MULTILINE,
)
class CodeGenerationResponse(BaseModel):
'''Code generation response structure'''
code: str = Field(description="The generated code snippet to accomplish the task")
explanation: str = Field(description="Explanation of the generated code and its functionality")
@computed_field(return_type=str)
@property
def parsed_code(self) -> str:
'''Extracts the code snippet without any surrounding text'''
normalised = _IM_SEP_TOKEN_PATTERN.sub("```", self.code).strip()
match = _CODE_BLOCK_PATTERN.search(normalised)
candidate = match.group(1).strip() if match else normalised
sanitised = _FORBIDDEN_MODULE_PATTERN.sub(r"# not allowed \1", candidate)
return sanitised.strip()
class RankResponse(BaseModel):
'''Code ranking response structure'''
rank: int = Field(
ge=1, le=10,
description="Rank of the code snippet from 1 (best) to 10 (worst)"
)

24
src/ea_chatbot/types.py Normal file
View File

@@ -0,0 +1,24 @@
from typing import TypedDict, Optional
from enum import StrEnum
class DBSettings(TypedDict):
host: str
port: int
user: str
pswd: str
db: str
table: Optional[str]
class Agent(StrEnum):
EXPERT_SELECTOR = "Expert Selector"
ANALYST_SELECTOR = "Analyst Selector"
THEORIST = "Theorist"
THEORIST_WEB = "Theorist-Web"
THEORIST_CLARIFICATION = "Theorist-Clarification"
PLANNER = "Planner"
CODE_GENERATOR = "Code Generator"
CODE_DEBUGGER = "Code Debugger"
CODE_EXECUTOR = "Code Executor"
ERROR_CORRECTOR = "Error Corrector"
CODE_RANKER = "Code Ranker"
SOLUTION_SUMMARIZER = "Solution Summarizer"

View File

@@ -0,0 +1,12 @@
from .db_client import DBClient
from .llm_factory import get_llm_model
from .logging import get_logger, LangChainLoggingHandler
from . import helpers
__all__ = [
"DBClient",
"get_llm_model",
"get_logger",
"LangChainLoggingHandler",
"helpers"
]

View File

@@ -0,0 +1,234 @@
from typing import Optional, Dict, Any, List, TYPE_CHECKING
import yaml
import json
import os
from ea_chatbot.utils.db_client import DBClient
if TYPE_CHECKING:
from ea_chatbot.types import DBSettings
def _get_table_checksum(db_client: DBClient, table: str) -> str:
"""Calculates the checksum of the table using DML statistics from pg_stat_user_tables."""
query = f"""
SELECT md5(concat_ws('|', n_tup_ins, n_tup_upd, n_tup_del)) AS dml_hash
FROM pg_stat_user_tables
WHERE schemaname = 'public' AND relname = '{table}';"""
try:
return str(db_client.query_df(query).iloc[0, 0])
except Exception:
return "unknown_checksum"
def _update_checksum_file(filepath: str, table: str, checksum: str):
"""Updates the checksum file with the new checksum for the table."""
checksums = {}
if os.path.exists(filepath):
with open(filepath, 'r') as f:
for line in f:
if ':' in line:
k, v = line.strip().split(':', 1)
checksums[k] = v
checksums[table] = checksum
with open(filepath, 'w') as f:
for k, v in checksums.items():
f.write(f"{k}:{v}")
def get_data_summary(data_dir: str = "data") -> Optional[str]:
"""
Reads the inspection.yaml file and returns its content as a string.
"""
inspection_file = os.path.join(data_dir, "inspection.yaml")
if os.path.exists(inspection_file):
with open(inspection_file, 'r') as f:
return f.read()
return None
def get_primary_key(db_client: DBClient, table_name: str) -> Optional[str]:
"""
Dynamically identifies the primary key of the table.
Returns the column name of the primary key, or None if not found.
"""
query = f"""
SELECT kcu.column_name
FROM information_schema.key_column_usage AS kcu
JOIN information_schema.table_constraints AS tc
ON kcu.constraint_name = tc.constraint_name
AND kcu.table_schema = tc.table_schema
WHERE kcu.table_name = '{table_name}'
AND tc.constraint_type = 'PRIMARY KEY'
LIMIT 1;
"""
try:
df = db_client.query_df(query)
if not df.empty:
return str(df.iloc[0, 0])
except Exception as e:
print(f"Warning: Could not determine primary key for {table_name}: {e}")
return None
def inspect_db_table(
db_client: Optional[DBClient]=None,
db_settings: Optional["DBSettings"]=None,
data_dir: str = "data",
force_update: bool = False
) -> Optional[str]:
"""
Inspects the database table, generates statistics for each column,
and saves the inspection results to a YAML file locally.
Improvements:
- Dynamic Primary Key Discovery
- Cardinality (Unique Counts)
- Categorical Sample Values for low cardinality columns
- Robust Quoting
"""
inspection_file = os.path.join(data_dir, "inspection.yaml")
checksum_file = os.path.join(data_dir, "checksum")
# Initialize DB Client
if db_client is None:
if db_settings is None:
print("Error: Either db_client or db_settings must be provided.")
return None
try:
db_client = DBClient(db_settings)
except Exception as e:
print(f"Failed to create DBClient: {e}")
return None
table_name = db_client.settings.get('table')
if not table_name:
print("Error: Table name must be specified in DBSettings.")
return None
os.makedirs(data_dir, exist_ok=True)
# Checksum verification
new_checksum = _get_table_checksum(db_client, table_name)
has_changed = True
if os.path.exists(checksum_file):
try:
with open(checksum_file, 'r') as f:
saved_checksums = f.read().strip()
if f"{table_name}:{new_checksum}" in saved_checksums:
has_changed = False
except Exception:
pass # Force update on read error
if not has_changed and not force_update:
return get_data_summary(data_dir)
print(f"Regenerating inspection file for table '{table_name}'...")
# Fetch Table Metadata
try:
# Get columns and types
columns_query = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}';"
columns_df = db_client.query_df(columns_query)
# Get Row Counts
total_rows_df = db_client.query_df(f'SELECT COUNT(*) FROM "{table_name}"')
total_rows = int(total_rows_df.iloc[0, 0])
# Dynamic Primary Key
primary_key = get_primary_key(db_client, table_name)
# Get First/Last Rows (if PK exists)
first_row_df = None
last_row_df = None
if primary_key:
first_row_df = db_client.query_df(f'SELECT * FROM "{table_name}" ORDER BY "{primary_key}" ASC LIMIT 1')
last_row_df = db_client.query_df(f'SELECT * FROM "{table_name}" ORDER BY "{primary_key}" DESC LIMIT 1')
except Exception as e:
print(f"Failed to retrieve basic table info: {e}")
return None
stats_dict: Dict[str, Any] = {}
if primary_key:
stats_dict['primary_key'] = primary_key
for _, row in columns_df.iterrows():
col_name = row['column_name']
dtype = row['data_type']
try:
# Count Values
# Using robust quoting
count_df = db_client.query_df(f'SELECT COUNT("{col_name}") FROM "{table_name}"')
count_val = int(count_df.iloc[0,0])
# Count Unique (Cardinality)
unique_df = db_client.query_df(f'SELECT COUNT(DISTINCT "{col_name}") FROM "{table_name}"')
unique_count = int(unique_df.iloc[0,0])
col_stats: Dict[str, Any] = {
'dtype': dtype,
'count_of_values': count_val,
'count_of_nulls': total_rows - count_val,
'unique_count': unique_count
}
if count_val == 0:
stats_dict[col_name] = col_stats
continue
# Numerical Stats
if any(t in dtype for t in ('int', 'float', 'numeric', 'double', 'real', 'decimal')):
stats_query = f'SELECT AVG("{col_name}"), MIN("{col_name}"), MAX("{col_name}") FROM "{table_name}"'
stats_df = db_client.query_df(stats_query)
if not stats_df.empty:
col_stats['mean'] = float(stats_df.iloc[0,0]) if stats_df.iloc[0,0] is not None else None
col_stats['min'] = float(stats_df.iloc[0,1]) if stats_df.iloc[0,1] is not None else None
col_stats['max'] = float(stats_df.iloc[0,2]) if stats_df.iloc[0,2] is not None else None
# Temporal Stats
elif any(t in dtype for t in ('date', 'timestamp')):
stats_query = f'SELECT MIN("{col_name}"), MAX("{col_name}") FROM "{table_name}"'
stats_df = db_client.query_df(stats_query)
if not stats_df.empty:
col_stats['min'] = str(stats_df.iloc[0,0])
col_stats['max'] = str(stats_df.iloc[0,1])
# Categorical/Text Stats
else:
# Sample values if cardinality is low (< 20)
if 0 < unique_count < 20:
distinct_query = f'SELECT DISTINCT "{col_name}" FROM "{table_name}" ORDER BY "{col_name}" LIMIT 20'
distinct_df = db_client.query_df(distinct_query)
col_stats['distinct_values'] = distinct_df.iloc[:, 0].tolist()
if first_row_df is not None and not first_row_df.empty and col_name in first_row_df.columns:
col_stats['first_value'] = str(first_row_df.iloc[0][col_name])
if last_row_df is not None and not last_row_df.empty and col_name in last_row_df.columns:
col_stats['last_value'] = str(last_row_df.iloc[0][col_name])
stats_dict[col_name] = col_stats
except Exception as e:
print(f"Warning: Could not process column {col_name}: {e}")
# Load existing inspections to merge (if multiple tables)
existing_inspections = {}
if os.path.exists(inspection_file):
try:
with open(inspection_file, 'r') as f:
existing_inspections = yaml.safe_load(f) or {}
# Backup old file
os.rename(inspection_file, inspection_file + ".old")
except Exception:
pass
existing_inspections[table_name] = stats_dict
# Save new inspection
inspection_content = yaml.dump(existing_inspections, sort_keys=False, default_flow_style=False)
with open(inspection_file, 'w') as f:
f.write(inspection_content)
# Update Checksum
_update_checksum_file(checksum_file, table_name, new_checksum)
print(f"Inspection saved to {inspection_file}")
return inspection_content

View File

@@ -0,0 +1,21 @@
from typing import Optional, TYPE_CHECKING
import pandas as pd
from sqlalchemy import create_engine, text
if TYPE_CHECKING:
from ea_chatbot.types import DBSettings
class DBClient:
def __init__(self, settings: "DBSettings"):
self.settings = settings
self._engine = self._create_engine()
def _create_engine(self):
url = f"postgresql://{self.settings['user']}:{self.settings['pswd']}@{self.settings['host']}:{self.settings['port']}/{self.settings['db']}"
return create_engine(url)
def query_df(self, sql: str, params: Optional[dict] = None) -> pd.DataFrame:
with self._engine.connect() as conn:
result = conn.execute(text(sql), params or {})
df = pd.DataFrame(result.fetchall(), columns=result.keys())
return df

View File

@@ -0,0 +1,73 @@
from typing import Optional, TYPE_CHECKING, Dict, Any
from datetime import datetime, timezone
import yaml
import json
if TYPE_CHECKING:
from ea_chatbot.graph.state import AgentState
def ordinal(n: int) -> str:
return f"{n}{'th' if 11<=n<=13 else {1:'st',2:'nd',3:'rd'}.get(n%10, 'th')}"
def get_readable_date(date_obj: Optional[datetime] = None, tz: Optional[timezone] = None) -> str:
if date_obj is None:
date_obj = datetime.now(timezone.utc)
if tz:
date_obj = date_obj.astimezone(tz)
return date_obj.strftime(f"%a {ordinal(date_obj.day)} of %b %Y")
def to_yaml(json_str: str, indent: int = 2) -> str:
"""
Attempts to convert a JSON string (potentially malformed from LLM) to a YAML string.
"""
if not json_str: return ""
try:
# Try direct parse
data = json.loads(json_str)
except json.JSONDecodeError:
# Try simplified repair: replace single quotes
try:
cleaned = json_str.replace("'", '"')
data = json.loads(cleaned)
except Exception:
# Fallback: return raw string if unparseable
return json_str
return yaml.dump(data, indent=indent, sort_keys=False)
def merge_agent_state(current_state: "AgentState", update: Dict[str, Any]) -> "AgentState":
"""
Merges a partial state update into the current state, mimicking LangGraph reduction logic.
- Lists (messages, plots) are appended.
- Dictionaries (dfs) are shallow merged.
- Other fields are overwritten.
"""
new_state = current_state.copy()
for key, value in update.items():
if value is None:
new_state[key] = None
continue
# Accumulate lists (messages, plots)
if key in ["messages", "plots"] and isinstance(value, list):
current_list = new_state.get(key, [])
if not isinstance(current_list, list):
current_list = []
new_state[key] = current_list + value
# Shallow merge dictionaries (dfs)
elif key == "dfs" and isinstance(value, dict):
current_dict = new_state.get(key, {})
if not isinstance(current_dict, dict):
current_dict = {}
merged_dict = current_dict.copy()
merged_dict.update(value)
new_state[key] = merged_dict
# Overwrite everything else
else:
new_state[key] = value
return new_state

View File

@@ -0,0 +1,36 @@
from typing import Optional, cast, TYPE_CHECKING, Literal, Dict, List, Tuple, Any
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_openai import ChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.callbacks import BaseCallbackHandler
from ea_chatbot.config import LLMConfig
def get_llm_model(config: LLMConfig, callbacks: Optional[List[BaseCallbackHandler]] = None) -> BaseChatModel:
"""
Factory function to get a LangChain chat model based on configuration.
Args:
config: LLMConfig object containing model settings.
callbacks: Optional list of LangChain callback handlers.
Returns:
Initialized BaseChatModel instance.
Raises:
ValueError: If the provider is not supported.
"""
params = {
"temperature": config.temperature,
"max_tokens": config.max_tokens,
**config.provider_specific
}
# Filter out None values to allow defaults to take over if not specified
params = {k: v for k, v in params.items() if v is not None}
if config.provider.lower() == "openai":
return ChatOpenAI(model=config.model, callbacks=callbacks, **params)
elif config.provider.lower() == "google" or config.provider.lower() == "google_genai":
return ChatGoogleGenerativeAI(model=config.model, callbacks=callbacks, **params)
else:
raise ValueError(f"Unsupported LLM provider: {config.provider}")

View File

@@ -0,0 +1,141 @@
import json
import os
import logging
from rich.logging import RichHandler
from langchain_core.callbacks import BaseCallbackHandler
from logging.handlers import RotatingFileHandler
from typing import Any, Optional, Dict, List
class LangChainLoggingHandler(BaseCallbackHandler):
"""Callback handler for logging LangChain events."""
def __init__(self, logger: Optional[logging.Logger] = None):
self.logger = logger or get_logger("langchain")
def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> Any:
# Serialized might be empty or missing name depending on how it's called
model_name = serialized.get("name") or kwargs.get("name") or "LLM"
self.logger.info(f"[bold blue]LLM Started:[/bold blue] {model_name}")
def on_llm_end(self, response: Any, **kwargs: Any) -> Any:
llm_output = getattr(response, "llm_output", {}) or {}
# Try to find model name in output or use fallback
model_name = llm_output.get("model_name") or "LLM"
token_usage = llm_output.get("token_usage", {})
msg = f"[bold green]LLM Ended:[/bold green] {model_name}"
if token_usage:
prompt = token_usage.get("prompt_tokens", 0)
completion = token_usage.get("completion_tokens", 0)
total = token_usage.get("total_tokens", 0)
msg += f" | [yellow]Tokens: {total}[/yellow] ({prompt} prompt, {completion} completion)"
self.logger.info(msg)
def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> Any:
self.logger.error(f"[bold red]LLM Error:[/bold red] {str(error)}")
class ContextLoggerAdapter(logging.LoggerAdapter):
"""Adapter to inject contextual metadata into log records."""
def process(self, msg: Any, kwargs: Any) -> tuple[Any, Any]:
extra = self.extra.copy()
if "extra" in kwargs:
extra.update(kwargs.pop("extra"))
kwargs["extra"] = extra
return msg, kwargs
class FlexibleJSONEncoder(json.JSONEncoder):
def default(self, obj: Any) -> Any:
if hasattr(obj, 'model_dump'): # Pydantic v2
return obj.model_dump()
if hasattr(obj, 'dict'): # Pydantic v1
return obj.dict()
if hasattr(obj, '__dict__'):
return self.serialize_custom_object(obj)
elif isinstance(obj, dict):
return {k: self.default(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [self.default(item) for item in obj]
return super().default(obj)
def serialize_custom_object(self, obj: Any) -> dict:
obj_dict = obj.__dict__.copy()
obj_dict['__custom_class__'] = obj.__class__.__name__
return obj_dict
class JsonFormatter(logging.Formatter):
"""Custom JSON formatter for structured logging."""
def format(self, record: logging.LogRecord) -> str:
# Standard fields
log_record = {
"timestamp": self.formatTime(record, self.datefmt),
"level": record.levelname,
"message": record.getMessage(),
"module": record.module,
"name": record.name,
}
# Add exception info if present
if record.exc_info:
log_record["exception"] = self.formatException(record.exc_info)
# Add all other extra fields from the record
# Filter out standard logging attributes
standard_attrs = {
'args', 'asctime', 'created', 'exc_info', 'exc_text', 'filename',
'funcName', 'levelname', 'levelno', 'lineno', 'module',
'msecs', 'message', 'msg', 'name', 'pathname', 'process',
'processName', 'relativeCreated', 'stack_info', 'thread', 'threadName'
}
for key, value in record.__dict__.items():
if key not in standard_attrs:
log_record[key] = value
return json.dumps(log_record, cls=FlexibleJSONEncoder)
def get_logger(name: str = "ea_chatbot", level: Optional[str] = None, log_file: Optional[str] = None) -> logging.Logger:
"""Get a configured logger with RichHandler and optional Json FileHandler."""
# Ensure name starts with ea_chatbot for hierarchy if not already
if name != "ea_chatbot" and not name.startswith("ea_chatbot."):
full_name = f"ea_chatbot.{name}"
else:
full_name = name
logger = logging.getLogger(full_name)
# Configure root ea_chatbot logger if it hasn't been configured
root_logger = logging.getLogger("ea_chatbot")
if not root_logger.handlers:
# Default to INFO if level not provided
log_level = getattr(logging, (level or "INFO").upper(), logging.INFO)
root_logger.setLevel(log_level)
# Console Handler (Rich)
rich_handler = RichHandler(
rich_tracebacks=True,
markup=True,
show_time=False,
show_path=False
)
root_logger.addHandler(rich_handler)
root_logger.propagate = False
# Always check if we need to add a FileHandler, even if root is already configured
if log_file:
existing_file_handlers = [h for h in root_logger.handlers if isinstance(h, RotatingFileHandler)]
if not existing_file_handlers:
os.makedirs(os.path.dirname(log_file), exist_ok=True)
file_handler = RotatingFileHandler(
log_file, maxBytes=5*1024*1024, backupCount=3
)
file_handler.setFormatter(JsonFormatter())
root_logger.addHandler(file_handler)
# Refresh logger object in case it was created before root was configured
logger = logging.getLogger(full_name)
# If level is explicitly provided for a sub-logger, set it
if level:
logger.setLevel(getattr(logging, level.upper(), logging.INFO))
return logger

90
tests/test_app.py Normal file
View File

@@ -0,0 +1,90 @@
import os
import sys
import pytest
from unittest.mock import MagicMock, patch
from streamlit.testing.v1 import AppTest
from langchain_core.messages import AIMessage
# Ensure src is in python path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../src')))
@pytest.fixture(autouse=True)
def mock_history_manager():
"""Globally mock HistoryManager to avoid DB calls during AppTest."""
with patch("ea_chatbot.history.manager.HistoryManager") as mock_cls:
instance = mock_cls.return_value
instance.create_conversation.return_value = MagicMock(id="conv_123")
instance.get_conversations.return_value = []
instance.get_messages.return_value = []
instance.add_message.return_value = MagicMock()
instance.update_conversation_summary.return_value = MagicMock()
yield instance
@pytest.fixture
def mock_app_stream():
with patch("ea_chatbot.graph.workflow.app.stream") as mock_stream:
# Mock events from app.stream
mock_stream.return_value = [
{"query_analyzer": {"next_action": "research"}},
{"researcher": {"messages": [AIMessage(content="Research result")]}}
]
yield mock_stream
@pytest.fixture
def mock_user():
user = MagicMock()
user.id = "test_id"
user.username = "test@example.com"
user.display_name = "Test User"
return user
def test_app_initial_state(mock_app_stream, mock_user):
"""Test that the app initializes with the correct title and empty history."""
at = AppTest.from_file("src/ea_chatbot/app.py")
# Simulate logged-in user
at.session_state["user"] = mock_user
at.run()
assert not at.exception
assert at.title[0].value == "🗳️ Election Analytics Chatbot"
# Check session state initialization
assert "messages" in at.session_state
assert len(at.session_state["messages"]) == 0
def test_app_dev_mode_toggle(mock_app_stream, mock_user):
"""Test that the dev mode toggle exists in the sidebar."""
with patch.dict(os.environ, {"DEV_MODE": "false"}):
at = AppTest.from_file("src/ea_chatbot/app.py")
at.session_state["user"] = mock_user
at.run()
# Check for sidebar toggle (checkbox)
assert len(at.sidebar.checkbox) > 0
dev_mode_toggle = at.sidebar.checkbox[0]
assert dev_mode_toggle.label == "Dev Mode"
assert dev_mode_toggle.value is False
def test_app_graph_execution_streaming(mock_app_stream, mock_user, mock_history_manager):
"""Test that entering a prompt triggers the graph stream and displays response."""
at = AppTest.from_file("src/ea_chatbot/app.py")
at.session_state["user"] = mock_user
at.run()
# Input a question
at.chat_input[0].set_value("Test question").run()
# Verify graph stream was called
assert mock_app_stream.called
# Message should be added to history
assert len(at.session_state["messages"]) == 2
assert at.session_state["messages"][0]["role"] == "user"
assert at.session_state["messages"][1]["role"] == "assistant"
assert "Research result" in at.session_state["messages"][1]["content"]
# Verify history manager was used
assert mock_history_manager.create_conversation.called
assert mock_history_manager.add_message.called

83
tests/test_app_auth.py Normal file
View File

@@ -0,0 +1,83 @@
import pytest
from unittest.mock import MagicMock, patch
from streamlit.testing.v1 import AppTest
from ea_chatbot.auth import AuthType
@pytest.fixture
def mock_history_manager_instance():
# We need to patch before the AppTest loads the module
with patch("ea_chatbot.history.manager.HistoryManager") as mock_cls:
instance = mock_cls.return_value
yield instance
def test_auth_ui_flow_step1_to_password(mock_history_manager_instance):
"""Test UI transition from Step 1 (email) to Step 2a (password) for LOCAL user."""
# Patch BEFORE creating AppTest
mock_user = MagicMock()
mock_user.password_hash = "hashed_password"
mock_history_manager_instance.get_user.return_value = mock_user
at = AppTest.from_file("src/ea_chatbot/app.py")
at.run()
# Step 1: Identification
assert at.session_state["login_step"] == "email"
at.text_input[0].set_value("local@example.com")
at.button[0].click().run()
# Verify transition to password step
assert at.session_state["login_step"] == "login_password"
assert at.session_state["login_email"] == "local@example.com"
assert "Welcome back" in at.info[0].value
def test_auth_ui_flow_step1_to_register(mock_history_manager_instance):
"""Test UI transition from Step 1 (email) to Step 2b (registration) for NEW user."""
mock_history_manager_instance.get_user.return_value = None
at = AppTest.from_file("src/ea_chatbot/app.py")
at.run()
# Step 1: Identification
at.text_input[0].set_value("new@example.com")
at.button[0].click().run()
# Verify transition to registration step
assert at.session_state["login_step"] == "register_details"
assert at.session_state["login_email"] == "new@example.com"
assert "Create an account" in at.info[0].value
def test_auth_ui_flow_step1_to_oidc(mock_history_manager_instance):
"""Test UI transition from Step 1 (email) to Step 2c (OIDC) for OIDC user."""
# Mock history_manager.get_user to return a user WITHOUT a password
mock_user = MagicMock()
mock_user.password_hash = None
mock_history_manager_instance.get_user.return_value = mock_user
at = AppTest.from_file("src/ea_chatbot/app.py")
at.run()
# Step 1: Identification
at.text_input[0].set_value("oidc@example.com")
at.button[0].click().run()
# Verify transition to OIDC step
assert at.session_state["login_step"] == "oidc_login"
assert at.session_state["login_email"] == "oidc@example.com"
assert "configured for Single Sign-On" in at.info[0].value
def test_auth_ui_flow_back_button(mock_history_manager_instance):
"""Test that the 'Back' button returns to Step 1."""
at = AppTest.from_file("src/ea_chatbot/app.py")
# Simulate being on Step 2a
at.session_state["login_step"] = "login_password"
at.session_state["login_email"] = "local@example.com"
at.run()
# Click Back (index 1 in Step 2a)
at.button[1].click().run()
# Verify return to email step
assert at.session_state["login_step"] == "email"

43
tests/test_auth.py Normal file
View File

@@ -0,0 +1,43 @@
import pytest
from unittest.mock import MagicMock, patch
from ea_chatbot.auth import OIDCClient
@patch("ea_chatbot.auth.OAuth2Session")
def test_oidc_client_initialization(mock_oauth):
client = OIDCClient(
client_id="test_id",
client_secret="test_secret",
server_metadata_url="https://test.server/.well-known/openid-configuration"
)
assert client.oauth_session is not None
@patch("ea_chatbot.auth.requests")
@patch("ea_chatbot.auth.OAuth2Session")
def test_get_login_url(mock_oauth_cls, mock_requests):
# Setup mock session
mock_session = MagicMock()
mock_oauth_cls.return_value = mock_session
# Mock metadata response
mock_response = MagicMock()
mock_response.json.return_value = {
"authorization_endpoint": "https://test.server/auth",
"token_endpoint": "https://test.server/token",
"userinfo_endpoint": "https://test.server/userinfo"
}
mock_requests.get.return_value = mock_response
# Mock authorization url generation
mock_session.create_authorization_url.return_value = ("https://test.server/auth?response_type=code", "state")
client = OIDCClient(
client_id="test_id",
client_secret="test_secret",
server_metadata_url="https://test.server/.well-known/openid-configuration"
)
url = client.get_login_url()
assert url == "https://test.server/auth?response_type=code"
# Verify metadata was fetched via requests
mock_requests.get.assert_called_with("https://test.server/.well-known/openid-configuration")

49
tests/test_auth_flow.py Normal file
View File

@@ -0,0 +1,49 @@
import pytest
from unittest.mock import MagicMock
from ea_chatbot.history.manager import HistoryManager
from ea_chatbot.auth import get_user_auth_type, AuthType
# Mocks
@pytest.fixture
def mock_history_manager():
return MagicMock(spec=HistoryManager)
def test_auth_flow_existing_local_user(mock_history_manager):
"""Test that an existing user with a password returns LOCAL auth type."""
# Setup
mock_user = MagicMock()
mock_user.password_hash = "hashed_secret"
mock_history_manager.get_user.return_value = mock_user
# Execute
auth_type = get_user_auth_type("test@example.com", mock_history_manager)
# Verify
assert auth_type == AuthType.LOCAL
mock_history_manager.get_user.assert_called_once_with("test@example.com")
def test_auth_flow_existing_oidc_user(mock_history_manager):
"""Test that an existing user WITHOUT a password returns OIDC auth type."""
# Setup
mock_user = MagicMock()
mock_user.password_hash = None # No password implies OIDC
mock_history_manager.get_user.return_value = mock_user
# Execute
auth_type = get_user_auth_type("sso@example.com", mock_history_manager)
# Verify
assert auth_type == AuthType.OIDC
mock_history_manager.get_user.assert_called_once_with("sso@example.com")
def test_auth_flow_new_user(mock_history_manager):
"""Test that a non-existent user returns NEW auth type."""
# Setup
mock_history_manager.get_user.return_value = None
# Execute
auth_type = get_user_auth_type("new@example.com", mock_history_manager)
# Verify
assert auth_type == AuthType.NEW
mock_history_manager.get_user.assert_called_once_with("new@example.com")

62
tests/test_coder.py Normal file
View File

@@ -0,0 +1,62 @@
import pytest
from unittest.mock import MagicMock, patch
from ea_chatbot.graph.nodes.coder import coder_node
from ea_chatbot.graph.nodes.error_corrector import error_corrector_node
@pytest.fixture
def mock_state():
return {
"messages": [],
"question": "Show me results for New Jersey",
"plan": "Step 1: Load data\nStep 2: Filter by NJ",
"code": None,
"error": None,
"plots": [],
"dfs": {},
"next_action": "plan"
}
@patch("ea_chatbot.graph.nodes.coder.get_llm_model")
@patch("ea_chatbot.utils.database_inspection.get_data_summary")
def test_coder_node(mock_get_summary, mock_get_llm, mock_state):
"""Test coder node generates code from plan."""
mock_get_summary.return_value = "Column: Name, Type: text"
mock_llm = MagicMock()
mock_get_llm.return_value = mock_llm
from ea_chatbot.schemas import CodeGenerationResponse
mock_response = CodeGenerationResponse(
code="import pandas as pd\nprint('Hello')",
explanation="Generated code"
)
mock_llm.with_structured_output.return_value.invoke.return_value = mock_response
result = coder_node(mock_state)
assert "code" in result
assert "import pandas as pd" in result["code"]
assert "error" in result
assert result["error"] is None
@patch("ea_chatbot.graph.nodes.error_corrector.get_llm_model")
def test_error_corrector_node(mock_get_llm, mock_state):
"""Test error corrector node fixes code."""
mock_state["code"] = "import pandas as pd\nprint(undefined_var)"
mock_state["error"] = "NameError: name 'undefined_var' is not defined"
mock_llm = MagicMock()
mock_get_llm.return_value = mock_llm
from ea_chatbot.schemas import CodeGenerationResponse
mock_response = CodeGenerationResponse(
code="import pandas as pd\nprint('Defined')",
explanation="Fixed variable"
)
mock_llm.with_structured_output.return_value.invoke.return_value = mock_response
result = error_corrector_node(mock_state)
assert "code" in result
assert "print('Defined')" in result["code"]
assert result["error"] is None

47
tests/test_config.py Normal file
View File

@@ -0,0 +1,47 @@
import pytest
from pydantic import ValidationError
from ea_chatbot.config import Settings, LLMConfig
def test_default_settings():
"""Test that default settings are loaded correctly."""
settings = Settings()
# Check default config for query analyzer
assert isinstance(settings.query_analyzer_llm, LLMConfig)
assert settings.query_analyzer_llm.provider == "openai"
assert settings.query_analyzer_llm.model == "gpt-5-mini"
assert settings.query_analyzer_llm.temperature == 0.0
# Check default config for planner
assert isinstance(settings.planner_llm, LLMConfig)
assert settings.planner_llm.provider == "openai"
assert settings.planner_llm.model == "gpt-5-mini"
def test_env_override(monkeypatch):
"""Test that environment variables override defaults."""
monkeypatch.setenv("QUERY_ANALYZER_LLM__MODEL", "gpt-3.5-turbo")
monkeypatch.setenv("QUERY_ANALYZER_LLM__TEMPERATURE", "0.7")
settings = Settings()
assert settings.query_analyzer_llm.model == "gpt-3.5-turbo"
assert settings.query_analyzer_llm.temperature == 0.7
def test_provider_specific_params():
"""Test that provider specific parameters can be set."""
config = LLMConfig(
provider="openai",
model="o1-preview",
provider_specific={"reasoning_effort": "high"}
)
assert config.provider_specific["reasoning_effort"] == "high"
def test_oidc_settings(monkeypatch):
"""Test OIDC settings configuration."""
monkeypatch.setenv("OIDC_CLIENT_ID", "test_client_id")
monkeypatch.setenv("OIDC_CLIENT_SECRET", "test_client_secret")
monkeypatch.setenv("OIDC_SERVER_METADATA_URL", "https://test.server/.well-known/openid-configuration")
settings = Settings()
assert settings.oidc_client_id == "test_client_id"
assert settings.oidc_client_secret == "test_client_secret"
assert settings.oidc_server_metadata_url == "https://test.server/.well-known/openid-configuration"

View File

@@ -0,0 +1,56 @@
import pytest
from unittest.mock import MagicMock, patch
from langchain_core.messages import HumanMessage, AIMessage
from ea_chatbot.graph.nodes.summarize_conversation import summarize_conversation_node
@pytest.fixture
def mock_state_with_history():
return {
"messages": [
HumanMessage(content="Show me the 2024 results for Florida"),
AIMessage(content="Here are the results for Florida in 2024...")
],
"summary": "The user is asking about 2024 election results."
}
@patch("ea_chatbot.graph.nodes.summarize_conversation.get_llm_model")
def test_summarize_conversation_node_updates_summary(mock_get_llm, mock_state_with_history):
mock_llm_instance = MagicMock()
mock_get_llm.return_value = mock_llm_instance
# Mock LLM response for updating summary
mock_llm_instance.invoke.return_value = AIMessage(content="Updated summary including NJ results.")
# Add new messages to simulate a completed turn
mock_state_with_history["messages"].extend([
HumanMessage(content="What about in New Jersey?"),
AIMessage(content="In New Jersey, the 2024 results were...")
])
result = summarize_conversation_node(mock_state_with_history)
assert "summary" in result
assert result["summary"] == "Updated summary including NJ results."
# Verify LLM was called with the correct context
call_messages = mock_llm_instance.invoke.call_args[0][0]
# Should include current summary and last turn messages
assert "Current summary: The user is asking about 2024 election results." in call_messages[0].content
@patch("ea_chatbot.graph.nodes.summarize_conversation.get_llm_model")
def test_summarize_conversation_node_initial_summary(mock_get_llm):
state = {
"messages": [
HumanMessage(content="Hi"),
AIMessage(content="Hello! How can I help you today?")
],
"summary": ""
}
mock_llm_instance = MagicMock()
mock_get_llm.return_value = mock_llm_instance
mock_llm_instance.invoke.return_value = AIMessage(content="Initial greeting.")
result = summarize_conversation_node(state)
assert result["summary"] == "Initial greeting."

View File

@@ -0,0 +1,195 @@
import pytest
import pandas as pd
import os
from unittest.mock import MagicMock, patch
from ea_chatbot.utils.database_inspection import get_primary_key, inspect_db_table, get_data_summary
@pytest.fixture
def mock_db_client():
mock_client = MagicMock()
mock_client.settings = {"table": "test_table"}
return mock_client
def test_get_primary_key(mock_db_client):
"""Test dynamic primary key discovery."""
# Mock response for primary key query
mock_df = pd.DataFrame({"column_name": ["my_pk"]})
mock_db_client.query_df.return_value = mock_df
pk = get_primary_key(mock_db_client, "test_table")
assert pk == "my_pk"
# Verify the query was called (at least once)
assert mock_db_client.query_df.called
def test_inspect_db_table_improved(mock_db_client, tmp_path):
"""Test improved inspect_db_table with cardinality and sampling."""
data_dir = str(tmp_path)
# 1. Mock columns and types
columns_df = pd.DataFrame({
"column_name": ["id", "category", "count"],
"data_type": ["integer", "text", "integer"]
})
# 2. Mock row count
total_rows_df = pd.DataFrame([{"count": 100}])
# 3. Mock PK discovery
pk_df = pd.DataFrame({"column_name": ["id"]})
# 4. Mock stats for columns
# We need to handle multiple calls to query_df
def side_effect(query):
if "information_schema.columns" in query:
return columns_df
if "COUNT(*)" in query:
return total_rows_df
if "information_schema.key_column_usage" in query:
return pk_df
# Category stats
if 'COUNT("category")' in query:
return pd.DataFrame([{"count": 100}])
if 'COUNT(DISTINCT "category")' in query:
return pd.DataFrame([{"count": 5}])
if 'SELECT DISTINCT "category"' in query:
return pd.DataFrame({"category": ["A", "B", "C", "D", "E"]})
# Count stats
if 'COUNT("count")' in query:
return pd.DataFrame([{"count": 100}])
if 'COUNT(DISTINCT "count")' in query:
return pd.DataFrame([{"count": 100}])
if 'AVG("count")' in query:
return pd.DataFrame([{"avg": 10.0, "min": 1, "max": 20}])
# ID stats (fix for IndexError)
if 'COUNT("id")' in query:
return pd.DataFrame([{"count": 100}])
if 'COUNT(DISTINCT "id")' in query:
return pd.DataFrame([{"count": 100}])
if 'AVG("id")' in query:
return pd.DataFrame([{"avg": 50.0, "min": 1, "max": 100}])
return pd.DataFrame()
mock_db_client.query_df.side_effect = side_effect
# Run inspection
inspect_db_table(mock_db_client, data_dir=data_dir)
# Read summary to verify
summary = get_data_summary(data_dir)
assert summary is not None
assert "test_table" in summary
assert "category" in summary
assert "distinct_values" in summary
assert "unique_count: 5" in summary
assert "- A" in summary
assert "- E" in summary
assert "primary_key: id" in summary
def test_get_data_summary_none(tmp_path):
"""Test get_data_summary when file doesn't exist."""
assert get_data_summary(str(tmp_path)) is None
def test_inspect_db_table_temporal(mock_db_client, tmp_path):
"""Test inspect_db_table with temporal columns."""
data_dir = str(tmp_path)
columns_df = pd.DataFrame({
"column_name": ["created_at"],
"data_type": ["timestamp without time zone"]
})
total_rows_df = pd.DataFrame([{"count": 50}])
pk_df = pd.DataFrame() # No PK
def side_effect(query):
if "information_schema.columns" in query:
return columns_df
if "COUNT(*)" in query:
return total_rows_df
if "information_schema.key_column_usage" in query:
return pk_df
if 'COUNT("created_at")' in query:
return pd.DataFrame([{"count": 50}])
if 'COUNT(DISTINCT "created_at")' in query:
return pd.DataFrame([{"count": 50}])
if 'MIN("created_at")' in query:
return pd.DataFrame([{"min": "2023-01-01", "max": "2023-12-31"}])
return pd.DataFrame()
mock_db_client.query_df.side_effect = side_effect
inspect_db_table(mock_db_client, data_dir=data_dir)
summary = get_data_summary(data_dir)
assert "created_at" in summary
assert "min: '2023-01-01'" in summary
assert "max: '2023-12-31'" in summary
def test_inspect_db_table_high_cardinality(mock_db_client, tmp_path):
"""Test inspect_db_table with high cardinality categorical column (no sample values)."""
data_dir = str(tmp_path)
columns_df = pd.DataFrame({
"column_name": ["user_id"],
"data_type": ["text"]
})
total_rows_df = pd.DataFrame([{"count": 100}])
pk_df = pd.DataFrame()
def side_effect(query):
if "information_schema.columns" in query:
return columns_df
if "COUNT(*)" in query:
return total_rows_df
if "information_schema.key_column_usage" in query:
return pk_df
if 'COUNT("user_id")' in query:
return pd.DataFrame([{"count": 100}])
if 'COUNT(DISTINCT "user_id")' in query:
# High cardinality > 20
return pd.DataFrame([{"count": 50}])
return pd.DataFrame()
mock_db_client.query_df.side_effect = side_effect
inspect_db_table(mock_db_client, data_dir=data_dir)
summary = get_data_summary(data_dir)
assert "user_id" in summary
assert "unique_count: 50" in summary
# Should NOT have distinct_values
assert "distinct_values" not in summary
def test_inspect_db_table_checksum_skip(mock_db_client, tmp_path):
"""Test that inspection is skipped if checksum matches."""
data_dir = str(tmp_path)
table = "test_table"
# 1. Create a fake checksum file
os.makedirs(data_dir, exist_ok=True)
# Checksum is md5 of "ins|upd|del". Let's say mock returns "my_hash"
# Mock checksum query
mock_db_client.query_df.return_value = pd.DataFrame([{"dml_hash": "my_hash"}])
# Write existing checksum
with open(os.path.join(data_dir, "checksum"), "w") as f:
f.write(f"{table}:my_hash\n")
# Write existing inspection
with open(os.path.join(data_dir, "inspection.yaml"), "w") as f:
f.write(f"{table}: {{ existing: true }}")
# Run inspection
result = inspect_db_table(mock_db_client, data_dir=data_dir)
# Should return existing content
assert "existing: true" in result
# query_df should be called ONLY for checksum (once)
# verify count of calls?
# Logic: 1 call for checksum. If match, return.
assert mock_db_client.query_df.call_count == 1

123
tests/test_executor.py Normal file
View File

@@ -0,0 +1,123 @@
import pytest
import pandas as pd
from unittest.mock import MagicMock, patch
from matplotlib.figure import Figure
from ea_chatbot.graph.nodes.executor import executor_node
from ea_chatbot.graph.state import AgentState
@pytest.fixture
def mock_settings():
with patch("ea_chatbot.graph.nodes.executor.Settings") as MockSettings:
mock_settings_instance = MagicMock()
mock_settings_instance.db_host = "localhost"
mock_settings_instance.db_port = 5432
mock_settings_instance.db_user = "user"
mock_settings_instance.db_pswd = "pass"
mock_settings_instance.db_name = "test_db"
mock_settings_instance.db_table = "test_table"
MockSettings.return_value = mock_settings_instance
yield mock_settings_instance
@pytest.fixture
def mock_db_client():
with patch("ea_chatbot.graph.nodes.executor.DBClient") as MockDBClient:
mock_client_instance = MagicMock()
MockDBClient.return_value = mock_client_instance
yield mock_client_instance
def test_executor_node_success_simple_print(mock_settings, mock_db_client):
"""Test executing simple code that prints to stdout."""
state = {
"code": "print('Hello, World!')",
"question": "test",
"messages": []
}
result = executor_node(state)
assert "code_output" in result
assert "Hello, World!" in result["code_output"]
assert result["error"] is None
assert result["plots"] == []
assert result["dfs"] == {}
def test_executor_node_success_dataframe(mock_settings, mock_db_client):
"""Test executing code that creates a DataFrame."""
code = """
import pandas as pd
df = pd.DataFrame({'a': [1, 2], 'b': [3, 4]})
print(df)
"""
state = {
"code": code,
"question": "test",
"messages": []
}
result = executor_node(state)
assert "code_output" in result
assert "a b" in result["code_output"] # Check part of DF string representation
assert "dfs" in result
assert "df" in result["dfs"]
assert isinstance(result["dfs"]["df"], pd.DataFrame)
def test_executor_node_success_plot(mock_settings, mock_db_client):
"""Test executing code that generates a plot."""
code = """
import matplotlib.pyplot as plt
fig = plt.figure()
plots.append(fig)
print('Plot generated')
"""
state = {
"code": code,
"question": "test",
"messages": []
}
result = executor_node(state)
assert "Plot generated" in result["code_output"]
assert "plots" in result
assert len(result["plots"]) == 1
assert isinstance(result["plots"][0], Figure)
def test_executor_node_error_syntax(mock_settings, mock_db_client):
"""Test executing code with a syntax error."""
state = {
"code": "print('Hello World", # Missing closing quote
"question": "test",
"messages": []
}
result = executor_node(state)
assert result["error"] is not None
assert "SyntaxError" in result["error"]
def test_executor_node_error_runtime(mock_settings, mock_db_client):
"""Test executing code with a runtime error."""
state = {
"code": "print(1 / 0)",
"question": "test",
"messages": []
}
result = executor_node(state)
assert result["error"] is not None
assert "ZeroDivisionError" in result["error"]
def test_executor_node_no_code(mock_settings, mock_db_client):
"""Test handling when no code is provided."""
state = {
"code": None,
"question": "test",
"messages": []
}
result = executor_node(state)
assert "error" in result
assert "No code provided" in result["error"]

77
tests/test_helpers.py Normal file
View File

@@ -0,0 +1,77 @@
import pytest
from langchain_core.messages import HumanMessage, AIMessage
from ea_chatbot.utils.helpers import merge_agent_state
def test_merge_agent_state_list_accumulation():
"""Verify that list fields (messages, plots) are accumulated (appended)."""
current_state = {
"messages": [HumanMessage(content="hello")],
"plots": ["plot1"]
}
update = {
"messages": [AIMessage(content="hi")],
"plots": ["plot2"]
}
merged = merge_agent_state(current_state, update)
assert len(merged["messages"]) == 2
assert merged["messages"][0].content == "hello"
assert merged["messages"][1].content == "hi"
assert len(merged["plots"]) == 2
assert merged["plots"] == ["plot1", "plot2"]
def test_merge_agent_state_dict_update():
"""Verify that dictionary fields (dfs) are updated (shallow merge)."""
current_state = {
"dfs": {"df1": "data1"}
}
update = {
"dfs": {"df2": "data2"}
}
merged = merge_agent_state(current_state, update)
assert merged["dfs"] == {"df1": "data1", "df2": "data2"}
# Verify overwrite within dict
update_overwrite = {
"dfs": {"df1": "new_data1"}
}
merged_overwrite = merge_agent_state(merged, update_overwrite)
assert merged_overwrite["dfs"] == {"df1": "new_data1", "df2": "data2"}
def test_merge_agent_state_standard_overwrite():
"""Verify that standard fields are overwritten."""
current_state = {
"question": "old question",
"next_action": "old action",
"plan": "old plan"
}
update = {
"question": "new question",
"next_action": "new action",
"plan": "new plan"
}
merged = merge_agent_state(current_state, update)
assert merged["question"] == "new question"
assert merged["next_action"] == "new action"
assert merged["plan"] == "new plan"
def test_merge_agent_state_none_handling():
"""Verify that None updates or missing keys in update don't break things."""
current_state = {
"question": "test",
"messages": ["msg1"]
}
# Empty update
assert merge_agent_state(current_state, {}) == current_state
# Update with None value for overwritable field
merged_none = merge_agent_state(current_state, {"question": None})
assert merged_none["question"] is None
assert merged_none["messages"] == ["msg1"]

View File

@@ -0,0 +1,145 @@
import pytest
from ea_chatbot.history.manager import HistoryManager
from ea_chatbot.history.models import User, Conversation, Message, Plot
from ea_chatbot.config import Settings
from sqlalchemy import delete
@pytest.fixture
def history_manager():
settings = Settings()
manager = HistoryManager(settings.history_db_url)
# Clean up tables before tests (order matters because of foreign keys)
with manager.get_session() as session:
session.execute(delete(Plot))
session.execute(delete(Message))
session.execute(delete(Conversation))
session.execute(delete(User))
return manager
def test_history_manager_initialization(history_manager):
assert history_manager.engine is not None
assert history_manager.SessionLocal is not None
def test_history_manager_session_context(history_manager):
with history_manager.get_session() as session:
assert session is not None
def test_get_user_not_found(history_manager):
user = history_manager.get_user("nonexistent@example.com")
assert user is None
def test_authenticate_user_success(history_manager):
email = "test@example.com"
password = "secretpassword"
history_manager.create_user(email=email, password=password)
user = history_manager.authenticate_user(email, password)
assert user is not None
assert user.username == email
def test_authenticate_user_failure(history_manager):
email = "test@example.com"
history_manager.create_user(email=email, password="correctpassword")
user = history_manager.authenticate_user(email, "wrongpassword")
assert user is None
def test_sync_user_from_oidc_new_user(history_manager):
user = history_manager.sync_user_from_oidc(
email="new@example.com",
display_name="New User"
)
assert user is not None
assert user.username == "new@example.com"
assert user.display_name == "New User"
def test_sync_user_from_oidc_existing_user(history_manager):
# First sync
history_manager.sync_user_from_oidc(
email="existing@example.com",
display_name="First Name"
)
# Second sync should update or return same user
user = history_manager.sync_user_from_oidc(
email="existing@example.com",
display_name="Updated Name"
)
assert user.display_name == "Updated Name"
# --- Conversation Management Tests ---
@pytest.fixture
def user(history_manager):
return history_manager.create_user(email="conv_user@example.com")
def test_create_conversation(history_manager, user):
conv = history_manager.create_conversation(
user_id=user.id,
data_state="new_jersey",
name="Test Chat",
summary="A test conversation summary"
)
assert conv is not None
assert conv.name == "Test Chat"
assert conv.summary == "A test conversation summary"
assert conv.user_id == user.id
def test_get_conversations(history_manager, user):
history_manager.create_conversation(user_id=user.id, data_state="nj", name="C1")
history_manager.create_conversation(user_id=user.id, data_state="nj", name="C2")
history_manager.create_conversation(user_id=user.id, data_state="ny", name="C3")
nj_convs = history_manager.get_conversations(user_id=user.id, data_state="nj")
assert len(nj_convs) == 2
ny_convs = history_manager.get_conversations(user_id=user.id, data_state="ny")
assert len(ny_convs) == 1
def test_rename_conversation(history_manager, user):
conv = history_manager.create_conversation(user.id, "nj", "Old Name")
updated = history_manager.rename_conversation(conv.id, "New Name")
assert updated.name == "New Name"
def test_delete_conversation(history_manager, user):
conv = history_manager.create_conversation(user.id, "nj", "To Delete")
history_manager.delete_conversation(conv.id)
convs = history_manager.get_conversations(user.id, "nj")
assert len(convs) == 0
# --- Message Management Tests ---
@pytest.fixture
def conversation(history_manager, user):
return history_manager.create_conversation(user.id, "nj", "Msg Test Conv")
def test_add_message(history_manager, conversation):
msg = history_manager.add_message(
conversation_id=conversation.id,
role="user",
content="Hello world"
)
assert msg is not None
assert msg.content == "Hello world"
assert msg.role == "user"
assert msg.conversation_id == conversation.id
def test_add_message_with_plots(history_manager, conversation):
plots_data = [b"fake_plot_1", b"fake_plot_2"]
msg = history_manager.add_message(
conversation_id=conversation.id,
role="assistant",
content="Here are plots",
plots=plots_data
)
assert len(msg.plots) == 2
assert msg.plots[0].image_data == b"fake_plot_1"
def test_get_messages(history_manager, conversation):
history_manager.add_message(conversation.id, "user", "Q1")
history_manager.add_message(conversation.id, "assistant", "A1")
messages = history_manager.get_messages(conversation.id)
assert len(messages) == 2
assert messages[0].content == "Q1"
assert messages[1].content == "A1"

View File

@@ -0,0 +1,55 @@
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, DeclarativeBase
# We anticipate these imports will fail initially
try:
from ea_chatbot.history.models import Base, User, Conversation, Message, Plot
except ImportError:
Base = None
User = None
Conversation = None
Message = None
Plot = None
def test_models_exist():
assert User is not None, "User model not found"
assert Conversation is not None, "Conversation model not found"
assert Message is not None, "Message model not found"
assert Plot is not None, "Plot model not found"
assert Base is not None, "Base declarative class not found"
def test_user_model_columns():
if not User: pytest.fail("User model undefined")
# Basic check if columns exist (by inspecting __table__.columns)
columns = User.__table__.columns.keys()
assert "id" in columns
assert "username" in columns
assert "password_hash" in columns
assert "display_name" in columns
def test_conversation_model_columns():
if not Conversation: pytest.fail("Conversation model undefined")
columns = Conversation.__table__.columns.keys()
assert "id" in columns
assert "user_id" in columns
assert "data_state" in columns
assert "name" in columns
assert "summary" in columns
assert "created_at" in columns
def test_message_model_columns():
if not Message: pytest.fail("Message model undefined")
columns = Message.__table__.columns.keys()
assert "id" in columns
assert "role" in columns
assert "content" in columns
assert "conversation_id" in columns
assert "created_at" in columns
def test_plot_model_columns():
if not Plot: pytest.fail("Plot model undefined")
columns = Plot.__table__.columns.keys()
assert "id" in columns
assert "message_id" in columns
assert "image_data" in columns

View File

@@ -0,0 +1,4 @@
import ea_chatbot.history
def test_history_module_importable():
assert ea_chatbot.history is not None

54
tests/test_llm_factory.py Normal file
View File

@@ -0,0 +1,54 @@
import pytest
from langchain_openai import ChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI
from ea_chatbot.config import LLMConfig
from ea_chatbot.utils.llm_factory import get_llm_model
def test_get_openai_model(monkeypatch):
"""Test creating an OpenAI model."""
monkeypatch.setenv("OPENAI_API_KEY", "dummy")
config = LLMConfig(
provider="openai",
model="gpt-4o",
temperature=0.5,
max_tokens=100
)
model = get_llm_model(config)
assert isinstance(model, ChatOpenAI)
assert model.model_name == "gpt-4o"
assert model.temperature == 0.5
assert model.max_tokens == 100
def test_get_google_model(monkeypatch):
"""Test creating a Google model."""
monkeypatch.setenv("GOOGLE_API_KEY", "dummy")
config = LLMConfig(
provider="google",
model="gemini-1.5-pro",
temperature=0.7
)
model = get_llm_model(config)
assert isinstance(model, ChatGoogleGenerativeAI)
assert model.model == "gemini-1.5-pro"
assert model.temperature == 0.7
def test_unsupported_provider():
"""Test that an unsupported provider raises an error."""
config = LLMConfig(provider="unknown", model="test")
with pytest.raises(ValueError, match="Unsupported LLM provider: unknown"):
get_llm_model(config)
def test_provider_specific_params(monkeypatch):
"""Test passing provider specific params."""
monkeypatch.setenv("OPENAI_API_KEY", "dummy")
config = LLMConfig(
provider="openai",
model="o1-preview",
provider_specific={"reasoning_effort": "high"}
)
# Note: reasoning_effort support depends on the langchain-openai version,
# but we check if kwargs are passed.
model = get_llm_model(config)
assert isinstance(model, ChatOpenAI)
# Check if reasoning_effort was passed correctly
assert getattr(model, "reasoning_effort", None) == "high"

View File

@@ -0,0 +1,19 @@
import pytest
from langchain_openai import ChatOpenAI
from ea_chatbot.config import LLMConfig
from ea_chatbot.utils.llm_factory import get_llm_model
from langchain_core.callbacks import BaseCallbackHandler
class MockHandler(BaseCallbackHandler):
pass
def test_get_llm_model_with_callbacks(monkeypatch):
"""Test that callbacks are passed to the model."""
monkeypatch.setenv("OPENAI_API_KEY", "dummy")
config = LLMConfig(provider="openai", model="gpt-4o")
handler = MockHandler()
model = get_llm_model(config, callbacks=[handler])
assert isinstance(model, ChatOpenAI)
assert handler in model.callbacks

View File

@@ -0,0 +1,44 @@
import logging
import pytest
import io
import json
from ea_chatbot.utils.logging import ContextLoggerAdapter, JsonFormatter
@pytest.fixture
def json_log_capture():
"""Fixture to capture JSON logs."""
log_stream = io.StringIO()
logger = logging.getLogger("test_context")
logger.setLevel(logging.INFO)
for handler in logger.handlers[:]:
logger.removeHandler(handler)
handler = logging.StreamHandler(log_stream)
handler.setFormatter(JsonFormatter())
logger.addHandler(handler)
return logger, log_stream
def test_context_logger_adapter_injects_metadata(json_log_capture):
"""Test that ContextLoggerAdapter injects metadata into the log record."""
logger, log_stream = json_log_capture
adapter = ContextLoggerAdapter(logger, {"run_id": "123", "node_name": "test_node"})
adapter.info("test message")
data = json.loads(log_stream.getvalue())
assert data["message"] == "test message"
assert data["run_id"] == "123"
assert data["node_name"] == "test_node"
def test_context_logger_adapter_override_metadata(json_log_capture):
"""Test that extra metadata can be provided during call."""
logger, log_stream = json_log_capture
adapter = ContextLoggerAdapter(logger, {"run_id": "123"})
# Passing extra context via the 'extra' parameter in standard logging
# Note: Our adapter should handle merging this.
adapter.info("test message", extra={"node_name": "dynamic_node"})
data = json.loads(log_stream.getvalue())
assert data["run_id"] == "123"
assert data["node_name"] == "dynamic_node"

View File

@@ -0,0 +1,67 @@
import logging
import pytest
from ea_chatbot.utils.logging import get_logger
@pytest.fixture(autouse=True)
def reset_logging():
"""Reset the ea_chatbot logger handlers before each test."""
logger = logging.getLogger("ea_chatbot")
# Remove all existing handlers
for handler in logger.handlers[:]:
logger.removeHandler(handler)
yield
# Also clean up after test
for handler in logger.handlers[:]:
logger.removeHandler(handler)
def test_get_logger_singleton():
"""Test that get_logger returns the same logger instance for the same name."""
logger1 = get_logger("test_logger")
logger2 = get_logger("test_logger")
assert logger1 is logger2
def test_get_logger_rich_handler():
"""Test that get_logger configures a RichHandler on root."""
get_logger("test_rich")
root = logging.getLogger("ea_chatbot")
# Check if any handler is a RichHandler
handler_names = [h.__class__.__name__ for h in root.handlers]
assert "RichHandler" in handler_names
def test_get_logger_level():
"""Test that get_logger sets the correct log level."""
logger = get_logger("test_level", level="DEBUG")
assert logger.level == logging.DEBUG
def test_json_formatter_serializes_dict():
"""Test that JsonFormatter serializes log records to JSON."""
from ea_chatbot.utils.logging import JsonFormatter
import json
formatter = JsonFormatter()
record = logging.LogRecord(
name="test", level=logging.INFO, pathname="test.py", lineno=10,
msg="test message", args=(), exc_info=None
)
formatted = formatter.format(record)
data = json.loads(formatted)
assert data["message"] == "test message"
assert data["level"] == "INFO"
assert "timestamp" in data
def test_get_logger_file_handler(tmp_path):
"""Test that get_logger configures a file handler on root."""
log_file = tmp_path / "test.json"
logger = get_logger("test_file", log_file=str(log_file))
root = logging.getLogger("ea_chatbot")
handler_names = [h.__class__.__name__ for h in root.handlers]
assert "RotatingFileHandler" in handler_names
logger.info("file log test")
# Check if file exists and has content
assert log_file.exists()
content = log_file.read_text()
assert "file log test" in content

83
tests/test_logging_e2e.py Normal file
View File

@@ -0,0 +1,83 @@
import os
import json
import pytest
import logging
from unittest.mock import MagicMock, patch
from ea_chatbot.graph.workflow import app
from ea_chatbot.graph.state import AgentState
from ea_chatbot.utils.logging import get_logger
from langchain_community.chat_models import FakeListChatModel
from langchain_core.runnables import RunnableLambda
@pytest.fixture(autouse=True)
def reset_logging():
"""Reset handlers on the root ea_chatbot logger."""
root = logging.getLogger("ea_chatbot")
for handler in root.handlers[:]:
root.removeHandler(handler)
yield
for handler in root.handlers[:]:
root.removeHandler(handler)
class FakeStructuredModel(FakeListChatModel):
def with_structured_output(self, schema, **kwargs):
# Return a runnable that returns a parsed object
def _invoke(input, config=None, **kwargs):
content = self.responses[0]
import json
data = json.loads(content)
if hasattr(schema, "model_validate"):
return schema.model_validate(data)
return data
return RunnableLambda(_invoke)
def test_logging_e2e_json_output(tmp_path):
"""Test that a full graph run produces structured JSON logs from multiple nodes."""
log_file = tmp_path / "e2e_test.jsonl"
# Configure the root logger
get_logger("ea_chatbot", log_file=str(log_file))
initial_state: AgentState = {
"messages": [],
"question": "Who won in 2024?",
"analysis": None,
"next_action": "",
"plan": None,
"code": None,
"code_output": None,
"error": None,
"plots": [],
"dfs": {}
}
# Create fake models that support callbacks and structured output
fake_analyzer_response = """{"data_required": [], "unknowns": [], "ambiguities": ["Which year?"], "conditions": [], "next_action": "clarify"}"""
fake_analyzer = FakeStructuredModel(responses=[fake_analyzer_response])
fake_clarify = FakeListChatModel(responses=["Please specify."])
with patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model") as mock_llm_factory:
mock_llm_factory.return_value = fake_analyzer
with patch("ea_chatbot.graph.nodes.clarification.get_llm_model") as mock_clarify_llm_factory:
mock_clarify_llm_factory.return_value = fake_clarify
# Run the graph
list(app.stream(initial_state))
# Verify file content
assert log_file.exists()
lines = log_file.read_text().splitlines()
assert len(lines) > 0
# Verify we have logs from different nodes
node_names = [json.loads(line)["name"] for line in lines]
assert "ea_chatbot.query_analyzer" in node_names
assert "ea_chatbot.clarification" in node_names
# Verify events
messages = [json.loads(line)["message"] for line in lines]
assert any("Analyzing question" in m for m in messages)
assert any("Clarification generated" in m for m in messages)

View File

@@ -0,0 +1,64 @@
import logging
import pytest
import io
from unittest.mock import MagicMock
from ea_chatbot.utils.logging import LangChainLoggingHandler
@pytest.fixture
def log_capture():
"""Fixture to capture logs from a logger."""
log_stream = io.StringIO()
logger = logging.getLogger("test_langchain")
logger.setLevel(logging.INFO)
# Remove existing handlers
for handler in logger.handlers[:]:
logger.removeHandler(handler)
handler = logging.StreamHandler(log_stream)
logger.addHandler(handler)
return logger, log_stream
def test_langchain_logging_handler_on_llm_start(log_capture):
"""Test that on_llm_start logs the correct message."""
logger, log_stream = log_capture
handler = LangChainLoggingHandler(logger=logger)
handler.on_llm_start(serialized={"name": "test_model"}, prompts=["test prompt"])
output = log_stream.getvalue()
assert "LLM Started:" in output
assert "test_model" in output
def test_langchain_logging_handler_on_llm_end(log_capture):
"""Test that on_llm_end logs token usage."""
logger, log_stream = log_capture
handler = LangChainLoggingHandler(logger=logger)
response = MagicMock()
response.llm_output = {
"token_usage": {
"prompt_tokens": 10,
"completion_tokens": 20,
"total_tokens": 30
},
"model_name": "test_model"
}
handler.on_llm_end(response=response)
output = log_stream.getvalue()
assert "LLM Ended:" in output
assert "test_model" in output
assert "Tokens: 30" in output
assert "10 prompt" in output
assert "20 completion" in output
def test_langchain_logging_handler_on_llm_error(log_capture):
"""Test that on_llm_error logs the error."""
logger, log_stream = log_capture
handler = LangChainLoggingHandler(logger=logger)
error = Exception("test error")
handler.on_llm_error(error=error)
output = log_stream.getvalue()
assert "LLM Error:" in output
assert "test error" in output

View File

@@ -0,0 +1,79 @@
import pytest
from unittest.mock import MagicMock, patch
from langchain_core.messages import HumanMessage, AIMessage
from ea_chatbot.graph.nodes.planner import planner_node
from ea_chatbot.graph.nodes.researcher import researcher_node
from ea_chatbot.graph.nodes.summarizer import summarizer_node
from ea_chatbot.schemas import TaskPlanResponse
@pytest.fixture
def mock_state_with_history():
return {
"messages": [
HumanMessage(content="Show me the 2024 results for Florida"),
AIMessage(content="Here are the results for Florida in 2024...")
],
"question": "What about in New Jersey?",
"analysis": {"data_required": ["2024 results", "New Jersey"], "unknowns": [], "ambiguities": [], "conditions": []},
"next_action": "plan",
"summary": "The user is asking about 2024 election results.",
"plan": "Plan steps...",
"code_output": "Code output..."
}
@patch("ea_chatbot.graph.nodes.planner.get_llm_model")
@patch("ea_chatbot.utils.database_inspection.get_data_summary")
@patch("ea_chatbot.graph.nodes.planner.PLANNER_PROMPT")
def test_planner_uses_history_and_summary(mock_prompt, mock_get_summary, mock_get_llm, mock_state_with_history):
mock_get_summary.return_value = "Data summary"
mock_llm_instance = MagicMock()
mock_get_llm.return_value = mock_llm_instance
mock_structured_llm = MagicMock()
mock_llm_instance.with_structured_output.return_value = mock_structured_llm
mock_structured_llm.invoke.return_value = TaskPlanResponse(
goal="goal",
reflection="reflection",
context={
"initial_context": "context",
"assumptions": [],
"constraints": []
},
steps=["Step 1: test"]
)
planner_node(mock_state_with_history)
mock_prompt.format_messages.assert_called_once()
kwargs = mock_prompt.format_messages.call_args[1]
assert kwargs["question"] == "What about in New Jersey?"
assert kwargs["summary"] == mock_state_with_history["summary"]
assert len(kwargs["history"]) == 2
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
@patch("ea_chatbot.graph.nodes.researcher.RESEARCHER_PROMPT")
def test_researcher_uses_history_and_summary(mock_prompt, mock_get_llm, mock_state_with_history):
mock_llm_instance = MagicMock()
mock_get_llm.return_value = mock_llm_instance
researcher_node(mock_state_with_history)
mock_prompt.format_messages.assert_called_once()
kwargs = mock_prompt.format_messages.call_args[1]
assert kwargs["question"] == "What about in New Jersey?"
assert kwargs["summary"] == mock_state_with_history["summary"]
assert len(kwargs["history"]) == 2
@patch("ea_chatbot.graph.nodes.summarizer.get_llm_model")
@patch("ea_chatbot.graph.nodes.summarizer.SUMMARIZER_PROMPT")
def test_summarizer_uses_history_and_summary(mock_prompt, mock_get_llm, mock_state_with_history):
mock_llm_instance = MagicMock()
mock_get_llm.return_value = mock_llm_instance
summarizer_node(mock_state_with_history)
mock_prompt.format_messages.assert_called_once()
kwargs = mock_prompt.format_messages.call_args[1]
assert kwargs["question"] == "What about in New Jersey?"
assert kwargs["summary"] == mock_state_with_history["summary"]
assert len(kwargs["history"]) == 2

View File

@@ -0,0 +1,76 @@
import pytest
from unittest.mock import MagicMock, patch
from langchain_core.messages import HumanMessage, AIMessage
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node, QueryAnalysis
from ea_chatbot.graph.state import AgentState
@pytest.fixture
def mock_state_with_history():
return {
"messages": [
HumanMessage(content="Show me the 2024 results for Florida"),
AIMessage(content="Here are the results for Florida in 2024...")
],
"question": "What about in New Jersey?",
"analysis": None,
"next_action": "",
"summary": "The user is asking about 2024 election results."
}
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
@patch("ea_chatbot.graph.nodes.query_analyzer.QUERY_ANALYZER_PROMPT")
def test_query_analyzer_uses_history_and_summary(mock_prompt, mock_get_llm, mock_state_with_history):
"""Test that query_analyzer_node passes history and summary to the prompt."""
mock_llm_instance = MagicMock()
mock_get_llm.return_value = mock_llm_instance
mock_structured_llm = MagicMock()
mock_llm_instance.with_structured_output.return_value = mock_structured_llm
mock_structured_llm.invoke.return_value = QueryAnalysis(
data_required=["2024 results", "New Jersey"],
unknowns=[],
ambiguities=[],
conditions=[],
next_action="plan"
)
query_analyzer_node(mock_state_with_history)
# Verify that the prompt was formatted with the correct variables
mock_prompt.format_messages.assert_called_once()
kwargs = mock_prompt.format_messages.call_args[1]
assert kwargs["question"] == "What about in New Jersey?"
assert "summary" in kwargs
assert kwargs["summary"] == mock_state_with_history["summary"]
assert "history" in kwargs
# History should contain the messages from the state
assert len(kwargs["history"]) == 2
assert kwargs["history"][0].content == "Show me the 2024 results for Florida"
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
def test_query_analyzer_context_window(mock_get_llm):
"""Test that query_analyzer_node only uses the last 6 messages (3 turns)."""
messages = [HumanMessage(content=f"Msg {i}") for i in range(10)]
state = {
"messages": messages,
"question": "Latest question",
"analysis": None,
"next_action": "",
"summary": "Summary"
}
mock_llm_instance = MagicMock()
mock_get_llm.return_value = mock_llm_instance
mock_structured_llm = MagicMock()
mock_llm_instance.with_structured_output.return_value = mock_structured_llm
mock_structured_llm.invoke.return_value = QueryAnalysis(
data_required=[], unknowns=[], ambiguities=[], conditions=[], next_action="plan"
)
with patch("ea_chatbot.graph.nodes.query_analyzer.QUERY_ANALYZER_PROMPT") as mock_prompt:
query_analyzer_node(state)
kwargs = mock_prompt.format_messages.call_args[1]
# Should only have last 6 messages
assert len(kwargs["history"]) == 6
assert kwargs["history"][0].content == "Msg 4"

87
tests/test_oidc_client.py Normal file
View File

@@ -0,0 +1,87 @@
import pytest
from unittest.mock import MagicMock, patch
from ea_chatbot.auth import OIDCClient
@pytest.fixture
def oidc_config():
return {
"client_id": "test_id",
"client_secret": "test_secret",
"server_metadata_url": "https://example.com/.well-known/openid-configuration",
"redirect_uri": "http://localhost:8501"
}
@pytest.fixture
def mock_metadata():
return {
"authorization_endpoint": "https://example.com/auth",
"token_endpoint": "https://example.com/token",
"userinfo_endpoint": "https://example.com/userinfo"
}
def test_oidc_fetch_metadata(oidc_config, mock_metadata):
client = OIDCClient(**oidc_config)
with patch("requests.get") as mock_get:
mock_response = MagicMock()
mock_response.json.return_value = mock_metadata
mock_get.return_value = mock_response
metadata = client.fetch_metadata()
assert metadata == mock_metadata
mock_get.assert_called_once_with(oidc_config["server_metadata_url"])
# Second call should use cache
client.fetch_metadata()
assert mock_get.call_count == 1
def test_oidc_get_login_url(oidc_config, mock_metadata):
client = OIDCClient(**oidc_config)
client.metadata = mock_metadata
with patch.object(client.oauth_session, "create_authorization_url") as mock_create_url:
mock_create_url.return_value = ("https://example.com/auth?state=xyz", "xyz")
url = client.get_login_url()
assert url == "https://example.com/auth?state=xyz"
mock_create_url.assert_called_once_with(mock_metadata["authorization_endpoint"])
def test_oidc_get_login_url_missing_endpoint(oidc_config):
client = OIDCClient(**oidc_config)
client.metadata = {"some_other": "field"}
with pytest.raises(ValueError, match="authorization_endpoint not found"):
client.get_login_url()
def test_oidc_exchange_code_for_token(oidc_config, mock_metadata):
client = OIDCClient(**oidc_config)
client.metadata = mock_metadata
with patch.object(client.oauth_session, "fetch_token") as mock_fetch_token:
mock_fetch_token.return_value = {"access_token": "abc"}
token = client.exchange_code_for_token("test_code")
assert token == {"access_token": "abc"}
mock_fetch_token.assert_called_once_with(
mock_metadata["token_endpoint"],
code="test_code",
client_secret=oidc_config["client_secret"]
)
def test_oidc_get_user_info(oidc_config, mock_metadata):
client = OIDCClient(**oidc_config)
client.metadata = mock_metadata
with patch.object(client.oauth_session, "get") as mock_get:
mock_response = MagicMock()
mock_response.json.return_value = {"sub": "user123"}
mock_get.return_value = mock_response
user_info = client.get_user_info({"access_token": "abc"})
assert user_info == {"sub": "user123"}
assert client.oauth_session.token == {"access_token": "abc"}
mock_get.assert_called_once_with(mock_metadata["userinfo_endpoint"])

46
tests/test_planner.py Normal file
View File

@@ -0,0 +1,46 @@
import pytest
from unittest.mock import MagicMock, patch
from ea_chatbot.graph.nodes.planner import planner_node
@pytest.fixture
def mock_state():
return {
"messages": [],
"question": "Show me results for New Jersey",
"analysis": {
# "requires_dataset" removed as it's no longer used
"expert": "Data Analyst",
"data": "NJ data",
"unknown": "results",
"condition": "state=NJ"
},
"next_action": "plan",
"plan": None
}
@patch("ea_chatbot.graph.nodes.planner.get_llm_model")
@patch("ea_chatbot.utils.database_inspection.get_data_summary")
def test_planner_node(mock_get_summary, mock_get_llm, mock_state):
"""Test planner node with unified prompt."""
mock_get_summary.return_value = "Column: Name, Type: text"
mock_llm = MagicMock()
mock_get_llm.return_value = mock_llm
from ea_chatbot.schemas import TaskPlanResponse, TaskPlanContext
mock_plan = TaskPlanResponse(
goal="Get NJ results",
reflection="The user wants NJ results",
context=TaskPlanContext(initial_context="NJ data", assumptions=[], constraints=[]),
steps=["Step 1: Load data", "Step 2: Filter by NJ"]
)
mock_llm.with_structured_output.return_value.invoke.return_value = mock_plan
result = planner_node(mock_state)
assert "plan" in result
assert "Step 1: Load data" in result["plan"]
assert "Step 2: Filter by NJ" in result["plan"]
# Verify helper was called
mock_get_summary.assert_called_once()

View File

@@ -0,0 +1,80 @@
import pytest
from unittest.mock import MagicMock, patch
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node, QueryAnalysis
from ea_chatbot.graph.state import AgentState
@pytest.fixture
def mock_state():
return {
"messages": [],
"question": "Show me the 2024 results for Florida",
"analysis": None,
"next_action": ""
}
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
def test_query_analyzer_data_analysis(mock_get_llm, mock_state):
"""Test that a clear data analysis query is routed to the planner."""
# Mock the LLM and the structured output runnable
mock_llm_instance = MagicMock()
mock_get_llm.return_value = mock_llm_instance
mock_structured_llm = MagicMock()
mock_llm_instance.with_structured_output.return_value = mock_structured_llm
# Define the expected Pydantic result
expected_analysis = QueryAnalysis(
data_required=["2024 results", "Florida"],
unknowns=[],
ambiguities=[],
conditions=[],
next_action="plan"
)
# When structured_llm.invoke is called with messages, return the Pydantic object
mock_structured_llm.invoke.return_value = expected_analysis
new_state_update = query_analyzer_node(mock_state)
assert new_state_update["next_action"] == "plan"
assert "2024 results" in new_state_update["analysis"]["data_required"]
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
def test_query_analyzer_ambiguous(mock_get_llm, mock_state):
"""Test that an ambiguous query is routed to clarification."""
mock_state["question"] = "What happened?"
mock_llm_instance = MagicMock()
mock_get_llm.return_value = mock_llm_instance
mock_structured_llm = MagicMock()
mock_llm_instance.with_structured_output.return_value = mock_structured_llm
expected_analysis = QueryAnalysis(
data_required=[],
unknowns=["What event?"],
ambiguities=[],
conditions=[],
next_action="clarify"
)
mock_structured_llm.invoke.return_value = expected_analysis
new_state_update = query_analyzer_node(mock_state)
assert new_state_update["next_action"] == "clarify"
assert len(new_state_update["analysis"]["unknowns"]) > 0
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
def test_query_analyzer_uses_config(mock_get_llm, mock_state, monkeypatch):
"""Test that the node uses the configured LLM settings."""
monkeypatch.setenv("QUERY_ANALYZER_LLM__MODEL", "gpt-3.5-turbo")
mock_llm_instance = MagicMock()
mock_get_llm.return_value = mock_llm_instance
mock_structured_llm = MagicMock()
mock_llm_instance.with_structured_output.return_value = mock_structured_llm
mock_structured_llm.invoke.return_value = QueryAnalysis(
data_required=[], unknowns=[], ambiguities=[], conditions=[], next_action="plan"
)
query_analyzer_node(mock_state)
# Verify get_llm_model was called with the overridden config
called_config = mock_get_llm.call_args[0][0]
assert called_config.model == "gpt-3.5-turbo"

View File

@@ -0,0 +1,45 @@
import pytest
import logging
from unittest.mock import MagicMock, patch
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node, QueryAnalysis
@pytest.fixture
def mock_state():
return {
"messages": [],
"question": "Show me the 2024 results for Florida",
"analysis": None,
"next_action": ""
}
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
@patch("ea_chatbot.graph.nodes.query_analyzer.get_logger")
def test_query_analyzer_logs_actions(mock_get_logger, mock_get_llm, mock_state):
"""Test that query_analyzer_node logs its main actions."""
# Mock Logger
mock_logger = MagicMock()
mock_get_logger.return_value = mock_logger
# Mock LLM
mock_llm_instance = MagicMock()
mock_get_llm.return_value = mock_llm_instance
mock_structured_llm = MagicMock()
mock_llm_instance.with_structured_output.return_value = mock_structured_llm
expected_analysis = QueryAnalysis(
data_required=["2024 results", "Florida"],
unknowns=[],
ambiguities=[],
conditions=[],
next_action="plan"
)
mock_structured_llm.invoke.return_value = expected_analysis
query_analyzer_node(mock_state)
# Check that logger was called
# We expect at least one log at the start and one at the end
assert mock_logger.info.called
# Verify specific log messages if we decide on them
# For now, just ensuring it's called is enough for Red phase

View File

@@ -0,0 +1,103 @@
import pytest
from unittest.mock import MagicMock, patch
from langchain_core.messages import HumanMessage, AIMessage
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node, QueryAnalysis
from ea_chatbot.graph.state import AgentState
@pytest.fixture
def base_state():
return {
"messages": [],
"question": "",
"analysis": None,
"next_action": "",
"summary": ""
}
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
def test_refinement_coreference_from_history(mock_get_llm, base_state):
"""
Test that the analyzer can resolve Year/State from history.
User asks "What about in NJ?" after a Florida 2024 query.
Expected: next_action = 'plan', NOT 'clarify' due to missing year.
"""
state = base_state.copy()
state["messages"] = [
HumanMessage(content="Show me 2024 results for Florida"),
AIMessage(content="Here are the 2024 results for Florida...")
]
state["question"] = "What about in New Jersey?"
state["summary"] = "The user is looking for 2024 election results."
mock_llm = MagicMock()
mock_get_llm.return_value = mock_llm
mock_structured = MagicMock()
mock_llm.with_structured_output.return_value = mock_structured
# We expect the LLM to eventually return 'plan' because it sees the context.
# For now, if it returns 'clarify', this test should fail once we update the prompt to BE less strict.
mock_structured.invoke.return_value = QueryAnalysis(
data_required=["2024 results", "New Jersey"],
unknowns=[],
ambiguities=[],
conditions=["state=NJ", "year=2024"],
next_action="plan"
)
result = query_analyzer_node(state)
assert result["next_action"] == "plan"
assert "NJ" in str(result["analysis"]["conditions"])
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
def test_refinement_tolerance_for_missing_format(mock_get_llm, base_state):
"""
Test that the analyzer doesn't flag missing output format or database name.
User asks "Give me a graph of turnout".
Expected: next_action = 'plan', even if 'format' or 'db' is not in query.
"""
state = base_state.copy()
state["question"] = "Give me a graph of voter turnout in 2024 for Florida"
mock_llm = MagicMock()
mock_get_llm.return_value = mock_llm
mock_structured = MagicMock()
mock_llm.with_structured_output.return_value = mock_structured
mock_structured.invoke.return_value = QueryAnalysis(
data_required=["voter turnout", "Florida"],
unknowns=[],
ambiguities=[],
conditions=["year=2024"],
next_action="plan"
)
result = query_analyzer_node(state)
assert result["next_action"] == "plan"
# Ensure no ambiguities were added by the analyzer itself (hallucinated requirement)
assert len(result["analysis"]["ambiguities"]) == 0
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
def test_refinement_enforces_voter_identity_clarification(mock_get_llm, base_state):
"""
Test that 'track the same voter' still triggers clarification.
"""
state = base_state.copy()
state["question"] = "Track the same voter participation in 2020 and 2024."
mock_llm = MagicMock()
mock_get_llm.return_value = mock_llm
mock_structured = MagicMock()
mock_llm.with_structured_output.return_value = mock_structured
# We WANT it to clarify here because voter identity is not defined.
mock_structured.invoke.return_value = QueryAnalysis(
data_required=["voter participation"],
unknowns=[],
ambiguities=["Please define what fields constitute 'the same voter' (e.g. ID, or Name and DOB)."],
conditions=[],
next_action="clarify"
)
result = query_analyzer_node(state)
assert result["next_action"] == "clarify"
assert "identity" in str(result["analysis"]["ambiguities"]).lower() or "same voter" in str(result["analysis"]["ambiguities"]).lower()

34
tests/test_researcher.py Normal file
View File

@@ -0,0 +1,34 @@
import pytest
from unittest.mock import MagicMock, patch
from langchain_core.messages import AIMessage
from langchain_openai import ChatOpenAI
from ea_chatbot.graph.nodes.researcher import researcher_node
@pytest.fixture
def mock_llm():
with patch("ea_chatbot.graph.nodes.researcher.get_llm_model") as mock_get_llm:
mock_llm_instance = MagicMock(spec=ChatOpenAI)
mock_get_llm.return_value = mock_llm_instance
yield mock_llm_instance
def test_researcher_node_success(mock_llm):
"""Test that researcher_node invokes LLM with web_search tool and returns messages."""
state = {
"question": "What is the capital of France?",
"messages": []
}
mock_llm_with_tools = MagicMock()
mock_llm.bind_tools.return_value = mock_llm_with_tools
mock_llm_with_tools.invoke.return_value = AIMessage(content="The capital of France is Paris.")
result = researcher_node(state)
assert mock_llm.bind_tools.called
# Check that it was called with web_search
args, kwargs = mock_llm.bind_tools.call_args
assert {"type": "web_search"} in args[0]
assert mock_llm_with_tools.invoke.called
assert "messages" in result
assert result["messages"][0].content == "The capital of France is Paris."

View File

@@ -0,0 +1,62 @@
import pytest
from unittest.mock import MagicMock, patch
from langchain_core.messages import AIMessage
from langchain_openai import ChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI
from ea_chatbot.graph.nodes.researcher import researcher_node
@pytest.fixture
def base_state():
return {
"question": "Who won the 2024 election?",
"messages": [],
"summary": ""
}
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
def test_researcher_binds_openai_search(mock_get_llm, base_state):
"""Test that OpenAI LLM binds 'web_search' tool."""
mock_llm = MagicMock(spec=ChatOpenAI)
mock_get_llm.return_value = mock_llm
mock_llm_with_tools = MagicMock()
mock_llm.bind_tools.return_value = mock_llm_with_tools
mock_llm_with_tools.invoke.return_value = AIMessage(content="OpenAI Search Result")
result = researcher_node(base_state)
# Verify bind_tools called with correct OpenAI tool
mock_llm.bind_tools.assert_called_once_with([{"type": "web_search"}])
assert result["messages"][0].content == "OpenAI Search Result"
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
def test_researcher_binds_google_search(mock_get_llm, base_state):
"""Test that Google LLM binds 'google_search' tool."""
mock_llm = MagicMock(spec=ChatGoogleGenerativeAI)
mock_get_llm.return_value = mock_llm
mock_llm_with_tools = MagicMock()
mock_llm.bind_tools.return_value = mock_llm_with_tools
mock_llm_with_tools.invoke.return_value = AIMessage(content="Google Search Result")
result = researcher_node(base_state)
# Verify bind_tools called with correct Google tool
mock_llm.bind_tools.assert_called_once_with([{"google_search": {}}])
assert result["messages"][0].content == "Google Search Result"
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
def test_researcher_fallback_on_bind_error(mock_get_llm, base_state):
"""Test that researcher falls back to basic LLM if bind_tools fails."""
mock_llm = MagicMock(spec=ChatOpenAI)
mock_get_llm.return_value = mock_llm
# Simulate bind_tools failing (e.g. model doesn't support it)
mock_llm.bind_tools.side_effect = Exception("Not supported")
mock_llm.invoke.return_value = AIMessage(content="Basic Result")
result = researcher_node(base_state)
# Should still succeed using the base LLM
assert result["messages"][0].content == "Basic Result"
mock_llm.invoke.assert_called_once()

41
tests/test_state.py Normal file
View File

@@ -0,0 +1,41 @@
import pytest
from typing import get_type_hints, List
from langchain_core.messages import BaseMessage, HumanMessage
from ea_chatbot.graph.state import AgentState
import operator
def test_agent_state_structure():
"""Verify that AgentState has the required fields and types."""
hints = get_type_hints(AgentState)
assert "messages" in hints
# Check if Annotated is used, we might need to inspect the __metadata__ if feasible,
# but for TypedDict, checking the key existence is a good start.
# The exact type check for Annotated[List[BaseMessage], operator.add] can be complex to assert strictly,
# but we can check if it's there.
assert "question" in hints
assert hints["question"] == str
# analysis should be Optional[Dict[str, Any]] or similar, but the spec says "Dictionary"
# Let's check it exists.
assert "analysis" in hints
assert "next_action" in hints
assert hints["next_action"] == str
assert "summary" in hints
# summary should be Optional[str] or str. Let's assume Optional[str] for flexibility.
assert "plots" in hints
assert "dfs" in hints
def test_messages_reducer_behavior():
"""Verify that the messages field allows adding lists (simulation of operator.add)."""
# This is harder to test directly on the TypedDict definition without instantiating it in a graph context,
# but we can verify that the type hint implies a list.
hints = get_type_hints(AgentState)
# We expect messages to be Annotated[List[BaseMessage], operator.add]
# We can just assume the developer implements it correctly if the previous test passes,
# or try to inspect the annotation.
pass

48
tests/test_summarizer.py Normal file
View File

@@ -0,0 +1,48 @@
import pytest
from unittest.mock import MagicMock, patch
from langchain_core.messages import AIMessage
from ea_chatbot.graph.nodes.summarizer import summarizer_node
from ea_chatbot.graph.state import AgentState
@pytest.fixture
def mock_llm():
with patch("ea_chatbot.graph.nodes.summarizer.get_llm_model") as mock_get_llm:
mock_llm_instance = MagicMock()
mock_get_llm.return_value = mock_llm_instance
yield mock_llm_instance
def test_summarizer_node_success(mock_llm):
"""Test that summarizer_node invokes LLM with correct inputs and returns messages."""
state = {
"question": "What is the total count?",
"plan": "1. Run query\n2. Sum results",
"code_output": "The total is 100",
"messages": []
}
mock_llm.invoke.return_value = AIMessage(content="The final answer is 100.")
result = summarizer_node(state)
# Verify LLM was called
assert mock_llm.invoke.called
# Verify result structure
assert "messages" in result
assert len(result["messages"]) == 1
assert isinstance(result["messages"][0], AIMessage)
assert result["messages"][0].content == "The final answer is 100."
def test_summarizer_node_empty_state(mock_llm):
"""Test handling of empty or minimal state."""
state = {
"question": "Empty?",
"messages": []
}
mock_llm.invoke.return_value = AIMessage(content="No data provided.")
result = summarizer_node(state)
assert "messages" in result
assert result["messages"][0].content == "No data provided."

93
tests/test_workflow.py Normal file
View File

@@ -0,0 +1,93 @@
import pytest
from unittest.mock import MagicMock, patch
from ea_chatbot.graph.workflow import app
from ea_chatbot.graph.nodes.query_analyzer import QueryAnalysis
from ea_chatbot.schemas import TaskPlanResponse, TaskPlanContext, CodeGenerationResponse
from langchain_core.messages import AIMessage
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
@patch("ea_chatbot.graph.nodes.planner.get_llm_model")
@patch("ea_chatbot.graph.nodes.coder.get_llm_model")
@patch("ea_chatbot.graph.nodes.summarizer.get_llm_model")
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
@patch("ea_chatbot.utils.database_inspection.get_data_summary")
@patch("ea_chatbot.graph.nodes.executor.Settings")
@patch("ea_chatbot.graph.nodes.executor.DBClient")
def test_workflow_full_flow(mock_db_client, mock_settings, mock_get_summary, mock_researcher_llm, mock_summarizer_llm, mock_coder_llm, mock_planner_llm, mock_qa_llm):
"""Test the flow from query_analyzer through planner to coder."""
# Mock Settings for Executor
mock_settings_instance = MagicMock()
mock_settings_instance.db_host = "localhost"
mock_settings_instance.db_port = 5432
mock_settings_instance.db_user = "user"
mock_settings_instance.db_pswd = "pass"
mock_settings_instance.db_name = "test_db"
mock_settings_instance.db_table = "test_table"
mock_settings.return_value = mock_settings_instance
# Mock DBClient
mock_client_instance = MagicMock()
mock_db_client.return_value = mock_client_instance
# 1. Mock Query Analyzer
mock_qa_instance = MagicMock()
mock_qa_llm.return_value = mock_qa_instance
mock_qa_instance.with_structured_output.return_value.invoke.return_value = QueryAnalysis(
data_required=["2024 results"],
unknowns=[],
ambiguities=[],
conditions=[],
next_action="plan"
)
# 2. Mock Planner
mock_planner_instance = MagicMock()
mock_planner_llm.return_value = mock_planner_instance
mock_get_summary.return_value = "Data summary"
mock_planner_instance.with_structured_output.return_value.invoke.return_value = TaskPlanResponse(
goal="Task Goal",
reflection="Reflection",
context=TaskPlanContext(initial_context="Ctx", assumptions=[], constraints=[]),
steps=["Step 1"]
)
# 3. Mock Coder
mock_coder_instance = MagicMock()
mock_coder_llm.return_value = mock_coder_instance
mock_coder_instance.with_structured_output.return_value.invoke.return_value = CodeGenerationResponse(
code="print('Hello')",
explanation="Explanation"
)
# 4. Mock Summarizer
mock_summarizer_instance = MagicMock()
mock_summarizer_llm.return_value = mock_summarizer_instance
mock_summarizer_instance.invoke.return_value = AIMessage(content="Summary")
# 5. Mock Researcher (just in case)
mock_researcher_instance = MagicMock()
mock_researcher_llm.return_value = mock_researcher_instance
# Initial state
initial_state = {
"messages": [],
"question": "Show me the 2024 results",
"analysis": None,
"next_action": "",
"plan": None,
"code": None,
"error": None,
"plots": [],
"dfs": {}
}
# Run the graph
# We use recursion_limit to avoid infinite loops in placeholders if any
result = app.invoke(initial_state, config={"recursion_limit": 10})
assert result["next_action"] == "plan"
assert "plan" in result and result["plan"] is not None
assert "code" in result and "print('Hello')" in result["code"]
assert "analysis" in result

139
tests/test_workflow_e2e.py Normal file
View File

@@ -0,0 +1,139 @@
import pytest
import yaml
from unittest.mock import MagicMock, patch
from langchain_core.messages import AIMessage
from ea_chatbot.graph.workflow import app
from ea_chatbot.graph.nodes.query_analyzer import QueryAnalysis
from ea_chatbot.schemas import TaskPlanResponse, TaskPlanContext, CodeGenerationResponse
@pytest.fixture
def mock_llms():
with patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model") as mock_qa_llm, \
patch("ea_chatbot.graph.nodes.planner.get_llm_model") as mock_planner_llm, \
patch("ea_chatbot.graph.nodes.coder.get_llm_model") as mock_coder_llm, \
patch("ea_chatbot.graph.nodes.summarizer.get_llm_model") as mock_summarizer_llm, \
patch("ea_chatbot.graph.nodes.researcher.get_llm_model") as mock_researcher_llm, \
patch("ea_chatbot.graph.nodes.summarize_conversation.get_llm_model") as mock_summary_llm, \
patch("ea_chatbot.utils.database_inspection.get_data_summary") as mock_get_summary:
mock_get_summary.return_value = "Data summary"
# Mock summary LLM to return a simple response
mock_summary_instance = MagicMock()
mock_summary_llm.return_value = mock_summary_instance
mock_summary_instance.invoke.return_value = AIMessage(content="Turn summary")
yield {
"qa": mock_qa_llm,
"planner": mock_planner_llm,
"coder": mock_coder_llm,
"summarizer": mock_summarizer_llm,
"researcher": mock_researcher_llm,
"summary": mock_summary_llm
}
def test_workflow_data_analysis_flow(mock_llms):
"""Test full flow: QueryAnalyzer -> Planner -> Coder -> Executor -> Summarizer."""
# 1. Mock Query Analyzer (routes to plan)
mock_qa_instance = MagicMock()
mock_llms["qa"].return_value = mock_qa_instance
mock_qa_instance.with_structured_output.return_value.invoke.return_value = QueryAnalysis(
data_required=["2024 results"],
unknowns=[],
ambiguities=[],
conditions=[],
next_action="plan"
)
# 2. Mock Planner
mock_planner_instance = MagicMock()
mock_llms["planner"].return_value = mock_planner_instance
mock_planner_instance.with_structured_output.return_value.invoke.return_value = TaskPlanResponse(
goal="Get results",
reflection="Reflect",
context=TaskPlanContext(initial_context="Ctx", assumptions=[], constraints=[]),
steps=["Step 1"]
)
# 3. Mock Coder
mock_coder_instance = MagicMock()
mock_llms["coder"].return_value = mock_coder_instance
mock_coder_instance.with_structured_output.return_value.invoke.return_value = CodeGenerationResponse(
code="print('Execution Success')",
explanation="Explain"
)
# 4. Mock Summarizer
mock_summarizer_instance = MagicMock()
mock_llms["summarizer"].return_value = mock_summarizer_instance
mock_summarizer_instance.invoke.return_value = AIMessage(content="Final Summary: Success")
# Initial state
initial_state = {
"messages": [],
"question": "Show me 2024 results",
"analysis": None,
"next_action": "",
"plan": None,
"code": None,
"error": None,
"plots": [],
"dfs": {}
}
# Run the graph
result = app.invoke(initial_state, config={"recursion_limit": 15})
assert result["next_action"] == "plan"
assert "Execution Success" in result["code_output"]
assert "Final Summary: Success" in result["messages"][-1].content
def test_workflow_research_flow(mock_llms):
"""Test flow: QueryAnalyzer -> Researcher -> Summarizer."""
# 1. Mock Query Analyzer (routes to research)
mock_qa_instance = MagicMock()
mock_llms["qa"].return_value = mock_qa_instance
mock_qa_instance.with_structured_output.return_value.invoke.return_value = QueryAnalysis(
data_required=[],
unknowns=[],
ambiguities=[],
conditions=[],
next_action="research"
)
# 2. Mock Researcher
mock_researcher_instance = MagicMock()
mock_llms["researcher"].return_value = mock_researcher_instance
# Researcher node uses bind_tools if it's ChatOpenAI/ChatGoogleGenerativeAI
# Since it's a MagicMock, it will fallback to using the base instance
mock_researcher_instance.invoke.return_value = AIMessage(content="Research Results")
# Also mock bind_tools just in case we ever use spec
mock_llm_with_tools = MagicMock()
mock_researcher_instance.bind_tools.return_value = mock_llm_with_tools
mock_llm_with_tools.invoke.return_value = AIMessage(content="Research Results")
# 3. Mock Summarizer (not used in this flow, but kept for completeness)
mock_summarizer_instance = MagicMock()
mock_llms["summarizer"].return_value = mock_summarizer_instance
mock_summarizer_instance.invoke.return_value = AIMessage(content="Final Summary: Research Success")
# Initial state
initial_state = {
"messages": [],
"question": "Who is the governor of Florida?",
"analysis": None,
"next_action": "",
"plan": None,
"code": None,
"error": None,
"plots": [],
"dfs": {}
}
# Run the graph
result = app.invoke(initial_state, config={"recursion_limit": 10})
assert result["next_action"] == "research"
assert "Research Results" in result["messages"][-1].content

View File

@@ -0,0 +1,62 @@
import pytest
from ea_chatbot.history.manager import HistoryManager
from ea_chatbot.history.models import User, Conversation, Message, Plot
from ea_chatbot.config import Settings
from sqlalchemy import delete
@pytest.fixture
def history_manager():
settings = Settings()
manager = HistoryManager(settings.history_db_url)
with manager.get_session() as session:
session.execute(delete(Plot))
session.execute(delete(Message))
session.execute(delete(Conversation))
session.execute(delete(User))
return manager
def test_full_history_workflow(history_manager):
# 1. Create and Authenticate User
email = "e2e@example.com"
password = "password123"
history_manager.create_user(email, password, "E2E User")
user = history_manager.authenticate_user(email, password)
assert user is not None
assert user.display_name == "E2E User"
# 2. Create Conversation
conv = history_manager.create_conversation(user.id, "nj", "Test Analytics")
assert conv.id is not None
# 3. Add User Message
history_manager.add_message(conv.id, "user", "How many voters in NJ?")
# 4. Add Assistant Message with Plot
plot_data = b"fake_png_data"
history_manager.add_message(
conv.id,
"assistant",
"There are X voters.",
plots=[plot_data]
)
# 5. Retrieve and Verify History
messages = history_manager.get_messages(conv.id)
assert len(messages) == 2
assert messages[0].role == "user"
assert messages[1].role == "assistant"
assert len(messages[1].plots) == 1
assert messages[1].plots[0].image_data == plot_data
# 6. Verify Conversation listing
convs = history_manager.get_conversations(user.id, "nj")
assert len(convs) == 1
assert convs[0].name == "Test Analytics"
# 7. Update summary
history_manager.update_conversation_summary(conv.id, "Voter count analysis")
# 8. Reload and verify summary
updated_convs = history_manager.get_conversations(user.id, "nj")
assert updated_convs[0].summary == "Voter count analysis"

2921
uv.lock generated Normal file

File diff suppressed because it is too large Load Diff