Refactor: Move backend files to backend/ directory and split .gitignore

This commit is contained in:
Yunxiao Xu
2026-02-11 17:40:44 -08:00
parent 48924affa0
commit 7a69133e26
96 changed files with 144 additions and 176 deletions

View File

View File

View File

@@ -0,0 +1,46 @@
import os
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from ea_chatbot.config import Settings
from ea_chatbot.history.manager import HistoryManager
from ea_chatbot.auth import OIDCClient
from ea_chatbot.api.utils import decode_access_token
from ea_chatbot.history.models import User
settings = Settings()
# Shared instances
history_manager = HistoryManager(settings.history_db_url)
oidc_client = None
if settings.oidc_client_id and settings.oidc_client_secret and settings.oidc_server_metadata_url:
oidc_client = OIDCClient(
client_id=settings.oidc_client_id,
client_secret=settings.oidc_client_secret,
server_metadata_url=settings.oidc_server_metadata_url,
redirect_uri=os.getenv("OIDC_REDIRECT_URI", "http://localhost:3000/auth/callback")
)
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login")
async def get_current_user(token: str = Depends(oauth2_scheme)) -> User:
"""Dependency to get the current authenticated user from the JWT token."""
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
payload = decode_access_token(token)
if payload is None:
raise credentials_exception
user_id: str | None = payload.get("sub")
if user_id is None:
raise credentials_exception
user = history_manager.get_user_by_id(user_id)
if user is None:
raise credentials_exception
return user

View File

@@ -0,0 +1,34 @@
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from ea_chatbot.api.routers import auth, history, artifacts, agent
from dotenv import load_dotenv
load_dotenv()
app = FastAPI(
title="Election Analytics Chatbot API",
description="Backend API for the LangGraph-based Election Analytics Chatbot",
version="0.1.0"
)
# Configure CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Adjust for production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(auth.router)
app.include_router(history.router)
app.include_router(artifacts.router)
app.include_router(agent.router)
@app.get("/health")
async def health_check():
return {"status": "ok"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)

View File

@@ -0,0 +1,162 @@
import json
import asyncio
from typing import AsyncGenerator, Optional, List
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import StreamingResponse
from ea_chatbot.api.dependencies import get_current_user, history_manager
from ea_chatbot.api.utils import convert_to_json_compatible
from ea_chatbot.graph.workflow import app
from ea_chatbot.graph.checkpoint import get_checkpointer
from ea_chatbot.history.models import User as UserDB, Conversation
from ea_chatbot.api.schemas import ChatRequest
import io
import base64
from langchain_core.messages import BaseMessage
router = APIRouter(prefix="/chat", tags=["agent"])
async def stream_agent_events(
message: str,
thread_id: str,
user_id: str,
summary: str
) -> AsyncGenerator[str, None]:
"""
Generator that invokes the LangGraph agent and yields SSE formatted events.
Persists assistant responses and plots to the database.
"""
initial_state = {
"messages": [],
"question": message,
"summary": summary,
"analysis": None,
"next_action": "",
"plan": None,
"code": None,
"code_output": None,
"error": None,
"plots": [],
"dfs": {}
}
config = {"configurable": {"thread_id": thread_id}}
assistant_chunks: List[str] = []
assistant_plots: List[bytes] = []
final_response: str = ""
new_summary: str = ""
try:
async with get_checkpointer() as checkpointer:
async for event in app.astream_events(
initial_state,
config,
version="v2",
checkpointer=checkpointer
):
kind = event.get("event")
name = event.get("name")
data = event.get("data", {})
# Standardize event for frontend
output_event = {
"type": kind,
"name": name,
"data": data
}
# Buffer assistant chunks (summarizer and researcher might stream)
if kind == "on_chat_model_stream" and name in ["summarizer", "researcher"]:
chunk = data.get("chunk", "")
# Use utility to safely extract text content from the chunk
chunk_data = convert_to_json_compatible(chunk)
if isinstance(chunk_data, dict) and "content" in chunk_data:
assistant_chunks.append(str(chunk_data["content"]))
else:
assistant_chunks.append(str(chunk_data))
# Buffer and encode plots
if kind == "on_chain_end" and name == "executor":
output = data.get("output", {})
if isinstance(output, dict) and "plots" in output:
plots = output["plots"]
encoded_plots = []
for fig in plots:
buf = io.BytesIO()
fig.savefig(buf, format="png")
plot_bytes = buf.getvalue()
assistant_plots.append(plot_bytes)
encoded_plots.append(base64.b64encode(plot_bytes).decode('utf-8'))
output_event["data"]["encoded_plots"] = encoded_plots
# Collect final response from terminal nodes
if kind == "on_chain_end" and name in ["summarizer", "researcher", "clarification"]:
output = data.get("output", {})
if isinstance(output, dict) and "messages" in output:
last_msg = output["messages"][-1]
# Use centralized utility to extract clean text content
# Since convert_to_json_compatible returns a dict for BaseMessage,
# we can extract 'content' from it.
msg_data = convert_to_json_compatible(last_msg)
if isinstance(msg_data, dict) and "content" in msg_data:
final_response = msg_data["content"]
else:
final_response = str(msg_data)
# Collect new summary
if kind == "on_chain_end" and name == "summarize_conversation":
output = data.get("output", {})
if isinstance(output, dict) and "summary" in output:
new_summary = output["summary"]
# Convert to JSON compatible format to avoid serialization errors
compatible_output = convert_to_json_compatible(output_event)
yield f"data: {json.dumps(compatible_output)}\n\n"
# If we didn't get a final_response from node output, use buffered chunks
if not final_response and assistant_chunks:
final_response = "".join(assistant_chunks)
# Save assistant message to DB
if final_response:
history_manager.add_message(thread_id, "assistant", final_response, plots=assistant_plots)
# Update summary in DB
if new_summary:
history_manager.update_conversation_summary(thread_id, new_summary)
yield "data: {\"type\": \"done\"}\n\n"
except Exception as e:
error_msg = f"Agent execution failed: {str(e)}"
history_manager.add_message(thread_id, "assistant", error_msg)
yield f"data: {json.dumps({'type': 'error', 'message': error_msg})}\n\n"
@router.post("/stream")
async def chat_stream(
request: ChatRequest,
current_user: UserDB = Depends(get_current_user)
):
"""
Stream agent execution events via SSE.
"""
with history_manager.get_session() as session:
conv = session.get(Conversation, request.thread_id)
if not conv:
raise HTTPException(status_code=404, detail="Conversation not found")
if conv.user_id != current_user.id:
raise HTTPException(status_code=403, detail="Not authorized to access this conversation")
# Save user message immediately
history_manager.add_message(request.thread_id, "user", request.message)
return StreamingResponse(
stream_agent_events(
request.message,
request.thread_id,
current_user.id,
request.summary or ""
),
media_type="text/event-stream"
)

View File

@@ -0,0 +1,45 @@
from fastapi import APIRouter, Depends, HTTPException, Response, status
from ea_chatbot.api.dependencies import get_current_user, history_manager
from ea_chatbot.history.models import Plot, Message, User as UserDB
import io
router = APIRouter(prefix="/artifacts", tags=["artifacts"])
@router.get("/plots/{plot_id}")
async def get_plot(
plot_id: str,
current_user: UserDB = Depends(get_current_user)
):
"""Retrieve a binary plot image (PNG)."""
with history_manager.get_session() as session:
plot = session.get(Plot, plot_id)
if not plot:
raise HTTPException(status_code=404, detail="Plot not found")
# Verify ownership via message -> conversation -> user
message = session.get(Message, plot.message_id)
if not message:
raise HTTPException(status_code=404, detail="Associated message not found")
# In a real app, we should check message.conversation.user_id == current_user.id
# For now, we'll assume the client has the ID correctly.
# But let's do a basic check since it's "secure artifact access".
if message.conversation.user_id != current_user.id:
raise HTTPException(status_code=403, detail="Not authorized to access this artifact")
return Response(content=plot.image_data, media_type="image/png")
@router.get("/data/{message_id}")
async def get_message_data(
message_id: str,
current_user: UserDB = Depends(get_current_user)
):
"""
Retrieve structured dataframe data associated with a message.
Currently returns 404 as dataframes are not yet persisted in the DB.
"""
# TODO: Implement persistence for DataFrames in Phase 4 or a future track
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Structured data not found for this message"
)

View File

@@ -0,0 +1,86 @@
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordRequestForm
from ea_chatbot.api.utils import create_access_token
from ea_chatbot.api.dependencies import history_manager, oidc_client, get_current_user
from ea_chatbot.history.models import User as UserDB
from ea_chatbot.api.schemas import Token, UserCreate, UserResponse
router = APIRouter(prefix="/auth", tags=["auth"])
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
async def register(user_in: UserCreate):
"""Register a new user."""
user = history_manager.get_user(user_in.email)
if user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="User already exists"
)
user = history_manager.create_user(
email=user_in.email,
password=user_in.password,
display_name=user_in.display_name
)
return {
"id": str(user.id),
"email": user.username,
"display_name": user.display_name
}
@router.post("/login", response_model=Token)
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
"""Login with email and password to get a JWT."""
user = history_manager.authenticate_user(form_data.username, form_data.password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect email or password",
headers={"WWW-Authenticate": "Bearer"},
)
access_token = create_access_token(data={"sub": str(user.id)})
return {"access_token": access_token, "token_type": "bearer"}
@router.get("/oidc/login")
async def oidc_login():
"""Get the OIDC authorization URL."""
if not oidc_client:
raise HTTPException(
status_code=status.HTTP_510_NOT_EXTENDED,
detail="OIDC is not configured"
)
url = oidc_client.get_login_url()
return {"url": url}
@router.get("/oidc/callback", response_model=Token)
async def oidc_callback(code: str):
"""Handle the OIDC callback and issue a JWT."""
if not oidc_client:
raise HTTPException(status_code=status.HTTP_510_NOT_EXTENDED, detail="OIDC not configured")
try:
token = oidc_client.exchange_code_for_token(code)
user_info = oidc_client.get_user_info(token)
email = user_info.get("email")
name = user_info.get("name") or user_info.get("preferred_username")
if not email:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Email not provided by OIDC")
user = history_manager.sync_user_from_oidc(email=email, display_name=name)
access_token = create_access_token(data={"sub": str(user.id)})
return {"access_token": access_token, "token_type": "bearer"}
except Exception as e:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=f"OIDC authentication failed: {str(e)}")
@router.get("/me", response_model=UserResponse)
async def get_me(current_user: UserDB = Depends(get_current_user)):
"""Get the current authenticated user's profile."""
return {
"id": str(current_user.id),
"email": current_user.username,
"display_name": current_user.display_name
}

View File

