fix: Restore OIDC login in Streamlit app using PKCE/Nonce flow

This commit is contained in:
Yunxiao Xu
2026-02-17 02:36:49 -08:00
parent a94cbc7f6d
commit 23471350df

View File

@@ -75,18 +75,36 @@ def main():
)
# Check for OIDC Callback
if "code" in st.query_params and oidc_client:
if "code" in st.query_params and "state" in st.query_params and oidc_client:
code = st.query_params["code"]
state = st.query_params["state"]
# Validate state
stored_state = st.session_state.get("oidc_state")
if not stored_state or state != stored_state:
st.error("OIDC state mismatch. Please try again.")
st.stop()
try:
token = oidc_client.exchange_code_for_token(code)
user_info = oidc_client.get_user_info(token)
email = user_info.get("email")
name = user_info.get("name") or user_info.get("preferred_username")
# 1. Exchange code using PKCE verifier
verifier = st.session_state.get("oidc_verifier")
token = oidc_client.exchange_code_for_token(code, code_verifier=verifier)
# 2. Validate ID Token using Nonce
id_token = token.get("id_token")
nonce = st.session_state.get("oidc_nonce")
claims = oidc_client.validate_id_token(id_token, nonce=nonce)
email = claims.get("email")
name = claims.get("name") or claims.get("preferred_username")
if email:
user = history_manager.sync_user_from_oidc(email=email, display_name=name)
# Clear query params
# Clear query params and session data
st.query_params.clear()
for key in ["oidc_state", "oidc_nonce", "oidc_verifier"]:
if key in st.session_state:
del st.session_state[key]
login_user(user)
except Exception as e:
st.error(f"OIDC Login failed: {str(e)}")
@@ -180,8 +198,13 @@ def main():
with col_sso:
if oidc_client:
login_url = oidc_client.get_login_url()
st.link_button("Login with SSO", login_url, type="primary", use_container_width=True)
auth_data = oidc_client.get_auth_data()
# Store PKCE/Nonce in session state for callback validation
st.session_state.oidc_state = auth_data["state"]
st.session_state.oidc_nonce = auth_data["nonce"]
st.session_state.oidc_verifier = auth_data["code_verifier"]
st.link_button("Login with SSO", auth_data["url"], type="primary", use_container_width=True)
else:
st.error("OIDC is not configured.")
@@ -195,8 +218,11 @@ def main():
st.header("Single Sign-On")
st.write("Login with your organizational account.")
if st.button("Login with SSO"):
login_url = oidc_client.get_login_url()
st.link_button("Go to **YXXU**", login_url, type="primary")
auth_data = oidc_client.get_auth_data()
st.session_state.oidc_state = auth_data["state"]
st.session_state.oidc_nonce = auth_data["nonce"]
st.session_state.oidc_verifier = auth_data["code_verifier"]
st.link_button("Go to **Login Provider**", auth_data["url"], type="primary")
else:
st.info("SSO is not configured.")