fix(backend): Refactor OIDC callback and auth dependency to correctly handle cookies and prefix all API routes with /api/v1.

This commit is contained in:
Yunxiao Xu
2026-02-12 01:26:28 -08:00
parent 49a9da7c0c
commit 0dfdef738d
6 changed files with 51 additions and 29 deletions

View File

@@ -25,9 +25,8 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/v1/auth/login", auto_error=Fa
async def get_current_user(request: Request, token: str = Depends(oauth2_scheme)) -> User: async def get_current_user(request: Request, token: str = Depends(oauth2_scheme)) -> User:
"""Dependency to get the current authenticated user from the JWT token (cookie or header).""" """Dependency to get the current authenticated user from the JWT token (cookie or header)."""
# Try getting token from cookie first # Prioritize cookie, fallback to header
if not token: token = request.cookies.get("access_token") or token
token = request.cookies.get("access_token")
credentials_exception = HTTPException( credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,

View File

@@ -1,4 +1,4 @@
from fastapi import APIRouter, Depends, HTTPException, status, Response from fastapi import APIRouter, Depends, HTTPException, status, Response, Request
from fastapi.responses import RedirectResponse from fastapi.responses import RedirectResponse
from fastapi.security import OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordRequestForm
from ea_chatbot.api.utils import create_access_token from ea_chatbot.api.utils import create_access_token
@@ -6,6 +6,9 @@ from ea_chatbot.api.dependencies import history_manager, oidc_client, get_curren
from ea_chatbot.history.models import User as UserDB from ea_chatbot.history.models import User as UserDB
from ea_chatbot.api.schemas import Token, UserCreate, UserResponse from ea_chatbot.api.schemas import Token, UserCreate, UserResponse
import os import os
import logging
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/auth", tags=["auth"]) router = APIRouter(prefix="/auth", tags=["auth"])
@@ -80,9 +83,9 @@ async def oidc_login():
url = oidc_client.get_login_url() url = oidc_client.get_login_url()
return {"url": url} return {"url": url}
@router.get("/oidc/callback") @router.get("/oidc/callback", response_model=Token)
async def oidc_callback(code: str): async def oidc_callback(request: Request, response: Response, code: str):
"""Handle the OIDC callback, issue a JWT, and redirect to frontend.""" """Handle the OIDC callback, issue a JWT, and redirect or return JSON."""
if not oidc_client: if not oidc_client:
raise HTTPException(status_code=status.HTTP_510_NOT_EXTENDED, detail="OIDC not configured") raise HTTPException(status_code=status.HTTP_510_NOT_EXTENDED, detail="OIDC not configured")
@@ -99,12 +102,29 @@ async def oidc_callback(code: str):
access_token = create_access_token(data={"sub": str(user.id)}) access_token = create_access_token(data={"sub": str(user.id)})
response = RedirectResponse(url=f"{FRONTEND_URL}/auth/callback") # Determine if we should redirect (direct provider callback) or return JSON (frontend exchange)
set_auth_cookie(response, access_token) is_ajax = request.headers.get("X-Requested-With") == "XMLHttpRequest" or \
return response "application/json" in request.headers.get("Accept", "")
if is_ajax:
set_auth_cookie(response, access_token)
return {"access_token": access_token, "token_type": "bearer"}
else:
redirect_response = RedirectResponse(url=f"{FRONTEND_URL}/auth/callback")
set_auth_cookie(redirect_response, access_token)
return redirect_response
except HTTPException:
raise
except Exception as e: except Exception as e:
# Redirect to frontend with error logger.exception("OIDC authentication failed")
return RedirectResponse(url=f"{FRONTEND_URL}?error=oidc_failed") # For non-ajax, redirect with error
is_ajax = request.headers.get("X-Requested-With") == "XMLHttpRequest" or \
"application/json" in request.headers.get("Accept", "")
if is_ajax:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="OIDC authentication failed")
else:
return RedirectResponse(url=f"{FRONTEND_URL}?error=oidc_failed")
@router.get("/me", response_model=UserResponse) @router.get("/me", response_model=UserResponse)
async def get_me(current_user: UserDB = Depends(get_current_user)): async def get_me(current_user: UserDB = Depends(get_current_user)):

