chore: Finalize cleanup phases (docstrings, utility consolidation, dev app isolation)

This commit is contained in:
Yunxiao Xu
2026-02-17 02:50:08 -08:00
parent 1b15a4e18c
commit 16d8e81b6b
6 changed files with 46 additions and 41 deletions

476
backend/dev_app.py Normal file
View File

@@ -0,0 +1,476 @@
import streamlit as st
import os
import io
from dotenv import load_dotenv
from ea_chatbot.graph.workflow import app
from ea_chatbot.graph.state import AgentState
from ea_chatbot.utils.logging import get_logger
from ea_chatbot.utils.helpers import merge_agent_state
from ea_chatbot.utils.plots import fig_to_bytes
from ea_chatbot.config import Settings
from ea_chatbot.history.manager import HistoryManager
from ea_chatbot.auth import OIDCClient, AuthType, get_user_auth_type
# Load environment variables
load_dotenv()
# Initialize Config and Manager
settings = Settings()
history_manager = HistoryManager(settings.history_db_url)
# Initialize OIDC Client if configured
oidc_client = None
if settings.oidc_client_id and settings.oidc_client_secret and settings.oidc_server_metadata_url:
oidc_client = OIDCClient(
client_id=settings.oidc_client_id,
client_secret=settings.oidc_client_secret,
server_metadata_url=settings.oidc_server_metadata_url,
# Redirect back to the same page
redirect_uri=os.getenv("OIDC_REDIRECT_URI", "http://localhost:8501")
)
# Initialize Logger
logger = get_logger(level=settings.log_level, log_file="logs/app.jsonl")
# --- Authentication Helpers ---
def login_user(user):
st.session_state.user = user
st.session_state.messages = []
st.session_state.summary = ""
st.session_state.current_conversation_id = None
st.rerun()
def logout_user():
for key in list(st.session_state.keys()):
del st.session_state[key]
st.rerun()
def load_conversation(conv_id):
messages = history_manager.get_messages(conv_id)
formatted_messages = []
for m in messages:
# Convert DB models to session state dicts
msg_dict = {
"role": m.role,
"content": m.content,
"plots": [p.image_data for p in m.plots]
}
formatted_messages.append(msg_dict)
st.session_state.messages = formatted_messages
st.session_state.current_conversation_id = conv_id
# Fetch summary from DB
with history_manager.get_session() as session:
from ea_chatbot.history.models import Conversation
conv = session.get(Conversation, conv_id)
st.session_state.summary = conv.summary if conv else ""
st.rerun()
def main():
st.set_page_config(
page_title="Election Analytics Chatbot",
page_icon="🗳️",
layout="wide"
)
# Check for OIDC Callback
if "code" in st.query_params and "state" in st.query_params and oidc_client:
code = st.query_params["code"]
state = st.query_params["state"]
# Validate state
stored_state = st.session_state.get("oidc_state")
if not stored_state or state != stored_state:
st.error("OIDC state mismatch. Please try again.")
st.stop()
try:
# 1. Exchange code using PKCE verifier
verifier = st.session_state.get("oidc_verifier")
token = oidc_client.exchange_code_for_token(code, code_verifier=verifier)
# 2. Validate ID Token using Nonce
id_token = token.get("id_token")
nonce = st.session_state.get("oidc_nonce")
claims = oidc_client.validate_id_token(id_token, nonce=nonce)
email = claims.get("email")
name = claims.get("name") or claims.get("preferred_username")
if email:
user = history_manager.sync_user_from_oidc(email=email, display_name=name)
# Clear query params and session data
st.query_params.clear()
for key in ["oidc_state", "oidc_nonce", "oidc_verifier"]:
if key in st.session_state:
del st.session_state[key]
login_user(user)
except Exception as e:
st.error(f"OIDC Login failed: {str(e)}")
# Display Login Screen if not authenticated
if "user" not in st.session_state:
st.title("🗳️ Election Analytics Chatbot")
# Initialize Login State
if "login_step" not in st.session_state:
st.session_state.login_step = "email"
if "login_email" not in st.session_state:
st.session_state.login_email = ""
col1, col2 = st.columns([1, 1])
with col1:
st.header("Login")
# Step 1: Identification
if st.session_state.login_step == "email":
st.write("Please enter your email to begin:")
with st.form("email_form"):
email_input = st.text_input("Email", value=st.session_state.login_email)
submitted = st.form_submit_button("Next")
if submitted:
if not email_input.strip():
st.error("Email cannot be empty.")
else:
st.session_state.login_email = email_input.strip()
auth_type = get_user_auth_type(st.session_state.login_email, history_manager)
if auth_type == AuthType.LOCAL:
st.session_state.login_step = "login_password"
elif auth_type == AuthType.OIDC:
st.session_state.login_step = "oidc_login"
else: # AuthType.NEW
st.session_state.login_step = "register_details"
st.rerun()
# Step 2a: Local Login
elif st.session_state.login_step == "login_password":
st.info(f"Welcome back, **{st.session_state.login_email}**!")
with st.form("password_form"):
password = st.text_input("Password", type="password")
col_login, col_back = st.columns([1, 1])
submitted = col_login.form_submit_button("Login")
back = col_back.form_submit_button("Back")
if back:
st.session_state.login_step = "email"
st.rerun()
if submitted:
user = history_manager.authenticate_user(st.session_state.login_email, password)
if user:
login_user(user)
else:
st.error("Invalid email or password")
# Step 2b: Registration
elif st.session_state.login_step == "register_details":
st.info(f"Create an account for **{st.session_state.login_email}**")
with st.form("register_form"):
reg_name = st.text_input("Display Name")
reg_password = st.text_input("Password", type="password")
col_reg, col_back = st.columns([1, 1])
submitted = col_reg.form_submit_button("Register & Login")
back = col_back.form_submit_button("Back")
if back:
st.session_state.login_step = "email"
st.rerun()
if submitted:
if not reg_password:
st.error("Password is required for registration.")
else:
user = history_manager.create_user(st.session_state.login_email, reg_password, reg_name)
st.success("Registered! Logging in...")
login_user(user)
# Step 2c: OIDC Redirection
elif st.session_state.login_step == "oidc_login":
st.info(f"**{st.session_state.login_email}** is configured for Single Sign-On (SSO).")
col_sso, col_back = st.columns([1, 1])
with col_sso:
if oidc_client:
auth_data = oidc_client.get_auth_data()
# Store PKCE/Nonce in session state for callback validation
st.session_state.oidc_state = auth_data["state"]
st.session_state.oidc_nonce = auth_data["nonce"]
st.session_state.oidc_verifier = auth_data["code_verifier"]
st.link_button("Login with SSO", auth_data["url"], type="primary", use_container_width=True)
else:
st.error("OIDC is not configured.")
with col_back:
if st.button("Back", use_container_width=True):
st.session_state.login_step = "email"
st.rerun()
with col2:
if oidc_client:
st.header("Single Sign-On")
st.write("Login with your organizational account.")
if st.button("Login with SSO"):
auth_data = oidc_client.get_auth_data()
st.session_state.oidc_state = auth_data["state"]
st.session_state.oidc_nonce = auth_data["nonce"]
st.session_state.oidc_verifier = auth_data["code_verifier"]
st.link_button("Go to **Login Provider**", auth_data["url"], type="primary")
else:
st.info("SSO is not configured.")
st.stop()
# --- Main App (Authenticated) ---
user = st.session_state.user
# Sidebar configuration
with st.sidebar:
st.title(f"Hi, {user.display_name or user.username}!")
if st.button("Logout"):
logout_user()
st.divider()
st.header("History")
if st.button(" New Chat", use_container_width=True):
st.session_state.messages = []
st.session_state.summary = ""
st.session_state.current_conversation_id = None
st.rerun()
# List conversations for the current user and data state
conversations = history_manager.get_conversations(user.id, settings.data_state)
for conv in conversations:
col_c, col_r, col_d = st.columns([0.7, 0.15, 0.15])
is_current = st.session_state.get("current_conversation_id") == conv.id
label = f"💬 {conv.name}" if not is_current else f"👉 {conv.name}"
if col_c.button(label, key=f"conv_{conv.id}", use_container_width=True):
load_conversation(conv.id)
if col_r.button("✏️", key=f"ren_{conv.id}"):
st.session_state.renaming_id = conv.id
if col_d.button("🗑️", key=f"del_{conv.id}"):
if history_manager.delete_conversation(conv.id):
if is_current:
st.session_state.current_conversation_id = None
st.session_state.messages = []
st.rerun()
# Rename dialog
if st.session_state.get("renaming_id"):
rid = st.session_state.renaming_id
with st.form("rename_form"):
new_name = st.text_input("New Name")
if st.form_submit_button("Save"):
history_manager.rename_conversation(rid, new_name)
del st.session_state.renaming_id
st.rerun()
if st.form_submit_button("Cancel"):
del st.session_state.renaming_id
st.rerun()
st.divider()
st.header("Settings")
# Check for DEV_MODE env var (defaults to False)
default_dev_mode = os.getenv("DEV_MODE", "false").lower() == "true"
dev_mode = st.checkbox("Dev Mode", value=default_dev_mode, help="Enable to see code generation and raw reasoning steps.")
st.title("🗳️ Election Analytics Chatbot")
# Initialize chat history state
if "messages" not in st.session_state:
st.session_state.messages = []
if "summary" not in st.session_state:
st.session_state.summary = ""
# Display chat messages from history on app rerun
for message in st.session_state.messages:
with st.chat_message(message["role"]):
if message.get("plan") and dev_mode:
with st.expander("Reasoning Plan"):
st.code(message["plan"], language="yaml")
if message.get("code") and dev_mode:
with st.expander("Generated Code"):
st.code(message["code"], language="python")
st.markdown(message["content"])
if message.get("plots"):
for plot_data in message["plots"]:
# If plot_data is bytes, convert to image
if isinstance(plot_data, bytes):
st.image(plot_data)
else:
# Fallback for old session state or non-binary
st.pyplot(plot_data)
if message.get("dfs"):
for df_name, df in message["dfs"].items():
st.subheader(f"Data: {df_name}")
st.dataframe(df)
# Accept user input
if prompt := st.chat_input("Ask a question about election data..."):
# Ensure we have a conversation ID
if not st.session_state.get("current_conversation_id"):
# Auto-create conversation
conv_name = (prompt[:30] + '...') if len(prompt) > 30 else prompt
conv = history_manager.create_conversation(user.id, settings.data_state, conv_name)
st.session_state.current_conversation_id = conv.id
conv_id = st.session_state.current_conversation_id
# Save user message to DB
history_manager.add_message(conv_id, "user", prompt)
# Add user message to session state
st.session_state.messages.append({"role": "user", "content": prompt})
# Display user message in chat message container
with st.chat_message("user"):
st.markdown(prompt)
# Prepare graph input
initial_state: AgentState = {
"messages": st.session_state.messages[:-1], # Pass history (excluding the current prompt)
"question": prompt,
"summary": st.session_state.summary,
"analysis": None,
"next_action": "",
"plan": None,
"code": None,
"code_output": None,
"error": None,
"plots": [],
"dfs": {},
"iterations": 0
}
# Placeholder for graph output
with st.chat_message("assistant"):
final_state = initial_state
# Real-time node updates
with st.status("Thinking...", expanded=True) as status:
try:
# Use app.stream to capture node transitions
for event in app.stream(initial_state):
for node_name, state_update in event.items():
prev_error = final_state.get("error")
# Use helper to merge state correctly (appending messages/plots, updating dfs)
final_state = merge_agent_state(final_state, state_update)
if node_name == "query_analyzer":
analysis = state_update.get("analysis", {})
next_action = state_update.get("next_action", "unknown")
status.write("🔍 **Analyzed Query:**")
for k,v in analysis.items():
status.write(f"- {k:<8}: {v}")
status.markdown(f"Next Step: {next_action.capitalize()}")
elif node_name == "planner":
status.write("📋 **Plan Generated**")
# Render artifacts
if state_update.get("plan") and dev_mode:
with st.expander("Reasoning Plan", expanded=True):
st.code(state_update["plan"], language="yaml")
elif node_name == "researcher":
status.write("🌐 **Research Complete**")
if state_update.get("messages") and dev_mode:
for msg in state_update["messages"]:
# Extract content from BaseMessage or show raw string
content = getattr(msg, "text", msg.content)
status.markdown(content)
elif node_name == "coder":
status.write("💻 **Code Generated**")
if state_update.get("code") and dev_mode:
with st.expander("Generated Code"):
st.code(state_update["code"], language="python")
elif node_name == "error_corrector":
status.write("🛠️ **Fixing Execution Error...**")
if prev_error:
truncated_error = prev_error.strip()
if len(truncated_error) > 180:
truncated_error = truncated_error[:180] + "..."
status.write(f"Previous error: {truncated_error}")
if state_update.get("code") and dev_mode:
with st.expander("Corrected Code"):
st.code(state_update["code"], language="python")
elif node_name == "executor":
if state_update.get("error"):
if dev_mode:
status.write(f"❌ **Execution Error:** {state_update.get('error')}...")
else:
status.write(f"❌ **Execution Error:** {state_update.get('error')[:100]}...")
else:
status.write("✅ **Execution Successful**")
if state_update.get("plots"):
status.write(f"📊 Generated {len(state_update['plots'])} plot(s)")
elif node_name == "summarizer":
status.write("📝 **Summarizing Results...**")
status.update(label="Complete!", state="complete", expanded=False)
except Exception as e:
status.update(label="Error!", state="error")
st.error(f"Error during graph execution: {str(e)}")
# Extract results
response_text: str = ""
if final_state.get("messages"):
# The last message is the Assistant's response
last_msg = final_state["messages"][-1]
response_text = getattr(last_msg, "text", str(last_msg.content))
st.markdown(response_text)
# Collect plot bytes for saving to DB
plot_bytes_list = []
if final_state.get("plots"):
for fig in final_state["plots"]:
st.pyplot(fig)
# Convert fig to bytes
plot_bytes_list.append(fig_to_bytes(fig))
if final_state.get("dfs"):
for df_name, df in final_state["dfs"].items():
st.subheader(f"Data: {df_name}")
st.dataframe(df)
# Save assistant message to DB
history_manager.add_message(conv_id, "assistant", response_text, plots=plot_bytes_list)
# Update summary in DB
new_summary = final_state.get("summary", "")
if new_summary:
history_manager.update_conversation_summary(conv_id, new_summary)
# Store assistant response in session history
st.session_state.messages.append({
"role": "assistant",
"content": response_text,
"plan": final_state.get("plan"),
"code": final_state.get("code"),
"plots": plot_bytes_list,
"dfs": final_state.get("dfs")
})
st.session_state.summary = new_summary
if __name__ == "__main__":
main()