feat(auth): Implement OIDCSession helper for secure temporary storage

This commit is contained in:
Yunxiao Xu
2026-02-14 04:12:14 -08:00
parent 2220714962
commit 05261e3cda
2 changed files with 71 additions and 0 deletions

View File

@@ -1,7 +1,9 @@
import requests import requests
from datetime import datetime, timedelta, timezone
from enum import Enum from enum import Enum
from typing import Dict, Any, Optional from typing import Dict, Any, Optional
from authlib.integrations.requests_client import OAuth2Session from authlib.integrations.requests_client import OAuth2Session
from jose import JWTError, jwt
class AuthType(Enum): class AuthType(Enum):
LOCAL = "local" LOCAL = "local"
@@ -29,6 +31,31 @@ def get_user_auth_type(email: str, history_manager: Any) -> AuthType:
return AuthType.OIDC return AuthType.OIDC
class OIDCSession:
"""Helper for managing temporary OIDC session data in a secure cookie."""
ALGORITHM = "HS256"
@classmethod
def encrypt(cls, data: dict, secret_key: str, expires_delta: Optional[timedelta] = None) -> str:
"""Encrypt OIDC session data into a JWT."""
to_encode = data.copy()
now = datetime.now(timezone.utc)
if expires_delta:
expire = now + expires_delta
else:
expire = now + timedelta(minutes=5) # Short-lived by default
to_encode.update({"exp": expire, "iat": now})
return jwt.encode(to_encode, secret_key, algorithm=cls.ALGORITHM)
@classmethod
def decrypt(cls, token: str, secret_key: str) -> Optional[dict]:
"""Decrypt and validate OIDC session data from a JWT."""
try:
return jwt.decode(token, secret_key, algorithms=[cls.ALGORITHM])
except JWTError:
return None
class OIDCClient: class OIDCClient:
""" """
Client for OIDC Authentication using Authlib. Client for OIDC Authentication using Authlib.

View File

@@ -0,0 +1,44 @@
import pytest
from datetime import timedelta
from ea_chatbot.auth import OIDCSession
from ea_chatbot.config import Settings
@pytest.fixture
def settings():
return Settings()
def test_oidc_session_encrypt_decrypt(settings):
session_data = {
"state": "test_state",
"nonce": "test_nonce",
"code_verifier": "test_verifier"
}
# Encrypt
token = OIDCSession.encrypt(session_data, settings.secret_key)
assert isinstance(token, str)
assert token != ""
# Decrypt
decrypted_data = OIDCSession.decrypt(token, settings.secret_key)
assert decrypted_data["state"] == "test_state"
assert decrypted_data["nonce"] == "test_nonce"
assert decrypted_data["code_verifier"] == "test_verifier"
def test_oidc_session_invalid_signature(settings):
session_data = {"state": "test_state"}
token = OIDCSession.encrypt(session_data, settings.secret_key)
# Tamper with the token
tampered_token = token[:-5] + "aaaaa"
decrypted_data = OIDCSession.decrypt(tampered_token, settings.secret_key)
assert decrypted_data is None
def test_oidc_session_expired(settings):
session_data = {"state": "test_state"}
# Encrypt with a very short expiration
token = OIDCSession.encrypt(session_data, settings.secret_key, expires_delta=timedelta(seconds=-1))
decrypted_data = OIDCSession.decrypt(token, settings.secret_key)
assert decrypted_data is None