@@ -0,0 +1,99 @@
from fastapi import APIRouter, Depends, HTTPException, status, Response
from typing import List, Optional
from ea_chatbot.api.dependencies import get_current_user, history_manager, settings
from ea_chatbot.history.models import User as UserDB
from ea_chatbot.api.schemas import ConversationResponse, MessageResponse, ConversationUpdate, ConversationCreate
router = APIRouter(prefix="/conversations", tags=["history"])
@router.post("", response_model=ConversationResponse, status_code=status.HTTP_201_CREATED)
async def create_conversation(
conv_in: ConversationCreate,
current_user: UserDB = Depends(get_current_user)
):
"""Create a new conversation."""
state = conv_in.data_state or settings.data_state
conv = history_manager.create_conversation(
user_id=current_user.id,
data_state=state,
name=conv_in.name,
summary=conv_in.summary
)
return {
"id": str(conv.id),
"name": conv.name,
"summary": conv.summary,
"created_at": conv.created_at,
"data_state": conv.data_state
}
@router.get("", response_model=List[ConversationResponse])
async def list_conversations(
current_user: UserDB = Depends(get_current_user),
data_state: Optional[str] = None
):
"""List all conversations for the authenticated user."""
# Use settings default if not provided
state = data_state or settings.data_state
conversations = history_manager.get_conversations(current_user.id, state)
return [
{
"id": str(c.id),
"name": c.name,
"summary": c.summary,
"created_at": c.created_at,
"data_state": c.data_state
} for c in conversations
]
@router.get("/{conversation_id}/messages", response_model=List[MessageResponse])
async def get_conversation_messages(
conversation_id: str,
current_user: UserDB = Depends(get_current_user)
):
"""Get all messages for a specific conversation."""
# TODO: Verify that the conversation belongs to the user
messages = history_manager.get_messages(conversation_id)
return [
{
"id": str(m.id),
"role": m.role,
"content": m.content,
"created_at": m.created_at
} for m in messages
]
@router.patch("/{conversation_id}", response_model=ConversationResponse)
async def update_conversation(
conversation_id: str,
update: ConversationUpdate,
current_user: UserDB = Depends(get_current_user)
):
"""Rename or update the summary of a conversation."""
conv = None
if update.name:
conv = history_manager.rename_conversation(conversation_id, update.name)
if update.summary:
conv = history_manager.update_conversation_summary(conversation_id, update.summary)
if not conv:
raise HTTPException(status_code=404, detail="Conversation not found")
return {
"id": str(conv.id),
"name": conv.name,
"summary": conv.summary,
"created_at": conv.created_at,
"data_state": conv.data_state
}
@router.delete("/{conversation_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_conversation(
conversation_id: str,
current_user: UserDB = Depends(get_current_user)
):
"""Delete a conversation."""
success = history_manager.delete_conversation(conversation_id)
if not success:
raise HTTPException(status_code=404, detail="Conversation not found")
return Response(status_code=status.HTTP_204_NO_CONTENT)

View File

@@ -0,0 +1,51 @@
from pydantic import BaseModel, EmailStr
from typing import List, Optional
from datetime import datetime
# --- Auth Schemas ---
class Token(BaseModel):
access_token: str
token_type: str
class UserCreate(BaseModel):
email: EmailStr
password: str
display_name: Optional[str] = None
class UserResponse(BaseModel):
id: str
email: str
display_name: str
# --- History Schemas ---
class ConversationResponse(BaseModel):
id: str
name: str
summary: Optional[str] = None
created_at: datetime
data_state: str
class ConversationCreate(BaseModel):
name: str
data_state: Optional[str] = None
summary: Optional[str] = None
class MessageResponse(BaseModel):
id: str
role: str
content: str
created_at: datetime
# Plots are fetched separately via artifact endpoints
class ConversationUpdate(BaseModel):
name: Optional[str] = None
summary: Optional[str] = None
# --- Agent Schemas ---
class ChatRequest(BaseModel):
message: str
thread_id: str # Maps to conversation_id
summary: Optional[str] = ""

View File

