feat(auth): Add PKCE and nonce support to OIDCClient

This commit is contained in:
Yunxiao Xu
2026-02-14 04:23:52 -08:00
parent 05261e3cda
commit ff0189a69b
2 changed files with 117 additions and 6 deletions

View File

@@ -1,8 +1,10 @@
import requests
from datetime import datetime, timedelta, timezone
from enum import Enum
from typing import Dict, Any, Optional
from typing import Dict, Any, Optional, Tuple
from authlib.integrations.requests_client import OAuth2Session
from authlib.common.security import generate_token
from authlib.oauth2.rfc7636 import create_s256_code_challenge
from jose import JWTError, jwt
class AuthType(Enum):
@@ -59,7 +61,7 @@ class OIDCSession:
class OIDCClient:
"""
Client for OIDC Authentication using Authlib.
Designed to work within a Streamlit environment.
Designed to work within a modern BFF architecture with PKCE and state/nonce validation.
"""
def __init__(
self,
@@ -86,8 +88,47 @@ class OIDCClient:
self.metadata = requests.get(self.server_metadata_url).json()
return self.metadata
def generate_pkce(self) -> Tuple[str, str]:
"""Generate PKCE code_verifier and code_challenge."""
code_verifier = generate_token(48)
code_challenge = create_s256_code_challenge(code_verifier)
return code_verifier, code_challenge
def generate_nonce(self) -> str:
"""Generate a random nonce for ID Token validation."""
return generate_token(16)
def get_auth_data(self) -> Dict[str, str]:
"""
Generate all required data for OIDC initiation.
Returns a dict with 'url', 'state', 'nonce', and 'code_verifier'.
"""
metadata = self.fetch_metadata()
authorization_endpoint = metadata.get("authorization_endpoint")
if not authorization_endpoint:
raise ValueError("authorization_endpoint not found in OIDC metadata")
state = generate_token(16)
nonce = self.generate_nonce()
code_verifier, code_challenge = self.generate_pkce()
uri, _ = self.oauth_session.create_authorization_url(
authorization_endpoint,
state=state,
nonce=nonce,
code_challenge=code_challenge,
code_challenge_method="S256"
)
return {
"url": uri,
"state": state,
"nonce": nonce,
"code_verifier": code_verifier
}
def get_login_url(self) -> str:
"""Generate the authorization URL."""
"""Legacy method for generating simple authorization URL."""
metadata = self.fetch_metadata()
authorization_endpoint = metadata.get("authorization_endpoint")
if not authorization_endpoint:
@@ -96,17 +137,22 @@ class OIDCClient:
uri, state = self.oauth_session.create_authorization_url(authorization_endpoint)
return uri
def exchange_code_for_token(self, code: str) -> Dict[str, Any]:
"""Exchange the authorization code for an access token."""
def exchange_code_for_token(self, code: str, code_verifier: Optional[str] = None) -> Dict[str, Any]:
"""Exchange the authorization code for an access token, optionally using PKCE verifier."""
metadata = self.fetch_metadata()
token_endpoint = metadata.get("token_endpoint")
if not token_endpoint:
raise ValueError("token_endpoint not found in OIDC metadata")
kwargs = {}
if code_verifier:
kwargs["code_verifier"] = code_verifier
token = self.oauth_session.fetch_token(
token_endpoint,
code=code,
client_secret=self.client_secret
client_secret=self.client_secret,
**kwargs
)
return token
@@ -119,6 +165,8 @@ class OIDCClient:
# Set the token on the session so it's used in the request
self.oauth_session.token = token
# For OIDC, the sub should be in the token (ID Token) usually.
# But we use userinfo endpoint as well.
resp = self.oauth_session.get(userinfo_endpoint)
resp.raise_for_status()
return resp.json()