feat(auth): Implement OIDCSession helper for secure temporary storage
This commit is contained in:
@@ -1,7 +1,9 @@
|
|||||||
import requests
|
import requests
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, Any, Optional
|
from typing import Dict, Any, Optional
|
||||||
from authlib.integrations.requests_client import OAuth2Session
|
from authlib.integrations.requests_client import OAuth2Session
|
||||||
|
from jose import JWTError, jwt
|
||||||
|
|
||||||
class AuthType(Enum):
|
class AuthType(Enum):
|
||||||
LOCAL = "local"
|
LOCAL = "local"
|
||||||
@@ -29,6 +31,31 @@ def get_user_auth_type(email: str, history_manager: Any) -> AuthType:
|
|||||||
|
|
||||||
return AuthType.OIDC
|
return AuthType.OIDC
|
||||||
|
|
||||||
|
class OIDCSession:
|
||||||
|
"""Helper for managing temporary OIDC session data in a secure cookie."""
|
||||||
|
ALGORITHM = "HS256"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def encrypt(cls, data: dict, secret_key: str, expires_delta: Optional[timedelta] = None) -> str:
|
||||||
|
"""Encrypt OIDC session data into a JWT."""
|
||||||
|
to_encode = data.copy()
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
if expires_delta:
|
||||||
|
expire = now + expires_delta
|
||||||
|
else:
|
||||||
|
expire = now + timedelta(minutes=5) # Short-lived by default
|
||||||
|
|
||||||
|
to_encode.update({"exp": expire, "iat": now})
|
||||||
|
return jwt.encode(to_encode, secret_key, algorithm=cls.ALGORITHM)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def decrypt(cls, token: str, secret_key: str) -> Optional[dict]:
|
||||||
|
"""Decrypt and validate OIDC session data from a JWT."""
|
||||||
|
try:
|
||||||
|
return jwt.decode(token, secret_key, algorithms=[cls.ALGORITHM])
|
||||||
|
except JWTError:
|
||||||
|
return None
|
||||||
|
|
||||||
class OIDCClient:
|
class OIDCClient:
|
||||||
"""
|
"""
|
||||||
Client for OIDC Authentication using Authlib.
|
Client for OIDC Authentication using Authlib.
|
||||||
|
|||||||
44
backend/tests/test_oidc_session.py
Normal file
44
backend/tests/test_oidc_session.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
import pytest
|
||||||
|
from datetime import timedelta
|
||||||
|
from ea_chatbot.auth import OIDCSession
|
||||||
|
from ea_chatbot.config import Settings
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def settings():
|
||||||
|
return Settings()
|
||||||
|
|
||||||
|
def test_oidc_session_encrypt_decrypt(settings):
|
||||||
|
session_data = {
|
||||||
|
"state": "test_state",
|
||||||
|
"nonce": "test_nonce",
|
||||||
|
"code_verifier": "test_verifier"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Encrypt
|
||||||
|
token = OIDCSession.encrypt(session_data, settings.secret_key)
|
||||||
|
assert isinstance(token, str)
|
||||||
|
assert token != ""
|
||||||
|
|
||||||
|
# Decrypt
|
||||||
|
decrypted_data = OIDCSession.decrypt(token, settings.secret_key)
|
||||||
|
assert decrypted_data["state"] == "test_state"
|
||||||
|
assert decrypted_data["nonce"] == "test_nonce"
|
||||||
|
assert decrypted_data["code_verifier"] == "test_verifier"
|
||||||
|
|
||||||
|
def test_oidc_session_invalid_signature(settings):
|
||||||
|
session_data = {"state": "test_state"}
|
||||||
|
token = OIDCSession.encrypt(session_data, settings.secret_key)
|
||||||
|
|
||||||
|
# Tamper with the token
|
||||||
|
tampered_token = token[:-5] + "aaaaa"
|
||||||
|
|
||||||
|
decrypted_data = OIDCSession.decrypt(tampered_token, settings.secret_key)
|
||||||
|
assert decrypted_data is None
|
||||||
|
|
||||||
|
def test_oidc_session_expired(settings):
|
||||||
|
session_data = {"state": "test_state"}
|
||||||
|
# Encrypt with a very short expiration
|
||||||
|
token = OIDCSession.encrypt(session_data, settings.secret_key, expires_delta=timedelta(seconds=-1))
|
||||||
|
|
||||||
|
decrypted_data = OIDCSession.decrypt(token, settings.secret_key)
|
||||||
|
assert decrypted_data is None
|
||||||
Reference in New Issue
Block a user