@@ -0,0 +1,94 @@
from datetime import datetime, timedelta, timezone
from typing import Optional, Union, Any, List, Dict
from jose import JWTError, jwt
from pydantic import BaseModel
from langchain_core.messages import BaseMessage
from ea_chatbot.config import Settings
settings = Settings()
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""
Create a JWT access token.
Args:
data: The payload data to encode.
expires_delta: Optional expiration time delta.
Returns:
str: The encoded JWT token.
"""
to_encode = data.copy()
now = datetime.now(timezone.utc)
if expires_delta:
expire = now + expires_delta
else:
expire = now + timedelta(minutes=settings.access_token_expire_minutes)
to_encode.update({
"exp": expire,
"iat": now,
"iss": "ea-chatbot-api"
})
encoded_jwt = jwt.encode(to_encode, settings.secret_key, algorithm=settings.algorithm)
return encoded_jwt
def decode_access_token(token: str) -> Optional[dict]:
"""
Decode a JWT access token.
Args:
token: The token to decode.
Returns:
Optional[dict]: The decoded payload if valid, None otherwise.
"""
try:
payload = jwt.decode(token, settings.secret_key, algorithms=[settings.algorithm])
return payload
except JWTError:
return None
def convert_to_json_compatible(obj: Any) -> Any:
"""Recursively convert LangChain objects, Pydantic models, and others to JSON compatible formats."""
if isinstance(obj, list):
return [convert_to_json_compatible(item) for item in obj]
elif isinstance(obj, dict):
return {k: convert_to_json_compatible(v) for k, v in obj.items()}
elif isinstance(obj, BaseMessage):
# Handle content that might be a list of blocks (e.g. from Gemini or OpenAI tools)
content = obj.content
if isinstance(content, list):
text_parts = []
for block in content:
if isinstance(block, str):
text_parts.append(block)
elif isinstance(block, dict):
if block.get("type") == "text":
text_parts.append(block.get("text", ""))
# You could also handle other block types if needed
content = "".join(text_parts)
# Prefer .text property if available (common in some message types)
if hasattr(obj, "text") and isinstance(obj.text, str) and obj.text:
content = obj.text
return {"type": obj.type, "content": content, **convert_to_json_compatible(obj.additional_kwargs)}
elif isinstance(obj, BaseModel):
return convert_to_json_compatible(obj.model_dump())
elif hasattr(obj, "model_dump"): # For Pydantic v2 if not caught by BaseModel
try:
return convert_to_json_compatible(obj.model_dump())
except Exception:
return str(obj)
elif hasattr(obj, "dict"): # Fallback for Pydantic v1 or other objects
try:
return convert_to_json_compatible(obj.dict())
except Exception:
return str(obj)
elif hasattr(obj, "content"):
return str(obj.content)
elif isinstance(obj, (datetime, timezone)):
return obj.isoformat()
return obj

View File

@@ -0,0 +1,451 @@
import streamlit as st
import asyncio
import os
import io
from dotenv import load_dotenv
from ea_chatbot.graph.workflow import app
from ea_chatbot.graph.state import AgentState
from ea_chatbot.utils.logging import get_logger
from ea_chatbot.utils.helpers import merge_agent_state
from ea_chatbot.config import Settings
from ea_chatbot.history.manager import HistoryManager
from ea_chatbot.auth import OIDCClient, AuthType, get_user_auth_type
# Load environment variables
load_dotenv()
# Initialize Config and Manager
settings = Settings()
history_manager = HistoryManager(settings.history_db_url)
# Initialize OIDC Client if configured
oidc_client = None
if settings.oidc_client_id and settings.oidc_client_secret and settings.oidc_server_metadata_url:
oidc_client = OIDCClient(
client_id=settings.oidc_client_id,
client_secret=settings.oidc_client_secret,
server_metadata_url=settings.oidc_server_metadata_url,
# Redirect back to the same page
redirect_uri=os.getenv("OIDC_REDIRECT_URI", "http://localhost:8501")
)
# Initialize Logger
logger = get_logger(level=settings.log_level, log_file="logs/app.jsonl")
# --- Authentication Helpers ---
def login_user(user):
st.session_state.user = user
st.session_state.messages = []
st.session_state.summary = ""
st.session_state.current_conversation_id = None
st.rerun()
def logout_user():
for key in list(st.session_state.keys()):
del st.session_state[key]
st.rerun()
def load_conversation(conv_id):
messages = history_manager.get_messages(conv_id)
formatted_messages = []
for m in messages:
# Convert DB models to session state dicts
msg_dict = {
"role": m.role,
"content": m.content,
"plots": [p.image_data for p in m.plots]
}
formatted_messages.append(msg_dict)
st.session_state.messages = formatted_messages
st.session_state.current_conversation_id = conv_id
# Fetch summary from DB
with history_manager.get_session() as session:
from ea_chatbot.history.models import Conversation
conv = session.get(Conversation, conv_id)
st.session_state.summary = conv.summary if conv else ""
st.rerun()
def main():
st.set_page_config(
page_title="Election Analytics Chatbot",
page_icon="🗳️",
layout="wide"
)
# Check for OIDC Callback
if "code" in st.query_params and oidc_client:
code = st.query_params["code"]
try:
token = oidc_client.exchange_code_for_token(code)
user_info = oidc_client.get_user_info(token)
email = user_info.get("email")
name = user_info.get("name") or user_info.get("preferred_username")
if email:
user = history_manager.sync_user_from_oidc(email=email, display_name=name)
# Clear query params
st.query_params.clear()
login_user(user)
except Exception as e:
st.error(f"OIDC Login failed: {str(e)}")
# Display Login Screen if not authenticated
if "user" not in st.session_state:
st.title("🗳️ Election Analytics Chatbot")
# Initialize Login State
if "login_step" not in st.session_state:
st.session_state.login_step = "email"
if "login_email" not in st.session_state:
st.session_state.login_email = ""
col1, col2 = st.columns([1, 1])
with col1:
st.header("Login")
# Step 1: Identification
if st.session_state.login_step == "email":
st.write("Please enter your email to begin:")
with st.form("email_form"):
email_input = st.text_input("Email", value=st.session_state.login_email)
submitted = st.form_submit_button("Next")
if submitted:
if not email_input.strip():
st.error("Email cannot be empty.")
else:
st.session_state.login_email = email_input.strip()
auth_type = get_user_auth_type(st.session_state.login_email, history_manager)
if auth_type == AuthType.LOCAL:
st.session_state.login_step = "login_password"
elif auth_type == AuthType.OIDC:
st.session_state.login_step = "oidc_login"
else: # AuthType.NEW
st.session_state.login_step = "register_details"
st.rerun()
# Step 2a: Local Login
elif st.session_state.login_step == "login_password":
st.info(f"Welcome back, **{st.session_state.login_email}**!")
with st.form("password_form"):
password = st.text_input("Password", type="password")
col_login, col_back = st.columns([1, 1])
submitted = col_login.form_submit_button("Login")
back = col_back.form_submit_button("Back")
if back:
st.session_state.login_step = "email"
st.rerun()
if submitted:
user = history_manager.authenticate_user(st.session_state.login_email, password)
if user:
login_user(user)
else:
st.error("Invalid email or password")
# Step 2b: Registration
elif st.session_state.login_step == "register_details":
st.info(f"Create an account for **{st.session_state.login_email}**")
with st.form("register_form"):
reg_name = st.text_input("Display Name")
reg_password = st.text_input("Password", type="password")
col_reg, col_back = st.columns([1, 1])
submitted = col_reg.form_submit_button("Register & Login")
back = col_back.form_submit_button("Back")
if back:
st.session_state.login_step = "email"
st.rerun()
if submitted:
if not reg_password:
st.error("Password is required for registration.")
else:
user = history_manager.create_user(st.session_state.login_email, reg_password, reg_name)
st.success("Registered! Logging in...")
login_user(user)
# Step 2c: OIDC Redirection
elif st.session_state.login_step == "oidc_login":
st.info(f"**{st.session_state.login_email}** is configured for Single Sign-On (SSO).")
col_sso, col_back = st.columns([1, 1])
with col_sso:
if oidc_client:
login_url = oidc_client.get_login_url()
st.link_button("Login with SSO", login_url, type="primary", use_container_width=True)
else:
st.error("OIDC is not configured.")
with col_back:
if st.button("Back", use_container_width=True):
st.session_state.login_step = "email"
st.rerun()
with col2:
if oidc_client:
st.header("Single Sign-On")
st.write("Login with your organizational account.")
if st.button("Login with SSO"):
login_url = oidc_client.get_login_url()
st.link_button("Go to **YXXU**", login_url, type="primary")
else:
st.info("SSO is not configured.")
st.stop()
# --- Main App (Authenticated) ---
user = st.session_state.user
# Sidebar configuration
with st.sidebar:
st.title(f"Hi, {user.display_name or user.username}!")
if st.button("Logout"):
logout_user()
st.divider()
st.header("History")
if st.button(" New Chat", use_container_width=True):
st.session_state.messages = []
st.session_state.summary = ""
st.session_state.current_conversation_id = None
st.rerun()
# List conversations for the current user and data state
conversations = history_manager.get_conversations(user.id, settings.data_state)
for conv in conversations:
col_c, col_r, col_d = st.columns([0.7, 0.15, 0.15])
is_current = st.session_state.get("current_conversation_id") == conv.id
label = f"💬 {conv.name}" if not is_current else f"👉 {conv.name}"
if col_c.button(label, key=f"conv_{conv.id}", use_container_width=True):
load_conversation(conv.id)
if col_r.button("✏️", key=f"ren_{conv.id}"):
st.session_state.renaming_id = conv.id
if col_d.button("🗑️", key=f"del_{conv.id}"):
if history_manager.delete_conversation(conv.id):
if is_current:
st.session_state.current_conversation_id = None
st.session_state.messages = []
st.rerun()
# Rename dialog
if st.session_state.get("renaming_id"):
rid = st.session_state.renaming_id
with st.form("rename_form"):
new_name = st.text_input("New Name")
if st.form_submit_button("Save"):
history_manager.rename_conversation(rid, new_name)
del st.session_state.renaming_id
st.rerun()
if st.form_submit_button("Cancel"):
del st.session_state.renaming_id
st.rerun()
st.divider()
st.header("Settings")
# Check for DEV_MODE env var (defaults to False)
default_dev_mode = os.getenv("DEV_MODE", "false").lower() == "true"
dev_mode = st.checkbox("Dev Mode", value=default_dev_mode, help="Enable to see code generation and raw reasoning steps.")
st.title("🗳️ Election Analytics Chatbot")
# Initialize chat history state
if "messages" not in st.session_state:
st.session_state.messages = []
if "summary" not in st.session_state:
st.session_state.summary = ""
# Display chat messages from history on app rerun
for message in st.session_state.messages:
with st.chat_message(message["role"]):
if message.get("plan") and dev_mode:
with st.expander("Reasoning Plan"):
st.code(message["plan"], language="yaml")
if message.get("code") and dev_mode:
with st.expander("Generated Code"):
st.code(message["code"], language="python")
st.markdown(message["content"])
if message.get("plots"):
for plot_data in message["plots"]:
# If plot_data is bytes, convert to image
if isinstance(plot_data, bytes):
st.image(plot_data)
else:
# Fallback for old session state or non-binary
st.pyplot(plot_data)
if message.get("dfs"):
for df_name, df in message["dfs"].items():
st.subheader(f"Data: {df_name}")
st.dataframe(df)
# Accept user input
if prompt := st.chat_input("Ask a question about election data..."):
# Ensure we have a conversation ID
if not st.session_state.get("current_conversation_id"):
# Auto-create conversation
conv_name = (prompt[:30] + '...') if len(prompt) > 30 else prompt
conv = history_manager.create_conversation(user.id, settings.data_state, conv_name)
st.session_state.current_conversation_id = conv.id
conv_id = st.session_state.current_conversation_id
# Save user message to DB
history_manager.add_message(conv_id, "user", prompt)
# Add user message to session state
st.session_state.messages.append({"role": "user", "content": prompt})
# Display user message in chat message container
with st.chat_message("user"):
st.markdown(prompt)
# Prepare graph input
initial_state: AgentState = {
"messages": st.session_state.messages[:-1], # Pass history (excluding the current prompt)
"question": prompt,
"summary": st.session_state.summary,
"analysis": None,
"next_action": "",
"plan": None,
"code": None,
"code_output": None,
"error": None,
"plots": [],
"dfs": {}
}
# Placeholder for graph output
with st.chat_message("assistant"):
final_state = initial_state
# Real-time node updates
with st.status("Thinking...", expanded=True) as status:
try:
# Use app.stream to capture node transitions
for event in app.stream(initial_state):
for node_name, state_update in event.items():
prev_error = final_state.get("error")
# Use helper to merge state correctly (appending messages/plots, updating dfs)
final_state = merge_agent_state(final_state, state_update)
if node_name == "query_analyzer":
analysis = state_update.get("analysis", {})
next_action = state_update.get("next_action", "unknown")
status.write(f"🔍 **Analyzed Query:**")
for k,v in analysis.items():
status.write(f"- {k:<8}: {v}")
status.markdown(f"Next Step: {next_action.capitalize()}")
elif node_name == "planner":
status.write("📋 **Plan Generated**")
# Render artifacts
if state_update.get("plan") and dev_mode:
with st.expander("Reasoning Plan", expanded=True):
st.code(state_update["plan"], language="yaml")
elif node_name == "researcher":
status.write("🌐 **Research Complete**")
if state_update.get("messages") and dev_mode:
for msg in state_update["messages"]:
# Extract content from BaseMessage or show raw string
content = getattr(msg, "text", msg.content)
status.markdown(content)
elif node_name == "coder":
status.write("💻 **Code Generated**")
if state_update.get("code") and dev_mode:
with st.expander("Generated Code"):
st.code(state_update["code"], language="python")
elif node_name == "error_corrector":
status.write("🛠️ **Fixing Execution Error...**")
if prev_error:
truncated_error = prev_error.strip()
if len(truncated_error) > 180:
truncated_error = truncated_error[:180] + "..."
status.write(f"Previous error: {truncated_error}")
if state_update.get("code") and dev_mode:
with st.expander("Corrected Code"):
st.code(state_update["code"], language="python")
elif node_name == "executor":
if state_update.get("error"):
if dev_mode:
status.write(f"❌ **Execution Error:** {state_update.get('error')}...")
else:
status.write(f"❌ **Execution Error:** {state_update.get('error')[:100]}...")
else:
status.write("✅ **Execution Successful**")
if state_update.get("plots"):
status.write(f"📊 Generated {len(state_update['plots'])} plot(s)")
elif node_name == "summarizer":
status.write("📝 **Summarizing Results...**")
status.update(label="Complete!", state="complete", expanded=False)
except Exception as e:
status.update(label="Error!", state="error")
st.error(f"Error during graph execution: {str(e)}")
# Extract results
response_text: str = ""
if final_state.get("messages"):
# The last message is the Assistant's response
last_msg = final_state["messages"][-1]
response_text = getattr(last_msg, "text", str(last_msg.content))
st.markdown(response_text)
# Collect plot bytes for saving to DB
plot_bytes_list = []
if final_state.get("plots"):
for fig in final_state["plots"]:
st.pyplot(fig)
# Convert fig to bytes
buf = io.BytesIO()
fig.savefig(buf, format="png")
plot_bytes_list.append(buf.getvalue())
if final_state.get("dfs"):
for df_name, df in final_state["dfs"].items():
st.subheader(f"Data: {df_name}")
st.dataframe(df)
# Save assistant message to DB
history_manager.add_message(conv_id, "assistant", response_text, plots=plot_bytes_list)
# Update summary in DB
new_summary = final_state.get("summary", "")
if new_summary:
history_manager.update_conversation_summary(conv_id, new_summary)
# Store assistant response in session history
st.session_state.messages.append({
"role": "assistant",
"content": response_text,
"plan": final_state.get("plan"),
"code": final_state.get("code"),
"plots": plot_bytes_list,
"dfs": final_state.get("dfs")
})
st.session_state.summary = new_summary
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,97 @@
import requests
from enum import Enum
from typing import Dict, Any, Optional
from authlib.integrations.requests_client import OAuth2Session
class AuthType(Enum):
LOCAL = "local"
OIDC = "oidc"
NEW = "new"
def get_user_auth_type(email: str, history_manager: Any) -> AuthType:
"""
Determine the authentication type for a given email.
Args:
email: The user's email address.
history_manager: Instance of HistoryManager to check the DB.
Returns:
AuthType: LOCAL if password exists, OIDC if user exists but no password, NEW otherwise.
"""
user = history_manager.get_user(email)
if not user:
return AuthType.NEW
if user.password_hash:
return AuthType.LOCAL
return AuthType.OIDC
class OIDCClient:
"""
Client for OIDC Authentication using Authlib.
Designed to work within a Streamlit environment.
"""
def __init__(
self,
client_id: str,
client_secret: str,
server_metadata_url: str,
redirect_uri: str = "http://localhost:8501"
):
self.client_id = client_id
self.client_secret = client_secret
self.server_metadata_url = server_metadata_url
self.redirect_uri = redirect_uri
self.oauth_session = OAuth2Session(
client_id=client_id,
client_secret=client_secret,
redirect_uri=redirect_uri,
scope="openid email profile"
)
self.metadata: Dict[str, Any] = {}
def fetch_metadata(self) -> Dict[str, Any]:
"""Fetch OIDC provider metadata if not already fetched."""
if not self.metadata:
self.metadata = requests.get(self.server_metadata_url).json()
return self.metadata
def get_login_url(self) -> str:
"""Generate the authorization URL."""
metadata = self.fetch_metadata()
authorization_endpoint = metadata.get("authorization_endpoint")
if not authorization_endpoint:
raise ValueError("authorization_endpoint not found in OIDC metadata")
uri, state = self.oauth_session.create_authorization_url(authorization_endpoint)
return uri
def exchange_code_for_token(self, code: str) -> Dict[str, Any]:
"""Exchange the authorization code for an access token."""
metadata = self.fetch_metadata()
token_endpoint = metadata.get("token_endpoint")
if not token_endpoint:
raise ValueError("token_endpoint not found in OIDC metadata")
token = self.oauth_session.fetch_token(
token_endpoint,
code=code,
client_secret=self.client_secret
)
return token
def get_user_info(self, token: Dict[str, Any]) -> Dict[str, Any]:
"""Fetch user information using the access token."""
metadata = self.fetch_metadata()
userinfo_endpoint = metadata.get("userinfo_endpoint")
if not userinfo_endpoint:
raise ValueError("userinfo_endpoint not found in OIDC metadata")
# Set the token on the session so it's used in the request
self.oauth_session.token = token
resp = self.oauth_session.get(userinfo_endpoint)
resp.raise_for_status()
return resp.json()

View File

@@ -0,0 +1,52 @@
from typing import Dict, Any, Optional
from pydantic import BaseModel, Field
from pydantic_settings import BaseSettings, SettingsConfigDict
from dotenv import load_dotenv
load_dotenv()
class LLMConfig(BaseModel):
"""Configuration for a specific LLM node."""
provider: str = "openai"
model: str = "gpt-5-mini"
temperature: float = 0.0
max_tokens: Optional[int] = None
provider_specific: Dict[str, Any] = Field(default_factory=dict)
class Settings(BaseSettings):
"""Global application settings."""
data_dir: str = "data"
data_state: str = "new_jersey"
log_level: str = Field(default="INFO", alias="LOG_LEVEL")
# Voter Database configuration
db_host: str = Field(default="localhost", alias="DB_HOST")
db_port: int = Field(default=5432, alias="DB_PORT")
db_name: str = Field(default="blockdata", alias="DB_NAME")
db_user: str = Field(default="user", alias="DB_USER")
db_pswd: str = Field(default="password", alias="DB_PSWD")
db_table: str = Field(default="rd_gc_voters_nj", alias="DB_TABLE")
# Application/History Database
history_db_url: str = Field(default="postgresql://user:password@localhost:5433/ea_history", alias="HISTORY_DB_URL")
# JWT Configuration
secret_key: str = Field(default="change-me-in-production", alias="SECRET_KEY")
algorithm: str = Field(default="HS256", alias="ALGORITHM")
access_token_expire_minutes: int = Field(default=30, alias="ACCESS_TOKEN_EXPIRE_MINUTES")
# OIDC Configuration
oidc_client_id: Optional[str] = Field(default=None, alias="OIDC_CLIENT_ID")
oidc_client_secret: Optional[str] = Field(default=None, alias="OIDC_CLIENT_SECRET")
oidc_server_metadata_url: Optional[str] = Field(default=None, alias="OIDC_SERVER_METADATA_URL")
# Default configurations for each node
query_analyzer_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0))
planner_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0))
coder_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0))
summarizer_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0))
researcher_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0))
# Allow nested env vars like QUERY_ANALYZER_LLM__MODEL
model_config = SettingsConfigDict(env_nested_delimiter='__', env_prefix='')

View File

View File

@@ -0,0 +1,44 @@
import contextlib
from typing import AsyncGenerator
from psycopg_pool import AsyncConnectionPool
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from ea_chatbot.config import Settings
settings = Settings()
_pool = None
def get_pool() -> AsyncConnectionPool:
"""Get or create the async connection pool."""
global _pool
if _pool is None:
_pool = AsyncConnectionPool(
conninfo=settings.history_db_url,
max_size=20,
kwargs={"autocommit": True, "prepare_threshold": 0},
open=False, # Don't open automatically on init
)
return _pool
@contextlib.asynccontextmanager
async def get_checkpointer() -> AsyncGenerator[AsyncPostgresSaver, None]:
"""
Context manager to get a PostgresSaver checkpointer.
Ensures that the checkpointer is properly initialized and the connection is managed.
"""
pool = get_pool()
# Ensure pool is open
if pool.closed:
await pool.open()
async with pool.connection() as conn:
checkpointer = AsyncPostgresSaver(conn)
# Ensure the necessary tables exist
await checkpointer.setup()
yield checkpointer
async def close_pool():
"""Close the connection pool."""
global _pool
if _pool and not _pool.closed:
await _pool.close()

View File

@@ -0,0 +1,45 @@
from langchain_core.messages import AIMessage
from ea_chatbot.graph.state import AgentState
from ea_chatbot.config import Settings
from ea_chatbot.utils.llm_factory import get_llm_model
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
def clarification_node(state: AgentState) -> dict:
"""Ask the user for missing information or clarifications."""
question = state["question"]
analysis = state.get("analysis", {})
ambiguities = analysis.get("ambiguities", [])
settings = Settings()
logger = get_logger("clarification")
logger.info(f"Generating clarification for {len(ambiguities)} ambiguities.")
llm = get_llm_model(
settings.query_analyzer_llm,
callbacks=[LangChainLoggingHandler(logger=logger)]
)
system_prompt = """You are a Clarification Specialist. Your role is to identify what information is missing from a user's request to perform a data analysis or research task.
Based on the analysis of the user's question, formulate a polite and concise request for the missing information."""
prompt = f"""Original Question: {question}
Missing/Ambiguous Information: {', '.join(ambiguities) if ambiguities else 'Unknown ambiguities'}
Please ask the user for the necessary details."""
messages = [
("system", system_prompt),
("user", prompt)
]
try:
response = llm.invoke(messages)
logger.info("[bold green]Clarification generated.[/bold green]")
return {
"messages": [response],
"next_action": "end" # To indicate we are done for now
}
except Exception as e:
logger.error(f"Failed to generate clarification: {str(e)}")
raise e

