import streamlit as st import os import io from dotenv import load_dotenv from ea_chatbot.graph.workflow import app from ea_chatbot.graph.state import AgentState from ea_chatbot.utils.logging import get_logger from ea_chatbot.utils.helpers import merge_agent_state from ea_chatbot.utils.plots import fig_to_bytes from ea_chatbot.config import Settings from ea_chatbot.history.manager import HistoryManager from ea_chatbot.auth import OIDCClient, AuthType, get_user_auth_type # Load environment variables load_dotenv() # Initialize Config and Manager settings = Settings() history_manager = HistoryManager(settings.history_db_url) # Initialize OIDC Client if configured oidc_client = None if settings.oidc_client_id and settings.oidc_client_secret and settings.oidc_server_metadata_url: oidc_client = OIDCClient( client_id=settings.oidc_client_id, client_secret=settings.oidc_client_secret, server_metadata_url=settings.oidc_server_metadata_url, # Redirect back to the same page redirect_uri=os.getenv("OIDC_REDIRECT_URI", "http://localhost:8501") ) # Initialize Logger logger = get_logger(level=settings.log_level, log_file="logs/app.jsonl") # --- Authentication Helpers --- def login_user(user): st.session_state.user = user st.session_state.messages = [] st.session_state.summary = "" st.session_state.current_conversation_id = None st.rerun() def logout_user(): for key in list(st.session_state.keys()): del st.session_state[key] st.rerun() def load_conversation(conv_id): messages = history_manager.get_messages(conv_id) formatted_messages = [] for m in messages: # Convert DB models to session state dicts msg_dict = { "role": m.role, "content": m.content, "plots": [p.image_data for p in m.plots] } formatted_messages.append(msg_dict) st.session_state.messages = formatted_messages st.session_state.current_conversation_id = conv_id # Fetch summary from DB with history_manager.get_session() as session: from ea_chatbot.history.models import Conversation conv = session.get(Conversation, conv_id) st.session_state.summary = conv.summary if conv else "" st.rerun() def main(): st.set_page_config( page_title="Election Analytics Chatbot", page_icon="🗳️", layout="wide" ) # Check for OIDC Callback if "code" in st.query_params and "state" in st.query_params and oidc_client: code = st.query_params["code"] state = st.query_params["state"] # Validate state stored_state = st.session_state.get("oidc_state") if not stored_state or state != stored_state: st.error("OIDC state mismatch. Please try again.") st.stop() try: # 1. Exchange code using PKCE verifier verifier = st.session_state.get("oidc_verifier") token = oidc_client.exchange_code_for_token(code, code_verifier=verifier) # 2. Validate ID Token using Nonce id_token = token.get("id_token") nonce = st.session_state.get("oidc_nonce") claims = oidc_client.validate_id_token(id_token, nonce=nonce) email = claims.get("email") name = claims.get("name") or claims.get("preferred_username") if email: user = history_manager.sync_user_from_oidc(email=email, display_name=name) # Clear query params and session data st.query_params.clear() for key in ["oidc_state", "oidc_nonce", "oidc_verifier"]: if key in st.session_state: del st.session_state[key] login_user(user) except Exception as e: st.error(f"OIDC Login failed: {str(e)}") # Display Login Screen if not authenticated if "user" not in st.session_state: st.title("🗳️ Election Analytics Chatbot") # Initialize Login State if "login_step" not in st.session_state: st.session_state.login_step = "email" if "login_email" not in st.session_state: st.session_state.login_email = "" col1, col2 = st.columns([1, 1]) with col1: st.header("Login") # Step 1: Identification if st.session_state.login_step == "email": st.write("Please enter your email to begin:") with st.form("email_form"): email_input = st.text_input("Email", value=st.session_state.login_email) submitted = st.form_submit_button("Next") if submitted: if not email_input.strip(): st.error("Email cannot be empty.") else: st.session_state.login_email = email_input.strip() auth_type = get_user_auth_type(st.session_state.login_email, history_manager) if auth_type == AuthType.LOCAL: st.session_state.login_step = "login_password" elif auth_type == AuthType.OIDC: st.session_state.login_step = "oidc_login" else: # AuthType.NEW st.session_state.login_step = "register_details" st.rerun() # Step 2a: Local Login elif st.session_state.login_step == "login_password": st.info(f"Welcome back, **{st.session_state.login_email}**!") with st.form("password_form"): password = st.text_input("Password", type="password") col_login, col_back = st.columns([1, 1]) submitted = col_login.form_submit_button("Login") back = col_back.form_submit_button("Back") if back: st.session_state.login_step = "email" st.rerun() if submitted: user = history_manager.authenticate_user(st.session_state.login_email, password) if user: login_user(user) else: st.error("Invalid email or password") # Step 2b: Registration elif st.session_state.login_step == "register_details": st.info(f"Create an account for **{st.session_state.login_email}**") with st.form("register_form"): reg_name = st.text_input("Display Name") reg_password = st.text_input("Password", type="password") col_reg, col_back = st.columns([1, 1]) submitted = col_reg.form_submit_button("Register & Login") back = col_back.form_submit_button("Back") if back: st.session_state.login_step = "email" st.rerun() if submitted: if not reg_password: st.error("Password is required for registration.") else: user = history_manager.create_user(st.session_state.login_email, reg_password, reg_name) st.success("Registered! Logging in...") login_user(user) # Step 2c: OIDC Redirection elif st.session_state.login_step == "oidc_login": st.info(f"**{st.session_state.login_email}** is configured for Single Sign-On (SSO).") col_sso, col_back = st.columns([1, 1]) with col_sso: if oidc_client: auth_data = oidc_client.get_auth_data() # Store PKCE/Nonce in session state for callback validation st.session_state.oidc_state = auth_data["state"] st.session_state.oidc_nonce = auth_data["nonce"] st.session_state.oidc_verifier = auth_data["code_verifier"] st.link_button("Login with SSO", auth_data["url"], type="primary", use_container_width=True) else: st.error("OIDC is not configured.") with col_back: if st.button("Back", use_container_width=True): st.session_state.login_step = "email" st.rerun() with col2: if oidc_client: st.header("Single Sign-On") st.write("Login with your organizational account.") if st.button("Login with SSO"): auth_data = oidc_client.get_auth_data() st.session_state.oidc_state = auth_data["state"] st.session_state.oidc_nonce = auth_data["nonce"] st.session_state.oidc_verifier = auth_data["code_verifier"] st.link_button("Go to **Login Provider**", auth_data["url"], type="primary") else: st.info("SSO is not configured.") st.stop() # --- Main App (Authenticated) --- user = st.session_state.user # Sidebar configuration with st.sidebar: st.title(f"Hi, {user.display_name or user.username}!") if st.button("Logout"): logout_user() st.divider() st.header("History") if st.button("➕ New Chat", use_container_width=True): st.session_state.messages = [] st.session_state.summary = "" st.session_state.current_conversation_id = None st.rerun() # List conversations for the current user and data state conversations = history_manager.get_conversations(user.id, settings.data_state) for conv in conversations: col_c, col_r, col_d = st.columns([0.7, 0.15, 0.15]) is_current = st.session_state.get("current_conversation_id") == conv.id label = f"💬 {conv.name}" if not is_current else f"👉 {conv.name}" if col_c.button(label, key=f"conv_{conv.id}", use_container_width=True): load_conversation(conv.id) if col_r.button("✏️", key=f"ren_{conv.id}"): st.session_state.renaming_id = conv.id if col_d.button("🗑️", key=f"del_{conv.id}"): if history_manager.delete_conversation(conv.id): if is_current: st.session_state.current_conversation_id = None st.session_state.messages = [] st.rerun() # Rename dialog if st.session_state.get("renaming_id"): rid = st.session_state.renaming_id with st.form("rename_form"): new_name = st.text_input("New Name") if st.form_submit_button("Save"): history_manager.rename_conversation(rid, new_name) del st.session_state.renaming_id st.rerun() if st.form_submit_button("Cancel"): del st.session_state.renaming_id st.rerun() st.divider() st.header("Settings") # Check for DEV_MODE env var (defaults to False) default_dev_mode = os.getenv("DEV_MODE", "false").lower() == "true" dev_mode = st.checkbox("Dev Mode", value=default_dev_mode, help="Enable to see code generation and raw reasoning steps.") st.title("🗳️ Election Analytics Chatbot") # Initialize chat history state if "messages" not in st.session_state: st.session_state.messages = [] if "summary" not in st.session_state: st.session_state.summary = "" # Display chat messages from history on app rerun for message in st.session_state.messages: with st.chat_message(message["role"]): if message.get("plan") and dev_mode: with st.expander("Reasoning Plan"): st.code(message["plan"], language="yaml") if message.get("code") and dev_mode: with st.expander("Generated Code"): st.code(message["code"], language="python") st.markdown(message["content"]) if message.get("plots"): for plot_data in message["plots"]: # If plot_data is bytes, convert to image if isinstance(plot_data, bytes): st.image(plot_data) else: # Fallback for old session state or non-binary st.pyplot(plot_data) if message.get("dfs"): for df_name, df in message["dfs"].items(): st.subheader(f"Data: {df_name}") st.dataframe(df) # Accept user input if prompt := st.chat_input("Ask a question about election data..."): # Ensure we have a conversation ID if not st.session_state.get("current_conversation_id"): # Auto-create conversation conv_name = (prompt[:30] + '...') if len(prompt) > 30 else prompt conv = history_manager.create_conversation(user.id, settings.data_state, conv_name) st.session_state.current_conversation_id = conv.id conv_id = st.session_state.current_conversation_id # Save user message to DB history_manager.add_message(conv_id, "user", prompt) # Add user message to session state st.session_state.messages.append({"role": "user", "content": prompt}) # Display user message in chat message container with st.chat_message("user"): st.markdown(prompt) # Prepare graph input initial_state: AgentState = { "messages": st.session_state.messages[:-1], # Pass history (excluding the current prompt) "question": prompt, "summary": st.session_state.summary, "analysis": None, "next_action": "", "plan": None, "code": None, "code_output": None, "error": None, "plots": [], "dfs": {}, "iterations": 0 } # Placeholder for graph output with st.chat_message("assistant"): final_state = initial_state # Real-time node updates with st.status("Thinking...", expanded=True) as status: try: # Use app.stream to capture node transitions for event in app.stream(initial_state): for node_name, state_update in event.items(): prev_error = final_state.get("error") # Use helper to merge state correctly (appending messages/plots, updating dfs) final_state = merge_agent_state(final_state, state_update) if node_name == "query_analyzer": analysis = state_update.get("analysis", {}) next_action = state_update.get("next_action", "unknown") status.write("🔍 **Analyzed Query:**") for k,v in analysis.items(): status.write(f"- {k:<8}: {v}") status.markdown(f"Next Step: {next_action.capitalize()}") elif node_name == "planner": status.write("📋 **Plan Generated**") # Render artifacts if state_update.get("plan") and dev_mode: with st.expander("Reasoning Plan", expanded=True): st.code(state_update["plan"], language="yaml") elif node_name == "researcher": status.write("🌐 **Research Complete**") if state_update.get("messages") and dev_mode: for msg in state_update["messages"]: # Extract content from BaseMessage or show raw string content = getattr(msg, "text", msg.content) status.markdown(content) elif node_name == "coder": status.write("💻 **Code Generated**") if state_update.get("code") and dev_mode: with st.expander("Generated Code"): st.code(state_update["code"], language="python") elif node_name == "error_corrector": status.write("🛠️ **Fixing Execution Error...**") if prev_error: truncated_error = prev_error.strip() if len(truncated_error) > 180: truncated_error = truncated_error[:180] + "..." status.write(f"Previous error: {truncated_error}") if state_update.get("code") and dev_mode: with st.expander("Corrected Code"): st.code(state_update["code"], language="python") elif node_name == "executor": if state_update.get("error"): if dev_mode: status.write(f"❌ **Execution Error:** {state_update.get('error')}...") else: status.write(f"❌ **Execution Error:** {state_update.get('error')[:100]}...") else: status.write("✅ **Execution Successful**") if state_update.get("plots"): status.write(f"📊 Generated {len(state_update['plots'])} plot(s)") elif node_name == "summarizer": status.write("📝 **Summarizing Results...**") status.update(label="Complete!", state="complete", expanded=False) except Exception as e: status.update(label="Error!", state="error") st.error(f"Error during graph execution: {str(e)}") # Extract results response_text: str = "" if final_state.get("messages"): # The last message is the Assistant's response last_msg = final_state["messages"][-1] response_text = getattr(last_msg, "text", str(last_msg.content)) st.markdown(response_text) # Collect plot bytes for saving to DB plot_bytes_list = [] if final_state.get("plots"): for fig in final_state["plots"]: st.pyplot(fig) # Convert fig to bytes plot_bytes_list.append(fig_to_bytes(fig)) if final_state.get("dfs"): for df_name, df in final_state["dfs"].items(): st.subheader(f"Data: {df_name}") st.dataframe(df) # Save assistant message to DB history_manager.add_message(conv_id, "assistant", response_text, plots=plot_bytes_list) # Update summary in DB new_summary = final_state.get("summary", "") if new_summary: history_manager.update_conversation_summary(conv_id, new_summary) # Store assistant response in session history st.session_state.messages.append({ "role": "assistant", "content": response_text, "plan": final_state.get("plan"), "code": final_state.get("code"), "plots": plot_bytes_list, "dfs": final_state.get("dfs") }) st.session_state.summary = new_summary if __name__ == "__main__": main()