feat(auth): Implement OIDCSession helper for secure temporary storage

This commit is contained in:
Yunxiao Xu
2026-02-14 04:12:14 -08:00
parent 2220714962
commit 05261e3cda
2 changed files with 71 additions and 0 deletions

View File

@@ -1,7 +1,9 @@
import requests
from datetime import datetime, timedelta, timezone
from enum import Enum
from typing import Dict, Any, Optional
from authlib.integrations.requests_client import OAuth2Session
from jose import JWTError, jwt
class AuthType(Enum):
LOCAL = "local"
@@ -29,6 +31,31 @@ def get_user_auth_type(email: str, history_manager: Any) -> AuthType:
return AuthType.OIDC
class OIDCSession:
"""Helper for managing temporary OIDC session data in a secure cookie."""
ALGORITHM = "HS256"
@classmethod
def encrypt(cls, data: dict, secret_key: str, expires_delta: Optional[timedelta] = None) -> str:
"""Encrypt OIDC session data into a JWT."""
to_encode = data.copy()
now = datetime.now(timezone.utc)
if expires_delta:
expire = now + expires_delta
else:
expire = now + timedelta(minutes=5) # Short-lived by default
to_encode.update({"exp": expire, "iat": now})
return jwt.encode(to_encode, secret_key, algorithm=cls.ALGORITHM)
@classmethod
def decrypt(cls, token: str, secret_key: str) -> Optional[dict]:
"""Decrypt and validate OIDC session data from a JWT."""
try:
return jwt.decode(token, secret_key, algorithms=[cls.ALGORITHM])
except JWTError:
return None
class OIDCClient:
"""
Client for OIDC Authentication using Authlib.

View File

@@ -0,0 +1,44 @@
import pytest
from datetime import timedelta
from ea_chatbot.auth import OIDCSession
from ea_chatbot.config import Settings
@pytest.fixture
def settings():
return Settings()
def test_oidc_session_encrypt_decrypt(settings):
session_data = {
"state": "test_state",
"nonce": "test_nonce",
"code_verifier": "test_verifier"
}
# Encrypt
token = OIDCSession.encrypt(session_data, settings.secret_key)
assert isinstance(token, str)
assert token != ""
# Decrypt
decrypted_data = OIDCSession.decrypt(token, settings.secret_key)
assert decrypted_data["state"] == "test_state"
assert decrypted_data["nonce"] == "test_nonce"
assert decrypted_data["code_verifier"] == "test_verifier"
def test_oidc_session_invalid_signature(settings):
session_data = {"state": "test_state"}
token = OIDCSession.encrypt(session_data, settings.secret_key)
# Tamper with the token
tampered_token = token[:-5] + "aaaaa"
decrypted_data = OIDCSession.decrypt(tampered_token, settings.secret_key)
assert decrypted_data is None
def test_oidc_session_expired(settings):
session_data = {"state": "test_state"}
# Encrypt with a very short expiration
token = OIDCSession.encrypt(session_data, settings.secret_key, expires_delta=timedelta(seconds=-1))
decrypted_data = OIDCSession.decrypt(token, settings.secret_key)
assert decrypted_data is None