From ff0189a69bc5f33fec548642f7709e47d6327ff3 Mon Sep 17 00:00:00 2001 From: Yunxiao Xu Date: Sat, 14 Feb 2026 04:23:52 -0800 Subject: [PATCH] feat(auth): Add PKCE and nonce support to OIDCClient --- backend/src/ea_chatbot/auth.py | 60 +++++++++++++++++++++++--- backend/tests/test_oidc_client_v2.py | 63 ++++++++++++++++++++++++++++ 2 files changed, 117 insertions(+), 6 deletions(-) create mode 100644 backend/tests/test_oidc_client_v2.py diff --git a/backend/src/ea_chatbot/auth.py b/backend/src/ea_chatbot/auth.py index 29a7796..ae88fa9 100644 --- a/backend/src/ea_chatbot/auth.py +++ b/backend/src/ea_chatbot/auth.py @@ -1,8 +1,10 @@ import requests from datetime import datetime, timedelta, timezone from enum import Enum -from typing import Dict, Any, Optional +from typing import Dict, Any, Optional, Tuple from authlib.integrations.requests_client import OAuth2Session +from authlib.common.security import generate_token +from authlib.oauth2.rfc7636 import create_s256_code_challenge from jose import JWTError, jwt class AuthType(Enum): @@ -59,7 +61,7 @@ class OIDCSession: class OIDCClient: """ Client for OIDC Authentication using Authlib. - Designed to work within a Streamlit environment. + Designed to work within a modern BFF architecture with PKCE and state/nonce validation. """ def __init__( self, @@ -86,8 +88,47 @@ class OIDCClient: self.metadata = requests.get(self.server_metadata_url).json() return self.metadata + def generate_pkce(self) -> Tuple[str, str]: + """Generate PKCE code_verifier and code_challenge.""" + code_verifier = generate_token(48) + code_challenge = create_s256_code_challenge(code_verifier) + return code_verifier, code_challenge + + def generate_nonce(self) -> str: + """Generate a random nonce for ID Token validation.""" + return generate_token(16) + + def get_auth_data(self) -> Dict[str, str]: + """ + Generate all required data for OIDC initiation. + Returns a dict with 'url', 'state', 'nonce', and 'code_verifier'. + """ + metadata = self.fetch_metadata() + authorization_endpoint = metadata.get("authorization_endpoint") + if not authorization_endpoint: + raise ValueError("authorization_endpoint not found in OIDC metadata") + + state = generate_token(16) + nonce = self.generate_nonce() + code_verifier, code_challenge = self.generate_pkce() + + uri, _ = self.oauth_session.create_authorization_url( + authorization_endpoint, + state=state, + nonce=nonce, + code_challenge=code_challenge, + code_challenge_method="S256" + ) + + return { + "url": uri, + "state": state, + "nonce": nonce, + "code_verifier": code_verifier + } + def get_login_url(self) -> str: - """Generate the authorization URL.""" + """Legacy method for generating simple authorization URL.""" metadata = self.fetch_metadata() authorization_endpoint = metadata.get("authorization_endpoint") if not authorization_endpoint: @@ -96,17 +137,22 @@ class OIDCClient: uri, state = self.oauth_session.create_authorization_url(authorization_endpoint) return uri - def exchange_code_for_token(self, code: str) -> Dict[str, Any]: - """Exchange the authorization code for an access token.""" + def exchange_code_for_token(self, code: str, code_verifier: Optional[str] = None) -> Dict[str, Any]: + """Exchange the authorization code for an access token, optionally using PKCE verifier.""" metadata = self.fetch_metadata() token_endpoint = metadata.get("token_endpoint") if not token_endpoint: raise ValueError("token_endpoint not found in OIDC metadata") + kwargs = {} + if code_verifier: + kwargs["code_verifier"] = code_verifier + token = self.oauth_session.fetch_token( token_endpoint, code=code, - client_secret=self.client_secret + client_secret=self.client_secret, + **kwargs ) return token @@ -119,6 +165,8 @@ class OIDCClient: # Set the token on the session so it's used in the request self.oauth_session.token = token + # For OIDC, the sub should be in the token (ID Token) usually. + # But we use userinfo endpoint as well. resp = self.oauth_session.get(userinfo_endpoint) resp.raise_for_status() return resp.json() diff --git a/backend/tests/test_oidc_client_v2.py b/backend/tests/test_oidc_client_v2.py new file mode 100644 index 0000000..6d7beed --- /dev/null +++ b/backend/tests/test_oidc_client_v2.py @@ -0,0 +1,63 @@ +import pytest +from unittest.mock import MagicMock, patch +from ea_chatbot.auth import OIDCClient + +@pytest.fixture +def oidc_config(): + return { + "client_id": "test_id", + "client_secret": "test_secret", + "server_metadata_url": "https://example.com/.well-known/openid-configuration", + "redirect_uri": "http://localhost:5173/auth/callback" + } + +@pytest.fixture +def mock_metadata(): + return { + "authorization_endpoint": "https://example.com/auth", + "token_endpoint": "https://example.com/token", + "userinfo_endpoint": "https://example.com/userinfo", + "jwks_uri": "https://example.com/jwks" + } + +def test_oidc_generate_pkce(oidc_config): + client = OIDCClient(**oidc_config) + code_verifier, code_challenge = client.generate_pkce() + assert len(code_verifier) >= 43 + assert len(code_challenge) > 0 + # Verifier should be high entropy + assert code_verifier != client.generate_pkce()[0] + +def test_oidc_get_auth_data(oidc_config, mock_metadata): + client = OIDCClient(**oidc_config) + client.metadata = mock_metadata + + auth_data = client.get_auth_data() + + assert "url" in auth_data + assert "state" in auth_data + assert "nonce" in auth_data + assert "code_verifier" in auth_data + + url = auth_data["url"] + assert "code_challenge=" in url + assert "code_challenge_method=S256" in url + assert f"state={auth_data['state']}" in url + assert f"nonce={auth_data['nonce']}" in url + +def test_oidc_exchange_code_with_pkce(oidc_config, mock_metadata): + client = OIDCClient(**oidc_config) + client.metadata = mock_metadata + + with patch.object(client.oauth_session, "fetch_token") as mock_fetch: + mock_fetch.return_value = {"access_token": "token"} + + token = client.exchange_code_for_token("code", code_verifier="verifier") + + assert token == {"access_token": "token"} + mock_fetch.assert_called_once_with( + mock_metadata["token_endpoint"], + code="code", + client_secret=oidc_config["client_secret"], + code_verifier="verifier" + )