chore: Finalize cleanup phases (docstrings, utility consolidation, dev app isolation)
This commit is contained in:
476
backend/dev_app.py
Normal file
476
backend/dev_app.py
Normal file
@@ -0,0 +1,476 @@
|
||||
import streamlit as st
|
||||
import os
|
||||
import io
|
||||
from dotenv import load_dotenv
|
||||
from ea_chatbot.graph.workflow import app
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.utils.logging import get_logger
|
||||
from ea_chatbot.utils.helpers import merge_agent_state
|
||||
from ea_chatbot.utils.plots import fig_to_bytes
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.history.manager import HistoryManager
|
||||
from ea_chatbot.auth import OIDCClient, AuthType, get_user_auth_type
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Initialize Config and Manager
|
||||
settings = Settings()
|
||||
history_manager = HistoryManager(settings.history_db_url)
|
||||
|
||||
# Initialize OIDC Client if configured
|
||||
oidc_client = None
|
||||
if settings.oidc_client_id and settings.oidc_client_secret and settings.oidc_server_metadata_url:
|
||||
oidc_client = OIDCClient(
|
||||
client_id=settings.oidc_client_id,
|
||||
client_secret=settings.oidc_client_secret,
|
||||
server_metadata_url=settings.oidc_server_metadata_url,
|
||||
# Redirect back to the same page
|
||||
redirect_uri=os.getenv("OIDC_REDIRECT_URI", "http://localhost:8501")
|
||||
)
|
||||
|
||||
# Initialize Logger
|
||||
logger = get_logger(level=settings.log_level, log_file="logs/app.jsonl")
|
||||
|
||||
# --- Authentication Helpers ---
|
||||
|
||||
def login_user(user):
|
||||
st.session_state.user = user
|
||||
st.session_state.messages = []
|
||||
st.session_state.summary = ""
|
||||
st.session_state.current_conversation_id = None
|
||||
st.rerun()
|
||||
|
||||
def logout_user():
|
||||
for key in list(st.session_state.keys()):
|
||||
del st.session_state[key]
|
||||
st.rerun()
|
||||
|
||||
def load_conversation(conv_id):
|
||||
messages = history_manager.get_messages(conv_id)
|
||||
formatted_messages = []
|
||||
for m in messages:
|
||||
# Convert DB models to session state dicts
|
||||
msg_dict = {
|
||||
"role": m.role,
|
||||
"content": m.content,
|
||||
"plots": [p.image_data for p in m.plots]
|
||||
}
|
||||
formatted_messages.append(msg_dict)
|
||||
|
||||
st.session_state.messages = formatted_messages
|
||||
st.session_state.current_conversation_id = conv_id
|
||||
# Fetch summary from DB
|
||||
with history_manager.get_session() as session:
|
||||
from ea_chatbot.history.models import Conversation
|
||||
conv = session.get(Conversation, conv_id)
|
||||
st.session_state.summary = conv.summary if conv else ""
|
||||
st.rerun()
|
||||
|
||||
def main():
|
||||
st.set_page_config(
|
||||
page_title="Election Analytics Chatbot",
|
||||
page_icon="🗳️",
|
||||
layout="wide"
|
||||
)
|
||||
|
||||
# Check for OIDC Callback
|
||||
if "code" in st.query_params and "state" in st.query_params and oidc_client:
|
||||
code = st.query_params["code"]
|
||||
state = st.query_params["state"]
|
||||
|
||||
# Validate state
|
||||
stored_state = st.session_state.get("oidc_state")
|
||||
if not stored_state or state != stored_state:
|
||||
st.error("OIDC state mismatch. Please try again.")
|
||||
st.stop()
|
||||
|
||||
try:
|
||||
# 1. Exchange code using PKCE verifier
|
||||
verifier = st.session_state.get("oidc_verifier")
|
||||
token = oidc_client.exchange_code_for_token(code, code_verifier=verifier)
|
||||
|
||||
# 2. Validate ID Token using Nonce
|
||||
id_token = token.get("id_token")
|
||||
nonce = st.session_state.get("oidc_nonce")
|
||||
claims = oidc_client.validate_id_token(id_token, nonce=nonce)
|
||||
|
||||
email = claims.get("email")
|
||||
name = claims.get("name") or claims.get("preferred_username")
|
||||
|
||||
if email:
|
||||
user = history_manager.sync_user_from_oidc(email=email, display_name=name)
|
||||
# Clear query params and session data
|
||||
st.query_params.clear()
|
||||
for key in ["oidc_state", "oidc_nonce", "oidc_verifier"]:
|
||||
if key in st.session_state:
|
||||
del st.session_state[key]
|
||||
login_user(user)
|
||||
except Exception as e:
|
||||
st.error(f"OIDC Login failed: {str(e)}")
|
||||
|
||||
# Display Login Screen if not authenticated
|
||||
if "user" not in st.session_state:
|
||||
st.title("🗳️ Election Analytics Chatbot")
|
||||
|
||||
# Initialize Login State
|
||||
if "login_step" not in st.session_state:
|
||||
st.session_state.login_step = "email"
|
||||
if "login_email" not in st.session_state:
|
||||
st.session_state.login_email = ""
|
||||
|
||||
col1, col2 = st.columns([1, 1])
|
||||
|
||||
with col1:
|
||||
st.header("Login")
|
||||
|
||||
# Step 1: Identification
|
||||
if st.session_state.login_step == "email":
|
||||
st.write("Please enter your email to begin:")
|
||||
with st.form("email_form"):
|
||||
email_input = st.text_input("Email", value=st.session_state.login_email)
|
||||
submitted = st.form_submit_button("Next")
|
||||
|
||||
if submitted:
|
||||
if not email_input.strip():
|
||||
st.error("Email cannot be empty.")
|
||||
else:
|
||||
st.session_state.login_email = email_input.strip()
|
||||
auth_type = get_user_auth_type(st.session_state.login_email, history_manager)
|
||||
|
||||
if auth_type == AuthType.LOCAL:
|
||||
st.session_state.login_step = "login_password"
|
||||
elif auth_type == AuthType.OIDC:
|
||||
st.session_state.login_step = "oidc_login"
|
||||
else: # AuthType.NEW
|
||||
st.session_state.login_step = "register_details"
|
||||
st.rerun()
|
||||
|
||||
# Step 2a: Local Login
|
||||
elif st.session_state.login_step == "login_password":
|
||||
st.info(f"Welcome back, **{st.session_state.login_email}**!")
|
||||
with st.form("password_form"):
|
||||
password = st.text_input("Password", type="password")
|
||||
|
||||
col_login, col_back = st.columns([1, 1])
|
||||
submitted = col_login.form_submit_button("Login")
|
||||
back = col_back.form_submit_button("Back")
|
||||
|
||||
if back:
|
||||
st.session_state.login_step = "email"
|
||||
st.rerun()
|
||||
|
||||
if submitted:
|
||||
user = history_manager.authenticate_user(st.session_state.login_email, password)
|
||||
if user:
|
||||
login_user(user)
|
||||
else:
|
||||
st.error("Invalid email or password")
|
||||
|
||||
# Step 2b: Registration
|
||||
elif st.session_state.login_step == "register_details":
|
||||
st.info(f"Create an account for **{st.session_state.login_email}**")
|
||||
with st.form("register_form"):
|
||||
reg_name = st.text_input("Display Name")
|
||||
reg_password = st.text_input("Password", type="password")
|
||||
|
||||
col_reg, col_back = st.columns([1, 1])
|
||||
submitted = col_reg.form_submit_button("Register & Login")
|
||||
back = col_back.form_submit_button("Back")
|
||||
|
||||
if back:
|
||||
st.session_state.login_step = "email"
|
||||
st.rerun()
|
||||
|
||||
if submitted:
|
||||
if not reg_password:
|
||||
st.error("Password is required for registration.")
|
||||
else:
|
||||
user = history_manager.create_user(st.session_state.login_email, reg_password, reg_name)
|
||||
st.success("Registered! Logging in...")
|
||||
login_user(user)
|
||||
|
||||
# Step 2c: OIDC Redirection
|
||||
elif st.session_state.login_step == "oidc_login":
|
||||
st.info(f"**{st.session_state.login_email}** is configured for Single Sign-On (SSO).")
|
||||
|
||||
col_sso, col_back = st.columns([1, 1])
|
||||
|
||||
with col_sso:
|
||||
if oidc_client:
|
||||
auth_data = oidc_client.get_auth_data()
|
||||
# Store PKCE/Nonce in session state for callback validation
|
||||
st.session_state.oidc_state = auth_data["state"]
|
||||
st.session_state.oidc_nonce = auth_data["nonce"]
|
||||
st.session_state.oidc_verifier = auth_data["code_verifier"]
|
||||
|
||||
st.link_button("Login with SSO", auth_data["url"], type="primary", use_container_width=True)
|
||||
else:
|
||||
st.error("OIDC is not configured.")
|
||||
|
||||
with col_back:
|
||||
if st.button("Back", use_container_width=True):
|
||||
st.session_state.login_step = "email"
|
||||
st.rerun()
|
||||
|
||||
with col2:
|
||||
if oidc_client:
|
||||
st.header("Single Sign-On")
|
||||
st.write("Login with your organizational account.")
|
||||
if st.button("Login with SSO"):
|
||||
auth_data = oidc_client.get_auth_data()
|
||||
st.session_state.oidc_state = auth_data["state"]
|
||||
st.session_state.oidc_nonce = auth_data["nonce"]
|
||||
st.session_state.oidc_verifier = auth_data["code_verifier"]
|
||||
st.link_button("Go to **Login Provider**", auth_data["url"], type="primary")
|
||||
else:
|
||||
st.info("SSO is not configured.")
|
||||
|
||||
st.stop()
|
||||
|
||||
# --- Main App (Authenticated) ---
|
||||
|
||||
user = st.session_state.user
|
||||
|
||||
# Sidebar configuration
|
||||
with st.sidebar:
|
||||
st.title(f"Hi, {user.display_name or user.username}!")
|
||||
|
||||
if st.button("Logout"):
|
||||
logout_user()
|
||||
|
||||
st.divider()
|
||||
|
||||
st.header("History")
|
||||
if st.button("➕ New Chat", use_container_width=True):
|
||||
st.session_state.messages = []
|
||||
st.session_state.summary = ""
|
||||
st.session_state.current_conversation_id = None
|
||||
st.rerun()
|
||||
|
||||
# List conversations for the current user and data state
|
||||
conversations = history_manager.get_conversations(user.id, settings.data_state)
|
||||
|
||||
for conv in conversations:
|
||||
col_c, col_r, col_d = st.columns([0.7, 0.15, 0.15])
|
||||
|
||||
is_current = st.session_state.get("current_conversation_id") == conv.id
|
||||
label = f"💬 {conv.name}" if not is_current else f"👉 {conv.name}"
|
||||
|
||||
if col_c.button(label, key=f"conv_{conv.id}", use_container_width=True):
|
||||
load_conversation(conv.id)
|
||||
|
||||
if col_r.button("✏️", key=f"ren_{conv.id}"):
|
||||
st.session_state.renaming_id = conv.id
|
||||
|
||||
if col_d.button("🗑️", key=f"del_{conv.id}"):
|
||||
if history_manager.delete_conversation(conv.id):
|
||||
if is_current:
|
||||
st.session_state.current_conversation_id = None
|
||||
st.session_state.messages = []
|
||||
st.rerun()
|
||||
|
||||
# Rename dialog
|
||||
if st.session_state.get("renaming_id"):
|
||||
rid = st.session_state.renaming_id
|
||||
with st.form("rename_form"):
|
||||
new_name = st.text_input("New Name")
|
||||
if st.form_submit_button("Save"):
|
||||
history_manager.rename_conversation(rid, new_name)
|
||||
del st.session_state.renaming_id
|
||||
st.rerun()
|
||||
if st.form_submit_button("Cancel"):
|
||||
del st.session_state.renaming_id
|
||||
st.rerun()
|
||||
|
||||
st.divider()
|
||||
st.header("Settings")
|
||||
# Check for DEV_MODE env var (defaults to False)
|
||||
default_dev_mode = os.getenv("DEV_MODE", "false").lower() == "true"
|
||||
dev_mode = st.checkbox("Dev Mode", value=default_dev_mode, help="Enable to see code generation and raw reasoning steps.")
|
||||
|
||||
st.title("🗳️ Election Analytics Chatbot")
|
||||
|
||||
# Initialize chat history state
|
||||
if "messages" not in st.session_state:
|
||||
st.session_state.messages = []
|
||||
if "summary" not in st.session_state:
|
||||
st.session_state.summary = ""
|
||||
|
||||
# Display chat messages from history on app rerun
|
||||
for message in st.session_state.messages:
|
||||
with st.chat_message(message["role"]):
|
||||
if message.get("plan") and dev_mode:
|
||||
with st.expander("Reasoning Plan"):
|
||||
st.code(message["plan"], language="yaml")
|
||||
if message.get("code") and dev_mode:
|
||||
with st.expander("Generated Code"):
|
||||
st.code(message["code"], language="python")
|
||||
|
||||
st.markdown(message["content"])
|
||||
if message.get("plots"):
|
||||
for plot_data in message["plots"]:
|
||||
# If plot_data is bytes, convert to image
|
||||
if isinstance(plot_data, bytes):
|
||||
st.image(plot_data)
|
||||
else:
|
||||
# Fallback for old session state or non-binary
|
||||
st.pyplot(plot_data)
|
||||
if message.get("dfs"):
|
||||
for df_name, df in message["dfs"].items():
|
||||
st.subheader(f"Data: {df_name}")
|
||||
st.dataframe(df)
|
||||
|
||||
# Accept user input
|
||||
if prompt := st.chat_input("Ask a question about election data..."):
|
||||
# Ensure we have a conversation ID
|
||||
if not st.session_state.get("current_conversation_id"):
|
||||
# Auto-create conversation
|
||||
conv_name = (prompt[:30] + '...') if len(prompt) > 30 else prompt
|
||||
conv = history_manager.create_conversation(user.id, settings.data_state, conv_name)
|
||||
st.session_state.current_conversation_id = conv.id
|
||||
|
||||
conv_id = st.session_state.current_conversation_id
|
||||
|
||||
# Save user message to DB
|
||||
history_manager.add_message(conv_id, "user", prompt)
|
||||
|
||||
# Add user message to session state
|
||||
st.session_state.messages.append({"role": "user", "content": prompt})
|
||||
|
||||
# Display user message in chat message container
|
||||
with st.chat_message("user"):
|
||||
st.markdown(prompt)
|
||||
|
||||
# Prepare graph input
|
||||
initial_state: AgentState = {
|
||||
"messages": st.session_state.messages[:-1], # Pass history (excluding the current prompt)
|
||||
"question": prompt,
|
||||
"summary": st.session_state.summary,
|
||||
"analysis": None,
|
||||
"next_action": "",
|
||||
"plan": None,
|
||||
"code": None,
|
||||
"code_output": None,
|
||||
"error": None,
|
||||
"plots": [],
|
||||
"dfs": {},
|
||||
"iterations": 0
|
||||
}
|
||||
|
||||
# Placeholder for graph output
|
||||
with st.chat_message("assistant"):
|
||||
final_state = initial_state
|
||||
# Real-time node updates
|
||||
with st.status("Thinking...", expanded=True) as status:
|
||||
try:
|
||||
# Use app.stream to capture node transitions
|
||||
for event in app.stream(initial_state):
|
||||
for node_name, state_update in event.items():
|
||||
prev_error = final_state.get("error")
|
||||
# Use helper to merge state correctly (appending messages/plots, updating dfs)
|
||||
final_state = merge_agent_state(final_state, state_update)
|
||||
|
||||
if node_name == "query_analyzer":
|
||||
analysis = state_update.get("analysis", {})
|
||||
next_action = state_update.get("next_action", "unknown")
|
||||
status.write("🔍 **Analyzed Query:**")
|
||||
for k,v in analysis.items():
|
||||
status.write(f"- {k:<8}: {v}")
|
||||
status.markdown(f"Next Step: {next_action.capitalize()}")
|
||||
|
||||
elif node_name == "planner":
|
||||
status.write("📋 **Plan Generated**")
|
||||
# Render artifacts
|
||||
if state_update.get("plan") and dev_mode:
|
||||
with st.expander("Reasoning Plan", expanded=True):
|
||||
st.code(state_update["plan"], language="yaml")
|
||||
|
||||
elif node_name == "researcher":
|
||||
status.write("🌐 **Research Complete**")
|
||||
if state_update.get("messages") and dev_mode:
|
||||
for msg in state_update["messages"]:
|
||||
# Extract content from BaseMessage or show raw string
|
||||
content = getattr(msg, "text", msg.content)
|
||||
status.markdown(content)
|
||||
|
||||
elif node_name == "coder":
|
||||
status.write("💻 **Code Generated**")
|
||||
if state_update.get("code") and dev_mode:
|
||||
with st.expander("Generated Code"):
|
||||
st.code(state_update["code"], language="python")
|
||||
|
||||
elif node_name == "error_corrector":
|
||||
status.write("🛠️ **Fixing Execution Error...**")
|
||||
if prev_error:
|
||||
truncated_error = prev_error.strip()
|
||||
if len(truncated_error) > 180:
|
||||
truncated_error = truncated_error[:180] + "..."
|
||||
status.write(f"Previous error: {truncated_error}")
|
||||
if state_update.get("code") and dev_mode:
|
||||
with st.expander("Corrected Code"):
|
||||
st.code(state_update["code"], language="python")
|
||||
|
||||
elif node_name == "executor":
|
||||
if state_update.get("error"):
|
||||
if dev_mode:
|
||||
status.write(f"❌ **Execution Error:** {state_update.get('error')}...")
|
||||
else:
|
||||
status.write(f"❌ **Execution Error:** {state_update.get('error')[:100]}...")
|
||||
else:
|
||||
status.write("✅ **Execution Successful**")
|
||||
if state_update.get("plots"):
|
||||
status.write(f"📊 Generated {len(state_update['plots'])} plot(s)")
|
||||
|
||||
elif node_name == "summarizer":
|
||||
status.write("📝 **Summarizing Results...**")
|
||||
|
||||
|
||||
status.update(label="Complete!", state="complete", expanded=False)
|
||||
|
||||
except Exception as e:
|
||||
status.update(label="Error!", state="error")
|
||||
st.error(f"Error during graph execution: {str(e)}")
|
||||
|
||||
# Extract results
|
||||
response_text: str = ""
|
||||
if final_state.get("messages"):
|
||||
# The last message is the Assistant's response
|
||||
last_msg = final_state["messages"][-1]
|
||||
response_text = getattr(last_msg, "text", str(last_msg.content))
|
||||
st.markdown(response_text)
|
||||
|
||||
# Collect plot bytes for saving to DB
|
||||
plot_bytes_list = []
|
||||
if final_state.get("plots"):
|
||||
for fig in final_state["plots"]:
|
||||
st.pyplot(fig)
|
||||
# Convert fig to bytes
|
||||
plot_bytes_list.append(fig_to_bytes(fig))
|
||||
|
||||
if final_state.get("dfs"):
|
||||
for df_name, df in final_state["dfs"].items():
|
||||
st.subheader(f"Data: {df_name}")
|
||||
st.dataframe(df)
|
||||
|
||||
# Save assistant message to DB
|
||||
history_manager.add_message(conv_id, "assistant", response_text, plots=plot_bytes_list)
|
||||
|
||||
# Update summary in DB
|
||||
new_summary = final_state.get("summary", "")
|
||||
if new_summary:
|
||||
history_manager.update_conversation_summary(conv_id, new_summary)
|
||||
|
||||
# Store assistant response in session history
|
||||
st.session_state.messages.append({
|
||||
"role": "assistant",
|
||||
"content": response_text,
|
||||
"plan": final_state.get("plan"),
|
||||
"code": final_state.get("code"),
|
||||
"plots": plot_bytes_list,
|
||||
"dfs": final_state.get("dfs")
|
||||
})
|
||||
st.session_state.summary = new_summary
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user