View File

@@ -0,0 +1,47 @@
from ea_chatbot.graph.state import AgentState
from ea_chatbot.config import Settings
from ea_chatbot.utils.llm_factory import get_llm_model
from ea_chatbot.utils import helpers, database_inspection
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
from ea_chatbot.graph.prompts.coder import CODE_GENERATOR_PROMPT
from ea_chatbot.schemas import CodeGenerationResponse
def coder_node(state: AgentState) -> dict:
"""Generate Python code based on the plan and data summary."""
question = state["question"]
plan = state.get("plan", "")
code_output = state.get("code_output", "None")
settings = Settings()
logger = get_logger("coder")
logger.info("Generating Python code...")
llm = get_llm_model(
settings.coder_llm,
callbacks=[LangChainLoggingHandler(logger=logger)]
)
structured_llm = llm.with_structured_output(CodeGenerationResponse)
# Always provide data summary
database_description = database_inspection.get_data_summary(data_dir=settings.data_dir) or "No data available."
example_code = "" # Placeholder
messages = CODE_GENERATOR_PROMPT.format_messages(
question=question,
plan=plan,
database_description=database_description,
code_exec_results=code_output,
example_code=example_code
)
try:
response = structured_llm.invoke(messages)
logger.info("[bold green]Code generated.[/bold green]")
return {
"code": response.parsed_code,
"error": None # Clear previous errors on new code generation
}
except Exception as e:
logger.error(f"Failed to generate code: {str(e)}")
raise e

View File

@@ -0,0 +1,44 @@
from ea_chatbot.graph.state import AgentState
from ea_chatbot.config import Settings
from ea_chatbot.utils.llm_factory import get_llm_model
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
from ea_chatbot.graph.prompts.coder import ERROR_CORRECTOR_PROMPT
from ea_chatbot.schemas import CodeGenerationResponse
def error_corrector_node(state: AgentState) -> dict:
"""Fix the code based on the execution error."""
code = state.get("code", "")
error = state.get("error", "Unknown error")
settings = Settings()
logger = get_logger("error_corrector")
logger.warning(f"[bold red]Execution error detected:[/bold red] {error[:100]}...")
logger.info("Attempting to correct the code...")
# Reuse coder LLM config or add a new one. Using coder_llm for now.
llm = get_llm_model(
settings.coder_llm,
callbacks=[LangChainLoggingHandler(logger=logger)]
)
structured_llm = llm.with_structured_output(CodeGenerationResponse)
messages = ERROR_CORRECTOR_PROMPT.format_messages(
code=code,
error=error
)
try:
response = structured_llm.invoke(messages)
logger.info("[bold green]Correction generated.[/bold green]")
current_iterations = state.get("iterations", 0)
return {
"code": response.parsed_code,
"error": None, # Clear error after fix attempt
"iterations": current_iterations + 1
}
except Exception as e:
logger.error(f"Failed to correct code: {str(e)}")
raise e

View File

@@ -0,0 +1,102 @@
import io
import sys
import traceback
from contextlib import redirect_stdout
from typing import Any, Dict, List, TYPE_CHECKING
import pandas as pd
from matplotlib.figure import Figure
from ea_chatbot.graph.state import AgentState
from ea_chatbot.utils.db_client import DBClient
from ea_chatbot.utils.logging import get_logger
from ea_chatbot.config import Settings
if TYPE_CHECKING:
from ea_chatbot.types import DBSettings
def executor_node(state: AgentState) -> dict:
"""Execute the Python code and capture output, plots, and dataframes."""
code = state.get("code")
logger = get_logger("executor")
if not code:
logger.error("No code provided to executor.")
return {"error": "No code provided to executor."}
logger.info("Executing Python code...")
settings = Settings()
db_settings: "DBSettings" = {
"host": settings.db_host,
"port": settings.db_port,
"user": settings.db_user,
"pswd": settings.db_pswd,
"db": settings.db_name,
"table": settings.db_table
}
db_client = DBClient(settings=db_settings)
# Initialize local variables for execution
# 'db' is the DBClient instance, 'plots' is for matplotlib figures
local_vars = {
'db': db_client,
'plots': [],
'pd': pd
}
stdout_buffer = io.StringIO()
error = None
code_output = ""
plots = []
dfs = {}
try:
with redirect_stdout(stdout_buffer):
# Execute the code in the context of local_vars
exec(code, {}, local_vars)
code_output = stdout_buffer.getvalue()
# Limit the output length if it's too long
if code_output.count('\n') > 32:
code_output = '\n'.join(code_output.split('\n')[:32]) + '\n...'
# Extract plots
raw_plots = local_vars.get('plots', [])
if isinstance(raw_plots, list):
plots = [p for p in raw_plots if isinstance(p, Figure)]
# Extract DataFrames that were likely intended for display
# We look for DataFrames in local_vars that were mentioned in the code
for key, value in local_vars.items():
if isinstance(value, pd.DataFrame):
# Heuristic: if the variable name is in the code, it might be a result DF
if key in code:
dfs[key] = value
logger.info(f"[bold green]Execution complete.[/bold green] Captured {len(plots)} plots and {len(dfs)} dataframes.")
except Exception as e:
# Capture the traceback
exc_type, exc_value, tb = sys.exc_info()
full_traceback = traceback.format_exc()
# Filter traceback to show only the relevant part (the executed string)
filtered_tb_lines = [line for line in full_traceback.split('\n') if '<string>' in line]
error = '\n'.join(filtered_tb_lines)
if error:
error += '\n'
error += f"{exc_type.__name__ if exc_type else 'Exception'}: {exc_value}"
logger.error(f"Execution failed: {str(e)}")
# If we have an error, we still might want to see partial stdout
code_output = stdout_buffer.getvalue()
return {
"code_output": code_output,
"error": error,
"plots": plots,
"dfs": dfs
}

View File

@@ -0,0 +1,51 @@
import yaml
from ea_chatbot.graph.state import AgentState
from ea_chatbot.config import Settings
from ea_chatbot.utils.llm_factory import get_llm_model
from ea_chatbot.utils import helpers, database_inspection
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
from ea_chatbot.graph.prompts.planner import PLANNER_PROMPT
from ea_chatbot.schemas import TaskPlanResponse
def planner_node(state: AgentState) -> dict:
"""Generate a structured plan based on the query analysis."""
question = state["question"]
history = state.get("messages", [])[-6:]
summary = state.get("summary", "")
settings = Settings()
logger = get_logger("planner")
logger.info("Generating task plan...")
llm = get_llm_model(
settings.planner_llm,
callbacks=[LangChainLoggingHandler(logger=logger)]
)
structured_llm = llm.with_structured_output(TaskPlanResponse)
date_str = helpers.get_readable_date()
# Always provide data summary; LLM decides relevance.
database_description = database_inspection.get_data_summary(data_dir=settings.data_dir) or "No data available."
example_plan = ""
messages = PLANNER_PROMPT.format_messages(
date=date_str,
question=question,
history=history,
summary=summary,
database_description=database_description,
example_plan=example_plan
)
# Generate the structured plan
try:
response = structured_llm.invoke(messages)
# Convert the structured response back to YAML string for the state
plan_yaml = yaml.dump(response.model_dump(), sort_keys=False)
logger.info("[bold green]Plan generated successfully.[/bold green]")
return {"plan": plan_yaml}
except Exception as e:
logger.error(f"Failed to generate plan: {str(e)}")
raise e

View File

