feat: implement mvp with email-first login flow and langgraph architecture
This commit is contained in:
51
.env.example
Normal file
51
.env.example
Normal file
@@ -0,0 +1,51 @@
|
||||
# API Keys
|
||||
OPENAI_API_KEY=your_openai_api_key_here
|
||||
GOOGLE_API_KEY=your_google_api_key_here
|
||||
|
||||
# App Configuration
|
||||
DATA_DIR=data
|
||||
DATA_STATE=new_jersey
|
||||
LOG_LEVEL=INFO
|
||||
DEV_MODE=false
|
||||
|
||||
# Voter Database Configuration
|
||||
DB_HOST=localhost
|
||||
DB_PORT=5432
|
||||
DB_NAME=blockdata
|
||||
DB_USER=user
|
||||
DB_PSWD=password
|
||||
DB_TABLE=rd_gc_voters_nj
|
||||
|
||||
# Application/History Database Configuration
|
||||
HISTORY_DB_URL=postgresql://user:password@localhost:5433/ea_history
|
||||
|
||||
# OIDC Configuration (Authentik/SSO)
|
||||
OIDC_CLIENT_ID=your_client_id
|
||||
OIDC_CLIENT_SECRET=your_client_secret
|
||||
OIDC_SERVER_METADATA_URL=https://your-authentik.example.com/application/o/ea-chatbot/.well-known/openid-configuration
|
||||
OIDC_REDIRECT_URI=http://localhost:8501
|
||||
|
||||
# Node Configuration Overrides (Optional)
|
||||
# Format: <NODE_NAME>_LLM__<PARAMETER>
|
||||
# Possible parameters: PROVIDER, MODEL, TEMPERATURE, MAX_TOKENS
|
||||
|
||||
# Query Analyzer
|
||||
# QUERY_ANALYZER_LLM__PROVIDER=openai
|
||||
# QUERY_ANALYZER_LLM__MODEL=gpt-5-mini
|
||||
# QUERY_ANALYZER_LLM__TEMPERATURE=0.0
|
||||
|
||||
# Planner
|
||||
# PLANNER_LLM__PROVIDER=openai
|
||||
# PLANNER_LLM__MODEL=gpt-5-mini
|
||||
|
||||
# Coder
|
||||
# CODER_LLM__PROVIDER=openai
|
||||
# CODER_LLM__MODEL=gpt-5-mini
|
||||
|
||||
# Summarizer
|
||||
# SUMMARIZER_LLM__PROVIDER=openai
|
||||
# SUMMARIZER_LLM__MODEL=gpt-5-mini
|
||||
|
||||
# Researcher
|
||||
# RESEARCHER_LLM__PROVIDER=google
|
||||
# RESEARCHER_LLM__MODEL=gemini-2.0-flash
|
||||
9
.gitignore
vendored
9
.gitignore
vendored
@@ -169,3 +169,12 @@ cython_debug/
|
||||
|
||||
# PyPI configuration file
|
||||
.pypirc
|
||||
|
||||
conductor/
|
||||
|
||||
data/
|
||||
|
||||
# Logs
|
||||
logs/
|
||||
postgres-data/
|
||||
langchain-docs/
|
||||
|
||||
169
GEMINI.md
Normal file
169
GEMINI.md
Normal file
@@ -0,0 +1,169 @@
|
||||
# Election Analytics Chatbot - Project Guide
|
||||
|
||||
## Overview
|
||||
This document serves as a guide for rewriting the current "BambooAI" based chatbot system into a modern, stateful, and graph-based architecture using **LangGraph**. The goal is to improve maintainability, observability, and flexibility of the agentic workflows.
|
||||
|
||||
## 1. Migration Goals
|
||||
- **Framework Switch**: Move from the custom linear `ChatBot` class (in `src/ea_chatbot/bambooai/core/chatbot.py`) to `LangGraph`.
|
||||
- **State Management**: explicit state management using LangGraph's `StateGraph`.
|
||||
- **Modularity**: Break down monolithic methods (`pd_agent_converse`, `execute_code`) into distinct Nodes.
|
||||
- **Observability**: Easier debugging of the decision process (Routing -> Planning -> Coding -> Executing).
|
||||
|
||||
## 2. Architecture Proposal
|
||||
|
||||
### 2.1. The Graph State
|
||||
The state will track the conversation and execution context.
|
||||
|
||||
```python
|
||||
from typing import TypedDict, Annotated, List, Dict, Any, Optional
|
||||
from langchain_core.messages import BaseMessage
|
||||
import operator
|
||||
|
||||
class AgentState(TypedDict):
|
||||
# Conversation history
|
||||
messages: Annotated[List[BaseMessage], operator.add]
|
||||
|
||||
# Task context
|
||||
question: str
|
||||
|
||||
# Query Analysis (Decomposition results)
|
||||
analysis: Optional[Dict[str, Any]]
|
||||
# Expected keys: "requires_dataset", "expert", "data", "unknown", "condition"
|
||||
|
||||
# Step-by-step reasoning
|
||||
plan: Optional[str]
|
||||
|
||||
# Code execution context
|
||||
code: Optional[str]
|
||||
code_output: Optional[str]
|
||||
error: Optional[str]
|
||||
|
||||
# Artifacts (for UI display)
|
||||
plots: List[Figure] # Matplotlib figures
|
||||
dfs: Dict[str, DataFrame] # Pandas DataFrames
|
||||
|
||||
# Control flow
|
||||
iterations: int
|
||||
next_action: str # Routing hint: "clarify", "plan", "research", "end"
|
||||
```
|
||||
|
||||
### 2.2. Nodes (The Actors)
|
||||
We will map existing logic to these nodes:
|
||||
|
||||
1. **`query_analyzer_node`** (Router & Refiner):
|
||||
* **Logic**: Replaces `Expert Selector` and `Analyst Selector`.
|
||||
* **Function**:
|
||||
1. Decomposes the user's query into key elements (Data, Unknowns, Conditions).
|
||||
2. Determines if the query is ambiguous or missing critical information.
|
||||
* **Output**: Updates `messages`. Returns routing decision:
|
||||
* `clarification_node` (if ambiguous).
|
||||
* `planner_node` (if clear data task).
|
||||
* `researcher_node` (if general/web task).
|
||||
|
||||
2. **`clarification_node`** (Human-in-the-loop):
|
||||
* **Logic**: Replaces `Theorist-Clarification`.
|
||||
* **Function**: Formulates a specific question to ask the user for missing details.
|
||||
* **Output**: Returns a message to the user and **interrupts** the graph execution to await user input.
|
||||
|
||||
3. **`researcher_node`** (Theorist):
|
||||
* **Logic**: Handles general queries or web searches.
|
||||
* **Function**: Uses `GoogleSearch` tool if necessary.
|
||||
* **Output**: Final answer.
|
||||
|
||||
4. **`planner_node`**:
|
||||
* **Logic**: Replaces `Planner`.
|
||||
* **Function**: Generates a step-by-step plan based on the decomposed query elements and dataframe ontology.
|
||||
* **Output**: Updates `plan`.
|
||||
|
||||
5. **`coder_node`**:
|
||||
* **Logic**: Replaces `Code Generator` & `Error Corrector`.
|
||||
* **Function**: Generates Python code. If `error` exists in state, it attempts to fix it.
|
||||
* **Output**: Updates `code`.
|
||||
|
||||
6. **`executor_node`**:
|
||||
* **Logic**: Replaces `Code Executor`.
|
||||
* **Function**: Executes the Python code in a safe(r) environment. It needs access to the `DBClient`.
|
||||
* **Output**: Updates `code_output`, `plots`, `dfs`. If exception, updates `error`.
|
||||
|
||||
7. **`summarizer_node`**:
|
||||
* **Logic**: Replaces `Solution Summarizer`.
|
||||
* **Function**: Interprets the code output and generates a natural language response.
|
||||
* **Output**: Final response message.
|
||||
|
||||
### 2.3. The Workflow (Graph)
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
Start --> QueryAnalyzer
|
||||
QueryAnalyzer -->|Ambiguous| Clarification
|
||||
Clarification -->|User Input| QueryAnalyzer
|
||||
QueryAnalyzer -->|General/Web| Researcher
|
||||
QueryAnalyzer -->|Data Analysis| Planner
|
||||
Planner --> Coder
|
||||
Coder --> Executor
|
||||
Executor -->|Success| Summarizer
|
||||
Executor -->|Error| Coder
|
||||
Researcher --> End
|
||||
Summarizer --> End
|
||||
```
|
||||
|
||||
## 3. Implementation Steps
|
||||
|
||||
### Step 1: Dependencies
|
||||
Add the following packages to `pyproject.toml`:
|
||||
* `langgraph`
|
||||
* `langchain`
|
||||
* `langchain-openai`
|
||||
* `langchain-google-genai`
|
||||
* `langchain-community`
|
||||
|
||||
### Step 2: Directory Structure
|
||||
Create a new package for the graph logic to keep it separate from the old one during migration.
|
||||
|
||||
```
|
||||
src/ea_chatbot/
|
||||
├── graph/
|
||||
│ ├── __init__.py
|
||||
│ ├── state.py # State definition
|
||||
│ ├── nodes/ # Individual node implementations
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── router.py
|
||||
│ │ ├── planner.py
|
||||
│ │ ├── coder.py
|
||||
│ │ ├── executor.py
|
||||
│ │ └── ...
|
||||
│ ├── workflow.py # Graph construction
|
||||
│ └── tools/ # DB and Search tools wrapped for LangChain
|
||||
└── ...
|
||||
```
|
||||
|
||||
### Step 3: Tool Wrapping
|
||||
Wrap the existing `DBClient` (from `src/ea_chatbot/bambooai/utils/db_client.py`) into a structure accessible by the `executor_node`. The `executor_node` will likely keep the existing `exec()` based approach initially for compatibility with the generated code, but structured as a graph node.
|
||||
|
||||
### Step 4: Prompt Migration
|
||||
Port the prompts from `data/PROMPT_TEMPLATES.json` or `src/ea_chatbot/bambooai/prompts/strings.py` into the respective nodes. Use LangChain's `ChatPromptTemplate` for better management.
|
||||
|
||||
### Step 5: Streamlit Integration
|
||||
Update `src/ea_chatbot/app.py` to use the new `workflow.compile()` runnable.
|
||||
* Instead of `chatbot.pd_agent_converse(...)`, use `app.stream(...)` (LangGraph app).
|
||||
* Handle the streaming output to update the UI progressively.
|
||||
|
||||
## 4. Key Considerations for Refactoring
|
||||
|
||||
* **Database Connection**: Ensure `DBClient` is initialized once and passed to the `Executor` node efficiently (e.g., via `configurable` parameters or closure).
|
||||
* **Prompt Templating**: The current system uses simple `format` strings. Switching to LangChain templates allows for easier model switching and partial formatting.
|
||||
* **Token Management**: LangGraph provides built-in tracing (if LangSmith is enabled), but we should ensure the `OutputManager` logic (printing costs/tokens) is preserved or adapted if still needed for the CLI/Logs.
|
||||
* **Vector DB**: The current system has `PineconeWrapper` for RAG. This should be integrated into the `Planner` or `Coder` node to fetch few-shot examples or context.
|
||||
|
||||
## 5. Next Actions
|
||||
1. **Initialize**: Create the folder structure.
|
||||
2. **Define State**: Create `src/ea_chatbot/graph/state.py`.
|
||||
3. **Implement Router**: Create the first node to replicate `Expert Selector` logic.
|
||||
4. **Implement Executor**: Port the `exec()` logic to a node.
|
||||
|
||||
## 6. Git Operations
|
||||
- Branches should be used for specific features or bug fixes.
|
||||
- New branches should be created from the `main` branch and `conductor` branch.
|
||||
- The conductor should always use the `conductor` branch and derived branches.
|
||||
- When a feature or fix is complete, use rebase to keep the commit history clean before merging.
|
||||
- The conductor related changes should never be merged into the `main` branch.
|
||||
149
alembic.ini
Normal file
149
alembic.ini
Normal file
@@ -0,0 +1,149 @@
|
||||
# A generic, single database configuration.
|
||||
|
||||
[alembic]
|
||||
# path to migration scripts.
|
||||
# this is typically a path given in POSIX (e.g. forward slashes)
|
||||
# format, relative to the token %(here)s which refers to the location of this
|
||||
# ini file
|
||||
script_location = %(here)s/alembic
|
||||
|
||||
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||
# Uncomment the line below if you want the files to be prepended with date and time
|
||||
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
|
||||
# for all available tokens
|
||||
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
||||
# Or organize into date-based subdirectories (requires recursive_version_locations = true)
|
||||
# file_template = %%(year)d/%%(month).2d/%%(day).2d_%%(hour).2d%%(minute).2d_%%(second).2d_%%(rev)s_%%(slug)s
|
||||
|
||||
# sys.path path, will be prepended to sys.path if present.
|
||||
# defaults to the current working directory. for multiple paths, the path separator
|
||||
# is defined by "path_separator" below.
|
||||
prepend_sys_path = .
|
||||
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
# as well as the filename.
|
||||
# If specified, requires the tzdata library which can be installed by adding
|
||||
# `alembic[tz]` to the pip requirements.
|
||||
# string value is passed to ZoneInfo()
|
||||
# leave blank for localtime
|
||||
# timezone =
|
||||
|
||||
# max length of characters to apply to the "slug" field
|
||||
# truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during
|
||||
# the 'revision' command, regardless of autogenerate
|
||||
# revision_environment = false
|
||||
|
||||
# set to 'true' to allow .pyc and .pyo files without
|
||||
# a source .py file to be detected as revisions in the
|
||||
# versions/ directory
|
||||
# sourceless = false
|
||||
|
||||
# version location specification; This defaults
|
||||
# to <script_location>/versions. When using multiple version
|
||||
# directories, initial revisions must be specified with --version-path.
|
||||
# The path separator used here should be the separator specified by "path_separator"
|
||||
# below.
|
||||
# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions
|
||||
|
||||
# path_separator; This indicates what character is used to split lists of file
|
||||
# paths, including version_locations and prepend_sys_path within configparser
|
||||
# files such as alembic.ini.
|
||||
# The default rendered in new alembic.ini files is "os", which uses os.pathsep
|
||||
# to provide os-dependent path splitting.
|
||||
#
|
||||
# Note that in order to support legacy alembic.ini files, this default does NOT
|
||||
# take place if path_separator is not present in alembic.ini. If this
|
||||
# option is omitted entirely, fallback logic is as follows:
|
||||
#
|
||||
# 1. Parsing of the version_locations option falls back to using the legacy
|
||||
# "version_path_separator" key, which if absent then falls back to the legacy
|
||||
# behavior of splitting on spaces and/or commas.
|
||||
# 2. Parsing of the prepend_sys_path option falls back to the legacy
|
||||
# behavior of splitting on spaces, commas, or colons.
|
||||
#
|
||||
# Valid values for path_separator are:
|
||||
#
|
||||
# path_separator = :
|
||||
# path_separator = ;
|
||||
# path_separator = space
|
||||
# path_separator = newline
|
||||
#
|
||||
# Use os.pathsep. Default configuration used for new projects.
|
||||
path_separator = os
|
||||
|
||||
# set to 'true' to search source files recursively
|
||||
# in each "version_locations" directory
|
||||
# new in Alembic version 1.10
|
||||
# recursive_version_locations = false
|
||||
|
||||
# the output encoding used when revision files
|
||||
# are written from script.py.mako
|
||||
# output_encoding = utf-8
|
||||
|
||||
# database URL. This is consumed by the user-maintained env.py script only.
|
||||
# other means of configuring database URLs may be customized within the env.py
|
||||
# file.
|
||||
sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||
|
||||
|
||||
[post_write_hooks]
|
||||
# post_write_hooks defines scripts or Python functions that are run
|
||||
# on newly generated revision scripts. See the documentation for further
|
||||
# detail and examples
|
||||
|
||||
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
||||
# hooks = black
|
||||
# black.type = console_scripts
|
||||
# black.entrypoint = black
|
||||
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||
|
||||
# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module
|
||||
# hooks = ruff
|
||||
# ruff.type = module
|
||||
# ruff.module = ruff
|
||||
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Alternatively, use the exec runner to execute a binary found on your PATH
|
||||
# hooks = ruff
|
||||
# ruff.type = exec
|
||||
# ruff.executable = ruff
|
||||
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Logging configuration. This is also consumed by the user-maintained
|
||||
# env.py script only.
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARNING
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARNING
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
1
alembic/README
Normal file
1
alembic/README
Normal file
@@ -0,0 +1 @@
|
||||
Generic single-database configuration.
|
||||
83
alembic/env.py
Normal file
83
alembic/env.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from logging.config import fileConfig
|
||||
import os
|
||||
import sys
|
||||
|
||||
from sqlalchemy import engine_from_config
|
||||
from sqlalchemy import pool
|
||||
|
||||
from alembic import context
|
||||
|
||||
# Add src to path to ensure we can import the app
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../src'))
|
||||
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.history.models import Base
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# add your model's MetaData object here
|
||||
# for 'autogenerate' support
|
||||
target_metadata = Base.metadata
|
||||
|
||||
# Override sqlalchemy.url from settings
|
||||
settings = Settings()
|
||||
config.set_main_option("sqlalchemy.url", settings.history_db_url)
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode.
|
||||
|
||||
This configures the context with just a URL
|
||||
and not an Engine, though an Engine is acceptable
|
||||
here as well. By skipping the Engine creation
|
||||
we don't even need a DBAPI to be available.
|
||||
|
||||
Calls to context.execute() here emit the given string to the
|
||||
script output.
|
||||
|
||||
"""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode.
|
||||
|
||||
In this scenario we need to create an Engine
|
||||
and associate a connection with the context.
|
||||
|
||||
"""
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(
|
||||
connection=connection, target_metadata=target_metadata
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
28
alembic/script.py.mako
Normal file
28
alembic/script.py.mako
Normal file
@@ -0,0 +1,28 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
${downgrades if downgrades else "pass"}
|
||||
32
alembic/versions/63886baa1255_add_summary_to_conversation.py
Normal file
32
alembic/versions/63886baa1255_add_summary_to_conversation.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""Add summary to conversation
|
||||
|
||||
Revision ID: 63886baa1255
|
||||
Revises: a15fde9a62df
|
||||
Create Date: 2026-02-07 22:34:47.254569
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '63886baa1255'
|
||||
down_revision: Union[str, Sequence[str], None] = 'a15fde9a62df'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('conversations', sa.Column('summary', sa.Text(), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column('conversations', 'summary')
|
||||
# ### end Alembic commands ###
|
||||
68
alembic/versions/a15fde9a62df_initial_history_tables.py
Normal file
68
alembic/versions/a15fde9a62df_initial_history_tables.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""Initial history tables
|
||||
|
||||
Revision ID: a15fde9a62df
|
||||
Revises:
|
||||
Create Date: 2026-02-07 21:18:57.524534
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'a15fde9a62df'
|
||||
down_revision: Union[str, Sequence[str], None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('users',
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('username', sa.String(), nullable=False),
|
||||
sa.Column('password_hash', sa.String(), nullable=True),
|
||||
sa.Column('display_name', sa.String(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_users_username'), 'users', ['username'], unique=True)
|
||||
op.create_table('conversations',
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('user_id', sa.String(), nullable=False),
|
||||
sa.Column('data_state', sa.String(), nullable=False),
|
||||
sa.Column('name', sa.String(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('messages',
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('conversation_id', sa.String(), nullable=False),
|
||||
sa.Column('role', sa.String(), nullable=False),
|
||||
sa.Column('content', sa.Text(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['conversation_id'], ['conversations.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('plots',
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('message_id', sa.String(), nullable=False),
|
||||
sa.Column('image_data', sa.LargeBinary(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['message_id'], ['messages.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table('plots')
|
||||
op.drop_table('messages')
|
||||
op.drop_table('conversations')
|
||||
op.drop_index(op.f('ix_users_username'), table_name='users')
|
||||
op.drop_table('users')
|
||||
# ### end Alembic commands ###
|
||||
18
docker-compose.yml
Normal file
18
docker-compose.yml
Normal file
@@ -0,0 +1,18 @@
|
||||
services:
|
||||
history-db:
|
||||
image: postgres:16
|
||||
container_name: ea-history-db
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
POSTGRES_DB: ${HISTORY_DB_NAME:-ea_history}
|
||||
POSTGRES_USER: ${HISTORY_DB_USER:-user}
|
||||
POSTGRES_PASSWORD: ${HISTORY_DB_PSWD:-password}
|
||||
ports:
|
||||
- "5433:5432"
|
||||
volumes:
|
||||
- ./postgres-data:/var/lib/postgresql/data
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U $${POSTGRES_USER} -d $${POSTGRES_DB}"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
39
pyproject.toml
Normal file
39
pyproject.toml
Normal file
@@ -0,0 +1,39 @@
|
||||
[project]
|
||||
name = "ea-chatbot"
|
||||
version = "0.1.0"
|
||||
description = "An election analytics chatbot using LangGraph."
|
||||
requires-python = ">=3.13"
|
||||
dependencies = [
|
||||
"langgraph",
|
||||
"langchain",
|
||||
"langchain-openai",
|
||||
"langchain-google-genai",
|
||||
"langchain-community",
|
||||
"pandas",
|
||||
"matplotlib",
|
||||
"psycopg2-binary",
|
||||
"termcolor>=3.3.0",
|
||||
"sqlalchemy>=2.0.46",
|
||||
"pyyaml>=6.0.3",
|
||||
"streamlit>=1.54.0",
|
||||
"rich>=14.3.2",
|
||||
"alembic>=1.13.0",
|
||||
"authlib>=1.3.0",
|
||||
"httpx>=0.27.0",
|
||||
"argon2-cffi>=23.1.0",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/ea_chatbot"]
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"pytest>=9.0.2",
|
||||
"pytest-cov>=7.0.0",
|
||||
"ruff>=0.9.3",
|
||||
"mypy>=1.14.1",
|
||||
]
|
||||
0
src/ea_chatbot/__init__.py
Normal file
0
src/ea_chatbot/__init__.py
Normal file
451
src/ea_chatbot/app.py
Normal file
451
src/ea_chatbot/app.py
Normal file
@@ -0,0 +1,451 @@
|
||||
import streamlit as st
|
||||
import asyncio
|
||||
import os
|
||||
import io
|
||||
from dotenv import load_dotenv
|
||||
from ea_chatbot.graph.workflow import app
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.utils.logging import get_logger
|
||||
from ea_chatbot.utils.helpers import merge_agent_state
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.history.manager import HistoryManager
|
||||
from ea_chatbot.auth import OIDCClient, AuthType, get_user_auth_type
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Initialize Config and Manager
|
||||
settings = Settings()
|
||||
history_manager = HistoryManager(settings.history_db_url)
|
||||
|
||||
# Initialize OIDC Client if configured
|
||||
oidc_client = None
|
||||
if settings.oidc_client_id and settings.oidc_client_secret and settings.oidc_server_metadata_url:
|
||||
oidc_client = OIDCClient(
|
||||
client_id=settings.oidc_client_id,
|
||||
client_secret=settings.oidc_client_secret,
|
||||
server_metadata_url=settings.oidc_server_metadata_url,
|
||||
# Redirect back to the same page
|
||||
redirect_uri=os.getenv("OIDC_REDIRECT_URI", "http://localhost:8501")
|
||||
)
|
||||
|
||||
# Initialize Logger
|
||||
logger = get_logger(level=settings.log_level, log_file="logs/app.jsonl")
|
||||
|
||||
# --- Authentication Helpers ---
|
||||
|
||||
def login_user(user):
|
||||
st.session_state.user = user
|
||||
st.session_state.messages = []
|
||||
st.session_state.summary = ""
|
||||
st.session_state.current_conversation_id = None
|
||||
st.rerun()
|
||||
|
||||
def logout_user():
|
||||
for key in list(st.session_state.keys()):
|
||||
del st.session_state[key]
|
||||
st.rerun()
|
||||
|
||||
def load_conversation(conv_id):
|
||||
messages = history_manager.get_messages(conv_id)
|
||||
formatted_messages = []
|
||||
for m in messages:
|
||||
# Convert DB models to session state dicts
|
||||
msg_dict = {
|
||||
"role": m.role,
|
||||
"content": m.content,
|
||||
"plots": [p.image_data for p in m.plots]
|
||||
}
|
||||
formatted_messages.append(msg_dict)
|
||||
|
||||
st.session_state.messages = formatted_messages
|
||||
st.session_state.current_conversation_id = conv_id
|
||||
# Fetch summary from DB
|
||||
with history_manager.get_session() as session:
|
||||
from ea_chatbot.history.models import Conversation
|
||||
conv = session.get(Conversation, conv_id)
|
||||
st.session_state.summary = conv.summary if conv else ""
|
||||
st.rerun()
|
||||
|
||||
def main():
|
||||
st.set_page_config(
|
||||
page_title="Election Analytics Chatbot",
|
||||
page_icon="🗳️",
|
||||
layout="wide"
|
||||
)
|
||||
|
||||
# Check for OIDC Callback
|
||||
if "code" in st.query_params and oidc_client:
|
||||
code = st.query_params["code"]
|
||||
try:
|
||||
token = oidc_client.exchange_code_for_token(code)
|
||||
user_info = oidc_client.get_user_info(token)
|
||||
email = user_info.get("email")
|
||||
name = user_info.get("name") or user_info.get("preferred_username")
|
||||
|
||||
if email:
|
||||
user = history_manager.sync_user_from_oidc(email=email, display_name=name)
|
||||
# Clear query params
|
||||
st.query_params.clear()
|
||||
login_user(user)
|
||||
except Exception as e:
|
||||
st.error(f"OIDC Login failed: {str(e)}")
|
||||
|
||||
# Display Login Screen if not authenticated
|
||||
if "user" not in st.session_state:
|
||||
st.title("🗳️ Election Analytics Chatbot")
|
||||
|
||||
# Initialize Login State
|
||||
if "login_step" not in st.session_state:
|
||||
st.session_state.login_step = "email"
|
||||
if "login_email" not in st.session_state:
|
||||
st.session_state.login_email = ""
|
||||
|
||||
col1, col2 = st.columns([1, 1])
|
||||
|
||||
with col1:
|
||||
st.header("Login")
|
||||
|
||||
# Step 1: Identification
|
||||
if st.session_state.login_step == "email":
|
||||
st.write("Please enter your email to begin:")
|
||||
with st.form("email_form"):
|
||||
email_input = st.text_input("Email", value=st.session_state.login_email)
|
||||
submitted = st.form_submit_button("Next")
|
||||
|
||||
if submitted:
|
||||
if not email_input.strip():
|
||||
st.error("Email cannot be empty.")
|
||||
else:
|
||||
st.session_state.login_email = email_input.strip()
|
||||
auth_type = get_user_auth_type(st.session_state.login_email, history_manager)
|
||||
|
||||
if auth_type == AuthType.LOCAL:
|
||||
st.session_state.login_step = "login_password"
|
||||
elif auth_type == AuthType.OIDC:
|
||||
st.session_state.login_step = "oidc_login"
|
||||
else: # AuthType.NEW
|
||||
st.session_state.login_step = "register_details"
|
||||
st.rerun()
|
||||
|
||||
# Step 2a: Local Login
|
||||
elif st.session_state.login_step == "login_password":
|
||||
st.info(f"Welcome back, **{st.session_state.login_email}**!")
|
||||
with st.form("password_form"):
|
||||
password = st.text_input("Password", type="password")
|
||||
|
||||
col_login, col_back = st.columns([1, 1])
|
||||
submitted = col_login.form_submit_button("Login")
|
||||
back = col_back.form_submit_button("Back")
|
||||
|
||||
if back:
|
||||
st.session_state.login_step = "email"
|
||||
st.rerun()
|
||||
|
||||
if submitted:
|
||||
user = history_manager.authenticate_user(st.session_state.login_email, password)
|
||||
if user:
|
||||
login_user(user)
|
||||
else:
|
||||
st.error("Invalid email or password")
|
||||
|
||||
# Step 2b: Registration
|
||||
elif st.session_state.login_step == "register_details":
|
||||
st.info(f"Create an account for **{st.session_state.login_email}**")
|
||||
with st.form("register_form"):
|
||||
reg_name = st.text_input("Display Name")
|
||||
reg_password = st.text_input("Password", type="password")
|
||||
|
||||
col_reg, col_back = st.columns([1, 1])
|
||||
submitted = col_reg.form_submit_button("Register & Login")
|
||||
back = col_back.form_submit_button("Back")
|
||||
|
||||
if back:
|
||||
st.session_state.login_step = "email"
|
||||
st.rerun()
|
||||
|
||||
if submitted:
|
||||
if not reg_password:
|
||||
st.error("Password is required for registration.")
|
||||
else:
|
||||
user = history_manager.create_user(st.session_state.login_email, reg_password, reg_name)
|
||||
st.success("Registered! Logging in...")
|
||||
login_user(user)
|
||||
|
||||
# Step 2c: OIDC Redirection
|
||||
elif st.session_state.login_step == "oidc_login":
|
||||
st.info(f"**{st.session_state.login_email}** is configured for Single Sign-On (SSO).")
|
||||
|
||||
col_sso, col_back = st.columns([1, 1])
|
||||
|
||||
with col_sso:
|
||||
if oidc_client:
|
||||
login_url = oidc_client.get_login_url()
|
||||
st.link_button("Login with SSO", login_url, type="primary", use_container_width=True)
|
||||
else:
|
||||
st.error("OIDC is not configured.")
|
||||
|
||||
with col_back:
|
||||
if st.button("Back", use_container_width=True):
|
||||
st.session_state.login_step = "email"
|
||||
st.rerun()
|
||||
|
||||
with col2:
|
||||
if oidc_client:
|
||||
st.header("Single Sign-On")
|
||||
st.write("Login with your organizational account.")
|
||||
if st.button("Login with SSO"):
|
||||
login_url = oidc_client.get_login_url()
|
||||
st.link_button("Go to **YXXU**", login_url, type="primary")
|
||||
else:
|
||||
st.info("SSO is not configured.")
|
||||
|
||||
st.stop()
|
||||
|
||||
# --- Main App (Authenticated) ---
|
||||
|
||||
user = st.session_state.user
|
||||
|
||||
# Sidebar configuration
|
||||
with st.sidebar:
|
||||
st.title(f"Hi, {user.display_name or user.username}!")
|
||||
|
||||
if st.button("Logout"):
|
||||
logout_user()
|
||||
|
||||
st.divider()
|
||||
|
||||
st.header("History")
|
||||
if st.button("➕ New Chat", use_container_width=True):
|
||||
st.session_state.messages = []
|
||||
st.session_state.summary = ""
|
||||
st.session_state.current_conversation_id = None
|
||||
st.rerun()
|
||||
|
||||
# List conversations for the current user and data state
|
||||
conversations = history_manager.get_conversations(user.id, settings.data_state)
|
||||
|
||||
for conv in conversations:
|
||||
col_c, col_r, col_d = st.columns([0.7, 0.15, 0.15])
|
||||
|
||||
is_current = st.session_state.get("current_conversation_id") == conv.id
|
||||
label = f"💬 {conv.name}" if not is_current else f"👉 {conv.name}"
|
||||
|
||||
if col_c.button(label, key=f"conv_{conv.id}", use_container_width=True):
|
||||
load_conversation(conv.id)
|
||||
|
||||
if col_r.button("✏️", key=f"ren_{conv.id}"):
|
||||
st.session_state.renaming_id = conv.id
|
||||
|
||||
if col_d.button("🗑️", key=f"del_{conv.id}"):
|
||||
if history_manager.delete_conversation(conv.id):
|
||||
if is_current:
|
||||
st.session_state.current_conversation_id = None
|
||||
st.session_state.messages = []
|
||||
st.rerun()
|
||||
|
||||
# Rename dialog
|
||||
if st.session_state.get("renaming_id"):
|
||||
rid = st.session_state.renaming_id
|
||||
with st.form("rename_form"):
|
||||
new_name = st.text_input("New Name")
|
||||
if st.form_submit_button("Save"):
|
||||
history_manager.rename_conversation(rid, new_name)
|
||||
del st.session_state.renaming_id
|
||||
st.rerun()
|
||||
if st.form_submit_button("Cancel"):
|
||||
del st.session_state.renaming_id
|
||||
st.rerun()
|
||||
|
||||
st.divider()
|
||||
st.header("Settings")
|
||||
# Check for DEV_MODE env var (defaults to False)
|
||||
default_dev_mode = os.getenv("DEV_MODE", "false").lower() == "true"
|
||||
dev_mode = st.checkbox("Dev Mode", value=default_dev_mode, help="Enable to see code generation and raw reasoning steps.")
|
||||
|
||||
st.title("🗳️ Election Analytics Chatbot")
|
||||
|
||||
# Initialize chat history state
|
||||
if "messages" not in st.session_state:
|
||||
st.session_state.messages = []
|
||||
if "summary" not in st.session_state:
|
||||
st.session_state.summary = ""
|
||||
|
||||
# Display chat messages from history on app rerun
|
||||
for message in st.session_state.messages:
|
||||
with st.chat_message(message["role"]):
|
||||
if message.get("plan") and dev_mode:
|
||||
with st.expander("Reasoning Plan"):
|
||||
st.code(message["plan"], language="yaml")
|
||||
if message.get("code") and dev_mode:
|
||||
with st.expander("Generated Code"):
|
||||
st.code(message["code"], language="python")
|
||||
|
||||
st.markdown(message["content"])
|
||||
if message.get("plots"):
|
||||
for plot_data in message["plots"]:
|
||||
# If plot_data is bytes, convert to image
|
||||
if isinstance(plot_data, bytes):
|
||||
st.image(plot_data)
|
||||
else:
|
||||
# Fallback for old session state or non-binary
|
||||
st.pyplot(plot_data)
|
||||
if message.get("dfs"):
|
||||
for df_name, df in message["dfs"].items():
|
||||
st.subheader(f"Data: {df_name}")
|
||||
st.dataframe(df)
|
||||
|
||||
# Accept user input
|
||||
if prompt := st.chat_input("Ask a question about election data..."):
|
||||
# Ensure we have a conversation ID
|
||||
if not st.session_state.get("current_conversation_id"):
|
||||
# Auto-create conversation
|
||||
conv_name = (prompt[:30] + '...') if len(prompt) > 30 else prompt
|
||||
conv = history_manager.create_conversation(user.id, settings.data_state, conv_name)
|
||||
st.session_state.current_conversation_id = conv.id
|
||||
|
||||
conv_id = st.session_state.current_conversation_id
|
||||
|
||||
# Save user message to DB
|
||||
history_manager.add_message(conv_id, "user", prompt)
|
||||
|
||||
# Add user message to session state
|
||||
st.session_state.messages.append({"role": "user", "content": prompt})
|
||||
|
||||
# Display user message in chat message container
|
||||
with st.chat_message("user"):
|
||||
st.markdown(prompt)
|
||||
|
||||
# Prepare graph input
|
||||
initial_state: AgentState = {
|
||||
"messages": st.session_state.messages[:-1], # Pass history (excluding the current prompt)
|
||||
"question": prompt,
|
||||
"summary": st.session_state.summary,
|
||||
"analysis": None,
|
||||
"next_action": "",
|
||||
"plan": None,
|
||||
"code": None,
|
||||
"code_output": None,
|
||||
"error": None,
|
||||
"plots": [],
|
||||
"dfs": {}
|
||||
}
|
||||
|
||||
# Placeholder for graph output
|
||||
with st.chat_message("assistant"):
|
||||
final_state = initial_state
|
||||
# Real-time node updates
|
||||
with st.status("Thinking...", expanded=True) as status:
|
||||
try:
|
||||
# Use app.stream to capture node transitions
|
||||
for event in app.stream(initial_state):
|
||||
for node_name, state_update in event.items():
|
||||
prev_error = final_state.get("error")
|
||||
# Use helper to merge state correctly (appending messages/plots, updating dfs)
|
||||
final_state = merge_agent_state(final_state, state_update)
|
||||
|
||||
if node_name == "query_analyzer":
|
||||
analysis = state_update.get("analysis", {})
|
||||
next_action = state_update.get("next_action", "unknown")
|
||||
status.write(f"🔍 **Analyzed Query:**")
|
||||
for k,v in analysis.items():
|
||||
status.write(f"- {k:<8}: {v}")
|
||||
status.markdown(f"Next Step: {next_action.capitalize()}")
|
||||
|
||||
elif node_name == "planner":
|
||||
status.write("📋 **Plan Generated**")
|
||||
# Render artifacts
|
||||
if state_update.get("plan") and dev_mode:
|
||||
with st.expander("Reasoning Plan", expanded=True):
|
||||
st.code(state_update["plan"], language="yaml")
|
||||
|
||||
elif node_name == "researcher":
|
||||
status.write("🌐 **Research Complete**")
|
||||
if state_update.get("messages") and dev_mode:
|
||||
for msg in state_update["messages"]:
|
||||
# Extract content from BaseMessage or show raw string
|
||||
content = getattr(msg, "text", msg.content)
|
||||
status.markdown(content)
|
||||
|
||||
elif node_name == "coder":
|
||||
status.write("💻 **Code Generated**")
|
||||
if state_update.get("code") and dev_mode:
|
||||
with st.expander("Generated Code"):
|
||||
st.code(state_update["code"], language="python")
|
||||
|
||||
elif node_name == "error_corrector":
|
||||
status.write("🛠️ **Fixing Execution Error...**")
|
||||
if prev_error:
|
||||
truncated_error = prev_error.strip()
|
||||
if len(truncated_error) > 180:
|
||||
truncated_error = truncated_error[:180] + "..."
|
||||
status.write(f"Previous error: {truncated_error}")
|
||||
if state_update.get("code") and dev_mode:
|
||||
with st.expander("Corrected Code"):
|
||||
st.code(state_update["code"], language="python")
|
||||
|
||||
elif node_name == "executor":
|
||||
if state_update.get("error"):
|
||||
if dev_mode:
|
||||
status.write(f"❌ **Execution Error:** {state_update.get('error')}...")
|
||||
else:
|
||||
status.write(f"❌ **Execution Error:** {state_update.get('error')[:100]}...")
|
||||
else:
|
||||
status.write("✅ **Execution Successful**")
|
||||
if state_update.get("plots"):
|
||||
status.write(f"📊 Generated {len(state_update['plots'])} plot(s)")
|
||||
|
||||
elif node_name == "summarizer":
|
||||
status.write("📝 **Summarizing Results...**")
|
||||
|
||||
|
||||
status.update(label="Complete!", state="complete", expanded=False)
|
||||
|
||||
except Exception as e:
|
||||
status.update(label="Error!", state="error")
|
||||
st.error(f"Error during graph execution: {str(e)}")
|
||||
|
||||
# Extract results
|
||||
response_text: str = ""
|
||||
if final_state.get("messages"):
|
||||
# The last message is the Assistant's response
|
||||
last_msg = final_state["messages"][-1]
|
||||
response_text = getattr(last_msg, "text", str(last_msg.content))
|
||||
st.markdown(response_text)
|
||||
|
||||
# Collect plot bytes for saving to DB
|
||||
plot_bytes_list = []
|
||||
if final_state.get("plots"):
|
||||
for fig in final_state["plots"]:
|
||||
st.pyplot(fig)
|
||||
# Convert fig to bytes
|
||||
buf = io.BytesIO()
|
||||
fig.savefig(buf, format="png")
|
||||
plot_bytes_list.append(buf.getvalue())
|
||||
|
||||
if final_state.get("dfs"):
|
||||
for df_name, df in final_state["dfs"].items():
|
||||
st.subheader(f"Data: {df_name}")
|
||||
st.dataframe(df)
|
||||
|
||||
# Save assistant message to DB
|
||||
history_manager.add_message(conv_id, "assistant", response_text, plots=plot_bytes_list)
|
||||
|
||||
# Update summary in DB
|
||||
new_summary = final_state.get("summary", "")
|
||||
if new_summary:
|
||||
history_manager.update_conversation_summary(conv_id, new_summary)
|
||||
|
||||
# Store assistant response in session history
|
||||
st.session_state.messages.append({
|
||||
"role": "assistant",
|
||||
"content": response_text,
|
||||
"plan": final_state.get("plan"),
|
||||
"code": final_state.get("code"),
|
||||
"plots": plot_bytes_list,
|
||||
"dfs": final_state.get("dfs")
|
||||
})
|
||||
st.session_state.summary = new_summary
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
97
src/ea_chatbot/auth.py
Normal file
97
src/ea_chatbot/auth.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import requests
|
||||
from enum import Enum
|
||||
from typing import Dict, Any, Optional
|
||||
from authlib.integrations.requests_client import OAuth2Session
|
||||
|
||||
class AuthType(Enum):
|
||||
LOCAL = "local"
|
||||
OIDC = "oidc"
|
||||
NEW = "new"
|
||||
|
||||
def get_user_auth_type(email: str, history_manager: Any) -> AuthType:
|
||||
"""
|
||||
Determine the authentication type for a given email.
|
||||
|
||||
Args:
|
||||
email: The user's email address.
|
||||
history_manager: Instance of HistoryManager to check the DB.
|
||||
|
||||
Returns:
|
||||
AuthType: LOCAL if password exists, OIDC if user exists but no password, NEW otherwise.
|
||||
"""
|
||||
user = history_manager.get_user(email)
|
||||
|
||||
if not user:
|
||||
return AuthType.NEW
|
||||
|
||||
if user.password_hash:
|
||||
return AuthType.LOCAL
|
||||
|
||||
return AuthType.OIDC
|
||||
|
||||
class OIDCClient:
|
||||
"""
|
||||
Client for OIDC Authentication using Authlib.
|
||||
Designed to work within a Streamlit environment.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
server_metadata_url: str,
|
||||
redirect_uri: str = "http://localhost:8501"
|
||||
):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.server_metadata_url = server_metadata_url
|
||||
self.redirect_uri = redirect_uri
|
||||
self.oauth_session = OAuth2Session(
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
redirect_uri=redirect_uri,
|
||||
scope="openid email profile"
|
||||
)
|
||||
self.metadata: Dict[str, Any] = {}
|
||||
|
||||
def fetch_metadata(self) -> Dict[str, Any]:
|
||||
"""Fetch OIDC provider metadata if not already fetched."""
|
||||
if not self.metadata:
|
||||
self.metadata = requests.get(self.server_metadata_url).json()
|
||||
return self.metadata
|
||||
|
||||
def get_login_url(self) -> str:
|
||||
"""Generate the authorization URL."""
|
||||
metadata = self.fetch_metadata()
|
||||
authorization_endpoint = metadata.get("authorization_endpoint")
|
||||
if not authorization_endpoint:
|
||||
raise ValueError("authorization_endpoint not found in OIDC metadata")
|
||||
|
||||
uri, state = self.oauth_session.create_authorization_url(authorization_endpoint)
|
||||
return uri
|
||||
|
||||
def exchange_code_for_token(self, code: str) -> Dict[str, Any]:
|
||||
"""Exchange the authorization code for an access token."""
|
||||
metadata = self.fetch_metadata()
|
||||
token_endpoint = metadata.get("token_endpoint")
|
||||
if not token_endpoint:
|
||||
raise ValueError("token_endpoint not found in OIDC metadata")
|
||||
|
||||
token = self.oauth_session.fetch_token(
|
||||
token_endpoint,
|
||||
code=code,
|
||||
client_secret=self.client_secret
|
||||
)
|
||||
return token
|
||||
|
||||
def get_user_info(self, token: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Fetch user information using the access token."""
|
||||
metadata = self.fetch_metadata()
|
||||
userinfo_endpoint = metadata.get("userinfo_endpoint")
|
||||
if not userinfo_endpoint:
|
||||
raise ValueError("userinfo_endpoint not found in OIDC metadata")
|
||||
|
||||
# Set the token on the session so it's used in the request
|
||||
self.oauth_session.token = token
|
||||
resp = self.oauth_session.get(userinfo_endpoint)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
44
src/ea_chatbot/config.py
Normal file
44
src/ea_chatbot/config.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from typing import Dict, Any, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
class LLMConfig(BaseModel):
|
||||
"""Configuration for a specific LLM node."""
|
||||
provider: str = "openai"
|
||||
model: str = "gpt-5-mini"
|
||||
temperature: float = 0.0
|
||||
max_tokens: Optional[int] = None
|
||||
provider_specific: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Global application settings."""
|
||||
|
||||
data_dir: str = "data"
|
||||
data_state: str = "new_jersey"
|
||||
log_level: str = Field(default="INFO", alias="LOG_LEVEL")
|
||||
|
||||
# Voter Database configuration
|
||||
db_host: str = Field(default="localhost", alias="DB_HOST")
|
||||
db_port: int = Field(default=5432, alias="DB_PORT")
|
||||
db_name: str = Field(default="blockdata", alias="DB_NAME")
|
||||
db_user: str = Field(default="user", alias="DB_USER")
|
||||
db_pswd: str = Field(default="password", alias="DB_PSWD")
|
||||
db_table: str = Field(default="rd_gc_voters_nj", alias="DB_TABLE")
|
||||
|
||||
# Application/History Database
|
||||
history_db_url: str = Field(default="postgresql://user:password@localhost:5433/ea_history", alias="HISTORY_DB_URL")
|
||||
|
||||
# OIDC Configuration
|
||||
oidc_client_id: Optional[str] = Field(default=None, alias="OIDC_CLIENT_ID")
|
||||
oidc_client_secret: Optional[str] = Field(default=None, alias="OIDC_CLIENT_SECRET")
|
||||
oidc_server_metadata_url: Optional[str] = Field(default=None, alias="OIDC_SERVER_METADATA_URL")
|
||||
|
||||
# Default configurations for each node
|
||||
query_analyzer_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0))
|
||||
planner_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0))
|
||||
coder_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0))
|
||||
summarizer_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0))
|
||||
researcher_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0))
|
||||
|
||||
# Allow nested env vars like QUERY_ANALYZER_LLM__MODEL
|
||||
model_config = SettingsConfigDict(env_nested_delimiter='__', env_prefix='')
|
||||
0
src/ea_chatbot/graph/__init__.py
Normal file
0
src/ea_chatbot/graph/__init__.py
Normal file
0
src/ea_chatbot/graph/nodes/__init__.py
Normal file
0
src/ea_chatbot/graph/nodes/__init__.py
Normal file
45
src/ea_chatbot/graph/nodes/clarification.py
Normal file
45
src/ea_chatbot/graph/nodes/clarification.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from langchain_core.messages import AIMessage
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||
|
||||
def clarification_node(state: AgentState) -> dict:
|
||||
"""Ask the user for missing information or clarifications."""
|
||||
question = state["question"]
|
||||
analysis = state.get("analysis", {})
|
||||
ambiguities = analysis.get("ambiguities", [])
|
||||
|
||||
settings = Settings()
|
||||
logger = get_logger("clarification")
|
||||
|
||||
logger.info(f"Generating clarification for {len(ambiguities)} ambiguities.")
|
||||
|
||||
llm = get_llm_model(
|
||||
settings.query_analyzer_llm,
|
||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||
)
|
||||
|
||||
system_prompt = """You are a Clarification Specialist. Your role is to identify what information is missing from a user's request to perform a data analysis or research task.
|
||||
Based on the analysis of the user's question, formulate a polite and concise request for the missing information."""
|
||||
|
||||
prompt = f"""Original Question: {question}
|
||||
Missing/Ambiguous Information: {', '.join(ambiguities) if ambiguities else 'Unknown ambiguities'}
|
||||
|
||||
Please ask the user for the necessary details."""
|
||||
|
||||
messages = [
|
||||
("system", system_prompt),
|
||||
("user", prompt)
|
||||
]
|
||||
|
||||
try:
|
||||
response = llm.invoke(messages)
|
||||
logger.info("[bold green]Clarification generated.[/bold green]")
|
||||
return {
|
||||
"messages": [response],
|
||||
"next_action": "end" # To indicate we are done for now
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate clarification: {str(e)}")
|
||||
raise e
|
||||
47
src/ea_chatbot/graph/nodes/coder.py
Normal file
47
src/ea_chatbot/graph/nodes/coder.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from ea_chatbot.utils import helpers, database_inspection
|
||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||
from ea_chatbot.graph.prompts.coder import CODE_GENERATOR_PROMPT
|
||||
from ea_chatbot.schemas import CodeGenerationResponse
|
||||
|
||||
def coder_node(state: AgentState) -> dict:
|
||||
"""Generate Python code based on the plan and data summary."""
|
||||
question = state["question"]
|
||||
plan = state.get("plan", "")
|
||||
code_output = state.get("code_output", "None")
|
||||
|
||||
settings = Settings()
|
||||
logger = get_logger("coder")
|
||||
|
||||
logger.info("Generating Python code...")
|
||||
|
||||
llm = get_llm_model(
|
||||
settings.coder_llm,
|
||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||
)
|
||||
structured_llm = llm.with_structured_output(CodeGenerationResponse)
|
||||
|
||||
# Always provide data summary
|
||||
database_description = database_inspection.get_data_summary(data_dir=settings.data_dir) or "No data available."
|
||||
example_code = "" # Placeholder
|
||||
|
||||
messages = CODE_GENERATOR_PROMPT.format_messages(
|
||||
question=question,
|
||||
plan=plan,
|
||||
database_description=database_description,
|
||||
code_exec_results=code_output,
|
||||
example_code=example_code
|
||||
)
|
||||
|
||||
try:
|
||||
response = structured_llm.invoke(messages)
|
||||
logger.info("[bold green]Code generated.[/bold green]")
|
||||
return {
|
||||
"code": response.parsed_code,
|
||||
"error": None # Clear previous errors on new code generation
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate code: {str(e)}")
|
||||
raise e
|
||||
40
src/ea_chatbot/graph/nodes/error_corrector.py
Normal file
40
src/ea_chatbot/graph/nodes/error_corrector.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||
from ea_chatbot.graph.prompts.coder import ERROR_CORRECTOR_PROMPT
|
||||
from ea_chatbot.schemas import CodeGenerationResponse
|
||||
|
||||
def error_corrector_node(state: AgentState) -> dict:
|
||||
"""Fix the code based on the execution error."""
|
||||
code = state.get("code", "")
|
||||
error = state.get("error", "Unknown error")
|
||||
|
||||
settings = Settings()
|
||||
logger = get_logger("error_corrector")
|
||||
|
||||
logger.warning(f"[bold red]Execution error detected:[/bold red] {error[:100]}...")
|
||||
logger.info("Attempting to correct the code...")
|
||||
|
||||
# Reuse coder LLM config or add a new one. Using coder_llm for now.
|
||||
llm = get_llm_model(
|
||||
settings.coder_llm,
|
||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||
)
|
||||
structured_llm = llm.with_structured_output(CodeGenerationResponse)
|
||||
|
||||
messages = ERROR_CORRECTOR_PROMPT.format_messages(
|
||||
code=code,
|
||||
error=error
|
||||
)
|
||||
|
||||
try:
|
||||
response = structured_llm.invoke(messages)
|
||||
logger.info("[bold green]Correction generated.[/bold green]")
|
||||
return {
|
||||
"code": response.parsed_code,
|
||||
"error": None # Clear error after fix attempt
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to correct code: {str(e)}")
|
||||
raise e
|
||||
102
src/ea_chatbot/graph/nodes/executor.py
Normal file
102
src/ea_chatbot/graph/nodes/executor.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import io
|
||||
import sys
|
||||
import traceback
|
||||
from contextlib import redirect_stdout
|
||||
from typing import Any, Dict, List, TYPE_CHECKING
|
||||
import pandas as pd
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.utils.db_client import DBClient
|
||||
from ea_chatbot.utils.logging import get_logger
|
||||
from ea_chatbot.config import Settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ea_chatbot.types import DBSettings
|
||||
|
||||
def executor_node(state: AgentState) -> dict:
|
||||
"""Execute the Python code and capture output, plots, and dataframes."""
|
||||
code = state.get("code")
|
||||
logger = get_logger("executor")
|
||||
|
||||
if not code:
|
||||
logger.error("No code provided to executor.")
|
||||
return {"error": "No code provided to executor."}
|
||||
|
||||
logger.info("Executing Python code...")
|
||||
settings = Settings()
|
||||
|
||||
db_settings: "DBSettings" = {
|
||||
"host": settings.db_host,
|
||||
"port": settings.db_port,
|
||||
"user": settings.db_user,
|
||||
"pswd": settings.db_pswd,
|
||||
"db": settings.db_name,
|
||||
"table": settings.db_table
|
||||
}
|
||||
|
||||
db_client = DBClient(settings=db_settings)
|
||||
|
||||
# Initialize local variables for execution
|
||||
# 'db' is the DBClient instance, 'plots' is for matplotlib figures
|
||||
local_vars = {
|
||||
'db': db_client,
|
||||
'plots': [],
|
||||
'pd': pd
|
||||
}
|
||||
|
||||
stdout_buffer = io.StringIO()
|
||||
error = None
|
||||
code_output = ""
|
||||
plots = []
|
||||
dfs = {}
|
||||
|
||||
try:
|
||||
with redirect_stdout(stdout_buffer):
|
||||
# Execute the code in the context of local_vars
|
||||
exec(code, {}, local_vars)
|
||||
|
||||
code_output = stdout_buffer.getvalue()
|
||||
|
||||
# Limit the output length if it's too long
|
||||
if code_output.count('\n') > 32:
|
||||
code_output = '\n'.join(code_output.split('\n')[:32]) + '\n...'
|
||||
|
||||
# Extract plots
|
||||
raw_plots = local_vars.get('plots', [])
|
||||
if isinstance(raw_plots, list):
|
||||
plots = [p for p in raw_plots if isinstance(p, Figure)]
|
||||
|
||||
# Extract DataFrames that were likely intended for display
|
||||
# We look for DataFrames in local_vars that were mentioned in the code
|
||||
for key, value in local_vars.items():
|
||||
if isinstance(value, pd.DataFrame):
|
||||
# Heuristic: if the variable name is in the code, it might be a result DF
|
||||
if key in code:
|
||||
dfs[key] = value
|
||||
|
||||
logger.info(f"[bold green]Execution complete.[/bold green] Captured {len(plots)} plots and {len(dfs)} dataframes.")
|
||||
|
||||
except Exception as e:
|
||||
# Capture the traceback
|
||||
exc_type, exc_value, tb = sys.exc_info()
|
||||
full_traceback = traceback.format_exc()
|
||||
|
||||
# Filter traceback to show only the relevant part (the executed string)
|
||||
filtered_tb_lines = [line for line in full_traceback.split('\n') if '<string>' in line]
|
||||
error = '\n'.join(filtered_tb_lines)
|
||||
if error:
|
||||
error += '\n'
|
||||
error += f"{exc_type.__name__ if exc_type else 'Exception'}: {exc_value}"
|
||||
|
||||
logger.error(f"Execution failed: {str(e)}")
|
||||
|
||||
# If we have an error, we still might want to see partial stdout
|
||||
code_output = stdout_buffer.getvalue()
|
||||
|
||||
return {
|
||||
"code_output": code_output,
|
||||
"error": error,
|
||||
"plots": plots,
|
||||
"dfs": dfs
|
||||
}
|
||||
51
src/ea_chatbot/graph/nodes/planner.py
Normal file
51
src/ea_chatbot/graph/nodes/planner.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import yaml
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from ea_chatbot.utils import helpers, database_inspection
|
||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||
from ea_chatbot.graph.prompts.planner import PLANNER_PROMPT
|
||||
from ea_chatbot.schemas import TaskPlanResponse
|
||||
|
||||
def planner_node(state: AgentState) -> dict:
|
||||
"""Generate a structured plan based on the query analysis."""
|
||||
question = state["question"]
|
||||
history = state.get("messages", [])[-6:]
|
||||
summary = state.get("summary", "")
|
||||
|
||||
settings = Settings()
|
||||
logger = get_logger("planner")
|
||||
|
||||
logger.info("Generating task plan...")
|
||||
|
||||
llm = get_llm_model(
|
||||
settings.planner_llm,
|
||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||
)
|
||||
structured_llm = llm.with_structured_output(TaskPlanResponse)
|
||||
|
||||
date_str = helpers.get_readable_date()
|
||||
|
||||
# Always provide data summary; LLM decides relevance.
|
||||
database_description = database_inspection.get_data_summary(data_dir=settings.data_dir) or "No data available."
|
||||
example_plan = ""
|
||||
|
||||
messages = PLANNER_PROMPT.format_messages(
|
||||
date=date_str,
|
||||
question=question,
|
||||
history=history,
|
||||
summary=summary,
|
||||
database_description=database_description,
|
||||
example_plan=example_plan
|
||||
)
|
||||
|
||||
# Generate the structured plan
|
||||
try:
|
||||
response = structured_llm.invoke(messages)
|
||||
# Convert the structured response back to YAML string for the state
|
||||
plan_yaml = yaml.dump(response.model_dump(), sort_keys=False)
|
||||
logger.info("[bold green]Plan generated successfully.[/bold green]")
|
||||
return {"plan": plan_yaml}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate plan: {str(e)}")
|
||||
raise e
|
||||
72
src/ea_chatbot/graph/nodes/query_analyzer.py
Normal file
72
src/ea_chatbot/graph/nodes/query_analyzer.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from typing import List, Literal
|
||||
from pydantic import BaseModel, Field
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||
from ea_chatbot.graph.prompts.query_analyzer import QUERY_ANALYZER_PROMPT
|
||||
|
||||
class QueryAnalysis(BaseModel):
|
||||
"""Analysis of the user's query."""
|
||||
data_required: List[str] = Field(description="List of data points or entities mentioned (e.g., ['2024 results', 'Florida']).")
|
||||
unknowns: List[str] = Field(description="List of target information the user wants to know or needed for final answer (e.g., 'who won', 'total votes').")
|
||||
ambiguities: List[str] = Field(description="List of CRITICAL missing details that prevent ANY analysis. Do NOT include database names or plot types if defaults can be used.")
|
||||
conditions: List[str] = Field(description="List of any filters or constraints (e.g., ['year=2024', 'state=Florida']). Include context resolved from history.")
|
||||
next_action: Literal["plan", "clarify", "research"] = Field(description="The next action to take. 'plan' for data analysis (even with defaults), 'research' for general knowledge, or 'clarify' ONLY for critical ambiguities.")
|
||||
|
||||
def query_analyzer_node(state: AgentState) -> dict:
|
||||
"""Analyze the user's question and determine the next course of action."""
|
||||
question = state["question"]
|
||||
history = state.get("messages", [])
|
||||
summary = state.get("summary", "")
|
||||
|
||||
# Keep last 3 turns (6 messages)
|
||||
history = history[-6:]
|
||||
|
||||
settings = Settings()
|
||||
logger = get_logger("query_analyzer")
|
||||
|
||||
logger.info(f"Analyzing question: [italic]\"{question}\"[/italic]")
|
||||
|
||||
# Initialize the LLM with structured output using the factory
|
||||
# Pass logging callback to track LLM usage
|
||||
llm = get_llm_model(
|
||||
settings.query_analyzer_llm,
|
||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||
)
|
||||
structured_llm = llm.with_structured_output(QueryAnalysis)
|
||||
|
||||
# Prepare messages using the prompt template
|
||||
messages = QUERY_ANALYZER_PROMPT.format_messages(
|
||||
question=question,
|
||||
history=history,
|
||||
summary=summary
|
||||
)
|
||||
|
||||
try:
|
||||
# Invoke the structured LLM directly with the list of messages
|
||||
analysis_result = structured_llm.invoke(messages)
|
||||
analysis_result = QueryAnalysis.model_validate(analysis_result)
|
||||
|
||||
analysis_dict = analysis_result.model_dump()
|
||||
analysis_dict.pop("next_action")
|
||||
next_action = analysis_result.next_action
|
||||
|
||||
logger.info(f"Analysis complete. Next action: [bold magenta]{next_action}[/bold magenta]")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during query analysis: {str(e)}")
|
||||
analysis_dict = {
|
||||
"data_required": [],
|
||||
"unknowns": [],
|
||||
"ambiguities": [f"Error during analysis: {str(e)}"],
|
||||
"conditions": []
|
||||
}
|
||||
next_action = "clarify"
|
||||
|
||||
return {
|
||||
"analysis": analysis_dict,
|
||||
"next_action": next_action
|
||||
}
|
||||
|
||||
|
||||
60
src/ea_chatbot/graph/nodes/researcher.py
Normal file
60
src/ea_chatbot/graph/nodes/researcher.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from ea_chatbot.utils import helpers
|
||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||
from ea_chatbot.graph.prompts.researcher import RESEARCHER_PROMPT
|
||||
|
||||
def researcher_node(state: AgentState) -> dict:
|
||||
"""Handle general research queries or web searches."""
|
||||
question = state["question"]
|
||||
history = state.get("messages", [])[-6:]
|
||||
summary = state.get("summary", "")
|
||||
|
||||
settings = Settings()
|
||||
logger = get_logger("researcher")
|
||||
|
||||
logger.info(f"Researching question: [italic]\"{question}\"[/italic]")
|
||||
|
||||
# Use researcher_llm from settings
|
||||
llm = get_llm_model(
|
||||
settings.researcher_llm,
|
||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||
)
|
||||
|
||||
date_str = helpers.get_readable_date()
|
||||
|
||||
messages = RESEARCHER_PROMPT.format_messages(
|
||||
date=date_str,
|
||||
question=question,
|
||||
history=history,
|
||||
summary=summary
|
||||
)
|
||||
|
||||
# Provider-aware tool binding
|
||||
try:
|
||||
if isinstance(llm, ChatGoogleGenerativeAI):
|
||||
# Native Google Search for Gemini
|
||||
llm_with_tools = llm.bind_tools([{"google_search": {}}])
|
||||
elif isinstance(llm, ChatOpenAI):
|
||||
# Native Web Search for OpenAI (built-in tool)
|
||||
llm_with_tools = llm.bind_tools([{"type": "web_search"}])
|
||||
else:
|
||||
# Fallback for other providers that might not support these specific search tools
|
||||
llm_with_tools = llm
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to bind search tools: {str(e)}. Falling back to base LLM.")
|
||||
llm_with_tools = llm
|
||||
|
||||
try:
|
||||
response = llm_with_tools.invoke(messages)
|
||||
logger.info("[bold green]Research complete.[/bold green]")
|
||||
return {
|
||||
"messages": [response]
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Research failed: {str(e)}")
|
||||
raise e
|
||||
52
src/ea_chatbot/graph/nodes/summarize_conversation.py
Normal file
52
src/ea_chatbot/graph/nodes/summarize_conversation.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from langchain_core.messages import SystemMessage, HumanMessage
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||
|
||||
def summarize_conversation_node(state: AgentState) -> dict:
|
||||
"""Update the conversation summary based on the latest interaction."""
|
||||
summary = state.get("summary", "")
|
||||
messages = state.get("messages", [])
|
||||
|
||||
# We only summarize if there are messages
|
||||
if not messages:
|
||||
return {}
|
||||
|
||||
# Get the last turn (User + Assistant)
|
||||
last_turn = messages[-2:]
|
||||
|
||||
settings = Settings()
|
||||
logger = get_logger("summarize_conversation")
|
||||
|
||||
logger.info("Updating conversation summary...")
|
||||
|
||||
# Use summarizer_llm for this task as well
|
||||
llm = get_llm_model(
|
||||
settings.summarizer_llm,
|
||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||
)
|
||||
|
||||
if summary:
|
||||
prompt = (
|
||||
f"This is a summary of the conversation so far: {summary}\n\n"
|
||||
"Extend the summary by taking into account the new messages above."
|
||||
)
|
||||
else:
|
||||
prompt = "Create a summary of the conversation above."
|
||||
|
||||
# Construct the messages for the summarization LLM
|
||||
summarization_messages = [
|
||||
SystemMessage(content=f"Current summary: {summary}" if summary else "You are a helpful assistant that summarizes conversations."),
|
||||
HumanMessage(content=f"Recent messages:\n{last_turn}\n\n{prompt}\n\nKeep the summary concise and focused on the key topics and data points discussed.")
|
||||
]
|
||||
|
||||
try:
|
||||
response = llm.invoke(summarization_messages)
|
||||
new_summary = response.content
|
||||
logger.info("[bold green]Conversation summary updated.[/bold green]")
|
||||
return {"summary": new_summary}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update summary: {str(e)}")
|
||||
# If summarization fails, we keep the old one
|
||||
return {"summary": summary}
|
||||
44
src/ea_chatbot/graph/nodes/summarizer.py
Normal file
44
src/ea_chatbot/graph/nodes/summarizer.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from langchain_core.messages import AIMessage
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||
from ea_chatbot.graph.prompts.summarizer import SUMMARIZER_PROMPT
|
||||
|
||||
def summarizer_node(state: AgentState) -> dict:
|
||||
"""Summarize the code execution results into a final answer."""
|
||||
question = state["question"]
|
||||
plan = state.get("plan", "")
|
||||
code_output = state.get("code_output", "")
|
||||
history = state.get("messages", [])[-6:]
|
||||
summary = state.get("summary", "")
|
||||
|
||||
settings = Settings()
|
||||
logger = get_logger("summarizer")
|
||||
|
||||
logger.info("Generating final summary...")
|
||||
|
||||
llm = get_llm_model(
|
||||
settings.summarizer_llm,
|
||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||
)
|
||||
|
||||
messages = SUMMARIZER_PROMPT.format_messages(
|
||||
question=question,
|
||||
plan=plan,
|
||||
code_output=code_output,
|
||||
history=history,
|
||||
summary=summary
|
||||
)
|
||||
|
||||
try:
|
||||
response = llm.invoke(messages)
|
||||
logger.info("[bold green]Summary generated.[/bold green]")
|
||||
|
||||
# Return the final message to be added to the state
|
||||
return {
|
||||
"messages": [response]
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate summary: {str(e)}")
|
||||
raise e
|
||||
10
src/ea_chatbot/graph/prompts/__init__.py
Normal file
10
src/ea_chatbot/graph/prompts/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from .query_analyzer import QUERY_ANALYZER_PROMPT
|
||||
from .planner import PLANNER_PROMPT
|
||||
from .coder import CODE_GENERATOR_PROMPT, ERROR_CORRECTOR_PROMPT
|
||||
|
||||
__all__ = [
|
||||
"QUERY_ANALYZER_PROMPT",
|
||||
"PLANNER_PROMPT",
|
||||
"CODE_GENERATOR_PROMPT",
|
||||
"ERROR_CORRECTOR_PROMPT",
|
||||
]
|
||||
64
src/ea_chatbot/graph/prompts/coder.py
Normal file
64
src/ea_chatbot/graph/prompts/coder.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
|
||||
CODE_GENERATOR_SYSTEM = """You are an AI data analyst and your job is to assist users with data analysis and coding tasks.
|
||||
The user will provide a task and a plan.
|
||||
|
||||
**Data Access:**
|
||||
- A database client is available as a variable named `db`.
|
||||
- You MUST use `db.query_df(sql_query)` to execute SQL queries and retrieve data as a Pandas DataFrame.
|
||||
- Do NOT assume a dataframe `df` is already loaded unless explicitly stated. You usually need to query it first.
|
||||
- The database schema is described in the prompt. Use it to construct valid SQL queries.
|
||||
|
||||
**Plotting:**
|
||||
- If you need to plot any data, use the `plots` list to store the figures.
|
||||
- Example: `plots.append(fig)` or `plots.append(plt.gcf())`.
|
||||
- Do not use `plt.show()` as it will render the plot and cause an error.
|
||||
|
||||
**Code Requirements:**
|
||||
- Produce FULL, COMPLETE CODE that includes all steps and solves the task!
|
||||
- Always include the import statements at the top of the code (e.g., `import pandas as pd`, `import matplotlib.pyplot as plt`).
|
||||
- Always include print statements to output the results of your code.
|
||||
- Use `db.query_df("SELECT ...")` to get data."""
|
||||
|
||||
CODE_GENERATOR_USER = """TASK:
|
||||
{question}
|
||||
|
||||
PLAN:
|
||||
```yaml
|
||||
{plan}
|
||||
```
|
||||
|
||||
AVAILABLE DATA SUMMARY (Database Schema):
|
||||
{database_description}
|
||||
|
||||
CODE EXECUTION OF THE PREVIOUS TASK RESULTED IN:
|
||||
{code_exec_results}
|
||||
|
||||
{example_code}"""
|
||||
|
||||
ERROR_CORRECTOR_SYSTEM = """The execution of the code resulted in an error.
|
||||
Return a complete, corrected python code that incorporates the fixes for the error.
|
||||
|
||||
**Reminders:**
|
||||
- You have access to a database client via the variable `db`.
|
||||
- Use `db.query_df(sql)` to run queries.
|
||||
- Use `plots.append(fig)` for plots.
|
||||
- Always include imports and print statements."""
|
||||
|
||||
ERROR_CORRECTOR_USER = """FAILED CODE:
|
||||
```python
|
||||
{code}
|
||||
```
|
||||
|
||||
ERROR:
|
||||
{error}"""
|
||||
|
||||
CODE_GENERATOR_PROMPT = ChatPromptTemplate.from_messages([
|
||||
("system", CODE_GENERATOR_SYSTEM),
|
||||
("human", CODE_GENERATOR_USER),
|
||||
])
|
||||
|
||||
ERROR_CORRECTOR_PROMPT = ChatPromptTemplate.from_messages([
|
||||
("system", ERROR_CORRECTOR_SYSTEM),
|
||||
("human", ERROR_CORRECTOR_USER),
|
||||
])
|
||||
46
src/ea_chatbot/graph/prompts/planner.py
Normal file
46
src/ea_chatbot/graph/prompts/planner.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
|
||||
PLANNER_SYSTEM = """You are a Research Specialist and your job is to find answers and educate the user.
|
||||
Provide factual information responding directly to the user's question. Include key details and context to ensure your response comprehensively answers their query.
|
||||
|
||||
Today's Date is: {date}"""
|
||||
|
||||
PLANNER_USER = """Conversation Summary: {summary}
|
||||
|
||||
TASK:
|
||||
{question}
|
||||
|
||||
AVAILABLE DATA SUMMARY (Use only if relevant to the task):
|
||||
{database_description}
|
||||
|
||||
First: Evaluate whether you have all necessary and requested information to provide a solution.
|
||||
Use the dataset description above to determine what data and in what format you have available to you.
|
||||
You are able to search internet if the user asks for it, or you require any information that you can not derive from the given dataset or the instruction.
|
||||
|
||||
Second: Incorporate any additional relevant context, reasoning, or details from previous interactions or internal chain-of-thought that may impact the solution.
|
||||
Ensure that all such information is fully included in your response rather than referring to previous answers indirectly.
|
||||
|
||||
Third: Reflect on the problem and briefly describe it, while addressing the problem goal, inputs, outputs,
|
||||
rules, constraints, and other relevant details that appear in the problem description.
|
||||
|
||||
Fourth: Based on the preceding steps, formulate your response as an algorithm, breaking the solution in up to eight simple concise yet descriptive, clear English steps.
|
||||
You MUST Include all values or instructions as described in the above task, or retrieved using internet search!
|
||||
If fewer steps suffice, that's acceptable. If more are needed, please include them.
|
||||
Remember to explain steps rather than write code.
|
||||
|
||||
This algorithm will be later converted to Python code.
|
||||
If a dataframe is required, assume it is named 'df' and is already defined/populated based on the data summary.
|
||||
|
||||
There is a list variable called `plots` that you need to use to store any plots you generate. Do not use `plt.show()` as it will render the plot and cause an error.
|
||||
|
||||
Output the algorithm as a YAML string. Always enclose the YAML string within ```yaml tags.
|
||||
|
||||
**Note: Ensure that any necessary context from prior interactions is fully embedded in the plan. Do not use phrases like "refer to previous answer"; instead, provide complete details inline.**
|
||||
|
||||
{example_plan}"""
|
||||
|
||||
PLANNER_PROMPT = ChatPromptTemplate.from_messages([
|
||||
("system", PLANNER_SYSTEM),
|
||||
MessagesPlaceholder(variable_name="history"),
|
||||
("human", PLANNER_USER),
|
||||
])
|
||||
33
src/ea_chatbot/graph/prompts/query_analyzer.py
Normal file
33
src/ea_chatbot/graph/prompts/query_analyzer.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
|
||||
SYSTEM_PROMPT = """You are an expert election data analyst. Decompose the user's question into key elements to determine the next action.
|
||||
|
||||
### Context & Defaults
|
||||
- **History:** Use the conversation history and summary to resolve coreferences (e.g., "those results", "that state"). Assume the current question inherits missing context (Year, State, County) from history.
|
||||
- **Data Access:** You have access to voter and election databases. Proceed to planning without asking for database or table names.
|
||||
- **Downstream Capabilities:** Visualizations are generated as Matplotlib figures. Proceed to planning for "graphs" or "plots" without asking for file formats or plot types.
|
||||
- **Trends:** For trend requests without a specified interval, allow the Planner to use a sensible default (e.g., by election cycle).
|
||||
|
||||
### Instructions:
|
||||
1. **Analyze:** Identify if the request is for data analysis, general facts (web research), or is critically ambiguous.
|
||||
2. **Extract Entities & Conditions:**
|
||||
- **Data Required:** e.g., "vote count", "demographics".
|
||||
- **Conditions:** e.g., "Year=2024". Include context from history.
|
||||
3. **Identify Target & Critical Ambiguities:**
|
||||
- **Unknowns:** The core target question.
|
||||
- **Critical Ambiguities:** ONLY list issues that PREVENT any analysis.
|
||||
- Examples: No timeframe/geography in query OR history; "track the same voter" without an identity definition.
|
||||
4. **Determine Action:**
|
||||
- `plan`: For data analysis where defaults or history provide sufficient context.
|
||||
- `research`: For general knowledge.
|
||||
- `clarify`: ONLY for CRITICAL ambiguities."""
|
||||
|
||||
USER_PROMPT_TEMPLATE = """Conversation Summary: {summary}
|
||||
|
||||
Analyze the following question: {question}"""
|
||||
|
||||
QUERY_ANALYZER_PROMPT = ChatPromptTemplate.from_messages([
|
||||
("system", SYSTEM_PROMPT),
|
||||
MessagesPlaceholder(variable_name="history"),
|
||||
("human", USER_PROMPT_TEMPLATE),
|
||||
])
|
||||
12
src/ea_chatbot/graph/prompts/researcher.py
Normal file
12
src/ea_chatbot/graph/prompts/researcher.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
|
||||
RESEARCHER_PROMPT = ChatPromptTemplate.from_messages([
|
||||
("system", """You are a Research Specialist and your job is to find answers and educate the user.
|
||||
Provide factual information responding directly to the user's question. Include key details and context to ensure your response comprehensively answers their query.
|
||||
|
||||
Today's Date is: {date}"""),
|
||||
MessagesPlaceholder(variable_name="history"),
|
||||
("user", """Conversation Summary: {summary}
|
||||
|
||||
{question}""")
|
||||
])
|
||||
27
src/ea_chatbot/graph/prompts/summarizer.py
Normal file
27
src/ea_chatbot/graph/prompts/summarizer.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
|
||||
SUMMARIZER_PROMPT = ChatPromptTemplate.from_messages([
|
||||
("system", """You are an expert election data analyst providing a final answer to the user.
|
||||
Use the provided conversation history and summary to ensure your response is contextually relevant and flows naturally from previous turns.
|
||||
|
||||
Conversation Summary: {summary}"""),
|
||||
MessagesPlaceholder(variable_name="history"),
|
||||
("user", """The user presented you with the following question.
|
||||
Question: {question}
|
||||
|
||||
To address this, you have designed an algorithm.
|
||||
Algorithm: {plan}.
|
||||
|
||||
You have crafted a Python code based on this algorithm, and the output generated by the code's execution is as follows.
|
||||
Output: {code_output}.
|
||||
|
||||
Please produce a comprehensive, easy-to-understand answer that:
|
||||
1. Summarizes the main insights or conclusions achieved through your method's implementation. Include execution results if necessary.
|
||||
2. Includes relevant findings from the code execution in a clear format (e.g., text explanation, tables, lists, bullet points).
|
||||
- Avoid referencing the code or output as 'the above results' or saying 'it's in the code output.'
|
||||
- Instead, present the actual key data or statistics within your explanation.
|
||||
3. If the user requested specific information that does not appear in the code's output but you can provide it, include that information directly in your summary.
|
||||
4. Present any data or tables that might have been generated by the code in full, since the user cannot directly see the execution output.
|
||||
|
||||
Your goal is to give a final answer that stands on its own without requiring the user to see the code or raw output directly.""")
|
||||
])
|
||||
33
src/ea_chatbot/graph/state.py
Normal file
33
src/ea_chatbot/graph/state.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from typing import TypedDict, Annotated, List, Dict, Any, Optional
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain.agents import AgentState as AS
|
||||
import operator
|
||||
|
||||
class AgentState(AS):
|
||||
# Conversation history
|
||||
messages: Annotated[List[BaseMessage], operator.add]
|
||||
|
||||
# Task context
|
||||
question: str
|
||||
|
||||
# Query Analysis (Decomposition results)
|
||||
analysis: Optional[Dict[str, Any]]
|
||||
# Expected keys: "requires_dataset", "expert", "data", "unknown", "condition"
|
||||
|
||||
# Step-by-step reasoning
|
||||
plan: Optional[str]
|
||||
|
||||
# Code execution context
|
||||
code: Optional[str]
|
||||
code_output: Optional[str]
|
||||
error: Optional[str]
|
||||
|
||||
# Artifacts (for UI display)
|
||||
plots: Annotated[List[Any], operator.add] # Matplotlib figures
|
||||
dfs: Dict[str, Any] # Pandas DataFrames
|
||||
|
||||
# Conversation summary
|
||||
summary: Optional[str]
|
||||
|
||||
# Routing hint: "clarify", "plan", "research", "end"
|
||||
next_action: str
|
||||
87
src/ea_chatbot/graph/workflow.py
Normal file
87
src/ea_chatbot/graph/workflow.py
Normal file
@@ -0,0 +1,87 @@
|
||||
from langgraph.graph import StateGraph, END
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node
|
||||
from ea_chatbot.graph.nodes.planner import planner_node
|
||||
from ea_chatbot.graph.nodes.coder import coder_node
|
||||
from ea_chatbot.graph.nodes.error_corrector import error_corrector_node
|
||||
from ea_chatbot.graph.nodes.executor import executor_node
|
||||
from ea_chatbot.graph.nodes.summarizer import summarizer_node
|
||||
from ea_chatbot.graph.nodes.researcher import researcher_node
|
||||
from ea_chatbot.graph.nodes.clarification import clarification_node
|
||||
from ea_chatbot.graph.nodes.summarize_conversation import summarize_conversation_node
|
||||
|
||||
def router(state: AgentState) -> str:
|
||||
"""Route to the next node based on the analysis."""
|
||||
next_action = state.get("next_action")
|
||||
if next_action == "plan":
|
||||
return "planner"
|
||||
elif next_action == "research":
|
||||
return "researcher"
|
||||
elif next_action == "clarify":
|
||||
return "clarification"
|
||||
else:
|
||||
return END
|
||||
|
||||
def create_workflow():
|
||||
"""Create the LangGraph workflow."""
|
||||
workflow = StateGraph(AgentState)
|
||||
|
||||
# Add nodes
|
||||
workflow.add_node("query_analyzer", query_analyzer_node)
|
||||
workflow.add_node("planner", planner_node)
|
||||
workflow.add_node("coder", coder_node)
|
||||
workflow.add_node("error_corrector", error_corrector_node)
|
||||
workflow.add_node("researcher", researcher_node)
|
||||
workflow.add_node("clarification", clarification_node)
|
||||
workflow.add_node("executor", executor_node)
|
||||
workflow.add_node("summarizer", summarizer_node)
|
||||
workflow.add_node("summarize_conversation", summarize_conversation_node)
|
||||
|
||||
# Set entry point
|
||||
workflow.set_entry_point("query_analyzer")
|
||||
|
||||
# Add conditional edges from query_analyzer
|
||||
workflow.add_conditional_edges(
|
||||
"query_analyzer",
|
||||
router,
|
||||
{
|
||||
"planner": "planner",
|
||||
"researcher": "researcher",
|
||||
"clarification": "clarification",
|
||||
END: END
|
||||
}
|
||||
)
|
||||
|
||||
# Linear flow for planning and coding
|
||||
workflow.add_edge("planner", "coder")
|
||||
workflow.add_edge("coder", "executor")
|
||||
|
||||
# Executor routing
|
||||
def executor_router(state: AgentState) -> str:
|
||||
if state.get("error"):
|
||||
return "error_corrector"
|
||||
return "summarizer"
|
||||
|
||||
workflow.add_conditional_edges(
|
||||
"executor",
|
||||
executor_router,
|
||||
{
|
||||
"error_corrector": "error_corrector",
|
||||
"summarizer": "summarizer"
|
||||
}
|
||||
)
|
||||
|
||||
workflow.add_edge("error_corrector", "executor")
|
||||
|
||||
workflow.add_edge("researcher", "summarize_conversation")
|
||||
workflow.add_edge("clarification", END)
|
||||
workflow.add_edge("summarizer", "summarize_conversation")
|
||||
workflow.add_edge("summarize_conversation", END)
|
||||
|
||||
# Compile the graph
|
||||
app = workflow.compile()
|
||||
|
||||
return app
|
||||
|
||||
# Initialize the app
|
||||
app = create_workflow()
|
||||
0
src/ea_chatbot/history/__init__.py
Normal file
0
src/ea_chatbot/history/__init__.py
Normal file
183
src/ea_chatbot/history/manager.py
Normal file
183
src/ea_chatbot/history/manager.py
Normal file
@@ -0,0 +1,183 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, List
|
||||
from sqlalchemy import create_engine, select, delete
|
||||
from sqlalchemy.orm import sessionmaker, Session
|
||||
from argon2 import PasswordHasher
|
||||
from argon2.exceptions import VerifyMismatchError
|
||||
|
||||
from ea_chatbot.history.models import User, Conversation, Message, Plot
|
||||
|
||||
# Argon2 Password Hasher
|
||||
ph = PasswordHasher()
|
||||
|
||||
class HistoryManager:
|
||||
"""Manages database sessions and operations for history and user data."""
|
||||
|
||||
def __init__(self, db_url: str):
|
||||
self.engine = create_engine(db_url)
|
||||
# expire_on_commit=False is important so we can use objects after session closes
|
||||
self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self.engine, expire_on_commit=False)
|
||||
|
||||
@contextmanager
|
||||
def get_session(self):
|
||||
"""Context manager for database sessions."""
|
||||
session = self.SessionLocal()
|
||||
try:
|
||||
yield session
|
||||
session.commit()
|
||||
except Exception:
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
# --- User Management ---
|
||||
|
||||
def get_user(self, email: str) -> Optional[User]:
|
||||
"""Fetch a user by their email (username)."""
|
||||
with self.get_session() as session:
|
||||
result = session.execute(select(User).where(User.username == email))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
def create_user(self, email: str, password: Optional[str] = None, display_name: Optional[str] = None) -> User:
|
||||
"""Create a new local user."""
|
||||
hashed_password = ph.hash(password) if password else None
|
||||
user = User(
|
||||
username=email,
|
||||
password_hash=hashed_password,
|
||||
display_name=display_name or email.split("@")[0]
|
||||
)
|
||||
with self.get_session() as session:
|
||||
session.add(user)
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
return user
|
||||
|
||||
def authenticate_user(self, email: str, password: str) -> Optional[User]:
|
||||
"""Authenticate a user by email and password."""
|
||||
user = self.get_user(email)
|
||||
if not user or not user.password_hash:
|
||||
return None
|
||||
|
||||
try:
|
||||
ph.verify(user.password_hash, password)
|
||||
return user
|
||||
except VerifyMismatchError:
|
||||
return None
|
||||
|
||||
def sync_user_from_oidc(self, email: str, display_name: Optional[str] = None) -> User:
|
||||
"""
|
||||
Synchronize a user from an OIDC provider.
|
||||
If a user with the same email exists, update their display name.
|
||||
Otherwise, create a new user.
|
||||
"""
|
||||
user = self.get_user(email)
|
||||
if user:
|
||||
# Update existing user if needed
|
||||
if display_name and user.display_name != display_name:
|
||||
with self.get_session() as session:
|
||||
db_user = session.get(User, user.id)
|
||||
db_user.display_name = display_name
|
||||
session.commit()
|
||||
session.refresh(db_user)
|
||||
return db_user
|
||||
return user
|
||||
else:
|
||||
# Create new user (no password for OIDC users initially)
|
||||
return self.create_user(email=email, display_name=display_name)
|
||||
|
||||
# --- Conversation Management ---
|
||||
|
||||
def create_conversation(self, user_id: str, data_state: str, name: str, summary: Optional[str] = None) -> Conversation:
|
||||
"""Create a new conversation for a user."""
|
||||
conv = Conversation(
|
||||
user_id=user_id,
|
||||
data_state=data_state,
|
||||
name=name,
|
||||
summary=summary
|
||||
)
|
||||
with self.get_session() as session:
|
||||
session.add(conv)
|
||||
session.commit()
|
||||
session.refresh(conv)
|
||||
return conv
|
||||
|
||||
def get_conversations(self, user_id: str, data_state: str) -> List[Conversation]:
|
||||
"""Get all conversations for a user and data state, ordered by creation time."""
|
||||
with self.get_session() as session:
|
||||
stmt = (
|
||||
select(Conversation)
|
||||
.where(Conversation.user_id == user_id, Conversation.data_state == data_state)
|
||||
.order_by(Conversation.created_at.desc())
|
||||
)
|
||||
result = session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
def rename_conversation(self, conversation_id: str, new_name: str) -> Optional[Conversation]:
|
||||
"""Rename an existing conversation."""
|
||||
with self.get_session() as session:
|
||||
conv = session.get(Conversation, conversation_id)
|
||||
if conv:
|
||||
conv.name = new_name
|
||||
session.commit()
|
||||
session.refresh(conv)
|
||||
return conv
|
||||
|
||||
def delete_conversation(self, conversation_id: str) -> bool:
|
||||
"""Delete a conversation and its associated messages/plots (via cascade)."""
|
||||
with self.get_session() as session:
|
||||
conv = session.get(Conversation, conversation_id)
|
||||
if conv:
|
||||
session.delete(conv)
|
||||
session.commit()
|
||||
return True
|
||||
return False
|
||||
|
||||
def update_conversation_summary(self, conversation_id: str, summary: str) -> Optional[Conversation]:
|
||||
"""Update the summary of a conversation."""
|
||||
with self.get_session() as session:
|
||||
conv = session.get(Conversation, conversation_id)
|
||||
if conv:
|
||||
conv.summary = summary
|
||||
session.commit()
|
||||
session.refresh(conv)
|
||||
return conv
|
||||
|
||||
# --- Message & Plot Management ---
|
||||
|
||||
def add_message(self, conversation_id: str, role: str, content: str, plots: Optional[List[bytes]] = None) -> Message:
|
||||
"""Add a message to a conversation, optionally with plots."""
|
||||
msg = Message(
|
||||
conversation_id=conversation_id,
|
||||
role=role,
|
||||
content=content
|
||||
)
|
||||
with self.get_session() as session:
|
||||
session.add(msg)
|
||||
session.flush() # Populate msg.id for plots
|
||||
|
||||
if plots:
|
||||
for plot_data in plots:
|
||||
plot = Plot(message_id=msg.id, image_data=plot_data)
|
||||
session.add(plot)
|
||||
|
||||
session.commit()
|
||||
session.refresh(msg)
|
||||
# Ensure plots are loaded before session closes if we need them
|
||||
_ = msg.plots
|
||||
return msg
|
||||
|
||||
def get_messages(self, conversation_id: str) -> List[Message]:
|
||||
"""Get all messages for a conversation, ordered by creation time."""
|
||||
with self.get_session() as session:
|
||||
stmt = (
|
||||
select(Message)
|
||||
.where(Message.conversation_id == conversation_id)
|
||||
.order_by(Message.created_at.asc())
|
||||
)
|
||||
result = session.execute(stmt)
|
||||
messages = list(result.scalars().all())
|
||||
# Pre-load plots for each message
|
||||
for m in messages:
|
||||
_ = m.plots
|
||||
return messages
|
||||
52
src/ea_chatbot/history/models.py
Normal file
52
src/ea_chatbot/history/models.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from typing import List, Optional
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
from sqlalchemy import String, ForeignKey, DateTime, LargeBinary, Text
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
username: Mapped[str] = mapped_column(String, unique=True, index=True)
|
||||
password_hash: Mapped[Optional[str]] = mapped_column(String, nullable=True)
|
||||
display_name: Mapped[Optional[str]] = mapped_column(String, nullable=True)
|
||||
|
||||
conversations: Mapped[List["Conversation"]] = relationship(back_populates="user", cascade="all, delete-orphan")
|
||||
|
||||
class Conversation(Base):
|
||||
__tablename__ = "conversations"
|
||||
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
user_id: Mapped[str] = mapped_column(ForeignKey("users.id"))
|
||||
data_state: Mapped[str] = mapped_column(String)
|
||||
name: Mapped[str] = mapped_column(String)
|
||||
summary: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||
|
||||
user: Mapped["User"] = relationship(back_populates="conversations")
|
||||
messages: Mapped[List["Message"]] = relationship(back_populates="conversation", cascade="all, delete-orphan")
|
||||
|
||||
class Message(Base):
|
||||
__tablename__ = "messages"
|
||||
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
conversation_id: Mapped[str] = mapped_column(ForeignKey("conversations.id"))
|
||||
role: Mapped[str] = mapped_column(String)
|
||||
content: Mapped[str] = mapped_column(Text)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||
|
||||
conversation: Mapped["Conversation"] = relationship(back_populates="messages")
|
||||
plots: Mapped[List["Plot"]] = relationship(back_populates="message", cascade="all, delete-orphan")
|
||||
|
||||
class Plot(Base):
|
||||
__tablename__ = "plots"
|
||||
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
message_id: Mapped[str] = mapped_column(ForeignKey("messages.id"))
|
||||
image_data: Mapped[bytes] = mapped_column(LargeBinary)
|
||||
|
||||
message: Mapped["Message"] = relationship(back_populates="plots")
|
||||
83
src/ea_chatbot/schemas.py
Normal file
83
src/ea_chatbot/schemas.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from pydantic import BaseModel, Field, computed_field
|
||||
from typing import Sequence, Optional
|
||||
import re
|
||||
|
||||
class TaskPlanContext(BaseModel):
|
||||
'''Background context relevant to the task plan'''
|
||||
initial_context: str = Field(
|
||||
min_length=1,
|
||||
description="Background information about the database/tables and previous conversations relevant to the task.",
|
||||
)
|
||||
assumptions: Sequence[str] = Field(
|
||||
description="Assumptions made while working on the task.",
|
||||
)
|
||||
constraints: Optional[Sequence[str]] = Field(
|
||||
description="Constraints that apply to the task.",
|
||||
)
|
||||
|
||||
class TaskPlanResponse(BaseModel):
|
||||
'''Structured plan to achieve the task objective'''
|
||||
goal: str = Field(
|
||||
min_length=1,
|
||||
description="Single-sentence objective the plan must achieve.",
|
||||
)
|
||||
reflection: str = Field(
|
||||
min_length=1,
|
||||
description="High-level natural-language reasoning describing the user's request and the intended solution approach.",
|
||||
)
|
||||
context: TaskPlanContext = Field(
|
||||
description="Background context relevant to the task plan.",
|
||||
)
|
||||
steps: Sequence[str] = Field(
|
||||
min_length=1,
|
||||
description="Ordered list of steps to execute that follow the 'Step <number>: <detail>' pattern.",
|
||||
)
|
||||
|
||||
_IM_SEP_TOKEN_PATTERN = re.compile(re.escape("<|im_sep|>"))
|
||||
_CODE_BLOCK_PATTERN = re.compile(r"```(?:python\s*)?(.*?)\s*```", re.DOTALL)
|
||||
_FORBIDDEN_MODULES = (
|
||||
"subprocess",
|
||||
"sys",
|
||||
"eval",
|
||||
"exec",
|
||||
"socket",
|
||||
"urllib",
|
||||
"shutil",
|
||||
"pickle",
|
||||
"ctypes",
|
||||
"multiprocessing",
|
||||
"tempfile",
|
||||
"glob",
|
||||
"pty",
|
||||
"commands",
|
||||
"cgi",
|
||||
"cgitb",
|
||||
"xml.etree.ElementTree",
|
||||
"builtins",
|
||||
)
|
||||
_FORBIDDEN_MODULE_PATTERN = re.compile(
|
||||
r"^((?:[^#].*)?\b(" + "|".join(map(re.escape, _FORBIDDEN_MODULES)) + r")\b.*)$",
|
||||
flags=re.MULTILINE,
|
||||
)
|
||||
|
||||
class CodeGenerationResponse(BaseModel):
|
||||
'''Code generation response structure'''
|
||||
code: str = Field(description="The generated code snippet to accomplish the task")
|
||||
explanation: str = Field(description="Explanation of the generated code and its functionality")
|
||||
|
||||
@computed_field(return_type=str)
|
||||
@property
|
||||
def parsed_code(self) -> str:
|
||||
'''Extracts the code snippet without any surrounding text'''
|
||||
normalised = _IM_SEP_TOKEN_PATTERN.sub("```", self.code).strip()
|
||||
match = _CODE_BLOCK_PATTERN.search(normalised)
|
||||
candidate = match.group(1).strip() if match else normalised
|
||||
sanitised = _FORBIDDEN_MODULE_PATTERN.sub(r"# not allowed \1", candidate)
|
||||
return sanitised.strip()
|
||||
|
||||
class RankResponse(BaseModel):
|
||||
'''Code ranking response structure'''
|
||||
rank: int = Field(
|
||||
ge=1, le=10,
|
||||
description="Rank of the code snippet from 1 (best) to 10 (worst)"
|
||||
)
|
||||
24
src/ea_chatbot/types.py
Normal file
24
src/ea_chatbot/types.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from typing import TypedDict, Optional
|
||||
from enum import StrEnum
|
||||
|
||||
class DBSettings(TypedDict):
|
||||
host: str
|
||||
port: int
|
||||
user: str
|
||||
pswd: str
|
||||
db: str
|
||||
table: Optional[str]
|
||||
|
||||
class Agent(StrEnum):
|
||||
EXPERT_SELECTOR = "Expert Selector"
|
||||
ANALYST_SELECTOR = "Analyst Selector"
|
||||
THEORIST = "Theorist"
|
||||
THEORIST_WEB = "Theorist-Web"
|
||||
THEORIST_CLARIFICATION = "Theorist-Clarification"
|
||||
PLANNER = "Planner"
|
||||
CODE_GENERATOR = "Code Generator"
|
||||
CODE_DEBUGGER = "Code Debugger"
|
||||
CODE_EXECUTOR = "Code Executor"
|
||||
ERROR_CORRECTOR = "Error Corrector"
|
||||
CODE_RANKER = "Code Ranker"
|
||||
SOLUTION_SUMMARIZER = "Solution Summarizer"
|
||||
12
src/ea_chatbot/utils/__init__.py
Normal file
12
src/ea_chatbot/utils/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from .db_client import DBClient
|
||||
from .llm_factory import get_llm_model
|
||||
from .logging import get_logger, LangChainLoggingHandler
|
||||
from . import helpers
|
||||
|
||||
__all__ = [
|
||||
"DBClient",
|
||||
"get_llm_model",
|
||||
"get_logger",
|
||||
"LangChainLoggingHandler",
|
||||
"helpers"
|
||||
]
|
||||
234
src/ea_chatbot/utils/database_inspection.py
Normal file
234
src/ea_chatbot/utils/database_inspection.py
Normal file
@@ -0,0 +1,234 @@
|
||||
from typing import Optional, Dict, Any, List, TYPE_CHECKING
|
||||
import yaml
|
||||
import json
|
||||
import os
|
||||
from ea_chatbot.utils.db_client import DBClient
|
||||
if TYPE_CHECKING:
|
||||
from ea_chatbot.types import DBSettings
|
||||
|
||||
def _get_table_checksum(db_client: DBClient, table: str) -> str:
|
||||
"""Calculates the checksum of the table using DML statistics from pg_stat_user_tables."""
|
||||
query = f"""
|
||||
SELECT md5(concat_ws('|', n_tup_ins, n_tup_upd, n_tup_del)) AS dml_hash
|
||||
FROM pg_stat_user_tables
|
||||
WHERE schemaname = 'public' AND relname = '{table}';"""
|
||||
try:
|
||||
return str(db_client.query_df(query).iloc[0, 0])
|
||||
except Exception:
|
||||
return "unknown_checksum"
|
||||
|
||||
def _update_checksum_file(filepath: str, table: str, checksum: str):
|
||||
"""Updates the checksum file with the new checksum for the table."""
|
||||
checksums = {}
|
||||
if os.path.exists(filepath):
|
||||
with open(filepath, 'r') as f:
|
||||
for line in f:
|
||||
if ':' in line:
|
||||
k, v = line.strip().split(':', 1)
|
||||
checksums[k] = v
|
||||
|
||||
checksums[table] = checksum
|
||||
|
||||
with open(filepath, 'w') as f:
|
||||
for k, v in checksums.items():
|
||||
f.write(f"{k}:{v}")
|
||||
|
||||
def get_data_summary(data_dir: str = "data") -> Optional[str]:
|
||||
"""
|
||||
Reads the inspection.yaml file and returns its content as a string.
|
||||
"""
|
||||
inspection_file = os.path.join(data_dir, "inspection.yaml")
|
||||
if os.path.exists(inspection_file):
|
||||
with open(inspection_file, 'r') as f:
|
||||
return f.read()
|
||||
return None
|
||||
|
||||
def get_primary_key(db_client: DBClient, table_name: str) -> Optional[str]:
|
||||
"""
|
||||
Dynamically identifies the primary key of the table.
|
||||
Returns the column name of the primary key, or None if not found.
|
||||
"""
|
||||
query = f"""
|
||||
SELECT kcu.column_name
|
||||
FROM information_schema.key_column_usage AS kcu
|
||||
JOIN information_schema.table_constraints AS tc
|
||||
ON kcu.constraint_name = tc.constraint_name
|
||||
AND kcu.table_schema = tc.table_schema
|
||||
WHERE kcu.table_name = '{table_name}'
|
||||
AND tc.constraint_type = 'PRIMARY KEY'
|
||||
LIMIT 1;
|
||||
"""
|
||||
try:
|
||||
df = db_client.query_df(query)
|
||||
if not df.empty:
|
||||
return str(df.iloc[0, 0])
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not determine primary key for {table_name}: {e}")
|
||||
return None
|
||||
|
||||
def inspect_db_table(
|
||||
db_client: Optional[DBClient]=None,
|
||||
db_settings: Optional["DBSettings"]=None,
|
||||
data_dir: str = "data",
|
||||
force_update: bool = False
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Inspects the database table, generates statistics for each column,
|
||||
and saves the inspection results to a YAML file locally.
|
||||
|
||||
Improvements:
|
||||
- Dynamic Primary Key Discovery
|
||||
- Cardinality (Unique Counts)
|
||||
- Categorical Sample Values for low cardinality columns
|
||||
- Robust Quoting
|
||||
"""
|
||||
inspection_file = os.path.join(data_dir, "inspection.yaml")
|
||||
checksum_file = os.path.join(data_dir, "checksum")
|
||||
|
||||
# Initialize DB Client
|
||||
if db_client is None:
|
||||
if db_settings is None:
|
||||
print("Error: Either db_client or db_settings must be provided.")
|
||||
return None
|
||||
try:
|
||||
db_client = DBClient(db_settings)
|
||||
except Exception as e:
|
||||
print(f"Failed to create DBClient: {e}")
|
||||
return None
|
||||
|
||||
table_name = db_client.settings.get('table')
|
||||
if not table_name:
|
||||
print("Error: Table name must be specified in DBSettings.")
|
||||
return None
|
||||
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
|
||||
# Checksum verification
|
||||
new_checksum = _get_table_checksum(db_client, table_name)
|
||||
has_changed = True
|
||||
|
||||
if os.path.exists(checksum_file):
|
||||
try:
|
||||
with open(checksum_file, 'r') as f:
|
||||
saved_checksums = f.read().strip()
|
||||
if f"{table_name}:{new_checksum}" in saved_checksums:
|
||||
has_changed = False
|
||||
except Exception:
|
||||
pass # Force update on read error
|
||||
|
||||
if not has_changed and not force_update:
|
||||
return get_data_summary(data_dir)
|
||||
|
||||
print(f"Regenerating inspection file for table '{table_name}'...")
|
||||
|
||||
# Fetch Table Metadata
|
||||
try:
|
||||
# Get columns and types
|
||||
columns_query = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}';"
|
||||
columns_df = db_client.query_df(columns_query)
|
||||
|
||||
# Get Row Counts
|
||||
total_rows_df = db_client.query_df(f'SELECT COUNT(*) FROM "{table_name}"')
|
||||
total_rows = int(total_rows_df.iloc[0, 0])
|
||||
|
||||
# Dynamic Primary Key
|
||||
primary_key = get_primary_key(db_client, table_name)
|
||||
|
||||
# Get First/Last Rows (if PK exists)
|
||||
first_row_df = None
|
||||
last_row_df = None
|
||||
if primary_key:
|
||||
first_row_df = db_client.query_df(f'SELECT * FROM "{table_name}" ORDER BY "{primary_key}" ASC LIMIT 1')
|
||||
last_row_df = db_client.query_df(f'SELECT * FROM "{table_name}" ORDER BY "{primary_key}" DESC LIMIT 1')
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to retrieve basic table info: {e}")
|
||||
return None
|
||||
|
||||
stats_dict: Dict[str, Any] = {}
|
||||
if primary_key:
|
||||
stats_dict['primary_key'] = primary_key
|
||||
|
||||
for _, row in columns_df.iterrows():
|
||||
col_name = row['column_name']
|
||||
dtype = row['data_type']
|
||||
|
||||
try:
|
||||
# Count Values
|
||||
# Using robust quoting
|
||||
count_df = db_client.query_df(f'SELECT COUNT("{col_name}") FROM "{table_name}"')
|
||||
count_val = int(count_df.iloc[0,0])
|
||||
|
||||
# Count Unique (Cardinality)
|
||||
unique_df = db_client.query_df(f'SELECT COUNT(DISTINCT "{col_name}") FROM "{table_name}"')
|
||||
unique_count = int(unique_df.iloc[0,0])
|
||||
|
||||
col_stats: Dict[str, Any] = {
|
||||
'dtype': dtype,
|
||||
'count_of_values': count_val,
|
||||
'count_of_nulls': total_rows - count_val,
|
||||
'unique_count': unique_count
|
||||
}
|
||||
|
||||
if count_val == 0:
|
||||
stats_dict[col_name] = col_stats
|
||||
continue
|
||||
|
||||
# Numerical Stats
|
||||
if any(t in dtype for t in ('int', 'float', 'numeric', 'double', 'real', 'decimal')):
|
||||
stats_query = f'SELECT AVG("{col_name}"), MIN("{col_name}"), MAX("{col_name}") FROM "{table_name}"'
|
||||
stats_df = db_client.query_df(stats_query)
|
||||
if not stats_df.empty:
|
||||
col_stats['mean'] = float(stats_df.iloc[0,0]) if stats_df.iloc[0,0] is not None else None
|
||||
col_stats['min'] = float(stats_df.iloc[0,1]) if stats_df.iloc[0,1] is not None else None
|
||||
col_stats['max'] = float(stats_df.iloc[0,2]) if stats_df.iloc[0,2] is not None else None
|
||||
|
||||
# Temporal Stats
|
||||
elif any(t in dtype for t in ('date', 'timestamp')):
|
||||
stats_query = f'SELECT MIN("{col_name}"), MAX("{col_name}") FROM "{table_name}"'
|
||||
stats_df = db_client.query_df(stats_query)
|
||||
if not stats_df.empty:
|
||||
col_stats['min'] = str(stats_df.iloc[0,0])
|
||||
col_stats['max'] = str(stats_df.iloc[0,1])
|
||||
|
||||
# Categorical/Text Stats
|
||||
else:
|
||||
# Sample values if cardinality is low (< 20)
|
||||
if 0 < unique_count < 20:
|
||||
distinct_query = f'SELECT DISTINCT "{col_name}" FROM "{table_name}" ORDER BY "{col_name}" LIMIT 20'
|
||||
distinct_df = db_client.query_df(distinct_query)
|
||||
col_stats['distinct_values'] = distinct_df.iloc[:, 0].tolist()
|
||||
|
||||
if first_row_df is not None and not first_row_df.empty and col_name in first_row_df.columns:
|
||||
col_stats['first_value'] = str(first_row_df.iloc[0][col_name])
|
||||
if last_row_df is not None and not last_row_df.empty and col_name in last_row_df.columns:
|
||||
col_stats['last_value'] = str(last_row_df.iloc[0][col_name])
|
||||
|
||||
stats_dict[col_name] = col_stats
|
||||
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not process column {col_name}: {e}")
|
||||
|
||||
# Load existing inspections to merge (if multiple tables)
|
||||
existing_inspections = {}
|
||||
if os.path.exists(inspection_file):
|
||||
try:
|
||||
with open(inspection_file, 'r') as f:
|
||||
existing_inspections = yaml.safe_load(f) or {}
|
||||
# Backup old file
|
||||
os.rename(inspection_file, inspection_file + ".old")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
existing_inspections[table_name] = stats_dict
|
||||
|
||||
# Save new inspection
|
||||
inspection_content = yaml.dump(existing_inspections, sort_keys=False, default_flow_style=False)
|
||||
with open(inspection_file, 'w') as f:
|
||||
f.write(inspection_content)
|
||||
|
||||
# Update Checksum
|
||||
_update_checksum_file(checksum_file, table_name, new_checksum)
|
||||
|
||||
print(f"Inspection saved to {inspection_file}")
|
||||
return inspection_content
|
||||
21
src/ea_chatbot/utils/db_client.py
Normal file
21
src/ea_chatbot/utils/db_client.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
import pandas as pd
|
||||
from sqlalchemy import create_engine, text
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ea_chatbot.types import DBSettings
|
||||
|
||||
class DBClient:
|
||||
def __init__(self, settings: "DBSettings"):
|
||||
self.settings = settings
|
||||
self._engine = self._create_engine()
|
||||
|
||||
def _create_engine(self):
|
||||
url = f"postgresql://{self.settings['user']}:{self.settings['pswd']}@{self.settings['host']}:{self.settings['port']}/{self.settings['db']}"
|
||||
return create_engine(url)
|
||||
|
||||
def query_df(self, sql: str, params: Optional[dict] = None) -> pd.DataFrame:
|
||||
with self._engine.connect() as conn:
|
||||
result = conn.execute(text(sql), params or {})
|
||||
df = pd.DataFrame(result.fetchall(), columns=result.keys())
|
||||
return df
|
||||
73
src/ea_chatbot/utils/helpers.py
Normal file
73
src/ea_chatbot/utils/helpers.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from typing import Optional, TYPE_CHECKING, Dict, Any
|
||||
from datetime import datetime, timezone
|
||||
import yaml
|
||||
import json
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
|
||||
def ordinal(n: int) -> str:
|
||||
return f"{n}{'th' if 11<=n<=13 else {1:'st',2:'nd',3:'rd'}.get(n%10, 'th')}"
|
||||
|
||||
def get_readable_date(date_obj: Optional[datetime] = None, tz: Optional[timezone] = None) -> str:
|
||||
if date_obj is None:
|
||||
date_obj = datetime.now(timezone.utc)
|
||||
if tz:
|
||||
date_obj = date_obj.astimezone(tz)
|
||||
return date_obj.strftime(f"%a {ordinal(date_obj.day)} of %b %Y")
|
||||
|
||||
def to_yaml(json_str: str, indent: int = 2) -> str:
|
||||
"""
|
||||
Attempts to convert a JSON string (potentially malformed from LLM) to a YAML string.
|
||||
"""
|
||||
if not json_str: return ""
|
||||
|
||||
try:
|
||||
# Try direct parse
|
||||
data = json.loads(json_str)
|
||||
except json.JSONDecodeError:
|
||||
# Try simplified repair: replace single quotes
|
||||
try:
|
||||
cleaned = json_str.replace("'", '"')
|
||||
data = json.loads(cleaned)
|
||||
except Exception:
|
||||
# Fallback: return raw string if unparseable
|
||||
return json_str
|
||||
|
||||
return yaml.dump(data, indent=indent, sort_keys=False)
|
||||
|
||||
def merge_agent_state(current_state: "AgentState", update: Dict[str, Any]) -> "AgentState":
|
||||
"""
|
||||
Merges a partial state update into the current state, mimicking LangGraph reduction logic.
|
||||
- Lists (messages, plots) are appended.
|
||||
- Dictionaries (dfs) are shallow merged.
|
||||
- Other fields are overwritten.
|
||||
"""
|
||||
new_state = current_state.copy()
|
||||
|
||||
for key, value in update.items():
|
||||
if value is None:
|
||||
new_state[key] = None
|
||||
continue
|
||||
|
||||
# Accumulate lists (messages, plots)
|
||||
if key in ["messages", "plots"] and isinstance(value, list):
|
||||
current_list = new_state.get(key, [])
|
||||
if not isinstance(current_list, list):
|
||||
current_list = []
|
||||
new_state[key] = current_list + value
|
||||
|
||||
# Shallow merge dictionaries (dfs)
|
||||
elif key == "dfs" and isinstance(value, dict):
|
||||
current_dict = new_state.get(key, {})
|
||||
if not isinstance(current_dict, dict):
|
||||
current_dict = {}
|
||||
merged_dict = current_dict.copy()
|
||||
merged_dict.update(value)
|
||||
new_state[key] = merged_dict
|
||||
|
||||
# Overwrite everything else
|
||||
else:
|
||||
new_state[key] = value
|
||||
|
||||
return new_state
|
||||
36
src/ea_chatbot/utils/llm_factory.py
Normal file
36
src/ea_chatbot/utils/llm_factory.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from typing import Optional, cast, TYPE_CHECKING, Literal, Dict, List, Tuple, Any
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from ea_chatbot.config import LLMConfig
|
||||
|
||||
def get_llm_model(config: LLMConfig, callbacks: Optional[List[BaseCallbackHandler]] = None) -> BaseChatModel:
|
||||
"""
|
||||
Factory function to get a LangChain chat model based on configuration.
|
||||
|
||||
Args:
|
||||
config: LLMConfig object containing model settings.
|
||||
callbacks: Optional list of LangChain callback handlers.
|
||||
|
||||
Returns:
|
||||
Initialized BaseChatModel instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If the provider is not supported.
|
||||
"""
|
||||
params = {
|
||||
"temperature": config.temperature,
|
||||
"max_tokens": config.max_tokens,
|
||||
**config.provider_specific
|
||||
}
|
||||
|
||||
# Filter out None values to allow defaults to take over if not specified
|
||||
params = {k: v for k, v in params.items() if v is not None}
|
||||
|
||||
if config.provider.lower() == "openai":
|
||||
return ChatOpenAI(model=config.model, callbacks=callbacks, **params)
|
||||
elif config.provider.lower() == "google" or config.provider.lower() == "google_genai":
|
||||
return ChatGoogleGenerativeAI(model=config.model, callbacks=callbacks, **params)
|
||||
else:
|
||||
raise ValueError(f"Unsupported LLM provider: {config.provider}")
|
||||
141
src/ea_chatbot/utils/logging.py
Normal file
141
src/ea_chatbot/utils/logging.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
from rich.logging import RichHandler
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from typing import Any, Optional, Dict, List
|
||||
|
||||
class LangChainLoggingHandler(BaseCallbackHandler):
|
||||
"""Callback handler for logging LangChain events."""
|
||||
|
||||
def __init__(self, logger: Optional[logging.Logger] = None):
|
||||
self.logger = logger or get_logger("langchain")
|
||||
|
||||
def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> Any:
|
||||
# Serialized might be empty or missing name depending on how it's called
|
||||
model_name = serialized.get("name") or kwargs.get("name") or "LLM"
|
||||
self.logger.info(f"[bold blue]LLM Started:[/bold blue] {model_name}")
|
||||
|
||||
def on_llm_end(self, response: Any, **kwargs: Any) -> Any:
|
||||
llm_output = getattr(response, "llm_output", {}) or {}
|
||||
# Try to find model name in output or use fallback
|
||||
model_name = llm_output.get("model_name") or "LLM"
|
||||
token_usage = llm_output.get("token_usage", {})
|
||||
|
||||
msg = f"[bold green]LLM Ended:[/bold green] {model_name}"
|
||||
if token_usage:
|
||||
prompt = token_usage.get("prompt_tokens", 0)
|
||||
completion = token_usage.get("completion_tokens", 0)
|
||||
total = token_usage.get("total_tokens", 0)
|
||||
msg += f" | [yellow]Tokens: {total}[/yellow] ({prompt} prompt, {completion} completion)"
|
||||
|
||||
self.logger.info(msg)
|
||||
|
||||
def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> Any:
|
||||
self.logger.error(f"[bold red]LLM Error:[/bold red] {str(error)}")
|
||||
|
||||
class ContextLoggerAdapter(logging.LoggerAdapter):
|
||||
"""Adapter to inject contextual metadata into log records."""
|
||||
def process(self, msg: Any, kwargs: Any) -> tuple[Any, Any]:
|
||||
extra = self.extra.copy()
|
||||
if "extra" in kwargs:
|
||||
extra.update(kwargs.pop("extra"))
|
||||
kwargs["extra"] = extra
|
||||
return msg, kwargs
|
||||
|
||||
class FlexibleJSONEncoder(json.JSONEncoder):
|
||||
def default(self, obj: Any) -> Any:
|
||||
if hasattr(obj, 'model_dump'): # Pydantic v2
|
||||
return obj.model_dump()
|
||||
if hasattr(obj, 'dict'): # Pydantic v1
|
||||
return obj.dict()
|
||||
if hasattr(obj, '__dict__'):
|
||||
return self.serialize_custom_object(obj)
|
||||
elif isinstance(obj, dict):
|
||||
return {k: self.default(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
return [self.default(item) for item in obj]
|
||||
return super().default(obj)
|
||||
|
||||
def serialize_custom_object(self, obj: Any) -> dict:
|
||||
obj_dict = obj.__dict__.copy()
|
||||
obj_dict['__custom_class__'] = obj.__class__.__name__
|
||||
return obj_dict
|
||||
|
||||
class JsonFormatter(logging.Formatter):
|
||||
"""Custom JSON formatter for structured logging."""
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
# Standard fields
|
||||
log_record = {
|
||||
"timestamp": self.formatTime(record, self.datefmt),
|
||||
"level": record.levelname,
|
||||
"message": record.getMessage(),
|
||||
"module": record.module,
|
||||
"name": record.name,
|
||||
}
|
||||
|
||||
# Add exception info if present
|
||||
if record.exc_info:
|
||||
log_record["exception"] = self.formatException(record.exc_info)
|
||||
|
||||
# Add all other extra fields from the record
|
||||
# Filter out standard logging attributes
|
||||
standard_attrs = {
|
||||
'args', 'asctime', 'created', 'exc_info', 'exc_text', 'filename',
|
||||
'funcName', 'levelname', 'levelno', 'lineno', 'module',
|
||||
'msecs', 'message', 'msg', 'name', 'pathname', 'process',
|
||||
'processName', 'relativeCreated', 'stack_info', 'thread', 'threadName'
|
||||
}
|
||||
for key, value in record.__dict__.items():
|
||||
if key not in standard_attrs:
|
||||
log_record[key] = value
|
||||
|
||||
return json.dumps(log_record, cls=FlexibleJSONEncoder)
|
||||
|
||||
def get_logger(name: str = "ea_chatbot", level: Optional[str] = None, log_file: Optional[str] = None) -> logging.Logger:
|
||||
"""Get a configured logger with RichHandler and optional Json FileHandler."""
|
||||
# Ensure name starts with ea_chatbot for hierarchy if not already
|
||||
if name != "ea_chatbot" and not name.startswith("ea_chatbot."):
|
||||
full_name = f"ea_chatbot.{name}"
|
||||
else:
|
||||
full_name = name
|
||||
|
||||
logger = logging.getLogger(full_name)
|
||||
|
||||
# Configure root ea_chatbot logger if it hasn't been configured
|
||||
root_logger = logging.getLogger("ea_chatbot")
|
||||
if not root_logger.handlers:
|
||||
# Default to INFO if level not provided
|
||||
log_level = getattr(logging, (level or "INFO").upper(), logging.INFO)
|
||||
root_logger.setLevel(log_level)
|
||||
|
||||
# Console Handler (Rich)
|
||||
rich_handler = RichHandler(
|
||||
rich_tracebacks=True,
|
||||
markup=True,
|
||||
show_time=False,
|
||||
show_path=False
|
||||
)
|
||||
root_logger.addHandler(rich_handler)
|
||||
root_logger.propagate = False
|
||||
|
||||
# Always check if we need to add a FileHandler, even if root is already configured
|
||||
if log_file:
|
||||
existing_file_handlers = [h for h in root_logger.handlers if isinstance(h, RotatingFileHandler)]
|
||||
if not existing_file_handlers:
|
||||
os.makedirs(os.path.dirname(log_file), exist_ok=True)
|
||||
file_handler = RotatingFileHandler(
|
||||
log_file, maxBytes=5*1024*1024, backupCount=3
|
||||
)
|
||||
file_handler.setFormatter(JsonFormatter())
|
||||
root_logger.addHandler(file_handler)
|
||||
|
||||
# Refresh logger object in case it was created before root was configured
|
||||
logger = logging.getLogger(full_name)
|
||||
|
||||
# If level is explicitly provided for a sub-logger, set it
|
||||
if level:
|
||||
logger.setLevel(getattr(logging, level.upper(), logging.INFO))
|
||||
|
||||
return logger
|
||||
90
tests/test_app.py
Normal file
90
tests/test_app.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import os
|
||||
import sys
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from streamlit.testing.v1 import AppTest
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
# Ensure src is in python path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../src')))
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_history_manager():
|
||||
"""Globally mock HistoryManager to avoid DB calls during AppTest."""
|
||||
with patch("ea_chatbot.history.manager.HistoryManager") as mock_cls:
|
||||
instance = mock_cls.return_value
|
||||
instance.create_conversation.return_value = MagicMock(id="conv_123")
|
||||
instance.get_conversations.return_value = []
|
||||
instance.get_messages.return_value = []
|
||||
instance.add_message.return_value = MagicMock()
|
||||
instance.update_conversation_summary.return_value = MagicMock()
|
||||
yield instance
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_stream():
|
||||
with patch("ea_chatbot.graph.workflow.app.stream") as mock_stream:
|
||||
# Mock events from app.stream
|
||||
mock_stream.return_value = [
|
||||
{"query_analyzer": {"next_action": "research"}},
|
||||
{"researcher": {"messages": [AIMessage(content="Research result")]}}
|
||||
]
|
||||
yield mock_stream
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user():
|
||||
user = MagicMock()
|
||||
user.id = "test_id"
|
||||
user.username = "test@example.com"
|
||||
user.display_name = "Test User"
|
||||
return user
|
||||
|
||||
def test_app_initial_state(mock_app_stream, mock_user):
|
||||
"""Test that the app initializes with the correct title and empty history."""
|
||||
at = AppTest.from_file("src/ea_chatbot/app.py")
|
||||
|
||||
# Simulate logged-in user
|
||||
at.session_state["user"] = mock_user
|
||||
|
||||
at.run()
|
||||
|
||||
assert not at.exception
|
||||
assert at.title[0].value == "🗳️ Election Analytics Chatbot"
|
||||
|
||||
# Check session state initialization
|
||||
assert "messages" in at.session_state
|
||||
assert len(at.session_state["messages"]) == 0
|
||||
|
||||
def test_app_dev_mode_toggle(mock_app_stream, mock_user):
|
||||
"""Test that the dev mode toggle exists in the sidebar."""
|
||||
with patch.dict(os.environ, {"DEV_MODE": "false"}):
|
||||
at = AppTest.from_file("src/ea_chatbot/app.py")
|
||||
at.session_state["user"] = mock_user
|
||||
at.run()
|
||||
|
||||
# Check for sidebar toggle (checkbox)
|
||||
assert len(at.sidebar.checkbox) > 0
|
||||
dev_mode_toggle = at.sidebar.checkbox[0]
|
||||
assert dev_mode_toggle.label == "Dev Mode"
|
||||
assert dev_mode_toggle.value is False
|
||||
|
||||
def test_app_graph_execution_streaming(mock_app_stream, mock_user, mock_history_manager):
|
||||
"""Test that entering a prompt triggers the graph stream and displays response."""
|
||||
at = AppTest.from_file("src/ea_chatbot/app.py")
|
||||
at.session_state["user"] = mock_user
|
||||
at.run()
|
||||
|
||||
# Input a question
|
||||
at.chat_input[0].set_value("Test question").run()
|
||||
|
||||
# Verify graph stream was called
|
||||
assert mock_app_stream.called
|
||||
|
||||
# Message should be added to history
|
||||
assert len(at.session_state["messages"]) == 2
|
||||
assert at.session_state["messages"][0]["role"] == "user"
|
||||
assert at.session_state["messages"][1]["role"] == "assistant"
|
||||
assert "Research result" in at.session_state["messages"][1]["content"]
|
||||
|
||||
# Verify history manager was used
|
||||
assert mock_history_manager.create_conversation.called
|
||||
assert mock_history_manager.add_message.called
|
||||
83
tests/test_app_auth.py
Normal file
83
tests/test_app_auth.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from streamlit.testing.v1 import AppTest
|
||||
from ea_chatbot.auth import AuthType
|
||||
|
||||
@pytest.fixture
|
||||
def mock_history_manager_instance():
|
||||
# We need to patch before the AppTest loads the module
|
||||
with patch("ea_chatbot.history.manager.HistoryManager") as mock_cls:
|
||||
instance = mock_cls.return_value
|
||||
yield instance
|
||||
|
||||
def test_auth_ui_flow_step1_to_password(mock_history_manager_instance):
|
||||
"""Test UI transition from Step 1 (email) to Step 2a (password) for LOCAL user."""
|
||||
# Patch BEFORE creating AppTest
|
||||
mock_user = MagicMock()
|
||||
mock_user.password_hash = "hashed_password"
|
||||
mock_history_manager_instance.get_user.return_value = mock_user
|
||||
|
||||
at = AppTest.from_file("src/ea_chatbot/app.py")
|
||||
at.run()
|
||||
|
||||
# Step 1: Identification
|
||||
assert at.session_state["login_step"] == "email"
|
||||
at.text_input[0].set_value("local@example.com")
|
||||
|
||||
at.button[0].click().run()
|
||||
|
||||
# Verify transition to password step
|
||||
assert at.session_state["login_step"] == "login_password"
|
||||
assert at.session_state["login_email"] == "local@example.com"
|
||||
assert "Welcome back" in at.info[0].value
|
||||
|
||||
def test_auth_ui_flow_step1_to_register(mock_history_manager_instance):
|
||||
"""Test UI transition from Step 1 (email) to Step 2b (registration) for NEW user."""
|
||||
mock_history_manager_instance.get_user.return_value = None
|
||||
|
||||
at = AppTest.from_file("src/ea_chatbot/app.py")
|
||||
at.run()
|
||||
|
||||
# Step 1: Identification
|
||||
at.text_input[0].set_value("new@example.com")
|
||||
|
||||
at.button[0].click().run()
|
||||
|
||||
# Verify transition to registration step
|
||||
assert at.session_state["login_step"] == "register_details"
|
||||
assert at.session_state["login_email"] == "new@example.com"
|
||||
assert "Create an account" in at.info[0].value
|
||||
|
||||
def test_auth_ui_flow_step1_to_oidc(mock_history_manager_instance):
|
||||
"""Test UI transition from Step 1 (email) to Step 2c (OIDC) for OIDC user."""
|
||||
# Mock history_manager.get_user to return a user WITHOUT a password
|
||||
mock_user = MagicMock()
|
||||
mock_user.password_hash = None
|
||||
mock_history_manager_instance.get_user.return_value = mock_user
|
||||
|
||||
at = AppTest.from_file("src/ea_chatbot/app.py")
|
||||
at.run()
|
||||
|
||||
# Step 1: Identification
|
||||
at.text_input[0].set_value("oidc@example.com")
|
||||
|
||||
at.button[0].click().run()
|
||||
|
||||
# Verify transition to OIDC step
|
||||
assert at.session_state["login_step"] == "oidc_login"
|
||||
assert at.session_state["login_email"] == "oidc@example.com"
|
||||
assert "configured for Single Sign-On" in at.info[0].value
|
||||
|
||||
def test_auth_ui_flow_back_button(mock_history_manager_instance):
|
||||
"""Test that the 'Back' button returns to Step 1."""
|
||||
at = AppTest.from_file("src/ea_chatbot/app.py")
|
||||
# Simulate being on Step 2a
|
||||
at.session_state["login_step"] = "login_password"
|
||||
at.session_state["login_email"] = "local@example.com"
|
||||
at.run()
|
||||
|
||||
# Click Back (index 1 in Step 2a)
|
||||
at.button[1].click().run()
|
||||
|
||||
# Verify return to email step
|
||||
assert at.session_state["login_step"] == "email"
|
||||
43
tests/test_auth.py
Normal file
43
tests/test_auth.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.auth import OIDCClient
|
||||
|
||||
@patch("ea_chatbot.auth.OAuth2Session")
|
||||
def test_oidc_client_initialization(mock_oauth):
|
||||
client = OIDCClient(
|
||||
client_id="test_id",
|
||||
client_secret="test_secret",
|
||||
server_metadata_url="https://test.server/.well-known/openid-configuration"
|
||||
)
|
||||
assert client.oauth_session is not None
|
||||
|
||||
@patch("ea_chatbot.auth.requests")
|
||||
@patch("ea_chatbot.auth.OAuth2Session")
|
||||
def test_get_login_url(mock_oauth_cls, mock_requests):
|
||||
# Setup mock session
|
||||
mock_session = MagicMock()
|
||||
mock_oauth_cls.return_value = mock_session
|
||||
|
||||
# Mock metadata response
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"authorization_endpoint": "https://test.server/auth",
|
||||
"token_endpoint": "https://test.server/token",
|
||||
"userinfo_endpoint": "https://test.server/userinfo"
|
||||
}
|
||||
mock_requests.get.return_value = mock_response
|
||||
|
||||
# Mock authorization url generation
|
||||
mock_session.create_authorization_url.return_value = ("https://test.server/auth?response_type=code", "state")
|
||||
|
||||
client = OIDCClient(
|
||||
client_id="test_id",
|
||||
client_secret="test_secret",
|
||||
server_metadata_url="https://test.server/.well-known/openid-configuration"
|
||||
)
|
||||
|
||||
url = client.get_login_url()
|
||||
|
||||
assert url == "https://test.server/auth?response_type=code"
|
||||
# Verify metadata was fetched via requests
|
||||
mock_requests.get.assert_called_with("https://test.server/.well-known/openid-configuration")
|
||||
49
tests/test_auth_flow.py
Normal file
49
tests/test_auth_flow.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
from ea_chatbot.history.manager import HistoryManager
|
||||
from ea_chatbot.auth import get_user_auth_type, AuthType
|
||||
|
||||
# Mocks
|
||||
@pytest.fixture
|
||||
def mock_history_manager():
|
||||
return MagicMock(spec=HistoryManager)
|
||||
|
||||
def test_auth_flow_existing_local_user(mock_history_manager):
|
||||
"""Test that an existing user with a password returns LOCAL auth type."""
|
||||
# Setup
|
||||
mock_user = MagicMock()
|
||||
mock_user.password_hash = "hashed_secret"
|
||||
mock_history_manager.get_user.return_value = mock_user
|
||||
|
||||
# Execute
|
||||
auth_type = get_user_auth_type("test@example.com", mock_history_manager)
|
||||
|
||||
# Verify
|
||||
assert auth_type == AuthType.LOCAL
|
||||
mock_history_manager.get_user.assert_called_once_with("test@example.com")
|
||||
|
||||
def test_auth_flow_existing_oidc_user(mock_history_manager):
|
||||
"""Test that an existing user WITHOUT a password returns OIDC auth type."""
|
||||
# Setup
|
||||
mock_user = MagicMock()
|
||||
mock_user.password_hash = None # No password implies OIDC
|
||||
mock_history_manager.get_user.return_value = mock_user
|
||||
|
||||
# Execute
|
||||
auth_type = get_user_auth_type("sso@example.com", mock_history_manager)
|
||||
|
||||
# Verify
|
||||
assert auth_type == AuthType.OIDC
|
||||
mock_history_manager.get_user.assert_called_once_with("sso@example.com")
|
||||
|
||||
def test_auth_flow_new_user(mock_history_manager):
|
||||
"""Test that a non-existent user returns NEW auth type."""
|
||||
# Setup
|
||||
mock_history_manager.get_user.return_value = None
|
||||
|
||||
# Execute
|
||||
auth_type = get_user_auth_type("new@example.com", mock_history_manager)
|
||||
|
||||
# Verify
|
||||
assert auth_type == AuthType.NEW
|
||||
mock_history_manager.get_user.assert_called_once_with("new@example.com")
|
||||
62
tests/test_coder.py
Normal file
62
tests/test_coder.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.graph.nodes.coder import coder_node
|
||||
from ea_chatbot.graph.nodes.error_corrector import error_corrector_node
|
||||
|
||||
@pytest.fixture
|
||||
def mock_state():
|
||||
return {
|
||||
"messages": [],
|
||||
"question": "Show me results for New Jersey",
|
||||
"plan": "Step 1: Load data\nStep 2: Filter by NJ",
|
||||
"code": None,
|
||||
"error": None,
|
||||
"plots": [],
|
||||
"dfs": {},
|
||||
"next_action": "plan"
|
||||
}
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.coder.get_llm_model")
|
||||
@patch("ea_chatbot.utils.database_inspection.get_data_summary")
|
||||
def test_coder_node(mock_get_summary, mock_get_llm, mock_state):
|
||||
"""Test coder node generates code from plan."""
|
||||
mock_get_summary.return_value = "Column: Name, Type: text"
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
from ea_chatbot.schemas import CodeGenerationResponse
|
||||
mock_response = CodeGenerationResponse(
|
||||
code="import pandas as pd\nprint('Hello')",
|
||||
explanation="Generated code"
|
||||
)
|
||||
mock_llm.with_structured_output.return_value.invoke.return_value = mock_response
|
||||
|
||||
result = coder_node(mock_state)
|
||||
|
||||
assert "code" in result
|
||||
assert "import pandas as pd" in result["code"]
|
||||
assert "error" in result
|
||||
assert result["error"] is None
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.error_corrector.get_llm_model")
|
||||
def test_error_corrector_node(mock_get_llm, mock_state):
|
||||
"""Test error corrector node fixes code."""
|
||||
mock_state["code"] = "import pandas as pd\nprint(undefined_var)"
|
||||
mock_state["error"] = "NameError: name 'undefined_var' is not defined"
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
from ea_chatbot.schemas import CodeGenerationResponse
|
||||
mock_response = CodeGenerationResponse(
|
||||
code="import pandas as pd\nprint('Defined')",
|
||||
explanation="Fixed variable"
|
||||
)
|
||||
mock_llm.with_structured_output.return_value.invoke.return_value = mock_response
|
||||
|
||||
result = error_corrector_node(mock_state)
|
||||
|
||||
assert "code" in result
|
||||
assert "print('Defined')" in result["code"]
|
||||
assert result["error"] is None
|
||||
47
tests/test_config.py
Normal file
47
tests/test_config.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
from ea_chatbot.config import Settings, LLMConfig
|
||||
|
||||
def test_default_settings():
|
||||
"""Test that default settings are loaded correctly."""
|
||||
settings = Settings()
|
||||
|
||||
# Check default config for query analyzer
|
||||
assert isinstance(settings.query_analyzer_llm, LLMConfig)
|
||||
assert settings.query_analyzer_llm.provider == "openai"
|
||||
assert settings.query_analyzer_llm.model == "gpt-5-mini"
|
||||
assert settings.query_analyzer_llm.temperature == 0.0
|
||||
|
||||
# Check default config for planner
|
||||
assert isinstance(settings.planner_llm, LLMConfig)
|
||||
assert settings.planner_llm.provider == "openai"
|
||||
assert settings.planner_llm.model == "gpt-5-mini"
|
||||
|
||||
def test_env_override(monkeypatch):
|
||||
"""Test that environment variables override defaults."""
|
||||
monkeypatch.setenv("QUERY_ANALYZER_LLM__MODEL", "gpt-3.5-turbo")
|
||||
monkeypatch.setenv("QUERY_ANALYZER_LLM__TEMPERATURE", "0.7")
|
||||
|
||||
settings = Settings()
|
||||
assert settings.query_analyzer_llm.model == "gpt-3.5-turbo"
|
||||
assert settings.query_analyzer_llm.temperature == 0.7
|
||||
|
||||
def test_provider_specific_params():
|
||||
"""Test that provider specific parameters can be set."""
|
||||
config = LLMConfig(
|
||||
provider="openai",
|
||||
model="o1-preview",
|
||||
provider_specific={"reasoning_effort": "high"}
|
||||
)
|
||||
assert config.provider_specific["reasoning_effort"] == "high"
|
||||
|
||||
def test_oidc_settings(monkeypatch):
|
||||
"""Test OIDC settings configuration."""
|
||||
monkeypatch.setenv("OIDC_CLIENT_ID", "test_client_id")
|
||||
monkeypatch.setenv("OIDC_CLIENT_SECRET", "test_client_secret")
|
||||
monkeypatch.setenv("OIDC_SERVER_METADATA_URL", "https://test.server/.well-known/openid-configuration")
|
||||
|
||||
settings = Settings()
|
||||
assert settings.oidc_client_id == "test_client_id"
|
||||
assert settings.oidc_client_secret == "test_client_secret"
|
||||
assert settings.oidc_server_metadata_url == "https://test.server/.well-known/openid-configuration"
|
||||
56
tests/test_conversation_summary.py
Normal file
56
tests/test_conversation_summary.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from ea_chatbot.graph.nodes.summarize_conversation import summarize_conversation_node
|
||||
|
||||
@pytest.fixture
|
||||
def mock_state_with_history():
|
||||
return {
|
||||
"messages": [
|
||||
HumanMessage(content="Show me the 2024 results for Florida"),
|
||||
AIMessage(content="Here are the results for Florida in 2024...")
|
||||
],
|
||||
"summary": "The user is asking about 2024 election results."
|
||||
}
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.summarize_conversation.get_llm_model")
|
||||
def test_summarize_conversation_node_updates_summary(mock_get_llm, mock_state_with_history):
|
||||
mock_llm_instance = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm_instance
|
||||
|
||||
# Mock LLM response for updating summary
|
||||
mock_llm_instance.invoke.return_value = AIMessage(content="Updated summary including NJ results.")
|
||||
|
||||
# Add new messages to simulate a completed turn
|
||||
mock_state_with_history["messages"].extend([
|
||||
HumanMessage(content="What about in New Jersey?"),
|
||||
AIMessage(content="In New Jersey, the 2024 results were...")
|
||||
])
|
||||
|
||||
result = summarize_conversation_node(mock_state_with_history)
|
||||
|
||||
assert "summary" in result
|
||||
assert result["summary"] == "Updated summary including NJ results."
|
||||
|
||||
# Verify LLM was called with the correct context
|
||||
call_messages = mock_llm_instance.invoke.call_args[0][0]
|
||||
# Should include current summary and last turn messages
|
||||
assert "Current summary: The user is asking about 2024 election results." in call_messages[0].content
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.summarize_conversation.get_llm_model")
|
||||
def test_summarize_conversation_node_initial_summary(mock_get_llm):
|
||||
state = {
|
||||
"messages": [
|
||||
HumanMessage(content="Hi"),
|
||||
AIMessage(content="Hello! How can I help you today?")
|
||||
],
|
||||
"summary": ""
|
||||
}
|
||||
|
||||
mock_llm_instance = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm_instance
|
||||
mock_llm_instance.invoke.return_value = AIMessage(content="Initial greeting.")
|
||||
|
||||
result = summarize_conversation_node(state)
|
||||
|
||||
assert result["summary"] == "Initial greeting."
|
||||
195
tests/test_database_inspection.py
Normal file
195
tests/test_database_inspection.py
Normal file
@@ -0,0 +1,195 @@
|
||||
import pytest
|
||||
import pandas as pd
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.utils.database_inspection import get_primary_key, inspect_db_table, get_data_summary
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_client():
|
||||
mock_client = MagicMock()
|
||||
mock_client.settings = {"table": "test_table"}
|
||||
return mock_client
|
||||
|
||||
def test_get_primary_key(mock_db_client):
|
||||
"""Test dynamic primary key discovery."""
|
||||
# Mock response for primary key query
|
||||
mock_df = pd.DataFrame({"column_name": ["my_pk"]})
|
||||
mock_db_client.query_df.return_value = mock_df
|
||||
|
||||
pk = get_primary_key(mock_db_client, "test_table")
|
||||
|
||||
assert pk == "my_pk"
|
||||
# Verify the query was called (at least once)
|
||||
assert mock_db_client.query_df.called
|
||||
|
||||
def test_inspect_db_table_improved(mock_db_client, tmp_path):
|
||||
"""Test improved inspect_db_table with cardinality and sampling."""
|
||||
data_dir = str(tmp_path)
|
||||
|
||||
# 1. Mock columns and types
|
||||
columns_df = pd.DataFrame({
|
||||
"column_name": ["id", "category", "count"],
|
||||
"data_type": ["integer", "text", "integer"]
|
||||
})
|
||||
|
||||
# 2. Mock row count
|
||||
total_rows_df = pd.DataFrame([{"count": 100}])
|
||||
|
||||
# 3. Mock PK discovery
|
||||
pk_df = pd.DataFrame({"column_name": ["id"]})
|
||||
|
||||
# 4. Mock stats for columns
|
||||
# We need to handle multiple calls to query_df
|
||||
def side_effect(query):
|
||||
if "information_schema.columns" in query:
|
||||
return columns_df
|
||||
if "COUNT(*)" in query:
|
||||
return total_rows_df
|
||||
if "information_schema.key_column_usage" in query:
|
||||
return pk_df
|
||||
|
||||
# Category stats
|
||||
if 'COUNT("category")' in query:
|
||||
return pd.DataFrame([{"count": 100}])
|
||||
if 'COUNT(DISTINCT "category")' in query:
|
||||
return pd.DataFrame([{"count": 5}])
|
||||
if 'SELECT DISTINCT "category"' in query:
|
||||
return pd.DataFrame({"category": ["A", "B", "C", "D", "E"]})
|
||||
|
||||
# Count stats
|
||||
if 'COUNT("count")' in query:
|
||||
return pd.DataFrame([{"count": 100}])
|
||||
if 'COUNT(DISTINCT "count")' in query:
|
||||
return pd.DataFrame([{"count": 100}])
|
||||
if 'AVG("count")' in query:
|
||||
return pd.DataFrame([{"avg": 10.0, "min": 1, "max": 20}])
|
||||
|
||||
# ID stats (fix for IndexError)
|
||||
if 'COUNT("id")' in query:
|
||||
return pd.DataFrame([{"count": 100}])
|
||||
if 'COUNT(DISTINCT "id")' in query:
|
||||
return pd.DataFrame([{"count": 100}])
|
||||
if 'AVG("id")' in query:
|
||||
return pd.DataFrame([{"avg": 50.0, "min": 1, "max": 100}])
|
||||
|
||||
return pd.DataFrame()
|
||||
mock_db_client.query_df.side_effect = side_effect
|
||||
|
||||
# Run inspection
|
||||
inspect_db_table(mock_db_client, data_dir=data_dir)
|
||||
|
||||
# Read summary to verify
|
||||
summary = get_data_summary(data_dir)
|
||||
assert summary is not None
|
||||
assert "test_table" in summary
|
||||
assert "category" in summary
|
||||
assert "distinct_values" in summary
|
||||
assert "unique_count: 5" in summary
|
||||
assert "- A" in summary
|
||||
assert "- E" in summary
|
||||
assert "primary_key: id" in summary
|
||||
|
||||
def test_get_data_summary_none(tmp_path):
|
||||
"""Test get_data_summary when file doesn't exist."""
|
||||
assert get_data_summary(str(tmp_path)) is None
|
||||
|
||||
def test_inspect_db_table_temporal(mock_db_client, tmp_path):
|
||||
"""Test inspect_db_table with temporal columns."""
|
||||
data_dir = str(tmp_path)
|
||||
|
||||
columns_df = pd.DataFrame({
|
||||
"column_name": ["created_at"],
|
||||
"data_type": ["timestamp without time zone"]
|
||||
})
|
||||
total_rows_df = pd.DataFrame([{"count": 50}])
|
||||
pk_df = pd.DataFrame() # No PK
|
||||
|
||||
def side_effect(query):
|
||||
if "information_schema.columns" in query:
|
||||
return columns_df
|
||||
if "COUNT(*)" in query:
|
||||
return total_rows_df
|
||||
if "information_schema.key_column_usage" in query:
|
||||
return pk_df
|
||||
if 'COUNT("created_at")' in query:
|
||||
return pd.DataFrame([{"count": 50}])
|
||||
if 'COUNT(DISTINCT "created_at")' in query:
|
||||
return pd.DataFrame([{"count": 50}])
|
||||
if 'MIN("created_at")' in query:
|
||||
return pd.DataFrame([{"min": "2023-01-01", "max": "2023-12-31"}])
|
||||
return pd.DataFrame()
|
||||
|
||||
mock_db_client.query_df.side_effect = side_effect
|
||||
|
||||
inspect_db_table(mock_db_client, data_dir=data_dir)
|
||||
|
||||
summary = get_data_summary(data_dir)
|
||||
assert "created_at" in summary
|
||||
assert "min: '2023-01-01'" in summary
|
||||
assert "max: '2023-12-31'" in summary
|
||||
|
||||
def test_inspect_db_table_high_cardinality(mock_db_client, tmp_path):
|
||||
"""Test inspect_db_table with high cardinality categorical column (no sample values)."""
|
||||
data_dir = str(tmp_path)
|
||||
|
||||
columns_df = pd.DataFrame({
|
||||
"column_name": ["user_id"],
|
||||
"data_type": ["text"]
|
||||
})
|
||||
total_rows_df = pd.DataFrame([{"count": 100}])
|
||||
pk_df = pd.DataFrame()
|
||||
|
||||
def side_effect(query):
|
||||
if "information_schema.columns" in query:
|
||||
return columns_df
|
||||
if "COUNT(*)" in query:
|
||||
return total_rows_df
|
||||
if "information_schema.key_column_usage" in query:
|
||||
return pk_df
|
||||
if 'COUNT("user_id")' in query:
|
||||
return pd.DataFrame([{"count": 100}])
|
||||
if 'COUNT(DISTINCT "user_id")' in query:
|
||||
# High cardinality > 20
|
||||
return pd.DataFrame([{"count": 50}])
|
||||
return pd.DataFrame()
|
||||
|
||||
mock_db_client.query_df.side_effect = side_effect
|
||||
|
||||
inspect_db_table(mock_db_client, data_dir=data_dir)
|
||||
|
||||
summary = get_data_summary(data_dir)
|
||||
assert "user_id" in summary
|
||||
assert "unique_count: 50" in summary
|
||||
# Should NOT have distinct_values
|
||||
assert "distinct_values" not in summary
|
||||
|
||||
def test_inspect_db_table_checksum_skip(mock_db_client, tmp_path):
|
||||
"""Test that inspection is skipped if checksum matches."""
|
||||
data_dir = str(tmp_path)
|
||||
table = "test_table"
|
||||
|
||||
# 1. Create a fake checksum file
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
# Checksum is md5 of "ins|upd|del". Let's say mock returns "my_hash"
|
||||
|
||||
# Mock checksum query
|
||||
mock_db_client.query_df.return_value = pd.DataFrame([{"dml_hash": "my_hash"}])
|
||||
|
||||
# Write existing checksum
|
||||
with open(os.path.join(data_dir, "checksum"), "w") as f:
|
||||
f.write(f"{table}:my_hash\n")
|
||||
|
||||
# Write existing inspection
|
||||
with open(os.path.join(data_dir, "inspection.yaml"), "w") as f:
|
||||
f.write(f"{table}: {{ existing: true }}")
|
||||
|
||||
# Run inspection
|
||||
result = inspect_db_table(mock_db_client, data_dir=data_dir)
|
||||
|
||||
# Should return existing content
|
||||
assert "existing: true" in result
|
||||
# query_df should be called ONLY for checksum (once)
|
||||
# verify count of calls?
|
||||
# Logic: 1 call for checksum. If match, return.
|
||||
assert mock_db_client.query_df.call_count == 1
|
||||
|
||||
123
tests/test_executor.py
Normal file
123
tests/test_executor.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from unittest.mock import MagicMock, patch
|
||||
from matplotlib.figure import Figure
|
||||
from ea_chatbot.graph.nodes.executor import executor_node
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings():
|
||||
with patch("ea_chatbot.graph.nodes.executor.Settings") as MockSettings:
|
||||
mock_settings_instance = MagicMock()
|
||||
mock_settings_instance.db_host = "localhost"
|
||||
mock_settings_instance.db_port = 5432
|
||||
mock_settings_instance.db_user = "user"
|
||||
mock_settings_instance.db_pswd = "pass"
|
||||
mock_settings_instance.db_name = "test_db"
|
||||
mock_settings_instance.db_table = "test_table"
|
||||
MockSettings.return_value = mock_settings_instance
|
||||
yield mock_settings_instance
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_client():
|
||||
with patch("ea_chatbot.graph.nodes.executor.DBClient") as MockDBClient:
|
||||
mock_client_instance = MagicMock()
|
||||
MockDBClient.return_value = mock_client_instance
|
||||
yield mock_client_instance
|
||||
|
||||
def test_executor_node_success_simple_print(mock_settings, mock_db_client):
|
||||
"""Test executing simple code that prints to stdout."""
|
||||
state = {
|
||||
"code": "print('Hello, World!')",
|
||||
"question": "test",
|
||||
"messages": []
|
||||
}
|
||||
|
||||
result = executor_node(state)
|
||||
|
||||
assert "code_output" in result
|
||||
assert "Hello, World!" in result["code_output"]
|
||||
assert result["error"] is None
|
||||
assert result["plots"] == []
|
||||
assert result["dfs"] == {}
|
||||
|
||||
def test_executor_node_success_dataframe(mock_settings, mock_db_client):
|
||||
"""Test executing code that creates a DataFrame."""
|
||||
code = """
|
||||
import pandas as pd
|
||||
df = pd.DataFrame({'a': [1, 2], 'b': [3, 4]})
|
||||
print(df)
|
||||
"""
|
||||
state = {
|
||||
"code": code,
|
||||
"question": "test",
|
||||
"messages": []
|
||||
}
|
||||
|
||||
result = executor_node(state)
|
||||
|
||||
assert "code_output" in result
|
||||
assert "a b" in result["code_output"] # Check part of DF string representation
|
||||
assert "dfs" in result
|
||||
assert "df" in result["dfs"]
|
||||
assert isinstance(result["dfs"]["df"], pd.DataFrame)
|
||||
|
||||
def test_executor_node_success_plot(mock_settings, mock_db_client):
|
||||
"""Test executing code that generates a plot."""
|
||||
code = """
|
||||
import matplotlib.pyplot as plt
|
||||
fig = plt.figure()
|
||||
plots.append(fig)
|
||||
print('Plot generated')
|
||||
"""
|
||||
state = {
|
||||
"code": code,
|
||||
"question": "test",
|
||||
"messages": []
|
||||
}
|
||||
|
||||
result = executor_node(state)
|
||||
|
||||
assert "Plot generated" in result["code_output"]
|
||||
assert "plots" in result
|
||||
assert len(result["plots"]) == 1
|
||||
assert isinstance(result["plots"][0], Figure)
|
||||
|
||||
def test_executor_node_error_syntax(mock_settings, mock_db_client):
|
||||
"""Test executing code with a syntax error."""
|
||||
state = {
|
||||
"code": "print('Hello World", # Missing closing quote
|
||||
"question": "test",
|
||||
"messages": []
|
||||
}
|
||||
|
||||
result = executor_node(state)
|
||||
|
||||
assert result["error"] is not None
|
||||
assert "SyntaxError" in result["error"]
|
||||
|
||||
def test_executor_node_error_runtime(mock_settings, mock_db_client):
|
||||
"""Test executing code with a runtime error."""
|
||||
state = {
|
||||
"code": "print(1 / 0)",
|
||||
"question": "test",
|
||||
"messages": []
|
||||
}
|
||||
|
||||
result = executor_node(state)
|
||||
|
||||
assert result["error"] is not None
|
||||
assert "ZeroDivisionError" in result["error"]
|
||||
|
||||
def test_executor_node_no_code(mock_settings, mock_db_client):
|
||||
"""Test handling when no code is provided."""
|
||||
state = {
|
||||
"code": None,
|
||||
"question": "test",
|
||||
"messages": []
|
||||
}
|
||||
|
||||
result = executor_node(state)
|
||||
|
||||
assert "error" in result
|
||||
assert "No code provided" in result["error"]
|
||||
77
tests/test_helpers.py
Normal file
77
tests/test_helpers.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import pytest
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from ea_chatbot.utils.helpers import merge_agent_state
|
||||
|
||||
def test_merge_agent_state_list_accumulation():
|
||||
"""Verify that list fields (messages, plots) are accumulated (appended)."""
|
||||
current_state = {
|
||||
"messages": [HumanMessage(content="hello")],
|
||||
"plots": ["plot1"]
|
||||
}
|
||||
update = {
|
||||
"messages": [AIMessage(content="hi")],
|
||||
"plots": ["plot2"]
|
||||
}
|
||||
|
||||
merged = merge_agent_state(current_state, update)
|
||||
|
||||
assert len(merged["messages"]) == 2
|
||||
assert merged["messages"][0].content == "hello"
|
||||
assert merged["messages"][1].content == "hi"
|
||||
|
||||
assert len(merged["plots"]) == 2
|
||||
assert merged["plots"] == ["plot1", "plot2"]
|
||||
|
||||
def test_merge_agent_state_dict_update():
|
||||
"""Verify that dictionary fields (dfs) are updated (shallow merge)."""
|
||||
current_state = {
|
||||
"dfs": {"df1": "data1"}
|
||||
}
|
||||
update = {
|
||||
"dfs": {"df2": "data2"}
|
||||
}
|
||||
|
||||
merged = merge_agent_state(current_state, update)
|
||||
|
||||
assert merged["dfs"] == {"df1": "data1", "df2": "data2"}
|
||||
|
||||
# Verify overwrite within dict
|
||||
update_overwrite = {
|
||||
"dfs": {"df1": "new_data1"}
|
||||
}
|
||||
merged_overwrite = merge_agent_state(merged, update_overwrite)
|
||||
assert merged_overwrite["dfs"] == {"df1": "new_data1", "df2": "data2"}
|
||||
|
||||
def test_merge_agent_state_standard_overwrite():
|
||||
"""Verify that standard fields are overwritten."""
|
||||
current_state = {
|
||||
"question": "old question",
|
||||
"next_action": "old action",
|
||||
"plan": "old plan"
|
||||
}
|
||||
update = {
|
||||
"question": "new question",
|
||||
"next_action": "new action",
|
||||
"plan": "new plan"
|
||||
}
|
||||
|
||||
merged = merge_agent_state(current_state, update)
|
||||
|
||||
assert merged["question"] == "new question"
|
||||
assert merged["next_action"] == "new action"
|
||||
assert merged["plan"] == "new plan"
|
||||
|
||||
def test_merge_agent_state_none_handling():
|
||||
"""Verify that None updates or missing keys in update don't break things."""
|
||||
current_state = {
|
||||
"question": "test",
|
||||
"messages": ["msg1"]
|
||||
}
|
||||
|
||||
# Empty update
|
||||
assert merge_agent_state(current_state, {}) == current_state
|
||||
|
||||
# Update with None value for overwritable field
|
||||
merged_none = merge_agent_state(current_state, {"question": None})
|
||||
assert merged_none["question"] is None
|
||||
assert merged_none["messages"] == ["msg1"]
|
||||
145
tests/test_history_manager.py
Normal file
145
tests/test_history_manager.py
Normal file
@@ -0,0 +1,145 @@
|
||||
import pytest
|
||||
from ea_chatbot.history.manager import HistoryManager
|
||||
from ea_chatbot.history.models import User, Conversation, Message, Plot
|
||||
from ea_chatbot.config import Settings
|
||||
from sqlalchemy import delete
|
||||
|
||||
@pytest.fixture
|
||||
def history_manager():
|
||||
settings = Settings()
|
||||
manager = HistoryManager(settings.history_db_url)
|
||||
# Clean up tables before tests (order matters because of foreign keys)
|
||||
with manager.get_session() as session:
|
||||
session.execute(delete(Plot))
|
||||
session.execute(delete(Message))
|
||||
session.execute(delete(Conversation))
|
||||
session.execute(delete(User))
|
||||
return manager
|
||||
|
||||
def test_history_manager_initialization(history_manager):
|
||||
assert history_manager.engine is not None
|
||||
assert history_manager.SessionLocal is not None
|
||||
|
||||
def test_history_manager_session_context(history_manager):
|
||||
with history_manager.get_session() as session:
|
||||
assert session is not None
|
||||
|
||||
def test_get_user_not_found(history_manager):
|
||||
user = history_manager.get_user("nonexistent@example.com")
|
||||
assert user is None
|
||||
|
||||
def test_authenticate_user_success(history_manager):
|
||||
email = "test@example.com"
|
||||
password = "secretpassword"
|
||||
history_manager.create_user(email=email, password=password)
|
||||
|
||||
user = history_manager.authenticate_user(email, password)
|
||||
assert user is not None
|
||||
assert user.username == email
|
||||
|
||||
def test_authenticate_user_failure(history_manager):
|
||||
email = "test@example.com"
|
||||
history_manager.create_user(email=email, password="correctpassword")
|
||||
|
||||
user = history_manager.authenticate_user(email, "wrongpassword")
|
||||
assert user is None
|
||||
|
||||
def test_sync_user_from_oidc_new_user(history_manager):
|
||||
user = history_manager.sync_user_from_oidc(
|
||||
email="new@example.com",
|
||||
display_name="New User"
|
||||
)
|
||||
assert user is not None
|
||||
assert user.username == "new@example.com"
|
||||
assert user.display_name == "New User"
|
||||
|
||||
def test_sync_user_from_oidc_existing_user(history_manager):
|
||||
# First sync
|
||||
history_manager.sync_user_from_oidc(
|
||||
email="existing@example.com",
|
||||
display_name="First Name"
|
||||
)
|
||||
# Second sync should update or return same user
|
||||
user = history_manager.sync_user_from_oidc(
|
||||
email="existing@example.com",
|
||||
display_name="Updated Name"
|
||||
)
|
||||
assert user.display_name == "Updated Name"
|
||||
|
||||
# --- Conversation Management Tests ---
|
||||
|
||||
@pytest.fixture
|
||||
def user(history_manager):
|
||||
return history_manager.create_user(email="conv_user@example.com")
|
||||
|
||||
def test_create_conversation(history_manager, user):
|
||||
conv = history_manager.create_conversation(
|
||||
user_id=user.id,
|
||||
data_state="new_jersey",
|
||||
name="Test Chat",
|
||||
summary="A test conversation summary"
|
||||
)
|
||||
assert conv is not None
|
||||
assert conv.name == "Test Chat"
|
||||
assert conv.summary == "A test conversation summary"
|
||||
assert conv.user_id == user.id
|
||||
|
||||
def test_get_conversations(history_manager, user):
|
||||
history_manager.create_conversation(user_id=user.id, data_state="nj", name="C1")
|
||||
history_manager.create_conversation(user_id=user.id, data_state="nj", name="C2")
|
||||
history_manager.create_conversation(user_id=user.id, data_state="ny", name="C3")
|
||||
|
||||
nj_convs = history_manager.get_conversations(user_id=user.id, data_state="nj")
|
||||
assert len(nj_convs) == 2
|
||||
|
||||
ny_convs = history_manager.get_conversations(user_id=user.id, data_state="ny")
|
||||
assert len(ny_convs) == 1
|
||||
|
||||
def test_rename_conversation(history_manager, user):
|
||||
conv = history_manager.create_conversation(user.id, "nj", "Old Name")
|
||||
updated = history_manager.rename_conversation(conv.id, "New Name")
|
||||
assert updated.name == "New Name"
|
||||
|
||||
def test_delete_conversation(history_manager, user):
|
||||
conv = history_manager.create_conversation(user.id, "nj", "To Delete")
|
||||
history_manager.delete_conversation(conv.id)
|
||||
|
||||
convs = history_manager.get_conversations(user.id, "nj")
|
||||
assert len(convs) == 0
|
||||
|
||||
# --- Message Management Tests ---
|
||||
|
||||
@pytest.fixture
|
||||
def conversation(history_manager, user):
|
||||
return history_manager.create_conversation(user.id, "nj", "Msg Test Conv")
|
||||
|
||||
def test_add_message(history_manager, conversation):
|
||||
msg = history_manager.add_message(
|
||||
conversation_id=conversation.id,
|
||||
role="user",
|
||||
content="Hello world"
|
||||
)
|
||||
assert msg is not None
|
||||
assert msg.content == "Hello world"
|
||||
assert msg.role == "user"
|
||||
assert msg.conversation_id == conversation.id
|
||||
|
||||
def test_add_message_with_plots(history_manager, conversation):
|
||||
plots_data = [b"fake_plot_1", b"fake_plot_2"]
|
||||
msg = history_manager.add_message(
|
||||
conversation_id=conversation.id,
|
||||
role="assistant",
|
||||
content="Here are plots",
|
||||
plots=plots_data
|
||||
)
|
||||
assert len(msg.plots) == 2
|
||||
assert msg.plots[0].image_data == b"fake_plot_1"
|
||||
|
||||
def test_get_messages(history_manager, conversation):
|
||||
history_manager.add_message(conversation.id, "user", "Q1")
|
||||
history_manager.add_message(conversation.id, "assistant", "A1")
|
||||
|
||||
messages = history_manager.get_messages(conversation.id)
|
||||
assert len(messages) == 2
|
||||
assert messages[0].content == "Q1"
|
||||
assert messages[1].content == "A1"
|
||||
55
tests/test_history_models.py
Normal file
55
tests/test_history_models.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker, DeclarativeBase
|
||||
|
||||
# We anticipate these imports will fail initially
|
||||
try:
|
||||
from ea_chatbot.history.models import Base, User, Conversation, Message, Plot
|
||||
except ImportError:
|
||||
Base = None
|
||||
User = None
|
||||
Conversation = None
|
||||
Message = None
|
||||
Plot = None
|
||||
|
||||
def test_models_exist():
|
||||
assert User is not None, "User model not found"
|
||||
assert Conversation is not None, "Conversation model not found"
|
||||
assert Message is not None, "Message model not found"
|
||||
assert Plot is not None, "Plot model not found"
|
||||
assert Base is not None, "Base declarative class not found"
|
||||
|
||||
def test_user_model_columns():
|
||||
if not User: pytest.fail("User model undefined")
|
||||
# Basic check if columns exist (by inspecting __table__.columns)
|
||||
columns = User.__table__.columns.keys()
|
||||
assert "id" in columns
|
||||
assert "username" in columns
|
||||
assert "password_hash" in columns
|
||||
assert "display_name" in columns
|
||||
|
||||
def test_conversation_model_columns():
|
||||
if not Conversation: pytest.fail("Conversation model undefined")
|
||||
columns = Conversation.__table__.columns.keys()
|
||||
assert "id" in columns
|
||||
assert "user_id" in columns
|
||||
assert "data_state" in columns
|
||||
assert "name" in columns
|
||||
assert "summary" in columns
|
||||
assert "created_at" in columns
|
||||
|
||||
def test_message_model_columns():
|
||||
if not Message: pytest.fail("Message model undefined")
|
||||
columns = Message.__table__.columns.keys()
|
||||
assert "id" in columns
|
||||
assert "role" in columns
|
||||
assert "content" in columns
|
||||
assert "conversation_id" in columns
|
||||
assert "created_at" in columns
|
||||
|
||||
def test_plot_model_columns():
|
||||
if not Plot: pytest.fail("Plot model undefined")
|
||||
columns = Plot.__table__.columns.keys()
|
||||
assert "id" in columns
|
||||
assert "message_id" in columns
|
||||
assert "image_data" in columns
|
||||
4
tests/test_history_module.py
Normal file
4
tests/test_history_module.py
Normal file
@@ -0,0 +1,4 @@
|
||||
import ea_chatbot.history
|
||||
|
||||
def test_history_module_importable():
|
||||
assert ea_chatbot.history is not None
|
||||
54
tests/test_llm_factory.py
Normal file
54
tests/test_llm_factory.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import pytest
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from ea_chatbot.config import LLMConfig
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
|
||||
def test_get_openai_model(monkeypatch):
|
||||
"""Test creating an OpenAI model."""
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "dummy")
|
||||
config = LLMConfig(
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
temperature=0.5,
|
||||
max_tokens=100
|
||||
)
|
||||
model = get_llm_model(config)
|
||||
assert isinstance(model, ChatOpenAI)
|
||||
assert model.model_name == "gpt-4o"
|
||||
assert model.temperature == 0.5
|
||||
assert model.max_tokens == 100
|
||||
|
||||
def test_get_google_model(monkeypatch):
|
||||
"""Test creating a Google model."""
|
||||
monkeypatch.setenv("GOOGLE_API_KEY", "dummy")
|
||||
config = LLMConfig(
|
||||
provider="google",
|
||||
model="gemini-1.5-pro",
|
||||
temperature=0.7
|
||||
)
|
||||
model = get_llm_model(config)
|
||||
assert isinstance(model, ChatGoogleGenerativeAI)
|
||||
assert model.model == "gemini-1.5-pro"
|
||||
assert model.temperature == 0.7
|
||||
|
||||
def test_unsupported_provider():
|
||||
"""Test that an unsupported provider raises an error."""
|
||||
config = LLMConfig(provider="unknown", model="test")
|
||||
with pytest.raises(ValueError, match="Unsupported LLM provider: unknown"):
|
||||
get_llm_model(config)
|
||||
|
||||
def test_provider_specific_params(monkeypatch):
|
||||
"""Test passing provider specific params."""
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "dummy")
|
||||
config = LLMConfig(
|
||||
provider="openai",
|
||||
model="o1-preview",
|
||||
provider_specific={"reasoning_effort": "high"}
|
||||
)
|
||||
# Note: reasoning_effort support depends on the langchain-openai version,
|
||||
# but we check if kwargs are passed.
|
||||
model = get_llm_model(config)
|
||||
assert isinstance(model, ChatOpenAI)
|
||||
# Check if reasoning_effort was passed correctly
|
||||
assert getattr(model, "reasoning_effort", None) == "high"
|
||||
19
tests/test_llm_factory_callbacks.py
Normal file
19
tests/test_llm_factory_callbacks.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import pytest
|
||||
from langchain_openai import ChatOpenAI
|
||||
from ea_chatbot.config import LLMConfig
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
|
||||
class MockHandler(BaseCallbackHandler):
|
||||
pass
|
||||
|
||||
def test_get_llm_model_with_callbacks(monkeypatch):
|
||||
"""Test that callbacks are passed to the model."""
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "dummy")
|
||||
config = LLMConfig(provider="openai", model="gpt-4o")
|
||||
handler = MockHandler()
|
||||
|
||||
model = get_llm_model(config, callbacks=[handler])
|
||||
|
||||
assert isinstance(model, ChatOpenAI)
|
||||
assert handler in model.callbacks
|
||||
44
tests/test_logging_context.py
Normal file
44
tests/test_logging_context.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import logging
|
||||
import pytest
|
||||
import io
|
||||
import json
|
||||
from ea_chatbot.utils.logging import ContextLoggerAdapter, JsonFormatter
|
||||
|
||||
@pytest.fixture
|
||||
def json_log_capture():
|
||||
"""Fixture to capture JSON logs."""
|
||||
log_stream = io.StringIO()
|
||||
logger = logging.getLogger("test_context")
|
||||
logger.setLevel(logging.INFO)
|
||||
for handler in logger.handlers[:]:
|
||||
logger.removeHandler(handler)
|
||||
|
||||
handler = logging.StreamHandler(log_stream)
|
||||
handler.setFormatter(JsonFormatter())
|
||||
logger.addHandler(handler)
|
||||
return logger, log_stream
|
||||
|
||||
def test_context_logger_adapter_injects_metadata(json_log_capture):
|
||||
"""Test that ContextLoggerAdapter injects metadata into the log record."""
|
||||
logger, log_stream = json_log_capture
|
||||
adapter = ContextLoggerAdapter(logger, {"run_id": "123", "node_name": "test_node"})
|
||||
|
||||
adapter.info("test message")
|
||||
|
||||
data = json.loads(log_stream.getvalue())
|
||||
assert data["message"] == "test message"
|
||||
assert data["run_id"] == "123"
|
||||
assert data["node_name"] == "test_node"
|
||||
|
||||
def test_context_logger_adapter_override_metadata(json_log_capture):
|
||||
"""Test that extra metadata can be provided during call."""
|
||||
logger, log_stream = json_log_capture
|
||||
adapter = ContextLoggerAdapter(logger, {"run_id": "123"})
|
||||
|
||||
# Passing extra context via the 'extra' parameter in standard logging
|
||||
# Note: Our adapter should handle merging this.
|
||||
adapter.info("test message", extra={"node_name": "dynamic_node"})
|
||||
|
||||
data = json.loads(log_stream.getvalue())
|
||||
assert data["run_id"] == "123"
|
||||
assert data["node_name"] == "dynamic_node"
|
||||
67
tests/test_logging_core.py
Normal file
67
tests/test_logging_core.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import logging
|
||||
import pytest
|
||||
from ea_chatbot.utils.logging import get_logger
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_logging():
|
||||
"""Reset the ea_chatbot logger handlers before each test."""
|
||||
logger = logging.getLogger("ea_chatbot")
|
||||
# Remove all existing handlers
|
||||
for handler in logger.handlers[:]:
|
||||
logger.removeHandler(handler)
|
||||
yield
|
||||
# Also clean up after test
|
||||
for handler in logger.handlers[:]:
|
||||
logger.removeHandler(handler)
|
||||
|
||||
def test_get_logger_singleton():
|
||||
"""Test that get_logger returns the same logger instance for the same name."""
|
||||
logger1 = get_logger("test_logger")
|
||||
logger2 = get_logger("test_logger")
|
||||
assert logger1 is logger2
|
||||
|
||||
def test_get_logger_rich_handler():
|
||||
"""Test that get_logger configures a RichHandler on root."""
|
||||
get_logger("test_rich")
|
||||
root = logging.getLogger("ea_chatbot")
|
||||
# Check if any handler is a RichHandler
|
||||
handler_names = [h.__class__.__name__ for h in root.handlers]
|
||||
assert "RichHandler" in handler_names
|
||||
|
||||
def test_get_logger_level():
|
||||
"""Test that get_logger sets the correct log level."""
|
||||
logger = get_logger("test_level", level="DEBUG")
|
||||
assert logger.level == logging.DEBUG
|
||||
|
||||
def test_json_formatter_serializes_dict():
|
||||
"""Test that JsonFormatter serializes log records to JSON."""
|
||||
from ea_chatbot.utils.logging import JsonFormatter
|
||||
import json
|
||||
|
||||
formatter = JsonFormatter()
|
||||
record = logging.LogRecord(
|
||||
name="test", level=logging.INFO, pathname="test.py", lineno=10,
|
||||
msg="test message", args=(), exc_info=None
|
||||
)
|
||||
formatted = formatter.format(record)
|
||||
data = json.loads(formatted)
|
||||
|
||||
assert data["message"] == "test message"
|
||||
assert data["level"] == "INFO"
|
||||
assert "timestamp" in data
|
||||
|
||||
def test_get_logger_file_handler(tmp_path):
|
||||
"""Test that get_logger configures a file handler on root."""
|
||||
log_file = tmp_path / "test.json"
|
||||
logger = get_logger("test_file", log_file=str(log_file))
|
||||
|
||||
root = logging.getLogger("ea_chatbot")
|
||||
handler_names = [h.__class__.__name__ for h in root.handlers]
|
||||
assert "RotatingFileHandler" in handler_names
|
||||
|
||||
logger.info("file log test")
|
||||
|
||||
# Check if file exists and has content
|
||||
assert log_file.exists()
|
||||
content = log_file.read_text()
|
||||
assert "file log test" in content
|
||||
83
tests/test_logging_e2e.py
Normal file
83
tests/test_logging_e2e.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import os
|
||||
import json
|
||||
import pytest
|
||||
import logging
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.graph.workflow import app
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.utils.logging import get_logger
|
||||
from langchain_community.chat_models import FakeListChatModel
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_logging():
|
||||
"""Reset handlers on the root ea_chatbot logger."""
|
||||
root = logging.getLogger("ea_chatbot")
|
||||
for handler in root.handlers[:]:
|
||||
root.removeHandler(handler)
|
||||
yield
|
||||
for handler in root.handlers[:]:
|
||||
root.removeHandler(handler)
|
||||
|
||||
class FakeStructuredModel(FakeListChatModel):
|
||||
def with_structured_output(self, schema, **kwargs):
|
||||
# Return a runnable that returns a parsed object
|
||||
def _invoke(input, config=None, **kwargs):
|
||||
content = self.responses[0]
|
||||
import json
|
||||
data = json.loads(content)
|
||||
if hasattr(schema, "model_validate"):
|
||||
return schema.model_validate(data)
|
||||
return data
|
||||
|
||||
return RunnableLambda(_invoke)
|
||||
|
||||
def test_logging_e2e_json_output(tmp_path):
|
||||
"""Test that a full graph run produces structured JSON logs from multiple nodes."""
|
||||
log_file = tmp_path / "e2e_test.jsonl"
|
||||
|
||||
# Configure the root logger
|
||||
get_logger("ea_chatbot", log_file=str(log_file))
|
||||
|
||||
initial_state: AgentState = {
|
||||
"messages": [],
|
||||
"question": "Who won in 2024?",
|
||||
"analysis": None,
|
||||
"next_action": "",
|
||||
"plan": None,
|
||||
"code": None,
|
||||
"code_output": None,
|
||||
"error": None,
|
||||
"plots": [],
|
||||
"dfs": {}
|
||||
}
|
||||
|
||||
# Create fake models that support callbacks and structured output
|
||||
fake_analyzer_response = """{"data_required": [], "unknowns": [], "ambiguities": ["Which year?"], "conditions": [], "next_action": "clarify"}"""
|
||||
fake_analyzer = FakeStructuredModel(responses=[fake_analyzer_response])
|
||||
|
||||
fake_clarify = FakeListChatModel(responses=["Please specify."])
|
||||
|
||||
with patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model") as mock_llm_factory:
|
||||
mock_llm_factory.return_value = fake_analyzer
|
||||
|
||||
with patch("ea_chatbot.graph.nodes.clarification.get_llm_model") as mock_clarify_llm_factory:
|
||||
mock_clarify_llm_factory.return_value = fake_clarify
|
||||
|
||||
# Run the graph
|
||||
list(app.stream(initial_state))
|
||||
|
||||
# Verify file content
|
||||
assert log_file.exists()
|
||||
lines = log_file.read_text().splitlines()
|
||||
assert len(lines) > 0
|
||||
|
||||
# Verify we have logs from different nodes
|
||||
node_names = [json.loads(line)["name"] for line in lines]
|
||||
assert "ea_chatbot.query_analyzer" in node_names
|
||||
assert "ea_chatbot.clarification" in node_names
|
||||
|
||||
# Verify events
|
||||
messages = [json.loads(line)["message"] for line in lines]
|
||||
assert any("Analyzing question" in m for m in messages)
|
||||
assert any("Clarification generated" in m for m in messages)
|
||||
64
tests/test_logging_langchain.py
Normal file
64
tests/test_logging_langchain.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import logging
|
||||
import pytest
|
||||
import io
|
||||
from unittest.mock import MagicMock
|
||||
from ea_chatbot.utils.logging import LangChainLoggingHandler
|
||||
|
||||
@pytest.fixture
|
||||
def log_capture():
|
||||
"""Fixture to capture logs from a logger."""
|
||||
log_stream = io.StringIO()
|
||||
logger = logging.getLogger("test_langchain")
|
||||
logger.setLevel(logging.INFO)
|
||||
# Remove existing handlers
|
||||
for handler in logger.handlers[:]:
|
||||
logger.removeHandler(handler)
|
||||
|
||||
handler = logging.StreamHandler(log_stream)
|
||||
logger.addHandler(handler)
|
||||
return logger, log_stream
|
||||
|
||||
def test_langchain_logging_handler_on_llm_start(log_capture):
|
||||
"""Test that on_llm_start logs the correct message."""
|
||||
logger, log_stream = log_capture
|
||||
handler = LangChainLoggingHandler(logger=logger)
|
||||
handler.on_llm_start(serialized={"name": "test_model"}, prompts=["test prompt"])
|
||||
|
||||
output = log_stream.getvalue()
|
||||
assert "LLM Started:" in output
|
||||
assert "test_model" in output
|
||||
|
||||
def test_langchain_logging_handler_on_llm_end(log_capture):
|
||||
"""Test that on_llm_end logs token usage."""
|
||||
logger, log_stream = log_capture
|
||||
handler = LangChainLoggingHandler(logger=logger)
|
||||
response = MagicMock()
|
||||
response.llm_output = {
|
||||
"token_usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 20,
|
||||
"total_tokens": 30
|
||||
},
|
||||
"model_name": "test_model"
|
||||
}
|
||||
|
||||
handler.on_llm_end(response=response)
|
||||
|
||||
output = log_stream.getvalue()
|
||||
assert "LLM Ended:" in output
|
||||
assert "test_model" in output
|
||||
assert "Tokens: 30" in output
|
||||
assert "10 prompt" in output
|
||||
assert "20 completion" in output
|
||||
|
||||
def test_langchain_logging_handler_on_llm_error(log_capture):
|
||||
"""Test that on_llm_error logs the error."""
|
||||
logger, log_stream = log_capture
|
||||
handler = LangChainLoggingHandler(logger=logger)
|
||||
error = Exception("test error")
|
||||
|
||||
handler.on_llm_error(error=error)
|
||||
|
||||
output = log_stream.getvalue()
|
||||
assert "LLM Error:" in output
|
||||
assert "test error" in output
|
||||
79
tests/test_multi_turn_planner_researcher.py
Normal file
79
tests/test_multi_turn_planner_researcher.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from ea_chatbot.graph.nodes.planner import planner_node
|
||||
from ea_chatbot.graph.nodes.researcher import researcher_node
|
||||
from ea_chatbot.graph.nodes.summarizer import summarizer_node
|
||||
from ea_chatbot.schemas import TaskPlanResponse
|
||||
|
||||
@pytest.fixture
|
||||
def mock_state_with_history():
|
||||
return {
|
||||
"messages": [
|
||||
HumanMessage(content="Show me the 2024 results for Florida"),
|
||||
AIMessage(content="Here are the results for Florida in 2024...")
|
||||
],
|
||||
"question": "What about in New Jersey?",
|
||||
"analysis": {"data_required": ["2024 results", "New Jersey"], "unknowns": [], "ambiguities": [], "conditions": []},
|
||||
"next_action": "plan",
|
||||
"summary": "The user is asking about 2024 election results.",
|
||||
"plan": "Plan steps...",
|
||||
"code_output": "Code output..."
|
||||
}
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.planner.get_llm_model")
|
||||
@patch("ea_chatbot.utils.database_inspection.get_data_summary")
|
||||
@patch("ea_chatbot.graph.nodes.planner.PLANNER_PROMPT")
|
||||
def test_planner_uses_history_and_summary(mock_prompt, mock_get_summary, mock_get_llm, mock_state_with_history):
|
||||
mock_get_summary.return_value = "Data summary"
|
||||
mock_llm_instance = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm_instance
|
||||
mock_structured_llm = MagicMock()
|
||||
mock_llm_instance.with_structured_output.return_value = mock_structured_llm
|
||||
|
||||
mock_structured_llm.invoke.return_value = TaskPlanResponse(
|
||||
goal="goal",
|
||||
reflection="reflection",
|
||||
context={
|
||||
"initial_context": "context",
|
||||
"assumptions": [],
|
||||
"constraints": []
|
||||
},
|
||||
steps=["Step 1: test"]
|
||||
)
|
||||
|
||||
planner_node(mock_state_with_history)
|
||||
|
||||
mock_prompt.format_messages.assert_called_once()
|
||||
kwargs = mock_prompt.format_messages.call_args[1]
|
||||
assert kwargs["question"] == "What about in New Jersey?"
|
||||
assert kwargs["summary"] == mock_state_with_history["summary"]
|
||||
assert len(kwargs["history"]) == 2
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
|
||||
@patch("ea_chatbot.graph.nodes.researcher.RESEARCHER_PROMPT")
|
||||
def test_researcher_uses_history_and_summary(mock_prompt, mock_get_llm, mock_state_with_history):
|
||||
mock_llm_instance = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm_instance
|
||||
|
||||
researcher_node(mock_state_with_history)
|
||||
|
||||
mock_prompt.format_messages.assert_called_once()
|
||||
kwargs = mock_prompt.format_messages.call_args[1]
|
||||
assert kwargs["question"] == "What about in New Jersey?"
|
||||
assert kwargs["summary"] == mock_state_with_history["summary"]
|
||||
assert len(kwargs["history"]) == 2
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.summarizer.get_llm_model")
|
||||
@patch("ea_chatbot.graph.nodes.summarizer.SUMMARIZER_PROMPT")
|
||||
def test_summarizer_uses_history_and_summary(mock_prompt, mock_get_llm, mock_state_with_history):
|
||||
mock_llm_instance = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm_instance
|
||||
|
||||
summarizer_node(mock_state_with_history)
|
||||
|
||||
mock_prompt.format_messages.assert_called_once()
|
||||
kwargs = mock_prompt.format_messages.call_args[1]
|
||||
assert kwargs["question"] == "What about in New Jersey?"
|
||||
assert kwargs["summary"] == mock_state_with_history["summary"]
|
||||
assert len(kwargs["history"]) == 2
|
||||
76
tests/test_multi_turn_query_analyzer.py
Normal file
76
tests/test_multi_turn_query_analyzer.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node, QueryAnalysis
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
|
||||
@pytest.fixture
|
||||
def mock_state_with_history():
|
||||
return {
|
||||
"messages": [
|
||||
HumanMessage(content="Show me the 2024 results for Florida"),
|
||||
AIMessage(content="Here are the results for Florida in 2024...")
|
||||
],
|
||||
"question": "What about in New Jersey?",
|
||||
"analysis": None,
|
||||
"next_action": "",
|
||||
"summary": "The user is asking about 2024 election results."
|
||||
}
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
|
||||
@patch("ea_chatbot.graph.nodes.query_analyzer.QUERY_ANALYZER_PROMPT")
|
||||
def test_query_analyzer_uses_history_and_summary(mock_prompt, mock_get_llm, mock_state_with_history):
|
||||
"""Test that query_analyzer_node passes history and summary to the prompt."""
|
||||
mock_llm_instance = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm_instance
|
||||
mock_structured_llm = MagicMock()
|
||||
mock_llm_instance.with_structured_output.return_value = mock_structured_llm
|
||||
|
||||
mock_structured_llm.invoke.return_value = QueryAnalysis(
|
||||
data_required=["2024 results", "New Jersey"],
|
||||
unknowns=[],
|
||||
ambiguities=[],
|
||||
conditions=[],
|
||||
next_action="plan"
|
||||
)
|
||||
|
||||
query_analyzer_node(mock_state_with_history)
|
||||
|
||||
# Verify that the prompt was formatted with the correct variables
|
||||
mock_prompt.format_messages.assert_called_once()
|
||||
kwargs = mock_prompt.format_messages.call_args[1]
|
||||
|
||||
assert kwargs["question"] == "What about in New Jersey?"
|
||||
assert "summary" in kwargs
|
||||
assert kwargs["summary"] == mock_state_with_history["summary"]
|
||||
assert "history" in kwargs
|
||||
# History should contain the messages from the state
|
||||
assert len(kwargs["history"]) == 2
|
||||
assert kwargs["history"][0].content == "Show me the 2024 results for Florida"
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
|
||||
def test_query_analyzer_context_window(mock_get_llm):
|
||||
"""Test that query_analyzer_node only uses the last 6 messages (3 turns)."""
|
||||
messages = [HumanMessage(content=f"Msg {i}") for i in range(10)]
|
||||
state = {
|
||||
"messages": messages,
|
||||
"question": "Latest question",
|
||||
"analysis": None,
|
||||
"next_action": "",
|
||||
"summary": "Summary"
|
||||
}
|
||||
|
||||
mock_llm_instance = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm_instance
|
||||
mock_structured_llm = MagicMock()
|
||||
mock_llm_instance.with_structured_output.return_value = mock_structured_llm
|
||||
mock_structured_llm.invoke.return_value = QueryAnalysis(
|
||||
data_required=[], unknowns=[], ambiguities=[], conditions=[], next_action="plan"
|
||||
)
|
||||
|
||||
with patch("ea_chatbot.graph.nodes.query_analyzer.QUERY_ANALYZER_PROMPT") as mock_prompt:
|
||||
query_analyzer_node(state)
|
||||
kwargs = mock_prompt.format_messages.call_args[1]
|
||||
# Should only have last 6 messages
|
||||
assert len(kwargs["history"]) == 6
|
||||
assert kwargs["history"][0].content == "Msg 4"
|
||||
87
tests/test_oidc_client.py
Normal file
87
tests/test_oidc_client.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.auth import OIDCClient
|
||||
|
||||
@pytest.fixture
|
||||
def oidc_config():
|
||||
return {
|
||||
"client_id": "test_id",
|
||||
"client_secret": "test_secret",
|
||||
"server_metadata_url": "https://example.com/.well-known/openid-configuration",
|
||||
"redirect_uri": "http://localhost:8501"
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_metadata():
|
||||
return {
|
||||
"authorization_endpoint": "https://example.com/auth",
|
||||
"token_endpoint": "https://example.com/token",
|
||||
"userinfo_endpoint": "https://example.com/userinfo"
|
||||
}
|
||||
|
||||
def test_oidc_fetch_metadata(oidc_config, mock_metadata):
|
||||
client = OIDCClient(**oidc_config)
|
||||
|
||||
with patch("requests.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = mock_metadata
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
metadata = client.fetch_metadata()
|
||||
|
||||
assert metadata == mock_metadata
|
||||
mock_get.assert_called_once_with(oidc_config["server_metadata_url"])
|
||||
|
||||
# Second call should use cache
|
||||
client.fetch_metadata()
|
||||
assert mock_get.call_count == 1
|
||||
|
||||
def test_oidc_get_login_url(oidc_config, mock_metadata):
|
||||
client = OIDCClient(**oidc_config)
|
||||
client.metadata = mock_metadata
|
||||
|
||||
with patch.object(client.oauth_session, "create_authorization_url") as mock_create_url:
|
||||
mock_create_url.return_value = ("https://example.com/auth?state=xyz", "xyz")
|
||||
|
||||
url = client.get_login_url()
|
||||
|
||||
assert url == "https://example.com/auth?state=xyz"
|
||||
mock_create_url.assert_called_once_with(mock_metadata["authorization_endpoint"])
|
||||
|
||||
def test_oidc_get_login_url_missing_endpoint(oidc_config):
|
||||
client = OIDCClient(**oidc_config)
|
||||
client.metadata = {"some_other": "field"}
|
||||
|
||||
with pytest.raises(ValueError, match="authorization_endpoint not found"):
|
||||
client.get_login_url()
|
||||
|
||||
def test_oidc_exchange_code_for_token(oidc_config, mock_metadata):
|
||||
client = OIDCClient(**oidc_config)
|
||||
client.metadata = mock_metadata
|
||||
|
||||
with patch.object(client.oauth_session, "fetch_token") as mock_fetch_token:
|
||||
mock_fetch_token.return_value = {"access_token": "abc"}
|
||||
|
||||
token = client.exchange_code_for_token("test_code")
|
||||
|
||||
assert token == {"access_token": "abc"}
|
||||
mock_fetch_token.assert_called_once_with(
|
||||
mock_metadata["token_endpoint"],
|
||||
code="test_code",
|
||||
client_secret=oidc_config["client_secret"]
|
||||
)
|
||||
|
||||
def test_oidc_get_user_info(oidc_config, mock_metadata):
|
||||
client = OIDCClient(**oidc_config)
|
||||
client.metadata = mock_metadata
|
||||
|
||||
with patch.object(client.oauth_session, "get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"sub": "user123"}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
user_info = client.get_user_info({"access_token": "abc"})
|
||||
|
||||
assert user_info == {"sub": "user123"}
|
||||
assert client.oauth_session.token == {"access_token": "abc"}
|
||||
mock_get.assert_called_once_with(mock_metadata["userinfo_endpoint"])
|
||||
46
tests/test_planner.py
Normal file
46
tests/test_planner.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.graph.nodes.planner import planner_node
|
||||
|
||||
@pytest.fixture
|
||||
def mock_state():
|
||||
return {
|
||||
"messages": [],
|
||||
"question": "Show me results for New Jersey",
|
||||
"analysis": {
|
||||
# "requires_dataset" removed as it's no longer used
|
||||
"expert": "Data Analyst",
|
||||
"data": "NJ data",
|
||||
"unknown": "results",
|
||||
"condition": "state=NJ"
|
||||
},
|
||||
"next_action": "plan",
|
||||
"plan": None
|
||||
}
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.planner.get_llm_model")
|
||||
@patch("ea_chatbot.utils.database_inspection.get_data_summary")
|
||||
def test_planner_node(mock_get_summary, mock_get_llm, mock_state):
|
||||
"""Test planner node with unified prompt."""
|
||||
mock_get_summary.return_value = "Column: Name, Type: text"
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
from ea_chatbot.schemas import TaskPlanResponse, TaskPlanContext
|
||||
mock_plan = TaskPlanResponse(
|
||||
goal="Get NJ results",
|
||||
reflection="The user wants NJ results",
|
||||
context=TaskPlanContext(initial_context="NJ data", assumptions=[], constraints=[]),
|
||||
steps=["Step 1: Load data", "Step 2: Filter by NJ"]
|
||||
)
|
||||
mock_llm.with_structured_output.return_value.invoke.return_value = mock_plan
|
||||
|
||||
result = planner_node(mock_state)
|
||||
|
||||
assert "plan" in result
|
||||
assert "Step 1: Load data" in result["plan"]
|
||||
assert "Step 2: Filter by NJ" in result["plan"]
|
||||
|
||||
# Verify helper was called
|
||||
mock_get_summary.assert_called_once()
|
||||
80
tests/test_query_analyzer.py
Normal file
80
tests/test_query_analyzer.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node, QueryAnalysis
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
|
||||
@pytest.fixture
|
||||
def mock_state():
|
||||
return {
|
||||
"messages": [],
|
||||
"question": "Show me the 2024 results for Florida",
|
||||
"analysis": None,
|
||||
"next_action": ""
|
||||
}
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
|
||||
def test_query_analyzer_data_analysis(mock_get_llm, mock_state):
|
||||
"""Test that a clear data analysis query is routed to the planner."""
|
||||
# Mock the LLM and the structured output runnable
|
||||
mock_llm_instance = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm_instance
|
||||
mock_structured_llm = MagicMock()
|
||||
mock_llm_instance.with_structured_output.return_value = mock_structured_llm
|
||||
# Define the expected Pydantic result
|
||||
expected_analysis = QueryAnalysis(
|
||||
data_required=["2024 results", "Florida"],
|
||||
unknowns=[],
|
||||
ambiguities=[],
|
||||
conditions=[],
|
||||
next_action="plan"
|
||||
)
|
||||
# When structured_llm.invoke is called with messages, return the Pydantic object
|
||||
mock_structured_llm.invoke.return_value = expected_analysis
|
||||
|
||||
new_state_update = query_analyzer_node(mock_state)
|
||||
|
||||
assert new_state_update["next_action"] == "plan"
|
||||
assert "2024 results" in new_state_update["analysis"]["data_required"]
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
|
||||
def test_query_analyzer_ambiguous(mock_get_llm, mock_state):
|
||||
"""Test that an ambiguous query is routed to clarification."""
|
||||
mock_state["question"] = "What happened?"
|
||||
|
||||
mock_llm_instance = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm_instance
|
||||
mock_structured_llm = MagicMock()
|
||||
mock_llm_instance.with_structured_output.return_value = mock_structured_llm
|
||||
expected_analysis = QueryAnalysis(
|
||||
data_required=[],
|
||||
unknowns=["What event?"],
|
||||
ambiguities=[],
|
||||
conditions=[],
|
||||
next_action="clarify"
|
||||
)
|
||||
|
||||
mock_structured_llm.invoke.return_value = expected_analysis
|
||||
|
||||
new_state_update = query_analyzer_node(mock_state)
|
||||
|
||||
assert new_state_update["next_action"] == "clarify"
|
||||
assert len(new_state_update["analysis"]["unknowns"]) > 0
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
|
||||
def test_query_analyzer_uses_config(mock_get_llm, mock_state, monkeypatch):
|
||||
"""Test that the node uses the configured LLM settings."""
|
||||
monkeypatch.setenv("QUERY_ANALYZER_LLM__MODEL", "gpt-3.5-turbo")
|
||||
|
||||
mock_llm_instance = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm_instance
|
||||
mock_structured_llm = MagicMock()
|
||||
mock_llm_instance.with_structured_output.return_value = mock_structured_llm
|
||||
mock_structured_llm.invoke.return_value = QueryAnalysis(
|
||||
data_required=[], unknowns=[], ambiguities=[], conditions=[], next_action="plan"
|
||||
)
|
||||
|
||||
query_analyzer_node(mock_state)
|
||||
|
||||
# Verify get_llm_model was called with the overridden config
|
||||
called_config = mock_get_llm.call_args[0][0]
|
||||
assert called_config.model == "gpt-3.5-turbo"
|
||||
45
tests/test_query_analyzer_logging.py
Normal file
45
tests/test_query_analyzer_logging.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import pytest
|
||||
import logging
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node, QueryAnalysis
|
||||
|
||||
@pytest.fixture
|
||||
def mock_state():
|
||||
return {
|
||||
"messages": [],
|
||||
"question": "Show me the 2024 results for Florida",
|
||||
"analysis": None,
|
||||
"next_action": ""
|
||||
}
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
|
||||
@patch("ea_chatbot.graph.nodes.query_analyzer.get_logger")
|
||||
def test_query_analyzer_logs_actions(mock_get_logger, mock_get_llm, mock_state):
|
||||
"""Test that query_analyzer_node logs its main actions."""
|
||||
# Mock Logger
|
||||
mock_logger = MagicMock()
|
||||
mock_get_logger.return_value = mock_logger
|
||||
|
||||
# Mock LLM
|
||||
mock_llm_instance = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm_instance
|
||||
mock_structured_llm = MagicMock()
|
||||
mock_llm_instance.with_structured_output.return_value = mock_structured_llm
|
||||
|
||||
expected_analysis = QueryAnalysis(
|
||||
data_required=["2024 results", "Florida"],
|
||||
unknowns=[],
|
||||
ambiguities=[],
|
||||
conditions=[],
|
||||
next_action="plan"
|
||||
)
|
||||
mock_structured_llm.invoke.return_value = expected_analysis
|
||||
|
||||
query_analyzer_node(mock_state)
|
||||
|
||||
# Check that logger was called
|
||||
# We expect at least one log at the start and one at the end
|
||||
assert mock_logger.info.called
|
||||
|
||||
# Verify specific log messages if we decide on them
|
||||
# For now, just ensuring it's called is enough for Red phase
|
||||
103
tests/test_query_analyzer_refinement.py
Normal file
103
tests/test_query_analyzer_refinement.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node, QueryAnalysis
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
|
||||
@pytest.fixture
|
||||
def base_state():
|
||||
return {
|
||||
"messages": [],
|
||||
"question": "",
|
||||
"analysis": None,
|
||||
"next_action": "",
|
||||
"summary": ""
|
||||
}
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
|
||||
def test_refinement_coreference_from_history(mock_get_llm, base_state):
|
||||
"""
|
||||
Test that the analyzer can resolve Year/State from history.
|
||||
User asks "What about in NJ?" after a Florida 2024 query.
|
||||
Expected: next_action = 'plan', NOT 'clarify' due to missing year.
|
||||
"""
|
||||
state = base_state.copy()
|
||||
state["messages"] = [
|
||||
HumanMessage(content="Show me 2024 results for Florida"),
|
||||
AIMessage(content="Here are the 2024 results for Florida...")
|
||||
]
|
||||
state["question"] = "What about in New Jersey?"
|
||||
state["summary"] = "The user is looking for 2024 election results."
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_structured = MagicMock()
|
||||
mock_llm.with_structured_output.return_value = mock_structured
|
||||
|
||||
# We expect the LLM to eventually return 'plan' because it sees the context.
|
||||
# For now, if it returns 'clarify', this test should fail once we update the prompt to BE less strict.
|
||||
mock_structured.invoke.return_value = QueryAnalysis(
|
||||
data_required=["2024 results", "New Jersey"],
|
||||
unknowns=[],
|
||||
ambiguities=[],
|
||||
conditions=["state=NJ", "year=2024"],
|
||||
next_action="plan"
|
||||
)
|
||||
|
||||
result = query_analyzer_node(state)
|
||||
assert result["next_action"] == "plan"
|
||||
assert "NJ" in str(result["analysis"]["conditions"])
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
|
||||
def test_refinement_tolerance_for_missing_format(mock_get_llm, base_state):
|
||||
"""
|
||||
Test that the analyzer doesn't flag missing output format or database name.
|
||||
User asks "Give me a graph of turnout".
|
||||
Expected: next_action = 'plan', even if 'format' or 'db' is not in query.
|
||||
"""
|
||||
state = base_state.copy()
|
||||
state["question"] = "Give me a graph of voter turnout in 2024 for Florida"
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_structured = MagicMock()
|
||||
mock_llm.with_structured_output.return_value = mock_structured
|
||||
|
||||
mock_structured.invoke.return_value = QueryAnalysis(
|
||||
data_required=["voter turnout", "Florida"],
|
||||
unknowns=[],
|
||||
ambiguities=[],
|
||||
conditions=["year=2024"],
|
||||
next_action="plan"
|
||||
)
|
||||
|
||||
result = query_analyzer_node(state)
|
||||
assert result["next_action"] == "plan"
|
||||
# Ensure no ambiguities were added by the analyzer itself (hallucinated requirement)
|
||||
assert len(result["analysis"]["ambiguities"]) == 0
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
|
||||
def test_refinement_enforces_voter_identity_clarification(mock_get_llm, base_state):
|
||||
"""
|
||||
Test that 'track the same voter' still triggers clarification.
|
||||
"""
|
||||
state = base_state.copy()
|
||||
state["question"] = "Track the same voter participation in 2020 and 2024."
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_structured = MagicMock()
|
||||
mock_llm.with_structured_output.return_value = mock_structured
|
||||
|
||||
# We WANT it to clarify here because voter identity is not defined.
|
||||
mock_structured.invoke.return_value = QueryAnalysis(
|
||||
data_required=["voter participation"],
|
||||
unknowns=[],
|
||||
ambiguities=["Please define what fields constitute 'the same voter' (e.g. ID, or Name and DOB)."],
|
||||
conditions=[],
|
||||
next_action="clarify"
|
||||
)
|
||||
|
||||
result = query_analyzer_node(state)
|
||||
assert result["next_action"] == "clarify"
|
||||
assert "identity" in str(result["analysis"]["ambiguities"]).lower() or "same voter" in str(result["analysis"]["ambiguities"]).lower()
|
||||
34
tests/test_researcher.py
Normal file
34
tests/test_researcher.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_openai import ChatOpenAI
|
||||
from ea_chatbot.graph.nodes.researcher import researcher_node
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm():
|
||||
with patch("ea_chatbot.graph.nodes.researcher.get_llm_model") as mock_get_llm:
|
||||
mock_llm_instance = MagicMock(spec=ChatOpenAI)
|
||||
mock_get_llm.return_value = mock_llm_instance
|
||||
yield mock_llm_instance
|
||||
|
||||
def test_researcher_node_success(mock_llm):
|
||||
"""Test that researcher_node invokes LLM with web_search tool and returns messages."""
|
||||
state = {
|
||||
"question": "What is the capital of France?",
|
||||
"messages": []
|
||||
}
|
||||
|
||||
mock_llm_with_tools = MagicMock()
|
||||
mock_llm.bind_tools.return_value = mock_llm_with_tools
|
||||
mock_llm_with_tools.invoke.return_value = AIMessage(content="The capital of France is Paris.")
|
||||
|
||||
result = researcher_node(state)
|
||||
|
||||
assert mock_llm.bind_tools.called
|
||||
# Check that it was called with web_search
|
||||
args, kwargs = mock_llm.bind_tools.call_args
|
||||
assert {"type": "web_search"} in args[0]
|
||||
|
||||
assert mock_llm_with_tools.invoke.called
|
||||
assert "messages" in result
|
||||
assert result["messages"][0].content == "The capital of France is Paris."
|
||||
62
tests/test_researcher_search_tools.py
Normal file
62
tests/test_researcher_search_tools.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from ea_chatbot.graph.nodes.researcher import researcher_node
|
||||
|
||||
@pytest.fixture
|
||||
def base_state():
|
||||
return {
|
||||
"question": "Who won the 2024 election?",
|
||||
"messages": [],
|
||||
"summary": ""
|
||||
}
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
|
||||
def test_researcher_binds_openai_search(mock_get_llm, base_state):
|
||||
"""Test that OpenAI LLM binds 'web_search' tool."""
|
||||
mock_llm = MagicMock(spec=ChatOpenAI)
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
mock_llm_with_tools = MagicMock()
|
||||
mock_llm.bind_tools.return_value = mock_llm_with_tools
|
||||
mock_llm_with_tools.invoke.return_value = AIMessage(content="OpenAI Search Result")
|
||||
|
||||
result = researcher_node(base_state)
|
||||
|
||||
# Verify bind_tools called with correct OpenAI tool
|
||||
mock_llm.bind_tools.assert_called_once_with([{"type": "web_search"}])
|
||||
assert result["messages"][0].content == "OpenAI Search Result"
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
|
||||
def test_researcher_binds_google_search(mock_get_llm, base_state):
|
||||
"""Test that Google LLM binds 'google_search' tool."""
|
||||
mock_llm = MagicMock(spec=ChatGoogleGenerativeAI)
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
mock_llm_with_tools = MagicMock()
|
||||
mock_llm.bind_tools.return_value = mock_llm_with_tools
|
||||
mock_llm_with_tools.invoke.return_value = AIMessage(content="Google Search Result")
|
||||
|
||||
result = researcher_node(base_state)
|
||||
|
||||
# Verify bind_tools called with correct Google tool
|
||||
mock_llm.bind_tools.assert_called_once_with([{"google_search": {}}])
|
||||
assert result["messages"][0].content == "Google Search Result"
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
|
||||
def test_researcher_fallback_on_bind_error(mock_get_llm, base_state):
|
||||
"""Test that researcher falls back to basic LLM if bind_tools fails."""
|
||||
mock_llm = MagicMock(spec=ChatOpenAI)
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
# Simulate bind_tools failing (e.g. model doesn't support it)
|
||||
mock_llm.bind_tools.side_effect = Exception("Not supported")
|
||||
mock_llm.invoke.return_value = AIMessage(content="Basic Result")
|
||||
|
||||
result = researcher_node(base_state)
|
||||
|
||||
# Should still succeed using the base LLM
|
||||
assert result["messages"][0].content == "Basic Result"
|
||||
mock_llm.invoke.assert_called_once()
|
||||
41
tests/test_state.py
Normal file
41
tests/test_state.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import pytest
|
||||
from typing import get_type_hints, List
|
||||
from langchain_core.messages import BaseMessage, HumanMessage
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
import operator
|
||||
|
||||
def test_agent_state_structure():
|
||||
"""Verify that AgentState has the required fields and types."""
|
||||
hints = get_type_hints(AgentState)
|
||||
|
||||
assert "messages" in hints
|
||||
# Check if Annotated is used, we might need to inspect the __metadata__ if feasible,
|
||||
# but for TypedDict, checking the key existence is a good start.
|
||||
# The exact type check for Annotated[List[BaseMessage], operator.add] can be complex to assert strictly,
|
||||
# but we can check if it's there.
|
||||
|
||||
assert "question" in hints
|
||||
assert hints["question"] == str
|
||||
|
||||
# analysis should be Optional[Dict[str, Any]] or similar, but the spec says "Dictionary"
|
||||
# Let's check it exists.
|
||||
assert "analysis" in hints
|
||||
|
||||
assert "next_action" in hints
|
||||
assert hints["next_action"] == str
|
||||
|
||||
assert "summary" in hints
|
||||
# summary should be Optional[str] or str. Let's assume Optional[str] for flexibility.
|
||||
|
||||
assert "plots" in hints
|
||||
assert "dfs" in hints
|
||||
|
||||
def test_messages_reducer_behavior():
|
||||
"""Verify that the messages field allows adding lists (simulation of operator.add)."""
|
||||
# This is harder to test directly on the TypedDict definition without instantiating it in a graph context,
|
||||
# but we can verify that the type hint implies a list.
|
||||
hints = get_type_hints(AgentState)
|
||||
# We expect messages to be Annotated[List[BaseMessage], operator.add]
|
||||
# We can just assume the developer implements it correctly if the previous test passes,
|
||||
# or try to inspect the annotation.
|
||||
pass
|
||||
48
tests/test_summarizer.py
Normal file
48
tests/test_summarizer.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from langchain_core.messages import AIMessage
|
||||
from ea_chatbot.graph.nodes.summarizer import summarizer_node
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm():
|
||||
with patch("ea_chatbot.graph.nodes.summarizer.get_llm_model") as mock_get_llm:
|
||||
mock_llm_instance = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm_instance
|
||||
yield mock_llm_instance
|
||||
|
||||
def test_summarizer_node_success(mock_llm):
|
||||
"""Test that summarizer_node invokes LLM with correct inputs and returns messages."""
|
||||
state = {
|
||||
"question": "What is the total count?",
|
||||
"plan": "1. Run query\n2. Sum results",
|
||||
"code_output": "The total is 100",
|
||||
"messages": []
|
||||
}
|
||||
|
||||
mock_llm.invoke.return_value = AIMessage(content="The final answer is 100.")
|
||||
|
||||
result = summarizer_node(state)
|
||||
|
||||
# Verify LLM was called
|
||||
assert mock_llm.invoke.called
|
||||
|
||||
# Verify result structure
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 1
|
||||
assert isinstance(result["messages"][0], AIMessage)
|
||||
assert result["messages"][0].content == "The final answer is 100."
|
||||
|
||||
def test_summarizer_node_empty_state(mock_llm):
|
||||
"""Test handling of empty or minimal state."""
|
||||
state = {
|
||||
"question": "Empty?",
|
||||
"messages": []
|
||||
}
|
||||
|
||||
mock_llm.invoke.return_value = AIMessage(content="No data provided.")
|
||||
|
||||
result = summarizer_node(state)
|
||||
|
||||
assert "messages" in result
|
||||
assert result["messages"][0].content == "No data provided."
|
||||
93
tests/test_workflow.py
Normal file
93
tests/test_workflow.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.graph.workflow import app
|
||||
from ea_chatbot.graph.nodes.query_analyzer import QueryAnalysis
|
||||
from ea_chatbot.schemas import TaskPlanResponse, TaskPlanContext, CodeGenerationResponse
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
|
||||
@patch("ea_chatbot.graph.nodes.planner.get_llm_model")
|
||||
@patch("ea_chatbot.graph.nodes.coder.get_llm_model")
|
||||
@patch("ea_chatbot.graph.nodes.summarizer.get_llm_model")
|
||||
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
|
||||
@patch("ea_chatbot.utils.database_inspection.get_data_summary")
|
||||
@patch("ea_chatbot.graph.nodes.executor.Settings")
|
||||
@patch("ea_chatbot.graph.nodes.executor.DBClient")
|
||||
def test_workflow_full_flow(mock_db_client, mock_settings, mock_get_summary, mock_researcher_llm, mock_summarizer_llm, mock_coder_llm, mock_planner_llm, mock_qa_llm):
|
||||
"""Test the flow from query_analyzer through planner to coder."""
|
||||
|
||||
# Mock Settings for Executor
|
||||
mock_settings_instance = MagicMock()
|
||||
mock_settings_instance.db_host = "localhost"
|
||||
mock_settings_instance.db_port = 5432
|
||||
mock_settings_instance.db_user = "user"
|
||||
mock_settings_instance.db_pswd = "pass"
|
||||
mock_settings_instance.db_name = "test_db"
|
||||
mock_settings_instance.db_table = "test_table"
|
||||
mock_settings.return_value = mock_settings_instance
|
||||
|
||||
# Mock DBClient
|
||||
mock_client_instance = MagicMock()
|
||||
mock_db_client.return_value = mock_client_instance
|
||||
|
||||
# 1. Mock Query Analyzer
|
||||
mock_qa_instance = MagicMock()
|
||||
mock_qa_llm.return_value = mock_qa_instance
|
||||
mock_qa_instance.with_structured_output.return_value.invoke.return_value = QueryAnalysis(
|
||||
data_required=["2024 results"],
|
||||
unknowns=[],
|
||||
ambiguities=[],
|
||||
conditions=[],
|
||||
next_action="plan"
|
||||
)
|
||||
|
||||
# 2. Mock Planner
|
||||
mock_planner_instance = MagicMock()
|
||||
mock_planner_llm.return_value = mock_planner_instance
|
||||
mock_get_summary.return_value = "Data summary"
|
||||
mock_planner_instance.with_structured_output.return_value.invoke.return_value = TaskPlanResponse(
|
||||
goal="Task Goal",
|
||||
reflection="Reflection",
|
||||
context=TaskPlanContext(initial_context="Ctx", assumptions=[], constraints=[]),
|
||||
steps=["Step 1"]
|
||||
)
|
||||
|
||||
# 3. Mock Coder
|
||||
mock_coder_instance = MagicMock()
|
||||
mock_coder_llm.return_value = mock_coder_instance
|
||||
mock_coder_instance.with_structured_output.return_value.invoke.return_value = CodeGenerationResponse(
|
||||
code="print('Hello')",
|
||||
explanation="Explanation"
|
||||
)
|
||||
|
||||
# 4. Mock Summarizer
|
||||
mock_summarizer_instance = MagicMock()
|
||||
mock_summarizer_llm.return_value = mock_summarizer_instance
|
||||
mock_summarizer_instance.invoke.return_value = AIMessage(content="Summary")
|
||||
|
||||
# 5. Mock Researcher (just in case)
|
||||
mock_researcher_instance = MagicMock()
|
||||
mock_researcher_llm.return_value = mock_researcher_instance
|
||||
|
||||
# Initial state
|
||||
initial_state = {
|
||||
"messages": [],
|
||||
"question": "Show me the 2024 results",
|
||||
"analysis": None,
|
||||
"next_action": "",
|
||||
"plan": None,
|
||||
"code": None,
|
||||
"error": None,
|
||||
"plots": [],
|
||||
"dfs": {}
|
||||
}
|
||||
|
||||
# Run the graph
|
||||
# We use recursion_limit to avoid infinite loops in placeholders if any
|
||||
result = app.invoke(initial_state, config={"recursion_limit": 10})
|
||||
|
||||
assert result["next_action"] == "plan"
|
||||
assert "plan" in result and result["plan"] is not None
|
||||
assert "code" in result and "print('Hello')" in result["code"]
|
||||
assert "analysis" in result
|
||||
139
tests/test_workflow_e2e.py
Normal file
139
tests/test_workflow_e2e.py
Normal file
@@ -0,0 +1,139 @@
|
||||
import pytest
|
||||
import yaml
|
||||
from unittest.mock import MagicMock, patch
|
||||
from langchain_core.messages import AIMessage
|
||||
from ea_chatbot.graph.workflow import app
|
||||
from ea_chatbot.graph.nodes.query_analyzer import QueryAnalysis
|
||||
from ea_chatbot.schemas import TaskPlanResponse, TaskPlanContext, CodeGenerationResponse
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llms():
|
||||
with patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model") as mock_qa_llm, \
|
||||
patch("ea_chatbot.graph.nodes.planner.get_llm_model") as mock_planner_llm, \
|
||||
patch("ea_chatbot.graph.nodes.coder.get_llm_model") as mock_coder_llm, \
|
||||
patch("ea_chatbot.graph.nodes.summarizer.get_llm_model") as mock_summarizer_llm, \
|
||||
patch("ea_chatbot.graph.nodes.researcher.get_llm_model") as mock_researcher_llm, \
|
||||
patch("ea_chatbot.graph.nodes.summarize_conversation.get_llm_model") as mock_summary_llm, \
|
||||
patch("ea_chatbot.utils.database_inspection.get_data_summary") as mock_get_summary:
|
||||
mock_get_summary.return_value = "Data summary"
|
||||
|
||||
# Mock summary LLM to return a simple response
|
||||
mock_summary_instance = MagicMock()
|
||||
mock_summary_llm.return_value = mock_summary_instance
|
||||
mock_summary_instance.invoke.return_value = AIMessage(content="Turn summary")
|
||||
|
||||
yield {
|
||||
"qa": mock_qa_llm,
|
||||
"planner": mock_planner_llm,
|
||||
"coder": mock_coder_llm,
|
||||
"summarizer": mock_summarizer_llm,
|
||||
"researcher": mock_researcher_llm,
|
||||
"summary": mock_summary_llm
|
||||
}
|
||||
|
||||
def test_workflow_data_analysis_flow(mock_llms):
|
||||
"""Test full flow: QueryAnalyzer -> Planner -> Coder -> Executor -> Summarizer."""
|
||||
|
||||
# 1. Mock Query Analyzer (routes to plan)
|
||||
mock_qa_instance = MagicMock()
|
||||
mock_llms["qa"].return_value = mock_qa_instance
|
||||
mock_qa_instance.with_structured_output.return_value.invoke.return_value = QueryAnalysis(
|
||||
data_required=["2024 results"],
|
||||
unknowns=[],
|
||||
ambiguities=[],
|
||||
conditions=[],
|
||||
next_action="plan"
|
||||
)
|
||||
|
||||
# 2. Mock Planner
|
||||
mock_planner_instance = MagicMock()
|
||||
mock_llms["planner"].return_value = mock_planner_instance
|
||||
mock_planner_instance.with_structured_output.return_value.invoke.return_value = TaskPlanResponse(
|
||||
goal="Get results",
|
||||
reflection="Reflect",
|
||||
context=TaskPlanContext(initial_context="Ctx", assumptions=[], constraints=[]),
|
||||
steps=["Step 1"]
|
||||
)
|
||||
|
||||
# 3. Mock Coder
|
||||
mock_coder_instance = MagicMock()
|
||||
mock_llms["coder"].return_value = mock_coder_instance
|
||||
mock_coder_instance.with_structured_output.return_value.invoke.return_value = CodeGenerationResponse(
|
||||
code="print('Execution Success')",
|
||||
explanation="Explain"
|
||||
)
|
||||
|
||||
# 4. Mock Summarizer
|
||||
mock_summarizer_instance = MagicMock()
|
||||
mock_llms["summarizer"].return_value = mock_summarizer_instance
|
||||
mock_summarizer_instance.invoke.return_value = AIMessage(content="Final Summary: Success")
|
||||
|
||||
# Initial state
|
||||
initial_state = {
|
||||
"messages": [],
|
||||
"question": "Show me 2024 results",
|
||||
"analysis": None,
|
||||
"next_action": "",
|
||||
"plan": None,
|
||||
"code": None,
|
||||
"error": None,
|
||||
"plots": [],
|
||||
"dfs": {}
|
||||
}
|
||||
|
||||
# Run the graph
|
||||
result = app.invoke(initial_state, config={"recursion_limit": 15})
|
||||
|
||||
assert result["next_action"] == "plan"
|
||||
assert "Execution Success" in result["code_output"]
|
||||
assert "Final Summary: Success" in result["messages"][-1].content
|
||||
|
||||
def test_workflow_research_flow(mock_llms):
|
||||
"""Test flow: QueryAnalyzer -> Researcher -> Summarizer."""
|
||||
|
||||
# 1. Mock Query Analyzer (routes to research)
|
||||
mock_qa_instance = MagicMock()
|
||||
mock_llms["qa"].return_value = mock_qa_instance
|
||||
mock_qa_instance.with_structured_output.return_value.invoke.return_value = QueryAnalysis(
|
||||
data_required=[],
|
||||
unknowns=[],
|
||||
ambiguities=[],
|
||||
conditions=[],
|
||||
next_action="research"
|
||||
)
|
||||
|
||||
# 2. Mock Researcher
|
||||
mock_researcher_instance = MagicMock()
|
||||
mock_llms["researcher"].return_value = mock_researcher_instance
|
||||
# Researcher node uses bind_tools if it's ChatOpenAI/ChatGoogleGenerativeAI
|
||||
# Since it's a MagicMock, it will fallback to using the base instance
|
||||
mock_researcher_instance.invoke.return_value = AIMessage(content="Research Results")
|
||||
|
||||
# Also mock bind_tools just in case we ever use spec
|
||||
mock_llm_with_tools = MagicMock()
|
||||
mock_researcher_instance.bind_tools.return_value = mock_llm_with_tools
|
||||
mock_llm_with_tools.invoke.return_value = AIMessage(content="Research Results")
|
||||
|
||||
# 3. Mock Summarizer (not used in this flow, but kept for completeness)
|
||||
mock_summarizer_instance = MagicMock()
|
||||
mock_llms["summarizer"].return_value = mock_summarizer_instance
|
||||
mock_summarizer_instance.invoke.return_value = AIMessage(content="Final Summary: Research Success")
|
||||
|
||||
# Initial state
|
||||
initial_state = {
|
||||
"messages": [],
|
||||
"question": "Who is the governor of Florida?",
|
||||
"analysis": None,
|
||||
"next_action": "",
|
||||
"plan": None,
|
||||
"code": None,
|
||||
"error": None,
|
||||
"plots": [],
|
||||
"dfs": {}
|
||||
}
|
||||
|
||||
# Run the graph
|
||||
result = app.invoke(initial_state, config={"recursion_limit": 10})
|
||||
|
||||
assert result["next_action"] == "research"
|
||||
assert "Research Results" in result["messages"][-1].content
|
||||
62
tests/test_workflow_history.py
Normal file
62
tests/test_workflow_history.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import pytest
|
||||
from ea_chatbot.history.manager import HistoryManager
|
||||
from ea_chatbot.history.models import User, Conversation, Message, Plot
|
||||
from ea_chatbot.config import Settings
|
||||
from sqlalchemy import delete
|
||||
|
||||
@pytest.fixture
|
||||
def history_manager():
|
||||
settings = Settings()
|
||||
manager = HistoryManager(settings.history_db_url)
|
||||
with manager.get_session() as session:
|
||||
session.execute(delete(Plot))
|
||||
session.execute(delete(Message))
|
||||
session.execute(delete(Conversation))
|
||||
session.execute(delete(User))
|
||||
return manager
|
||||
|
||||
def test_full_history_workflow(history_manager):
|
||||
# 1. Create and Authenticate User
|
||||
email = "e2e@example.com"
|
||||
password = "password123"
|
||||
history_manager.create_user(email, password, "E2E User")
|
||||
|
||||
user = history_manager.authenticate_user(email, password)
|
||||
assert user is not None
|
||||
assert user.display_name == "E2E User"
|
||||
|
||||
# 2. Create Conversation
|
||||
conv = history_manager.create_conversation(user.id, "nj", "Test Analytics")
|
||||
assert conv.id is not None
|
||||
|
||||
# 3. Add User Message
|
||||
history_manager.add_message(conv.id, "user", "How many voters in NJ?")
|
||||
|
||||
# 4. Add Assistant Message with Plot
|
||||
plot_data = b"fake_png_data"
|
||||
history_manager.add_message(
|
||||
conv.id,
|
||||
"assistant",
|
||||
"There are X voters.",
|
||||
plots=[plot_data]
|
||||
)
|
||||
|
||||
# 5. Retrieve and Verify History
|
||||
messages = history_manager.get_messages(conv.id)
|
||||
assert len(messages) == 2
|
||||
assert messages[0].role == "user"
|
||||
assert messages[1].role == "assistant"
|
||||
assert len(messages[1].plots) == 1
|
||||
assert messages[1].plots[0].image_data == plot_data
|
||||
|
||||
# 6. Verify Conversation listing
|
||||
convs = history_manager.get_conversations(user.id, "nj")
|
||||
assert len(convs) == 1
|
||||
assert convs[0].name == "Test Analytics"
|
||||
|
||||
# 7. Update summary
|
||||
history_manager.update_conversation_summary(conv.id, "Voter count analysis")
|
||||
|
||||
# 8. Reload and verify summary
|
||||
updated_convs = history_manager.get_conversations(user.id, "nj")
|
||||
assert updated_convs[0].summary == "Voter count analysis"
|
||||
Reference in New Issue
Block a user