refactor(auth): Use user_id as JWT sub and implement get_user_by_id

Switched from username to user_id as the primary identifier in JWT tokens to better support external authentication providers. Added get_user_by_id to HistoryManager and updated API dependencies and tests to reflect these changes.
This commit is contained in:
Yunxiao Xu
2026-02-11 16:41:27 -08:00
parent ceddacf9cb
commit b23fbce8d0
9 changed files with 31 additions and 15 deletions

View File

@@ -35,11 +35,11 @@ async def get_current_user(token: str = Depends(oauth2_scheme)) -> User:
if payload is None: if payload is None:
raise credentials_exception raise credentials_exception
username: str | None = payload.get("sub") user_id: str | None = payload.get("sub")
if username is None: if user_id is None:
raise credentials_exception raise credentials_exception
user = history_manager.get_user(username) user = history_manager.get_user_by_id(user_id)
if user is None: if user is None:
raise credentials_exception raise credentials_exception

View File

@@ -39,7 +39,7 @@ async def login(form_data: OAuth2PasswordRequestForm = Depends()):
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
access_token = create_access_token(data={"sub": user.username, "user_id": str(user.id)}) access_token = create_access_token(data={"sub": str(user.id)})
return {"access_token": access_token, "token_type": "bearer"} return {"access_token": access_token, "token_type": "bearer"}
@router.get("/oidc/login") @router.get("/oidc/login")
@@ -71,7 +71,7 @@ async def oidc_callback(code: str):
user = history_manager.sync_user_from_oidc(email=email, display_name=name) user = history_manager.sync_user_from_oidc(email=email, display_name=name)
access_token = create_access_token(data={"sub": user.username, "user_id": str(user.id)}) access_token = create_access_token(data={"sub": str(user.id)})
return {"access_token": access_token, "token_type": "bearer"} return {"access_token": access_token, "token_type": "bearer"}
except Exception as e: except Exception as e:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=f"OIDC authentication failed: {str(e)}") raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=f"OIDC authentication failed: {str(e)}")

View File

@@ -19,12 +19,18 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -
str: The encoded JWT token. str: The encoded JWT token.
""" """
to_encode = data.copy() to_encode = data.copy()
if expires_delta: now = datetime.now(timezone.utc)
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.access_token_expire_minutes)
to_encode.update({"exp": expire}) if expires_delta:
expire = now + expires_delta
else:
expire = now + timedelta(minutes=settings.access_token_expire_minutes)
to_encode.update({
"exp": expire,
"iat": now,
"iss": "ea-chatbot-api"
})
encoded_jwt = jwt.encode(to_encode, settings.secret_key, algorithm=settings.algorithm) encoded_jwt = jwt.encode(to_encode, settings.secret_key, algorithm=settings.algorithm)
return encoded_jwt return encoded_jwt

View File

@@ -39,6 +39,11 @@ class HistoryManager:
result = session.execute(select(User).where(User.username == email)) result = session.execute(select(User).where(User.username == email))
return result.scalar_one_or_none() return result.scalar_one_or_none()
def get_user_by_id(self, user_id: str) -> Optional[User]:
"""Fetch a user by their ID."""
with self.get_session() as session:
return session.get(User, user_id)
def create_user(self, email: str, password: Optional[str] = None, display_name: Optional[str] = None) -> User: def create_user(self, email: str, password: Optional[str] = None, display_name: Optional[str] = None) -> User:
"""Create a new local user.""" """Create a new local user."""
hashed_password = ph.hash(password) if password else None hashed_password = ph.hash(password) if password else None

View File

@@ -15,7 +15,7 @@ def mock_user():
@pytest.fixture @pytest.fixture
def auth_header(mock_user): def auth_header(mock_user):
app.dependency_overrides[get_current_user] = lambda: mock_user app.dependency_overrides[get_current_user] = lambda: mock_user
token = create_access_token(data={"sub": mock_user.username, "user_id": mock_user.id}) token = create_access_token(data={"sub": mock_user.id})
yield {"Authorization": f"Bearer {token}"} yield {"Authorization": f"Bearer {token}"}
app.dependency_overrides.clear() app.dependency_overrides.clear()

View File

@@ -92,10 +92,10 @@ def test_oidc_callback_success():
def test_get_me_success(): def test_get_me_success():
"""Test getting current user with a valid token.""" """Test getting current user with a valid token."""
from ea_chatbot.api.utils import create_access_token from ea_chatbot.api.utils import create_access_token
token = create_access_token(data={"sub": "test@example.com", "user_id": "123"}) token = create_access_token(data={"sub": "123"})
with patch("ea_chatbot.api.dependencies.history_manager") as mock_hm: with patch("ea_chatbot.api.dependencies.history_manager") as mock_hm:
mock_hm.get_user.return_value = User(id="123", username="test@example.com", display_name="Test") mock_hm.get_user_by_id.return_value = User(id="123", username="test@example.com", display_name="Test")
response = client.get( response = client.get(
"/auth/me", "/auth/me",

View File

@@ -18,7 +18,7 @@ def mock_user():
def auth_header(mock_user): def auth_header(mock_user):
# Override get_current_user to return our mock user # Override get_current_user to return our mock user
app.dependency_overrides[get_current_user] = lambda: mock_user app.dependency_overrides[get_current_user] = lambda: mock_user
token = create_access_token(data={"sub": mock_user.username, "user_id": mock_user.id}) token = create_access_token(data={"sub": mock_user.id})
yield {"Authorization": f"Bearer {token}"} yield {"Authorization": f"Bearer {token}"}
app.dependency_overrides.clear() app.dependency_overrides.clear()

View File

@@ -17,7 +17,7 @@ def mock_user():
@pytest.fixture @pytest.fixture
def auth_header(mock_user): def auth_header(mock_user):
app.dependency_overrides[get_current_user] = lambda: mock_user app.dependency_overrides[get_current_user] = lambda: mock_user
token = create_access_token(data={"sub": mock_user.username, "user_id": mock_user.id}) token = create_access_token(data={"sub": mock_user.id})
yield {"Authorization": f"Bearer {token}"} yield {"Authorization": f"Bearer {token}"}
app.dependency_overrides.clear() app.dependency_overrides.clear()

View File

@@ -25,6 +25,11 @@ def test_full_history_workflow(history_manager):
assert user is not None assert user is not None
assert user.display_name == "E2E User" assert user.display_name == "E2E User"
# 1.1 Verify get_user_by_id
fetched_user = history_manager.get_user_by_id(user.id)
assert fetched_user is not None
assert fetched_user.username == email
# 2. Create Conversation # 2. Create Conversation
conv = history_manager.create_conversation(user.id, "nj", "Test Analytics") conv = history_manager.create_conversation(user.id, "nj", "Test Analytics")
assert conv.id is not None assert conv.id is not None