@@ -0,0 +1,73 @@
from typing import List, Literal
from pydantic import BaseModel, Field
from ea_chatbot.graph.state import AgentState
from ea_chatbot.config import Settings
from ea_chatbot.utils.llm_factory import get_llm_model
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
from ea_chatbot.graph.prompts.query_analyzer import QUERY_ANALYZER_PROMPT
class QueryAnalysis(BaseModel):
"""Analysis of the user's query."""
data_required: List[str] = Field(description="List of data points or entities mentioned (e.g., ['2024 results', 'Florida']).")
unknowns: List[str] = Field(description="List of target information the user wants to know or needed for final answer (e.g., 'who won', 'total votes').")
ambiguities: List[str] = Field(description="List of CRITICAL missing details that prevent ANY analysis. Do NOT include database names or plot types if defaults can be used.")
conditions: List[str] = Field(description="List of any filters or constraints (e.g., ['year=2024', 'state=Florida']). Include context resolved from history.")
next_action: Literal["plan", "clarify", "research"] = Field(description="The next action to take. 'plan' for data analysis (even with defaults), 'research' for general knowledge, or 'clarify' ONLY for critical ambiguities.")
def query_analyzer_node(state: AgentState) -> dict:
"""Analyze the user's question and determine the next course of action."""
question = state["question"]
history = state.get("messages", [])
summary = state.get("summary", "")
# Keep last 3 turns (6 messages)
history = history[-6:]
settings = Settings()
logger = get_logger("query_analyzer")
logger.info(f"Analyzing question: [italic]\"{question}\"[/italic]")
# Initialize the LLM with structured output using the factory
# Pass logging callback to track LLM usage
llm = get_llm_model(
settings.query_analyzer_llm,
callbacks=[LangChainLoggingHandler(logger=logger)]
)
structured_llm = llm.with_structured_output(QueryAnalysis)
# Prepare messages using the prompt template
messages = QUERY_ANALYZER_PROMPT.format_messages(
question=question,
history=history,
summary=summary
)
try:
# Invoke the structured LLM directly with the list of messages
analysis_result = structured_llm.invoke(messages)
analysis_result = QueryAnalysis.model_validate(analysis_result)
analysis_dict = analysis_result.model_dump()
analysis_dict.pop("next_action")
next_action = analysis_result.next_action
logger.info(f"Analysis complete. Next action: [bold magenta]{next_action}[/bold magenta]")
except Exception as e:
logger.error(f"Error during query analysis: {str(e)}")
analysis_dict = {
"data_required": [],
"unknowns": [],
"ambiguities": [f"Error during analysis: {str(e)}"],
"conditions": []
}
next_action = "clarify"
return {
"analysis": analysis_dict,
"next_action": next_action,
"iterations": 0
}

View File

@@ -0,0 +1,60 @@
from langchain_core.messages import AIMessage
from langchain_openai import ChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI
from ea_chatbot.graph.state import AgentState
from ea_chatbot.config import Settings
from ea_chatbot.utils.llm_factory import get_llm_model
from ea_chatbot.utils import helpers
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
from ea_chatbot.graph.prompts.researcher import RESEARCHER_PROMPT
def researcher_node(state: AgentState) -> dict:
"""Handle general research queries or web searches."""
question = state["question"]
history = state.get("messages", [])[-6:]
summary = state.get("summary", "")
settings = Settings()
logger = get_logger("researcher")
logger.info(f"Researching question: [italic]\"{question}\"[/italic]")
# Use researcher_llm from settings
llm = get_llm_model(
settings.researcher_llm,
callbacks=[LangChainLoggingHandler(logger=logger)]
)
date_str = helpers.get_readable_date()
messages = RESEARCHER_PROMPT.format_messages(
date=date_str,
question=question,
history=history,
summary=summary
)
# Provider-aware tool binding
try:
if isinstance(llm, ChatGoogleGenerativeAI):
# Native Google Search for Gemini
llm_with_tools = llm.bind_tools([{"google_search": {}}])
elif isinstance(llm, ChatOpenAI):
# Native Web Search for OpenAI (built-in tool)
llm_with_tools = llm.bind_tools([{"type": "web_search"}])
else:
# Fallback for other providers that might not support these specific search tools
llm_with_tools = llm
except Exception as e:
logger.warning(f"Failed to bind search tools: {str(e)}. Falling back to base LLM.")
llm_with_tools = llm
try:
response = llm_with_tools.invoke(messages)
logger.info("[bold green]Research complete.[/bold green]")
return {
"messages": [response]
}
except Exception as e:
logger.error(f"Research failed: {str(e)}")
raise e

View File

@@ -0,0 +1,52 @@
from langchain_core.messages import SystemMessage, HumanMessage
from ea_chatbot.graph.state import AgentState
from ea_chatbot.config import Settings
from ea_chatbot.utils.llm_factory import get_llm_model
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
def summarize_conversation_node(state: AgentState) -> dict:
"""Update the conversation summary based on the latest interaction."""
summary = state.get("summary", "")
messages = state.get("messages", [])
# We only summarize if there are messages
if not messages:
return {}
# Get the last turn (User + Assistant)
last_turn = messages[-2:]
settings = Settings()
logger = get_logger("summarize_conversation")
logger.info("Updating conversation summary...")
# Use summarizer_llm for this task as well
llm = get_llm_model(
settings.summarizer_llm,
callbacks=[LangChainLoggingHandler(logger=logger)]
)
if summary:
prompt = (
f"This is a summary of the conversation so far: {summary}\n\n"
"Extend the summary by taking into account the new messages above."
)
else:
prompt = "Create a summary of the conversation above."
# Construct the messages for the summarization LLM
summarization_messages = [
SystemMessage(content=f"Current summary: {summary}" if summary else "You are a helpful assistant that summarizes conversations."),
HumanMessage(content=f"Recent messages:\n{last_turn}\n\n{prompt}\n\nKeep the summary concise and focused on the key topics and data points discussed.")
]
try:
response = llm.invoke(summarization_messages)
new_summary = response.content
logger.info("[bold green]Conversation summary updated.[/bold green]")
return {"summary": new_summary}
except Exception as e:
logger.error(f"Failed to update summary: {str(e)}")
# If summarization fails, we keep the old one
return {"summary": summary}

View File

@@ -0,0 +1,44 @@
from langchain_core.messages import AIMessage
from ea_chatbot.graph.state import AgentState
from ea_chatbot.config import Settings
from ea_chatbot.utils.llm_factory import get_llm_model
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
from ea_chatbot.graph.prompts.summarizer import SUMMARIZER_PROMPT
def summarizer_node(state: AgentState) -> dict:
"""Summarize the code execution results into a final answer."""
question = state["question"]
plan = state.get("plan", "")
code_output = state.get("code_output", "")
history = state.get("messages", [])[-6:]
summary = state.get("summary", "")
settings = Settings()
logger = get_logger("summarizer")
logger.info("Generating final summary...")
llm = get_llm_model(
settings.summarizer_llm,
callbacks=[LangChainLoggingHandler(logger=logger)]
)
messages = SUMMARIZER_PROMPT.format_messages(
question=question,
plan=plan,
code_output=code_output,
history=history,
summary=summary
)
try:
response = llm.invoke(messages)
logger.info("[bold green]Summary generated.[/bold green]")
# Return the final message to be added to the state
return {
"messages": [response]
}
except Exception as e:
logger.error(f"Failed to generate summary: {str(e)}")
raise e

View File

@@ -0,0 +1,10 @@
from .query_analyzer import QUERY_ANALYZER_PROMPT
from .planner import PLANNER_PROMPT
from .coder import CODE_GENERATOR_PROMPT, ERROR_CORRECTOR_PROMPT
__all__ = [
"QUERY_ANALYZER_PROMPT",
"PLANNER_PROMPT",
"CODE_GENERATOR_PROMPT",
"ERROR_CORRECTOR_PROMPT",
]

View File

@@ -0,0 +1,64 @@
from langchain_core.prompts import ChatPromptTemplate
CODE_GENERATOR_SYSTEM = """You are an AI data analyst and your job is to assist users with data analysis and coding tasks.
The user will provide a task and a plan.
**Data Access:**
- A database client is available as a variable named `db`.
- You MUST use `db.query_df(sql_query)` to execute SQL queries and retrieve data as a Pandas DataFrame.
- Do NOT assume a dataframe `df` is already loaded unless explicitly stated. You usually need to query it first.
- The database schema is described in the prompt. Use it to construct valid SQL queries.
**Plotting:**
- If you need to plot any data, use the `plots` list to store the figures.
- Example: `plots.append(fig)` or `plots.append(plt.gcf())`.
- Do not use `plt.show()` as it will render the plot and cause an error.
**Code Requirements:**
- Produce FULL, COMPLETE CODE that includes all steps and solves the task!
- Always include the import statements at the top of the code (e.g., `import pandas as pd`, `import matplotlib.pyplot as plt`).
- Always include print statements to output the results of your code.
- Use `db.query_df("SELECT ...")` to get data."""
CODE_GENERATOR_USER = """TASK:
{question}
PLAN:
```yaml
{plan}
```
AVAILABLE DATA SUMMARY (Database Schema):
{database_description}
CODE EXECUTION OF THE PREVIOUS TASK RESULTED IN:
{code_exec_results}
{example_code}"""
ERROR_CORRECTOR_SYSTEM = """The execution of the code resulted in an error.
Return a complete, corrected python code that incorporates the fixes for the error.
**Reminders:**
- You have access to a database client via the variable `db`.
- Use `db.query_df(sql)` to run queries.
- Use `plots.append(fig)` for plots.
- Always include imports and print statements."""
ERROR_CORRECTOR_USER = """FAILED CODE:
```python
{code}
```
ERROR:
{error}"""
CODE_GENERATOR_PROMPT = ChatPromptTemplate.from_messages([
("system", CODE_GENERATOR_SYSTEM),
("human", CODE_GENERATOR_USER),
])
ERROR_CORRECTOR_PROMPT = ChatPromptTemplate.from_messages([
("system", ERROR_CORRECTOR_SYSTEM),
("human", ERROR_CORRECTOR_USER),
])

View File

@@ -0,0 +1,46 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
PLANNER_SYSTEM = """You are a Research Specialist and your job is to find answers and educate the user.
Provide factual information responding directly to the user's question. Include key details and context to ensure your response comprehensively answers their query.
Today's Date is: {date}"""
PLANNER_USER = """Conversation Summary: {summary}
TASK:
{question}
AVAILABLE DATA SUMMARY (Use only if relevant to the task):
{database_description}
First: Evaluate whether you have all necessary and requested information to provide a solution.
Use the dataset description above to determine what data and in what format you have available to you.
You are able to search internet if the user asks for it, or you require any information that you can not derive from the given dataset or the instruction.
Second: Incorporate any additional relevant context, reasoning, or details from previous interactions or internal chain-of-thought that may impact the solution.
Ensure that all such information is fully included in your response rather than referring to previous answers indirectly.
Third: Reflect on the problem and briefly describe it, while addressing the problem goal, inputs, outputs,
rules, constraints, and other relevant details that appear in the problem description.
Fourth: Based on the preceding steps, formulate your response as an algorithm, breaking the solution in up to eight simple concise yet descriptive, clear English steps.
You MUST Include all values or instructions as described in the above task, or retrieved using internet search!
If fewer steps suffice, that's acceptable. If more are needed, please include them.
Remember to explain steps rather than write code.
This algorithm will be later converted to Python code.
If a dataframe is required, assume it is named 'df' and is already defined/populated based on the data summary.
There is a list variable called `plots` that you need to use to store any plots you generate. Do not use `plt.show()` as it will render the plot and cause an error.
Output the algorithm as a YAML string. Always enclose the YAML string within ```yaml tags.
**Note: Ensure that any necessary context from prior interactions is fully embedded in the plan. Do not use phrases like "refer to previous answer"; instead, provide complete details inline.**
{example_plan}"""
PLANNER_PROMPT = ChatPromptTemplate.from_messages([
("system", PLANNER_SYSTEM),
MessagesPlaceholder(variable_name="history"),
("human", PLANNER_USER),
])

View File

