feat(auth): Add PKCE and nonce support to OIDCClient

This commit is contained in:
Yunxiao Xu
2026-02-14 04:23:52 -08:00
parent 05261e3cda
commit ff0189a69b
2 changed files with 117 additions and 6 deletions

View File

@@ -1,8 +1,10 @@
import requests import requests
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from enum import Enum from enum import Enum
from typing import Dict, Any, Optional from typing import Dict, Any, Optional, Tuple
from authlib.integrations.requests_client import OAuth2Session from authlib.integrations.requests_client import OAuth2Session
from authlib.common.security import generate_token
from authlib.oauth2.rfc7636 import create_s256_code_challenge
from jose import JWTError, jwt from jose import JWTError, jwt
class AuthType(Enum): class AuthType(Enum):
@@ -59,7 +61,7 @@ class OIDCSession:
class OIDCClient: class OIDCClient:
""" """
Client for OIDC Authentication using Authlib. Client for OIDC Authentication using Authlib.
Designed to work within a Streamlit environment. Designed to work within a modern BFF architecture with PKCE and state/nonce validation.
""" """
def __init__( def __init__(
self, self,
@@ -86,8 +88,47 @@ class OIDCClient:
self.metadata = requests.get(self.server_metadata_url).json() self.metadata = requests.get(self.server_metadata_url).json()
return self.metadata return self.metadata
def generate_pkce(self) -> Tuple[str, str]:
"""Generate PKCE code_verifier and code_challenge."""
code_verifier = generate_token(48)
code_challenge = create_s256_code_challenge(code_verifier)
return code_verifier, code_challenge
def generate_nonce(self) -> str:
"""Generate a random nonce for ID Token validation."""
return generate_token(16)
def get_auth_data(self) -> Dict[str, str]:
"""
Generate all required data for OIDC initiation.
Returns a dict with 'url', 'state', 'nonce', and 'code_verifier'.
"""
metadata = self.fetch_metadata()
authorization_endpoint = metadata.get("authorization_endpoint")
if not authorization_endpoint:
raise ValueError("authorization_endpoint not found in OIDC metadata")
state = generate_token(16)
nonce = self.generate_nonce()
code_verifier, code_challenge = self.generate_pkce()
uri, _ = self.oauth_session.create_authorization_url(
authorization_endpoint,
state=state,
nonce=nonce,
code_challenge=code_challenge,
code_challenge_method="S256"
)
return {
"url": uri,
"state": state,
"nonce": nonce,
"code_verifier": code_verifier
}
def get_login_url(self) -> str: def get_login_url(self) -> str:
"""Generate the authorization URL.""" """Legacy method for generating simple authorization URL."""
metadata = self.fetch_metadata() metadata = self.fetch_metadata()
authorization_endpoint = metadata.get("authorization_endpoint") authorization_endpoint = metadata.get("authorization_endpoint")
if not authorization_endpoint: if not authorization_endpoint:
@@ -96,17 +137,22 @@ class OIDCClient:
uri, state = self.oauth_session.create_authorization_url(authorization_endpoint) uri, state = self.oauth_session.create_authorization_url(authorization_endpoint)
return uri return uri
def exchange_code_for_token(self, code: str) -> Dict[str, Any]: def exchange_code_for_token(self, code: str, code_verifier: Optional[str] = None) -> Dict[str, Any]:
"""Exchange the authorization code for an access token.""" """Exchange the authorization code for an access token, optionally using PKCE verifier."""
metadata = self.fetch_metadata() metadata = self.fetch_metadata()
token_endpoint = metadata.get("token_endpoint") token_endpoint = metadata.get("token_endpoint")
if not token_endpoint: if not token_endpoint:
raise ValueError("token_endpoint not found in OIDC metadata") raise ValueError("token_endpoint not found in OIDC metadata")
kwargs = {}
if code_verifier:
kwargs["code_verifier"] = code_verifier
token = self.oauth_session.fetch_token( token = self.oauth_session.fetch_token(
token_endpoint, token_endpoint,
code=code, code=code,
client_secret=self.client_secret client_secret=self.client_secret,
**kwargs
) )
return token return token
@@ -119,6 +165,8 @@ class OIDCClient:
# Set the token on the session so it's used in the request # Set the token on the session so it's used in the request
self.oauth_session.token = token self.oauth_session.token = token
# For OIDC, the sub should be in the token (ID Token) usually.
# But we use userinfo endpoint as well.
resp = self.oauth_session.get(userinfo_endpoint) resp = self.oauth_session.get(userinfo_endpoint)
resp.raise_for_status() resp.raise_for_status()
return resp.json() return resp.json()

View File

@@ -0,0 +1,63 @@
import pytest
from unittest.mock import MagicMock, patch
from ea_chatbot.auth import OIDCClient
@pytest.fixture
def oidc_config():
return {
"client_id": "test_id",
"client_secret": "test_secret",
"server_metadata_url": "https://example.com/.well-known/openid-configuration",
"redirect_uri": "http://localhost:5173/auth/callback"
}
@pytest.fixture
def mock_metadata():
return {
"authorization_endpoint": "https://example.com/auth",
"token_endpoint": "https://example.com/token",
"userinfo_endpoint": "https://example.com/userinfo",
"jwks_uri": "https://example.com/jwks"
}
def test_oidc_generate_pkce(oidc_config):
client = OIDCClient(**oidc_config)
code_verifier, code_challenge = client.generate_pkce()
assert len(code_verifier) >= 43
assert len(code_challenge) > 0
# Verifier should be high entropy
assert code_verifier != client.generate_pkce()[0]
def test_oidc_get_auth_data(oidc_config, mock_metadata):
client = OIDCClient(**oidc_config)
client.metadata = mock_metadata
auth_data = client.get_auth_data()
assert "url" in auth_data
assert "state" in auth_data
assert "nonce" in auth_data
assert "code_verifier" in auth_data
url = auth_data["url"]
assert "code_challenge=" in url
assert "code_challenge_method=S256" in url
assert f"state={auth_data['state']}" in url
assert f"nonce={auth_data['nonce']}" in url
def test_oidc_exchange_code_with_pkce(oidc_config, mock_metadata):
client = OIDCClient(**oidc_config)
client.metadata = mock_metadata
with patch.object(client.oauth_session, "fetch_token") as mock_fetch:
mock_fetch.return_value = {"access_token": "token"}
token = client.exchange_code_for_token("code", code_verifier="verifier")
assert token == {"access_token": "token"}
mock_fetch.assert_called_once_with(
mock_metadata["token_endpoint"],
code="code",
client_secret=oidc_config["client_secret"],
code_verifier="verifier"
)