Refactor: Move backend files to backend/ directory and split .gitignore
This commit is contained in:
56
backend/.env.example
Normal file
56
backend/.env.example
Normal file
@@ -0,0 +1,56 @@
|
||||
# API Keys
|
||||
OPENAI_API_KEY=your_openai_api_key_here
|
||||
GOOGLE_API_KEY=your_google_api_key_here
|
||||
|
||||
# App Configuration
|
||||
DATA_DIR=data
|
||||
DATA_STATE=new_jersey
|
||||
LOG_LEVEL=INFO
|
||||
DEV_MODE=false
|
||||
|
||||
# Security & JWT Configuration
|
||||
SECRET_KEY=change-me-in-production
|
||||
ALGORITHM=HS256
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=30
|
||||
|
||||
# Voter Database Configuration
|
||||
DB_HOST=localhost
|
||||
DB_PORT=5432
|
||||
DB_NAME=blockdata
|
||||
DB_USER=user
|
||||
DB_PSWD=password
|
||||
DB_TABLE=rd_gc_voters_nj
|
||||
|
||||
# Application/History Database Configuration
|
||||
HISTORY_DB_URL=postgresql://user:password@localhost:5433/ea_history
|
||||
|
||||
# OIDC Configuration (Authentik/SSO)
|
||||
OIDC_CLIENT_ID=your_client_id
|
||||
OIDC_CLIENT_SECRET=your_client_secret
|
||||
OIDC_SERVER_METADATA_URL=https://your-authentik.example.com/application/o/ea-chatbot/.well-known/openid-configuration
|
||||
OIDC_REDIRECT_URI=http://localhost:8501
|
||||
|
||||
# Node Configuration Overrides (Optional)
|
||||
# Format: <NODE_NAME>_LLM__<PARAMETER>
|
||||
# Possible parameters: PROVIDER, MODEL, TEMPERATURE, MAX_TOKENS
|
||||
|
||||
# Query Analyzer
|
||||
# QUERY_ANALYZER_LLM__PROVIDER=openai
|
||||
# QUERY_ANALYZER_LLM__MODEL=gpt-5-mini
|
||||
# QUERY_ANALYZER_LLM__TEMPERATURE=0.0
|
||||
|
||||
# Planner
|
||||
# PLANNER_LLM__PROVIDER=openai
|
||||
# PLANNER_LLM__MODEL=gpt-5-mini
|
||||
|
||||
# Coder
|
||||
# CODER_LLM__PROVIDER=openai
|
||||
# CODER_LLM__MODEL=gpt-5-mini
|
||||
|
||||
# Summarizer
|
||||
# SUMMARIZER_LLM__PROVIDER=openai
|
||||
# SUMMARIZER_LLM__MODEL=gpt-5-mini
|
||||
|
||||
# Researcher
|
||||
# RESEARCHER_LLM__PROVIDER=google
|
||||
# RESEARCHER_LLM__MODEL=gemini-2.0-flash
|
||||
126
backend/.gitignore
vendored
Normal file
126
backend/.gitignore
vendored
Normal file
@@ -0,0 +1,126 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# .python-version
|
||||
|
||||
# UV
|
||||
#uv.lock
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyPI configuration file
|
||||
.pypirc
|
||||
|
||||
# Project specific
|
||||
data/
|
||||
logs/
|
||||
149
backend/alembic.ini
Normal file
149
backend/alembic.ini
Normal file
@@ -0,0 +1,149 @@
|
||||
# A generic, single database configuration.
|
||||
|
||||
[alembic]
|
||||
# path to migration scripts.
|
||||
# this is typically a path given in POSIX (e.g. forward slashes)
|
||||
# format, relative to the token %(here)s which refers to the location of this
|
||||
# ini file
|
||||
script_location = %(here)s/alembic
|
||||
|
||||
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||
# Uncomment the line below if you want the files to be prepended with date and time
|
||||
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
|
||||
# for all available tokens
|
||||
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
||||
# Or organize into date-based subdirectories (requires recursive_version_locations = true)
|
||||
# file_template = %%(year)d/%%(month).2d/%%(day).2d_%%(hour).2d%%(minute).2d_%%(second).2d_%%(rev)s_%%(slug)s
|
||||
|
||||
# sys.path path, will be prepended to sys.path if present.
|
||||
# defaults to the current working directory. for multiple paths, the path separator
|
||||
# is defined by "path_separator" below.
|
||||
prepend_sys_path = .
|
||||
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
# as well as the filename.
|
||||
# If specified, requires the tzdata library which can be installed by adding
|
||||
# `alembic[tz]` to the pip requirements.
|
||||
# string value is passed to ZoneInfo()
|
||||
# leave blank for localtime
|
||||
# timezone =
|
||||
|
||||
# max length of characters to apply to the "slug" field
|
||||
# truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during
|
||||
# the 'revision' command, regardless of autogenerate
|
||||
# revision_environment = false
|
||||
|
||||
# set to 'true' to allow .pyc and .pyo files without
|
||||
# a source .py file to be detected as revisions in the
|
||||
# versions/ directory
|
||||
# sourceless = false
|
||||
|
||||
# version location specification; This defaults
|
||||
# to <script_location>/versions. When using multiple version
|
||||
# directories, initial revisions must be specified with --version-path.
|
||||
# The path separator used here should be the separator specified by "path_separator"
|
||||
# below.
|
||||
# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions
|
||||
|
||||
# path_separator; This indicates what character is used to split lists of file
|
||||
# paths, including version_locations and prepend_sys_path within configparser
|
||||
# files such as alembic.ini.
|
||||
# The default rendered in new alembic.ini files is "os", which uses os.pathsep
|
||||
# to provide os-dependent path splitting.
|
||||
#
|
||||
# Note that in order to support legacy alembic.ini files, this default does NOT
|
||||
# take place if path_separator is not present in alembic.ini. If this
|
||||
# option is omitted entirely, fallback logic is as follows:
|
||||
#
|
||||
# 1. Parsing of the version_locations option falls back to using the legacy
|
||||
# "version_path_separator" key, which if absent then falls back to the legacy
|
||||
# behavior of splitting on spaces and/or commas.
|
||||
# 2. Parsing of the prepend_sys_path option falls back to the legacy
|
||||
# behavior of splitting on spaces, commas, or colons.
|
||||
#
|
||||
# Valid values for path_separator are:
|
||||
#
|
||||
# path_separator = :
|
||||
# path_separator = ;
|
||||
# path_separator = space
|
||||
# path_separator = newline
|
||||
#
|
||||
# Use os.pathsep. Default configuration used for new projects.
|
||||
path_separator = os
|
||||
|
||||
# set to 'true' to search source files recursively
|
||||
# in each "version_locations" directory
|
||||
# new in Alembic version 1.10
|
||||
# recursive_version_locations = false
|
||||
|
||||
# the output encoding used when revision files
|
||||
# are written from script.py.mako
|
||||
# output_encoding = utf-8
|
||||
|
||||
# database URL. This is consumed by the user-maintained env.py script only.
|
||||
# other means of configuring database URLs may be customized within the env.py
|
||||
# file.
|
||||
sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||
|
||||
|
||||
[post_write_hooks]
|
||||
# post_write_hooks defines scripts or Python functions that are run
|
||||
# on newly generated revision scripts. See the documentation for further
|
||||
# detail and examples
|
||||
|
||||
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
||||
# hooks = black
|
||||
# black.type = console_scripts
|
||||
# black.entrypoint = black
|
||||
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||
|
||||
# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module
|
||||
# hooks = ruff
|
||||
# ruff.type = module
|
||||
# ruff.module = ruff
|
||||
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Alternatively, use the exec runner to execute a binary found on your PATH
|
||||
# hooks = ruff
|
||||
# ruff.type = exec
|
||||
# ruff.executable = ruff
|
||||
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Logging configuration. This is also consumed by the user-maintained
|
||||
# env.py script only.
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARNING
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARNING
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
1
backend/alembic/README
Normal file
1
backend/alembic/README
Normal file
@@ -0,0 +1 @@
|
||||
Generic single-database configuration.
|
||||
83
backend/alembic/env.py
Normal file
83
backend/alembic/env.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from logging.config import fileConfig
|
||||
import os
|
||||
import sys
|
||||
|
||||
from sqlalchemy import engine_from_config
|
||||
from sqlalchemy import pool
|
||||
|
||||
from alembic import context
|
||||
|
||||
# Add src to path to ensure we can import the app
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../src'))
|
||||
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.history.models import Base
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# add your model's MetaData object here
|
||||
# for 'autogenerate' support
|
||||
target_metadata = Base.metadata
|
||||
|
||||
# Override sqlalchemy.url from settings
|
||||
settings = Settings()
|
||||
config.set_main_option("sqlalchemy.url", settings.history_db_url)
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode.
|
||||
|
||||
This configures the context with just a URL
|
||||
and not an Engine, though an Engine is acceptable
|
||||
here as well. By skipping the Engine creation
|
||||
we don't even need a DBAPI to be available.
|
||||
|
||||
Calls to context.execute() here emit the given string to the
|
||||
script output.
|
||||
|
||||
"""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode.
|
||||
|
||||
In this scenario we need to create an Engine
|
||||
and associate a connection with the context.
|
||||
|
||||
"""
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(
|
||||
connection=connection, target_metadata=target_metadata
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
28
backend/alembic/script.py.mako
Normal file
28
backend/alembic/script.py.mako
Normal file
@@ -0,0 +1,28 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
${downgrades if downgrades else "pass"}
|
||||
@@ -0,0 +1,32 @@
|
||||
"""Add summary to conversation
|
||||
|
||||
Revision ID: 63886baa1255
|
||||
Revises: a15fde9a62df
|
||||
Create Date: 2026-02-07 22:34:47.254569
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '63886baa1255'
|
||||
down_revision: Union[str, Sequence[str], None] = 'a15fde9a62df'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('conversations', sa.Column('summary', sa.Text(), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column('conversations', 'summary')
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,68 @@
|
||||
"""Initial history tables
|
||||
|
||||
Revision ID: a15fde9a62df
|
||||
Revises:
|
||||
Create Date: 2026-02-07 21:18:57.524534
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'a15fde9a62df'
|
||||
down_revision: Union[str, Sequence[str], None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('users',
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('username', sa.String(), nullable=False),
|
||||
sa.Column('password_hash', sa.String(), nullable=True),
|
||||
sa.Column('display_name', sa.String(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_users_username'), 'users', ['username'], unique=True)
|
||||
op.create_table('conversations',
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('user_id', sa.String(), nullable=False),
|
||||
sa.Column('data_state', sa.String(), nullable=False),
|
||||
sa.Column('name', sa.String(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('messages',
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('conversation_id', sa.String(), nullable=False),
|
||||
sa.Column('role', sa.String(), nullable=False),
|
||||
sa.Column('content', sa.Text(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['conversation_id'], ['conversations.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('plots',
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('message_id', sa.String(), nullable=False),
|
||||
sa.Column('image_data', sa.LargeBinary(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['message_id'], ['messages.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table('plots')
|
||||
op.drop_table('messages')
|
||||
op.drop_table('conversations')
|
||||
op.drop_index(op.f('ix_users_username'), table_name='users')
|
||||
op.drop_table('users')
|
||||
# ### end Alembic commands ###
|
||||
47
backend/pyproject.toml
Normal file
47
backend/pyproject.toml
Normal file
@@ -0,0 +1,47 @@
|
||||
[project]
|
||||
name = "ea-chatbot"
|
||||
version = "0.1.0"
|
||||
description = "An election analytics chatbot using LangGraph."
|
||||
requires-python = ">=3.13"
|
||||
dependencies = [
|
||||
"langgraph",
|
||||
"langchain",
|
||||
"langchain-openai",
|
||||
"langchain-google-genai",
|
||||
"langchain-community",
|
||||
"pandas",
|
||||
"matplotlib",
|
||||
"psycopg2-binary",
|
||||
"psycopg[binary]>=3.1.0",
|
||||
"termcolor>=3.3.0",
|
||||
"sqlalchemy>=2.0.46",
|
||||
"pyyaml>=6.0.3",
|
||||
"streamlit>=1.54.0",
|
||||
"rich>=14.3.2",
|
||||
"alembic>=1.13.0",
|
||||
"authlib>=1.3.0",
|
||||
"httpx>=0.27.0",
|
||||
"argon2-cffi>=23.1.0",
|
||||
"email-validator>=2.1.0",
|
||||
"fastapi>=0.109.0",
|
||||
"uvicorn>=0.27.0",
|
||||
"python-jose[cryptography]>=3.3.0",
|
||||
"python-multipart>=0.0.9",
|
||||
"langgraph-checkpoint-postgres>=2.0.0",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/ea_chatbot"]
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"pytest>=9.0.2",
|
||||
"pytest-cov>=7.0.0",
|
||||
"pytest-asyncio>=0.23.0",
|
||||
"ruff>=0.9.3",
|
||||
"mypy>=1.14.1",
|
||||
]
|
||||
0
backend/src/ea_chatbot/__init__.py
Normal file
0
backend/src/ea_chatbot/__init__.py
Normal file
0
backend/src/ea_chatbot/api/__init__.py
Normal file
0
backend/src/ea_chatbot/api/__init__.py
Normal file
46
backend/src/ea_chatbot/api/dependencies.py
Normal file
46
backend/src/ea_chatbot/api/dependencies.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import os
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.history.manager import HistoryManager
|
||||
from ea_chatbot.auth import OIDCClient
|
||||
from ea_chatbot.api.utils import decode_access_token
|
||||
from ea_chatbot.history.models import User
|
||||
|
||||
settings = Settings()
|
||||
|
||||
# Shared instances
|
||||
history_manager = HistoryManager(settings.history_db_url)
|
||||
|
||||
oidc_client = None
|
||||
if settings.oidc_client_id and settings.oidc_client_secret and settings.oidc_server_metadata_url:
|
||||
oidc_client = OIDCClient(
|
||||
client_id=settings.oidc_client_id,
|
||||
client_secret=settings.oidc_client_secret,
|
||||
server_metadata_url=settings.oidc_server_metadata_url,
|
||||
redirect_uri=os.getenv("OIDC_REDIRECT_URI", "http://localhost:3000/auth/callback")
|
||||
)
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login")
|
||||
|
||||
async def get_current_user(token: str = Depends(oauth2_scheme)) -> User:
|
||||
"""Dependency to get the current authenticated user from the JWT token."""
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
payload = decode_access_token(token)
|
||||
if payload is None:
|
||||
raise credentials_exception
|
||||
|
||||
user_id: str | None = payload.get("sub")
|
||||
if user_id is None:
|
||||
raise credentials_exception
|
||||
|
||||
user = history_manager.get_user_by_id(user_id)
|
||||
if user is None:
|
||||
raise credentials_exception
|
||||
|
||||
return user
|
||||
34
backend/src/ea_chatbot/api/main.py
Normal file
34
backend/src/ea_chatbot/api/main.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from ea_chatbot.api.routers import auth, history, artifacts, agent
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
app = FastAPI(
|
||||
title="Election Analytics Chatbot API",
|
||||
description="Backend API for the LangGraph-based Election Analytics Chatbot",
|
||||
version="0.1.0"
|
||||
)
|
||||
|
||||
# Configure CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # Adjust for production
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.include_router(auth.router)
|
||||
app.include_router(history.router)
|
||||
app.include_router(artifacts.router)
|
||||
app.include_router(agent.router)
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
return {"status": "ok"}
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
0
backend/src/ea_chatbot/api/routers/__init__.py
Normal file
0
backend/src/ea_chatbot/api/routers/__init__.py
Normal file
162
backend/src/ea_chatbot/api/routers/agent.py
Normal file
162
backend/src/ea_chatbot/api/routers/agent.py
Normal file
@@ -0,0 +1,162 @@
|
||||
import json
|
||||
import asyncio
|
||||
from typing import AsyncGenerator, Optional, List
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
from ea_chatbot.api.dependencies import get_current_user, history_manager
|
||||
from ea_chatbot.api.utils import convert_to_json_compatible
|
||||
from ea_chatbot.graph.workflow import app
|
||||
from ea_chatbot.graph.checkpoint import get_checkpointer
|
||||
from ea_chatbot.history.models import User as UserDB, Conversation
|
||||
from ea_chatbot.api.schemas import ChatRequest
|
||||
import io
|
||||
import base64
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
router = APIRouter(prefix="/chat", tags=["agent"])
|
||||
|
||||
async def stream_agent_events(
|
||||
message: str,
|
||||
thread_id: str,
|
||||
user_id: str,
|
||||
summary: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Generator that invokes the LangGraph agent and yields SSE formatted events.
|
||||
Persists assistant responses and plots to the database.
|
||||
"""
|
||||
initial_state = {
|
||||
"messages": [],
|
||||
"question": message,
|
||||
"summary": summary,
|
||||
"analysis": None,
|
||||
"next_action": "",
|
||||
"plan": None,
|
||||
"code": None,
|
||||
"code_output": None,
|
||||
"error": None,
|
||||
"plots": [],
|
||||
"dfs": {}
|
||||
}
|
||||
|
||||
config = {"configurable": {"thread_id": thread_id}}
|
||||
|
||||
assistant_chunks: List[str] = []
|
||||
assistant_plots: List[bytes] = []
|
||||
final_response: str = ""
|
||||
new_summary: str = ""
|
||||
|
||||
try:
|
||||
async with get_checkpointer() as checkpointer:
|
||||
async for event in app.astream_events(
|
||||
initial_state,
|
||||
config,
|
||||
version="v2",
|
||||
checkpointer=checkpointer
|
||||
):
|
||||
kind = event.get("event")
|
||||
name = event.get("name")
|
||||
data = event.get("data", {})
|
||||
|
||||
# Standardize event for frontend
|
||||
output_event = {
|
||||
"type": kind,
|
||||
"name": name,
|
||||
"data": data
|
||||
}
|
||||
|
||||
# Buffer assistant chunks (summarizer and researcher might stream)
|
||||
if kind == "on_chat_model_stream" and name in ["summarizer", "researcher"]:
|
||||
chunk = data.get("chunk", "")
|
||||
# Use utility to safely extract text content from the chunk
|
||||
chunk_data = convert_to_json_compatible(chunk)
|
||||
if isinstance(chunk_data, dict) and "content" in chunk_data:
|
||||
assistant_chunks.append(str(chunk_data["content"]))
|
||||
else:
|
||||
assistant_chunks.append(str(chunk_data))
|
||||
|
||||
# Buffer and encode plots
|
||||
if kind == "on_chain_end" and name == "executor":
|
||||
output = data.get("output", {})
|
||||
if isinstance(output, dict) and "plots" in output:
|
||||
plots = output["plots"]
|
||||
encoded_plots = []
|
||||
for fig in plots:
|
||||
buf = io.BytesIO()
|
||||
fig.savefig(buf, format="png")
|
||||
plot_bytes = buf.getvalue()
|
||||
assistant_plots.append(plot_bytes)
|
||||
encoded_plots.append(base64.b64encode(plot_bytes).decode('utf-8'))
|
||||
output_event["data"]["encoded_plots"] = encoded_plots
|
||||
|
||||
# Collect final response from terminal nodes
|
||||
if kind == "on_chain_end" and name in ["summarizer", "researcher", "clarification"]:
|
||||
output = data.get("output", {})
|
||||
if isinstance(output, dict) and "messages" in output:
|
||||
last_msg = output["messages"][-1]
|
||||
|
||||
# Use centralized utility to extract clean text content
|
||||
# Since convert_to_json_compatible returns a dict for BaseMessage,
|
||||
# we can extract 'content' from it.
|
||||
msg_data = convert_to_json_compatible(last_msg)
|
||||
if isinstance(msg_data, dict) and "content" in msg_data:
|
||||
final_response = msg_data["content"]
|
||||
else:
|
||||
final_response = str(msg_data)
|
||||
|
||||
# Collect new summary
|
||||
if kind == "on_chain_end" and name == "summarize_conversation":
|
||||
output = data.get("output", {})
|
||||
if isinstance(output, dict) and "summary" in output:
|
||||
new_summary = output["summary"]
|
||||
|
||||
# Convert to JSON compatible format to avoid serialization errors
|
||||
compatible_output = convert_to_json_compatible(output_event)
|
||||
yield f"data: {json.dumps(compatible_output)}\n\n"
|
||||
|
||||
# If we didn't get a final_response from node output, use buffered chunks
|
||||
if not final_response and assistant_chunks:
|
||||
final_response = "".join(assistant_chunks)
|
||||
|
||||
# Save assistant message to DB
|
||||
if final_response:
|
||||
history_manager.add_message(thread_id, "assistant", final_response, plots=assistant_plots)
|
||||
|
||||
# Update summary in DB
|
||||
if new_summary:
|
||||
history_manager.update_conversation_summary(thread_id, new_summary)
|
||||
|
||||
yield "data: {\"type\": \"done\"}\n\n"
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Agent execution failed: {str(e)}"
|
||||
history_manager.add_message(thread_id, "assistant", error_msg)
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': error_msg})}\n\n"
|
||||
|
||||
@router.post("/stream")
|
||||
async def chat_stream(
|
||||
request: ChatRequest,
|
||||
current_user: UserDB = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Stream agent execution events via SSE.
|
||||
"""
|
||||
with history_manager.get_session() as session:
|
||||
conv = session.get(Conversation, request.thread_id)
|
||||
if not conv:
|
||||
raise HTTPException(status_code=404, detail="Conversation not found")
|
||||
if conv.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Not authorized to access this conversation")
|
||||
|
||||
# Save user message immediately
|
||||
history_manager.add_message(request.thread_id, "user", request.message)
|
||||
|
||||
return StreamingResponse(
|
||||
stream_agent_events(
|
||||
request.message,
|
||||
request.thread_id,
|
||||
current_user.id,
|
||||
request.summary or ""
|
||||
),
|
||||
media_type="text/event-stream"
|
||||
)
|
||||
45
backend/src/ea_chatbot/api/routers/artifacts.py
Normal file
45
backend/src/ea_chatbot/api/routers/artifacts.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Response, status
|
||||
from ea_chatbot.api.dependencies import get_current_user, history_manager
|
||||
from ea_chatbot.history.models import Plot, Message, User as UserDB
|
||||
import io
|
||||
|
||||
router = APIRouter(prefix="/artifacts", tags=["artifacts"])
|
||||
|
||||
@router.get("/plots/{plot_id}")
|
||||
async def get_plot(
|
||||
plot_id: str,
|
||||
current_user: UserDB = Depends(get_current_user)
|
||||
):
|
||||
"""Retrieve a binary plot image (PNG)."""
|
||||
with history_manager.get_session() as session:
|
||||
plot = session.get(Plot, plot_id)
|
||||
if not plot:
|
||||
raise HTTPException(status_code=404, detail="Plot not found")
|
||||
|
||||
# Verify ownership via message -> conversation -> user
|
||||
message = session.get(Message, plot.message_id)
|
||||
if not message:
|
||||
raise HTTPException(status_code=404, detail="Associated message not found")
|
||||
|
||||
# In a real app, we should check message.conversation.user_id == current_user.id
|
||||
# For now, we'll assume the client has the ID correctly.
|
||||
# But let's do a basic check since it's "secure artifact access".
|
||||
if message.conversation.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Not authorized to access this artifact")
|
||||
|
||||
return Response(content=plot.image_data, media_type="image/png")
|
||||
|
||||
@router.get("/data/{message_id}")
|
||||
async def get_message_data(
|
||||
message_id: str,
|
||||
current_user: UserDB = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Retrieve structured dataframe data associated with a message.
|
||||
Currently returns 404 as dataframes are not yet persisted in the DB.
|
||||
"""
|
||||
# TODO: Implement persistence for DataFrames in Phase 4 or a future track
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Structured data not found for this message"
|
||||
)
|
||||
86
backend/src/ea_chatbot/api/routers/auth.py
Normal file
86
backend/src/ea_chatbot/api/routers/auth.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from ea_chatbot.api.utils import create_access_token
|
||||
from ea_chatbot.api.dependencies import history_manager, oidc_client, get_current_user
|
||||
from ea_chatbot.history.models import User as UserDB
|
||||
from ea_chatbot.api.schemas import Token, UserCreate, UserResponse
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def register(user_in: UserCreate):
|
||||
"""Register a new user."""
|
||||
user = history_manager.get_user(user_in.email)
|
||||
if user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="User already exists"
|
||||
)
|
||||
|
||||
user = history_manager.create_user(
|
||||
email=user_in.email,
|
||||
password=user_in.password,
|
||||
display_name=user_in.display_name
|
||||
)
|
||||
return {
|
||||
"id": str(user.id),
|
||||
"email": user.username,
|
||||
"display_name": user.display_name
|
||||
}
|
||||
|
||||
@router.post("/login", response_model=Token)
|
||||
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
|
||||
"""Login with email and password to get a JWT."""
|
||||
user = history_manager.authenticate_user(form_data.username, form_data.password)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect email or password",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
access_token = create_access_token(data={"sub": str(user.id)})
|
||||
return {"access_token": access_token, "token_type": "bearer"}
|
||||
|
||||
@router.get("/oidc/login")
|
||||
async def oidc_login():
|
||||
"""Get the OIDC authorization URL."""
|
||||
if not oidc_client:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_510_NOT_EXTENDED,
|
||||
detail="OIDC is not configured"
|
||||
)
|
||||
|
||||
url = oidc_client.get_login_url()
|
||||
return {"url": url}
|
||||
|
||||
@router.get("/oidc/callback", response_model=Token)
|
||||
async def oidc_callback(code: str):
|
||||
"""Handle the OIDC callback and issue a JWT."""
|
||||
if not oidc_client:
|
||||
raise HTTPException(status_code=status.HTTP_510_NOT_EXTENDED, detail="OIDC not configured")
|
||||
|
||||
try:
|
||||
token = oidc_client.exchange_code_for_token(code)
|
||||
user_info = oidc_client.get_user_info(token)
|
||||
email = user_info.get("email")
|
||||
name = user_info.get("name") or user_info.get("preferred_username")
|
||||
|
||||
if not email:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Email not provided by OIDC")
|
||||
|
||||
user = history_manager.sync_user_from_oidc(email=email, display_name=name)
|
||||
|
||||
access_token = create_access_token(data={"sub": str(user.id)})
|
||||
return {"access_token": access_token, "token_type": "bearer"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=f"OIDC authentication failed: {str(e)}")
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
async def get_me(current_user: UserDB = Depends(get_current_user)):
|
||||
"""Get the current authenticated user's profile."""
|
||||
return {
|
||||
"id": str(current_user.id),
|
||||
"email": current_user.username,
|
||||
"display_name": current_user.display_name
|
||||
}
|
||||
99
backend/src/ea_chatbot/api/routers/history.py
Normal file
99
backend/src/ea_chatbot/api/routers/history.py
Normal file
@@ -0,0 +1,99 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Response
|
||||
from typing import List, Optional
|
||||
from ea_chatbot.api.dependencies import get_current_user, history_manager, settings
|
||||
from ea_chatbot.history.models import User as UserDB
|
||||
from ea_chatbot.api.schemas import ConversationResponse, MessageResponse, ConversationUpdate, ConversationCreate
|
||||
|
||||
router = APIRouter(prefix="/conversations", tags=["history"])
|
||||
|
||||
@router.post("", response_model=ConversationResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_conversation(
|
||||
conv_in: ConversationCreate,
|
||||
current_user: UserDB = Depends(get_current_user)
|
||||
):
|
||||
"""Create a new conversation."""
|
||||
state = conv_in.data_state or settings.data_state
|
||||
conv = history_manager.create_conversation(
|
||||
user_id=current_user.id,
|
||||
data_state=state,
|
||||
name=conv_in.name,
|
||||
summary=conv_in.summary
|
||||
)
|
||||
return {
|
||||
"id": str(conv.id),
|
||||
"name": conv.name,
|
||||
"summary": conv.summary,
|
||||
"created_at": conv.created_at,
|
||||
"data_state": conv.data_state
|
||||
}
|
||||
|
||||
@router.get("", response_model=List[ConversationResponse])
|
||||
async def list_conversations(
|
||||
current_user: UserDB = Depends(get_current_user),
|
||||
data_state: Optional[str] = None
|
||||
):
|
||||
"""List all conversations for the authenticated user."""
|
||||
# Use settings default if not provided
|
||||
state = data_state or settings.data_state
|
||||
conversations = history_manager.get_conversations(current_user.id, state)
|
||||
return [
|
||||
{
|
||||
"id": str(c.id),
|
||||
"name": c.name,
|
||||
"summary": c.summary,
|
||||
"created_at": c.created_at,
|
||||
"data_state": c.data_state
|
||||
} for c in conversations
|
||||
]
|
||||
|
||||
@router.get("/{conversation_id}/messages", response_model=List[MessageResponse])
|
||||
async def get_conversation_messages(
|
||||
conversation_id: str,
|
||||
current_user: UserDB = Depends(get_current_user)
|
||||
):
|
||||
"""Get all messages for a specific conversation."""
|
||||
# TODO: Verify that the conversation belongs to the user
|
||||
messages = history_manager.get_messages(conversation_id)
|
||||
return [
|
||||
{
|
||||
"id": str(m.id),
|
||||
"role": m.role,
|
||||
"content": m.content,
|
||||
"created_at": m.created_at
|
||||
} for m in messages
|
||||
]
|
||||
|
||||
@router.patch("/{conversation_id}", response_model=ConversationResponse)
|
||||
async def update_conversation(
|
||||
conversation_id: str,
|
||||
update: ConversationUpdate,
|
||||
current_user: UserDB = Depends(get_current_user)
|
||||
):
|
||||
"""Rename or update the summary of a conversation."""
|
||||
conv = None
|
||||
if update.name:
|
||||
conv = history_manager.rename_conversation(conversation_id, update.name)
|
||||
if update.summary:
|
||||
conv = history_manager.update_conversation_summary(conversation_id, update.summary)
|
||||
|
||||
if not conv:
|
||||
raise HTTPException(status_code=404, detail="Conversation not found")
|
||||
|
||||
return {
|
||||
"id": str(conv.id),
|
||||
"name": conv.name,
|
||||
"summary": conv.summary,
|
||||
"created_at": conv.created_at,
|
||||
"data_state": conv.data_state
|
||||
}
|
||||
|
||||
@router.delete("/{conversation_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_conversation(
|
||||
conversation_id: str,
|
||||
current_user: UserDB = Depends(get_current_user)
|
||||
):
|
||||
"""Delete a conversation."""
|
||||
success = history_manager.delete_conversation(conversation_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Conversation not found")
|
||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||
51
backend/src/ea_chatbot/api/schemas.py
Normal file
51
backend/src/ea_chatbot/api/schemas.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from pydantic import BaseModel, EmailStr
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
|
||||
# --- Auth Schemas ---
|
||||
|
||||
class Token(BaseModel):
|
||||
access_token: str
|
||||
token_type: str
|
||||
|
||||
class UserCreate(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
display_name: Optional[str] = None
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
id: str
|
||||
email: str
|
||||
display_name: str
|
||||
|
||||
# --- History Schemas ---
|
||||
|
||||
class ConversationResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
summary: Optional[str] = None
|
||||
created_at: datetime
|
||||
data_state: str
|
||||
|
||||
class ConversationCreate(BaseModel):
|
||||
name: str
|
||||
data_state: Optional[str] = None
|
||||
summary: Optional[str] = None
|
||||
|
||||
class MessageResponse(BaseModel):
|
||||
id: str
|
||||
role: str
|
||||
content: str
|
||||
created_at: datetime
|
||||
# Plots are fetched separately via artifact endpoints
|
||||
|
||||
class ConversationUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
summary: Optional[str] = None
|
||||
|
||||
# --- Agent Schemas ---
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
message: str
|
||||
thread_id: str # Maps to conversation_id
|
||||
summary: Optional[str] = ""
|
||||
94
backend/src/ea_chatbot/api/utils.py
Normal file
94
backend/src/ea_chatbot/api/utils.py
Normal file
@@ -0,0 +1,94 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional, Union, Any, List, Dict
|
||||
from jose import JWTError, jwt
|
||||
from pydantic import BaseModel
|
||||
from langchain_core.messages import BaseMessage
|
||||
from ea_chatbot.config import Settings
|
||||
|
||||
settings = Settings()
|
||||
|
||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||
"""
|
||||
Create a JWT access token.
|
||||
|
||||
Args:
|
||||
data: The payload data to encode.
|
||||
expires_delta: Optional expiration time delta.
|
||||
|
||||
Returns:
|
||||
str: The encoded JWT token.
|
||||
"""
|
||||
to_encode = data.copy()
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
if expires_delta:
|
||||
expire = now + expires_delta
|
||||
else:
|
||||
expire = now + timedelta(minutes=settings.access_token_expire_minutes)
|
||||
|
||||
to_encode.update({
|
||||
"exp": expire,
|
||||
"iat": now,
|
||||
"iss": "ea-chatbot-api"
|
||||
})
|
||||
encoded_jwt = jwt.encode(to_encode, settings.secret_key, algorithm=settings.algorithm)
|
||||
return encoded_jwt
|
||||
|
||||
def decode_access_token(token: str) -> Optional[dict]:
|
||||
"""
|
||||
Decode a JWT access token.
|
||||
|
||||
Args:
|
||||
token: The token to decode.
|
||||
|
||||
Returns:
|
||||
Optional[dict]: The decoded payload if valid, None otherwise.
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(token, settings.secret_key, algorithms=[settings.algorithm])
|
||||
return payload
|
||||
except JWTError:
|
||||
return None
|
||||
|
||||
def convert_to_json_compatible(obj: Any) -> Any:
|
||||
"""Recursively convert LangChain objects, Pydantic models, and others to JSON compatible formats."""
|
||||
if isinstance(obj, list):
|
||||
return [convert_to_json_compatible(item) for item in obj]
|
||||
elif isinstance(obj, dict):
|
||||
return {k: convert_to_json_compatible(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, BaseMessage):
|
||||
# Handle content that might be a list of blocks (e.g. from Gemini or OpenAI tools)
|
||||
content = obj.content
|
||||
if isinstance(content, list):
|
||||
text_parts = []
|
||||
for block in content:
|
||||
if isinstance(block, str):
|
||||
text_parts.append(block)
|
||||
elif isinstance(block, dict):
|
||||
if block.get("type") == "text":
|
||||
text_parts.append(block.get("text", ""))
|
||||
# You could also handle other block types if needed
|
||||
content = "".join(text_parts)
|
||||
|
||||
# Prefer .text property if available (common in some message types)
|
||||
if hasattr(obj, "text") and isinstance(obj.text, str) and obj.text:
|
||||
content = obj.text
|
||||
|
||||
return {"type": obj.type, "content": content, **convert_to_json_compatible(obj.additional_kwargs)}
|
||||
elif isinstance(obj, BaseModel):
|
||||
return convert_to_json_compatible(obj.model_dump())
|
||||
elif hasattr(obj, "model_dump"): # For Pydantic v2 if not caught by BaseModel
|
||||
try:
|
||||
return convert_to_json_compatible(obj.model_dump())
|
||||
except Exception:
|
||||
return str(obj)
|
||||
elif hasattr(obj, "dict"): # Fallback for Pydantic v1 or other objects
|
||||
try:
|
||||
return convert_to_json_compatible(obj.dict())
|
||||
except Exception:
|
||||
return str(obj)
|
||||
elif hasattr(obj, "content"):
|
||||
return str(obj.content)
|
||||
elif isinstance(obj, (datetime, timezone)):
|
||||
return obj.isoformat()
|
||||
return obj
|
||||
451
backend/src/ea_chatbot/app.py
Normal file
451
backend/src/ea_chatbot/app.py
Normal file
@@ -0,0 +1,451 @@
|
||||
import streamlit as st
|
||||
import asyncio
|
||||
import os
|
||||
import io
|
||||
from dotenv import load_dotenv
|
||||
from ea_chatbot.graph.workflow import app
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.utils.logging import get_logger
|
||||
from ea_chatbot.utils.helpers import merge_agent_state
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.history.manager import HistoryManager
|
||||
from ea_chatbot.auth import OIDCClient, AuthType, get_user_auth_type
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Initialize Config and Manager
|
||||
settings = Settings()
|
||||
history_manager = HistoryManager(settings.history_db_url)
|
||||
|
||||
# Initialize OIDC Client if configured
|
||||
oidc_client = None
|
||||
if settings.oidc_client_id and settings.oidc_client_secret and settings.oidc_server_metadata_url:
|
||||
oidc_client = OIDCClient(
|
||||
client_id=settings.oidc_client_id,
|
||||
client_secret=settings.oidc_client_secret,
|
||||
server_metadata_url=settings.oidc_server_metadata_url,
|
||||
# Redirect back to the same page
|
||||
redirect_uri=os.getenv("OIDC_REDIRECT_URI", "http://localhost:8501")
|
||||
)
|
||||
|
||||
# Initialize Logger
|
||||
logger = get_logger(level=settings.log_level, log_file="logs/app.jsonl")
|
||||
|
||||
# --- Authentication Helpers ---
|
||||
|
||||
def login_user(user):
|
||||
st.session_state.user = user
|
||||
st.session_state.messages = []
|
||||
st.session_state.summary = ""
|
||||
st.session_state.current_conversation_id = None
|
||||
st.rerun()
|
||||
|
||||
def logout_user():
|
||||
for key in list(st.session_state.keys()):
|
||||
del st.session_state[key]
|
||||
st.rerun()
|
||||
|
||||
def load_conversation(conv_id):
|
||||
messages = history_manager.get_messages(conv_id)
|
||||
formatted_messages = []
|
||||
for m in messages:
|
||||
# Convert DB models to session state dicts
|
||||
msg_dict = {
|
||||
"role": m.role,
|
||||
"content": m.content,
|
||||
"plots": [p.image_data for p in m.plots]
|
||||
}
|
||||
formatted_messages.append(msg_dict)
|
||||
|
||||
st.session_state.messages = formatted_messages
|
||||
st.session_state.current_conversation_id = conv_id
|
||||
# Fetch summary from DB
|
||||
with history_manager.get_session() as session:
|
||||
from ea_chatbot.history.models import Conversation
|
||||
conv = session.get(Conversation, conv_id)
|
||||
st.session_state.summary = conv.summary if conv else ""
|
||||
st.rerun()
|
||||
|
||||
def main():
|
||||
st.set_page_config(
|
||||
page_title="Election Analytics Chatbot",
|
||||
page_icon="🗳️",
|
||||
layout="wide"
|
||||
)
|
||||
|
||||
# Check for OIDC Callback
|
||||
if "code" in st.query_params and oidc_client:
|
||||
code = st.query_params["code"]
|
||||
try:
|
||||
token = oidc_client.exchange_code_for_token(code)
|
||||
user_info = oidc_client.get_user_info(token)
|
||||
email = user_info.get("email")
|
||||
name = user_info.get("name") or user_info.get("preferred_username")
|
||||
|
||||
if email:
|
||||
user = history_manager.sync_user_from_oidc(email=email, display_name=name)
|
||||
# Clear query params
|
||||
st.query_params.clear()
|
||||
login_user(user)
|
||||
except Exception as e:
|
||||
st.error(f"OIDC Login failed: {str(e)}")
|
||||
|
||||
# Display Login Screen if not authenticated
|
||||
if "user" not in st.session_state:
|
||||
st.title("🗳️ Election Analytics Chatbot")
|
||||
|
||||
# Initialize Login State
|
||||
if "login_step" not in st.session_state:
|
||||
st.session_state.login_step = "email"
|
||||
if "login_email" not in st.session_state:
|
||||
st.session_state.login_email = ""
|
||||
|
||||
col1, col2 = st.columns([1, 1])
|
||||
|
||||
with col1:
|
||||
st.header("Login")
|
||||
|
||||
# Step 1: Identification
|
||||
if st.session_state.login_step == "email":
|
||||
st.write("Please enter your email to begin:")
|
||||
with st.form("email_form"):
|
||||
email_input = st.text_input("Email", value=st.session_state.login_email)
|
||||
submitted = st.form_submit_button("Next")
|
||||
|
||||
if submitted:
|
||||
if not email_input.strip():
|
||||
st.error("Email cannot be empty.")
|
||||
else:
|
||||
st.session_state.login_email = email_input.strip()
|
||||
auth_type = get_user_auth_type(st.session_state.login_email, history_manager)
|
||||
|
||||
if auth_type == AuthType.LOCAL:
|
||||
st.session_state.login_step = "login_password"
|
||||
elif auth_type == AuthType.OIDC:
|
||||
st.session_state.login_step = "oidc_login"
|
||||
else: # AuthType.NEW
|
||||
st.session_state.login_step = "register_details"
|
||||
st.rerun()
|
||||
|
||||
# Step 2a: Local Login
|
||||
elif st.session_state.login_step == "login_password":
|
||||
st.info(f"Welcome back, **{st.session_state.login_email}**!")
|
||||
with st.form("password_form"):
|
||||
password = st.text_input("Password", type="password")
|
||||
|
||||
col_login, col_back = st.columns([1, 1])
|
||||
submitted = col_login.form_submit_button("Login")
|
||||
back = col_back.form_submit_button("Back")
|
||||
|
||||
if back:
|
||||
st.session_state.login_step = "email"
|
||||
st.rerun()
|
||||
|
||||
if submitted:
|
||||
user = history_manager.authenticate_user(st.session_state.login_email, password)
|
||||
if user:
|
||||
login_user(user)
|
||||
else:
|
||||
st.error("Invalid email or password")
|
||||
|
||||
# Step 2b: Registration
|
||||
elif st.session_state.login_step == "register_details":
|
||||
st.info(f"Create an account for **{st.session_state.login_email}**")
|
||||
with st.form("register_form"):
|
||||
reg_name = st.text_input("Display Name")
|
||||
reg_password = st.text_input("Password", type="password")
|
||||
|
||||
col_reg, col_back = st.columns([1, 1])
|
||||
submitted = col_reg.form_submit_button("Register & Login")
|
||||
back = col_back.form_submit_button("Back")
|
||||
|
||||
if back:
|
||||
st.session_state.login_step = "email"
|
||||
st.rerun()
|
||||
|
||||
if submitted:
|
||||
if not reg_password:
|
||||
st.error("Password is required for registration.")
|
||||
else:
|
||||
user = history_manager.create_user(st.session_state.login_email, reg_password, reg_name)
|
||||
st.success("Registered! Logging in...")
|
||||
login_user(user)
|
||||
|
||||
# Step 2c: OIDC Redirection
|
||||
elif st.session_state.login_step == "oidc_login":
|
||||
st.info(f"**{st.session_state.login_email}** is configured for Single Sign-On (SSO).")
|
||||
|
||||
col_sso, col_back = st.columns([1, 1])
|
||||
|
||||
with col_sso:
|
||||
if oidc_client:
|
||||
login_url = oidc_client.get_login_url()
|
||||
st.link_button("Login with SSO", login_url, type="primary", use_container_width=True)
|
||||
else:
|
||||
st.error("OIDC is not configured.")
|
||||
|
||||
with col_back:
|
||||
if st.button("Back", use_container_width=True):
|
||||
st.session_state.login_step = "email"
|
||||
st.rerun()
|
||||
|
||||
with col2:
|
||||
if oidc_client:
|
||||
st.header("Single Sign-On")
|
||||
st.write("Login with your organizational account.")
|
||||
if st.button("Login with SSO"):
|
||||
login_url = oidc_client.get_login_url()
|
||||
st.link_button("Go to **YXXU**", login_url, type="primary")
|
||||
else:
|
||||
st.info("SSO is not configured.")
|
||||
|
||||
st.stop()
|
||||
|
||||
# --- Main App (Authenticated) ---
|
||||
|
||||
user = st.session_state.user
|
||||
|
||||
# Sidebar configuration
|
||||
with st.sidebar:
|
||||
st.title(f"Hi, {user.display_name or user.username}!")
|
||||
|
||||
if st.button("Logout"):
|
||||
logout_user()
|
||||
|
||||
st.divider()
|
||||
|
||||
st.header("History")
|
||||
if st.button("➕ New Chat", use_container_width=True):
|
||||
st.session_state.messages = []
|
||||
st.session_state.summary = ""
|
||||
st.session_state.current_conversation_id = None
|
||||
st.rerun()
|
||||
|
||||
# List conversations for the current user and data state
|
||||
conversations = history_manager.get_conversations(user.id, settings.data_state)
|
||||
|
||||
for conv in conversations:
|
||||
col_c, col_r, col_d = st.columns([0.7, 0.15, 0.15])
|
||||
|
||||
is_current = st.session_state.get("current_conversation_id") == conv.id
|
||||
label = f"💬 {conv.name}" if not is_current else f"👉 {conv.name}"
|
||||
|
||||
if col_c.button(label, key=f"conv_{conv.id}", use_container_width=True):
|
||||
load_conversation(conv.id)
|
||||
|
||||
if col_r.button("✏️", key=f"ren_{conv.id}"):
|
||||
st.session_state.renaming_id = conv.id
|
||||
|
||||
if col_d.button("🗑️", key=f"del_{conv.id}"):
|
||||
if history_manager.delete_conversation(conv.id):
|
||||
if is_current:
|
||||
st.session_state.current_conversation_id = None
|
||||
st.session_state.messages = []
|
||||
st.rerun()
|
||||
|
||||
# Rename dialog
|
||||
if st.session_state.get("renaming_id"):
|
||||
rid = st.session_state.renaming_id
|
||||
with st.form("rename_form"):
|
||||
new_name = st.text_input("New Name")
|
||||
if st.form_submit_button("Save"):
|
||||
history_manager.rename_conversation(rid, new_name)
|
||||
del st.session_state.renaming_id
|
||||
st.rerun()
|
||||
if st.form_submit_button("Cancel"):
|
||||
del st.session_state.renaming_id
|
||||
st.rerun()
|
||||
|
||||
st.divider()
|
||||
st.header("Settings")
|
||||
# Check for DEV_MODE env var (defaults to False)
|
||||
default_dev_mode = os.getenv("DEV_MODE", "false").lower() == "true"
|
||||
dev_mode = st.checkbox("Dev Mode", value=default_dev_mode, help="Enable to see code generation and raw reasoning steps.")
|
||||
|
||||
st.title("🗳️ Election Analytics Chatbot")
|
||||
|
||||
# Initialize chat history state
|
||||
if "messages" not in st.session_state:
|
||||
st.session_state.messages = []
|
||||
if "summary" not in st.session_state:
|
||||
st.session_state.summary = ""
|
||||
|
||||
# Display chat messages from history on app rerun
|
||||
for message in st.session_state.messages:
|
||||
with st.chat_message(message["role"]):
|
||||
if message.get("plan") and dev_mode:
|
||||
with st.expander("Reasoning Plan"):
|
||||
st.code(message["plan"], language="yaml")
|
||||
if message.get("code") and dev_mode:
|
||||
with st.expander("Generated Code"):
|
||||
st.code(message["code"], language="python")
|
||||
|
||||
st.markdown(message["content"])
|
||||
if message.get("plots"):
|
||||
for plot_data in message["plots"]:
|
||||
# If plot_data is bytes, convert to image
|
||||
if isinstance(plot_data, bytes):
|
||||
st.image(plot_data)
|
||||
else:
|
||||
# Fallback for old session state or non-binary
|
||||
st.pyplot(plot_data)
|
||||
if message.get("dfs"):
|
||||
for df_name, df in message["dfs"].items():
|
||||
st.subheader(f"Data: {df_name}")
|
||||
st.dataframe(df)
|
||||
|
||||
# Accept user input
|
||||
if prompt := st.chat_input("Ask a question about election data..."):
|
||||
# Ensure we have a conversation ID
|
||||
if not st.session_state.get("current_conversation_id"):
|
||||
# Auto-create conversation
|
||||
conv_name = (prompt[:30] + '...') if len(prompt) > 30 else prompt
|
||||
conv = history_manager.create_conversation(user.id, settings.data_state, conv_name)
|
||||
st.session_state.current_conversation_id = conv.id
|
||||
|
||||
conv_id = st.session_state.current_conversation_id
|
||||
|
||||
# Save user message to DB
|
||||
history_manager.add_message(conv_id, "user", prompt)
|
||||
|
||||
# Add user message to session state
|
||||
st.session_state.messages.append({"role": "user", "content": prompt})
|
||||
|
||||
# Display user message in chat message container
|
||||
with st.chat_message("user"):
|
||||
st.markdown(prompt)
|
||||
|
||||
# Prepare graph input
|
||||
initial_state: AgentState = {
|
||||
"messages": st.session_state.messages[:-1], # Pass history (excluding the current prompt)
|
||||
"question": prompt,
|
||||
"summary": st.session_state.summary,
|
||||
"analysis": None,
|
||||
"next_action": "",
|
||||
"plan": None,
|
||||
"code": None,
|
||||
"code_output": None,
|
||||
"error": None,
|
||||
"plots": [],
|
||||
"dfs": {}
|
||||
}
|
||||
|
||||
# Placeholder for graph output
|
||||
with st.chat_message("assistant"):
|
||||
final_state = initial_state
|
||||
# Real-time node updates
|
||||
with st.status("Thinking...", expanded=True) as status:
|
||||
try:
|
||||
# Use app.stream to capture node transitions
|
||||
for event in app.stream(initial_state):
|
||||
for node_name, state_update in event.items():
|
||||
prev_error = final_state.get("error")
|
||||
# Use helper to merge state correctly (appending messages/plots, updating dfs)
|
||||
final_state = merge_agent_state(final_state, state_update)
|
||||
|
||||
if node_name == "query_analyzer":
|
||||
analysis = state_update.get("analysis", {})
|
||||
next_action = state_update.get("next_action", "unknown")
|
||||
status.write(f"🔍 **Analyzed Query:**")
|
||||
for k,v in analysis.items():
|
||||
status.write(f"- {k:<8}: {v}")
|
||||
status.markdown(f"Next Step: {next_action.capitalize()}")
|
||||
|
||||
elif node_name == "planner":
|
||||
status.write("📋 **Plan Generated**")
|
||||
# Render artifacts
|
||||
if state_update.get("plan") and dev_mode:
|
||||
with st.expander("Reasoning Plan", expanded=True):
|
||||
st.code(state_update["plan"], language="yaml")
|
||||
|
||||
elif node_name == "researcher":
|
||||
status.write("🌐 **Research Complete**")
|
||||
if state_update.get("messages") and dev_mode:
|
||||
for msg in state_update["messages"]:
|
||||
# Extract content from BaseMessage or show raw string
|
||||
content = getattr(msg, "text", msg.content)
|
||||
status.markdown(content)
|
||||
|
||||
elif node_name == "coder":
|
||||
status.write("💻 **Code Generated**")
|
||||
if state_update.get("code") and dev_mode:
|
||||
with st.expander("Generated Code"):
|
||||
st.code(state_update["code"], language="python")
|
||||
|
||||
elif node_name == "error_corrector":
|
||||
status.write("🛠️ **Fixing Execution Error...**")
|
||||
if prev_error:
|
||||
truncated_error = prev_error.strip()
|
||||
if len(truncated_error) > 180:
|
||||
truncated_error = truncated_error[:180] + "..."
|
||||
status.write(f"Previous error: {truncated_error}")
|
||||
if state_update.get("code") and dev_mode:
|
||||
with st.expander("Corrected Code"):
|
||||
st.code(state_update["code"], language="python")
|
||||
|
||||
elif node_name == "executor":
|
||||
if state_update.get("error"):
|
||||
if dev_mode:
|
||||
status.write(f"❌ **Execution Error:** {state_update.get('error')}...")
|
||||
else:
|
||||
status.write(f"❌ **Execution Error:** {state_update.get('error')[:100]}...")
|
||||
else:
|
||||
status.write("✅ **Execution Successful**")
|
||||
if state_update.get("plots"):
|
||||
status.write(f"📊 Generated {len(state_update['plots'])} plot(s)")
|
||||
|
||||
elif node_name == "summarizer":
|
||||
status.write("📝 **Summarizing Results...**")
|
||||
|
||||
|
||||
status.update(label="Complete!", state="complete", expanded=False)
|
||||
|
||||
except Exception as e:
|
||||
status.update(label="Error!", state="error")
|
||||
st.error(f"Error during graph execution: {str(e)}")
|
||||
|
||||
# Extract results
|
||||
response_text: str = ""
|
||||
if final_state.get("messages"):
|
||||
# The last message is the Assistant's response
|
||||
last_msg = final_state["messages"][-1]
|
||||
response_text = getattr(last_msg, "text", str(last_msg.content))
|
||||
st.markdown(response_text)
|
||||
|
||||
# Collect plot bytes for saving to DB
|
||||
plot_bytes_list = []
|
||||
if final_state.get("plots"):
|
||||
for fig in final_state["plots"]:
|
||||
st.pyplot(fig)
|
||||
# Convert fig to bytes
|
||||
buf = io.BytesIO()
|
||||
fig.savefig(buf, format="png")
|
||||
plot_bytes_list.append(buf.getvalue())
|
||||
|
||||
if final_state.get("dfs"):
|
||||
for df_name, df in final_state["dfs"].items():
|
||||
st.subheader(f"Data: {df_name}")
|
||||
st.dataframe(df)
|
||||
|
||||
# Save assistant message to DB
|
||||
history_manager.add_message(conv_id, "assistant", response_text, plots=plot_bytes_list)
|
||||
|
||||
# Update summary in DB
|
||||
new_summary = final_state.get("summary", "")
|
||||
if new_summary:
|
||||
history_manager.update_conversation_summary(conv_id, new_summary)
|
||||
|
||||
# Store assistant response in session history
|
||||
st.session_state.messages.append({
|
||||
"role": "assistant",
|
||||
"content": response_text,
|
||||
"plan": final_state.get("plan"),
|
||||
"code": final_state.get("code"),
|
||||
"plots": plot_bytes_list,
|
||||
"dfs": final_state.get("dfs")
|
||||
})
|
||||
st.session_state.summary = new_summary
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
97
backend/src/ea_chatbot/auth.py
Normal file
97
backend/src/ea_chatbot/auth.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import requests
|
||||
from enum import Enum
|
||||
from typing import Dict, Any, Optional
|
||||
from authlib.integrations.requests_client import OAuth2Session
|
||||
|
||||
class AuthType(Enum):
|
||||
LOCAL = "local"
|
||||
OIDC = "oidc"
|
||||
NEW = "new"
|
||||
|
||||
def get_user_auth_type(email: str, history_manager: Any) -> AuthType:
|
||||
"""
|
||||
Determine the authentication type for a given email.
|
||||
|
||||
Args:
|
||||
email: The user's email address.
|
||||
history_manager: Instance of HistoryManager to check the DB.
|
||||
|
||||
Returns:
|
||||
AuthType: LOCAL if password exists, OIDC if user exists but no password, NEW otherwise.
|
||||
"""
|
||||
user = history_manager.get_user(email)
|
||||
|
||||
if not user:
|
||||
return AuthType.NEW
|
||||
|
||||
if user.password_hash:
|
||||
return AuthType.LOCAL
|
||||
|
||||
return AuthType.OIDC
|
||||
|
||||
class OIDCClient:
|
||||
"""
|
||||
Client for OIDC Authentication using Authlib.
|
||||
Designed to work within a Streamlit environment.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
server_metadata_url: str,
|
||||
redirect_uri: str = "http://localhost:8501"
|
||||
):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.server_metadata_url = server_metadata_url
|
||||
self.redirect_uri = redirect_uri
|
||||
self.oauth_session = OAuth2Session(
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
redirect_uri=redirect_uri,
|
||||
scope="openid email profile"
|
||||
)
|
||||
self.metadata: Dict[str, Any] = {}
|
||||
|
||||
def fetch_metadata(self) -> Dict[str, Any]:
|
||||
"""Fetch OIDC provider metadata if not already fetched."""
|
||||
if not self.metadata:
|
||||
self.metadata = requests.get(self.server_metadata_url).json()
|
||||
return self.metadata
|
||||
|
||||
def get_login_url(self) -> str:
|
||||
"""Generate the authorization URL."""
|
||||
metadata = self.fetch_metadata()
|
||||
authorization_endpoint = metadata.get("authorization_endpoint")
|
||||
if not authorization_endpoint:
|
||||
raise ValueError("authorization_endpoint not found in OIDC metadata")
|
||||
|
||||
uri, state = self.oauth_session.create_authorization_url(authorization_endpoint)
|
||||
return uri
|
||||
|
||||
def exchange_code_for_token(self, code: str) -> Dict[str, Any]:
|
||||
"""Exchange the authorization code for an access token."""
|
||||
metadata = self.fetch_metadata()
|
||||
token_endpoint = metadata.get("token_endpoint")
|
||||
if not token_endpoint:
|
||||
raise ValueError("token_endpoint not found in OIDC metadata")
|
||||
|
||||
token = self.oauth_session.fetch_token(
|
||||
token_endpoint,
|
||||
code=code,
|
||||
client_secret=self.client_secret
|
||||
)
|
||||
return token
|
||||
|
||||
def get_user_info(self, token: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Fetch user information using the access token."""
|
||||
metadata = self.fetch_metadata()
|
||||
userinfo_endpoint = metadata.get("userinfo_endpoint")
|
||||
if not userinfo_endpoint:
|
||||
raise ValueError("userinfo_endpoint not found in OIDC metadata")
|
||||
|
||||
# Set the token on the session so it's used in the request
|
||||
self.oauth_session.token = token
|
||||
resp = self.oauth_session.get(userinfo_endpoint)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
52
backend/src/ea_chatbot/config.py
Normal file
52
backend/src/ea_chatbot/config.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from typing import Dict, Any, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
class LLMConfig(BaseModel):
|
||||
"""Configuration for a specific LLM node."""
|
||||
provider: str = "openai"
|
||||
model: str = "gpt-5-mini"
|
||||
temperature: float = 0.0
|
||||
max_tokens: Optional[int] = None
|
||||
provider_specific: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Global application settings."""
|
||||
|
||||
data_dir: str = "data"
|
||||
data_state: str = "new_jersey"
|
||||
log_level: str = Field(default="INFO", alias="LOG_LEVEL")
|
||||
|
||||
# Voter Database configuration
|
||||
db_host: str = Field(default="localhost", alias="DB_HOST")
|
||||
db_port: int = Field(default=5432, alias="DB_PORT")
|
||||
db_name: str = Field(default="blockdata", alias="DB_NAME")
|
||||
db_user: str = Field(default="user", alias="DB_USER")
|
||||
db_pswd: str = Field(default="password", alias="DB_PSWD")
|
||||
db_table: str = Field(default="rd_gc_voters_nj", alias="DB_TABLE")
|
||||
|
||||
# Application/History Database
|
||||
history_db_url: str = Field(default="postgresql://user:password@localhost:5433/ea_history", alias="HISTORY_DB_URL")
|
||||
|
||||
# JWT Configuration
|
||||
secret_key: str = Field(default="change-me-in-production", alias="SECRET_KEY")
|
||||
algorithm: str = Field(default="HS256", alias="ALGORITHM")
|
||||
access_token_expire_minutes: int = Field(default=30, alias="ACCESS_TOKEN_EXPIRE_MINUTES")
|
||||
|
||||
# OIDC Configuration
|
||||
oidc_client_id: Optional[str] = Field(default=None, alias="OIDC_CLIENT_ID")
|
||||
oidc_client_secret: Optional[str] = Field(default=None, alias="OIDC_CLIENT_SECRET")
|
||||
oidc_server_metadata_url: Optional[str] = Field(default=None, alias="OIDC_SERVER_METADATA_URL")
|
||||
|
||||
# Default configurations for each node
|
||||
query_analyzer_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0))
|
||||
planner_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0))
|
||||
coder_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0))
|
||||
summarizer_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0))
|
||||
researcher_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0))
|
||||
|
||||
# Allow nested env vars like QUERY_ANALYZER_LLM__MODEL
|
||||
model_config = SettingsConfigDict(env_nested_delimiter='__', env_prefix='')
|
||||
0
backend/src/ea_chatbot/graph/__init__.py
Normal file
0
backend/src/ea_chatbot/graph/__init__.py
Normal file
44
backend/src/ea_chatbot/graph/checkpoint.py
Normal file
44
backend/src/ea_chatbot/graph/checkpoint.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import contextlib
|
||||
from typing import AsyncGenerator
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
from ea_chatbot.config import Settings
|
||||
|
||||
settings = Settings()
|
||||
|
||||
_pool = None
|
||||
|
||||
def get_pool() -> AsyncConnectionPool:
|
||||
"""Get or create the async connection pool."""
|
||||
global _pool
|
||||
if _pool is None:
|
||||
_pool = AsyncConnectionPool(
|
||||
conninfo=settings.history_db_url,
|
||||
max_size=20,
|
||||
kwargs={"autocommit": True, "prepare_threshold": 0},
|
||||
open=False, # Don't open automatically on init
|
||||
)
|
||||
return _pool
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def get_checkpointer() -> AsyncGenerator[AsyncPostgresSaver, None]:
|
||||
"""
|
||||
Context manager to get a PostgresSaver checkpointer.
|
||||
Ensures that the checkpointer is properly initialized and the connection is managed.
|
||||
"""
|
||||
pool = get_pool()
|
||||
# Ensure pool is open
|
||||
if pool.closed:
|
||||
await pool.open()
|
||||
|
||||
async with pool.connection() as conn:
|
||||
checkpointer = AsyncPostgresSaver(conn)
|
||||
# Ensure the necessary tables exist
|
||||
await checkpointer.setup()
|
||||
yield checkpointer
|
||||
|
||||
async def close_pool():
|
||||
"""Close the connection pool."""
|
||||
global _pool
|
||||
if _pool and not _pool.closed:
|
||||
await _pool.close()
|
||||
0
backend/src/ea_chatbot/graph/nodes/__init__.py
Normal file
0
backend/src/ea_chatbot/graph/nodes/__init__.py
Normal file
45
backend/src/ea_chatbot/graph/nodes/clarification.py
Normal file
45
backend/src/ea_chatbot/graph/nodes/clarification.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from langchain_core.messages import AIMessage
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||
|
||||
def clarification_node(state: AgentState) -> dict:
|
||||
"""Ask the user for missing information or clarifications."""
|
||||
question = state["question"]
|
||||
analysis = state.get("analysis", {})
|
||||
ambiguities = analysis.get("ambiguities", [])
|
||||
|
||||
settings = Settings()
|
||||
logger = get_logger("clarification")
|
||||
|
||||
logger.info(f"Generating clarification for {len(ambiguities)} ambiguities.")
|
||||
|
||||
llm = get_llm_model(
|
||||
settings.query_analyzer_llm,
|
||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||
)
|
||||
|
||||
system_prompt = """You are a Clarification Specialist. Your role is to identify what information is missing from a user's request to perform a data analysis or research task.
|
||||
Based on the analysis of the user's question, formulate a polite and concise request for the missing information."""
|
||||
|
||||
prompt = f"""Original Question: {question}
|
||||
Missing/Ambiguous Information: {', '.join(ambiguities) if ambiguities else 'Unknown ambiguities'}
|
||||
|
||||
Please ask the user for the necessary details."""
|
||||
|
||||
messages = [
|
||||
("system", system_prompt),
|
||||
("user", prompt)
|
||||
]
|
||||
|
||||
try:
|
||||
response = llm.invoke(messages)
|
||||
logger.info("[bold green]Clarification generated.[/bold green]")
|
||||
return {
|
||||
"messages": [response],
|
||||
"next_action": "end" # To indicate we are done for now
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate clarification: {str(e)}")
|
||||
raise e
|
||||
47
backend/src/ea_chatbot/graph/nodes/coder.py
Normal file
47
backend/src/ea_chatbot/graph/nodes/coder.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from ea_chatbot.utils import helpers, database_inspection
|
||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||
from ea_chatbot.graph.prompts.coder import CODE_GENERATOR_PROMPT
|
||||
from ea_chatbot.schemas import CodeGenerationResponse
|
||||
|
||||
def coder_node(state: AgentState) -> dict:
|
||||
"""Generate Python code based on the plan and data summary."""
|
||||
question = state["question"]
|
||||
plan = state.get("plan", "")
|
||||
code_output = state.get("code_output", "None")
|
||||
|
||||
settings = Settings()
|
||||
logger = get_logger("coder")
|
||||
|
||||
logger.info("Generating Python code...")
|
||||
|
||||
llm = get_llm_model(
|
||||
settings.coder_llm,
|
||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||
)
|
||||
structured_llm = llm.with_structured_output(CodeGenerationResponse)
|
||||
|
||||
# Always provide data summary
|
||||
database_description = database_inspection.get_data_summary(data_dir=settings.data_dir) or "No data available."
|
||||
example_code = "" # Placeholder
|
||||
|
||||
messages = CODE_GENERATOR_PROMPT.format_messages(
|
||||
question=question,
|
||||
plan=plan,
|
||||
database_description=database_description,
|
||||
code_exec_results=code_output,
|
||||
example_code=example_code
|
||||
)
|
||||
|
||||
try:
|
||||
response = structured_llm.invoke(messages)
|
||||
logger.info("[bold green]Code generated.[/bold green]")
|
||||
return {
|
||||
"code": response.parsed_code,
|
||||
"error": None # Clear previous errors on new code generation
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate code: {str(e)}")
|
||||
raise e
|
||||
44
backend/src/ea_chatbot/graph/nodes/error_corrector.py
Normal file
44
backend/src/ea_chatbot/graph/nodes/error_corrector.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||
from ea_chatbot.graph.prompts.coder import ERROR_CORRECTOR_PROMPT
|
||||
from ea_chatbot.schemas import CodeGenerationResponse
|
||||
|
||||
def error_corrector_node(state: AgentState) -> dict:
|
||||
"""Fix the code based on the execution error."""
|
||||
code = state.get("code", "")
|
||||
error = state.get("error", "Unknown error")
|
||||
|
||||
settings = Settings()
|
||||
logger = get_logger("error_corrector")
|
||||
|
||||
logger.warning(f"[bold red]Execution error detected:[/bold red] {error[:100]}...")
|
||||
logger.info("Attempting to correct the code...")
|
||||
|
||||
# Reuse coder LLM config or add a new one. Using coder_llm for now.
|
||||
llm = get_llm_model(
|
||||
settings.coder_llm,
|
||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||
)
|
||||
structured_llm = llm.with_structured_output(CodeGenerationResponse)
|
||||
|
||||
messages = ERROR_CORRECTOR_PROMPT.format_messages(
|
||||
code=code,
|
||||
error=error
|
||||
)
|
||||
|
||||
try:
|
||||
response = structured_llm.invoke(messages)
|
||||
logger.info("[bold green]Correction generated.[/bold green]")
|
||||
|
||||
current_iterations = state.get("iterations", 0)
|
||||
|
||||
return {
|
||||
"code": response.parsed_code,
|
||||
"error": None, # Clear error after fix attempt
|
||||
"iterations": current_iterations + 1
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to correct code: {str(e)}")
|
||||
raise e
|
||||
102
backend/src/ea_chatbot/graph/nodes/executor.py
Normal file
102
backend/src/ea_chatbot/graph/nodes/executor.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import io
|
||||
import sys
|
||||
import traceback
|
||||
from contextlib import redirect_stdout
|
||||
from typing import Any, Dict, List, TYPE_CHECKING
|
||||
import pandas as pd
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.utils.db_client import DBClient
|
||||
from ea_chatbot.utils.logging import get_logger
|
||||
from ea_chatbot.config import Settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ea_chatbot.types import DBSettings
|
||||
|
||||
def executor_node(state: AgentState) -> dict:
|
||||
"""Execute the Python code and capture output, plots, and dataframes."""
|
||||
code = state.get("code")
|
||||
logger = get_logger("executor")
|
||||
|
||||
if not code:
|
||||
logger.error("No code provided to executor.")
|
||||
return {"error": "No code provided to executor."}
|
||||
|
||||
logger.info("Executing Python code...")
|
||||
settings = Settings()
|
||||
|
||||
db_settings: "DBSettings" = {
|
||||
"host": settings.db_host,
|
||||
"port": settings.db_port,
|
||||
"user": settings.db_user,
|
||||
"pswd": settings.db_pswd,
|
||||
"db": settings.db_name,
|
||||
"table": settings.db_table
|
||||
}
|
||||
|
||||
db_client = DBClient(settings=db_settings)
|
||||
|
||||
# Initialize local variables for execution
|
||||
# 'db' is the DBClient instance, 'plots' is for matplotlib figures
|
||||
local_vars = {
|
||||
'db': db_client,
|
||||
'plots': [],
|
||||
'pd': pd
|
||||
}
|
||||
|
||||
stdout_buffer = io.StringIO()
|
||||
error = None
|
||||
code_output = ""
|
||||
plots = []
|
||||
dfs = {}
|
||||
|
||||
try:
|
||||
with redirect_stdout(stdout_buffer):
|
||||
# Execute the code in the context of local_vars
|
||||
exec(code, {}, local_vars)
|
||||
|
||||
code_output = stdout_buffer.getvalue()
|
||||
|
||||
# Limit the output length if it's too long
|
||||
if code_output.count('\n') > 32:
|
||||
code_output = '\n'.join(code_output.split('\n')[:32]) + '\n...'
|
||||
|
||||
# Extract plots
|
||||
raw_plots = local_vars.get('plots', [])
|
||||
if isinstance(raw_plots, list):
|
||||
plots = [p for p in raw_plots if isinstance(p, Figure)]
|
||||
|
||||
# Extract DataFrames that were likely intended for display
|
||||
# We look for DataFrames in local_vars that were mentioned in the code
|
||||
for key, value in local_vars.items():
|
||||
if isinstance(value, pd.DataFrame):
|
||||
# Heuristic: if the variable name is in the code, it might be a result DF
|
||||
if key in code:
|
||||
dfs[key] = value
|
||||
|
||||
logger.info(f"[bold green]Execution complete.[/bold green] Captured {len(plots)} plots and {len(dfs)} dataframes.")
|
||||
|
||||
except Exception as e:
|
||||
# Capture the traceback
|
||||
exc_type, exc_value, tb = sys.exc_info()
|
||||
full_traceback = traceback.format_exc()
|
||||
|
||||
# Filter traceback to show only the relevant part (the executed string)
|
||||
filtered_tb_lines = [line for line in full_traceback.split('\n') if '<string>' in line]
|
||||
error = '\n'.join(filtered_tb_lines)
|
||||
if error:
|
||||
error += '\n'
|
||||
error += f"{exc_type.__name__ if exc_type else 'Exception'}: {exc_value}"
|
||||
|
||||
logger.error(f"Execution failed: {str(e)}")
|
||||
|
||||
# If we have an error, we still might want to see partial stdout
|
||||
code_output = stdout_buffer.getvalue()
|
||||
|
||||
return {
|
||||
"code_output": code_output,
|
||||
"error": error,
|
||||
"plots": plots,
|
||||
"dfs": dfs
|
||||
}
|
||||
51
backend/src/ea_chatbot/graph/nodes/planner.py
Normal file
51
backend/src/ea_chatbot/graph/nodes/planner.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import yaml
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from ea_chatbot.utils import helpers, database_inspection
|
||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||
from ea_chatbot.graph.prompts.planner import PLANNER_PROMPT
|
||||
from ea_chatbot.schemas import TaskPlanResponse
|
||||
|
||||
def planner_node(state: AgentState) -> dict:
|
||||
"""Generate a structured plan based on the query analysis."""
|
||||
question = state["question"]
|
||||
history = state.get("messages", [])[-6:]
|
||||
summary = state.get("summary", "")
|
||||
|
||||
settings = Settings()
|
||||
logger = get_logger("planner")
|
||||
|
||||
logger.info("Generating task plan...")
|
||||
|
||||
llm = get_llm_model(
|
||||
settings.planner_llm,
|
||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||
)
|
||||
structured_llm = llm.with_structured_output(TaskPlanResponse)
|
||||
|
||||
date_str = helpers.get_readable_date()
|
||||
|
||||
# Always provide data summary; LLM decides relevance.
|
||||
database_description = database_inspection.get_data_summary(data_dir=settings.data_dir) or "No data available."
|
||||
example_plan = ""
|
||||
|
||||
messages = PLANNER_PROMPT.format_messages(
|
||||
date=date_str,
|
||||
question=question,
|
||||
history=history,
|
||||
summary=summary,
|
||||
database_description=database_description,
|
||||
example_plan=example_plan
|
||||
)
|
||||
|
||||
# Generate the structured plan
|
||||
try:
|
||||
response = structured_llm.invoke(messages)
|
||||
# Convert the structured response back to YAML string for the state
|
||||
plan_yaml = yaml.dump(response.model_dump(), sort_keys=False)
|
||||
logger.info("[bold green]Plan generated successfully.[/bold green]")
|
||||
return {"plan": plan_yaml}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate plan: {str(e)}")
|
||||
raise e
|
||||
73
backend/src/ea_chatbot/graph/nodes/query_analyzer.py
Normal file
73
backend/src/ea_chatbot/graph/nodes/query_analyzer.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from typing import List, Literal
|
||||
from pydantic import BaseModel, Field
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||
from ea_chatbot.graph.prompts.query_analyzer import QUERY_ANALYZER_PROMPT
|
||||
|
||||
class QueryAnalysis(BaseModel):
|
||||
"""Analysis of the user's query."""
|
||||
data_required: List[str] = Field(description="List of data points or entities mentioned (e.g., ['2024 results', 'Florida']).")
|
||||
unknowns: List[str] = Field(description="List of target information the user wants to know or needed for final answer (e.g., 'who won', 'total votes').")
|
||||
ambiguities: List[str] = Field(description="List of CRITICAL missing details that prevent ANY analysis. Do NOT include database names or plot types if defaults can be used.")
|
||||
conditions: List[str] = Field(description="List of any filters or constraints (e.g., ['year=2024', 'state=Florida']). Include context resolved from history.")
|
||||
next_action: Literal["plan", "clarify", "research"] = Field(description="The next action to take. 'plan' for data analysis (even with defaults), 'research' for general knowledge, or 'clarify' ONLY for critical ambiguities.")
|
||||
|
||||
def query_analyzer_node(state: AgentState) -> dict:
|
||||
"""Analyze the user's question and determine the next course of action."""
|
||||
question = state["question"]
|
||||
history = state.get("messages", [])
|
||||
summary = state.get("summary", "")
|
||||
|
||||
# Keep last 3 turns (6 messages)
|
||||
history = history[-6:]
|
||||
|
||||
settings = Settings()
|
||||
logger = get_logger("query_analyzer")
|
||||
|
||||
logger.info(f"Analyzing question: [italic]\"{question}\"[/italic]")
|
||||
|
||||
# Initialize the LLM with structured output using the factory
|
||||
# Pass logging callback to track LLM usage
|
||||
llm = get_llm_model(
|
||||
settings.query_analyzer_llm,
|
||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||
)
|
||||
structured_llm = llm.with_structured_output(QueryAnalysis)
|
||||
|
||||
# Prepare messages using the prompt template
|
||||
messages = QUERY_ANALYZER_PROMPT.format_messages(
|
||||
question=question,
|
||||
history=history,
|
||||
summary=summary
|
||||
)
|
||||
|
||||
try:
|
||||
# Invoke the structured LLM directly with the list of messages
|
||||
analysis_result = structured_llm.invoke(messages)
|
||||
analysis_result = QueryAnalysis.model_validate(analysis_result)
|
||||
|
||||
analysis_dict = analysis_result.model_dump()
|
||||
analysis_dict.pop("next_action")
|
||||
next_action = analysis_result.next_action
|
||||
|
||||
logger.info(f"Analysis complete. Next action: [bold magenta]{next_action}[/bold magenta]")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during query analysis: {str(e)}")
|
||||
analysis_dict = {
|
||||
"data_required": [],
|
||||
"unknowns": [],
|
||||
"ambiguities": [f"Error during analysis: {str(e)}"],
|
||||
"conditions": []
|
||||
}
|
||||
next_action = "clarify"
|
||||
|
||||
return {
|
||||
"analysis": analysis_dict,
|
||||
"next_action": next_action,
|
||||
"iterations": 0
|
||||
}
|
||||
|
||||
|
||||
60
backend/src/ea_chatbot/graph/nodes/researcher.py
Normal file
60
backend/src/ea_chatbot/graph/nodes/researcher.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from ea_chatbot.utils import helpers
|
||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||
from ea_chatbot.graph.prompts.researcher import RESEARCHER_PROMPT
|
||||
|
||||
def researcher_node(state: AgentState) -> dict:
|
||||
"""Handle general research queries or web searches."""
|
||||
question = state["question"]
|
||||
history = state.get("messages", [])[-6:]
|
||||
summary = state.get("summary", "")
|
||||
|
||||
settings = Settings()
|
||||
logger = get_logger("researcher")
|
||||
|
||||
logger.info(f"Researching question: [italic]\"{question}\"[/italic]")
|
||||
|
||||
# Use researcher_llm from settings
|
||||
llm = get_llm_model(
|
||||
settings.researcher_llm,
|
||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||
)
|
||||
|
||||
date_str = helpers.get_readable_date()
|
||||
|
||||
messages = RESEARCHER_PROMPT.format_messages(
|
||||
date=date_str,
|
||||
question=question,
|
||||
history=history,
|
||||
summary=summary
|
||||
)
|
||||
|
||||
# Provider-aware tool binding
|
||||
try:
|
||||
if isinstance(llm, ChatGoogleGenerativeAI):
|
||||
# Native Google Search for Gemini
|
||||
llm_with_tools = llm.bind_tools([{"google_search": {}}])
|
||||
elif isinstance(llm, ChatOpenAI):
|
||||
# Native Web Search for OpenAI (built-in tool)
|
||||
llm_with_tools = llm.bind_tools([{"type": "web_search"}])
|
||||
else:
|
||||
# Fallback for other providers that might not support these specific search tools
|
||||
llm_with_tools = llm
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to bind search tools: {str(e)}. Falling back to base LLM.")
|
||||
llm_with_tools = llm
|
||||
|
||||
try:
|
||||
response = llm_with_tools.invoke(messages)
|
||||
logger.info("[bold green]Research complete.[/bold green]")
|
||||
return {
|
||||
"messages": [response]
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Research failed: {str(e)}")
|
||||
raise e
|
||||
52
backend/src/ea_chatbot/graph/nodes/summarize_conversation.py
Normal file
52
backend/src/ea_chatbot/graph/nodes/summarize_conversation.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from langchain_core.messages import SystemMessage, HumanMessage
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||
|
||||
def summarize_conversation_node(state: AgentState) -> dict:
|
||||
"""Update the conversation summary based on the latest interaction."""
|
||||
summary = state.get("summary", "")
|
||||
messages = state.get("messages", [])
|
||||
|
||||
# We only summarize if there are messages
|
||||
if not messages:
|
||||
return {}
|
||||
|
||||
# Get the last turn (User + Assistant)
|
||||
last_turn = messages[-2:]
|
||||
|
||||
settings = Settings()
|
||||
logger = get_logger("summarize_conversation")
|
||||
|
||||
logger.info("Updating conversation summary...")
|
||||
|
||||
# Use summarizer_llm for this task as well
|
||||
llm = get_llm_model(
|
||||
settings.summarizer_llm,
|
||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||
)
|
||||
|
||||
if summary:
|
||||
prompt = (
|
||||
f"This is a summary of the conversation so far: {summary}\n\n"
|
||||
"Extend the summary by taking into account the new messages above."
|
||||
)
|
||||
else:
|
||||
prompt = "Create a summary of the conversation above."
|
||||
|
||||
# Construct the messages for the summarization LLM
|
||||
summarization_messages = [
|
||||
SystemMessage(content=f"Current summary: {summary}" if summary else "You are a helpful assistant that summarizes conversations."),
|
||||
HumanMessage(content=f"Recent messages:\n{last_turn}\n\n{prompt}\n\nKeep the summary concise and focused on the key topics and data points discussed.")
|
||||
]
|
||||
|
||||
try:
|
||||
response = llm.invoke(summarization_messages)
|
||||
new_summary = response.content
|
||||
logger.info("[bold green]Conversation summary updated.[/bold green]")
|
||||
return {"summary": new_summary}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update summary: {str(e)}")
|
||||
# If summarization fails, we keep the old one
|
||||
return {"summary": summary}
|
||||
44
backend/src/ea_chatbot/graph/nodes/summarizer.py
Normal file
44
backend/src/ea_chatbot/graph/nodes/summarizer.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from langchain_core.messages import AIMessage
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||
from ea_chatbot.graph.prompts.summarizer import SUMMARIZER_PROMPT
|
||||
|
||||
def summarizer_node(state: AgentState) -> dict:
|
||||
"""Summarize the code execution results into a final answer."""
|
||||
question = state["question"]
|
||||
plan = state.get("plan", "")
|
||||
code_output = state.get("code_output", "")
|
||||
history = state.get("messages", [])[-6:]
|
||||
summary = state.get("summary", "")
|
||||
|
||||
settings = Settings()
|
||||
logger = get_logger("summarizer")
|
||||
|
||||
logger.info("Generating final summary...")
|
||||
|
||||
llm = get_llm_model(
|
||||
settings.summarizer_llm,
|
||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||
)
|
||||
|
||||
messages = SUMMARIZER_PROMPT.format_messages(
|
||||
question=question,
|
||||
plan=plan,
|
||||
code_output=code_output,
|
||||
history=history,
|
||||
summary=summary
|
||||
)
|
||||
|
||||
try:
|
||||
response = llm.invoke(messages)
|
||||
logger.info("[bold green]Summary generated.[/bold green]")
|
||||
|
||||
# Return the final message to be added to the state
|
||||
return {
|
||||
"messages": [response]
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate summary: {str(e)}")
|
||||
raise e
|
||||
10
backend/src/ea_chatbot/graph/prompts/__init__.py
Normal file
10
backend/src/ea_chatbot/graph/prompts/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from .query_analyzer import QUERY_ANALYZER_PROMPT
|
||||
from .planner import PLANNER_PROMPT
|
||||
from .coder import CODE_GENERATOR_PROMPT, ERROR_CORRECTOR_PROMPT
|
||||
|
||||
__all__ = [
|
||||
"QUERY_ANALYZER_PROMPT",
|
||||
"PLANNER_PROMPT",
|
||||
"CODE_GENERATOR_PROMPT",
|
||||
"ERROR_CORRECTOR_PROMPT",
|
||||
]
|
||||
64
backend/src/ea_chatbot/graph/prompts/coder.py
Normal file
64
backend/src/ea_chatbot/graph/prompts/coder.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
|
||||
CODE_GENERATOR_SYSTEM = """You are an AI data analyst and your job is to assist users with data analysis and coding tasks.
|
||||
The user will provide a task and a plan.
|
||||
|
||||
**Data Access:**
|
||||
- A database client is available as a variable named `db`.
|
||||
- You MUST use `db.query_df(sql_query)` to execute SQL queries and retrieve data as a Pandas DataFrame.
|
||||
- Do NOT assume a dataframe `df` is already loaded unless explicitly stated. You usually need to query it first.
|
||||
- The database schema is described in the prompt. Use it to construct valid SQL queries.
|
||||
|
||||
**Plotting:**
|
||||
- If you need to plot any data, use the `plots` list to store the figures.
|
||||
- Example: `plots.append(fig)` or `plots.append(plt.gcf())`.
|
||||
- Do not use `plt.show()` as it will render the plot and cause an error.
|
||||
|
||||
**Code Requirements:**
|
||||
- Produce FULL, COMPLETE CODE that includes all steps and solves the task!
|
||||
- Always include the import statements at the top of the code (e.g., `import pandas as pd`, `import matplotlib.pyplot as plt`).
|
||||
- Always include print statements to output the results of your code.
|
||||
- Use `db.query_df("SELECT ...")` to get data."""
|
||||
|
||||
CODE_GENERATOR_USER = """TASK:
|
||||
{question}
|
||||
|
||||
PLAN:
|
||||
```yaml
|
||||
{plan}
|
||||
```
|
||||
|
||||
AVAILABLE DATA SUMMARY (Database Schema):
|
||||
{database_description}
|
||||
|
||||
CODE EXECUTION OF THE PREVIOUS TASK RESULTED IN:
|
||||
{code_exec_results}
|
||||
|
||||
{example_code}"""
|
||||
|
||||
ERROR_CORRECTOR_SYSTEM = """The execution of the code resulted in an error.
|
||||
Return a complete, corrected python code that incorporates the fixes for the error.
|
||||
|
||||
**Reminders:**
|
||||
- You have access to a database client via the variable `db`.
|
||||
- Use `db.query_df(sql)` to run queries.
|
||||
- Use `plots.append(fig)` for plots.
|
||||
- Always include imports and print statements."""
|
||||
|
||||
ERROR_CORRECTOR_USER = """FAILED CODE:
|
||||
```python
|
||||
{code}
|
||||
```
|
||||
|
||||
ERROR:
|
||||
{error}"""
|
||||
|
||||
CODE_GENERATOR_PROMPT = ChatPromptTemplate.from_messages([
|
||||
("system", CODE_GENERATOR_SYSTEM),
|
||||
("human", CODE_GENERATOR_USER),
|
||||
])
|
||||
|
||||
ERROR_CORRECTOR_PROMPT = ChatPromptTemplate.from_messages([
|
||||
("system", ERROR_CORRECTOR_SYSTEM),
|
||||
("human", ERROR_CORRECTOR_USER),
|
||||
])
|
||||
46
backend/src/ea_chatbot/graph/prompts/planner.py
Normal file
46
backend/src/ea_chatbot/graph/prompts/planner.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
|
||||
PLANNER_SYSTEM = """You are a Research Specialist and your job is to find answers and educate the user.
|
||||
Provide factual information responding directly to the user's question. Include key details and context to ensure your response comprehensively answers their query.
|
||||
|
||||
Today's Date is: {date}"""
|
||||
|
||||
PLANNER_USER = """Conversation Summary: {summary}
|
||||
|
||||
TASK:
|
||||
{question}
|
||||
|
||||
AVAILABLE DATA SUMMARY (Use only if relevant to the task):
|
||||
{database_description}
|
||||
|
||||
First: Evaluate whether you have all necessary and requested information to provide a solution.
|
||||
Use the dataset description above to determine what data and in what format you have available to you.
|
||||
You are able to search internet if the user asks for it, or you require any information that you can not derive from the given dataset or the instruction.
|
||||
|
||||
Second: Incorporate any additional relevant context, reasoning, or details from previous interactions or internal chain-of-thought that may impact the solution.
|
||||
Ensure that all such information is fully included in your response rather than referring to previous answers indirectly.
|
||||
|
||||
Third: Reflect on the problem and briefly describe it, while addressing the problem goal, inputs, outputs,
|
||||
rules, constraints, and other relevant details that appear in the problem description.
|
||||
|
||||
Fourth: Based on the preceding steps, formulate your response as an algorithm, breaking the solution in up to eight simple concise yet descriptive, clear English steps.
|
||||
You MUST Include all values or instructions as described in the above task, or retrieved using internet search!
|
||||
If fewer steps suffice, that's acceptable. If more are needed, please include them.
|
||||
Remember to explain steps rather than write code.
|
||||
|
||||
This algorithm will be later converted to Python code.
|
||||
If a dataframe is required, assume it is named 'df' and is already defined/populated based on the data summary.
|
||||
|
||||
There is a list variable called `plots` that you need to use to store any plots you generate. Do not use `plt.show()` as it will render the plot and cause an error.
|
||||
|
||||
Output the algorithm as a YAML string. Always enclose the YAML string within ```yaml tags.
|
||||
|
||||
**Note: Ensure that any necessary context from prior interactions is fully embedded in the plan. Do not use phrases like "refer to previous answer"; instead, provide complete details inline.**
|
||||
|
||||
{example_plan}"""
|
||||
|
||||
PLANNER_PROMPT = ChatPromptTemplate.from_messages([
|
||||
("system", PLANNER_SYSTEM),
|
||||
MessagesPlaceholder(variable_name="history"),
|
||||
("human", PLANNER_USER),
|
||||
])
|
||||
33
backend/src/ea_chatbot/graph/prompts/query_analyzer.py
Normal file
33
backend/src/ea_chatbot/graph/prompts/query_analyzer.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
|
||||
SYSTEM_PROMPT = """You are an expert election data analyst. Decompose the user's question into key elements to determine the next action.
|
||||
|
||||
### Context & Defaults
|
||||
- **History:** Use the conversation history and summary to resolve coreferences (e.g., "those results", "that state"). Assume the current question inherits missing context (Year, State, County) from history.
|
||||
- **Data Access:** You have access to voter and election databases. Proceed to planning without asking for database or table names.
|
||||
- **Downstream Capabilities:** Visualizations are generated as Matplotlib figures. Proceed to planning for "graphs" or "plots" without asking for file formats or plot types.
|
||||
- **Trends:** For trend requests without a specified interval, allow the Planner to use a sensible default (e.g., by election cycle).
|
||||
|
||||
### Instructions:
|
||||
1. **Analyze:** Identify if the request is for data analysis, general facts (web research), or is critically ambiguous.
|
||||
2. **Extract Entities & Conditions:**
|
||||
- **Data Required:** e.g., "vote count", "demographics".
|
||||
- **Conditions:** e.g., "Year=2024". Include context from history.
|
||||
3. **Identify Target & Critical Ambiguities:**
|
||||
- **Unknowns:** The core target question.
|
||||
- **Critical Ambiguities:** ONLY list issues that PREVENT any analysis.
|
||||
- Examples: No timeframe/geography in query OR history; "track the same voter" without an identity definition.
|
||||
4. **Determine Action:**
|
||||
- `plan`: For data analysis where defaults or history provide sufficient context.
|
||||
- `research`: For general knowledge.
|
||||
- `clarify`: ONLY for CRITICAL ambiguities."""
|
||||
|
||||
USER_PROMPT_TEMPLATE = """Conversation Summary: {summary}
|
||||
|
||||
Analyze the following question: {question}"""
|
||||
|
||||
QUERY_ANALYZER_PROMPT = ChatPromptTemplate.from_messages([
|
||||
("system", SYSTEM_PROMPT),
|
||||
MessagesPlaceholder(variable_name="history"),
|
||||
("human", USER_PROMPT_TEMPLATE),
|
||||
])
|
||||
12
backend/src/ea_chatbot/graph/prompts/researcher.py
Normal file
12
backend/src/ea_chatbot/graph/prompts/researcher.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
|
||||
RESEARCHER_PROMPT = ChatPromptTemplate.from_messages([
|
||||
("system", """You are a Research Specialist and your job is to find answers and educate the user.
|
||||
Provide factual information responding directly to the user's question. Include key details and context to ensure your response comprehensively answers their query.
|
||||
|
||||
Today's Date is: {date}"""),
|
||||
MessagesPlaceholder(variable_name="history"),
|
||||
("user", """Conversation Summary: {summary}
|
||||
|
||||
{question}""")
|
||||
])
|
||||
27
backend/src/ea_chatbot/graph/prompts/summarizer.py
Normal file
27
backend/src/ea_chatbot/graph/prompts/summarizer.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
|
||||
SUMMARIZER_PROMPT = ChatPromptTemplate.from_messages([
|
||||
("system", """You are an expert election data analyst providing a final answer to the user.
|
||||
Use the provided conversation history and summary to ensure your response is contextually relevant and flows naturally from previous turns.
|
||||
|
||||
Conversation Summary: {summary}"""),
|
||||
MessagesPlaceholder(variable_name="history"),
|
||||
("user", """The user presented you with the following question.
|
||||
Question: {question}
|
||||
|
||||
To address this, you have designed an algorithm.
|
||||
Algorithm: {plan}.
|
||||
|
||||
You have crafted a Python code based on this algorithm, and the output generated by the code's execution is as follows.
|
||||
Output: {code_output}.
|
||||
|
||||
Please produce a comprehensive, easy-to-understand answer that:
|
||||
1. Summarizes the main insights or conclusions achieved through your method's implementation. Include execution results if necessary.
|
||||
2. Includes relevant findings from the code execution in a clear format (e.g., text explanation, tables, lists, bullet points).
|
||||
- Avoid referencing the code or output as 'the above results' or saying 'it's in the code output.'
|
||||
- Instead, present the actual key data or statistics within your explanation.
|
||||
3. If the user requested specific information that does not appear in the code's output but you can provide it, include that information directly in your summary.
|
||||
4. Present any data or tables that might have been generated by the code in full, since the user cannot directly see the execution output.
|
||||
|
||||
Your goal is to give a final answer that stands on its own without requiring the user to see the code or raw output directly.""")
|
||||
])
|
||||
36
backend/src/ea_chatbot/graph/state.py
Normal file
36
backend/src/ea_chatbot/graph/state.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from typing import TypedDict, Annotated, List, Dict, Any, Optional
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain.agents import AgentState as AS
|
||||
import operator
|
||||
|
||||
class AgentState(AS):
|
||||
# Conversation history
|
||||
messages: Annotated[List[BaseMessage], operator.add]
|
||||
|
||||
# Task context
|
||||
question: str
|
||||
|
||||
# Query Analysis (Decomposition results)
|
||||
analysis: Optional[Dict[str, Any]]
|
||||
# Expected keys: "requires_dataset", "expert", "data", "unknown", "condition"
|
||||
|
||||
# Step-by-step reasoning
|
||||
plan: Optional[str]
|
||||
|
||||
# Code execution context
|
||||
code: Optional[str]
|
||||
code_output: Optional[str]
|
||||
error: Optional[str]
|
||||
|
||||
# Artifacts (for UI display)
|
||||
plots: Annotated[List[Any], operator.add] # Matplotlib figures
|
||||
dfs: Dict[str, Any] # Pandas DataFrames
|
||||
|
||||
# Conversation summary
|
||||
summary: Optional[str]
|
||||
|
||||
# Routing hint: "clarify", "plan", "research", "end"
|
||||
next_action: str
|
||||
|
||||
# Number of execution attempts
|
||||
iterations: int
|
||||
92
backend/src/ea_chatbot/graph/workflow.py
Normal file
92
backend/src/ea_chatbot/graph/workflow.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from langgraph.graph import StateGraph, END
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node
|
||||
from ea_chatbot.graph.nodes.planner import planner_node
|
||||
from ea_chatbot.graph.nodes.coder import coder_node
|
||||
from ea_chatbot.graph.nodes.error_corrector import error_corrector_node
|
||||
from ea_chatbot.graph.nodes.executor import executor_node
|
||||
from ea_chatbot.graph.nodes.summarizer import summarizer_node
|
||||
from ea_chatbot.graph.nodes.researcher import researcher_node
|
||||
from ea_chatbot.graph.nodes.clarification import clarification_node
|
||||
from ea_chatbot.graph.nodes.summarize_conversation import summarize_conversation_node
|
||||
|
||||
MAX_ITERATIONS = 3
|
||||
|
||||
def router(state: AgentState) -> str:
|
||||
"""Route to the next node based on the analysis."""
|
||||
next_action = state.get("next_action")
|
||||
if next_action == "plan":
|
||||
return "planner"
|
||||
elif next_action == "research":
|
||||
return "researcher"
|
||||
elif next_action == "clarify":
|
||||
return "clarification"
|
||||
else:
|
||||
return END
|
||||
|
||||
def create_workflow():
|
||||
"""Create the LangGraph workflow."""
|
||||
workflow = StateGraph(AgentState)
|
||||
|
||||
# Add nodes
|
||||
workflow.add_node("query_analyzer", query_analyzer_node)
|
||||
workflow.add_node("planner", planner_node)
|
||||
workflow.add_node("coder", coder_node)
|
||||
workflow.add_node("error_corrector", error_corrector_node)
|
||||
workflow.add_node("researcher", researcher_node)
|
||||
workflow.add_node("clarification", clarification_node)
|
||||
workflow.add_node("executor", executor_node)
|
||||
workflow.add_node("summarizer", summarizer_node)
|
||||
workflow.add_node("summarize_conversation", summarize_conversation_node)
|
||||
|
||||
# Set entry point
|
||||
workflow.set_entry_point("query_analyzer")
|
||||
|
||||
# Add conditional edges from query_analyzer
|
||||
workflow.add_conditional_edges(
|
||||
"query_analyzer",
|
||||
router,
|
||||
{
|
||||
"planner": "planner",
|
||||
"researcher": "researcher",
|
||||
"clarification": "clarification",
|
||||
END: END
|
||||
}
|
||||
)
|
||||
|
||||
# Linear flow for planning and coding
|
||||
workflow.add_edge("planner", "coder")
|
||||
workflow.add_edge("coder", "executor")
|
||||
|
||||
# Executor routing
|
||||
def executor_router(state: AgentState) -> str:
|
||||
if state.get("error"):
|
||||
# Check for iteration limit to prevent infinite loops
|
||||
if state.get("iterations", 0) >= MAX_ITERATIONS:
|
||||
return "summarizer"
|
||||
return "error_corrector"
|
||||
return "summarizer"
|
||||
|
||||
workflow.add_conditional_edges(
|
||||
"executor",
|
||||
executor_router,
|
||||
{
|
||||
"error_corrector": "error_corrector",
|
||||
"summarizer": "summarizer"
|
||||
}
|
||||
)
|
||||
|
||||
workflow.add_edge("error_corrector", "executor")
|
||||
|
||||
workflow.add_edge("researcher", "summarize_conversation")
|
||||
workflow.add_edge("clarification", END)
|
||||
workflow.add_edge("summarizer", "summarize_conversation")
|
||||
workflow.add_edge("summarize_conversation", END)
|
||||
|
||||
# Compile the graph
|
||||
app = workflow.compile()
|
||||
|
||||
return app
|
||||
|
||||
# Initialize the app
|
||||
app = create_workflow()
|
||||
0
backend/src/ea_chatbot/history/__init__.py
Normal file
0
backend/src/ea_chatbot/history/__init__.py
Normal file
188
backend/src/ea_chatbot/history/manager.py
Normal file
188
backend/src/ea_chatbot/history/manager.py
Normal file
@@ -0,0 +1,188 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, List
|
||||
from sqlalchemy import create_engine, select, delete
|
||||
from sqlalchemy.orm import sessionmaker, Session
|
||||
from argon2 import PasswordHasher
|
||||
from argon2.exceptions import VerifyMismatchError
|
||||
|
||||
from ea_chatbot.history.models import User, Conversation, Message, Plot
|
||||
|
||||
# Argon2 Password Hasher
|
||||
ph = PasswordHasher()
|
||||
|
||||
class HistoryManager:
|
||||
"""Manages database sessions and operations for history and user data."""
|
||||
|
||||
def __init__(self, db_url: str):
|
||||
self.engine = create_engine(db_url)
|
||||
# expire_on_commit=False is important so we can use objects after session closes
|
||||
self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self.engine, expire_on_commit=False)
|
||||
|
||||
@contextmanager
|
||||
def get_session(self):
|
||||
"""Context manager for database sessions."""
|
||||
session = self.SessionLocal()
|
||||
try:
|
||||
yield session
|
||||
session.commit()
|
||||
except Exception:
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
# --- User Management ---
|
||||
|
||||
def get_user(self, email: str) -> Optional[User]:
|
||||
"""Fetch a user by their email (username)."""
|
||||
with self.get_session() as session:
|
||||
result = session.execute(select(User).where(User.username == email))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
def get_user_by_id(self, user_id: str) -> Optional[User]:
|
||||
"""Fetch a user by their ID."""
|
||||
with self.get_session() as session:
|
||||
return session.get(User, user_id)
|
||||
|
||||
def create_user(self, email: str, password: Optional[str] = None, display_name: Optional[str] = None) -> User:
|
||||
"""Create a new local user."""
|
||||
hashed_password = ph.hash(password) if password else None
|
||||
user = User(
|
||||
username=email,
|
||||
password_hash=hashed_password,
|
||||
display_name=display_name or email.split("@")[0]
|
||||
)
|
||||
with self.get_session() as session:
|
||||
session.add(user)
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
return user
|
||||
|
||||
def authenticate_user(self, email: str, password: str) -> Optional[User]:
|
||||
"""Authenticate a user by email and password."""
|
||||
user = self.get_user(email)
|
||||
if not user or not user.password_hash:
|
||||
return None
|
||||
|
||||
try:
|
||||
ph.verify(user.password_hash, password)
|
||||
return user
|
||||
except VerifyMismatchError:
|
||||
return None
|
||||
|
||||
def sync_user_from_oidc(self, email: str, display_name: Optional[str] = None) -> User:
|
||||
"""
|
||||
Synchronize a user from an OIDC provider.
|
||||
If a user with the same email exists, update their display name.
|
||||
Otherwise, create a new user.
|
||||
"""
|
||||
user = self.get_user(email)
|
||||
if user:
|
||||
# Update existing user if needed
|
||||
if display_name and user.display_name != display_name:
|
||||
with self.get_session() as session:
|
||||
db_user = session.get(User, user.id)
|
||||
db_user.display_name = display_name
|
||||
session.commit()
|
||||
session.refresh(db_user)
|
||||
return db_user
|
||||
return user
|
||||
else:
|
||||
# Create new user (no password for OIDC users initially)
|
||||
return self.create_user(email=email, display_name=display_name)
|
||||
|
||||
# --- Conversation Management ---
|
||||
|
||||
def create_conversation(self, user_id: str, data_state: str, name: str, summary: Optional[str] = None) -> Conversation:
|
||||
"""Create a new conversation for a user."""
|
||||
conv = Conversation(
|
||||
user_id=user_id,
|
||||
data_state=data_state,
|
||||
name=name,
|
||||
summary=summary
|
||||
)
|
||||
with self.get_session() as session:
|
||||
session.add(conv)
|
||||
session.commit()
|
||||
session.refresh(conv)
|
||||
return conv
|
||||
|
||||
def get_conversations(self, user_id: str, data_state: str) -> List[Conversation]:
|
||||
"""Get all conversations for a user and data state, ordered by creation time."""
|
||||
with self.get_session() as session:
|
||||
stmt = (
|
||||
select(Conversation)
|
||||
.where(Conversation.user_id == user_id, Conversation.data_state == data_state)
|
||||
.order_by(Conversation.created_at.desc())
|
||||
)
|
||||
result = session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
def rename_conversation(self, conversation_id: str, new_name: str) -> Optional[Conversation]:
|
||||
"""Rename an existing conversation."""
|
||||
with self.get_session() as session:
|
||||
conv = session.get(Conversation, conversation_id)
|
||||
if conv:
|
||||
conv.name = new_name
|
||||
session.commit()
|
||||
session.refresh(conv)
|
||||
return conv
|
||||
|
||||
def delete_conversation(self, conversation_id: str) -> bool:
|
||||
"""Delete a conversation and its associated messages/plots (via cascade)."""
|
||||
with self.get_session() as session:
|
||||
conv = session.get(Conversation, conversation_id)
|
||||
if conv:
|
||||
session.delete(conv)
|
||||
session.commit()
|
||||
return True
|
||||
return False
|
||||
|
||||
def update_conversation_summary(self, conversation_id: str, summary: str) -> Optional[Conversation]:
|
||||
"""Update the summary of a conversation."""
|
||||
with self.get_session() as session:
|
||||
conv = session.get(Conversation, conversation_id)
|
||||
if conv:
|
||||
conv.summary = summary
|
||||
session.commit()
|
||||
session.refresh(conv)
|
||||
return conv
|
||||
|
||||
# --- Message & Plot Management ---
|
||||
|
||||
def add_message(self, conversation_id: str, role: str, content: str, plots: Optional[List[bytes]] = None) -> Message:
|
||||
"""Add a message to a conversation, optionally with plots."""
|
||||
msg = Message(
|
||||
conversation_id=conversation_id,
|
||||
role=role,
|
||||
content=content
|
||||
)
|
||||
with self.get_session() as session:
|
||||
session.add(msg)
|
||||
session.flush() # Populate msg.id for plots
|
||||
|
||||
if plots:
|
||||
for plot_data in plots:
|
||||
plot = Plot(message_id=msg.id, image_data=plot_data)
|
||||
session.add(plot)
|
||||
|
||||
session.commit()
|
||||
session.refresh(msg)
|
||||
# Ensure plots are loaded before session closes if we need them
|
||||
_ = msg.plots
|
||||
return msg
|
||||
|
||||
def get_messages(self, conversation_id: str) -> List[Message]:
|
||||
"""Get all messages for a conversation, ordered by creation time."""
|
||||
with self.get_session() as session:
|
||||
stmt = (
|
||||
select(Message)
|
||||
.where(Message.conversation_id == conversation_id)
|
||||
.order_by(Message.created_at.asc())
|
||||
)
|
||||
result = session.execute(stmt)
|
||||
messages = list(result.scalars().all())
|
||||
# Pre-load plots for each message
|
||||
for m in messages:
|
||||
_ = m.plots
|
||||
return messages
|
||||
52
backend/src/ea_chatbot/history/models.py
Normal file
52
backend/src/ea_chatbot/history/models.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from typing import List, Optional
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
from sqlalchemy import String, ForeignKey, DateTime, LargeBinary, Text
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
username: Mapped[str] = mapped_column(String, unique=True, index=True)
|
||||
password_hash: Mapped[Optional[str]] = mapped_column(String, nullable=True)
|
||||
display_name: Mapped[Optional[str]] = mapped_column(String, nullable=True)
|
||||
|
||||
conversations: Mapped[List["Conversation"]] = relationship(back_populates="user", cascade="all, delete-orphan")
|
||||
|
||||
class Conversation(Base):
|
||||
__tablename__ = "conversations"
|
||||
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
user_id: Mapped[str] = mapped_column(ForeignKey("users.id"))
|
||||
data_state: Mapped[str] = mapped_column(String)
|
||||
name: Mapped[str] = mapped_column(String)
|
||||
summary: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||
|
||||
user: Mapped["User"] = relationship(back_populates="conversations")
|
||||
messages: Mapped[List["Message"]] = relationship(back_populates="conversation", cascade="all, delete-orphan")
|
||||
|
||||
class Message(Base):
|
||||
__tablename__ = "messages"
|
||||
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
conversation_id: Mapped[str] = mapped_column(ForeignKey("conversations.id"))
|
||||
role: Mapped[str] = mapped_column(String)
|
||||
content: Mapped[str] = mapped_column(Text)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||
|
||||
conversation: Mapped["Conversation"] = relationship(back_populates="messages")
|
||||
plots: Mapped[List["Plot"]] = relationship(back_populates="message", cascade="all, delete-orphan")
|
||||
|
||||
class Plot(Base):
|
||||
__tablename__ = "plots"
|
||||
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
message_id: Mapped[str] = mapped_column(ForeignKey("messages.id"))
|
||||
image_data: Mapped[bytes] = mapped_column(LargeBinary)
|
||||
|
||||
message: Mapped["Message"] = relationship(back_populates="plots")
|
||||
83
backend/src/ea_chatbot/schemas.py
Normal file
83
backend/src/ea_chatbot/schemas.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from pydantic import BaseModel, Field, computed_field
|
||||
from typing import Sequence, Optional
|
||||
import re
|
||||
|
||||
class TaskPlanContext(BaseModel):
|
||||
'''Background context relevant to the task plan'''
|
||||
initial_context: str = Field(
|
||||
min_length=1,
|
||||
description="Background information about the database/tables and previous conversations relevant to the task.",
|
||||
)
|
||||
assumptions: Sequence[str] = Field(
|
||||
description="Assumptions made while working on the task.",
|
||||
)
|
||||
constraints: Optional[Sequence[str]] = Field(
|
||||
description="Constraints that apply to the task.",
|
||||
)
|
||||
|
||||
class TaskPlanResponse(BaseModel):
|
||||
'''Structured plan to achieve the task objective'''
|
||||
goal: str = Field(
|
||||
min_length=1,
|
||||
description="Single-sentence objective the plan must achieve.",
|
||||
)
|
||||
reflection: str = Field(
|
||||
min_length=1,
|
||||
description="High-level natural-language reasoning describing the user's request and the intended solution approach.",
|
||||
)
|
||||
context: TaskPlanContext = Field(
|
||||
description="Background context relevant to the task plan.",
|
||||
)
|
||||
steps: Sequence[str] = Field(
|
||||
min_length=1,
|
||||
description="Ordered list of steps to execute that follow the 'Step <number>: <detail>' pattern.",
|
||||
)
|
||||
|
||||
_IM_SEP_TOKEN_PATTERN = re.compile(re.escape("<|im_sep|>"))
|
||||
_CODE_BLOCK_PATTERN = re.compile(r"```(?:python\s*)?(.*?)\s*```", re.DOTALL)
|
||||
_FORBIDDEN_MODULES = (
|
||||
"subprocess",
|
||||
"sys",
|
||||
"eval",
|
||||
"exec",
|
||||
"socket",
|
||||
"urllib",
|
||||
"shutil",
|
||||
"pickle",
|
||||
"ctypes",
|
||||
"multiprocessing",
|
||||
"tempfile",
|
||||
"glob",
|
||||
"pty",
|
||||
"commands",
|
||||
"cgi",
|
||||
"cgitb",
|
||||
"xml.etree.ElementTree",
|
||||
"builtins",
|
||||
)
|
||||
_FORBIDDEN_MODULE_PATTERN = re.compile(
|
||||
r"^((?:[^#].*)?\b(" + "|".join(map(re.escape, _FORBIDDEN_MODULES)) + r")\b.*)$",
|
||||
flags=re.MULTILINE,
|
||||
)
|
||||
|
||||
class CodeGenerationResponse(BaseModel):
|
||||
'''Code generation response structure'''
|
||||
code: str = Field(description="The generated code snippet to accomplish the task")
|
||||
explanation: str = Field(description="Explanation of the generated code and its functionality")
|
||||
|
||||
@computed_field(return_type=str)
|
||||
@property
|
||||
def parsed_code(self) -> str:
|
||||
'''Extracts the code snippet without any surrounding text'''
|
||||
normalised = _IM_SEP_TOKEN_PATTERN.sub("```", self.code).strip()
|
||||
match = _CODE_BLOCK_PATTERN.search(normalised)
|
||||
candidate = match.group(1).strip() if match else normalised
|
||||
sanitised = _FORBIDDEN_MODULE_PATTERN.sub(r"# not allowed \1", candidate)
|
||||
return sanitised.strip()
|
||||
|
||||
class RankResponse(BaseModel):
|
||||
'''Code ranking response structure'''
|
||||
rank: int = Field(
|
||||
ge=1, le=10,
|
||||
description="Rank of the code snippet from 1 (best) to 10 (worst)"
|
||||
)
|
||||
24
backend/src/ea_chatbot/types.py
Normal file
24
backend/src/ea_chatbot/types.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from typing import TypedDict, Optional
|
||||
from enum import StrEnum
|
||||
|
||||
class DBSettings(TypedDict):
|
||||
host: str
|
||||
port: int
|
||||
user: str
|
||||
pswd: str
|
||||
db: str
|
||||
table: Optional[str]
|
||||
|
||||
class Agent(StrEnum):
|
||||
EXPERT_SELECTOR = "Expert Selector"
|
||||
ANALYST_SELECTOR = "Analyst Selector"
|
||||
THEORIST = "Theorist"
|
||||
THEORIST_WEB = "Theorist-Web"
|
||||
THEORIST_CLARIFICATION = "Theorist-Clarification"
|
||||
PLANNER = "Planner"
|
||||
CODE_GENERATOR = "Code Generator"
|
||||
CODE_DEBUGGER = "Code Debugger"
|
||||
CODE_EXECUTOR = "Code Executor"
|
||||
ERROR_CORRECTOR = "Error Corrector"
|
||||
CODE_RANKER = "Code Ranker"
|
||||
SOLUTION_SUMMARIZER = "Solution Summarizer"
|
||||
12
backend/src/ea_chatbot/utils/__init__.py
Normal file
12
backend/src/ea_chatbot/utils/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from .db_client import DBClient
|
||||
from .llm_factory import get_llm_model
|
||||
from .logging import get_logger, LangChainLoggingHandler
|
||||
from . import helpers
|
||||
|
||||
__all__ = [
|
||||
"DBClient",
|
||||
"get_llm_model",
|
||||
"get_logger",
|
||||
"LangChainLoggingHandler",
|
||||
"helpers"
|
||||
]
|
||||
234
backend/src/ea_chatbot/utils/database_inspection.py
Normal file
234
backend/src/ea_chatbot/utils/database_inspection.py
Normal file
@@ -0,0 +1,234 @@
|
||||
from typing import Optional, Dict, Any, List, TYPE_CHECKING
|
||||
import yaml
|
||||
import json
|
||||
import os
|
||||
from ea_chatbot.utils.db_client import DBClient
|
||||
if TYPE_CHECKING:
|
||||
from ea_chatbot.types import DBSettings
|
||||
|
||||
def _get_table_checksum(db_client: DBClient, table: str) -> str:
|
||||
"""Calculates the checksum of the table using DML statistics from pg_stat_user_tables."""
|
||||
query = f"""
|
||||
SELECT md5(concat_ws('|', n_tup_ins, n_tup_upd, n_tup_del)) AS dml_hash
|
||||
FROM pg_stat_user_tables
|
||||
WHERE schemaname = 'public' AND relname = '{table}';"""
|
||||
try:
|
||||
return str(db_client.query_df(query).iloc[0, 0])
|
||||
except Exception:
|
||||
return "unknown_checksum"
|
||||
|
||||
def _update_checksum_file(filepath: str, table: str, checksum: str):
|
||||
"""Updates the checksum file with the new checksum for the table."""
|
||||
checksums = {}
|
||||
if os.path.exists(filepath):
|
||||
with open(filepath, 'r') as f:
|
||||
for line in f:
|
||||
if ':' in line:
|
||||
k, v = line.strip().split(':', 1)
|
||||
checksums[k] = v
|
||||
|
||||
checksums[table] = checksum
|
||||
|
||||
with open(filepath, 'w') as f:
|
||||
for k, v in checksums.items():
|
||||
f.write(f"{k}:{v}")
|
||||
|
||||
def get_data_summary(data_dir: str = "data") -> Optional[str]:
|
||||
"""
|
||||
Reads the inspection.yaml file and returns its content as a string.
|
||||
"""
|
||||
inspection_file = os.path.join(data_dir, "inspection.yaml")
|
||||
if os.path.exists(inspection_file):
|
||||
with open(inspection_file, 'r') as f:
|
||||
return f.read()
|
||||
return None
|
||||
|
||||
def get_primary_key(db_client: DBClient, table_name: str) -> Optional[str]:
|
||||
"""
|
||||
Dynamically identifies the primary key of the table.
|
||||
Returns the column name of the primary key, or None if not found.
|
||||
"""
|
||||
query = f"""
|
||||
SELECT kcu.column_name
|
||||
FROM information_schema.key_column_usage AS kcu
|
||||
JOIN information_schema.table_constraints AS tc
|
||||
ON kcu.constraint_name = tc.constraint_name
|
||||
AND kcu.table_schema = tc.table_schema
|
||||
WHERE kcu.table_name = '{table_name}'
|
||||
AND tc.constraint_type = 'PRIMARY KEY'
|
||||
LIMIT 1;
|
||||
"""
|
||||
try:
|
||||
df = db_client.query_df(query)
|
||||
if not df.empty:
|
||||
return str(df.iloc[0, 0])
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not determine primary key for {table_name}: {e}")
|
||||
return None
|
||||
|
||||
def inspect_db_table(
|
||||
db_client: Optional[DBClient]=None,
|
||||
db_settings: Optional["DBSettings"]=None,
|
||||
data_dir: str = "data",
|
||||
force_update: bool = False
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Inspects the database table, generates statistics for each column,
|
||||
and saves the inspection results to a YAML file locally.
|
||||
|
||||
Improvements:
|
||||
- Dynamic Primary Key Discovery
|
||||
- Cardinality (Unique Counts)
|
||||
- Categorical Sample Values for low cardinality columns
|
||||
- Robust Quoting
|
||||
"""
|
||||
inspection_file = os.path.join(data_dir, "inspection.yaml")
|
||||
checksum_file = os.path.join(data_dir, "checksum")
|
||||
|
||||
# Initialize DB Client
|
||||
if db_client is None:
|
||||
if db_settings is None:
|
||||
print("Error: Either db_client or db_settings must be provided.")
|
||||
return None
|
||||
try:
|
||||
db_client = DBClient(db_settings)
|
||||
except Exception as e:
|
||||
print(f"Failed to create DBClient: {e}")
|
||||
return None
|
||||
|
||||
table_name = db_client.settings.get('table')
|
||||
if not table_name:
|
||||
print("Error: Table name must be specified in DBSettings.")
|
||||
return None
|
||||
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
|
||||
# Checksum verification
|
||||
new_checksum = _get_table_checksum(db_client, table_name)
|
||||
has_changed = True
|
||||
|
||||
if os.path.exists(checksum_file):
|
||||
try:
|
||||
with open(checksum_file, 'r') as f:
|
||||
saved_checksums = f.read().strip()
|
||||
if f"{table_name}:{new_checksum}" in saved_checksums:
|
||||
has_changed = False
|
||||
except Exception:
|
||||
pass # Force update on read error
|
||||
|
||||
if not has_changed and not force_update:
|
||||
return get_data_summary(data_dir)
|
||||
|
||||
print(f"Regenerating inspection file for table '{table_name}'...")
|
||||
|
||||
# Fetch Table Metadata
|
||||
try:
|
||||
# Get columns and types
|
||||
columns_query = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}';"
|
||||
columns_df = db_client.query_df(columns_query)
|
||||
|
||||
# Get Row Counts
|
||||
total_rows_df = db_client.query_df(f'SELECT COUNT(*) FROM "{table_name}"')
|
||||
total_rows = int(total_rows_df.iloc[0, 0])
|
||||
|
||||
# Dynamic Primary Key
|
||||
primary_key = get_primary_key(db_client, table_name)
|
||||
|
||||
# Get First/Last Rows (if PK exists)
|
||||
first_row_df = None
|
||||
last_row_df = None
|
||||
if primary_key:
|
||||
first_row_df = db_client.query_df(f'SELECT * FROM "{table_name}" ORDER BY "{primary_key}" ASC LIMIT 1')
|
||||
last_row_df = db_client.query_df(f'SELECT * FROM "{table_name}" ORDER BY "{primary_key}" DESC LIMIT 1')
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to retrieve basic table info: {e}")
|
||||
return None
|
||||
|
||||
stats_dict: Dict[str, Any] = {}
|
||||
if primary_key:
|
||||
stats_dict['primary_key'] = primary_key
|
||||
|
||||
for _, row in columns_df.iterrows():
|
||||
col_name = row['column_name']
|
||||
dtype = row['data_type']
|
||||
|
||||
try:
|
||||
# Count Values
|
||||
# Using robust quoting
|
||||
count_df = db_client.query_df(f'SELECT COUNT("{col_name}") FROM "{table_name}"')
|
||||
count_val = int(count_df.iloc[0,0])
|
||||
|
||||
# Count Unique (Cardinality)
|
||||
unique_df = db_client.query_df(f'SELECT COUNT(DISTINCT "{col_name}") FROM "{table_name}"')
|
||||
unique_count = int(unique_df.iloc[0,0])
|
||||
|
||||
col_stats: Dict[str, Any] = {
|
||||
'dtype': dtype,
|
||||
'count_of_values': count_val,
|
||||
'count_of_nulls': total_rows - count_val,
|
||||
'unique_count': unique_count
|
||||
}
|
||||
|
||||
if count_val == 0:
|
||||
stats_dict[col_name] = col_stats
|
||||
continue
|
||||
|
||||
# Numerical Stats
|
||||
if any(t in dtype for t in ('int', 'float', 'numeric', 'double', 'real', 'decimal')):
|
||||
stats_query = f'SELECT AVG("{col_name}"), MIN("{col_name}"), MAX("{col_name}") FROM "{table_name}"'
|
||||
stats_df = db_client.query_df(stats_query)
|
||||
if not stats_df.empty:
|
||||
col_stats['mean'] = float(stats_df.iloc[0,0]) if stats_df.iloc[0,0] is not None else None
|
||||
col_stats['min'] = float(stats_df.iloc[0,1]) if stats_df.iloc[0,1] is not None else None
|
||||
col_stats['max'] = float(stats_df.iloc[0,2]) if stats_df.iloc[0,2] is not None else None
|
||||
|
||||
# Temporal Stats
|
||||
elif any(t in dtype for t in ('date', 'timestamp')):
|
||||
stats_query = f'SELECT MIN("{col_name}"), MAX("{col_name}") FROM "{table_name}"'
|
||||
stats_df = db_client.query_df(stats_query)
|
||||
if not stats_df.empty:
|
||||
col_stats['min'] = str(stats_df.iloc[0,0])
|
||||
col_stats['max'] = str(stats_df.iloc[0,1])
|
||||
|
||||
# Categorical/Text Stats
|
||||
else:
|
||||
# Sample values if cardinality is low (< 20)
|
||||
if 0 < unique_count < 20:
|
||||
distinct_query = f'SELECT DISTINCT "{col_name}" FROM "{table_name}" ORDER BY "{col_name}" LIMIT 20'
|
||||
distinct_df = db_client.query_df(distinct_query)
|
||||
col_stats['distinct_values'] = distinct_df.iloc[:, 0].tolist()
|
||||
|
||||
if first_row_df is not None and not first_row_df.empty and col_name in first_row_df.columns:
|
||||
col_stats['first_value'] = str(first_row_df.iloc[0][col_name])
|
||||
if last_row_df is not None and not last_row_df.empty and col_name in last_row_df.columns:
|
||||
col_stats['last_value'] = str(last_row_df.iloc[0][col_name])
|
||||
|
||||
stats_dict[col_name] = col_stats
|
||||
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not process column {col_name}: {e}")
|
||||
|
||||
# Load existing inspections to merge (if multiple tables)
|
||||
existing_inspections = {}
|
||||
if os.path.exists(inspection_file):
|
||||
try:
|
||||
with open(inspection_file, 'r') as f:
|
||||
existing_inspections = yaml.safe_load(f) or {}
|
||||
# Backup old file
|
||||
os.rename(inspection_file, inspection_file + ".old")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
existing_inspections[table_name] = stats_dict
|
||||
|
||||
# Save new inspection
|
||||
inspection_content = yaml.dump(existing_inspections, sort_keys=False, default_flow_style=False)
|
||||
with open(inspection_file, 'w') as f:
|
||||
f.write(inspection_content)
|
||||
|
||||
# Update Checksum
|
||||
_update_checksum_file(checksum_file, table_name, new_checksum)
|
||||
|
||||
print(f"Inspection saved to {inspection_file}")
|
||||
return inspection_content
|
||||
21
backend/src/ea_chatbot/utils/db_client.py
Normal file
21
backend/src/ea_chatbot/utils/db_client.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
import pandas as pd
|
||||
from sqlalchemy import create_engine, text
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ea_chatbot.types import DBSettings
|
||||
|
||||
class DBClient:
|
||||
def __init__(self, settings: "DBSettings"):
|
||||
self.settings = settings
|
||||
self._engine = self._create_engine()
|
||||
|
||||
def _create_engine(self):
|
||||
url = f"postgresql://{self.settings['user']}:{self.settings['pswd']}@{self.settings['host']}:{self.settings['port']}/{self.settings['db']}"
|
||||
return create_engine(url)
|
||||
|
||||
def query_df(self, sql: str, params: Optional[dict] = None) -> pd.DataFrame:
|
||||
with self._engine.connect() as conn:
|
||||
result = conn.execute(text(sql), params or {})
|
||||
df = pd.DataFrame(result.fetchall(), columns=result.keys())
|
||||
return df
|
||||
73
backend/src/ea_chatbot/utils/helpers.py
Normal file
73
backend/src/ea_chatbot/utils/helpers.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from typing import Optional, TYPE_CHECKING, Dict, Any
|
||||
from datetime import datetime, timezone
|
||||
import yaml
|
||||
import json
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
|
||||
def ordinal(n: int) -> str:
|
||||
return f"{n}{'th' if 11<=n<=13 else {1:'st',2:'nd',3:'rd'}.get(n%10, 'th')}"
|
||||
|
||||
def get_readable_date(date_obj: Optional[datetime] = None, tz: Optional[timezone] = None) -> str:
|
||||
if date_obj is None:
|
||||
date_obj = datetime.now(timezone.utc)
|
||||
if tz:
|
||||
date_obj = date_obj.astimezone(tz)
|
||||
return date_obj.strftime(f"%a {ordinal(date_obj.day)} of %b %Y")
|
||||
|
||||
def to_yaml(json_str: str, indent: int = 2) -> str:
|
||||
"""
|
||||
Attempts to convert a JSON string (potentially malformed from LLM) to a YAML string.
|
||||
"""
|
||||
if not json_str: return ""
|
||||
|
||||
try:
|
||||
# Try direct parse
|
||||
data = json.loads(json_str)
|
||||
except json.JSONDecodeError:
|
||||
# Try simplified repair: replace single quotes
|
||||
try:
|
||||
cleaned = json_str.replace("'", '"')
|
||||
data = json.loads(cleaned)
|
||||
except Exception:
|
||||
# Fallback: return raw string if unparseable
|
||||
return json_str
|
||||
|
||||
return yaml.dump(data, indent=indent, sort_keys=False)
|
||||
|
||||
def merge_agent_state(current_state: "AgentState", update: Dict[str, Any]) -> "AgentState":
|
||||
"""
|
||||
Merges a partial state update into the current state, mimicking LangGraph reduction logic.
|
||||
- Lists (messages, plots) are appended.
|
||||
- Dictionaries (dfs) are shallow merged.
|
||||
- Other fields are overwritten.
|
||||
"""
|
||||
new_state = current_state.copy()
|
||||
|
||||
for key, value in update.items():
|
||||
if value is None:
|
||||
new_state[key] = None
|
||||
continue
|
||||
|
||||
# Accumulate lists (messages, plots)
|
||||
if key in ["messages", "plots"] and isinstance(value, list):
|
||||
current_list = new_state.get(key, [])
|
||||
if not isinstance(current_list, list):
|
||||
current_list = []
|
||||
new_state[key] = current_list + value
|
||||
|
||||
# Shallow merge dictionaries (dfs)
|
||||
elif key == "dfs" and isinstance(value, dict):
|
||||
current_dict = new_state.get(key, {})
|
||||
if not isinstance(current_dict, dict):
|
||||
current_dict = {}
|
||||
merged_dict = current_dict.copy()
|
||||
merged_dict.update(value)
|
||||
new_state[key] = merged_dict
|
||||
|
||||
# Overwrite everything else
|
||||
else:
|
||||
new_state[key] = value
|
||||
|
||||
return new_state
|
||||
36
backend/src/ea_chatbot/utils/llm_factory.py
Normal file
36
backend/src/ea_chatbot/utils/llm_factory.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from typing import Optional, cast, TYPE_CHECKING, Literal, Dict, List, Tuple, Any
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from ea_chatbot.config import LLMConfig
|
||||
|
||||
def get_llm_model(config: LLMConfig, callbacks: Optional[List[BaseCallbackHandler]] = None) -> BaseChatModel:
|
||||
"""
|
||||
Factory function to get a LangChain chat model based on configuration.
|
||||
|
||||
Args:
|
||||
config: LLMConfig object containing model settings.
|
||||
callbacks: Optional list of LangChain callback handlers.
|
||||
|
||||
Returns:
|
||||
Initialized BaseChatModel instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If the provider is not supported.
|
||||
"""
|
||||
params = {
|
||||
"temperature": config.temperature,
|
||||
"max_tokens": config.max_tokens,
|
||||
**config.provider_specific
|
||||
}
|
||||
|
||||
# Filter out None values to allow defaults to take over if not specified
|
||||
params = {k: v for k, v in params.items() if v is not None}
|
||||
|
||||
if config.provider.lower() == "openai":
|
||||
return ChatOpenAI(model=config.model, callbacks=callbacks, **params)
|
||||
elif config.provider.lower() == "google" or config.provider.lower() == "google_genai":
|
||||
return ChatGoogleGenerativeAI(model=config.model, callbacks=callbacks, **params)
|
||||
else:
|
||||
raise ValueError(f"Unsupported LLM provider: {config.provider}")
|
||||
141
backend/src/ea_chatbot/utils/logging.py
Normal file
141
backend/src/ea_chatbot/utils/logging.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
from rich.logging import RichHandler
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from typing import Any, Optional, Dict, List
|
||||
|
||||
class LangChainLoggingHandler(BaseCallbackHandler):
|
||||
"""Callback handler for logging LangChain events."""
|
||||
|
||||
def __init__(self, logger: Optional[logging.Logger] = None):
|
||||
self.logger = logger or get_logger("langchain")
|
||||
|
||||
def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> Any:
|
||||
# Serialized might be empty or missing name depending on how it's called
|
||||
model_name = serialized.get("name") or kwargs.get("name") or "LLM"
|
||||
self.logger.info(f"[bold blue]LLM Started:[/bold blue] {model_name}")
|
||||
|
||||
def on_llm_end(self, response: Any, **kwargs: Any) -> Any:
|
||||
llm_output = getattr(response, "llm_output", {}) or {}
|
||||
# Try to find model name in output or use fallback
|
||||
model_name = llm_output.get("model_name") or "LLM"
|
||||
token_usage = llm_output.get("token_usage", {})
|
||||
|
||||
msg = f"[bold green]LLM Ended:[/bold green] {model_name}"
|
||||
if token_usage:
|
||||
prompt = token_usage.get("prompt_tokens", 0)
|
||||
completion = token_usage.get("completion_tokens", 0)
|
||||
total = token_usage.get("total_tokens", 0)
|
||||
msg += f" | [yellow]Tokens: {total}[/yellow] ({prompt} prompt, {completion} completion)"
|
||||
|
||||
self.logger.info(msg)
|
||||
|
||||
def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> Any:
|
||||
self.logger.error(f"[bold red]LLM Error:[/bold red] {str(error)}")
|
||||
|
||||
class ContextLoggerAdapter(logging.LoggerAdapter):
|
||||
"""Adapter to inject contextual metadata into log records."""
|
||||
def process(self, msg: Any, kwargs: Any) -> tuple[Any, Any]:
|
||||
extra = self.extra.copy()
|
||||
if "extra" in kwargs:
|
||||
extra.update(kwargs.pop("extra"))
|
||||
kwargs["extra"] = extra
|
||||
return msg, kwargs
|
||||
|
||||
class FlexibleJSONEncoder(json.JSONEncoder):
|
||||
def default(self, obj: Any) -> Any:
|
||||
if hasattr(obj, 'model_dump'): # Pydantic v2
|
||||
return obj.model_dump()
|
||||
if hasattr(obj, 'dict'): # Pydantic v1
|
||||
return obj.dict()
|
||||
if hasattr(obj, '__dict__'):
|
||||
return self.serialize_custom_object(obj)
|
||||
elif isinstance(obj, dict):
|
||||
return {k: self.default(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
return [self.default(item) for item in obj]
|
||||
return super().default(obj)
|
||||
|
||||
def serialize_custom_object(self, obj: Any) -> dict:
|
||||
obj_dict = obj.__dict__.copy()
|
||||
obj_dict['__custom_class__'] = obj.__class__.__name__
|
||||
return obj_dict
|
||||
|
||||
class JsonFormatter(logging.Formatter):
|
||||
"""Custom JSON formatter for structured logging."""
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
# Standard fields
|
||||
log_record = {
|
||||
"timestamp": self.formatTime(record, self.datefmt),
|
||||
"level": record.levelname,
|
||||
"message": record.getMessage(),
|
||||
"module": record.module,
|
||||
"name": record.name,
|
||||
}
|
||||
|
||||
# Add exception info if present
|
||||
if record.exc_info:
|
||||
log_record["exception"] = self.formatException(record.exc_info)
|
||||
|
||||
# Add all other extra fields from the record
|
||||
# Filter out standard logging attributes
|
||||
standard_attrs = {
|
||||
'args', 'asctime', 'created', 'exc_info', 'exc_text', 'filename',
|
||||
'funcName', 'levelname', 'levelno', 'lineno', 'module',
|
||||
'msecs', 'message', 'msg', 'name', 'pathname', 'process',
|
||||
'processName', 'relativeCreated', 'stack_info', 'thread', 'threadName'
|
||||
}
|
||||
for key, value in record.__dict__.items():
|
||||
if key not in standard_attrs:
|
||||
log_record[key] = value
|
||||
|
||||
return json.dumps(log_record, cls=FlexibleJSONEncoder)
|
||||
|
||||
def get_logger(name: str = "ea_chatbot", level: Optional[str] = None, log_file: Optional[str] = None) -> logging.Logger:
|
||||
"""Get a configured logger with RichHandler and optional Json FileHandler."""
|
||||
# Ensure name starts with ea_chatbot for hierarchy if not already
|
||||
if name != "ea_chatbot" and not name.startswith("ea_chatbot."):
|
||||
full_name = f"ea_chatbot.{name}"
|
||||
else:
|
||||
full_name = name
|
||||
|
||||
logger = logging.getLogger(full_name)
|
||||
|
||||
# Configure root ea_chatbot logger if it hasn't been configured
|
||||
root_logger = logging.getLogger("ea_chatbot")
|
||||
if not root_logger.handlers:
|
||||
# Default to INFO if level not provided
|
||||
log_level = getattr(logging, (level or "INFO").upper(), logging.INFO)
|
||||
root_logger.setLevel(log_level)
|
||||
|
||||
# Console Handler (Rich)
|
||||
rich_handler = RichHandler(
|
||||
rich_tracebacks=True,
|
||||
markup=True,
|
||||
show_time=False,
|
||||
show_path=False
|
||||
)
|
||||
root_logger.addHandler(rich_handler)
|
||||
root_logger.propagate = False
|
||||
|
||||
# Always check if we need to add a FileHandler, even if root is already configured
|
||||
if log_file:
|
||||
existing_file_handlers = [h for h in root_logger.handlers if isinstance(h, RotatingFileHandler)]
|
||||
if not existing_file_handlers:
|
||||
os.makedirs(os.path.dirname(log_file), exist_ok=True)
|
||||
file_handler = RotatingFileHandler(
|
||||
log_file, maxBytes=5*1024*1024, backupCount=3
|
||||
)
|
||||
file_handler.setFormatter(JsonFormatter())
|
||||
root_logger.addHandler(file_handler)
|
||||
|
||||
# Refresh logger object in case it was created before root was configured
|
||||
logger = logging.getLogger(full_name)
|
||||
|
||||
# If level is explicitly provided for a sub-logger, set it
|
||||
if level:
|
||||
logger.setLevel(getattr(logging, level.upper(), logging.INFO))
|
||||
|
||||
return logger
|
||||
64
backend/tests/api/test_agent.py
Normal file
64
backend/tests/api/test_agent.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
from ea_chatbot.api.main import app
|
||||
from ea_chatbot.api.dependencies import get_current_user
|
||||
from ea_chatbot.history.models import User
|
||||
from ea_chatbot.api.utils import create_access_token
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user():
|
||||
return User(id="user-123", username="test@example.com", display_name="Test User")
|
||||
|
||||
@pytest.fixture
|
||||
def auth_header(mock_user):
|
||||
app.dependency_overrides[get_current_user] = lambda: mock_user
|
||||
token = create_access_token(data={"sub": mock_user.id})
|
||||
yield {"Authorization": f"Bearer {token}"}
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
def test_stream_agent_unauthorized():
|
||||
"""Test that streaming requires authentication."""
|
||||
response = client.post("/chat/stream", json={"message": "hello"})
|
||||
assert response.status_code == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_agent_success(auth_header, mock_user):
|
||||
"""Test successful agent streaming with SSE."""
|
||||
# We need to mock the LangGraph app.astream_events
|
||||
mock_events = [
|
||||
{"event": "on_chat_model_start", "name": "gpt-5", "data": {"input": "..."}},
|
||||
{"event": "on_chat_model_stream", "name": "gpt-5", "data": {"chunk": "Hello"}},
|
||||
{"event": "on_chain_end", "name": "agent", "data": {"output": "..."}}
|
||||
]
|
||||
|
||||
async def mock_astream_events(*args, **kwargs):
|
||||
for event in mock_events:
|
||||
yield event
|
||||
|
||||
with patch("ea_chatbot.api.routers.agent.app.astream_events", side_effect=mock_astream_events), \
|
||||
patch("ea_chatbot.api.routers.agent.get_checkpointer") as mock_cp, \
|
||||
patch("ea_chatbot.api.routers.agent.history_manager") as mock_hm:
|
||||
|
||||
mock_cp.return_value.__aenter__.return_value = AsyncMock()
|
||||
|
||||
# Mock session and DB objects
|
||||
mock_session = MagicMock()
|
||||
mock_hm.get_session.return_value.__enter__.return_value = mock_session
|
||||
from ea_chatbot.history.models import Conversation
|
||||
mock_conv = Conversation(id="t1", user_id=mock_user.id)
|
||||
mock_session.get.return_value = mock_conv
|
||||
|
||||
# Using TestClient with a stream context
|
||||
with client.stream("POST", "/chat/stream",
|
||||
json={"message": "hello", "thread_id": "t1"},
|
||||
headers=auth_header) as response:
|
||||
assert response.status_code == 200
|
||||
assert "text/event-stream" in response.headers["content-type"]
|
||||
|
||||
lines = list(response.iter_lines())
|
||||
# Each event should start with 'data: ' and be valid JSON
|
||||
data_lines = [line for line in lines if line.startswith("data: ")]
|
||||
assert len(data_lines) >= len(mock_events)
|
||||
107
backend/tests/api/test_api_auth.py
Normal file
107
backend/tests/api/test_api_auth.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.api.main import app
|
||||
from ea_chatbot.history.models import User
|
||||
|
||||
# We will need to mock HistoryManager and get_db dependencies later
|
||||
# For now, we define the expected behavior of the auth endpoints.
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user():
|
||||
return User(
|
||||
id="user-123",
|
||||
username="test@example.com",
|
||||
display_name="Test User",
|
||||
password_hash="hashed_password"
|
||||
)
|
||||
|
||||
def test_register_user_success():
|
||||
"""Test successful user registration."""
|
||||
# We mock it where it is used in the router
|
||||
with patch("ea_chatbot.api.routers.auth.history_manager") as mock_hm:
|
||||
mock_hm.get_user.return_value = None
|
||||
mock_hm.create_user.return_value = User(id="1", username="new@example.com", display_name="New")
|
||||
|
||||
response = client.post(
|
||||
"/auth/register",
|
||||
json={"email": "new@example.com", "password": "password123", "display_name": "New"}
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
assert response.json()["email"] == "new@example.com"
|
||||
|
||||
def test_login_success():
|
||||
"""Test successful login and JWT return."""
|
||||
with patch("ea_chatbot.api.routers.auth.history_manager") as mock_hm:
|
||||
mock_hm.authenticate_user.return_value = User(id="1", username="test@example.com")
|
||||
|
||||
response = client.post(
|
||||
"/auth/login",
|
||||
data={"username": "test@example.com", "password": "password123"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert "access_token" in response.json()
|
||||
assert response.json()["token_type"] == "bearer"
|
||||
|
||||
def test_login_invalid_credentials():
|
||||
"""Test login with wrong password."""
|
||||
with patch("ea_chatbot.api.routers.auth.history_manager") as mock_hm:
|
||||
mock_hm.authenticate_user.return_value = None
|
||||
|
||||
response = client.post(
|
||||
"/auth/login",
|
||||
data={"username": "test@example.com", "password": "wrongpassword"}
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
assert "detail" in response.json()
|
||||
|
||||
def test_protected_route_without_token():
|
||||
"""Test that protected routes require a token."""
|
||||
response = client.get("/auth/me")
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_oidc_login_redirect():
|
||||
"""Test that OIDC login returns a redirect URL."""
|
||||
with patch("ea_chatbot.api.routers.auth.oidc_client") as mock_oidc:
|
||||
mock_oidc.get_login_url.return_value = "https://oidc-provider.com/auth"
|
||||
|
||||
response = client.get("/auth/oidc/login")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["url"] == "https://oidc-provider.com/auth"
|
||||
|
||||
def test_oidc_callback_success():
|
||||
"""Test successful OIDC callback and JWT issuance."""
|
||||
with patch("ea_chatbot.api.routers.auth.oidc_client") as mock_oidc, \
|
||||
patch("ea_chatbot.api.routers.auth.history_manager") as mock_hm:
|
||||
|
||||
mock_oidc.exchange_code_for_token.return_value = {"access_token": "oidc-token"}
|
||||
mock_oidc.get_user_info.return_value = {"email": "sso@example.com", "name": "SSO User"}
|
||||
mock_hm.sync_user_from_oidc.return_value = User(id="sso-123", username="sso@example.com", display_name="SSO User")
|
||||
|
||||
response = client.get("/auth/oidc/callback?code=some-code")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert "access_token" in response.json()
|
||||
assert response.json()["token_type"] == "bearer"
|
||||
|
||||
def test_get_me_success():
|
||||
"""Test getting current user with a valid token."""
|
||||
from ea_chatbot.api.utils import create_access_token
|
||||
token = create_access_token(data={"sub": "123"})
|
||||
|
||||
with patch("ea_chatbot.api.dependencies.history_manager") as mock_hm:
|
||||
mock_hm.get_user_by_id.return_value = User(id="123", username="test@example.com", display_name="Test")
|
||||
|
||||
response = client.get(
|
||||
"/auth/me",
|
||||
headers={"Authorization": f"Bearer {token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["email"] == "test@example.com"
|
||||
assert response.json()["id"] == "123"
|
||||
116
backend/tests/api/test_api_history.py
Normal file
116
backend/tests/api/test_api_history.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.api.main import app
|
||||
from ea_chatbot.api.dependencies import get_current_user
|
||||
from ea_chatbot.history.models import Conversation, Message, Plot, User
|
||||
from ea_chatbot.api.utils import create_access_token
|
||||
from datetime import datetime, timezone
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user():
|
||||
user = User(id="user-123", username="test@example.com", display_name="Test User")
|
||||
return user
|
||||
|
||||
@pytest.fixture
|
||||
def auth_header(mock_user):
|
||||
# Override get_current_user to return our mock user
|
||||
app.dependency_overrides[get_current_user] = lambda: mock_user
|
||||
token = create_access_token(data={"sub": mock_user.id})
|
||||
yield {"Authorization": f"Bearer {token}"}
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
def test_get_conversations_success(auth_header, mock_user):
|
||||
"""Test retrieving list of conversations."""
|
||||
with patch("ea_chatbot.api.routers.history.history_manager") as mock_hm:
|
||||
mock_hm.get_conversations.return_value = [
|
||||
Conversation(
|
||||
id="c1",
|
||||
name="Conv 1",
|
||||
user_id=mock_user.id,
|
||||
data_state="nj",
|
||||
created_at=datetime.now(timezone.utc)
|
||||
)
|
||||
]
|
||||
|
||||
response = client.get("/conversations", headers=auth_header)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 1
|
||||
assert response.json()[0]["name"] == "Conv 1"
|
||||
|
||||
def test_create_conversation_success(auth_header, mock_user):
|
||||
"""Test creating a new conversation."""
|
||||
with patch("ea_chatbot.api.routers.history.history_manager") as mock_hm:
|
||||
mock_hm.create_conversation.return_value = Conversation(
|
||||
id="c2",
|
||||
name="New Conv",
|
||||
user_id=mock_user.id,
|
||||
data_state="nj",
|
||||
created_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/conversations",
|
||||
json={"name": "New Conv"},
|
||||
headers=auth_header
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
assert response.json()["name"] == "New Conv"
|
||||
assert response.json()["id"] == "c2"
|
||||
|
||||
def test_get_messages_success(auth_header):
|
||||
"""Test retrieving messages for a conversation."""
|
||||
with patch("ea_chatbot.api.routers.history.history_manager") as mock_hm:
|
||||
mock_hm.get_messages.return_value = [
|
||||
Message(
|
||||
id="m1",
|
||||
role="user",
|
||||
content="Hello",
|
||||
conversation_id="c1",
|
||||
created_at=datetime.now(timezone.utc)
|
||||
)
|
||||
]
|
||||
|
||||
response = client.get("/conversations/c1/messages", headers=auth_header)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 1
|
||||
assert response.json()[0]["content"] == "Hello"
|
||||
|
||||
def test_delete_conversation_success(auth_header):
|
||||
"""Test deleting a conversation."""
|
||||
with patch("ea_chatbot.api.routers.history.history_manager") as mock_hm:
|
||||
mock_hm.delete_conversation.return_value = True
|
||||
|
||||
response = client.delete("/conversations/c1", headers=auth_header)
|
||||
assert response.status_code == 204
|
||||
|
||||
def test_get_plot_success(auth_header, mock_user):
|
||||
"""Test retrieving a plot artifact."""
|
||||
with patch("ea_chatbot.api.routers.artifacts.history_manager") as mock_hm:
|
||||
# Mocking finding a plot by ID
|
||||
mock_session = MagicMock()
|
||||
mock_hm.get_session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mocking the models and their relationships
|
||||
mock_conv = Conversation(id="c1", user_id=mock_user.id, user=mock_user)
|
||||
mock_msg = Message(id="m1", conversation_id="c1", conversation=mock_conv)
|
||||
mock_plot = Plot(id="p1", image_data=b"fake-image-data", message_id="m1", message=mock_msg)
|
||||
|
||||
def mock_get(model, id):
|
||||
if model == Plot: return mock_plot
|
||||
if model == Message: return mock_msg
|
||||
if model == Conversation: return mock_conv
|
||||
return None
|
||||
|
||||
mock_session.get.side_effect = mock_get
|
||||
|
||||
response = client.get("/artifacts/plots/p1", headers=auth_header)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.content == b"fake-image-data"
|
||||
assert response.headers["content-type"] == "image/png"
|
||||
9
backend/tests/api/test_api_main.py
Normal file
9
backend/tests/api/test_api_main.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from fastapi.testclient import TestClient
|
||||
from ea_chatbot.api.main import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
def test_health_check():
|
||||
response = client.get("/health")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "ok"}
|
||||
63
backend/tests/api/test_persistence.py
Normal file
63
backend/tests/api/test_persistence.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
from ea_chatbot.api.main import app
|
||||
from ea_chatbot.api.dependencies import get_current_user
|
||||
from ea_chatbot.history.models import User, Conversation, Message, Plot
|
||||
from ea_chatbot.api.utils import create_access_token
|
||||
from datetime import datetime, timezone
|
||||
import json
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user():
|
||||
return User(id="user-123", username="test@example.com", display_name="Test User")
|
||||
|
||||
@pytest.fixture
|
||||
def auth_header(mock_user):
|
||||
app.dependency_overrides[get_current_user] = lambda: mock_user
|
||||
token = create_access_token(data={"sub": mock_user.id})
|
||||
yield {"Authorization": f"Bearer {token}"}
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
def test_persistence_integration_success(auth_header, mock_user):
|
||||
"""Test that messages and plots are persisted correctly during streaming."""
|
||||
mock_events = [
|
||||
{"event": "on_chat_model_stream", "name": "summarizer", "data": {"chunk": "Final answer"}},
|
||||
{"event": "on_chain_end", "name": "summarizer", "data": {"output": {"messages": [{"content": "Final answer"}]}}},
|
||||
{"event": "on_chain_end", "name": "summarize_conversation", "data": {"output": {"summary": "New summary"}}}
|
||||
]
|
||||
|
||||
async def mock_astream_events(*args, **kwargs):
|
||||
for event in mock_events:
|
||||
yield event
|
||||
|
||||
with patch("ea_chatbot.api.routers.agent.app.astream_events", side_effect=mock_astream_events), \
|
||||
patch("ea_chatbot.api.routers.agent.get_checkpointer") as mock_cp, \
|
||||
patch("ea_chatbot.api.routers.agent.history_manager") as mock_hm:
|
||||
|
||||
mock_cp.return_value.__aenter__.return_value = AsyncMock()
|
||||
|
||||
# Mock session and DB objects
|
||||
mock_session = MagicMock()
|
||||
mock_hm.get_session.return_value.__enter__.return_value = mock_session
|
||||
mock_conv = Conversation(id="t1", user_id=mock_user.id)
|
||||
mock_session.get.return_value = mock_conv
|
||||
|
||||
# Act
|
||||
with client.stream("POST", "/chat/stream",
|
||||
json={"message": "persistence test", "thread_id": "t1"},
|
||||
headers=auth_header) as response:
|
||||
assert response.status_code == 200
|
||||
list(response.iter_lines()) # Consume stream
|
||||
|
||||
# Assertions
|
||||
# 1. User message should be saved immediately
|
||||
mock_hm.add_message.assert_any_call("t1", "user", "persistence test")
|
||||
|
||||
# 2. Assistant message should be saved at the end
|
||||
mock_hm.add_message.assert_any_call("t1", "assistant", "Final answer", plots=[])
|
||||
|
||||
# 3. Summary should be updated
|
||||
mock_hm.update_conversation_summary.assert_called_once_with("t1", "New summary")
|
||||
51
backend/tests/api/test_utils.py
Normal file
51
backend/tests/api/test_utils.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from datetime import timedelta
|
||||
from ea_chatbot.api.utils import create_access_token, decode_access_token, convert_to_json_compatible
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
def test_create_and_decode_access_token():
|
||||
"""Test that a token can be created and then decoded."""
|
||||
data = {"sub": "test@example.com", "user_id": "123"}
|
||||
token = create_access_token(data)
|
||||
|
||||
decoded = decode_access_token(token)
|
||||
assert decoded["sub"] == data["sub"]
|
||||
assert decoded["user_id"] == data["user_id"]
|
||||
assert "exp" in decoded
|
||||
|
||||
def test_decode_invalid_token():
|
||||
"""Test that an invalid token returns None."""
|
||||
assert decode_access_token("invalid-token") is None
|
||||
|
||||
def test_expired_token():
|
||||
"""Test that an expired token returns None."""
|
||||
data = {"sub": "test@example.com"}
|
||||
# Create a token that expired 1 minute ago
|
||||
token = create_access_token(data, expires_delta=timedelta(minutes=-1))
|
||||
|
||||
assert decode_access_token(token) is None
|
||||
|
||||
def test_convert_to_json_compatible_complex_message():
|
||||
"""Test that list-based message content is handled correctly."""
|
||||
# Mock a message with list-based content (blocks)
|
||||
msg = AIMessage(content=[
|
||||
{"type": "text", "text": "Hello "},
|
||||
{"type": "text", "text": "world!"},
|
||||
{"type": "other", "data": "ignore me"}
|
||||
])
|
||||
|
||||
result = convert_to_json_compatible(msg)
|
||||
assert result["content"] == "Hello world!"
|
||||
assert result["type"] == "ai"
|
||||
|
||||
def test_convert_to_json_compatible_message_with_text_prop():
|
||||
"""Test that .text property is prioritized if available."""
|
||||
# Using a MagicMock to simulate the property safely
|
||||
from unittest.mock import MagicMock
|
||||
msg = MagicMock(spec=AIMessage)
|
||||
msg.content = "Raw content"
|
||||
msg.text = "Just the text"
|
||||
msg.type = "ai"
|
||||
msg.additional_kwargs = {}
|
||||
|
||||
result = convert_to_json_compatible(msg)
|
||||
assert result["content"] == "Just the text"
|
||||
49
backend/tests/graph/test_checkpoint.py
Normal file
49
backend/tests/graph/test_checkpoint.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from ea_chatbot.graph.checkpoint import get_checkpointer
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_checkpointer_initialization():
|
||||
"""Test that the checkpointer setup is called."""
|
||||
mock_conn = AsyncMock()
|
||||
mock_pool = MagicMock() # Changed from AsyncMock to MagicMock
|
||||
mock_pool.closed = True
|
||||
mock_pool.open = AsyncMock() # Ensure open is awaitable
|
||||
|
||||
# Setup mock_pool.connection() to return an async context manager
|
||||
mock_cm = MagicMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_pool.connection.return_value = mock_cm
|
||||
|
||||
# We need to patch the get_pool function and AsyncPostgresSaver in the module
|
||||
with MagicMock() as mock_get_pool, \
|
||||
MagicMock() as mock_saver_class:
|
||||
import ea_chatbot.graph.checkpoint as checkpoint
|
||||
mock_get_pool.return_value = mock_pool
|
||||
|
||||
# Mock AsyncPostgresSaver class
|
||||
mock_saver_instance = AsyncMock()
|
||||
mock_saver_class.return_value = mock_saver_instance
|
||||
|
||||
original_get_pool = checkpoint.get_pool
|
||||
checkpoint.get_pool = mock_get_pool
|
||||
|
||||
# Patch AsyncPostgresSaver where it's imported in checkpoint.py
|
||||
import langgraph.checkpoint.postgres.aio as pg_aio
|
||||
original_saver = checkpoint.AsyncPostgresSaver
|
||||
checkpoint.AsyncPostgresSaver = mock_saver_class
|
||||
|
||||
try:
|
||||
async with get_checkpointer() as checkpointer:
|
||||
assert checkpointer == mock_saver_instance
|
||||
# Verify setup was called
|
||||
mock_saver_instance.setup.assert_called_once()
|
||||
|
||||
# Verify pool was opened
|
||||
mock_pool.open.assert_called_once()
|
||||
# Verify connection was requested
|
||||
mock_pool.connection.assert_called_once()
|
||||
finally:
|
||||
checkpoint.get_pool = original_get_pool
|
||||
checkpoint.AsyncPostgresSaver = original_saver
|
||||
90
backend/tests/test_app.py
Normal file
90
backend/tests/test_app.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import os
|
||||
import sys
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from streamlit.testing.v1 import AppTest
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
# Ensure src is in python path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../src')))
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_history_manager():
|
||||
"""Globally mock HistoryManager to avoid DB calls during AppTest."""
|
||||
with patch("ea_chatbot.history.manager.HistoryManager") as mock_cls:
|
||||
instance = mock_cls.return_value
|
||||
instance.create_conversation.return_value = MagicMock(id="conv_123")
|
||||
instance.get_conversations.return_value = []
|
||||
instance.get_messages.return_value = []
|
||||
instance.add_message.return_value = MagicMock()
|
||||
instance.update_conversation_summary.return_value = MagicMock()
|
||||
yield instance
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_stream():
|
||||
with patch("ea_chatbot.graph.workflow.app.stream") as mock_stream:
|
||||
# Mock events from app.stream
|
||||
mock_stream.return_value = [
|
||||
{"query_analyzer": {"next_action": "research"}},
|
||||
{"researcher": {"messages": [AIMessage(content="Research result")]}}
|
||||
]
|
||||
yield mock_stream
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user():
|
||||
user = MagicMock()
|
||||
user.id = "test_id"
|
||||
user.username = "test@example.com"
|
||||
user.display_name = "Test User"
|
||||
return user
|
||||
|
||||
def test_app_initial_state(mock_app_stream, mock_user):
|
||||
"""Test that the app initializes with the correct title and empty history."""
|
||||
at = AppTest.from_file("src/ea_chatbot/app.py")
|
||||
|
||||
# Simulate logged-in user
|
||||
at.session_state["user"] = mock_user
|
||||
|
||||
at.run()
|
||||
|
||||
assert not at.exception
|
||||
assert at.title[0].value == "🗳️ Election Analytics Chatbot"
|
||||
|
||||
# Check session state initialization
|
||||
assert "messages" in at.session_state
|
||||
assert len(at.session_state["messages"]) == 0
|
||||
|
||||
def test_app_dev_mode_toggle(mock_app_stream, mock_user):
|
||||
"""Test that the dev mode toggle exists in the sidebar."""
|
||||
with patch.dict(os.environ, {"DEV_MODE": "false"}):
|
||||
at = AppTest.from_file("src/ea_chatbot/app.py")
|
||||
at.session_state["user"] = mock_user
|
||||
at.run()
|
||||
|
||||
# Check for sidebar toggle (checkbox)
|
||||
assert len(at.sidebar.checkbox) > 0
|
||||
dev_mode_toggle = at.sidebar.checkbox[0]
|
||||
assert dev_mode_toggle.label == "Dev Mode"
|
||||
assert dev_mode_toggle.value is False
|
||||
|
||||
def test_app_graph_execution_streaming(mock_app_stream, mock_user, mock_history_manager):
|
||||
"""Test that entering a prompt triggers the graph stream and displays response."""
|
||||
at = AppTest.from_file("src/ea_chatbot/app.py")
|
||||
at.session_state["user"] = mock_user
|
||||
at.run()
|
||||
|
||||
# Input a question
|
||||
at.chat_input[0].set_value("Test question").run()
|
||||
|
||||
# Verify graph stream was called
|
||||
assert mock_app_stream.called
|
||||
|
||||
# Message should be added to history
|
||||
assert len(at.session_state["messages"]) == 2
|
||||
assert at.session_state["messages"][0]["role"] == "user"
|
||||
assert at.session_state["messages"][1]["role"] == "assistant"
|
||||
assert "Research result" in at.session_state["messages"][1]["content"]
|
||||
|
||||
# Verify history manager was used
|
||||
assert mock_history_manager.create_conversation.called
|
||||
assert mock_history_manager.add_message.called
|
||||
83
backend/tests/test_app_auth.py
Normal file
83
backend/tests/test_app_auth.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from streamlit.testing.v1 import AppTest
|
||||
from ea_chatbot.auth import AuthType
|
||||
|
||||
@pytest.fixture
|
||||
def mock_history_manager_instance():
|
||||
# We need to patch before the AppTest loads the module
|
||||
with patch("ea_chatbot.history.manager.HistoryManager") as mock_cls:
|
||||
instance = mock_cls.return_value
|
||||
yield instance
|
||||
|
||||
def test_auth_ui_flow_step1_to_password(mock_history_manager_instance):
|
||||
"""Test UI transition from Step 1 (email) to Step 2a (password) for LOCAL user."""
|
||||
# Patch BEFORE creating AppTest
|
||||
mock_user = MagicMock()
|
||||
mock_user.password_hash = "hashed_password"
|
||||
mock_history_manager_instance.get_user.return_value = mock_user
|
||||
|
||||
at = AppTest.from_file("src/ea_chatbot/app.py")
|
||||
at.run()
|
||||
|
||||
# Step 1: Identification
|
||||
assert at.session_state["login_step"] == "email"
|
||||
at.text_input[0].set_value("local@example.com")
|
||||
|
||||
at.button[0].click().run()
|
||||
|
||||
# Verify transition to password step
|
||||
assert at.session_state["login_step"] == "login_password"
|
||||
assert at.session_state["login_email"] == "local@example.com"
|
||||
assert "Welcome back" in at.info[0].value
|
||||
|
||||
def test_auth_ui_flow_step1_to_register(mock_history_manager_instance):
|
||||
"""Test UI transition from Step 1 (email) to Step 2b (registration) for NEW user."""
|
||||
mock_history_manager_instance.get_user.return_value = None
|
||||
|
||||
at = AppTest.from_file("src/ea_chatbot/app.py")
|
||||
at.run()
|
||||
|
||||
# Step 1: Identification
|
||||
at.text_input[0].set_value("new@example.com")
|
||||
|
||||
at.button[0].click().run()
|
||||
|
||||
# Verify transition to registration step
|
||||
assert at.session_state["login_step"] == "register_details"
|
||||
assert at.session_state["login_email"] == "new@example.com"
|
||||
assert "Create an account" in at.info[0].value
|
||||
|
||||
def test_auth_ui_flow_step1_to_oidc(mock_history_manager_instance):
|
||||
"""Test UI transition from Step 1 (email) to Step 2c (OIDC) for OIDC user."""
|
||||
# Mock history_manager.get_user to return a user WITHOUT a password
|
||||
mock_user = MagicMock()
|
||||
mock_user.password_hash = None
|
||||
mock_history_manager_instance.get_user.return_value = mock_user
|
||||
|
||||
at = AppTest.from_file("src/ea_chatbot/app.py")
|
||||
at.run()
|
||||
|
||||
# Step 1: Identification
|
||||
at.text_input[0].set_value("oidc@example.com")
|
||||
|
||||
at.button[0].click().run()
|
||||
|
||||
# Verify transition to OIDC step
|
||||
assert at.session_state["login_step"] == "oidc_login"
|
||||
assert at.session_state["login_email"] == "oidc@example.com"
|
||||
assert "configured for Single Sign-On" in at.info[0].value
|
||||
|
||||
def test_auth_ui_flow_back_button(mock_history_manager_instance):
|
||||
"""Test that the 'Back' button returns to Step 1."""
|
||||
at = AppTest.from_file("src/ea_chatbot/app.py")
|
||||
# Simulate being on Step 2a
|
||||
at.session_state["login_step"] = "login_password"
|
||||
at.session_state["login_email"] = "local@example.com"
|
||||
at.run()
|
||||
|
||||
# Click Back (index 1 in Step 2a)
|
||||
at.button[1].click().run()
|
||||
|
||||
# Verify return to email step
|
||||
assert at.session_state["login_step"] == "email"
|
||||
43
backend/tests/test_auth.py
Normal file
43
backend/tests/test_auth.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.auth import OIDCClient
|
||||
|
||||
@patch("ea_chatbot.auth.OAuth2Session")
|
||||
def test_oidc_client_initialization(mock_oauth):
|
||||
client = OIDCClient(
|
||||
client_id="test_id",
|
||||
client_secret="test_secret",
|
||||
server_metadata_url="https://test.server/.well-known/openid-configuration"
|
||||
)
|
||||
assert client.oauth_session is not None
|
||||
|
||||
@patch("ea_chatbot.auth.requests")
|
||||
@patch("ea_chatbot.auth.OAuth2Session")
|
||||
def test_get_login_url(mock_oauth_cls, mock_requests):
|
||||
# Setup mock session
|
||||
mock_session = MagicMock()
|
||||
mock_oauth_cls.return_value = mock_session
|
||||
|
||||
# Mock metadata response
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"authorization_endpoint": "https://test.server/auth",
|
||||
"token_endpoint": "https://test.server/token",
|
||||
"userinfo_endpoint": "https://test.server/userinfo"
|
||||
}
|
||||
mock_requests.get.return_value = mock_response
|
||||
|
||||
# Mock authorization url generation
|
||||
mock_session.create_authorization_url.return_value = ("https://test.server/auth?response_type=code", "state")
|
||||
|
||||
client = OIDCClient(
|
||||
client_id="test_id",
|
||||
client_secret="test_secret",
|
||||
server_metadata_url="https://test.server/.well-known/openid-configuration"
|
||||
)
|
||||
|
||||
url = client.get_login_url()
|
||||
|
||||
assert url == "https://test.server/auth?response_type=code"
|
||||
# Verify metadata was fetched via requests
|
||||
mock_requests.get.assert_called_with("https://test.server/.well-known/openid-configuration")
|
||||
49
backend/tests/test_auth_flow.py
Normal file
49
backend/tests/test_auth_flow.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
from ea_chatbot.history.manager import HistoryManager
|
||||
from ea_chatbot.auth import get_user_auth_type, AuthType
|
||||
|
||||
# Mocks
|
||||
@pytest.fixture
|
||||
def mock_history_manager():
|
||||
return MagicMock(spec=HistoryManager)
|
||||
|
||||
def test_auth_flow_existing_local_user(mock_history_manager):
|
||||
"""Test that an existing user with a password returns LOCAL auth type."""
|
||||
# Setup
|
||||
mock_user = MagicMock()
|
||||
mock_user.password_hash = "hashed_secret"
|
||||
mock_history_manager.get_user.return_value = mock_user
|
||||
|
||||
# Execute
|
||||
auth_type = get_user_auth_type("test@example.com", mock_history_manager)
|
||||
|
||||
# Verify
|
||||
assert auth_type == AuthType.LOCAL
|
||||
mock_history_manager.get_user.assert_called_once_with("test@example.com")
|
||||
|
||||
def test_auth_flow_existing_oidc_user(mock_history_manager):
|
||||
"""Test that an existing user WITHOUT a password returns OIDC auth type."""
|
||||
# Setup
|
||||
mock_user = MagicMock()
|
||||
mock_user.password_hash = None # No password implies OIDC
|
||||
mock_history_manager.get_user.return_value = mock_user
|
||||
|
||||
# Execute
|
||||
auth_type = get_user_auth_type("sso@example.com", mock_history_manager)
|
||||
|
||||
# Verify
|
||||
assert auth_type == AuthType.OIDC
|
||||
mock_history_manager.get_user.assert_called_once_with("sso@example.com")
|
||||
|
||||
def test_auth_flow_new_user(mock_history_manager):
|
||||
"""Test that a non-existent user returns NEW auth type."""
|
||||
# Setup
|
||||
mock_history_manager.get_user.return_value = None
|
||||
|
||||
# Execute
|
||||
auth_type = get_user_auth_type("new@example.com", mock_history_manager)
|
||||
|
||||
# Verify
|
||||
assert auth_type == AuthType.NEW
|
||||
mock_history_manager.get_user.assert_called_once_with("new@example.com")
|
||||
62
backend/tests/test_coder.py
Normal file
62
backend/tests/test_coder.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.graph.nodes.coder import coder_node
|
||||
from ea_chatbot.graph.nodes.error_corrector import error_corrector_node
|
||||
|
||||
@pytest.fixture
|
||||
def mock_state():
|
||||
return {
|
||||
"messages": [],
|
||||
"question": "Show me results for New Jersey",
|
||||
"plan": "Step 1: Load data\nStep 2: Filter by NJ",
|
||||
"code": None,
|
||||
"error": None,
|
||||
"plots": [],
|
||||
"dfs": {},
|
||||
"next_action": "plan"
|
||||
}
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.coder.get_llm_model")
|
||||
@patch("ea_chatbot.utils.database_inspection.get_data_summary")
|
||||
def test_coder_node(mock_get_summary, mock_get_llm, mock_state):
|
||||
"""Test coder node generates code from plan."""
|
||||
mock_get_summary.return_value = "Column: Name, Type: text"
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
from ea_chatbot.schemas import CodeGenerationResponse
|
||||
mock_response = CodeGenerationResponse(
|
||||
code="import pandas as pd\nprint('Hello')",
|
||||
explanation="Generated code"
|
||||
)
|
||||
mock_llm.with_structured_output.return_value.invoke.return_value = mock_response
|
||||
|
||||
result = coder_node(mock_state)
|
||||
|
||||
assert "code" in result
|
||||
assert "import pandas as pd" in result["code"]
|
||||
assert "error" in result
|
||||
assert result["error"] is None
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.error_corrector.get_llm_model")
|
||||
def test_error_corrector_node(mock_get_llm, mock_state):
|
||||
"""Test error corrector node fixes code."""
|
||||
mock_state["code"] = "import pandas as pd\nprint(undefined_var)"
|
||||
mock_state["error"] = "NameError: name 'undefined_var' is not defined"
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
from ea_chatbot.schemas import CodeGenerationResponse
|
||||
mock_response = CodeGenerationResponse(
|
||||
code="import pandas as pd\nprint('Defined')",
|
||||
explanation="Fixed variable"
|
||||
)
|
||||
mock_llm.with_structured_output.return_value.invoke.return_value = mock_response
|
||||
|
||||
result = error_corrector_node(mock_state)
|
||||
|
||||
assert "code" in result
|
||||
assert "print('Defined')" in result["code"]
|
||||
assert result["error"] is None
|
||||
47
backend/tests/test_config.py
Normal file
47
backend/tests/test_config.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
from ea_chatbot.config import Settings, LLMConfig
|
||||
|
||||
def test_default_settings():
|
||||
"""Test that default settings are loaded correctly."""
|
||||
settings = Settings()
|
||||
|
||||
# Check default config for query analyzer
|
||||
assert isinstance(settings.query_analyzer_llm, LLMConfig)
|
||||
assert settings.query_analyzer_llm.provider == "openai"
|
||||
assert settings.query_analyzer_llm.model == "gpt-5-mini"
|
||||
assert settings.query_analyzer_llm.temperature == 0.0
|
||||
|
||||
# Check default config for planner
|
||||
assert isinstance(settings.planner_llm, LLMConfig)
|
||||
assert settings.planner_llm.provider == "openai"
|
||||
assert settings.planner_llm.model == "gpt-5-mini"
|
||||
|
||||
def test_env_override(monkeypatch):
|
||||
"""Test that environment variables override defaults."""
|
||||
monkeypatch.setenv("QUERY_ANALYZER_LLM__MODEL", "gpt-3.5-turbo")
|
||||
monkeypatch.setenv("QUERY_ANALYZER_LLM__TEMPERATURE", "0.7")
|
||||
|
||||
settings = Settings()
|
||||
assert settings.query_analyzer_llm.model == "gpt-3.5-turbo"
|
||||
assert settings.query_analyzer_llm.temperature == 0.7
|
||||
|
||||
def test_provider_specific_params():
|
||||
"""Test that provider specific parameters can be set."""
|
||||
config = LLMConfig(
|
||||
provider="openai",
|
||||
model="o1-preview",
|
||||
provider_specific={"reasoning_effort": "high"}
|
||||
)
|
||||
assert config.provider_specific["reasoning_effort"] == "high"
|
||||
|
||||
def test_oidc_settings(monkeypatch):
|
||||
"""Test OIDC settings configuration."""
|
||||
monkeypatch.setenv("OIDC_CLIENT_ID", "test_client_id")
|
||||
monkeypatch.setenv("OIDC_CLIENT_SECRET", "test_client_secret")
|
||||
monkeypatch.setenv("OIDC_SERVER_METADATA_URL", "https://test.server/.well-known/openid-configuration")
|
||||
|
||||
settings = Settings()
|
||||
assert settings.oidc_client_id == "test_client_id"
|
||||
assert settings.oidc_client_secret == "test_client_secret"
|
||||
assert settings.oidc_server_metadata_url == "https://test.server/.well-known/openid-configuration"
|
||||
56
backend/tests/test_conversation_summary.py
Normal file
56
backend/tests/test_conversation_summary.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from ea_chatbot.graph.nodes.summarize_conversation import summarize_conversation_node
|
||||
|
||||
@pytest.fixture
|
||||
def mock_state_with_history():
|
||||
return {
|
||||
"messages": [
|
||||
HumanMessage(content="Show me the 2024 results for Florida"),
|
||||
AIMessage(content="Here are the results for Florida in 2024...")
|
||||
],
|
||||
"summary": "The user is asking about 2024 election results."
|
||||
}
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.summarize_conversation.get_llm_model")
|
||||
def test_summarize_conversation_node_updates_summary(mock_get_llm, mock_state_with_history):
|
||||
mock_llm_instance = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm_instance
|
||||
|
||||
# Mock LLM response for updating summary
|
||||
mock_llm_instance.invoke.return_value = AIMessage(content="Updated summary including NJ results.")
|
||||
|
||||
# Add new messages to simulate a completed turn
|
||||
mock_state_with_history["messages"].extend([
|
||||
HumanMessage(content="What about in New Jersey?"),
|
||||
AIMessage(content="In New Jersey, the 2024 results were...")
|
||||
])
|
||||
|
||||
result = summarize_conversation_node(mock_state_with_history)
|
||||
|
||||
assert "summary" in result
|
||||
assert result["summary"] == "Updated summary including NJ results."
|
||||
|
||||
# Verify LLM was called with the correct context
|
||||
call_messages = mock_llm_instance.invoke.call_args[0][0]
|
||||
# Should include current summary and last turn messages
|
||||
assert "Current summary: The user is asking about 2024 election results." in call_messages[0].content
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.summarize_conversation.get_llm_model")
|
||||
def test_summarize_conversation_node_initial_summary(mock_get_llm):
|
||||
state = {
|
||||
"messages": [
|
||||
HumanMessage(content="Hi"),
|
||||
AIMessage(content="Hello! How can I help you today?")
|
||||
],
|
||||
"summary": ""
|
||||
}
|
||||
|
||||
mock_llm_instance = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm_instance
|
||||
mock_llm_instance.invoke.return_value = AIMessage(content="Initial greeting.")
|
||||
|
||||
result = summarize_conversation_node(state)
|
||||
|
||||
assert result["summary"] == "Initial greeting."
|
||||
195
backend/tests/test_database_inspection.py
Normal file
195
backend/tests/test_database_inspection.py
Normal file
@@ -0,0 +1,195 @@
|
||||
import pytest
|
||||
import pandas as pd
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.utils.database_inspection import get_primary_key, inspect_db_table, get_data_summary
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_client():
|
||||
mock_client = MagicMock()
|
||||
mock_client.settings = {"table": "test_table"}
|
||||
return mock_client
|
||||
|
||||
def test_get_primary_key(mock_db_client):
|
||||
"""Test dynamic primary key discovery."""
|
||||
# Mock response for primary key query
|
||||
mock_df = pd.DataFrame({"column_name": ["my_pk"]})
|
||||
mock_db_client.query_df.return_value = mock_df
|
||||
|
||||
pk = get_primary_key(mock_db_client, "test_table")
|
||||
|
||||
assert pk == "my_pk"
|
||||
# Verify the query was called (at least once)
|
||||
assert mock_db_client.query_df.called
|
||||
|
||||
def test_inspect_db_table_improved(mock_db_client, tmp_path):
|
||||
"""Test improved inspect_db_table with cardinality and sampling."""
|
||||
data_dir = str(tmp_path)
|
||||
|
||||
# 1. Mock columns and types
|
||||
columns_df = pd.DataFrame({
|
||||
"column_name": ["id", "category", "count"],
|
||||
"data_type": ["integer", "text", "integer"]
|
||||
})
|
||||
|
||||
# 2. Mock row count
|
||||
total_rows_df = pd.DataFrame([{"count": 100}])
|
||||
|
||||
# 3. Mock PK discovery
|
||||
pk_df = pd.DataFrame({"column_name": ["id"]})
|
||||
|
||||
# 4. Mock stats for columns
|
||||
# We need to handle multiple calls to query_df
|
||||
def side_effect(query):
|
||||
if "information_schema.columns" in query:
|
||||
return columns_df
|
||||
if "COUNT(*)" in query:
|
||||
return total_rows_df
|
||||
if "information_schema.key_column_usage" in query:
|
||||
return pk_df
|
||||
|
||||
# Category stats
|
||||
if 'COUNT("category")' in query:
|
||||
return pd.DataFrame([{"count": 100}])
|
||||
if 'COUNT(DISTINCT "category")' in query:
|
||||
return pd.DataFrame([{"count": 5}])
|
||||
if 'SELECT DISTINCT "category"' in query:
|
||||
return pd.DataFrame({"category": ["A", "B", "C", "D", "E"]})
|
||||
|
||||
# Count stats
|
||||
if 'COUNT("count")' in query:
|
||||
return pd.DataFrame([{"count": 100}])
|
||||
if 'COUNT(DISTINCT "count")' in query:
|
||||
return pd.DataFrame([{"count": 100}])
|
||||
if 'AVG("count")' in query:
|
||||
return pd.DataFrame([{"avg": 10.0, "min": 1, "max": 20}])
|
||||
|
||||
# ID stats (fix for IndexError)
|
||||
if 'COUNT("id")' in query:
|
||||
return pd.DataFrame([{"count": 100}])
|
||||
if 'COUNT(DISTINCT "id")' in query:
|
||||
return pd.DataFrame([{"count": 100}])
|
||||
if 'AVG("id")' in query:
|
||||
return pd.DataFrame([{"avg": 50.0, "min": 1, "max": 100}])
|
||||
|
||||
return pd.DataFrame()
|
||||
mock_db_client.query_df.side_effect = side_effect
|
||||
|
||||
# Run inspection
|
||||
inspect_db_table(mock_db_client, data_dir=data_dir)
|
||||
|
||||
# Read summary to verify
|
||||
summary = get_data_summary(data_dir)
|
||||
assert summary is not None
|
||||
assert "test_table" in summary
|
||||
assert "category" in summary
|
||||
assert "distinct_values" in summary
|
||||
assert "unique_count: 5" in summary
|
||||
assert "- A" in summary
|
||||
assert "- E" in summary
|
||||
assert "primary_key: id" in summary
|
||||
|
||||
def test_get_data_summary_none(tmp_path):
|
||||
"""Test get_data_summary when file doesn't exist."""
|
||||
assert get_data_summary(str(tmp_path)) is None
|
||||
|
||||
def test_inspect_db_table_temporal(mock_db_client, tmp_path):
|
||||
"""Test inspect_db_table with temporal columns."""
|
||||
data_dir = str(tmp_path)
|
||||
|
||||
columns_df = pd.DataFrame({
|
||||
"column_name": ["created_at"],
|
||||
"data_type": ["timestamp without time zone"]
|
||||
})
|
||||
total_rows_df = pd.DataFrame([{"count": 50}])
|
||||
pk_df = pd.DataFrame() # No PK
|
||||
|
||||
def side_effect(query):
|
||||
if "information_schema.columns" in query:
|
||||
return columns_df
|
||||
if "COUNT(*)" in query:
|
||||
return total_rows_df
|
||||
if "information_schema.key_column_usage" in query:
|
||||
return pk_df
|
||||
if 'COUNT("created_at")' in query:
|
||||
return pd.DataFrame([{"count": 50}])
|
||||
if 'COUNT(DISTINCT "created_at")' in query:
|
||||
return pd.DataFrame([{"count": 50}])
|
||||
if 'MIN("created_at")' in query:
|
||||
return pd.DataFrame([{"min": "2023-01-01", "max": "2023-12-31"}])
|
||||
return pd.DataFrame()
|
||||
|
||||
mock_db_client.query_df.side_effect = side_effect
|
||||
|
||||
inspect_db_table(mock_db_client, data_dir=data_dir)
|
||||
|
||||
summary = get_data_summary(data_dir)
|
||||
assert "created_at" in summary
|
||||
assert "min: '2023-01-01'" in summary
|
||||
assert "max: '2023-12-31'" in summary
|
||||
|
||||
def test_inspect_db_table_high_cardinality(mock_db_client, tmp_path):
|
||||
"""Test inspect_db_table with high cardinality categorical column (no sample values)."""
|
||||
data_dir = str(tmp_path)
|
||||
|
||||
columns_df = pd.DataFrame({
|
||||
"column_name": ["user_id"],
|
||||
"data_type": ["text"]
|
||||
})
|
||||
total_rows_df = pd.DataFrame([{"count": 100}])
|
||||
pk_df = pd.DataFrame()
|
||||
|
||||
def side_effect(query):
|
||||
if "information_schema.columns" in query:
|
||||
return columns_df
|
||||
if "COUNT(*)" in query:
|
||||
return total_rows_df
|
||||
if "information_schema.key_column_usage" in query:
|
||||
return pk_df
|
||||
if 'COUNT("user_id")' in query:
|
||||
return pd.DataFrame([{"count": 100}])
|
||||
if 'COUNT(DISTINCT "user_id")' in query:
|
||||
# High cardinality > 20
|
||||
return pd.DataFrame([{"count": 50}])
|
||||
return pd.DataFrame()
|
||||
|
||||
mock_db_client.query_df.side_effect = side_effect
|
||||
|
||||
inspect_db_table(mock_db_client, data_dir=data_dir)
|
||||
|
||||
summary = get_data_summary(data_dir)
|
||||
assert "user_id" in summary
|
||||
assert "unique_count: 50" in summary
|
||||
# Should NOT have distinct_values
|
||||
assert "distinct_values" not in summary
|
||||
|
||||
def test_inspect_db_table_checksum_skip(mock_db_client, tmp_path):
|
||||
"""Test that inspection is skipped if checksum matches."""
|
||||
data_dir = str(tmp_path)
|
||||
table = "test_table"
|
||||
|
||||
# 1. Create a fake checksum file
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
# Checksum is md5 of "ins|upd|del". Let's say mock returns "my_hash"
|
||||
|
||||
# Mock checksum query
|
||||
mock_db_client.query_df.return_value = pd.DataFrame([{"dml_hash": "my_hash"}])
|
||||
|
||||
# Write existing checksum
|
||||
with open(os.path.join(data_dir, "checksum"), "w") as f:
|
||||
f.write(f"{table}:my_hash\n")
|
||||
|
||||
# Write existing inspection
|
||||
with open(os.path.join(data_dir, "inspection.yaml"), "w") as f:
|
||||
f.write(f"{table}: {{ existing: true }}")
|
||||
|
||||
# Run inspection
|
||||
result = inspect_db_table(mock_db_client, data_dir=data_dir)
|
||||
|
||||
# Should return existing content
|
||||
assert "existing: true" in result
|
||||
# query_df should be called ONLY for checksum (once)
|
||||
# verify count of calls?
|
||||
# Logic: 1 call for checksum. If match, return.
|
||||
assert mock_db_client.query_df.call_count == 1
|
||||
|
||||
123
backend/tests/test_executor.py
Normal file
123
backend/tests/test_executor.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from unittest.mock import MagicMock, patch
|
||||
from matplotlib.figure import Figure
|
||||
from ea_chatbot.graph.nodes.executor import executor_node
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings():
|
||||
with patch("ea_chatbot.graph.nodes.executor.Settings") as MockSettings:
|
||||
mock_settings_instance = MagicMock()
|
||||
mock_settings_instance.db_host = "localhost"
|
||||
mock_settings_instance.db_port = 5432
|
||||
mock_settings_instance.db_user = "user"
|
||||
mock_settings_instance.db_pswd = "pass"
|
||||
mock_settings_instance.db_name = "test_db"
|
||||
mock_settings_instance.db_table = "test_table"
|
||||
MockSettings.return_value = mock_settings_instance
|
||||
yield mock_settings_instance
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_client():
|
||||
with patch("ea_chatbot.graph.nodes.executor.DBClient") as MockDBClient:
|
||||
mock_client_instance = MagicMock()
|
||||
MockDBClient.return_value = mock_client_instance
|
||||
yield mock_client_instance
|
||||
|
||||
def test_executor_node_success_simple_print(mock_settings, mock_db_client):
|
||||
"""Test executing simple code that prints to stdout."""
|
||||
state = {
|
||||
"code": "print('Hello, World!')",
|
||||
"question": "test",
|
||||
"messages": []
|
||||
}
|
||||
|
||||
result = executor_node(state)
|
||||
|
||||
assert "code_output" in result
|
||||
assert "Hello, World!" in result["code_output"]
|
||||
assert result["error"] is None
|
||||
assert result["plots"] == []
|
||||
assert result["dfs"] == {}
|
||||
|
||||
def test_executor_node_success_dataframe(mock_settings, mock_db_client):
|
||||
"""Test executing code that creates a DataFrame."""
|
||||
code = """
|
||||
import pandas as pd
|
||||
df = pd.DataFrame({'a': [1, 2], 'b': [3, 4]})
|
||||
print(df)
|
||||
"""
|
||||
state = {
|
||||
"code": code,
|
||||
"question": "test",
|
||||
"messages": []
|
||||
}
|
||||
|
||||
result = executor_node(state)
|
||||
|
||||
assert "code_output" in result
|
||||
assert "a b" in result["code_output"] # Check part of DF string representation
|
||||
assert "dfs" in result
|
||||
assert "df" in result["dfs"]
|
||||
assert isinstance(result["dfs"]["df"], pd.DataFrame)
|
||||
|
||||
def test_executor_node_success_plot(mock_settings, mock_db_client):
|
||||
"""Test executing code that generates a plot."""
|
||||
code = """
|
||||
import matplotlib.pyplot as plt
|
||||
fig = plt.figure()
|
||||
plots.append(fig)
|
||||
print('Plot generated')
|
||||
"""
|
||||
state = {
|
||||
"code": code,
|
||||
"question": "test",
|
||||
"messages": []
|
||||
}
|
||||
|
||||
result = executor_node(state)
|
||||
|
||||
assert "Plot generated" in result["code_output"]
|
||||
assert "plots" in result
|
||||
assert len(result["plots"]) == 1
|
||||
assert isinstance(result["plots"][0], Figure)
|
||||
|
||||
def test_executor_node_error_syntax(mock_settings, mock_db_client):
|
||||
"""Test executing code with a syntax error."""
|
||||
state = {
|
||||
"code": "print('Hello World", # Missing closing quote
|
||||
"question": "test",
|
||||
"messages": []
|
||||
}
|
||||
|
||||
result = executor_node(state)
|
||||
|
||||
assert result["error"] is not None
|
||||
assert "SyntaxError" in result["error"]
|
||||
|
||||
def test_executor_node_error_runtime(mock_settings, mock_db_client):
|
||||
"""Test executing code with a runtime error."""
|
||||
state = {
|
||||
"code": "print(1 / 0)",
|
||||
"question": "test",
|
||||
"messages": []
|
||||
}
|
||||
|
||||
result = executor_node(state)
|
||||
|
||||
assert result["error"] is not None
|
||||
assert "ZeroDivisionError" in result["error"]
|
||||
|
||||
def test_executor_node_no_code(mock_settings, mock_db_client):
|
||||
"""Test handling when no code is provided."""
|
||||
state = {
|
||||
"code": None,
|
||||
"question": "test",
|
||||
"messages": []
|
||||
}
|
||||
|
||||
result = executor_node(state)
|
||||
|
||||
assert "error" in result
|
||||
assert "No code provided" in result["error"]
|
||||
77
backend/tests/test_helpers.py
Normal file
77
backend/tests/test_helpers.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import pytest
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from ea_chatbot.utils.helpers import merge_agent_state
|
||||
|
||||
def test_merge_agent_state_list_accumulation():
|
||||
"""Verify that list fields (messages, plots) are accumulated (appended)."""
|
||||
current_state = {
|
||||
"messages": [HumanMessage(content="hello")],
|
||||
"plots": ["plot1"]
|
||||
}
|
||||
update = {
|
||||
"messages": [AIMessage(content="hi")],
|
||||
"plots": ["plot2"]
|
||||
}
|
||||
|
||||
merged = merge_agent_state(current_state, update)
|
||||
|
||||
assert len(merged["messages"]) == 2
|
||||
assert merged["messages"][0].content == "hello"
|
||||
assert merged["messages"][1].content == "hi"
|
||||
|
||||
assert len(merged["plots"]) == 2
|
||||
assert merged["plots"] == ["plot1", "plot2"]
|
||||
|
||||
def test_merge_agent_state_dict_update():
|
||||
"""Verify that dictionary fields (dfs) are updated (shallow merge)."""
|
||||
current_state = {
|
||||
"dfs": {"df1": "data1"}
|
||||
}
|
||||
update = {
|
||||
"dfs": {"df2": "data2"}
|
||||
}
|
||||
|
||||
merged = merge_agent_state(current_state, update)
|
||||
|
||||
assert merged["dfs"] == {"df1": "data1", "df2": "data2"}
|
||||
|
||||
# Verify overwrite within dict
|
||||
update_overwrite = {
|
||||
"dfs": {"df1": "new_data1"}
|
||||
}
|
||||
merged_overwrite = merge_agent_state(merged, update_overwrite)
|
||||
assert merged_overwrite["dfs"] == {"df1": "new_data1", "df2": "data2"}
|
||||
|
||||
def test_merge_agent_state_standard_overwrite():
|
||||
"""Verify that standard fields are overwritten."""
|
||||
current_state = {
|
||||
"question": "old question",
|
||||
"next_action": "old action",
|
||||
"plan": "old plan"
|
||||
}
|
||||
update = {
|
||||
"question": "new question",
|
||||
"next_action": "new action",
|
||||
"plan": "new plan"
|
||||
}
|
||||
|
||||
merged = merge_agent_state(current_state, update)
|
||||
|
||||
assert merged["question"] == "new question"
|
||||
assert merged["next_action"] == "new action"
|
||||
assert merged["plan"] == "new plan"
|
||||
|
||||
def test_merge_agent_state_none_handling():
|
||||
"""Verify that None updates or missing keys in update don't break things."""
|
||||
current_state = {
|
||||
"question": "test",
|
||||
"messages": ["msg1"]
|
||||
}
|
||||
|
||||
# Empty update
|
||||
assert merge_agent_state(current_state, {}) == current_state
|
||||
|
||||
# Update with None value for overwritable field
|
||||
merged_none = merge_agent_state(current_state, {"question": None})
|
||||
assert merged_none["question"] is None
|
||||
assert merged_none["messages"] == ["msg1"]
|
||||
145
backend/tests/test_history_manager.py
Normal file
145
backend/tests/test_history_manager.py
Normal file
@@ -0,0 +1,145 @@
|
||||
import pytest
|
||||
from ea_chatbot.history.manager import HistoryManager
|
||||
from ea_chatbot.history.models import User, Conversation, Message, Plot
|
||||
from ea_chatbot.config import Settings
|
||||
from sqlalchemy import delete
|
||||
|
||||
@pytest.fixture
|
||||
def history_manager():
|
||||
settings = Settings()
|
||||
manager = HistoryManager(settings.history_db_url)
|
||||
# Clean up tables before tests (order matters because of foreign keys)
|
||||
with manager.get_session() as session:
|
||||
session.execute(delete(Plot))
|
||||
session.execute(delete(Message))
|
||||
session.execute(delete(Conversation))
|
||||
session.execute(delete(User))
|
||||
return manager
|
||||
|
||||
def test_history_manager_initialization(history_manager):
|
||||
assert history_manager.engine is not None
|
||||
assert history_manager.SessionLocal is not None
|
||||
|
||||
def test_history_manager_session_context(history_manager):
|
||||
with history_manager.get_session() as session:
|
||||
assert session is not None
|
||||
|
||||
def test_get_user_not_found(history_manager):
|
||||
user = history_manager.get_user("nonexistent@example.com")
|
||||
assert user is None
|
||||
|
||||
def test_authenticate_user_success(history_manager):
|
||||
email = "test@example.com"
|
||||
password = "secretpassword"
|
||||
history_manager.create_user(email=email, password=password)
|
||||
|
||||
user = history_manager.authenticate_user(email, password)
|
||||
assert user is not None
|
||||
assert user.username == email
|
||||
|
||||
def test_authenticate_user_failure(history_manager):
|
||||
email = "test@example.com"
|
||||
history_manager.create_user(email=email, password="correctpassword")
|
||||
|
||||
user = history_manager.authenticate_user(email, "wrongpassword")
|
||||
assert user is None
|
||||
|
||||
def test_sync_user_from_oidc_new_user(history_manager):
|
||||
user = history_manager.sync_user_from_oidc(
|
||||
email="new@example.com",
|
||||
display_name="New User"
|
||||
)
|
||||
assert user is not None
|
||||
assert user.username == "new@example.com"
|
||||
assert user.display_name == "New User"
|
||||
|
||||
def test_sync_user_from_oidc_existing_user(history_manager):
|
||||
# First sync
|
||||
history_manager.sync_user_from_oidc(
|
||||
email="existing@example.com",
|
||||
display_name="First Name"
|
||||
)
|
||||
# Second sync should update or return same user
|
||||
user = history_manager.sync_user_from_oidc(
|
||||
email="existing@example.com",
|
||||
display_name="Updated Name"
|
||||
)
|
||||
assert user.display_name == "Updated Name"
|
||||
|
||||
# --- Conversation Management Tests ---
|
||||
|
||||
@pytest.fixture
|
||||
def user(history_manager):
|
||||
return history_manager.create_user(email="conv_user@example.com")
|
||||
|
||||
def test_create_conversation(history_manager, user):
|
||||
conv = history_manager.create_conversation(
|
||||
user_id=user.id,
|
||||
data_state="new_jersey",
|
||||
name="Test Chat",
|
||||
summary="A test conversation summary"
|
||||
)
|
||||
assert conv is not None
|
||||
assert conv.name == "Test Chat"
|
||||
assert conv.summary == "A test conversation summary"
|
||||
assert conv.user_id == user.id
|
||||
|
||||
def test_get_conversations(history_manager, user):
|
||||
history_manager.create_conversation(user_id=user.id, data_state="nj", name="C1")
|
||||
history_manager.create_conversation(user_id=user.id, data_state="nj", name="C2")
|
||||
history_manager.create_conversation(user_id=user.id, data_state="ny", name="C3")
|
||||
|
||||
nj_convs = history_manager.get_conversations(user_id=user.id, data_state="nj")
|
||||
assert len(nj_convs) == 2
|
||||
|
||||
ny_convs = history_manager.get_conversations(user_id=user.id, data_state="ny")
|
||||
assert len(ny_convs) == 1
|
||||
|
||||
def test_rename_conversation(history_manager, user):
|
||||
conv = history_manager.create_conversation(user.id, "nj", "Old Name")
|
||||
updated = history_manager.rename_conversation(conv.id, "New Name")
|
||||
assert updated.name == "New Name"
|
||||
|
||||
def test_delete_conversation(history_manager, user):
|
||||
conv = history_manager.create_conversation(user.id, "nj", "To Delete")
|
||||
history_manager.delete_conversation(conv.id)
|
||||
|
||||
convs = history_manager.get_conversations(user.id, "nj")
|
||||
assert len(convs) == 0
|
||||
|
||||
# --- Message Management Tests ---
|
||||
|
||||
@pytest.fixture
|
||||
def conversation(history_manager, user):
|
||||
return history_manager.create_conversation(user.id, "nj", "Msg Test Conv")
|
||||
|
||||
def test_add_message(history_manager, conversation):
|
||||
msg = history_manager.add_message(
|
||||
conversation_id=conversation.id,
|
||||
role="user",
|
||||
content="Hello world"
|
||||
)
|
||||
assert msg is not None
|
||||
assert msg.content == "Hello world"
|
||||
assert msg.role == "user"
|
||||
assert msg.conversation_id == conversation.id
|
||||
|
||||
def test_add_message_with_plots(history_manager, conversation):
|
||||
plots_data = [b"fake_plot_1", b"fake_plot_2"]
|
||||
msg = history_manager.add_message(
|
||||
conversation_id=conversation.id,
|
||||
role="assistant",
|
||||
content="Here are plots",
|
||||
plots=plots_data
|
||||
)
|
||||
assert len(msg.plots) == 2
|
||||
assert msg.plots[0].image_data == b"fake_plot_1"
|
||||
|
||||
def test_get_messages(history_manager, conversation):
|
||||
history_manager.add_message(conversation.id, "user", "Q1")
|
||||
history_manager.add_message(conversation.id, "assistant", "A1")
|
||||
|
||||
messages = history_manager.get_messages(conversation.id)
|
||||
assert len(messages) == 2
|
||||
assert messages[0].content == "Q1"
|
||||
assert messages[1].content == "A1"
|
||||
55
backend/tests/test_history_models.py
Normal file
55
backend/tests/test_history_models.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker, DeclarativeBase
|
||||
|
||||
# We anticipate these imports will fail initially
|
||||
try:
|
||||
from ea_chatbot.history.models import Base, User, Conversation, Message, Plot
|
||||
except ImportError:
|
||||
Base = None
|
||||
User = None
|
||||
Conversation = None
|
||||
Message = None
|
||||
Plot = None
|
||||
|
||||
def test_models_exist():
|
||||
assert User is not None, "User model not found"
|
||||
assert Conversation is not None, "Conversation model not found"
|
||||
assert Message is not None, "Message model not found"
|
||||
assert Plot is not None, "Plot model not found"
|
||||
assert Base is not None, "Base declarative class not found"
|
||||
|
||||
def test_user_model_columns():
|
||||
if not User: pytest.fail("User model undefined")
|
||||
# Basic check if columns exist (by inspecting __table__.columns)
|
||||
columns = User.__table__.columns.keys()
|
||||
assert "id" in columns
|
||||
assert "username" in columns
|
||||
assert "password_hash" in columns
|
||||
assert "display_name" in columns
|
||||
|
||||
def test_conversation_model_columns():
|
||||
if not Conversation: pytest.fail("Conversation model undefined")
|
||||
columns = Conversation.__table__.columns.keys()
|
||||
assert "id" in columns
|
||||
assert "user_id" in columns
|
||||
assert "data_state" in columns
|
||||
assert "name" in columns
|
||||
assert "summary" in columns
|
||||
assert "created_at" in columns
|
||||
|
||||
def test_message_model_columns():
|
||||
if not Message: pytest.fail("Message model undefined")
|
||||
columns = Message.__table__.columns.keys()
|
||||
assert "id" in columns
|
||||
assert "role" in columns
|
||||
assert "content" in columns
|
||||
assert "conversation_id" in columns
|
||||
assert "created_at" in columns
|
||||
|
||||
def test_plot_model_columns():
|
||||
if not Plot: pytest.fail("Plot model undefined")
|
||||
columns = Plot.__table__.columns.keys()
|
||||
assert "id" in columns
|
||||
assert "message_id" in columns
|
||||
assert "image_data" in columns
|
||||
4
backend/tests/test_history_module.py
Normal file
4
backend/tests/test_history_module.py
Normal file
@@ -0,0 +1,4 @@
|
||||
import ea_chatbot.history
|
||||
|
||||
def test_history_module_importable():
|
||||
assert ea_chatbot.history is not None
|
||||
54
backend/tests/test_llm_factory.py
Normal file
54
backend/tests/test_llm_factory.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import pytest
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from ea_chatbot.config import LLMConfig
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
|
||||
def test_get_openai_model(monkeypatch):
|
||||
"""Test creating an OpenAI model."""
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "dummy")
|
||||
config = LLMConfig(
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
temperature=0.5,
|
||||
max_tokens=100
|
||||
)
|
||||
model = get_llm_model(config)
|
||||
assert isinstance(model, ChatOpenAI)
|
||||
assert model.model_name == "gpt-4o"
|
||||
assert model.temperature == 0.5
|
||||
assert model.max_tokens == 100
|
||||
|
||||
def test_get_google_model(monkeypatch):
|
||||
"""Test creating a Google model."""
|
||||
monkeypatch.setenv("GOOGLE_API_KEY", "dummy")
|
||||
config = LLMConfig(
|
||||
provider="google",
|
||||
model="gemini-1.5-pro",
|
||||
temperature=0.7
|
||||
)
|
||||
model = get_llm_model(config)
|
||||
assert isinstance(model, ChatGoogleGenerativeAI)
|
||||
assert model.model == "gemini-1.5-pro"
|
||||
assert model.temperature == 0.7
|
||||
|
||||
def test_unsupported_provider():
|
||||
"""Test that an unsupported provider raises an error."""
|
||||
config = LLMConfig(provider="unknown", model="test")
|
||||
with pytest.raises(ValueError, match="Unsupported LLM provider: unknown"):
|
||||
get_llm_model(config)
|
||||
|
||||
def test_provider_specific_params(monkeypatch):
|
||||
"""Test passing provider specific params."""
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "dummy")
|
||||
config = LLMConfig(
|
||||
provider="openai",
|
||||
model="o1-preview",
|
||||
provider_specific={"reasoning_effort": "high"}
|
||||
)
|
||||
# Note: reasoning_effort support depends on the langchain-openai version,
|
||||
# but we check if kwargs are passed.
|
||||
model = get_llm_model(config)
|
||||
assert isinstance(model, ChatOpenAI)
|
||||
# Check if reasoning_effort was passed correctly
|
||||
assert getattr(model, "reasoning_effort", None) == "high"
|
||||
19
backend/tests/test_llm_factory_callbacks.py
Normal file
19
backend/tests/test_llm_factory_callbacks.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import pytest
|
||||
from langchain_openai import ChatOpenAI
|
||||
from ea_chatbot.config import LLMConfig
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
|
||||
class MockHandler(BaseCallbackHandler):
|
||||
pass
|
||||
|
||||
def test_get_llm_model_with_callbacks(monkeypatch):
|
||||
"""Test that callbacks are passed to the model."""
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "dummy")
|
||||
config = LLMConfig(provider="openai", model="gpt-4o")
|
||||
handler = MockHandler()
|
||||
|
||||
model = get_llm_model(config, callbacks=[handler])
|
||||
|
||||
assert isinstance(model, ChatOpenAI)
|
||||
assert handler in model.callbacks
|
||||
44
backend/tests/test_logging_context.py
Normal file
44
backend/tests/test_logging_context.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import logging
|
||||
import pytest
|
||||
import io
|
||||
import json
|
||||
from ea_chatbot.utils.logging import ContextLoggerAdapter, JsonFormatter
|
||||
|
||||
@pytest.fixture
|
||||
def json_log_capture():
|
||||
"""Fixture to capture JSON logs."""
|
||||
log_stream = io.StringIO()
|
||||
logger = logging.getLogger("test_context")
|
||||
logger.setLevel(logging.INFO)
|
||||
for handler in logger.handlers[:]:
|
||||
logger.removeHandler(handler)
|
||||
|
||||
handler = logging.StreamHandler(log_stream)
|
||||
handler.setFormatter(JsonFormatter())
|
||||
logger.addHandler(handler)
|
||||
return logger, log_stream
|
||||
|
||||
def test_context_logger_adapter_injects_metadata(json_log_capture):
|
||||
"""Test that ContextLoggerAdapter injects metadata into the log record."""
|
||||
logger, log_stream = json_log_capture
|
||||
adapter = ContextLoggerAdapter(logger, {"run_id": "123", "node_name": "test_node"})
|
||||
|
||||
adapter.info("test message")
|
||||
|
||||
data = json.loads(log_stream.getvalue())
|
||||
assert data["message"] == "test message"
|
||||
assert data["run_id"] == "123"
|
||||
assert data["node_name"] == "test_node"
|
||||
|
||||
def test_context_logger_adapter_override_metadata(json_log_capture):
|
||||
"""Test that extra metadata can be provided during call."""
|
||||
logger, log_stream = json_log_capture
|
||||
adapter = ContextLoggerAdapter(logger, {"run_id": "123"})
|
||||
|
||||
# Passing extra context via the 'extra' parameter in standard logging
|
||||
# Note: Our adapter should handle merging this.
|
||||
adapter.info("test message", extra={"node_name": "dynamic_node"})
|
||||
|
||||
data = json.loads(log_stream.getvalue())
|
||||
assert data["run_id"] == "123"
|
||||
assert data["node_name"] == "dynamic_node"
|
||||
67
backend/tests/test_logging_core.py
Normal file
67
backend/tests/test_logging_core.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import logging
|
||||
import pytest
|
||||
from ea_chatbot.utils.logging import get_logger
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_logging():
|
||||
"""Reset the ea_chatbot logger handlers before each test."""
|
||||
logger = logging.getLogger("ea_chatbot")
|
||||
# Remove all existing handlers
|
||||
for handler in logger.handlers[:]:
|
||||
logger.removeHandler(handler)
|
||||
yield
|
||||
# Also clean up after test
|
||||
for handler in logger.handlers[:]:
|
||||
logger.removeHandler(handler)
|
||||
|
||||
def test_get_logger_singleton():
|
||||
"""Test that get_logger returns the same logger instance for the same name."""
|
||||
logger1 = get_logger("test_logger")
|
||||
logger2 = get_logger("test_logger")
|
||||
assert logger1 is logger2
|
||||
|
||||
def test_get_logger_rich_handler():
|
||||
"""Test that get_logger configures a RichHandler on root."""
|
||||
get_logger("test_rich")
|
||||
root = logging.getLogger("ea_chatbot")
|
||||
# Check if any handler is a RichHandler
|
||||
handler_names = [h.__class__.__name__ for h in root.handlers]
|
||||
assert "RichHandler" in handler_names
|
||||
|
||||
def test_get_logger_level():
|
||||
"""Test that get_logger sets the correct log level."""
|
||||
logger = get_logger("test_level", level="DEBUG")
|
||||
assert logger.level == logging.DEBUG
|
||||
|
||||
def test_json_formatter_serializes_dict():
|
||||
"""Test that JsonFormatter serializes log records to JSON."""
|
||||
from ea_chatbot.utils.logging import JsonFormatter
|
||||
import json
|
||||
|
||||
formatter = JsonFormatter()
|
||||
record = logging.LogRecord(
|
||||
name="test", level=logging.INFO, pathname="test.py", lineno=10,
|
||||
msg="test message", args=(), exc_info=None
|
||||
)
|
||||
formatted = formatter.format(record)
|
||||
data = json.loads(formatted)
|
||||
|
||||
assert data["message"] == "test message"
|
||||
assert data["level"] == "INFO"
|
||||
assert "timestamp" in data
|
||||
|
||||
def test_get_logger_file_handler(tmp_path):
|
||||
"""Test that get_logger configures a file handler on root."""
|
||||
log_file = tmp_path / "test.json"
|
||||
logger = get_logger("test_file", log_file=str(log_file))
|
||||
|
||||
root = logging.getLogger("ea_chatbot")
|
||||
handler_names = [h.__class__.__name__ for h in root.handlers]
|
||||
assert "RotatingFileHandler" in handler_names
|
||||
|
||||
logger.info("file log test")
|
||||
|
||||
# Check if file exists and has content
|
||||
assert log_file.exists()
|
||||
content = log_file.read_text()
|
||||
assert "file log test" in content
|
||||
83
backend/tests/test_logging_e2e.py
Normal file
83
backend/tests/test_logging_e2e.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import os
|
||||
import json
|
||||
import pytest
|
||||
import logging
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.graph.workflow import app
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.utils.logging import get_logger
|
||||
from langchain_community.chat_models import FakeListChatModel
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_logging():
|
||||
"""Reset handlers on the root ea_chatbot logger."""
|
||||
root = logging.getLogger("ea_chatbot")
|
||||
for handler in root.handlers[:]:
|
||||
root.removeHandler(handler)
|
||||
yield
|
||||
for handler in root.handlers[:]:
|
||||
root.removeHandler(handler)
|
||||
|
||||
class FakeStructuredModel(FakeListChatModel):
|
||||
def with_structured_output(self, schema, **kwargs):
|
||||
# Return a runnable that returns a parsed object
|
||||
def _invoke(input, config=None, **kwargs):
|
||||
content = self.responses[0]
|
||||
import json
|
||||
data = json.loads(content)
|
||||
if hasattr(schema, "model_validate"):
|
||||
return schema.model_validate(data)
|
||||
return data
|
||||
|
||||
return RunnableLambda(_invoke)
|
||||
|
||||
def test_logging_e2e_json_output(tmp_path):
|
||||
"""Test that a full graph run produces structured JSON logs from multiple nodes."""
|
||||
log_file = tmp_path / "e2e_test.jsonl"
|
||||
|
||||
# Configure the root logger
|
||||
get_logger("ea_chatbot", log_file=str(log_file))
|
||||
|
||||
initial_state: AgentState = {
|
||||
"messages": [],
|
||||
"question": "Who won in 2024?",
|
||||
"analysis": None,
|
||||
"next_action": "",
|
||||
"plan": None,
|
||||
"code": None,
|
||||
"code_output": None,
|
||||
"error": None,
|
||||
"plots": [],
|
||||
"dfs": {}
|
||||
}
|
||||
|
||||
# Create fake models that support callbacks and structured output
|
||||
fake_analyzer_response = """{"data_required": [], "unknowns": [], "ambiguities": ["Which year?"], "conditions": [], "next_action": "clarify"}"""
|
||||
fake_analyzer = FakeStructuredModel(responses=[fake_analyzer_response])
|
||||
|
||||
fake_clarify = FakeListChatModel(responses=["Please specify."])
|
||||
|
||||
with patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model") as mock_llm_factory:
|
||||
mock_llm_factory.return_value = fake_analyzer
|
||||
|
||||
with patch("ea_chatbot.graph.nodes.clarification.get_llm_model") as mock_clarify_llm_factory:
|
||||
mock_clarify_llm_factory.return_value = fake_clarify
|
||||
|
||||
# Run the graph
|
||||
list(app.stream(initial_state))
|
||||
|
||||
# Verify file content
|
||||
assert log_file.exists()
|
||||
lines = log_file.read_text().splitlines()
|
||||
assert len(lines) > 0
|
||||
|
||||
# Verify we have logs from different nodes
|
||||
node_names = [json.loads(line)["name"] for line in lines]
|
||||
assert "ea_chatbot.query_analyzer" in node_names
|
||||
assert "ea_chatbot.clarification" in node_names
|
||||
|
||||
# Verify events
|
||||
messages = [json.loads(line)["message"] for line in lines]
|
||||
assert any("Analyzing question" in m for m in messages)
|
||||
assert any("Clarification generated" in m for m in messages)
|
||||
64
backend/tests/test_logging_langchain.py
Normal file
64
backend/tests/test_logging_langchain.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import logging
|
||||
import pytest
|
||||
import io
|
||||
from unittest.mock import MagicMock
|
||||
from ea_chatbot.utils.logging import LangChainLoggingHandler
|
||||
|
||||
@pytest.fixture
|
||||
def log_capture():
|
||||
"""Fixture to capture logs from a logger."""
|
||||
log_stream = io.StringIO()
|
||||
logger = logging.getLogger("test_langchain")
|
||||
logger.setLevel(logging.INFO)
|
||||
# Remove existing handlers
|
||||
for handler in logger.handlers[:]:
|
||||
logger.removeHandler(handler)
|
||||
|
||||
handler = logging.StreamHandler(log_stream)
|
||||
logger.addHandler(handler)
|
||||
return logger, log_stream
|
||||
|
||||
def test_langchain_logging_handler_on_llm_start(log_capture):
|
||||
"""Test that on_llm_start logs the correct message."""
|
||||
logger, log_stream = log_capture
|
||||
handler = LangChainLoggingHandler(logger=logger)
|
||||
handler.on_llm_start(serialized={"name": "test_model"}, prompts=["test prompt"])
|
||||
|
||||
output = log_stream.getvalue()
|
||||
assert "LLM Started:" in output
|
||||
assert "test_model" in output
|
||||
|
||||
def test_langchain_logging_handler_on_llm_end(log_capture):
|
||||
"""Test that on_llm_end logs token usage."""
|
||||
logger, log_stream = log_capture
|
||||
handler = LangChainLoggingHandler(logger=logger)
|
||||
response = MagicMock()
|
||||
response.llm_output = {
|
||||
"token_usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 20,
|
||||
"total_tokens": 30
|
||||
},
|
||||
"model_name": "test_model"
|
||||
}
|
||||
|
||||
handler.on_llm_end(response=response)
|
||||
|
||||
output = log_stream.getvalue()
|
||||
assert "LLM Ended:" in output
|
||||
assert "test_model" in output
|
||||
assert "Tokens: 30" in output
|
||||
assert "10 prompt" in output
|
||||
assert "20 completion" in output
|
||||
|
||||
def test_langchain_logging_handler_on_llm_error(log_capture):
|
||||
"""Test that on_llm_error logs the error."""
|
||||
logger, log_stream = log_capture
|
||||
handler = LangChainLoggingHandler(logger=logger)
|
||||
error = Exception("test error")
|
||||
|
||||
handler.on_llm_error(error=error)
|
||||
|
||||
output = log_stream.getvalue()
|
||||
assert "LLM Error:" in output
|
||||
assert "test error" in output
|
||||
79
backend/tests/test_multi_turn_planner_researcher.py
Normal file
79
backend/tests/test_multi_turn_planner_researcher.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from ea_chatbot.graph.nodes.planner import planner_node
|
||||
from ea_chatbot.graph.nodes.researcher import researcher_node
|
||||
from ea_chatbot.graph.nodes.summarizer import summarizer_node
|
||||
from ea_chatbot.schemas import TaskPlanResponse
|
||||
|
||||
@pytest.fixture
|
||||
def mock_state_with_history():
|
||||
return {
|
||||
"messages": [
|
||||
HumanMessage(content="Show me the 2024 results for Florida"),
|
||||
AIMessage(content="Here are the results for Florida in 2024...")
|
||||
],
|
||||
"question": "What about in New Jersey?",
|
||||
"analysis": {"data_required": ["2024 results", "New Jersey"], "unknowns": [], "ambiguities": [], "conditions": []},
|
||||
"next_action": "plan",
|
||||
"summary": "The user is asking about 2024 election results.",
|
||||
"plan": "Plan steps...",
|
||||
"code_output": "Code output..."
|
||||
}
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.planner.get_llm_model")
|
||||
@patch("ea_chatbot.utils.database_inspection.get_data_summary")
|
||||
@patch("ea_chatbot.graph.nodes.planner.PLANNER_PROMPT")
|
||||
def test_planner_uses_history_and_summary(mock_prompt, mock_get_summary, mock_get_llm, mock_state_with_history):
|
||||
mock_get_summary.return_value = "Data summary"
|
||||
mock_llm_instance = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm_instance
|
||||
mock_structured_llm = MagicMock()
|
||||
mock_llm_instance.with_structured_output.return_value = mock_structured_llm
|
||||
|
||||
mock_structured_llm.invoke.return_value = TaskPlanResponse(
|
||||
goal="goal",
|
||||
reflection="reflection",
|
||||
context={
|
||||
"initial_context": "context",
|
||||
"assumptions": [],
|
||||
"constraints": []
|
||||
},
|
||||
steps=["Step 1: test"]
|
||||
)
|
||||
|
||||
planner_node(mock_state_with_history)
|
||||
|
||||
mock_prompt.format_messages.assert_called_once()
|
||||
kwargs = mock_prompt.format_messages.call_args[1]
|
||||
assert kwargs["question"] == "What about in New Jersey?"
|
||||
assert kwargs["summary"] == mock_state_with_history["summary"]
|
||||
assert len(kwargs["history"]) == 2
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
|
||||
@patch("ea_chatbot.graph.nodes.researcher.RESEARCHER_PROMPT")
|
||||
def test_researcher_uses_history_and_summary(mock_prompt, mock_get_llm, mock_state_with_history):
|
||||
mock_llm_instance = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm_instance
|
||||
|
||||
researcher_node(mock_state_with_history)
|
||||
|
||||
mock_prompt.format_messages.assert_called_once()
|
||||
kwargs = mock_prompt.format_messages.call_args[1]
|
||||
assert kwargs["question"] == "What about in New Jersey?"
|
||||
assert kwargs["summary"] == mock_state_with_history["summary"]
|
||||
assert len(kwargs["history"]) == 2
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.summarizer.get_llm_model")
|
||||
@patch("ea_chatbot.graph.nodes.summarizer.SUMMARIZER_PROMPT")
|
||||
def test_summarizer_uses_history_and_summary(mock_prompt, mock_get_llm, mock_state_with_history):
|
||||
mock_llm_instance = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm_instance
|
||||
|
||||
summarizer_node(mock_state_with_history)
|
||||
|
||||
mock_prompt.format_messages.assert_called_once()
|
||||
kwargs = mock_prompt.format_messages.call_args[1]
|
||||
assert kwargs["question"] == "What about in New Jersey?"
|
||||
assert kwargs["summary"] == mock_state_with_history["summary"]
|
||||
assert len(kwargs["history"]) == 2
|
||||
76
backend/tests/test_multi_turn_query_analyzer.py
Normal file
76
backend/tests/test_multi_turn_query_analyzer.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node, QueryAnalysis
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
|
||||
@pytest.fixture
|
||||
def mock_state_with_history():
|
||||
return {
|
||||
"messages": [
|
||||
HumanMessage(content="Show me the 2024 results for Florida"),
|
||||
AIMessage(content="Here are the results for Florida in 2024...")
|
||||
],
|
||||
"question": "What about in New Jersey?",
|
||||
"analysis": None,
|
||||
"next_action": "",
|
||||
"summary": "The user is asking about 2024 election results."
|
||||
}
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
|
||||
@patch("ea_chatbot.graph.nodes.query_analyzer.QUERY_ANALYZER_PROMPT")
|
||||
def test_query_analyzer_uses_history_and_summary(mock_prompt, mock_get_llm, mock_state_with_history):
|
||||
"""Test that query_analyzer_node passes history and summary to the prompt."""
|
||||
mock_llm_instance = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm_instance
|
||||
mock_structured_llm = MagicMock()
|
||||
mock_llm_instance.with_structured_output.return_value = mock_structured_llm
|
||||
|
||||
mock_structured_llm.invoke.return_value = QueryAnalysis(
|
||||
data_required=["2024 results", "New Jersey"],
|
||||
unknowns=[],
|
||||
ambiguities=[],
|
||||
conditions=[],
|
||||
next_action="plan"
|
||||
)
|
||||
|
||||
query_analyzer_node(mock_state_with_history)
|
||||
|
||||
# Verify that the prompt was formatted with the correct variables
|
||||
mock_prompt.format_messages.assert_called_once()
|
||||
kwargs = mock_prompt.format_messages.call_args[1]
|
||||
|
||||
assert kwargs["question"] == "What about in New Jersey?"
|
||||
assert "summary" in kwargs
|
||||
assert kwargs["summary"] == mock_state_with_history["summary"]
|
||||
assert "history" in kwargs
|
||||
# History should contain the messages from the state
|
||||
assert len(kwargs["history"]) == 2
|
||||
assert kwargs["history"][0].content == "Show me the 2024 results for Florida"
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
|
||||
def test_query_analyzer_context_window(mock_get_llm):
|
||||
"""Test that query_analyzer_node only uses the last 6 messages (3 turns)."""
|
||||
messages = [HumanMessage(content=f"Msg {i}") for i in range(10)]
|
||||
state = {
|
||||
"messages": messages,
|
||||
"question": "Latest question",
|
||||
"analysis": None,
|
||||
"next_action": "",
|
||||
"summary": "Summary"
|
||||
}
|
||||
|
||||
mock_llm_instance = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm_instance
|
||||
mock_structured_llm = MagicMock()
|
||||
mock_llm_instance.with_structured_output.return_value = mock_structured_llm
|
||||
mock_structured_llm.invoke.return_value = QueryAnalysis(
|
||||
data_required=[], unknowns=[], ambiguities=[], conditions=[], next_action="plan"
|
||||
)
|
||||
|
||||
with patch("ea_chatbot.graph.nodes.query_analyzer.QUERY_ANALYZER_PROMPT") as mock_prompt:
|
||||
query_analyzer_node(state)
|
||||
kwargs = mock_prompt.format_messages.call_args[1]
|
||||
# Should only have last 6 messages
|
||||
assert len(kwargs["history"]) == 6
|
||||
assert kwargs["history"][0].content == "Msg 4"
|
||||
87
backend/tests/test_oidc_client.py
Normal file
87
backend/tests/test_oidc_client.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.auth import OIDCClient
|
||||
|
||||
@pytest.fixture
|
||||
def oidc_config():
|
||||
return {
|
||||
"client_id": "test_id",
|
||||
"client_secret": "test_secret",
|
||||
"server_metadata_url": "https://example.com/.well-known/openid-configuration",
|
||||
"redirect_uri": "http://localhost:8501"
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_metadata():
|
||||
return {
|
||||
"authorization_endpoint": "https://example.com/auth",
|
||||
"token_endpoint": "https://example.com/token",
|
||||
"userinfo_endpoint": "https://example.com/userinfo"
|
||||
}
|
||||
|
||||
def test_oidc_fetch_metadata(oidc_config, mock_metadata):
|
||||
client = OIDCClient(**oidc_config)
|
||||
|
||||
with patch("requests.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = mock_metadata
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
metadata = client.fetch_metadata()
|
||||
|
||||
assert metadata == mock_metadata
|
||||
mock_get.assert_called_once_with(oidc_config["server_metadata_url"])
|
||||
|
||||
# Second call should use cache
|
||||
client.fetch_metadata()
|
||||
assert mock_get.call_count == 1
|
||||
|
||||
def test_oidc_get_login_url(oidc_config, mock_metadata):
|
||||
client = OIDCClient(**oidc_config)
|
||||
client.metadata = mock_metadata
|
||||
|
||||
with patch.object(client.oauth_session, "create_authorization_url") as mock_create_url:
|
||||
mock_create_url.return_value = ("https://example.com/auth?state=xyz", "xyz")
|
||||
|
||||
url = client.get_login_url()
|
||||
|
||||
assert url == "https://example.com/auth?state=xyz"
|
||||
mock_create_url.assert_called_once_with(mock_metadata["authorization_endpoint"])
|
||||
|
||||
def test_oidc_get_login_url_missing_endpoint(oidc_config):
|
||||
client = OIDCClient(**oidc_config)
|
||||
client.metadata = {"some_other": "field"}
|
||||
|
||||
with pytest.raises(ValueError, match="authorization_endpoint not found"):
|
||||
client.get_login_url()
|
||||
|
||||
def test_oidc_exchange_code_for_token(oidc_config, mock_metadata):
|
||||
client = OIDCClient(**oidc_config)
|
||||
client.metadata = mock_metadata
|
||||
|
||||
with patch.object(client.oauth_session, "fetch_token") as mock_fetch_token:
|
||||
mock_fetch_token.return_value = {"access_token": "abc"}
|
||||
|
||||
token = client.exchange_code_for_token("test_code")
|
||||
|
||||
assert token == {"access_token": "abc"}
|
||||
mock_fetch_token.assert_called_once_with(
|
||||
mock_metadata["token_endpoint"],
|
||||
code="test_code",
|
||||
client_secret=oidc_config["client_secret"]
|
||||
)
|
||||
|
||||
def test_oidc_get_user_info(oidc_config, mock_metadata):
|
||||
client = OIDCClient(**oidc_config)
|
||||
client.metadata = mock_metadata
|
||||
|
||||
with patch.object(client.oauth_session, "get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"sub": "user123"}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
user_info = client.get_user_info({"access_token": "abc"})
|
||||
|
||||
assert user_info == {"sub": "user123"}
|
||||
assert client.oauth_session.token == {"access_token": "abc"}
|
||||
mock_get.assert_called_once_with(mock_metadata["userinfo_endpoint"])
|
||||
46
backend/tests/test_planner.py
Normal file
46
backend/tests/test_planner.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.graph.nodes.planner import planner_node
|
||||
|
||||
@pytest.fixture
|
||||
def mock_state():
|
||||
return {
|
||||
"messages": [],
|
||||
"question": "Show me results for New Jersey",
|
||||
"analysis": {
|
||||
# "requires_dataset" removed as it's no longer used
|
||||
"expert": "Data Analyst",
|
||||
"data": "NJ data",
|
||||
"unknown": "results",
|
||||
"condition": "state=NJ"
|
||||
},
|
||||
"next_action": "plan",
|
||||
"plan": None
|
||||
}
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.planner.get_llm_model")
|
||||
@patch("ea_chatbot.utils.database_inspection.get_data_summary")
|
||||
def test_planner_node(mock_get_summary, mock_get_llm, mock_state):
|
||||
"""Test planner node with unified prompt."""
|
||||
mock_get_summary.return_value = "Column: Name, Type: text"
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
from ea_chatbot.schemas import TaskPlanResponse, TaskPlanContext
|
||||
mock_plan = TaskPlanResponse(
|
||||
goal="Get NJ results",
|
||||
reflection="The user wants NJ results",
|
||||
context=TaskPlanContext(initial_context="NJ data", assumptions=[], constraints=[]),
|
||||
steps=["Step 1: Load data", "Step 2: Filter by NJ"]
|
||||
)
|
||||
mock_llm.with_structured_output.return_value.invoke.return_value = mock_plan
|
||||
|
||||
result = planner_node(mock_state)
|
||||
|
||||
assert "plan" in result
|
||||
assert "Step 1: Load data" in result["plan"]
|
||||
assert "Step 2: Filter by NJ" in result["plan"]
|
||||
|
||||
# Verify helper was called
|
||||
mock_get_summary.assert_called_once()
|
||||
80
backend/tests/test_query_analyzer.py
Normal file
80
backend/tests/test_query_analyzer.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node, QueryAnalysis
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
|
||||
@pytest.fixture
|
||||
def mock_state():
|
||||
return {
|
||||
"messages": [],
|
||||
"question": "Show me the 2024 results for Florida",
|
||||
"analysis": None,
|
||||
"next_action": ""
|
||||
}
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
|
||||
def test_query_analyzer_data_analysis(mock_get_llm, mock_state):
|
||||
"""Test that a clear data analysis query is routed to the planner."""
|
||||
# Mock the LLM and the structured output runnable
|
||||
mock_llm_instance = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm_instance
|
||||
mock_structured_llm = MagicMock()
|
||||
mock_llm_instance.with_structured_output.return_value = mock_structured_llm
|
||||
# Define the expected Pydantic result
|
||||
expected_analysis = QueryAnalysis(
|
||||
data_required=["2024 results", "Florida"],
|
||||
unknowns=[],
|
||||
ambiguities=[],
|
||||
conditions=[],
|
||||
next_action="plan"
|
||||
)
|
||||
# When structured_llm.invoke is called with messages, return the Pydantic object
|
||||
mock_structured_llm.invoke.return_value = expected_analysis
|
||||
|
||||
new_state_update = query_analyzer_node(mock_state)
|
||||
|
||||
assert new_state_update["next_action"] == "plan"
|
||||
assert "2024 results" in new_state_update["analysis"]["data_required"]
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
|
||||
def test_query_analyzer_ambiguous(mock_get_llm, mock_state):
|
||||
"""Test that an ambiguous query is routed to clarification."""
|
||||
mock_state["question"] = "What happened?"
|
||||
|
||||
mock_llm_instance = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm_instance
|
||||
mock_structured_llm = MagicMock()
|
||||
mock_llm_instance.with_structured_output.return_value = mock_structured_llm
|
||||
expected_analysis = QueryAnalysis(
|
||||
data_required=[],
|
||||
unknowns=["What event?"],
|
||||
ambiguities=[],
|
||||
conditions=[],
|
||||
next_action="clarify"
|
||||
)
|
||||
|
||||
mock_structured_llm.invoke.return_value = expected_analysis
|
||||
|
||||
new_state_update = query_analyzer_node(mock_state)
|
||||
|
||||
assert new_state_update["next_action"] == "clarify"
|
||||
assert len(new_state_update["analysis"]["unknowns"]) > 0
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
|
||||
def test_query_analyzer_uses_config(mock_get_llm, mock_state, monkeypatch):
|
||||
"""Test that the node uses the configured LLM settings."""
|
||||
monkeypatch.setenv("QUERY_ANALYZER_LLM__MODEL", "gpt-3.5-turbo")
|
||||
|
||||
mock_llm_instance = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm_instance
|
||||
mock_structured_llm = MagicMock()
|
||||
mock_llm_instance.with_structured_output.return_value = mock_structured_llm
|
||||
mock_structured_llm.invoke.return_value = QueryAnalysis(
|
||||
data_required=[], unknowns=[], ambiguities=[], conditions=[], next_action="plan"
|
||||
)
|
||||
|
||||
query_analyzer_node(mock_state)
|
||||
|
||||
# Verify get_llm_model was called with the overridden config
|
||||
called_config = mock_get_llm.call_args[0][0]
|
||||
assert called_config.model == "gpt-3.5-turbo"
|
||||
45
backend/tests/test_query_analyzer_logging.py
Normal file
45
backend/tests/test_query_analyzer_logging.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import pytest
|
||||
import logging
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node, QueryAnalysis
|
||||
|
||||
@pytest.fixture
|
||||
def mock_state():
|
||||
return {
|
||||
"messages": [],
|
||||
"question": "Show me the 2024 results for Florida",
|
||||
"analysis": None,
|
||||
"next_action": ""
|
||||
}
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
|
||||
@patch("ea_chatbot.graph.nodes.query_analyzer.get_logger")
|
||||
def test_query_analyzer_logs_actions(mock_get_logger, mock_get_llm, mock_state):
|
||||
"""Test that query_analyzer_node logs its main actions."""
|
||||
# Mock Logger
|
||||
mock_logger = MagicMock()
|
||||
mock_get_logger.return_value = mock_logger
|
||||
|
||||
# Mock LLM
|
||||
mock_llm_instance = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm_instance
|
||||
mock_structured_llm = MagicMock()
|
||||
mock_llm_instance.with_structured_output.return_value = mock_structured_llm
|
||||
|
||||
expected_analysis = QueryAnalysis(
|
||||
data_required=["2024 results", "Florida"],
|
||||
unknowns=[],
|
||||
ambiguities=[],
|
||||
conditions=[],
|
||||
next_action="plan"
|
||||
)
|
||||
mock_structured_llm.invoke.return_value = expected_analysis
|
||||
|
||||
query_analyzer_node(mock_state)
|
||||
|
||||
# Check that logger was called
|
||||
# We expect at least one log at the start and one at the end
|
||||
assert mock_logger.info.called
|
||||
|
||||
# Verify specific log messages if we decide on them
|
||||
# For now, just ensuring it's called is enough for Red phase
|
||||
103
backend/tests/test_query_analyzer_refinement.py
Normal file
103
backend/tests/test_query_analyzer_refinement.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node, QueryAnalysis
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
|
||||
@pytest.fixture
|
||||
def base_state():
|
||||
return {
|
||||
"messages": [],
|
||||
"question": "",
|
||||
"analysis": None,
|
||||
"next_action": "",
|
||||
"summary": ""
|
||||
}
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
|
||||
def test_refinement_coreference_from_history(mock_get_llm, base_state):
|
||||
"""
|
||||
Test that the analyzer can resolve Year/State from history.
|
||||
User asks "What about in NJ?" after a Florida 2024 query.
|
||||
Expected: next_action = 'plan', NOT 'clarify' due to missing year.
|
||||
"""
|
||||
state = base_state.copy()
|
||||
state["messages"] = [
|
||||
HumanMessage(content="Show me 2024 results for Florida"),
|
||||
AIMessage(content="Here are the 2024 results for Florida...")
|
||||
]
|
||||
state["question"] = "What about in New Jersey?"
|
||||
state["summary"] = "The user is looking for 2024 election results."
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_structured = MagicMock()
|
||||
mock_llm.with_structured_output.return_value = mock_structured
|
||||
|
||||
# We expect the LLM to eventually return 'plan' because it sees the context.
|
||||
# For now, if it returns 'clarify', this test should fail once we update the prompt to BE less strict.
|
||||
mock_structured.invoke.return_value = QueryAnalysis(
|
||||
data_required=["2024 results", "New Jersey"],
|
||||
unknowns=[],
|
||||
ambiguities=[],
|
||||
conditions=["state=NJ", "year=2024"],
|
||||
next_action="plan"
|
||||
)
|
||||
|
||||
result = query_analyzer_node(state)
|
||||
assert result["next_action"] == "plan"
|
||||
assert "NJ" in str(result["analysis"]["conditions"])
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
|
||||
def test_refinement_tolerance_for_missing_format(mock_get_llm, base_state):
|
||||
"""
|
||||
Test that the analyzer doesn't flag missing output format or database name.
|
||||
User asks "Give me a graph of turnout".
|
||||
Expected: next_action = 'plan', even if 'format' or 'db' is not in query.
|
||||
"""
|
||||
state = base_state.copy()
|
||||
state["question"] = "Give me a graph of voter turnout in 2024 for Florida"
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_structured = MagicMock()
|
||||
mock_llm.with_structured_output.return_value = mock_structured
|
||||
|
||||
mock_structured.invoke.return_value = QueryAnalysis(
|
||||
data_required=["voter turnout", "Florida"],
|
||||
unknowns=[],
|
||||
ambiguities=[],
|
||||
conditions=["year=2024"],
|
||||
next_action="plan"
|
||||
)
|
||||
|
||||
result = query_analyzer_node(state)
|
||||
assert result["next_action"] == "plan"
|
||||
# Ensure no ambiguities were added by the analyzer itself (hallucinated requirement)
|
||||
assert len(result["analysis"]["ambiguities"]) == 0
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
|
||||
def test_refinement_enforces_voter_identity_clarification(mock_get_llm, base_state):
|
||||
"""
|
||||
Test that 'track the same voter' still triggers clarification.
|
||||
"""
|
||||
state = base_state.copy()
|
||||
state["question"] = "Track the same voter participation in 2020 and 2024."
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_structured = MagicMock()
|
||||
mock_llm.with_structured_output.return_value = mock_structured
|
||||
|
||||
# We WANT it to clarify here because voter identity is not defined.
|
||||
mock_structured.invoke.return_value = QueryAnalysis(
|
||||
data_required=["voter participation"],
|
||||
unknowns=[],
|
||||
ambiguities=["Please define what fields constitute 'the same voter' (e.g. ID, or Name and DOB)."],
|
||||
conditions=[],
|
||||
next_action="clarify"
|
||||
)
|
||||
|
||||
result = query_analyzer_node(state)
|
||||
assert result["next_action"] == "clarify"
|
||||
assert "identity" in str(result["analysis"]["ambiguities"]).lower() or "same voter" in str(result["analysis"]["ambiguities"]).lower()
|
||||
34
backend/tests/test_researcher.py
Normal file
34
backend/tests/test_researcher.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_openai import ChatOpenAI
|
||||
from ea_chatbot.graph.nodes.researcher import researcher_node
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm():
|
||||
with patch("ea_chatbot.graph.nodes.researcher.get_llm_model") as mock_get_llm:
|
||||
mock_llm_instance = MagicMock(spec=ChatOpenAI)
|
||||
mock_get_llm.return_value = mock_llm_instance
|
||||
yield mock_llm_instance
|
||||
|
||||
def test_researcher_node_success(mock_llm):
|
||||
"""Test that researcher_node invokes LLM with web_search tool and returns messages."""
|
||||
state = {
|
||||
"question": "What is the capital of France?",
|
||||
"messages": []
|
||||
}
|
||||
|
||||
mock_llm_with_tools = MagicMock()
|
||||
mock_llm.bind_tools.return_value = mock_llm_with_tools
|
||||
mock_llm_with_tools.invoke.return_value = AIMessage(content="The capital of France is Paris.")
|
||||
|
||||
result = researcher_node(state)
|
||||
|
||||
assert mock_llm.bind_tools.called
|
||||
# Check that it was called with web_search
|
||||
args, kwargs = mock_llm.bind_tools.call_args
|
||||
assert {"type": "web_search"} in args[0]
|
||||
|
||||
assert mock_llm_with_tools.invoke.called
|
||||
assert "messages" in result
|
||||
assert result["messages"][0].content == "The capital of France is Paris."
|
||||
62
backend/tests/test_researcher_search_tools.py
Normal file
62
backend/tests/test_researcher_search_tools.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from ea_chatbot.graph.nodes.researcher import researcher_node
|
||||
|
||||
@pytest.fixture
|
||||
def base_state():
|
||||
return {
|
||||
"question": "Who won the 2024 election?",
|
||||
"messages": [],
|
||||
"summary": ""
|
||||
}
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
|
||||
def test_researcher_binds_openai_search(mock_get_llm, base_state):
|
||||
"""Test that OpenAI LLM binds 'web_search' tool."""
|
||||
mock_llm = MagicMock(spec=ChatOpenAI)
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
mock_llm_with_tools = MagicMock()
|
||||
mock_llm.bind_tools.return_value = mock_llm_with_tools
|
||||
mock_llm_with_tools.invoke.return_value = AIMessage(content="OpenAI Search Result")
|
||||
|
||||
result = researcher_node(base_state)
|
||||
|
||||
# Verify bind_tools called with correct OpenAI tool
|
||||
mock_llm.bind_tools.assert_called_once_with([{"type": "web_search"}])
|
||||
assert result["messages"][0].content == "OpenAI Search Result"
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
|
||||
def test_researcher_binds_google_search(mock_get_llm, base_state):
|
||||
"""Test that Google LLM binds 'google_search' tool."""
|
||||
mock_llm = MagicMock(spec=ChatGoogleGenerativeAI)
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
mock_llm_with_tools = MagicMock()
|
||||
mock_llm.bind_tools.return_value = mock_llm_with_tools
|
||||
mock_llm_with_tools.invoke.return_value = AIMessage(content="Google Search Result")
|
||||
|
||||
result = researcher_node(base_state)
|
||||
|
||||
# Verify bind_tools called with correct Google tool
|
||||
mock_llm.bind_tools.assert_called_once_with([{"google_search": {}}])
|
||||
assert result["messages"][0].content == "Google Search Result"
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
|
||||
def test_researcher_fallback_on_bind_error(mock_get_llm, base_state):
|
||||
"""Test that researcher falls back to basic LLM if bind_tools fails."""
|
||||
mock_llm = MagicMock(spec=ChatOpenAI)
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
# Simulate bind_tools failing (e.g. model doesn't support it)
|
||||
mock_llm.bind_tools.side_effect = Exception("Not supported")
|
||||
mock_llm.invoke.return_value = AIMessage(content="Basic Result")
|
||||
|
||||
result = researcher_node(base_state)
|
||||
|
||||
# Should still succeed using the base LLM
|
||||
assert result["messages"][0].content == "Basic Result"
|
||||
mock_llm.invoke.assert_called_once()
|
||||
41
backend/tests/test_state.py
Normal file
41
backend/tests/test_state.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import pytest
|
||||
from typing import get_type_hints, List
|
||||
from langchain_core.messages import BaseMessage, HumanMessage
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
import operator
|
||||
|
||||
def test_agent_state_structure():
|
||||
"""Verify that AgentState has the required fields and types."""
|
||||
hints = get_type_hints(AgentState)
|
||||
|
||||
assert "messages" in hints
|
||||
# Check if Annotated is used, we might need to inspect the __metadata__ if feasible,
|
||||
# but for TypedDict, checking the key existence is a good start.
|
||||
# The exact type check for Annotated[List[BaseMessage], operator.add] can be complex to assert strictly,
|
||||
# but we can check if it's there.
|
||||
|
||||
assert "question" in hints
|
||||
assert hints["question"] == str
|
||||
|
||||
# analysis should be Optional[Dict[str, Any]] or similar, but the spec says "Dictionary"
|
||||
# Let's check it exists.
|
||||
assert "analysis" in hints
|
||||
|
||||
assert "next_action" in hints
|
||||
assert hints["next_action"] == str
|
||||
|
||||
assert "summary" in hints
|
||||
# summary should be Optional[str] or str. Let's assume Optional[str] for flexibility.
|
||||
|
||||
assert "plots" in hints
|
||||
assert "dfs" in hints
|
||||
|
||||
def test_messages_reducer_behavior():
|
||||
"""Verify that the messages field allows adding lists (simulation of operator.add)."""
|
||||
# This is harder to test directly on the TypedDict definition without instantiating it in a graph context,
|
||||
# but we can verify that the type hint implies a list.
|
||||
hints = get_type_hints(AgentState)
|
||||
# We expect messages to be Annotated[List[BaseMessage], operator.add]
|
||||
# We can just assume the developer implements it correctly if the previous test passes,
|
||||
# or try to inspect the annotation.
|
||||
pass
|
||||
48
backend/tests/test_summarizer.py
Normal file
48
backend/tests/test_summarizer.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from langchain_core.messages import AIMessage
|
||||
from ea_chatbot.graph.nodes.summarizer import summarizer_node
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm():
|
||||
with patch("ea_chatbot.graph.nodes.summarizer.get_llm_model") as mock_get_llm:
|
||||
mock_llm_instance = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm_instance
|
||||
yield mock_llm_instance
|
||||
|
||||
def test_summarizer_node_success(mock_llm):
|
||||
"""Test that summarizer_node invokes LLM with correct inputs and returns messages."""
|
||||
state = {
|
||||
"question": "What is the total count?",
|
||||
"plan": "1. Run query\n2. Sum results",
|
||||
"code_output": "The total is 100",
|
||||
"messages": []
|
||||
}
|
||||
|
||||
mock_llm.invoke.return_value = AIMessage(content="The final answer is 100.")
|
||||
|
||||
result = summarizer_node(state)
|
||||
|
||||
# Verify LLM was called
|
||||
assert mock_llm.invoke.called
|
||||
|
||||
# Verify result structure
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 1
|
||||
assert isinstance(result["messages"][0], AIMessage)
|
||||
assert result["messages"][0].content == "The final answer is 100."
|
||||
|
||||
def test_summarizer_node_empty_state(mock_llm):
|
||||
"""Test handling of empty or minimal state."""
|
||||
state = {
|
||||
"question": "Empty?",
|
||||
"messages": []
|
||||
}
|
||||
|
||||
mock_llm.invoke.return_value = AIMessage(content="No data provided.")
|
||||
|
||||
result = summarizer_node(state)
|
||||
|
||||
assert "messages" in result
|
||||
assert result["messages"][0].content == "No data provided."
|
||||
93
backend/tests/test_workflow.py
Normal file
93
backend/tests/test_workflow.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.graph.workflow import app
|
||||
from ea_chatbot.graph.nodes.query_analyzer import QueryAnalysis
|
||||
from ea_chatbot.schemas import TaskPlanResponse, TaskPlanContext, CodeGenerationResponse
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
|
||||
@patch("ea_chatbot.graph.nodes.planner.get_llm_model")
|
||||
@patch("ea_chatbot.graph.nodes.coder.get_llm_model")
|
||||
@patch("ea_chatbot.graph.nodes.summarizer.get_llm_model")
|
||||
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
|
||||
@patch("ea_chatbot.utils.database_inspection.get_data_summary")
|
||||
@patch("ea_chatbot.graph.nodes.executor.Settings")
|
||||
@patch("ea_chatbot.graph.nodes.executor.DBClient")
|
||||
def test_workflow_full_flow(mock_db_client, mock_settings, mock_get_summary, mock_researcher_llm, mock_summarizer_llm, mock_coder_llm, mock_planner_llm, mock_qa_llm):
|
||||
"""Test the flow from query_analyzer through planner to coder."""
|
||||
|
||||
# Mock Settings for Executor
|
||||
mock_settings_instance = MagicMock()
|
||||
mock_settings_instance.db_host = "localhost"
|
||||
mock_settings_instance.db_port = 5432
|
||||
mock_settings_instance.db_user = "user"
|
||||
mock_settings_instance.db_pswd = "pass"
|
||||
mock_settings_instance.db_name = "test_db"
|
||||
mock_settings_instance.db_table = "test_table"
|
||||
mock_settings.return_value = mock_settings_instance
|
||||
|
||||
# Mock DBClient
|
||||
mock_client_instance = MagicMock()
|
||||
mock_db_client.return_value = mock_client_instance
|
||||
|
||||
# 1. Mock Query Analyzer
|
||||
mock_qa_instance = MagicMock()
|
||||
mock_qa_llm.return_value = mock_qa_instance
|
||||
mock_qa_instance.with_structured_output.return_value.invoke.return_value = QueryAnalysis(
|
||||
data_required=["2024 results"],
|
||||
unknowns=[],
|
||||
ambiguities=[],
|
||||
conditions=[],
|
||||
next_action="plan"
|
||||
)
|
||||
|
||||
# 2. Mock Planner
|
||||
mock_planner_instance = MagicMock()
|
||||
mock_planner_llm.return_value = mock_planner_instance
|
||||
mock_get_summary.return_value = "Data summary"
|
||||
mock_planner_instance.with_structured_output.return_value.invoke.return_value = TaskPlanResponse(
|
||||
goal="Task Goal",
|
||||
reflection="Reflection",
|
||||
context=TaskPlanContext(initial_context="Ctx", assumptions=[], constraints=[]),
|
||||
steps=["Step 1"]
|
||||
)
|
||||
|
||||
# 3. Mock Coder
|
||||
mock_coder_instance = MagicMock()
|
||||
mock_coder_llm.return_value = mock_coder_instance
|
||||
mock_coder_instance.with_structured_output.return_value.invoke.return_value = CodeGenerationResponse(
|
||||
code="print('Hello')",
|
||||
explanation="Explanation"
|
||||
)
|
||||
|
||||
# 4. Mock Summarizer
|
||||
mock_summarizer_instance = MagicMock()
|
||||
mock_summarizer_llm.return_value = mock_summarizer_instance
|
||||
mock_summarizer_instance.invoke.return_value = AIMessage(content="Summary")
|
||||
|
||||
# 5. Mock Researcher (just in case)
|
||||
mock_researcher_instance = MagicMock()
|
||||
mock_researcher_llm.return_value = mock_researcher_instance
|
||||
|
||||
# Initial state
|
||||
initial_state = {
|
||||
"messages": [],
|
||||
"question": "Show me the 2024 results",
|
||||
"analysis": None,
|
||||
"next_action": "",
|
||||
"plan": None,
|
||||
"code": None,
|
||||
"error": None,
|
||||
"plots": [],
|
||||
"dfs": {}
|
||||
}
|
||||
|
||||
# Run the graph
|
||||
# We use recursion_limit to avoid infinite loops in placeholders if any
|
||||
result = app.invoke(initial_state, config={"recursion_limit": 10})
|
||||
|
||||
assert result["next_action"] == "plan"
|
||||
assert "plan" in result and result["plan"] is not None
|
||||
assert "code" in result and "print('Hello')" in result["code"]
|
||||
assert "analysis" in result
|
||||
139
backend/tests/test_workflow_e2e.py
Normal file
139
backend/tests/test_workflow_e2e.py
Normal file
@@ -0,0 +1,139 @@
|
||||
import pytest
|
||||
import yaml
|
||||
from unittest.mock import MagicMock, patch
|
||||
from langchain_core.messages import AIMessage
|
||||
from ea_chatbot.graph.workflow import app
|
||||
from ea_chatbot.graph.nodes.query_analyzer import QueryAnalysis
|
||||
from ea_chatbot.schemas import TaskPlanResponse, TaskPlanContext, CodeGenerationResponse
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llms():
|
||||
with patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model") as mock_qa_llm, \
|
||||
patch("ea_chatbot.graph.nodes.planner.get_llm_model") as mock_planner_llm, \
|
||||
patch("ea_chatbot.graph.nodes.coder.get_llm_model") as mock_coder_llm, \
|
||||
patch("ea_chatbot.graph.nodes.summarizer.get_llm_model") as mock_summarizer_llm, \
|
||||
patch("ea_chatbot.graph.nodes.researcher.get_llm_model") as mock_researcher_llm, \
|
||||
patch("ea_chatbot.graph.nodes.summarize_conversation.get_llm_model") as mock_summary_llm, \
|
||||
patch("ea_chatbot.utils.database_inspection.get_data_summary") as mock_get_summary:
|
||||
mock_get_summary.return_value = "Data summary"
|
||||
|
||||
# Mock summary LLM to return a simple response
|
||||
mock_summary_instance = MagicMock()
|
||||
mock_summary_llm.return_value = mock_summary_instance
|
||||
mock_summary_instance.invoke.return_value = AIMessage(content="Turn summary")
|
||||
|
||||
yield {
|
||||
"qa": mock_qa_llm,
|
||||
"planner": mock_planner_llm,
|
||||
"coder": mock_coder_llm,
|
||||
"summarizer": mock_summarizer_llm,
|
||||
"researcher": mock_researcher_llm,
|
||||
"summary": mock_summary_llm
|
||||
}
|
||||
|
||||
def test_workflow_data_analysis_flow(mock_llms):
|
||||
"""Test full flow: QueryAnalyzer -> Planner -> Coder -> Executor -> Summarizer."""
|
||||
|
||||
# 1. Mock Query Analyzer (routes to plan)
|
||||
mock_qa_instance = MagicMock()
|
||||
mock_llms["qa"].return_value = mock_qa_instance
|
||||
mock_qa_instance.with_structured_output.return_value.invoke.return_value = QueryAnalysis(
|
||||
data_required=["2024 results"],
|
||||
unknowns=[],
|
||||
ambiguities=[],
|
||||
conditions=[],
|
||||
next_action="plan"
|
||||
)
|
||||
|
||||
# 2. Mock Planner
|
||||
mock_planner_instance = MagicMock()
|
||||
mock_llms["planner"].return_value = mock_planner_instance
|
||||
mock_planner_instance.with_structured_output.return_value.invoke.return_value = TaskPlanResponse(
|
||||
goal="Get results",
|
||||
reflection="Reflect",
|
||||
context=TaskPlanContext(initial_context="Ctx", assumptions=[], constraints=[]),
|
||||
steps=["Step 1"]
|
||||
)
|
||||
|
||||
# 3. Mock Coder
|
||||
mock_coder_instance = MagicMock()
|
||||
mock_llms["coder"].return_value = mock_coder_instance
|
||||
mock_coder_instance.with_structured_output.return_value.invoke.return_value = CodeGenerationResponse(
|
||||
code="print('Execution Success')",
|
||||
explanation="Explain"
|
||||
)
|
||||
|
||||
# 4. Mock Summarizer
|
||||
mock_summarizer_instance = MagicMock()
|
||||
mock_llms["summarizer"].return_value = mock_summarizer_instance
|
||||
mock_summarizer_instance.invoke.return_value = AIMessage(content="Final Summary: Success")
|
||||
|
||||
# Initial state
|
||||
initial_state = {
|
||||
"messages": [],
|
||||
"question": "Show me 2024 results",
|
||||
"analysis": None,
|
||||
"next_action": "",
|
||||
"plan": None,
|
||||
"code": None,
|
||||
"error": None,
|
||||
"plots": [],
|
||||
"dfs": {}
|
||||
}
|
||||
|
||||
# Run the graph
|
||||
result = app.invoke(initial_state, config={"recursion_limit": 15})
|
||||
|
||||
assert result["next_action"] == "plan"
|
||||
assert "Execution Success" in result["code_output"]
|
||||
assert "Final Summary: Success" in result["messages"][-1].content
|
||||
|
||||
def test_workflow_research_flow(mock_llms):
|
||||
"""Test flow: QueryAnalyzer -> Researcher -> Summarizer."""
|
||||
|
||||
# 1. Mock Query Analyzer (routes to research)
|
||||
mock_qa_instance = MagicMock()
|
||||
mock_llms["qa"].return_value = mock_qa_instance
|
||||
mock_qa_instance.with_structured_output.return_value.invoke.return_value = QueryAnalysis(
|
||||
data_required=[],
|
||||
unknowns=[],
|
||||
ambiguities=[],
|
||||
conditions=[],
|
||||
next_action="research"
|
||||
)
|
||||
|
||||
# 2. Mock Researcher
|
||||
mock_researcher_instance = MagicMock()
|
||||
mock_llms["researcher"].return_value = mock_researcher_instance
|
||||
# Researcher node uses bind_tools if it's ChatOpenAI/ChatGoogleGenerativeAI
|
||||
# Since it's a MagicMock, it will fallback to using the base instance
|
||||
mock_researcher_instance.invoke.return_value = AIMessage(content="Research Results")
|
||||
|
||||
# Also mock bind_tools just in case we ever use spec
|
||||
mock_llm_with_tools = MagicMock()
|
||||
mock_researcher_instance.bind_tools.return_value = mock_llm_with_tools
|
||||
mock_llm_with_tools.invoke.return_value = AIMessage(content="Research Results")
|
||||
|
||||
# 3. Mock Summarizer (not used in this flow, but kept for completeness)
|
||||
mock_summarizer_instance = MagicMock()
|
||||
mock_llms["summarizer"].return_value = mock_summarizer_instance
|
||||
mock_summarizer_instance.invoke.return_value = AIMessage(content="Final Summary: Research Success")
|
||||
|
||||
# Initial state
|
||||
initial_state = {
|
||||
"messages": [],
|
||||
"question": "Who is the governor of Florida?",
|
||||
"analysis": None,
|
||||
"next_action": "",
|
||||
"plan": None,
|
||||
"code": None,
|
||||
"error": None,
|
||||
"plots": [],
|
||||
"dfs": {}
|
||||
}
|
||||
|
||||
# Run the graph
|
||||
result = app.invoke(initial_state, config={"recursion_limit": 10})
|
||||
|
||||
assert result["next_action"] == "research"
|
||||
assert "Research Results" in result["messages"][-1].content
|
||||
72
backend/tests/test_workflow_history.py
Normal file
72
backend/tests/test_workflow_history.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import pytest
|
||||
from ea_chatbot.history.manager import HistoryManager
|
||||
from ea_chatbot.history.models import User, Conversation, Message, Plot
|
||||
from ea_chatbot.config import Settings
|
||||
from sqlalchemy import delete
|
||||
|
||||
@pytest.fixture
|
||||
def history_manager():
|
||||
settings = Settings()
|
||||
manager = HistoryManager(settings.history_db_url)
|
||||
|
||||
def cleanup():
|
||||
with manager.get_session() as session:
|
||||
session.execute(delete(Plot))
|
||||
session.execute(delete(Message))
|
||||
session.execute(delete(Conversation))
|
||||
session.execute(delete(User))
|
||||
|
||||
cleanup()
|
||||
yield manager
|
||||
cleanup()
|
||||
|
||||
def test_full_history_workflow(history_manager):
|
||||
# 1. Create and Authenticate User
|
||||
email = "e2e@example.com"
|
||||
password = "password123"
|
||||
history_manager.create_user(email, password, "E2E User")
|
||||
|
||||
user = history_manager.authenticate_user(email, password)
|
||||
assert user is not None
|
||||
assert user.display_name == "E2E User"
|
||||
|
||||
# 1.1 Verify get_user_by_id
|
||||
fetched_user = history_manager.get_user_by_id(user.id)
|
||||
assert fetched_user is not None
|
||||
assert fetched_user.username == email
|
||||
|
||||
# 2. Create Conversation
|
||||
conv = history_manager.create_conversation(user.id, "nj", "Test Analytics")
|
||||
assert conv.id is not None
|
||||
|
||||
# 3. Add User Message
|
||||
history_manager.add_message(conv.id, "user", "How many voters in NJ?")
|
||||
|
||||
# 4. Add Assistant Message with Plot
|
||||
plot_data = b"fake_png_data"
|
||||
history_manager.add_message(
|
||||
conv.id,
|
||||
"assistant",
|
||||
"There are X voters.",
|
||||
plots=[plot_data]
|
||||
)
|
||||
|
||||
# 5. Retrieve and Verify History
|
||||
messages = history_manager.get_messages(conv.id)
|
||||
assert len(messages) == 2
|
||||
assert messages[0].role == "user"
|
||||
assert messages[1].role == "assistant"
|
||||
assert len(messages[1].plots) == 1
|
||||
assert messages[1].plots[0].image_data == plot_data
|
||||
|
||||
# 6. Verify Conversation listing
|
||||
convs = history_manager.get_conversations(user.id, "nj")
|
||||
assert len(convs) == 1
|
||||
assert convs[0].name == "Test Analytics"
|
||||
|
||||
# 7. Update summary
|
||||
history_manager.update_conversation_summary(conv.id, "Voter count analysis")
|
||||
|
||||
# 8. Reload and verify summary
|
||||
updated_convs = history_manager.get_conversations(user.id, "nj")
|
||||
assert updated_convs[0].summary == "Voter count analysis"
|
||||
3134
backend/uv.lock
generated
Normal file
3134
backend/uv.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user