fix: Restore OIDC login in Streamlit app using PKCE/Nonce flow
This commit is contained in:
@@ -75,18 +75,36 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Check for OIDC Callback
|
# Check for OIDC Callback
|
||||||
if "code" in st.query_params and oidc_client:
|
if "code" in st.query_params and "state" in st.query_params and oidc_client:
|
||||||
code = st.query_params["code"]
|
code = st.query_params["code"]
|
||||||
|
state = st.query_params["state"]
|
||||||
|
|
||||||
|
# Validate state
|
||||||
|
stored_state = st.session_state.get("oidc_state")
|
||||||
|
if not stored_state or state != stored_state:
|
||||||
|
st.error("OIDC state mismatch. Please try again.")
|
||||||
|
st.stop()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
token = oidc_client.exchange_code_for_token(code)
|
# 1. Exchange code using PKCE verifier
|
||||||
user_info = oidc_client.get_user_info(token)
|
verifier = st.session_state.get("oidc_verifier")
|
||||||
email = user_info.get("email")
|
token = oidc_client.exchange_code_for_token(code, code_verifier=verifier)
|
||||||
name = user_info.get("name") or user_info.get("preferred_username")
|
|
||||||
|
# 2. Validate ID Token using Nonce
|
||||||
|
id_token = token.get("id_token")
|
||||||
|
nonce = st.session_state.get("oidc_nonce")
|
||||||
|
claims = oidc_client.validate_id_token(id_token, nonce=nonce)
|
||||||
|
|
||||||
|
email = claims.get("email")
|
||||||
|
name = claims.get("name") or claims.get("preferred_username")
|
||||||
|
|
||||||
if email:
|
if email:
|
||||||
user = history_manager.sync_user_from_oidc(email=email, display_name=name)
|
user = history_manager.sync_user_from_oidc(email=email, display_name=name)
|
||||||
# Clear query params
|
# Clear query params and session data
|
||||||
st.query_params.clear()
|
st.query_params.clear()
|
||||||
|
for key in ["oidc_state", "oidc_nonce", "oidc_verifier"]:
|
||||||
|
if key in st.session_state:
|
||||||
|
del st.session_state[key]
|
||||||
login_user(user)
|
login_user(user)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
st.error(f"OIDC Login failed: {str(e)}")
|
st.error(f"OIDC Login failed: {str(e)}")
|
||||||
@@ -180,8 +198,13 @@ def main():
|
|||||||
|
|
||||||
with col_sso:
|
with col_sso:
|
||||||
if oidc_client:
|
if oidc_client:
|
||||||
login_url = oidc_client.get_login_url()
|
auth_data = oidc_client.get_auth_data()
|
||||||
st.link_button("Login with SSO", login_url, type="primary", use_container_width=True)
|
# Store PKCE/Nonce in session state for callback validation
|
||||||
|
st.session_state.oidc_state = auth_data["state"]
|
||||||
|
st.session_state.oidc_nonce = auth_data["nonce"]
|
||||||
|
st.session_state.oidc_verifier = auth_data["code_verifier"]
|
||||||
|
|
||||||
|
st.link_button("Login with SSO", auth_data["url"], type="primary", use_container_width=True)
|
||||||
else:
|
else:
|
||||||
st.error("OIDC is not configured.")
|
st.error("OIDC is not configured.")
|
||||||
|
|
||||||
@@ -195,8 +218,11 @@ def main():
|
|||||||
st.header("Single Sign-On")
|
st.header("Single Sign-On")
|
||||||
st.write("Login with your organizational account.")
|
st.write("Login with your organizational account.")
|
||||||
if st.button("Login with SSO"):
|
if st.button("Login with SSO"):
|
||||||
login_url = oidc_client.get_login_url()
|
auth_data = oidc_client.get_auth_data()
|
||||||
st.link_button("Go to **YXXU**", login_url, type="primary")
|
st.session_state.oidc_state = auth_data["state"]
|
||||||
|
st.session_state.oidc_nonce = auth_data["nonce"]
|
||||||
|
st.session_state.oidc_verifier = auth_data["code_verifier"]
|
||||||
|
st.link_button("Go to **Login Provider**", auth_data["url"], type="primary")
|
||||||
else:
|
else:
|
||||||
st.info("SSO is not configured.")
|
st.info("SSO is not configured.")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user