View File

@@ -21,7 +21,7 @@ def auth_header(mock_user):
def test_stream_agent_unauthorized(): def test_stream_agent_unauthorized():
"""Test that streaming requires authentication.""" """Test that streaming requires authentication."""
response = client.post("/chat/stream", json={"message": "hello"}) response = client.post("/api/v1/chat/stream", json={"message": "hello"})
assert response.status_code == 401 assert response.status_code == 401
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -52,7 +52,7 @@ async def test_stream_agent_success(auth_header, mock_user):
mock_session.get.return_value = mock_conv mock_session.get.return_value = mock_conv
# Using TestClient with a stream context # Using TestClient with a stream context
with client.stream("POST", "/chat/stream", with client.stream("POST", "/api/v1/chat/stream",
json={"message": "hello", "thread_id": "t1"}, json={"message": "hello", "thread_id": "t1"},
headers=auth_header) as response: headers=auth_header) as response:
assert response.status_code == 200 assert response.status_code == 200

View File

@@ -26,7 +26,7 @@ def test_register_user_success():
mock_hm.create_user.return_value = User(id="1", username="new@example.com", display_name="New") mock_hm.create_user.return_value = User(id="1", username="new@example.com", display_name="New")
response = client.post( response = client.post(
"/auth/register", "/api/v1/auth/register",
json={"email": "new@example.com", "password": "password123", "display_name": "New"} json={"email": "new@example.com", "password": "password123", "display_name": "New"}
) )
@@ -39,7 +39,7 @@ def test_login_success():
mock_hm.authenticate_user.return_value = User(id="1", username="test@example.com") mock_hm.authenticate_user.return_value = User(id="1", username="test@example.com")
response = client.post( response = client.post(
"/auth/login", "/api/v1/auth/login",
data={"username": "test@example.com", "password": "password123"} data={"username": "test@example.com", "password": "password123"}
) )
@@ -53,7 +53,7 @@ def test_login_invalid_credentials():
mock_hm.authenticate_user.return_value = None mock_hm.authenticate_user.return_value = None
response = client.post( response = client.post(
"/auth/login", "/api/v1/auth/login",
data={"username": "test@example.com", "password": "wrongpassword"} data={"username": "test@example.com", "password": "wrongpassword"}
) )
@@ -62,7 +62,7 @@ def test_login_invalid_credentials():
def test_protected_route_without_token(): def test_protected_route_without_token():
"""Test that protected routes require a token.""" """Test that protected routes require a token."""
response = client.get("/auth/me") response = client.get("/api/v1/auth/me")
assert response.status_code == 401 assert response.status_code == 401
def test_oidc_login_redirect(): def test_oidc_login_redirect():
@@ -70,12 +70,12 @@ def test_oidc_login_redirect():
with patch("ea_chatbot.api.routers.auth.oidc_client") as mock_oidc: with patch("ea_chatbot.api.routers.auth.oidc_client") as mock_oidc:
mock_oidc.get_login_url.return_value = "https://oidc-provider.com/auth" mock_oidc.get_login_url.return_value = "https://oidc-provider.com/auth"
response = client.get("/auth/oidc/login") response = client.get("/api/v1/auth/oidc/login")
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["url"] == "https://oidc-provider.com/auth" assert response.json()["url"] == "https://oidc-provider.com/auth"
def test_oidc_callback_success(): def test_oidc_callback_success_ajax():
"""Test successful OIDC callback and JWT issuance.""" """Test successful OIDC callback and JWT issuance via AJAX."""
with patch("ea_chatbot.api.routers.auth.oidc_client") as mock_oidc, \ with patch("ea_chatbot.api.routers.auth.oidc_client") as mock_oidc, \
patch("ea_chatbot.api.routers.auth.history_manager") as mock_hm: patch("ea_chatbot.api.routers.auth.history_manager") as mock_hm:
@@ -83,7 +83,10 @@ def test_oidc_callback_success():
mock_oidc.get_user_info.return_value = {"email": "sso@example.com", "name": "SSO User"} mock_oidc.get_user_info.return_value = {"email": "sso@example.com", "name": "SSO User"}
mock_hm.sync_user_from_oidc.return_value = User(id="sso-123", username="sso@example.com", display_name="SSO User") mock_hm.sync_user_from_oidc.return_value = User(id="sso-123", username="sso@example.com", display_name="SSO User")
response = client.get("/auth/oidc/callback?code=some-code") response = client.get(
"/api/v1/auth/oidc/callback?code=some-code",
headers={"Accept": "application/json"}
)
assert response.status_code == 200 assert response.status_code == 200
assert "access_token" in response.json() assert "access_token" in response.json()
@@ -98,7 +101,7 @@ def test_get_me_success():
mock_hm.get_user_by_id.return_value = User(id="123", username="test@example.com", display_name="Test") mock_hm.get_user_by_id.return_value = User(id="123", username="test@example.com", display_name="Test")
response = client.get( response = client.get(
"/auth/me", "/api/v1/auth/me",
headers={"Authorization": f"Bearer {token}"} headers={"Authorization": f"Bearer {token}"}
) )

