diff --git a/backend/src/ea_chatbot/auth.py b/backend/src/ea_chatbot/auth.py index ae88fa9..01ea087 100644 --- a/backend/src/ea_chatbot/auth.py +++ b/backend/src/ea_chatbot/auth.py @@ -81,6 +81,7 @@ class OIDCClient: scope="openid email profile" ) self.metadata: Dict[str, Any] = {} + self.jwks: Optional[Dict[str, Any]] = None def fetch_metadata(self) -> Dict[str, Any]: """Fetch OIDC provider metadata if not already fetched.""" @@ -88,6 +89,16 @@ class OIDCClient: self.metadata = requests.get(self.server_metadata_url).json() return self.metadata + def fetch_jwks(self) -> Dict[str, Any]: + """Fetch OIDC provider JWKS if not already fetched.""" + if not self.jwks: + metadata = self.fetch_metadata() + jwks_uri = metadata.get("jwks_uri") + if not jwks_uri: + raise ValueError("jwks_uri not found in OIDC metadata") + self.jwks = requests.get(jwks_uri).json() + return self.jwks + def generate_pkce(self) -> Tuple[str, str]: """Generate PKCE code_verifier and code_challenge.""" code_verifier = generate_token(48) @@ -127,6 +138,30 @@ class OIDCClient: "code_verifier": code_verifier } + def validate_id_token(self, token: str, nonce: Optional[str] = None) -> Dict[str, Any]: + """ + Validate the ID Token and return its claims. + Verifies signature (using JWKS), issuer, audience, and nonce. + """ + metadata = self.fetch_metadata() + jwks = self.fetch_jwks() + + try: + claims = jwt.decode( + token, + jwks, + algorithms=metadata.get("id_token_signing_alg_values_supported", ["RS256"]), + audience=self.client_id, + issuer=metadata.get("issuer") + ) + + if nonce and claims.get("nonce") != nonce: + raise ValueError("Invalid nonce") + + return claims + except JWTError as e: + raise ValueError(f"ID Token validation failed: {str(e)}") + def get_login_url(self) -> str: """Legacy method for generating simple authorization URL.""" metadata = self.fetch_metadata() diff --git a/backend/tests/test_oidc_validation.py b/backend/tests/test_oidc_validation.py new file mode 100644 index 0000000..b61e4fc --- /dev/null +++ b/backend/tests/test_oidc_validation.py @@ -0,0 +1,68 @@ +import pytest +from unittest.mock import MagicMock, patch +from ea_chatbot.auth import OIDCClient +from jose import jwt + +@pytest.fixture +def oidc_config(): + return { + "client_id": "test_id", + "client_secret": "test_secret", + "server_metadata_url": "https://example.com/.well-known/openid-configuration", + "redirect_uri": "http://localhost:5173/auth/callback" + } + +@pytest.fixture +def mock_metadata(): + return { + "issuer": "https://example.com", + "jwks_uri": "https://example.com/jwks", + "id_token_signing_alg_values_supported": ["RS256"] + } + +def test_oidc_validate_id_token_success(oidc_config, mock_metadata): + client = OIDCClient(**oidc_config) + + id_token_payload = { + "iss": "https://example.com", + "sub": "user123", + "aud": "test_id", + "nonce": "test_nonce", + "exp": 9999999999, + "iat": 1000000000 + } + + # Mock JWT decoding, JWKS fetching, and metadata fetching + with patch("ea_chatbot.auth.jwt.decode") as mock_decode, \ + patch.object(client, "fetch_jwks") as mock_fetch_jwks, \ + patch.object(client, "fetch_metadata") as mock_fetch_metadata: + + mock_decode.return_value = id_token_payload + mock_fetch_metadata.return_value = mock_metadata + mock_fetch_jwks.return_value = {"keys": []} + + claims = client.validate_id_token("fake_token", nonce="test_nonce") + + assert claims == id_token_payload + mock_decode.assert_called_once() + +def test_oidc_validate_id_token_invalid_nonce(oidc_config, mock_metadata): + client = OIDCClient(**oidc_config) + + id_token_payload = { + "iss": "https://example.com", + "aud": "test_id", + "nonce": "wrong_nonce", + "exp": 9999999999 + } + + with patch("ea_chatbot.auth.jwt.decode") as mock_decode, \ + patch.object(client, "fetch_jwks") as mock_fetch_jwks, \ + patch.object(client, "fetch_metadata") as mock_fetch_metadata: + + mock_decode.return_value = id_token_payload + mock_fetch_metadata.return_value = mock_metadata + mock_fetch_jwks.return_value = {"keys": []} + + with pytest.raises(ValueError, match="Invalid nonce"): + client.validate_id_token("fake_token", nonce="test_nonce")