Files
ea-chatbot-lg/backend/dev_app.py

477 lines
21 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import streamlit as st
import os
import io
from dotenv import load_dotenv
from ea_chatbot.graph.workflow import app
from ea_chatbot.graph.state import AgentState
from ea_chatbot.utils.logging import get_logger
from ea_chatbot.utils.helpers import merge_agent_state
from ea_chatbot.utils.plots import fig_to_bytes
from ea_chatbot.config import Settings
from ea_chatbot.history.manager import HistoryManager
from ea_chatbot.auth import OIDCClient, AuthType, get_user_auth_type
# Load environment variables
load_dotenv()
# Initialize Config and Manager
settings = Settings()
history_manager = HistoryManager(settings.history_db_url)
# Initialize OIDC Client if configured
oidc_client = None
if settings.oidc_client_id and settings.oidc_client_secret and settings.oidc_server_metadata_url:
oidc_client = OIDCClient(
client_id=settings.oidc_client_id,
client_secret=settings.oidc_client_secret,
server_metadata_url=settings.oidc_server_metadata_url,
# Redirect back to the same page
redirect_uri=os.getenv("OIDC_REDIRECT_URI", "http://localhost:8501")
)
# Initialize Logger
logger = get_logger(level=settings.log_level, log_file="logs/app.jsonl")
# --- Authentication Helpers ---
def login_user(user):
st.session_state.user = user
st.session_state.messages = []
st.session_state.summary = ""
st.session_state.current_conversation_id = None
st.rerun()
def logout_user():
for key in list(st.session_state.keys()):
del st.session_state[key]
st.rerun()
def load_conversation(conv_id):
messages = history_manager.get_messages(conv_id)
formatted_messages = []
for m in messages:
# Convert DB models to session state dicts
msg_dict = {
"role": m.role,
"content": m.content,
"plots": [p.image_data for p in m.plots]
}
formatted_messages.append(msg_dict)
st.session_state.messages = formatted_messages
st.session_state.current_conversation_id = conv_id
# Fetch summary from DB
with history_manager.get_session() as session:
from ea_chatbot.history.models import Conversation
conv = session.get(Conversation, conv_id)
st.session_state.summary = conv.summary if conv else ""
st.rerun()
def main():
st.set_page_config(
page_title="Election Analytics Chatbot",
page_icon="🗳️",
layout="wide"
)
# Check for OIDC Callback
if "code" in st.query_params and "state" in st.query_params and oidc_client:
code = st.query_params["code"]
state = st.query_params["state"]
# Validate state
stored_state = st.session_state.get("oidc_state")
if not stored_state or state != stored_state:
st.error("OIDC state mismatch. Please try again.")
st.stop()
try:
# 1. Exchange code using PKCE verifier
verifier = st.session_state.get("oidc_verifier")
token = oidc_client.exchange_code_for_token(code, code_verifier=verifier)
# 2. Validate ID Token using Nonce
id_token = token.get("id_token")
nonce = st.session_state.get("oidc_nonce")
claims = oidc_client.validate_id_token(id_token, nonce=nonce)
email = claims.get("email")
name = claims.get("name") or claims.get("preferred_username")
if email:
user = history_manager.sync_user_from_oidc(email=email, display_name=name)
# Clear query params and session data
st.query_params.clear()
for key in ["oidc_state", "oidc_nonce", "oidc_verifier"]:
if key in st.session_state:
del st.session_state[key]
login_user(user)
except Exception as e:
st.error(f"OIDC Login failed: {str(e)}")
# Display Login Screen if not authenticated
if "user" not in st.session_state:
st.title("🗳️ Election Analytics Chatbot")
# Initialize Login State
if "login_step" not in st.session_state:
st.session_state.login_step = "email"
if "login_email" not in st.session_state:
st.session_state.login_email = ""
col1, col2 = st.columns([1, 1])
with col1:
st.header("Login")
# Step 1: Identification
if st.session_state.login_step == "email":
st.write("Please enter your email to begin:")
with st.form("email_form"):
email_input = st.text_input("Email", value=st.session_state.login_email)
submitted = st.form_submit_button("Next")
if submitted:
if not email_input.strip():
st.error("Email cannot be empty.")
else:
st.session_state.login_email = email_input.strip()
auth_type = get_user_auth_type(st.session_state.login_email, history_manager)
if auth_type == AuthType.LOCAL:
st.session_state.login_step = "login_password"
elif auth_type == AuthType.OIDC:
st.session_state.login_step = "oidc_login"
else: # AuthType.NEW
st.session_state.login_step = "register_details"
st.rerun()
# Step 2a: Local Login
elif st.session_state.login_step == "login_password":
st.info(f"Welcome back, **{st.session_state.login_email}**!")
with st.form("password_form"):
password = st.text_input("Password", type="password")
col_login, col_back = st.columns([1, 1])
submitted = col_login.form_submit_button("Login")
back = col_back.form_submit_button("Back")
if back:
st.session_state.login_step = "email"
st.rerun()
if submitted:
user = history_manager.authenticate_user(st.session_state.login_email, password)
if user:
login_user(user)
else:
st.error("Invalid email or password")
# Step 2b: Registration
elif st.session_state.login_step == "register_details":
st.info(f"Create an account for **{st.session_state.login_email}**")
with st.form("register_form"):
reg_name = st.text_input("Display Name")
reg_password = st.text_input("Password", type="password")
col_reg, col_back = st.columns([1, 1])
submitted = col_reg.form_submit_button("Register & Login")
back = col_back.form_submit_button("Back")
if back:
st.session_state.login_step = "email"
st.rerun()
if submitted:
if not reg_password:
st.error("Password is required for registration.")
else:
user = history_manager.create_user(st.session_state.login_email, reg_password, reg_name)
st.success("Registered! Logging in...")
login_user(user)
# Step 2c: OIDC Redirection
elif st.session_state.login_step == "oidc_login":
st.info(f"**{st.session_state.login_email}** is configured for Single Sign-On (SSO).")
col_sso, col_back = st.columns([1, 1])
with col_sso:
if oidc_client:
auth_data = oidc_client.get_auth_data()
# Store PKCE/Nonce in session state for callback validation
st.session_state.oidc_state = auth_data["state"]
st.session_state.oidc_nonce = auth_data["nonce"]
st.session_state.oidc_verifier = auth_data["code_verifier"]
st.link_button("Login with SSO", auth_data["url"], type="primary", use_container_width=True)
else:
st.error("OIDC is not configured.")
with col_back:
if st.button("Back", use_container_width=True):
st.session_state.login_step = "email"
st.rerun()
with col2:
if oidc_client:
st.header("Single Sign-On")
st.write("Login with your organizational account.")
if st.button("Login with SSO"):
auth_data = oidc_client.get_auth_data()
st.session_state.oidc_state = auth_data["state"]
st.session_state.oidc_nonce = auth_data["nonce"]
st.session_state.oidc_verifier = auth_data["code_verifier"]
st.link_button("Go to **Login Provider**", auth_data["url"], type="primary")
else:
st.info("SSO is not configured.")
st.stop()
# --- Main App (Authenticated) ---
user = st.session_state.user
# Sidebar configuration
with st.sidebar:
st.title(f"Hi, {user.display_name or user.username}!")
if st.button("Logout"):
logout_user()
st.divider()
st.header("History")
if st.button(" New Chat", use_container_width=True):
st.session_state.messages = []
st.session_state.summary = ""
st.session_state.current_conversation_id = None
st.rerun()
# List conversations for the current user and data state
conversations = history_manager.get_conversations(user.id, settings.data_state)
for conv in conversations:
col_c, col_r, col_d = st.columns([0.7, 0.15, 0.15])
is_current = st.session_state.get("current_conversation_id") == conv.id
label = f"💬 {conv.name}" if not is_current else f"👉 {conv.name}"
if col_c.button(label, key=f"conv_{conv.id}", use_container_width=True):
load_conversation(conv.id)
if col_r.button("✏️", key=f"ren_{conv.id}"):
st.session_state.renaming_id = conv.id
if col_d.button("🗑️", key=f"del_{conv.id}"):
if history_manager.delete_conversation(conv.id):
if is_current:
st.session_state.current_conversation_id = None
st.session_state.messages = []
st.rerun()
# Rename dialog
if st.session_state.get("renaming_id"):
rid = st.session_state.renaming_id
with st.form("rename_form"):
new_name = st.text_input("New Name")
if st.form_submit_button("Save"):
history_manager.rename_conversation(rid, new_name)
del st.session_state.renaming_id
st.rerun()
if st.form_submit_button("Cancel"):
del st.session_state.renaming_id
st.rerun()
st.divider()
st.header("Settings")
# Check for DEV_MODE env var (defaults to False)
default_dev_mode = os.getenv("DEV_MODE", "false").lower() == "true"
dev_mode = st.checkbox("Dev Mode", value=default_dev_mode, help="Enable to see code generation and raw reasoning steps.")
st.title("🗳️ Election Analytics Chatbot")
# Initialize chat history state
if "messages" not in st.session_state:
st.session_state.messages = []
if "summary" not in st.session_state:
st.session_state.summary = ""
# Display chat messages from history on app rerun
for message in st.session_state.messages:
with st.chat_message(message["role"]):
if message.get("plan") and dev_mode:
with st.expander("Reasoning Plan"):
st.code(message["plan"], language="yaml")
if message.get("code") and dev_mode:
with st.expander("Generated Code"):
st.code(message["code"], language="python")
st.markdown(message["content"])
if message.get("plots"):
for plot_data in message["plots"]:
# If plot_data is bytes, convert to image
if isinstance(plot_data, bytes):
st.image(plot_data)
else:
# Fallback for old session state or non-binary
st.pyplot(plot_data)
if message.get("dfs"):
for df_name, df in message["dfs"].items():
st.subheader(f"Data: {df_name}")
st.dataframe(df)
# Accept user input
if prompt := st.chat_input("Ask a question about election data..."):
# Ensure we have a conversation ID
if not st.session_state.get("current_conversation_id"):
# Auto-create conversation
conv_name = (prompt[:30] + '...') if len(prompt) > 30 else prompt
conv = history_manager.create_conversation(user.id, settings.data_state, conv_name)
st.session_state.current_conversation_id = conv.id
conv_id = st.session_state.current_conversation_id
# Save user message to DB
history_manager.add_message(conv_id, "user", prompt)
# Add user message to session state
st.session_state.messages.append({"role": "user", "content": prompt})
# Display user message in chat message container
with st.chat_message("user"):
st.markdown(prompt)
# Prepare graph input
initial_state: AgentState = {
"messages": st.session_state.messages[:-1], # Pass history (excluding the current prompt)
"question": prompt,
"summary": st.session_state.summary,
"analysis": None,
"next_action": "",
"plan": None,
"code": None,
"code_output": None,
"error": None,
"plots": [],
"dfs": {},
"iterations": 0
}
# Placeholder for graph output
with st.chat_message("assistant"):
final_state = initial_state
# Real-time node updates
with st.status("Thinking...", expanded=True) as status:
try:
# Use app.stream to capture node transitions
for event in app.stream(initial_state):
for node_name, state_update in event.items():
prev_error = final_state.get("error")
# Use helper to merge state correctly (appending messages/plots, updating dfs)
final_state = merge_agent_state(final_state, state_update)
if node_name == "query_analyzer":
analysis = state_update.get("analysis", {})
next_action = state_update.get("next_action", "unknown")
status.write("🔍 **Analyzed Query:**")
for k,v in analysis.items():
status.write(f"- {k:<8}: {v}")
status.markdown(f"Next Step: {next_action.capitalize()}")
elif node_name == "planner":
status.write("📋 **Plan Generated**")
# Render artifacts
if state_update.get("plan") and dev_mode:
with st.expander("Reasoning Plan", expanded=True):
st.code(state_update["plan"], language="yaml")
elif node_name == "researcher":
status.write("🌐 **Research Complete**")
if state_update.get("messages") and dev_mode:
for msg in state_update["messages"]:
# Extract content from BaseMessage or show raw string
content = getattr(msg, "text", msg.content)
status.markdown(content)
elif node_name == "coder":
status.write("💻 **Code Generated**")
if state_update.get("code") and dev_mode:
with st.expander("Generated Code"):
st.code(state_update["code"], language="python")
elif node_name == "error_corrector":
status.write("🛠️ **Fixing Execution Error...**")
if prev_error:
truncated_error = prev_error.strip()
if len(truncated_error) > 180:
truncated_error = truncated_error[:180] + "..."
status.write(f"Previous error: {truncated_error}")
if state_update.get("code") and dev_mode:
with st.expander("Corrected Code"):
st.code(state_update["code"], language="python")
elif node_name == "executor":
if state_update.get("error"):
if dev_mode:
status.write(f"❌ **Execution Error:** {state_update.get('error')}...")
else:
status.write(f"❌ **Execution Error:** {state_update.get('error')[:100]}...")
else:
status.write("✅ **Execution Successful**")
if state_update.get("plots"):
status.write(f"📊 Generated {len(state_update['plots'])} plot(s)")
elif node_name == "summarizer":
status.write("📝 **Summarizing Results...**")
status.update(label="Complete!", state="complete", expanded=False)
except Exception as e:
status.update(label="Error!", state="error")
st.error(f"Error during graph execution: {str(e)}")
# Extract results
response_text: str = ""
if final_state.get("messages"):
# The last message is the Assistant's response
last_msg = final_state["messages"][-1]
response_text = getattr(last_msg, "text", str(last_msg.content))
st.markdown(response_text)
# Collect plot bytes for saving to DB
plot_bytes_list = []
if final_state.get("plots"):
for fig in final_state["plots"]:
st.pyplot(fig)
# Convert fig to bytes
plot_bytes_list.append(fig_to_bytes(fig))
if final_state.get("dfs"):
for df_name, df in final_state["dfs"].items():
st.subheader(f"Data: {df_name}")
st.dataframe(df)
# Save assistant message to DB
history_manager.add_message(conv_id, "assistant", response_text, plots=plot_bytes_list)
# Update summary in DB
new_summary = final_state.get("summary", "")
if new_summary:
history_manager.update_conversation_summary(conv_id, new_summary)
# Store assistant response in session history
st.session_state.messages.append({
"role": "assistant",
"content": response_text,
"plan": final_state.get("plan"),
"code": final_state.get("code"),
"plots": plot_bytes_list,
"dfs": final_state.get("dfs")
})
st.session_state.summary = new_summary
if __name__ == "__main__":
main()