View File

@@ -35,7 +35,7 @@ def test_get_conversations_success(auth_header, mock_user):
) )
] ]
response = client.get("/conversations", headers=auth_header) response = client.get("/api/v1/conversations", headers=auth_header)
assert response.status_code == 200 assert response.status_code == 200
assert len(response.json()) == 1 assert len(response.json()) == 1
@@ -53,7 +53,7 @@ def test_create_conversation_success(auth_header, mock_user):
) )
response = client.post( response = client.post(
"/conversations", "/api/v1/conversations",
json={"name": "New Conv"}, json={"name": "New Conv"},
headers=auth_header headers=auth_header
) )
@@ -75,7 +75,7 @@ def test_get_messages_success(auth_header):
) )
] ]
response = client.get("/conversations/c1/messages", headers=auth_header) response = client.get("/api/v1/conversations/c1/messages", headers=auth_header)
assert response.status_code == 200 assert response.status_code == 200
assert len(response.json()) == 1 assert len(response.json()) == 1
@@ -86,7 +86,7 @@ def test_delete_conversation_success(auth_header):
with patch("ea_chatbot.api.routers.history.history_manager") as mock_hm: with patch("ea_chatbot.api.routers.history.history_manager") as mock_hm:
mock_hm.delete_conversation.return_value = True mock_hm.delete_conversation.return_value = True
response = client.delete("/conversations/c1", headers=auth_header) response = client.delete("/api/v1/conversations/c1", headers=auth_header)
assert response.status_code == 204 assert response.status_code == 204
def test_get_plot_success(auth_header, mock_user): def test_get_plot_success(auth_header, mock_user):
@@ -109,7 +109,7 @@ def test_get_plot_success(auth_header, mock_user):
mock_session.get.side_effect = mock_get mock_session.get.side_effect = mock_get
response = client.get("/artifacts/plots/p1", headers=auth_header) response = client.get("/api/v1/artifacts/plots/p1", headers=auth_header)
assert response.status_code == 200 assert response.status_code == 200
assert response.content == b"fake-image-data" assert response.content == b"fake-image-data"

View File

@@ -46,7 +46,7 @@ def test_persistence_integration_success(auth_header, mock_user):
mock_session.get.return_value = mock_conv mock_session.get.return_value = mock_conv
# Act # Act
with client.stream("POST", "/chat/stream", with client.stream("POST", "/api/v1/chat/stream",
json={"message": "persistence test", "thread_id": "t1"}, json={"message": "persistence test", "thread_id": "t1"},
headers=auth_header) as response: headers=auth_header) as response:
assert response.status_code == 200 assert response.status_code == 200