diff --git a/backend/src/ea_chatbot/api/dependencies.py b/backend/src/ea_chatbot/api/dependencies.py index e3d2b9f..69d93c8 100644 --- a/backend/src/ea_chatbot/api/dependencies.py +++ b/backend/src/ea_chatbot/api/dependencies.py @@ -25,9 +25,8 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/v1/auth/login", auto_error=Fa async def get_current_user(request: Request, token: str = Depends(oauth2_scheme)) -> User: """Dependency to get the current authenticated user from the JWT token (cookie or header).""" - # Try getting token from cookie first - if not token: - token = request.cookies.get("access_token") + # Prioritize cookie, fallback to header + token = request.cookies.get("access_token") or token credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, diff --git a/backend/src/ea_chatbot/api/routers/auth.py b/backend/src/ea_chatbot/api/routers/auth.py index d6792e7..65ca3e6 100644 --- a/backend/src/ea_chatbot/api/routers/auth.py +++ b/backend/src/ea_chatbot/api/routers/auth.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter, Depends, HTTPException, status, Response +from fastapi import APIRouter, Depends, HTTPException, status, Response, Request from fastapi.responses import RedirectResponse from fastapi.security import OAuth2PasswordRequestForm from ea_chatbot.api.utils import create_access_token @@ -6,6 +6,9 @@ from ea_chatbot.api.dependencies import history_manager, oidc_client, get_curren from ea_chatbot.history.models import User as UserDB from ea_chatbot.api.schemas import Token, UserCreate, UserResponse import os +import logging + +logger = logging.getLogger(__name__) router = APIRouter(prefix="/auth", tags=["auth"]) @@ -80,9 +83,9 @@ async def oidc_login(): url = oidc_client.get_login_url() return {"url": url} -@router.get("/oidc/callback") -async def oidc_callback(code: str): - """Handle the OIDC callback, issue a JWT, and redirect to frontend.""" +@router.get("/oidc/callback", response_model=Token) +async def oidc_callback(request: Request, response: Response, code: str): + """Handle the OIDC callback, issue a JWT, and redirect or return JSON.""" if not oidc_client: raise HTTPException(status_code=status.HTTP_510_NOT_EXTENDED, detail="OIDC not configured") @@ -99,12 +102,29 @@ async def oidc_callback(code: str): access_token = create_access_token(data={"sub": str(user.id)}) - response = RedirectResponse(url=f"{FRONTEND_URL}/auth/callback") - set_auth_cookie(response, access_token) - return response + # Determine if we should redirect (direct provider callback) or return JSON (frontend exchange) + is_ajax = request.headers.get("X-Requested-With") == "XMLHttpRequest" or \ + "application/json" in request.headers.get("Accept", "") + + if is_ajax: + set_auth_cookie(response, access_token) + return {"access_token": access_token, "token_type": "bearer"} + else: + redirect_response = RedirectResponse(url=f"{FRONTEND_URL}/auth/callback") + set_auth_cookie(redirect_response, access_token) + return redirect_response + + except HTTPException: + raise except Exception as e: - # Redirect to frontend with error - return RedirectResponse(url=f"{FRONTEND_URL}?error=oidc_failed") + logger.exception("OIDC authentication failed") + # For non-ajax, redirect with error + is_ajax = request.headers.get("X-Requested-With") == "XMLHttpRequest" or \ + "application/json" in request.headers.get("Accept", "") + if is_ajax: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="OIDC authentication failed") + else: + return RedirectResponse(url=f"{FRONTEND_URL}?error=oidc_failed") @router.get("/me", response_model=UserResponse) async def get_me(current_user: UserDB = Depends(get_current_user)): diff --git a/backend/tests/api/test_agent.py b/backend/tests/api/test_agent.py index 09fc007..aef9e79 100644 --- a/backend/tests/api/test_agent.py +++ b/backend/tests/api/test_agent.py @@ -21,7 +21,7 @@ def auth_header(mock_user): def test_stream_agent_unauthorized(): """Test that streaming requires authentication.""" - response = client.post("/chat/stream", json={"message": "hello"}) + response = client.post("/api/v1/chat/stream", json={"message": "hello"}) assert response.status_code == 401 @pytest.mark.asyncio @@ -52,7 +52,7 @@ async def test_stream_agent_success(auth_header, mock_user): mock_session.get.return_value = mock_conv # Using TestClient with a stream context - with client.stream("POST", "/chat/stream", + with client.stream("POST", "/api/v1/chat/stream", json={"message": "hello", "thread_id": "t1"}, headers=auth_header) as response: assert response.status_code == 200 diff --git a/backend/tests/api/test_api_auth.py b/backend/tests/api/test_api_auth.py index 9b8d94c..799da97 100644 --- a/backend/tests/api/test_api_auth.py +++ b/backend/tests/api/test_api_auth.py @@ -26,7 +26,7 @@ def test_register_user_success(): mock_hm.create_user.return_value = User(id="1", username="new@example.com", display_name="New") response = client.post( - "/auth/register", + "/api/v1/auth/register", json={"email": "new@example.com", "password": "password123", "display_name": "New"} ) @@ -39,7 +39,7 @@ def test_login_success(): mock_hm.authenticate_user.return_value = User(id="1", username="test@example.com") response = client.post( - "/auth/login", + "/api/v1/auth/login", data={"username": "test@example.com", "password": "password123"} ) @@ -53,7 +53,7 @@ def test_login_invalid_credentials(): mock_hm.authenticate_user.return_value = None response = client.post( - "/auth/login", + "/api/v1/auth/login", data={"username": "test@example.com", "password": "wrongpassword"} ) @@ -62,7 +62,7 @@ def test_login_invalid_credentials(): def test_protected_route_without_token(): """Test that protected routes require a token.""" - response = client.get("/auth/me") + response = client.get("/api/v1/auth/me") assert response.status_code == 401 def test_oidc_login_redirect(): @@ -70,12 +70,12 @@ def test_oidc_login_redirect(): with patch("ea_chatbot.api.routers.auth.oidc_client") as mock_oidc: mock_oidc.get_login_url.return_value = "https://oidc-provider.com/auth" - response = client.get("/auth/oidc/login") + response = client.get("/api/v1/auth/oidc/login") assert response.status_code == 200 assert response.json()["url"] == "https://oidc-provider.com/auth" -def test_oidc_callback_success(): - """Test successful OIDC callback and JWT issuance.""" +def test_oidc_callback_success_ajax(): + """Test successful OIDC callback and JWT issuance via AJAX.""" with patch("ea_chatbot.api.routers.auth.oidc_client") as mock_oidc, \ patch("ea_chatbot.api.routers.auth.history_manager") as mock_hm: @@ -83,7 +83,10 @@ def test_oidc_callback_success(): mock_oidc.get_user_info.return_value = {"email": "sso@example.com", "name": "SSO User"} mock_hm.sync_user_from_oidc.return_value = User(id="sso-123", username="sso@example.com", display_name="SSO User") - response = client.get("/auth/oidc/callback?code=some-code") + response = client.get( + "/api/v1/auth/oidc/callback?code=some-code", + headers={"Accept": "application/json"} + ) assert response.status_code == 200 assert "access_token" in response.json() @@ -98,7 +101,7 @@ def test_get_me_success(): mock_hm.get_user_by_id.return_value = User(id="123", username="test@example.com", display_name="Test") response = client.get( - "/auth/me", + "/api/v1/auth/me", headers={"Authorization": f"Bearer {token}"} ) diff --git a/backend/tests/api/test_api_history.py b/backend/tests/api/test_api_history.py index d322e44..e0b3117 100644 --- a/backend/tests/api/test_api_history.py +++ b/backend/tests/api/test_api_history.py @@ -35,7 +35,7 @@ def test_get_conversations_success(auth_header, mock_user): ) ] - response = client.get("/conversations", headers=auth_header) + response = client.get("/api/v1/conversations", headers=auth_header) assert response.status_code == 200 assert len(response.json()) == 1 @@ -53,7 +53,7 @@ def test_create_conversation_success(auth_header, mock_user): ) response = client.post( - "/conversations", + "/api/v1/conversations", json={"name": "New Conv"}, headers=auth_header ) @@ -75,7 +75,7 @@ def test_get_messages_success(auth_header): ) ] - response = client.get("/conversations/c1/messages", headers=auth_header) + response = client.get("/api/v1/conversations/c1/messages", headers=auth_header) assert response.status_code == 200 assert len(response.json()) == 1 @@ -86,7 +86,7 @@ def test_delete_conversation_success(auth_header): with patch("ea_chatbot.api.routers.history.history_manager") as mock_hm: mock_hm.delete_conversation.return_value = True - response = client.delete("/conversations/c1", headers=auth_header) + response = client.delete("/api/v1/conversations/c1", headers=auth_header) assert response.status_code == 204 def test_get_plot_success(auth_header, mock_user): @@ -109,7 +109,7 @@ def test_get_plot_success(auth_header, mock_user): mock_session.get.side_effect = mock_get - response = client.get("/artifacts/plots/p1", headers=auth_header) + response = client.get("/api/v1/artifacts/plots/p1", headers=auth_header) assert response.status_code == 200 assert response.content == b"fake-image-data" diff --git a/backend/tests/api/test_persistence.py b/backend/tests/api/test_persistence.py index d1fb265..0b9089e 100644 --- a/backend/tests/api/test_persistence.py +++ b/backend/tests/api/test_persistence.py @@ -46,7 +46,7 @@ def test_persistence_integration_success(auth_header, mock_user): mock_session.get.return_value = mock_conv # Act - with client.stream("POST", "/chat/stream", + with client.stream("POST", "/api/v1/chat/stream", json={"message": "persistence test", "thread_id": "t1"}, headers=auth_header) as response: assert response.status_code == 200