@@ -0,0 +1,33 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
SYSTEM_PROMPT = """You are an expert election data analyst. Decompose the user's question into key elements to determine the next action.
### Context & Defaults
- **History:** Use the conversation history and summary to resolve coreferences (e.g., "those results", "that state"). Assume the current question inherits missing context (Year, State, County) from history.
- **Data Access:** You have access to voter and election databases. Proceed to planning without asking for database or table names.
- **Downstream Capabilities:** Visualizations are generated as Matplotlib figures. Proceed to planning for "graphs" or "plots" without asking for file formats or plot types.
- **Trends:** For trend requests without a specified interval, allow the Planner to use a sensible default (e.g., by election cycle).
### Instructions:
1. **Analyze:** Identify if the request is for data analysis, general facts (web research), or is critically ambiguous.
2. **Extract Entities & Conditions:**
- **Data Required:** e.g., "vote count", "demographics".
- **Conditions:** e.g., "Year=2024". Include context from history.
3. **Identify Target & Critical Ambiguities:**
- **Unknowns:** The core target question.
- **Critical Ambiguities:** ONLY list issues that PREVENT any analysis.
- Examples: No timeframe/geography in query OR history; "track the same voter" without an identity definition.
4. **Determine Action:**
- `plan`: For data analysis where defaults or history provide sufficient context.
- `research`: For general knowledge.
- `clarify`: ONLY for CRITICAL ambiguities."""
USER_PROMPT_TEMPLATE = """Conversation Summary: {summary}
Analyze the following question: {question}"""
QUERY_ANALYZER_PROMPT = ChatPromptTemplate.from_messages([
("system", SYSTEM_PROMPT),
MessagesPlaceholder(variable_name="history"),
("human", USER_PROMPT_TEMPLATE),
])

View File

@@ -0,0 +1,12 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
RESEARCHER_PROMPT = ChatPromptTemplate.from_messages([
("system", """You are a Research Specialist and your job is to find answers and educate the user.
Provide factual information responding directly to the user's question. Include key details and context to ensure your response comprehensively answers their query.
Today's Date is: {date}"""),
MessagesPlaceholder(variable_name="history"),
("user", """Conversation Summary: {summary}
{question}""")
])

View File

@@ -0,0 +1,27 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
SUMMARIZER_PROMPT = ChatPromptTemplate.from_messages([
("system", """You are an expert election data analyst providing a final answer to the user.
Use the provided conversation history and summary to ensure your response is contextually relevant and flows naturally from previous turns.
Conversation Summary: {summary}"""),
MessagesPlaceholder(variable_name="history"),
("user", """The user presented you with the following question.
Question: {question}
To address this, you have designed an algorithm.
Algorithm: {plan}.
You have crafted a Python code based on this algorithm, and the output generated by the code's execution is as follows.
Output: {code_output}.
Please produce a comprehensive, easy-to-understand answer that:
1. Summarizes the main insights or conclusions achieved through your method's implementation. Include execution results if necessary.
2. Includes relevant findings from the code execution in a clear format (e.g., text explanation, tables, lists, bullet points).
- Avoid referencing the code or output as 'the above results' or saying 'it's in the code output.'
- Instead, present the actual key data or statistics within your explanation.
3. If the user requested specific information that does not appear in the code's output but you can provide it, include that information directly in your summary.
4. Present any data or tables that might have been generated by the code in full, since the user cannot directly see the execution output.
Your goal is to give a final answer that stands on its own without requiring the user to see the code or raw output directly.""")
])

View File

@@ -0,0 +1,36 @@
from typing import TypedDict, Annotated, List, Dict, Any, Optional
from langchain_core.messages import BaseMessage
from langchain.agents import AgentState as AS
import operator
class AgentState(AS):
# Conversation history
messages: Annotated[List[BaseMessage], operator.add]
# Task context
question: str
# Query Analysis (Decomposition results)
analysis: Optional[Dict[str, Any]]
# Expected keys: "requires_dataset", "expert", "data", "unknown", "condition"
# Step-by-step reasoning
plan: Optional[str]
# Code execution context
code: Optional[str]
code_output: Optional[str]
error: Optional[str]
# Artifacts (for UI display)
plots: Annotated[List[Any], operator.add] # Matplotlib figures
dfs: Dict[str, Any] # Pandas DataFrames
# Conversation summary
summary: Optional[str]
# Routing hint: "clarify", "plan", "research", "end"
next_action: str
# Number of execution attempts
iterations: int

View File

@@ -0,0 +1,92 @@
from langgraph.graph import StateGraph, END
from ea_chatbot.graph.state import AgentState
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node
from ea_chatbot.graph.nodes.planner import planner_node
from ea_chatbot.graph.nodes.coder import coder_node
from ea_chatbot.graph.nodes.error_corrector import error_corrector_node
from ea_chatbot.graph.nodes.executor import executor_node
from ea_chatbot.graph.nodes.summarizer import summarizer_node
from ea_chatbot.graph.nodes.researcher import researcher_node
from ea_chatbot.graph.nodes.clarification import clarification_node
from ea_chatbot.graph.nodes.summarize_conversation import summarize_conversation_node
MAX_ITERATIONS = 3
def router(state: AgentState) -> str:
"""Route to the next node based on the analysis."""
next_action = state.get("next_action")
if next_action == "plan":
return "planner"
elif next_action == "research":
return "researcher"
elif next_action == "clarify":
return "clarification"
else:
return END
def create_workflow():
"""Create the LangGraph workflow."""
workflow = StateGraph(AgentState)
# Add nodes
workflow.add_node("query_analyzer", query_analyzer_node)
workflow.add_node("planner", planner_node)
workflow.add_node("coder", coder_node)
workflow.add_node("error_corrector", error_corrector_node)
workflow.add_node("researcher", researcher_node)
workflow.add_node("clarification", clarification_node)
workflow.add_node("executor", executor_node)
workflow.add_node("summarizer", summarizer_node)
workflow.add_node("summarize_conversation", summarize_conversation_node)
# Set entry point
workflow.set_entry_point("query_analyzer")
# Add conditional edges from query_analyzer
workflow.add_conditional_edges(
"query_analyzer",
router,
{
"planner": "planner",
"researcher": "researcher",
"clarification": "clarification",
END: END
}
)
# Linear flow for planning and coding
workflow.add_edge("planner", "coder")
workflow.add_edge("coder", "executor")
# Executor routing
def executor_router(state: AgentState) -> str:
if state.get("error"):
# Check for iteration limit to prevent infinite loops
if state.get("iterations", 0) >= MAX_ITERATIONS:
return "summarizer"
return "error_corrector"
return "summarizer"
workflow.add_conditional_edges(
"executor",
executor_router,
{
"error_corrector": "error_corrector",
"summarizer": "summarizer"
}
)
workflow.add_edge("error_corrector", "executor")
workflow.add_edge("researcher", "summarize_conversation")
workflow.add_edge("clarification", END)
workflow.add_edge("summarizer", "summarize_conversation")
workflow.add_edge("summarize_conversation", END)
# Compile the graph
app = workflow.compile()
return app
# Initialize the app
app = create_workflow()

View File

@@ -0,0 +1,188 @@
from contextlib import contextmanager
from typing import Optional, List
from sqlalchemy import create_engine, select, delete
from sqlalchemy.orm import sessionmaker, Session
from argon2 import PasswordHasher
from argon2.exceptions import VerifyMismatchError
from ea_chatbot.history.models import User, Conversation, Message, Plot
# Argon2 Password Hasher
ph = PasswordHasher()
class HistoryManager:
"""Manages database sessions and operations for history and user data."""
def __init__(self, db_url: str):
self.engine = create_engine(db_url)
# expire_on_commit=False is important so we can use objects after session closes
self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self.engine, expire_on_commit=False)
@contextmanager
def get_session(self):
"""Context manager for database sessions."""
session = self.SessionLocal()
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
# --- User Management ---
def get_user(self, email: str) -> Optional[User]:
"""Fetch a user by their email (username)."""
with self.get_session() as session:
result = session.execute(select(User).where(User.username == email))
return result.scalar_one_or_none()
def get_user_by_id(self, user_id: str) -> Optional[User]:
"""Fetch a user by their ID."""
with self.get_session() as session:
return session.get(User, user_id)
def create_user(self, email: str, password: Optional[str] = None, display_name: Optional[str] = None) -> User:
"""Create a new local user."""
hashed_password = ph.hash(password) if password else None
user = User(
username=email,
password_hash=hashed_password,
display_name=display_name or email.split("@")[0]
)
with self.get_session() as session:
session.add(user)
session.commit()
session.refresh(user)
return user
def authenticate_user(self, email: str, password: str) -> Optional[User]:
"""Authenticate a user by email and password."""
user = self.get_user(email)
if not user or not user.password_hash:
return None
try:
ph.verify(user.password_hash, password)
return user
except VerifyMismatchError:
return None
def sync_user_from_oidc(self, email: str, display_name: Optional[str] = None) -> User:
"""
Synchronize a user from an OIDC provider.
If a user with the same email exists, update their display name.
Otherwise, create a new user.
"""
user = self.get_user(email)
if user:
# Update existing user if needed
if display_name and user.display_name != display_name:
with self.get_session() as session:
db_user = session.get(User, user.id)
db_user.display_name = display_name
session.commit()
session.refresh(db_user)
return db_user
return user
else:
# Create new user (no password for OIDC users initially)
return self.create_user(email=email, display_name=display_name)
# --- Conversation Management ---
def create_conversation(self, user_id: str, data_state: str, name: str, summary: Optional[str] = None) -> Conversation:
"""Create a new conversation for a user."""
conv = Conversation(
user_id=user_id,
data_state=data_state,
name=name,
summary=summary
)
with self.get_session() as session:
session.add(conv)
session.commit()
session.refresh(conv)
return conv
def get_conversations(self, user_id: str, data_state: str) -> List[Conversation]:
"""Get all conversations for a user and data state, ordered by creation time."""
with self.get_session() as session:
stmt = (
select(Conversation)
.where(Conversation.user_id == user_id, Conversation.data_state == data_state)
.order_by(Conversation.created_at.desc())
)
result = session.execute(stmt)
return list(result.scalars().all())
def rename_conversation(self, conversation_id: str, new_name: str) -> Optional[Conversation]:
"""Rename an existing conversation."""
with self.get_session() as session:
conv = session.get(Conversation, conversation_id)
if conv:
conv.name = new_name
session.commit()
session.refresh(conv)
return conv
def delete_conversation(self, conversation_id: str) -> bool:
"""Delete a conversation and its associated messages/plots (via cascade)."""
with self.get_session() as session:
conv = session.get(Conversation, conversation_id)
if conv:
session.delete(conv)
session.commit()
return True
return False
def update_conversation_summary(self, conversation_id: str, summary: str) -> Optional[Conversation]:
"""Update the summary of a conversation."""
with self.get_session() as session:
conv = session.get(Conversation, conversation_id)
if conv:
conv.summary = summary
session.commit()
session.refresh(conv)
return conv
# --- Message & Plot Management ---
def add_message(self, conversation_id: str, role: str, content: str, plots: Optional[List[bytes]] = None) -> Message:
"""Add a message to a conversation, optionally with plots."""
msg = Message(
conversation_id=conversation_id,
role=role,
content=content
)
with self.get_session() as session:
session.add(msg)
session.flush() # Populate msg.id for plots
if plots:
for plot_data in plots:
plot = Plot(message_id=msg.id, image_data=plot_data)
session.add(plot)
session.commit()
session.refresh(msg)
# Ensure plots are loaded before session closes if we need them
_ = msg.plots
return msg
def get_messages(self, conversation_id: str) -> List[Message]:
"""Get all messages for a conversation, ordered by creation time."""
with self.get_session() as session:
stmt = (
select(Message)
.where(Message.conversation_id == conversation_id)
.order_by(Message.created_at.asc())
)
result = session.execute(stmt)
messages = list(result.scalars().all())
# Pre-load plots for each message
for m in messages:
_ = m.plots
return messages

