feat(auth): Add PKCE and nonce support to OIDCClient
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
import requests
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
from typing import Dict, Any, Optional
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
from authlib.integrations.requests_client import OAuth2Session
|
||||
from authlib.common.security import generate_token
|
||||
from authlib.oauth2.rfc7636 import create_s256_code_challenge
|
||||
from jose import JWTError, jwt
|
||||
|
||||
class AuthType(Enum):
|
||||
@@ -59,7 +61,7 @@ class OIDCSession:
|
||||
class OIDCClient:
|
||||
"""
|
||||
Client for OIDC Authentication using Authlib.
|
||||
Designed to work within a Streamlit environment.
|
||||
Designed to work within a modern BFF architecture with PKCE and state/nonce validation.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
@@ -86,8 +88,47 @@ class OIDCClient:
|
||||
self.metadata = requests.get(self.server_metadata_url).json()
|
||||
return self.metadata
|
||||
|
||||
def generate_pkce(self) -> Tuple[str, str]:
|
||||
"""Generate PKCE code_verifier and code_challenge."""
|
||||
code_verifier = generate_token(48)
|
||||
code_challenge = create_s256_code_challenge(code_verifier)
|
||||
return code_verifier, code_challenge
|
||||
|
||||
def generate_nonce(self) -> str:
|
||||
"""Generate a random nonce for ID Token validation."""
|
||||
return generate_token(16)
|
||||
|
||||
def get_auth_data(self) -> Dict[str, str]:
|
||||
"""
|
||||
Generate all required data for OIDC initiation.
|
||||
Returns a dict with 'url', 'state', 'nonce', and 'code_verifier'.
|
||||
"""
|
||||
metadata = self.fetch_metadata()
|
||||
authorization_endpoint = metadata.get("authorization_endpoint")
|
||||
if not authorization_endpoint:
|
||||
raise ValueError("authorization_endpoint not found in OIDC metadata")
|
||||
|
||||
state = generate_token(16)
|
||||
nonce = self.generate_nonce()
|
||||
code_verifier, code_challenge = self.generate_pkce()
|
||||
|
||||
uri, _ = self.oauth_session.create_authorization_url(
|
||||
authorization_endpoint,
|
||||
state=state,
|
||||
nonce=nonce,
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method="S256"
|
||||
)
|
||||
|
||||
return {
|
||||
"url": uri,
|
||||
"state": state,
|
||||
"nonce": nonce,
|
||||
"code_verifier": code_verifier
|
||||
}
|
||||
|
||||
def get_login_url(self) -> str:
|
||||
"""Generate the authorization URL."""
|
||||
"""Legacy method for generating simple authorization URL."""
|
||||
metadata = self.fetch_metadata()
|
||||
authorization_endpoint = metadata.get("authorization_endpoint")
|
||||
if not authorization_endpoint:
|
||||
@@ -96,17 +137,22 @@ class OIDCClient:
|
||||
uri, state = self.oauth_session.create_authorization_url(authorization_endpoint)
|
||||
return uri
|
||||
|
||||
def exchange_code_for_token(self, code: str) -> Dict[str, Any]:
|
||||
"""Exchange the authorization code for an access token."""
|
||||
def exchange_code_for_token(self, code: str, code_verifier: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Exchange the authorization code for an access token, optionally using PKCE verifier."""
|
||||
metadata = self.fetch_metadata()
|
||||
token_endpoint = metadata.get("token_endpoint")
|
||||
if not token_endpoint:
|
||||
raise ValueError("token_endpoint not found in OIDC metadata")
|
||||
|
||||
kwargs = {}
|
||||
if code_verifier:
|
||||
kwargs["code_verifier"] = code_verifier
|
||||
|
||||
token = self.oauth_session.fetch_token(
|
||||
token_endpoint,
|
||||
code=code,
|
||||
client_secret=self.client_secret
|
||||
client_secret=self.client_secret,
|
||||
**kwargs
|
||||
)
|
||||
return token
|
||||
|
||||
@@ -119,6 +165,8 @@ class OIDCClient:
|
||||
|
||||
# Set the token on the session so it's used in the request
|
||||
self.oauth_session.token = token
|
||||
# For OIDC, the sub should be in the token (ID Token) usually.
|
||||
# But we use userinfo endpoint as well.
|
||||
resp = self.oauth_session.get(userinfo_endpoint)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
63
backend/tests/test_oidc_client_v2.py
Normal file
63
backend/tests/test_oidc_client_v2.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.auth import OIDCClient
|
||||
|
||||
@pytest.fixture
|
||||
def oidc_config():
|
||||
return {
|
||||
"client_id": "test_id",
|
||||
"client_secret": "test_secret",
|
||||
"server_metadata_url": "https://example.com/.well-known/openid-configuration",
|
||||
"redirect_uri": "http://localhost:5173/auth/callback"
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_metadata():
|
||||
return {
|
||||
"authorization_endpoint": "https://example.com/auth",
|
||||
"token_endpoint": "https://example.com/token",
|
||||
"userinfo_endpoint": "https://example.com/userinfo",
|
||||
"jwks_uri": "https://example.com/jwks"
|
||||
}
|
||||
|
||||
def test_oidc_generate_pkce(oidc_config):
|
||||
client = OIDCClient(**oidc_config)
|
||||
code_verifier, code_challenge = client.generate_pkce()
|
||||
assert len(code_verifier) >= 43
|
||||
assert len(code_challenge) > 0
|
||||
# Verifier should be high entropy
|
||||
assert code_verifier != client.generate_pkce()[0]
|
||||
|
||||
def test_oidc_get_auth_data(oidc_config, mock_metadata):
|
||||
client = OIDCClient(**oidc_config)
|
||||
client.metadata = mock_metadata
|
||||
|
||||
auth_data = client.get_auth_data()
|
||||
|
||||
assert "url" in auth_data
|
||||
assert "state" in auth_data
|
||||
assert "nonce" in auth_data
|
||||
assert "code_verifier" in auth_data
|
||||
|
||||
url = auth_data["url"]
|
||||
assert "code_challenge=" in url
|
||||
assert "code_challenge_method=S256" in url
|
||||
assert f"state={auth_data['state']}" in url
|
||||
assert f"nonce={auth_data['nonce']}" in url
|
||||
|
||||
def test_oidc_exchange_code_with_pkce(oidc_config, mock_metadata):
|
||||
client = OIDCClient(**oidc_config)
|
||||
client.metadata = mock_metadata
|
||||
|
||||
with patch.object(client.oauth_session, "fetch_token") as mock_fetch:
|
||||
mock_fetch.return_value = {"access_token": "token"}
|
||||
|
||||
token = client.exchange_code_for_token("code", code_verifier="verifier")
|
||||
|
||||
assert token == {"access_token": "token"}
|
||||
mock_fetch.assert_called_once_with(
|
||||
mock_metadata["token_endpoint"],
|
||||
code="code",
|
||||
client_secret=oidc_config["client_secret"],
|
||||
code_verifier="verifier"
|
||||
)
|
||||
Reference in New Issue
Block a user