feat(auth): Implement ID Token validation in OIDCClient
This commit is contained in:
@@ -81,6 +81,7 @@ class OIDCClient:
|
|||||||
scope="openid email profile"
|
scope="openid email profile"
|
||||||
)
|
)
|
||||||
self.metadata: Dict[str, Any] = {}
|
self.metadata: Dict[str, Any] = {}
|
||||||
|
self.jwks: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
def fetch_metadata(self) -> Dict[str, Any]:
|
def fetch_metadata(self) -> Dict[str, Any]:
|
||||||
"""Fetch OIDC provider metadata if not already fetched."""
|
"""Fetch OIDC provider metadata if not already fetched."""
|
||||||
@@ -88,6 +89,16 @@ class OIDCClient:
|
|||||||
self.metadata = requests.get(self.server_metadata_url).json()
|
self.metadata = requests.get(self.server_metadata_url).json()
|
||||||
return self.metadata
|
return self.metadata
|
||||||
|
|
||||||
|
def fetch_jwks(self) -> Dict[str, Any]:
|
||||||
|
"""Fetch OIDC provider JWKS if not already fetched."""
|
||||||
|
if not self.jwks:
|
||||||
|
metadata = self.fetch_metadata()
|
||||||
|
jwks_uri = metadata.get("jwks_uri")
|
||||||
|
if not jwks_uri:
|
||||||
|
raise ValueError("jwks_uri not found in OIDC metadata")
|
||||||
|
self.jwks = requests.get(jwks_uri).json()
|
||||||
|
return self.jwks
|
||||||
|
|
||||||
def generate_pkce(self) -> Tuple[str, str]:
|
def generate_pkce(self) -> Tuple[str, str]:
|
||||||
"""Generate PKCE code_verifier and code_challenge."""
|
"""Generate PKCE code_verifier and code_challenge."""
|
||||||
code_verifier = generate_token(48)
|
code_verifier = generate_token(48)
|
||||||
@@ -127,6 +138,30 @@ class OIDCClient:
|
|||||||
"code_verifier": code_verifier
|
"code_verifier": code_verifier
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def validate_id_token(self, token: str, nonce: Optional[str] = None) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Validate the ID Token and return its claims.
|
||||||
|
Verifies signature (using JWKS), issuer, audience, and nonce.
|
||||||
|
"""
|
||||||
|
metadata = self.fetch_metadata()
|
||||||
|
jwks = self.fetch_jwks()
|
||||||
|
|
||||||
|
try:
|
||||||
|
claims = jwt.decode(
|
||||||
|
token,
|
||||||
|
jwks,
|
||||||
|
algorithms=metadata.get("id_token_signing_alg_values_supported", ["RS256"]),
|
||||||
|
audience=self.client_id,
|
||||||
|
issuer=metadata.get("issuer")
|
||||||
|
)
|
||||||
|
|
||||||
|
if nonce and claims.get("nonce") != nonce:
|
||||||
|
raise ValueError("Invalid nonce")
|
||||||
|
|
||||||
|
return claims
|
||||||
|
except JWTError as e:
|
||||||
|
raise ValueError(f"ID Token validation failed: {str(e)}")
|
||||||
|
|
||||||
def get_login_url(self) -> str:
|
def get_login_url(self) -> str:
|
||||||
"""Legacy method for generating simple authorization URL."""
|
"""Legacy method for generating simple authorization URL."""
|
||||||
metadata = self.fetch_metadata()
|
metadata = self.fetch_metadata()
|
||||||
|
|||||||
68
backend/tests/test_oidc_validation.py
Normal file
68
backend/tests/test_oidc_validation.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
from ea_chatbot.auth import OIDCClient
|
||||||
|
from jose import jwt
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def oidc_config():
|
||||||
|
return {
|
||||||
|
"client_id": "test_id",
|
||||||
|
"client_secret": "test_secret",
|
||||||
|
"server_metadata_url": "https://example.com/.well-known/openid-configuration",
|
||||||
|
"redirect_uri": "http://localhost:5173/auth/callback"
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_metadata():
|
||||||
|
return {
|
||||||
|
"issuer": "https://example.com",
|
||||||
|
"jwks_uri": "https://example.com/jwks",
|
||||||
|
"id_token_signing_alg_values_supported": ["RS256"]
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_oidc_validate_id_token_success(oidc_config, mock_metadata):
|
||||||
|
client = OIDCClient(**oidc_config)
|
||||||
|
|
||||||
|
id_token_payload = {
|
||||||
|
"iss": "https://example.com",
|
||||||
|
"sub": "user123",
|
||||||
|
"aud": "test_id",
|
||||||
|
"nonce": "test_nonce",
|
||||||
|
"exp": 9999999999,
|
||||||
|
"iat": 1000000000
|
||||||
|
}
|
||||||
|
|
||||||
|
# Mock JWT decoding, JWKS fetching, and metadata fetching
|
||||||
|
with patch("ea_chatbot.auth.jwt.decode") as mock_decode, \
|
||||||
|
patch.object(client, "fetch_jwks") as mock_fetch_jwks, \
|
||||||
|
patch.object(client, "fetch_metadata") as mock_fetch_metadata:
|
||||||
|
|
||||||
|
mock_decode.return_value = id_token_payload
|
||||||
|
mock_fetch_metadata.return_value = mock_metadata
|
||||||
|
mock_fetch_jwks.return_value = {"keys": []}
|
||||||
|
|
||||||
|
claims = client.validate_id_token("fake_token", nonce="test_nonce")
|
||||||
|
|
||||||
|
assert claims == id_token_payload
|
||||||
|
mock_decode.assert_called_once()
|
||||||
|
|
||||||
|
def test_oidc_validate_id_token_invalid_nonce(oidc_config, mock_metadata):
|
||||||
|
client = OIDCClient(**oidc_config)
|
||||||
|
|
||||||
|
id_token_payload = {
|
||||||
|
"iss": "https://example.com",
|
||||||
|
"aud": "test_id",
|
||||||
|
"nonce": "wrong_nonce",
|
||||||
|
"exp": 9999999999
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch("ea_chatbot.auth.jwt.decode") as mock_decode, \
|
||||||
|
patch.object(client, "fetch_jwks") as mock_fetch_jwks, \
|
||||||
|
patch.object(client, "fetch_metadata") as mock_fetch_metadata:
|
||||||
|
|
||||||
|
mock_decode.return_value = id_token_payload
|
||||||
|
mock_fetch_metadata.return_value = mock_metadata
|
||||||
|
mock_fetch_jwks.return_value = {"keys": []}
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Invalid nonce"):
|
||||||
|
client.validate_id_token("fake_token", nonce="test_nonce")
|
||||||
Reference in New Issue
Block a user