View File

@@ -0,0 +1,52 @@
from typing import List, Optional
from datetime import datetime, timezone
import uuid
from sqlalchemy import String, ForeignKey, DateTime, LargeBinary, Text
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
class Base(DeclarativeBase):
pass
class User(Base):
__tablename__ = "users"
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
username: Mapped[str] = mapped_column(String, unique=True, index=True)
password_hash: Mapped[Optional[str]] = mapped_column(String, nullable=True)
display_name: Mapped[Optional[str]] = mapped_column(String, nullable=True)
conversations: Mapped[List["Conversation"]] = relationship(back_populates="user", cascade="all, delete-orphan")
class Conversation(Base):
__tablename__ = "conversations"
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
user_id: Mapped[str] = mapped_column(ForeignKey("users.id"))
data_state: Mapped[str] = mapped_column(String)
name: Mapped[str] = mapped_column(String)
summary: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, default=lambda: datetime.now(timezone.utc))
user: Mapped["User"] = relationship(back_populates="conversations")
messages: Mapped[List["Message"]] = relationship(back_populates="conversation", cascade="all, delete-orphan")
class Message(Base):
__tablename__ = "messages"
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
conversation_id: Mapped[str] = mapped_column(ForeignKey("conversations.id"))
role: Mapped[str] = mapped_column(String)
content: Mapped[str] = mapped_column(Text)
created_at: Mapped[datetime] = mapped_column(DateTime, default=lambda: datetime.now(timezone.utc))
conversation: Mapped["Conversation"] = relationship(back_populates="messages")
plots: Mapped[List["Plot"]] = relationship(back_populates="message", cascade="all, delete-orphan")
class Plot(Base):
__tablename__ = "plots"
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
message_id: Mapped[str] = mapped_column(ForeignKey("messages.id"))
image_data: Mapped[bytes] = mapped_column(LargeBinary)
message: Mapped["Message"] = relationship(back_populates="plots")

View File

@@ -0,0 +1,83 @@
from pydantic import BaseModel, Field, computed_field
from typing import Sequence, Optional
import re
class TaskPlanContext(BaseModel):
'''Background context relevant to the task plan'''
initial_context: str = Field(
min_length=1,
description="Background information about the database/tables and previous conversations relevant to the task.",
)
assumptions: Sequence[str] = Field(
description="Assumptions made while working on the task.",
)
constraints: Optional[Sequence[str]] = Field(
description="Constraints that apply to the task.",
)
class TaskPlanResponse(BaseModel):
'''Structured plan to achieve the task objective'''
goal: str = Field(
min_length=1,
description="Single-sentence objective the plan must achieve.",
)
reflection: str = Field(
min_length=1,
description="High-level natural-language reasoning describing the user's request and the intended solution approach.",
)
context: TaskPlanContext = Field(
description="Background context relevant to the task plan.",
)
steps: Sequence[str] = Field(
min_length=1,
description="Ordered list of steps to execute that follow the 'Step <number>: <detail>' pattern.",
)
_IM_SEP_TOKEN_PATTERN = re.compile(re.escape("<|im_sep|>"))
_CODE_BLOCK_PATTERN = re.compile(r"```(?:python\s*)?(.*?)\s*```", re.DOTALL)
_FORBIDDEN_MODULES = (
"subprocess",
"sys",
"eval",
"exec",
"socket",
"urllib",
"shutil",
"pickle",
"ctypes",
"multiprocessing",
"tempfile",
"glob",
"pty",
"commands",
"cgi",
"cgitb",
"xml.etree.ElementTree",
"builtins",
)
_FORBIDDEN_MODULE_PATTERN = re.compile(
r"^((?:[^#].*)?\b(" + "|".join(map(re.escape, _FORBIDDEN_MODULES)) + r")\b.*)$",
flags=re.MULTILINE,
)
class CodeGenerationResponse(BaseModel):
'''Code generation response structure'''
code: str = Field(description="The generated code snippet to accomplish the task")
explanation: str = Field(description="Explanation of the generated code and its functionality")
@computed_field(return_type=str)
@property
def parsed_code(self) -> str:
'''Extracts the code snippet without any surrounding text'''
normalised = _IM_SEP_TOKEN_PATTERN.sub("```", self.code).strip()
match = _CODE_BLOCK_PATTERN.search(normalised)
candidate = match.group(1).strip() if match else normalised
sanitised = _FORBIDDEN_MODULE_PATTERN.sub(r"# not allowed \1", candidate)
return sanitised.strip()
class RankResponse(BaseModel):
'''Code ranking response structure'''
rank: int = Field(
ge=1, le=10,
description="Rank of the code snippet from 1 (best) to 10 (worst)"
)

View File

@@ -0,0 +1,24 @@
from typing import TypedDict, Optional
from enum import StrEnum
class DBSettings(TypedDict):
host: str
port: int
user: str
pswd: str
db: str
table: Optional[str]
class Agent(StrEnum):
EXPERT_SELECTOR = "Expert Selector"
ANALYST_SELECTOR = "Analyst Selector"
THEORIST = "Theorist"
THEORIST_WEB = "Theorist-Web"
THEORIST_CLARIFICATION = "Theorist-Clarification"
PLANNER = "Planner"
CODE_GENERATOR = "Code Generator"
CODE_DEBUGGER = "Code Debugger"
CODE_EXECUTOR = "Code Executor"
ERROR_CORRECTOR = "Error Corrector"
CODE_RANKER = "Code Ranker"
SOLUTION_SUMMARIZER = "Solution Summarizer"

View File

@@ -0,0 +1,12 @@
from .db_client import DBClient
from .llm_factory import get_llm_model
from .logging import get_logger, LangChainLoggingHandler
from . import helpers
__all__ = [
"DBClient",
"get_llm_model",
"get_logger",
"LangChainLoggingHandler",
"helpers"
]

View File

@@ -0,0 +1,234 @@
from typing import Optional, Dict, Any, List, TYPE_CHECKING
import yaml
import json
import os
from ea_chatbot.utils.db_client import DBClient
if TYPE_CHECKING:
from ea_chatbot.types import DBSettings
def _get_table_checksum(db_client: DBClient, table: str) -> str:
"""Calculates the checksum of the table using DML statistics from pg_stat_user_tables."""
query = f"""
SELECT md5(concat_ws('|', n_tup_ins, n_tup_upd, n_tup_del)) AS dml_hash
FROM pg_stat_user_tables
WHERE schemaname = 'public' AND relname = '{table}';"""
try:
return str(db_client.query_df(query).iloc[0, 0])
except Exception:
return "unknown_checksum"
def _update_checksum_file(filepath: str, table: str, checksum: str):
"""Updates the checksum file with the new checksum for the table."""
checksums = {}
if os.path.exists(filepath):
with open(filepath, 'r') as f:
for line in f:
if ':' in line:
k, v = line.strip().split(':', 1)
checksums[k] = v
checksums[table] = checksum
with open(filepath, 'w') as f:
for k, v in checksums.items():
f.write(f"{k}:{v}")
def get_data_summary(data_dir: str = "data") -> Optional[str]:
"""
Reads the inspection.yaml file and returns its content as a string.
"""
inspection_file = os.path.join(data_dir, "inspection.yaml")
if os.path.exists(inspection_file):
with open(inspection_file, 'r') as f:
return f.read()
return None
def get_primary_key(db_client: DBClient, table_name: str) -> Optional[str]:
"""
Dynamically identifies the primary key of the table.
Returns the column name of the primary key, or None if not found.
"""
query = f"""
SELECT kcu.column_name
FROM information_schema.key_column_usage AS kcu
JOIN information_schema.table_constraints AS tc
ON kcu.constraint_name = tc.constraint_name
AND kcu.table_schema = tc.table_schema
WHERE kcu.table_name = '{table_name}'
AND tc.constraint_type = 'PRIMARY KEY'
LIMIT 1;
"""
try:
df = db_client.query_df(query)
if not df.empty:
return str(df.iloc[0, 0])
except Exception as e:
print(f"Warning: Could not determine primary key for {table_name}: {e}")
return None
def inspect_db_table(
db_client: Optional[DBClient]=None,
db_settings: Optional["DBSettings"]=None,
data_dir: str = "data",
force_update: bool = False
) -> Optional[str]:
"""
Inspects the database table, generates statistics for each column,
and saves the inspection results to a YAML file locally.
Improvements:
- Dynamic Primary Key Discovery
- Cardinality (Unique Counts)
- Categorical Sample Values for low cardinality columns
- Robust Quoting
"""
inspection_file = os.path.join(data_dir, "inspection.yaml")
checksum_file = os.path.join(data_dir, "checksum")
# Initialize DB Client
if db_client is None:
if db_settings is None:
print("Error: Either db_client or db_settings must be provided.")
return None
try:
db_client = DBClient(db_settings)
except Exception as e:
print(f"Failed to create DBClient: {e}")
return None
table_name = db_client.settings.get('table')
if not table_name:
print("Error: Table name must be specified in DBSettings.")
return None
os.makedirs(data_dir, exist_ok=True)
# Checksum verification
new_checksum = _get_table_checksum(db_client, table_name)
has_changed = True
if os.path.exists(checksum_file):
try:
with open(checksum_file, 'r') as f:
saved_checksums = f.read().strip()
if f"{table_name}:{new_checksum}" in saved_checksums:
has_changed = False
except Exception:
pass # Force update on read error
if not has_changed and not force_update:
return get_data_summary(data_dir)
print(f"Regenerating inspection file for table '{table_name}'...")
# Fetch Table Metadata
try:
# Get columns and types
columns_query = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}';"
columns_df = db_client.query_df(columns_query)
# Get Row Counts
total_rows_df = db_client.query_df(f'SELECT COUNT(*) FROM "{table_name}"')
total_rows = int(total_rows_df.iloc[0, 0])
# Dynamic Primary Key
primary_key = get_primary_key(db_client, table_name)
# Get First/Last Rows (if PK exists)
first_row_df = None
last_row_df = None
if primary_key:
first_row_df = db_client.query_df(f'SELECT * FROM "{table_name}" ORDER BY "{primary_key}" ASC LIMIT 1')
last_row_df = db_client.query_df(f'SELECT * FROM "{table_name}" ORDER BY "{primary_key}" DESC LIMIT 1')
except Exception as e:
print(f"Failed to retrieve basic table info: {e}")
return None
stats_dict: Dict[str, Any] = {}
if primary_key:
stats_dict['primary_key'] = primary_key
for _, row in columns_df.iterrows():
col_name = row['column_name']
dtype = row['data_type']
try:
# Count Values
# Using robust quoting
count_df = db_client.query_df(f'SELECT COUNT("{col_name}") FROM "{table_name}"')
count_val = int(count_df.iloc[0,0])
# Count Unique (Cardinality)
unique_df = db_client.query_df(f'SELECT COUNT(DISTINCT "{col_name}") FROM "{table_name}"')
unique_count = int(unique_df.iloc[0,0])
col_stats: Dict[str, Any] = {
'dtype': dtype,
'count_of_values': count_val,
'count_of_nulls': total_rows - count_val,
'unique_count': unique_count
}
if count_val == 0:
stats_dict[col_name] = col_stats
continue
# Numerical Stats
if any(t in dtype for t in ('int', 'float', 'numeric', 'double', 'real', 'decimal')):
stats_query = f'SELECT AVG("{col_name}"), MIN("{col_name}"), MAX("{col_name}") FROM "{table_name}"'
stats_df = db_client.query_df(stats_query)
if not stats_df.empty:
col_stats['mean'] = float(stats_df.iloc[0,0]) if stats_df.iloc[0,0] is not None else None
col_stats['min'] = float(stats_df.iloc[0,1]) if stats_df.iloc[0,1] is not None else None
col_stats['max'] = float(stats_df.iloc[0,2]) if stats_df.iloc[0,2] is not None else None
# Temporal Stats
elif any(t in dtype for t in ('date', 'timestamp')):
stats_query = f'SELECT MIN("{col_name}"), MAX("{col_name}") FROM "{table_name}"'
stats_df = db_client.query_df(stats_query)
if not stats_df.empty:
col_stats['min'] = str(stats_df.iloc[0,0])
col_stats['max'] = str(stats_df.iloc[0,1])
# Categorical/Text Stats
else:
# Sample values if cardinality is low (< 20)
if 0 < unique_count < 20:
distinct_query = f'SELECT DISTINCT "{col_name}" FROM "{table_name}" ORDER BY "{col_name}" LIMIT 20'
distinct_df = db_client.query_df(distinct_query)
col_stats['distinct_values'] = distinct_df.iloc[:, 0].tolist()
if first_row_df is not None and not first_row_df.empty and col_name in first_row_df.columns:
col_stats['first_value'] = str(first_row_df.iloc[0][col_name])
if last_row_df is not None and not last_row_df.empty and col_name in last_row_df.columns:
col_stats['last_value'] = str(last_row_df.iloc[0][col_name])
stats_dict[col_name] = col_stats
except Exception as e:
print(f"Warning: Could not process column {col_name}: {e}")
# Load existing inspections to merge (if multiple tables)
existing_inspections = {}
if os.path.exists(inspection_file):
try:
with open(inspection_file, 'r') as f:
existing_inspections = yaml.safe_load(f) or {}
# Backup old file
os.rename(inspection_file, inspection_file + ".old")
except Exception:
pass
existing_inspections[table_name] = stats_dict
# Save new inspection
inspection_content = yaml.dump(existing_inspections, sort_keys=False, default_flow_style=False)
with open(inspection_file, 'w') as f:
f.write(inspection_content)
# Update Checksum
_update_checksum_file(checksum_file, table_name, new_checksum)
print(f"Inspection saved to {inspection_file}")
return inspection_content

