477 lines
21 KiB
Python
477 lines
21 KiB
Python
import streamlit as st
|
||
import os
|
||
import io
|
||
from dotenv import load_dotenv
|
||
from ea_chatbot.graph.workflow import app
|
||
from ea_chatbot.graph.state import AgentState
|
||
from ea_chatbot.utils.logging import get_logger
|
||
from ea_chatbot.utils.helpers import merge_agent_state
|
||
from ea_chatbot.utils.plots import fig_to_bytes
|
||
from ea_chatbot.config import Settings
|
||
from ea_chatbot.history.manager import HistoryManager
|
||
from ea_chatbot.auth import OIDCClient, AuthType, get_user_auth_type
|
||
|
||
# Load environment variables
|
||
load_dotenv()
|
||
|
||
# Initialize Config and Manager
|
||
settings = Settings()
|
||
history_manager = HistoryManager(settings.history_db_url)
|
||
|
||
# Initialize OIDC Client if configured
|
||
oidc_client = None
|
||
if settings.oidc_client_id and settings.oidc_client_secret and settings.oidc_server_metadata_url:
|
||
oidc_client = OIDCClient(
|
||
client_id=settings.oidc_client_id,
|
||
client_secret=settings.oidc_client_secret,
|
||
server_metadata_url=settings.oidc_server_metadata_url,
|
||
# Redirect back to the same page
|
||
redirect_uri=os.getenv("OIDC_REDIRECT_URI", "http://localhost:8501")
|
||
)
|
||
|
||
# Initialize Logger
|
||
logger = get_logger(level=settings.log_level, log_file="logs/app.jsonl")
|
||
|
||
# --- Authentication Helpers ---
|
||
|
||
def login_user(user):
|
||
st.session_state.user = user
|
||
st.session_state.messages = []
|
||
st.session_state.summary = ""
|
||
st.session_state.current_conversation_id = None
|
||
st.rerun()
|
||
|
||
def logout_user():
|
||
for key in list(st.session_state.keys()):
|
||
del st.session_state[key]
|
||
st.rerun()
|
||
|
||
def load_conversation(conv_id):
|
||
messages = history_manager.get_messages(conv_id)
|
||
formatted_messages = []
|
||
for m in messages:
|
||
# Convert DB models to session state dicts
|
||
msg_dict = {
|
||
"role": m.role,
|
||
"content": m.content,
|
||
"plots": [p.image_data for p in m.plots]
|
||
}
|
||
formatted_messages.append(msg_dict)
|
||
|
||
st.session_state.messages = formatted_messages
|
||
st.session_state.current_conversation_id = conv_id
|
||
# Fetch summary from DB
|
||
with history_manager.get_session() as session:
|
||
from ea_chatbot.history.models import Conversation
|
||
conv = session.get(Conversation, conv_id)
|
||
st.session_state.summary = conv.summary if conv else ""
|
||
st.rerun()
|
||
|
||
def main():
|
||
st.set_page_config(
|
||
page_title="Election Analytics Chatbot",
|
||
page_icon="🗳️",
|
||
layout="wide"
|
||
)
|
||
|
||
# Check for OIDC Callback
|
||
if "code" in st.query_params and "state" in st.query_params and oidc_client:
|
||
code = st.query_params["code"]
|
||
state = st.query_params["state"]
|
||
|
||
# Validate state
|
||
stored_state = st.session_state.get("oidc_state")
|
||
if not stored_state or state != stored_state:
|
||
st.error("OIDC state mismatch. Please try again.")
|
||
st.stop()
|
||
|
||
try:
|
||
# 1. Exchange code using PKCE verifier
|
||
verifier = st.session_state.get("oidc_verifier")
|
||
token = oidc_client.exchange_code_for_token(code, code_verifier=verifier)
|
||
|
||
# 2. Validate ID Token using Nonce
|
||
id_token = token.get("id_token")
|
||
nonce = st.session_state.get("oidc_nonce")
|
||
claims = oidc_client.validate_id_token(id_token, nonce=nonce)
|
||
|
||
email = claims.get("email")
|
||
name = claims.get("name") or claims.get("preferred_username")
|
||
|
||
if email:
|
||
user = history_manager.sync_user_from_oidc(email=email, display_name=name)
|
||
# Clear query params and session data
|
||
st.query_params.clear()
|
||
for key in ["oidc_state", "oidc_nonce", "oidc_verifier"]:
|
||
if key in st.session_state:
|
||
del st.session_state[key]
|
||
login_user(user)
|
||
except Exception as e:
|
||
st.error(f"OIDC Login failed: {str(e)}")
|
||
|
||
# Display Login Screen if not authenticated
|
||
if "user" not in st.session_state:
|
||
st.title("🗳️ Election Analytics Chatbot")
|
||
|
||
# Initialize Login State
|
||
if "login_step" not in st.session_state:
|
||
st.session_state.login_step = "email"
|
||
if "login_email" not in st.session_state:
|
||
st.session_state.login_email = ""
|
||
|
||
col1, col2 = st.columns([1, 1])
|
||
|
||
with col1:
|
||
st.header("Login")
|
||
|
||
# Step 1: Identification
|
||
if st.session_state.login_step == "email":
|
||
st.write("Please enter your email to begin:")
|
||
with st.form("email_form"):
|
||
email_input = st.text_input("Email", value=st.session_state.login_email)
|
||
submitted = st.form_submit_button("Next")
|
||
|
||
if submitted:
|
||
if not email_input.strip():
|
||
st.error("Email cannot be empty.")
|
||
else:
|
||
st.session_state.login_email = email_input.strip()
|
||
auth_type = get_user_auth_type(st.session_state.login_email, history_manager)
|
||
|
||
if auth_type == AuthType.LOCAL:
|
||
st.session_state.login_step = "login_password"
|
||
elif auth_type == AuthType.OIDC:
|
||
st.session_state.login_step = "oidc_login"
|
||
else: # AuthType.NEW
|
||
st.session_state.login_step = "register_details"
|
||
st.rerun()
|
||
|
||
# Step 2a: Local Login
|
||
elif st.session_state.login_step == "login_password":
|
||
st.info(f"Welcome back, **{st.session_state.login_email}**!")
|
||
with st.form("password_form"):
|
||
password = st.text_input("Password", type="password")
|
||
|
||
col_login, col_back = st.columns([1, 1])
|
||
submitted = col_login.form_submit_button("Login")
|
||
back = col_back.form_submit_button("Back")
|
||
|
||
if back:
|
||
st.session_state.login_step = "email"
|
||
st.rerun()
|
||
|
||
if submitted:
|
||
user = history_manager.authenticate_user(st.session_state.login_email, password)
|
||
if user:
|
||
login_user(user)
|
||
else:
|
||
st.error("Invalid email or password")
|
||
|
||
# Step 2b: Registration
|
||
elif st.session_state.login_step == "register_details":
|
||
st.info(f"Create an account for **{st.session_state.login_email}**")
|
||
with st.form("register_form"):
|
||
reg_name = st.text_input("Display Name")
|
||
reg_password = st.text_input("Password", type="password")
|
||
|
||
col_reg, col_back = st.columns([1, 1])
|
||
submitted = col_reg.form_submit_button("Register & Login")
|
||
back = col_back.form_submit_button("Back")
|
||
|
||
if back:
|
||
st.session_state.login_step = "email"
|
||
st.rerun()
|
||
|
||
if submitted:
|
||
if not reg_password:
|
||
st.error("Password is required for registration.")
|
||
else:
|
||
user = history_manager.create_user(st.session_state.login_email, reg_password, reg_name)
|
||
st.success("Registered! Logging in...")
|
||
login_user(user)
|
||
|
||
# Step 2c: OIDC Redirection
|
||
elif st.session_state.login_step == "oidc_login":
|
||
st.info(f"**{st.session_state.login_email}** is configured for Single Sign-On (SSO).")
|
||
|
||
col_sso, col_back = st.columns([1, 1])
|
||
|
||
with col_sso:
|
||
if oidc_client:
|
||
auth_data = oidc_client.get_auth_data()
|
||
# Store PKCE/Nonce in session state for callback validation
|
||
st.session_state.oidc_state = auth_data["state"]
|
||
st.session_state.oidc_nonce = auth_data["nonce"]
|
||
st.session_state.oidc_verifier = auth_data["code_verifier"]
|
||
|
||
st.link_button("Login with SSO", auth_data["url"], type="primary", use_container_width=True)
|
||
else:
|
||
st.error("OIDC is not configured.")
|
||
|
||
with col_back:
|
||
if st.button("Back", use_container_width=True):
|
||
st.session_state.login_step = "email"
|
||
st.rerun()
|
||
|
||
with col2:
|
||
if oidc_client:
|
||
st.header("Single Sign-On")
|
||
st.write("Login with your organizational account.")
|
||
if st.button("Login with SSO"):
|
||
auth_data = oidc_client.get_auth_data()
|
||
st.session_state.oidc_state = auth_data["state"]
|
||
st.session_state.oidc_nonce = auth_data["nonce"]
|
||
st.session_state.oidc_verifier = auth_data["code_verifier"]
|
||
st.link_button("Go to **Login Provider**", auth_data["url"], type="primary")
|
||
else:
|
||
st.info("SSO is not configured.")
|
||
|
||
st.stop()
|
||
|
||
# --- Main App (Authenticated) ---
|
||
|
||
user = st.session_state.user
|
||
|
||
# Sidebar configuration
|
||
with st.sidebar:
|
||
st.title(f"Hi, {user.display_name or user.username}!")
|
||
|
||
if st.button("Logout"):
|
||
logout_user()
|
||
|
||
st.divider()
|
||
|
||
st.header("History")
|
||
if st.button("➕ New Chat", use_container_width=True):
|
||
st.session_state.messages = []
|
||
st.session_state.summary = ""
|
||
st.session_state.current_conversation_id = None
|
||
st.rerun()
|
||
|
||
# List conversations for the current user and data state
|
||
conversations = history_manager.get_conversations(user.id, settings.data_state)
|
||
|
||
for conv in conversations:
|
||
col_c, col_r, col_d = st.columns([0.7, 0.15, 0.15])
|
||
|
||
is_current = st.session_state.get("current_conversation_id") == conv.id
|
||
label = f"💬 {conv.name}" if not is_current else f"👉 {conv.name}"
|
||
|
||
if col_c.button(label, key=f"conv_{conv.id}", use_container_width=True):
|
||
load_conversation(conv.id)
|
||
|
||
if col_r.button("✏️", key=f"ren_{conv.id}"):
|
||
st.session_state.renaming_id = conv.id
|
||
|
||
if col_d.button("🗑️", key=f"del_{conv.id}"):
|
||
if history_manager.delete_conversation(conv.id):
|
||
if is_current:
|
||
st.session_state.current_conversation_id = None
|
||
st.session_state.messages = []
|
||
st.rerun()
|
||
|
||
# Rename dialog
|
||
if st.session_state.get("renaming_id"):
|
||
rid = st.session_state.renaming_id
|
||
with st.form("rename_form"):
|
||
new_name = st.text_input("New Name")
|
||
if st.form_submit_button("Save"):
|
||
history_manager.rename_conversation(rid, new_name)
|
||
del st.session_state.renaming_id
|
||
st.rerun()
|
||
if st.form_submit_button("Cancel"):
|
||
del st.session_state.renaming_id
|
||
st.rerun()
|
||
|
||
st.divider()
|
||
st.header("Settings")
|
||
# Check for DEV_MODE env var (defaults to False)
|
||
default_dev_mode = os.getenv("DEV_MODE", "false").lower() == "true"
|
||
dev_mode = st.checkbox("Dev Mode", value=default_dev_mode, help="Enable to see code generation and raw reasoning steps.")
|
||
|
||
st.title("🗳️ Election Analytics Chatbot")
|
||
|
||
# Initialize chat history state
|
||
if "messages" not in st.session_state:
|
||
st.session_state.messages = []
|
||
if "summary" not in st.session_state:
|
||
st.session_state.summary = ""
|
||
|
||
# Display chat messages from history on app rerun
|
||
for message in st.session_state.messages:
|
||
with st.chat_message(message["role"]):
|
||
if message.get("plan") and dev_mode:
|
||
with st.expander("Reasoning Plan"):
|
||
st.code(message["plan"], language="yaml")
|
||
if message.get("code") and dev_mode:
|
||
with st.expander("Generated Code"):
|
||
st.code(message["code"], language="python")
|
||
|
||
st.markdown(message["content"])
|
||
if message.get("plots"):
|
||
for plot_data in message["plots"]:
|
||
# If plot_data is bytes, convert to image
|
||
if isinstance(plot_data, bytes):
|
||
st.image(plot_data)
|
||
else:
|
||
# Fallback for old session state or non-binary
|
||
st.pyplot(plot_data)
|
||
if message.get("dfs"):
|
||
for df_name, df in message["dfs"].items():
|
||
st.subheader(f"Data: {df_name}")
|
||
st.dataframe(df)
|
||
|
||
# Accept user input
|
||
if prompt := st.chat_input("Ask a question about election data..."):
|
||
# Ensure we have a conversation ID
|
||
if not st.session_state.get("current_conversation_id"):
|
||
# Auto-create conversation
|
||
conv_name = (prompt[:30] + '...') if len(prompt) > 30 else prompt
|
||
conv = history_manager.create_conversation(user.id, settings.data_state, conv_name)
|
||
st.session_state.current_conversation_id = conv.id
|
||
|
||
conv_id = st.session_state.current_conversation_id
|
||
|
||
# Save user message to DB
|
||
history_manager.add_message(conv_id, "user", prompt)
|
||
|
||
# Add user message to session state
|
||
st.session_state.messages.append({"role": "user", "content": prompt})
|
||
|
||
# Display user message in chat message container
|
||
with st.chat_message("user"):
|
||
st.markdown(prompt)
|
||
|
||
# Prepare graph input
|
||
initial_state: AgentState = {
|
||
"messages": st.session_state.messages[:-1], # Pass history (excluding the current prompt)
|
||
"question": prompt,
|
||
"summary": st.session_state.summary,
|
||
"analysis": None,
|
||
"next_action": "",
|
||
"plan": None,
|
||
"code": None,
|
||
"code_output": None,
|
||
"error": None,
|
||
"plots": [],
|
||
"dfs": {},
|
||
"iterations": 0
|
||
}
|
||
|
||
# Placeholder for graph output
|
||
with st.chat_message("assistant"):
|
||
final_state = initial_state
|
||
# Real-time node updates
|
||
with st.status("Thinking...", expanded=True) as status:
|
||
try:
|
||
# Use app.stream to capture node transitions
|
||
for event in app.stream(initial_state):
|
||
for node_name, state_update in event.items():
|
||
prev_error = final_state.get("error")
|
||
# Use helper to merge state correctly (appending messages/plots, updating dfs)
|
||
final_state = merge_agent_state(final_state, state_update)
|
||
|
||
if node_name == "query_analyzer":
|
||
analysis = state_update.get("analysis", {})
|
||
next_action = state_update.get("next_action", "unknown")
|
||
status.write("🔍 **Analyzed Query:**")
|
||
for k,v in analysis.items():
|
||
status.write(f"- {k:<8}: {v}")
|
||
status.markdown(f"Next Step: {next_action.capitalize()}")
|
||
|
||
elif node_name == "planner":
|
||
status.write("📋 **Plan Generated**")
|
||
# Render artifacts
|
||
if state_update.get("plan") and dev_mode:
|
||
with st.expander("Reasoning Plan", expanded=True):
|
||
st.code(state_update["plan"], language="yaml")
|
||
|
||
elif node_name == "researcher":
|
||
status.write("🌐 **Research Complete**")
|
||
if state_update.get("messages") and dev_mode:
|
||
for msg in state_update["messages"]:
|
||
# Extract content from BaseMessage or show raw string
|
||
content = getattr(msg, "text", msg.content)
|
||
status.markdown(content)
|
||
|
||
elif node_name == "coder":
|
||
status.write("💻 **Code Generated**")
|
||
if state_update.get("code") and dev_mode:
|
||
with st.expander("Generated Code"):
|
||
st.code(state_update["code"], language="python")
|
||
|
||
elif node_name == "error_corrector":
|
||
status.write("🛠️ **Fixing Execution Error...**")
|
||
if prev_error:
|
||
truncated_error = prev_error.strip()
|
||
if len(truncated_error) > 180:
|
||
truncated_error = truncated_error[:180] + "..."
|
||
status.write(f"Previous error: {truncated_error}")
|
||
if state_update.get("code") and dev_mode:
|
||
with st.expander("Corrected Code"):
|
||
st.code(state_update["code"], language="python")
|
||
|
||
elif node_name == "executor":
|
||
if state_update.get("error"):
|
||
if dev_mode:
|
||
status.write(f"❌ **Execution Error:** {state_update.get('error')}...")
|
||
else:
|
||
status.write(f"❌ **Execution Error:** {state_update.get('error')[:100]}...")
|
||
else:
|
||
status.write("✅ **Execution Successful**")
|
||
if state_update.get("plots"):
|
||
status.write(f"📊 Generated {len(state_update['plots'])} plot(s)")
|
||
|
||
elif node_name == "summarizer":
|
||
status.write("📝 **Summarizing Results...**")
|
||
|
||
|
||
status.update(label="Complete!", state="complete", expanded=False)
|
||
|
||
except Exception as e:
|
||
status.update(label="Error!", state="error")
|
||
st.error(f"Error during graph execution: {str(e)}")
|
||
|
||
# Extract results
|
||
response_text: str = ""
|
||
if final_state.get("messages"):
|
||
# The last message is the Assistant's response
|
||
last_msg = final_state["messages"][-1]
|
||
response_text = getattr(last_msg, "text", str(last_msg.content))
|
||
st.markdown(response_text)
|
||
|
||
# Collect plot bytes for saving to DB
|
||
plot_bytes_list = []
|
||
if final_state.get("plots"):
|
||
for fig in final_state["plots"]:
|
||
st.pyplot(fig)
|
||
# Convert fig to bytes
|
||
plot_bytes_list.append(fig_to_bytes(fig))
|
||
|
||
if final_state.get("dfs"):
|
||
for df_name, df in final_state["dfs"].items():
|
||
st.subheader(f"Data: {df_name}")
|
||
st.dataframe(df)
|
||
|
||
# Save assistant message to DB
|
||
history_manager.add_message(conv_id, "assistant", response_text, plots=plot_bytes_list)
|
||
|
||
# Update summary in DB
|
||
new_summary = final_state.get("summary", "")
|
||
if new_summary:
|
||
history_manager.update_conversation_summary(conv_id, new_summary)
|
||
|
||
# Store assistant response in session history
|
||
st.session_state.messages.append({
|
||
"role": "assistant",
|
||
"content": response_text,
|
||
"plan": final_state.get("plan"),
|
||
"code": final_state.get("code"),
|
||
"plots": plot_bytes_list,
|
||
"dfs": final_state.get("dfs")
|
||
})
|
||
st.session_state.summary = new_summary
|
||
|
||
if __name__ == "__main__":
|
||
main()
|