diff --git a/src/ea_chatbot/api/dependencies.py b/src/ea_chatbot/api/dependencies.py index 84fdddd..fe0f533 100644 --- a/src/ea_chatbot/api/dependencies.py +++ b/src/ea_chatbot/api/dependencies.py @@ -35,11 +35,11 @@ async def get_current_user(token: str = Depends(oauth2_scheme)) -> User: if payload is None: raise credentials_exception - username: str | None = payload.get("sub") - if username is None: + user_id: str | None = payload.get("sub") + if user_id is None: raise credentials_exception - user = history_manager.get_user(username) + user = history_manager.get_user_by_id(user_id) if user is None: raise credentials_exception diff --git a/src/ea_chatbot/api/routers/auth.py b/src/ea_chatbot/api/routers/auth.py index 28d4953..96beb9c 100644 --- a/src/ea_chatbot/api/routers/auth.py +++ b/src/ea_chatbot/api/routers/auth.py @@ -39,7 +39,7 @@ async def login(form_data: OAuth2PasswordRequestForm = Depends()): headers={"WWW-Authenticate": "Bearer"}, ) - access_token = create_access_token(data={"sub": user.username, "user_id": str(user.id)}) + access_token = create_access_token(data={"sub": str(user.id)}) return {"access_token": access_token, "token_type": "bearer"} @router.get("/oidc/login") @@ -71,7 +71,7 @@ async def oidc_callback(code: str): user = history_manager.sync_user_from_oidc(email=email, display_name=name) - access_token = create_access_token(data={"sub": user.username, "user_id": str(user.id)}) + access_token = create_access_token(data={"sub": str(user.id)}) return {"access_token": access_token, "token_type": "bearer"} except Exception as e: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=f"OIDC authentication failed: {str(e)}") diff --git a/src/ea_chatbot/api/utils.py b/src/ea_chatbot/api/utils.py index a775095..4019dcd 100644 --- a/src/ea_chatbot/api/utils.py +++ b/src/ea_chatbot/api/utils.py @@ -19,12 +19,18 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) - str: The encoded JWT token. """ to_encode = data.copy() - if expires_delta: - expire = datetime.now(timezone.utc) + expires_delta - else: - expire = datetime.now(timezone.utc) + timedelta(minutes=settings.access_token_expire_minutes) + now = datetime.now(timezone.utc) - to_encode.update({"exp": expire}) + if expires_delta: + expire = now + expires_delta + else: + expire = now + timedelta(minutes=settings.access_token_expire_minutes) + + to_encode.update({ + "exp": expire, + "iat": now, + "iss": "ea-chatbot-api" + }) encoded_jwt = jwt.encode(to_encode, settings.secret_key, algorithm=settings.algorithm) return encoded_jwt diff --git a/src/ea_chatbot/history/manager.py b/src/ea_chatbot/history/manager.py index aa3c06e..6adc8a1 100644 --- a/src/ea_chatbot/history/manager.py +++ b/src/ea_chatbot/history/manager.py @@ -39,6 +39,11 @@ class HistoryManager: result = session.execute(select(User).where(User.username == email)) return result.scalar_one_or_none() + def get_user_by_id(self, user_id: str) -> Optional[User]: + """Fetch a user by their ID.""" + with self.get_session() as session: + return session.get(User, user_id) + def create_user(self, email: str, password: Optional[str] = None, display_name: Optional[str] = None) -> User: """Create a new local user.""" hashed_password = ph.hash(password) if password else None diff --git a/tests/api/test_agent.py b/tests/api/test_agent.py index ae4f21b..09fc007 100644 --- a/tests/api/test_agent.py +++ b/tests/api/test_agent.py @@ -15,7 +15,7 @@ def mock_user(): @pytest.fixture def auth_header(mock_user): app.dependency_overrides[get_current_user] = lambda: mock_user - token = create_access_token(data={"sub": mock_user.username, "user_id": mock_user.id}) + token = create_access_token(data={"sub": mock_user.id}) yield {"Authorization": f"Bearer {token}"} app.dependency_overrides.clear() diff --git a/tests/api/test_api_auth.py b/tests/api/test_api_auth.py index a97dac7..9b8d94c 100644 --- a/tests/api/test_api_auth.py +++ b/tests/api/test_api_auth.py @@ -92,10 +92,10 @@ def test_oidc_callback_success(): def test_get_me_success(): """Test getting current user with a valid token.""" from ea_chatbot.api.utils import create_access_token - token = create_access_token(data={"sub": "test@example.com", "user_id": "123"}) + token = create_access_token(data={"sub": "123"}) with patch("ea_chatbot.api.dependencies.history_manager") as mock_hm: - mock_hm.get_user.return_value = User(id="123", username="test@example.com", display_name="Test") + mock_hm.get_user_by_id.return_value = User(id="123", username="test@example.com", display_name="Test") response = client.get( "/auth/me", diff --git a/tests/api/test_api_history.py b/tests/api/test_api_history.py index 5e0a76b..d322e44 100644 --- a/tests/api/test_api_history.py +++ b/tests/api/test_api_history.py @@ -18,7 +18,7 @@ def mock_user(): def auth_header(mock_user): # Override get_current_user to return our mock user app.dependency_overrides[get_current_user] = lambda: mock_user - token = create_access_token(data={"sub": mock_user.username, "user_id": mock_user.id}) + token = create_access_token(data={"sub": mock_user.id}) yield {"Authorization": f"Bearer {token}"} app.dependency_overrides.clear() diff --git a/tests/api/test_persistence.py b/tests/api/test_persistence.py index fac2640..d1fb265 100644 --- a/tests/api/test_persistence.py +++ b/tests/api/test_persistence.py @@ -17,7 +17,7 @@ def mock_user(): @pytest.fixture def auth_header(mock_user): app.dependency_overrides[get_current_user] = lambda: mock_user - token = create_access_token(data={"sub": mock_user.username, "user_id": mock_user.id}) + token = create_access_token(data={"sub": mock_user.id}) yield {"Authorization": f"Bearer {token}"} app.dependency_overrides.clear() diff --git a/tests/test_workflow_history.py b/tests/test_workflow_history.py index ccc826a..f774f83 100644 --- a/tests/test_workflow_history.py +++ b/tests/test_workflow_history.py @@ -25,6 +25,11 @@ def test_full_history_workflow(history_manager): assert user is not None assert user.display_name == "E2E User" + # 1.1 Verify get_user_by_id + fetched_user = history_manager.get_user_by_id(user.id) + assert fetched_user is not None + assert fetched_user.username == email + # 2. Create Conversation conv = history_manager.create_conversation(user.id, "nj", "Test Analytics") assert conv.id is not None