View File

@@ -0,0 +1,21 @@
from typing import Optional, TYPE_CHECKING
import pandas as pd
from sqlalchemy import create_engine, text
if TYPE_CHECKING:
from ea_chatbot.types import DBSettings
class DBClient:
def __init__(self, settings: "DBSettings"):
self.settings = settings
self._engine = self._create_engine()
def _create_engine(self):
url = f"postgresql://{self.settings['user']}:{self.settings['pswd']}@{self.settings['host']}:{self.settings['port']}/{self.settings['db']}"
return create_engine(url)
def query_df(self, sql: str, params: Optional[dict] = None) -> pd.DataFrame:
with self._engine.connect() as conn:
result = conn.execute(text(sql), params or {})
df = pd.DataFrame(result.fetchall(), columns=result.keys())
return df

View File

@@ -0,0 +1,73 @@
from typing import Optional, TYPE_CHECKING, Dict, Any
from datetime import datetime, timezone
import yaml
import json
if TYPE_CHECKING:
from ea_chatbot.graph.state import AgentState
def ordinal(n: int) -> str:
return f"{n}{'th' if 11<=n<=13 else {1:'st',2:'nd',3:'rd'}.get(n%10, 'th')}"
def get_readable_date(date_obj: Optional[datetime] = None, tz: Optional[timezone] = None) -> str:
if date_obj is None:
date_obj = datetime.now(timezone.utc)
if tz:
date_obj = date_obj.astimezone(tz)
return date_obj.strftime(f"%a {ordinal(date_obj.day)} of %b %Y")
def to_yaml(json_str: str, indent: int = 2) -> str:
"""
Attempts to convert a JSON string (potentially malformed from LLM) to a YAML string.
"""
if not json_str: return ""
try:
# Try direct parse
data = json.loads(json_str)
except json.JSONDecodeError:
# Try simplified repair: replace single quotes
try:
cleaned = json_str.replace("'", '"')
data = json.loads(cleaned)
except Exception:
# Fallback: return raw string if unparseable
return json_str
return yaml.dump(data, indent=indent, sort_keys=False)
def merge_agent_state(current_state: "AgentState", update: Dict[str, Any]) -> "AgentState":
"""
Merges a partial state update into the current state, mimicking LangGraph reduction logic.
- Lists (messages, plots) are appended.
- Dictionaries (dfs) are shallow merged.
- Other fields are overwritten.
"""
new_state = current_state.copy()
for key, value in update.items():
if value is None:
new_state[key] = None
continue
# Accumulate lists (messages, plots)
if key in ["messages", "plots"] and isinstance(value, list):
current_list = new_state.get(key, [])
if not isinstance(current_list, list):
current_list = []
new_state[key] = current_list + value
# Shallow merge dictionaries (dfs)
elif key == "dfs" and isinstance(value, dict):
current_dict = new_state.get(key, {})
if not isinstance(current_dict, dict):
current_dict = {}
merged_dict = current_dict.copy()
merged_dict.update(value)
new_state[key] = merged_dict
# Overwrite everything else
else:
new_state[key] = value
return new_state

View File

@@ -0,0 +1,36 @@
from typing import Optional, cast, TYPE_CHECKING, Literal, Dict, List, Tuple, Any
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_openai import ChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.callbacks import BaseCallbackHandler
from ea_chatbot.config import LLMConfig
def get_llm_model(config: LLMConfig, callbacks: Optional[List[BaseCallbackHandler]] = None) -> BaseChatModel:
"""
Factory function to get a LangChain chat model based on configuration.
Args:
config: LLMConfig object containing model settings.
callbacks: Optional list of LangChain callback handlers.
Returns:
Initialized BaseChatModel instance.
Raises:
ValueError: If the provider is not supported.
"""
params = {
"temperature": config.temperature,
"max_tokens": config.max_tokens,
**config.provider_specific
}
# Filter out None values to allow defaults to take over if not specified
params = {k: v for k, v in params.items() if v is not None}
if config.provider.lower() == "openai":
return ChatOpenAI(model=config.model, callbacks=callbacks, **params)
elif config.provider.lower() == "google" or config.provider.lower() == "google_genai":
return ChatGoogleGenerativeAI(model=config.model, callbacks=callbacks, **params)
else:
raise ValueError(f"Unsupported LLM provider: {config.provider}")

View File

@@ -0,0 +1,141 @@
import json
import os
import logging
from rich.logging import RichHandler
from langchain_core.callbacks import BaseCallbackHandler
from logging.handlers import RotatingFileHandler
from typing import Any, Optional, Dict, List
class LangChainLoggingHandler(BaseCallbackHandler):
"""Callback handler for logging LangChain events."""
def __init__(self, logger: Optional[logging.Logger] = None):
self.logger = logger or get_logger("langchain")
def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> Any:
# Serialized might be empty or missing name depending on how it's called
model_name = serialized.get("name") or kwargs.get("name") or "LLM"
self.logger.info(f"[bold blue]LLM Started:[/bold blue] {model_name}")
def on_llm_end(self, response: Any, **kwargs: Any) -> Any:
llm_output = getattr(response, "llm_output", {}) or {}
# Try to find model name in output or use fallback
model_name = llm_output.get("model_name") or "LLM"
token_usage = llm_output.get("token_usage", {})
msg = f"[bold green]LLM Ended:[/bold green] {model_name}"
if token_usage:
prompt = token_usage.get("prompt_tokens", 0)
completion = token_usage.get("completion_tokens", 0)
total = token_usage.get("total_tokens", 0)
msg += f" | [yellow]Tokens: {total}[/yellow] ({prompt} prompt, {completion} completion)"
self.logger.info(msg)
def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> Any:
self.logger.error(f"[bold red]LLM Error:[/bold red] {str(error)}")
class ContextLoggerAdapter(logging.LoggerAdapter):
"""Adapter to inject contextual metadata into log records."""
def process(self, msg: Any, kwargs: Any) -> tuple[Any, Any]:
extra = self.extra.copy()
if "extra" in kwargs:
extra.update(kwargs.pop("extra"))
kwargs["extra"] = extra
return msg, kwargs
class FlexibleJSONEncoder(json.JSONEncoder):
def default(self, obj: Any) -> Any:
if hasattr(obj, 'model_dump'): # Pydantic v2
return obj.model_dump()
if hasattr(obj, 'dict'): # Pydantic v1
return obj.dict()
if hasattr(obj, '__dict__'):
return self.serialize_custom_object(obj)
elif isinstance(obj, dict):
return {k: self.default(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [self.default(item) for item in obj]
return super().default(obj)
def serialize_custom_object(self, obj: Any) -> dict:
obj_dict = obj.__dict__.copy()
obj_dict['__custom_class__'] = obj.__class__.__name__
return obj_dict
class JsonFormatter(logging.Formatter):
"""Custom JSON formatter for structured logging."""
def format(self, record: logging.LogRecord) -> str:
# Standard fields
log_record = {
"timestamp": self.formatTime(record, self.datefmt),
"level": record.levelname,
"message": record.getMessage(),
"module": record.module,
"name": record.name,
}
# Add exception info if present
if record.exc_info:
log_record["exception"] = self.formatException(record.exc_info)
# Add all other extra fields from the record
# Filter out standard logging attributes
standard_attrs = {
'args', 'asctime', 'created', 'exc_info', 'exc_text', 'filename',
'funcName', 'levelname', 'levelno', 'lineno', 'module',
'msecs', 'message', 'msg', 'name', 'pathname', 'process',
'processName', 'relativeCreated', 'stack_info', 'thread', 'threadName'
}
for key, value in record.__dict__.items():
if key not in standard_attrs:
log_record[key] = value
return json.dumps(log_record, cls=FlexibleJSONEncoder)
def get_logger(name: str = "ea_chatbot", level: Optional[str] = None, log_file: Optional[str] = None) -> logging.Logger:
"""Get a configured logger with RichHandler and optional Json FileHandler."""
# Ensure name starts with ea_chatbot for hierarchy if not already
if name != "ea_chatbot" and not name.startswith("ea_chatbot."):
full_name = f"ea_chatbot.{name}"
else:
full_name = name
logger = logging.getLogger(full_name)
# Configure root ea_chatbot logger if it hasn't been configured
root_logger = logging.getLogger("ea_chatbot")
if not root_logger.handlers:
# Default to INFO if level not provided
log_level = getattr(logging, (level or "INFO").upper(), logging.INFO)
root_logger.setLevel(log_level)
# Console Handler (Rich)
rich_handler = RichHandler(
rich_tracebacks=True,
markup=True,
show_time=False,
show_path=False
)
root_logger.addHandler(rich_handler)
root_logger.propagate = False
# Always check if we need to add a FileHandler, even if root is already configured
if log_file:
existing_file_handlers = [h for h in root_logger.handlers if isinstance(h, RotatingFileHandler)]
if not existing_file_handlers:
os.makedirs(os.path.dirname(log_file), exist_ok=True)
file_handler = RotatingFileHandler(
log_file, maxBytes=5*1024*1024, backupCount=3
)
file_handler.setFormatter(JsonFormatter())
root_logger.addHandler(file_handler)
# Refresh logger object in case it was created before root was configured
logger = logging.getLogger(full_name)
# If level is explicitly provided for a sub-logger, set it
if level:
logger.setLevel(getattr(logging, level.upper(), logging.INFO))
return logger