fix(backend): Refactor OIDC callback and auth dependency to correctly handle cookies and prefix all API routes with /api/v1.
This commit is contained in:
@@ -25,9 +25,8 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/v1/auth/login", auto_error=Fa
|
|||||||
|
|
||||||
async def get_current_user(request: Request, token: str = Depends(oauth2_scheme)) -> User:
|
async def get_current_user(request: Request, token: str = Depends(oauth2_scheme)) -> User:
|
||||||
"""Dependency to get the current authenticated user from the JWT token (cookie or header)."""
|
"""Dependency to get the current authenticated user from the JWT token (cookie or header)."""
|
||||||
# Try getting token from cookie first
|
# Prioritize cookie, fallback to header
|
||||||
if not token:
|
token = request.cookies.get("access_token") or token
|
||||||
token = request.cookies.get("access_token")
|
|
||||||
|
|
||||||
credentials_exception = HTTPException(
|
credentials_exception = HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from fastapi import APIRouter, Depends, HTTPException, status, Response
|
from fastapi import APIRouter, Depends, HTTPException, status, Response, Request
|
||||||
from fastapi.responses import RedirectResponse
|
from fastapi.responses import RedirectResponse
|
||||||
from fastapi.security import OAuth2PasswordRequestForm
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
from ea_chatbot.api.utils import create_access_token
|
from ea_chatbot.api.utils import create_access_token
|
||||||
@@ -6,6 +6,9 @@ from ea_chatbot.api.dependencies import history_manager, oidc_client, get_curren
|
|||||||
from ea_chatbot.history.models import User as UserDB
|
from ea_chatbot.history.models import User as UserDB
|
||||||
from ea_chatbot.api.schemas import Token, UserCreate, UserResponse
|
from ea_chatbot.api.schemas import Token, UserCreate, UserResponse
|
||||||
import os
|
import os
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||||
|
|
||||||
@@ -80,9 +83,9 @@ async def oidc_login():
|
|||||||
url = oidc_client.get_login_url()
|
url = oidc_client.get_login_url()
|
||||||
return {"url": url}
|
return {"url": url}
|
||||||
|
|
||||||
@router.get("/oidc/callback")
|
@router.get("/oidc/callback", response_model=Token)
|
||||||
async def oidc_callback(code: str):
|
async def oidc_callback(request: Request, response: Response, code: str):
|
||||||
"""Handle the OIDC callback, issue a JWT, and redirect to frontend."""
|
"""Handle the OIDC callback, issue a JWT, and redirect or return JSON."""
|
||||||
if not oidc_client:
|
if not oidc_client:
|
||||||
raise HTTPException(status_code=status.HTTP_510_NOT_EXTENDED, detail="OIDC not configured")
|
raise HTTPException(status_code=status.HTTP_510_NOT_EXTENDED, detail="OIDC not configured")
|
||||||
|
|
||||||
@@ -99,12 +102,29 @@ async def oidc_callback(code: str):
|
|||||||
|
|
||||||
access_token = create_access_token(data={"sub": str(user.id)})
|
access_token = create_access_token(data={"sub": str(user.id)})
|
||||||
|
|
||||||
response = RedirectResponse(url=f"{FRONTEND_URL}/auth/callback")
|
# Determine if we should redirect (direct provider callback) or return JSON (frontend exchange)
|
||||||
set_auth_cookie(response, access_token)
|
is_ajax = request.headers.get("X-Requested-With") == "XMLHttpRequest" or \
|
||||||
return response
|
"application/json" in request.headers.get("Accept", "")
|
||||||
|
|
||||||
|
if is_ajax:
|
||||||
|
set_auth_cookie(response, access_token)
|
||||||
|
return {"access_token": access_token, "token_type": "bearer"}
|
||||||
|
else:
|
||||||
|
redirect_response = RedirectResponse(url=f"{FRONTEND_URL}/auth/callback")
|
||||||
|
set_auth_cookie(redirect_response, access_token)
|
||||||
|
return redirect_response
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Redirect to frontend with error
|
logger.exception("OIDC authentication failed")
|
||||||
return RedirectResponse(url=f"{FRONTEND_URL}?error=oidc_failed")
|
# For non-ajax, redirect with error
|
||||||
|
is_ajax = request.headers.get("X-Requested-With") == "XMLHttpRequest" or \
|
||||||
|
"application/json" in request.headers.get("Accept", "")
|
||||||
|
if is_ajax:
|
||||||
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="OIDC authentication failed")
|
||||||
|
else:
|
||||||
|
return RedirectResponse(url=f"{FRONTEND_URL}?error=oidc_failed")
|
||||||
|
|
||||||
@router.get("/me", response_model=UserResponse)
|
@router.get("/me", response_model=UserResponse)
|
||||||
async def get_me(current_user: UserDB = Depends(get_current_user)):
|
async def get_me(current_user: UserDB = Depends(get_current_user)):
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ def auth_header(mock_user):
|
|||||||
|
|
||||||
def test_stream_agent_unauthorized():
|
def test_stream_agent_unauthorized():
|
||||||
"""Test that streaming requires authentication."""
|
"""Test that streaming requires authentication."""
|
||||||
response = client.post("/chat/stream", json={"message": "hello"})
|
response = client.post("/api/v1/chat/stream", json={"message": "hello"})
|
||||||
assert response.status_code == 401
|
assert response.status_code == 401
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -52,7 +52,7 @@ async def test_stream_agent_success(auth_header, mock_user):
|
|||||||
mock_session.get.return_value = mock_conv
|
mock_session.get.return_value = mock_conv
|
||||||
|
|
||||||
# Using TestClient with a stream context
|
# Using TestClient with a stream context
|
||||||
with client.stream("POST", "/chat/stream",
|
with client.stream("POST", "/api/v1/chat/stream",
|
||||||
json={"message": "hello", "thread_id": "t1"},
|
json={"message": "hello", "thread_id": "t1"},
|
||||||
headers=auth_header) as response:
|
headers=auth_header) as response:
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ def test_register_user_success():
|
|||||||
mock_hm.create_user.return_value = User(id="1", username="new@example.com", display_name="New")
|
mock_hm.create_user.return_value = User(id="1", username="new@example.com", display_name="New")
|
||||||
|
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/auth/register",
|
"/api/v1/auth/register",
|
||||||
json={"email": "new@example.com", "password": "password123", "display_name": "New"}
|
json={"email": "new@example.com", "password": "password123", "display_name": "New"}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -39,7 +39,7 @@ def test_login_success():
|
|||||||
mock_hm.authenticate_user.return_value = User(id="1", username="test@example.com")
|
mock_hm.authenticate_user.return_value = User(id="1", username="test@example.com")
|
||||||
|
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/auth/login",
|
"/api/v1/auth/login",
|
||||||
data={"username": "test@example.com", "password": "password123"}
|
data={"username": "test@example.com", "password": "password123"}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -53,7 +53,7 @@ def test_login_invalid_credentials():
|
|||||||
mock_hm.authenticate_user.return_value = None
|
mock_hm.authenticate_user.return_value = None
|
||||||
|
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/auth/login",
|
"/api/v1/auth/login",
|
||||||
data={"username": "test@example.com", "password": "wrongpassword"}
|
data={"username": "test@example.com", "password": "wrongpassword"}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -62,7 +62,7 @@ def test_login_invalid_credentials():
|
|||||||
|
|
||||||
def test_protected_route_without_token():
|
def test_protected_route_without_token():
|
||||||
"""Test that protected routes require a token."""
|
"""Test that protected routes require a token."""
|
||||||
response = client.get("/auth/me")
|
response = client.get("/api/v1/auth/me")
|
||||||
assert response.status_code == 401
|
assert response.status_code == 401
|
||||||
|
|
||||||
def test_oidc_login_redirect():
|
def test_oidc_login_redirect():
|
||||||
@@ -70,12 +70,12 @@ def test_oidc_login_redirect():
|
|||||||
with patch("ea_chatbot.api.routers.auth.oidc_client") as mock_oidc:
|
with patch("ea_chatbot.api.routers.auth.oidc_client") as mock_oidc:
|
||||||
mock_oidc.get_login_url.return_value = "https://oidc-provider.com/auth"
|
mock_oidc.get_login_url.return_value = "https://oidc-provider.com/auth"
|
||||||
|
|
||||||
response = client.get("/auth/oidc/login")
|
response = client.get("/api/v1/auth/oidc/login")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json()["url"] == "https://oidc-provider.com/auth"
|
assert response.json()["url"] == "https://oidc-provider.com/auth"
|
||||||
|
|
||||||
def test_oidc_callback_success():
|
def test_oidc_callback_success_ajax():
|
||||||
"""Test successful OIDC callback and JWT issuance."""
|
"""Test successful OIDC callback and JWT issuance via AJAX."""
|
||||||
with patch("ea_chatbot.api.routers.auth.oidc_client") as mock_oidc, \
|
with patch("ea_chatbot.api.routers.auth.oidc_client") as mock_oidc, \
|
||||||
patch("ea_chatbot.api.routers.auth.history_manager") as mock_hm:
|
patch("ea_chatbot.api.routers.auth.history_manager") as mock_hm:
|
||||||
|
|
||||||
@@ -83,7 +83,10 @@ def test_oidc_callback_success():
|
|||||||
mock_oidc.get_user_info.return_value = {"email": "sso@example.com", "name": "SSO User"}
|
mock_oidc.get_user_info.return_value = {"email": "sso@example.com", "name": "SSO User"}
|
||||||
mock_hm.sync_user_from_oidc.return_value = User(id="sso-123", username="sso@example.com", display_name="SSO User")
|
mock_hm.sync_user_from_oidc.return_value = User(id="sso-123", username="sso@example.com", display_name="SSO User")
|
||||||
|
|
||||||
response = client.get("/auth/oidc/callback?code=some-code")
|
response = client.get(
|
||||||
|
"/api/v1/auth/oidc/callback?code=some-code",
|
||||||
|
headers={"Accept": "application/json"}
|
||||||
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert "access_token" in response.json()
|
assert "access_token" in response.json()
|
||||||
@@ -98,7 +101,7 @@ def test_get_me_success():
|
|||||||
mock_hm.get_user_by_id.return_value = User(id="123", username="test@example.com", display_name="Test")
|
mock_hm.get_user_by_id.return_value = User(id="123", username="test@example.com", display_name="Test")
|
||||||
|
|
||||||
response = client.get(
|
response = client.get(
|
||||||
"/auth/me",
|
"/api/v1/auth/me",
|
||||||
headers={"Authorization": f"Bearer {token}"}
|
headers={"Authorization": f"Bearer {token}"}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ def test_get_conversations_success(auth_header, mock_user):
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
response = client.get("/conversations", headers=auth_header)
|
response = client.get("/api/v1/conversations", headers=auth_header)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert len(response.json()) == 1
|
assert len(response.json()) == 1
|
||||||
@@ -53,7 +53,7 @@ def test_create_conversation_success(auth_header, mock_user):
|
|||||||
)
|
)
|
||||||
|
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/conversations",
|
"/api/v1/conversations",
|
||||||
json={"name": "New Conv"},
|
json={"name": "New Conv"},
|
||||||
headers=auth_header
|
headers=auth_header
|
||||||
)
|
)
|
||||||
@@ -75,7 +75,7 @@ def test_get_messages_success(auth_header):
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
response = client.get("/conversations/c1/messages", headers=auth_header)
|
response = client.get("/api/v1/conversations/c1/messages", headers=auth_header)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert len(response.json()) == 1
|
assert len(response.json()) == 1
|
||||||
@@ -86,7 +86,7 @@ def test_delete_conversation_success(auth_header):
|
|||||||
with patch("ea_chatbot.api.routers.history.history_manager") as mock_hm:
|
with patch("ea_chatbot.api.routers.history.history_manager") as mock_hm:
|
||||||
mock_hm.delete_conversation.return_value = True
|
mock_hm.delete_conversation.return_value = True
|
||||||
|
|
||||||
response = client.delete("/conversations/c1", headers=auth_header)
|
response = client.delete("/api/v1/conversations/c1", headers=auth_header)
|
||||||
assert response.status_code == 204
|
assert response.status_code == 204
|
||||||
|
|
||||||
def test_get_plot_success(auth_header, mock_user):
|
def test_get_plot_success(auth_header, mock_user):
|
||||||
@@ -109,7 +109,7 @@ def test_get_plot_success(auth_header, mock_user):
|
|||||||
|
|
||||||
mock_session.get.side_effect = mock_get
|
mock_session.get.side_effect = mock_get
|
||||||
|
|
||||||
response = client.get("/artifacts/plots/p1", headers=auth_header)
|
response = client.get("/api/v1/artifacts/plots/p1", headers=auth_header)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.content == b"fake-image-data"
|
assert response.content == b"fake-image-data"
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ def test_persistence_integration_success(auth_header, mock_user):
|
|||||||
mock_session.get.return_value = mock_conv
|
mock_session.get.return_value = mock_conv
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
with client.stream("POST", "/chat/stream",
|
with client.stream("POST", "/api/v1/chat/stream",
|
||||||
json={"message": "persistence test", "thread_id": "t1"},
|
json={"message": "persistence test", "thread_id": "t1"},
|
||||||
headers=auth_header) as response:
|
headers=auth_header) as response:
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|||||||
Reference in New Issue
Block a user