diff --git a/src/ea_chatbot/api/utils.py b/src/ea_chatbot/api/utils.py new file mode 100644 index 0000000..d30ba9c --- /dev/null +++ b/src/ea_chatbot/api/utils.py @@ -0,0 +1,43 @@ +from datetime import datetime, timedelta, timezone +from typing import Optional, Union +from jose import JWTError, jwt +from ea_chatbot.config import Settings + +settings = Settings() + +def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: + """ + Create a JWT access token. + + Args: + data: The payload data to encode. + expires_delta: Optional expiration time delta. + + Returns: + str: The encoded JWT token. + """ + to_encode = data.copy() + if expires_delta: + expire = datetime.now(timezone.utc) + expires_delta + else: + expire = datetime.now(timezone.utc) + timedelta(minutes=settings.access_token_expire_minutes) + + to_encode.update({"exp": expire}) + encoded_jwt = jwt.encode(to_encode, settings.secret_key, algorithm=settings.algorithm) + return encoded_jwt + +def decode_access_token(token: str) -> Optional[dict]: + """ + Decode a JWT access token. + + Args: + token: The token to decode. + + Returns: + Optional[dict]: The decoded payload if valid, None otherwise. + """ + try: + payload = jwt.decode(token, settings.secret_key, algorithms=[settings.algorithm]) + return payload + except JWTError: + return None diff --git a/src/ea_chatbot/config.py b/src/ea_chatbot/config.py index 10caffd..e440938 100644 --- a/src/ea_chatbot/config.py +++ b/src/ea_chatbot/config.py @@ -28,6 +28,11 @@ class Settings(BaseSettings): # Application/History Database history_db_url: str = Field(default="postgresql://user:password@localhost:5433/ea_history", alias="HISTORY_DB_URL") + # JWT Configuration + secret_key: str = Field(default="change-me-in-production", alias="SECRET_KEY") + algorithm: str = Field(default="HS256", alias="ALGORITHM") + access_token_expire_minutes: int = Field(default=30, alias="ACCESS_TOKEN_EXPIRE_MINUTES") + # OIDC Configuration oidc_client_id: Optional[str] = Field(default=None, alias="OIDC_CLIENT_ID") oidc_client_secret: Optional[str] = Field(default=None, alias="OIDC_CLIENT_SECRET") diff --git a/tests/api/test_utils.py b/tests/api/test_utils.py new file mode 100644 index 0000000..9fb1d26 --- /dev/null +++ b/tests/api/test_utils.py @@ -0,0 +1,24 @@ +from datetime import timedelta +from ea_chatbot.api.utils import create_access_token, decode_access_token + +def test_create_and_decode_access_token(): + """Test that a token can be created and then decoded.""" + data = {"sub": "test@example.com", "user_id": "123"} + token = create_access_token(data) + + decoded = decode_access_token(token) + assert decoded["sub"] == data["sub"] + assert decoded["user_id"] == data["user_id"] + assert "exp" in decoded + +def test_decode_invalid_token(): + """Test that an invalid token returns None.""" + assert decode_access_token("invalid-token") is None + +def test_expired_token(): + """Test that an expired token returns None.""" + data = {"sub": "test@example.com"} + # Create a token that expired 1 minute ago + token = create_access_token(data, expires_delta=timedelta(minutes=-1)) + + assert decode_access_token(token) is None