From 68c09854823fed85db2e7ad7078d5232c2f2d24f Mon Sep 17 00:00:00 2001 From: Yunxiao Xu Date: Sun, 15 Feb 2026 02:50:26 -0800 Subject: [PATCH] feat(auth): Complete OIDC security refactor and modernize test suite - Refactored OIDC flow to implement PKCE, state/nonce validation, and BFF pattern. - Centralized configuration in Settings class (DEV_MODE, FRONTEND_URL, OIDC_REDIRECT_URI). - Updated auth routers to use conditional secure cookie flags based on DEV_MODE. - Modernized and cleaned up test suite by removing legacy Streamlit tests. - Fixed linting errors and unused imports across the backend. --- GEMINI.md | 11 ++- backend/.env.example | 5 +- backend/src/ea_chatbot/api/dependencies.py | 3 +- backend/src/ea_chatbot/api/main.py | 3 - backend/src/ea_chatbot/api/routers/agent.py | 5 +- .../src/ea_chatbot/api/routers/artifacts.py | 1 - backend/src/ea_chatbot/api/routers/auth.py | 99 ++++++++++++------- backend/src/ea_chatbot/api/utils.py | 2 +- backend/src/ea_chatbot/app.py | 3 +- backend/src/ea_chatbot/config.py | 3 + .../ea_chatbot/graph/nodes/clarification.py | 1 - backend/src/ea_chatbot/graph/nodes/coder.py | 2 +- .../src/ea_chatbot/graph/nodes/executor.py | 2 +- .../src/ea_chatbot/graph/nodes/researcher.py | 1 - .../src/ea_chatbot/graph/nodes/summarizer.py | 1 - backend/src/ea_chatbot/graph/state.py | 2 +- backend/src/ea_chatbot/history/manager.py | 4 +- .../ea_chatbot/utils/database_inspection.py | 3 +- backend/src/ea_chatbot/utils/helpers.py | 3 +- backend/src/ea_chatbot/utils/llm_factory.py | 2 +- backend/tests/api/test_api_auth.py | 36 ++++--- backend/tests/api/test_api_auth_cookie.py | 24 +++-- backend/tests/api/test_oidc_flow.py | 82 +++++++++++++++ backend/tests/api/test_persistence.py | 4 +- backend/tests/graph/test_checkpoint.py | 1 - backend/tests/test_app.py | 90 ----------------- backend/tests/test_app_auth.py | 83 ---------------- backend/tests/test_auth.py | 43 -------- backend/tests/test_auth_flow.py | 49 --------- backend/tests/test_config.py | 2 - backend/tests/test_database_inspection.py | 2 +- backend/tests/test_executor.py | 1 - backend/tests/test_helpers.py | 1 - backend/tests/test_history_models.py | 2 - backend/tests/test_llm_factory_callbacks.py | 1 - backend/tests/test_logging_e2e.py | 3 +- .../tests/test_multi_turn_query_analyzer.py | 1 - backend/tests/test_oidc_client.py | 87 ---------------- backend/tests/test_oidc_validation.py | 3 +- backend/tests/test_query_analyzer.py | 1 - backend/tests/test_query_analyzer_logging.py | 1 - .../tests/test_query_analyzer_refinement.py | 1 - backend/tests/test_state.py | 5 +- backend/tests/test_summarizer.py | 1 - backend/tests/test_workflow.py | 1 - backend/tests/test_workflow_e2e.py | 1 - frontend/src/App.tsx | 6 +- frontend/src/components/auth/AuthCallback.tsx | 39 -------- frontend/src/components/auth/LoginForm.tsx | 5 +- frontend/src/services/auth.ts | 5 - 50 files changed, 222 insertions(+), 515 deletions(-) create mode 100644 backend/tests/api/test_oidc_flow.py delete mode 100644 backend/tests/test_app.py delete mode 100644 backend/tests/test_app_auth.py delete mode 100644 backend/tests/test_auth.py delete mode 100644 backend/tests/test_auth_flow.py delete mode 100644 backend/tests/test_oidc_client.py delete mode 100644 frontend/src/components/auth/AuthCallback.tsx diff --git a/GEMINI.md b/GEMINI.md index 0f36b8c..de05bf5 100644 --- a/GEMINI.md +++ b/GEMINI.md @@ -48,8 +48,9 @@ The frontend is a modern SPA (Single Page Application) designed for data-heavy i - **LangChain Docs**: See the `langchain-docs/` folder for local LangChain and LangGraph documentation. ## Git Operations -- Branches should be used for specific features or bug fixes. -- New branches should be created from the `main` branch and `conductor` branch. -- The conductor should always use the `conductor` branch and derived branches. -- When a feature or fix is complete, use rebase to keep the commit history clean before merging. -- The conductor related changes should never be merged into the `main` branch. +- All new feature and bug-fix branches must be created from the `conductor` branch except hot-fix. +- The `conductor` branch serves as the primary development branch where integration occurs. +- The `main` branch is reserved for stable, production-ready code. +- Merges from `conductor` to `main` should only occur when significant milestones are reached and stability is verified. +- Conductor-specific configuration or meta-files should remain on the `conductor` branch or its derivatives and never be merged into the `main` branch. +- Use rebase to keep commit history clean before merging feature branches back into `conductor`. diff --git a/backend/.env.example b/backend/.env.example index ac71a98..71351a6 100644 --- a/backend/.env.example +++ b/backend/.env.example @@ -6,7 +6,8 @@ GOOGLE_API_KEY=your_google_api_key_here DATA_DIR=data DATA_STATE=new_jersey LOG_LEVEL=INFO -DEV_MODE=false +DEV_MODE=true +FRONTEND_URL=http://localhost:5173 # Security & JWT Configuration SECRET_KEY=change-me-in-production @@ -28,7 +29,7 @@ HISTORY_DB_URL=postgresql://user:password@localhost:5433/ea_history OIDC_CLIENT_ID=your_client_id OIDC_CLIENT_SECRET=your_client_secret OIDC_SERVER_METADATA_URL=https://your-authentik.example.com/application/o/ea-chatbot/.well-known/openid-configuration -OIDC_REDIRECT_URI=http://localhost:8501 +OIDC_REDIRECT_URI=http://localhost:8000/api/v1/auth/oidc/callback # Node Configuration Overrides (Optional) # Format: _LLM__ diff --git a/backend/src/ea_chatbot/api/dependencies.py b/backend/src/ea_chatbot/api/dependencies.py index 69d93c8..72b5796 100644 --- a/backend/src/ea_chatbot/api/dependencies.py +++ b/backend/src/ea_chatbot/api/dependencies.py @@ -1,4 +1,3 @@ -import os from fastapi import Depends, HTTPException, status, Request from fastapi.security import OAuth2PasswordBearer from ea_chatbot.config import Settings @@ -18,7 +17,7 @@ if settings.oidc_client_id and settings.oidc_client_secret and settings.oidc_ser client_id=settings.oidc_client_id, client_secret=settings.oidc_client_secret, server_metadata_url=settings.oidc_server_metadata_url, - redirect_uri=os.getenv("OIDC_REDIRECT_URI", "http://localhost:3000/auth/callback") + redirect_uri=settings.oidc_redirect_uri ) oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/v1/auth/login", auto_error=False) diff --git a/backend/src/ea_chatbot/api/main.py b/backend/src/ea_chatbot/api/main.py index ec045ee..52f9f0a 100644 --- a/backend/src/ea_chatbot/api/main.py +++ b/backend/src/ea_chatbot/api/main.py @@ -1,9 +1,6 @@ from fastapi import FastAPI, APIRouter from fastapi.middleware.cors import CORSMiddleware from ea_chatbot.api.routers import auth, history, artifacts, agent -from dotenv import load_dotenv - -load_dotenv() app = FastAPI( title="Election Analytics Chatbot API", diff --git a/backend/src/ea_chatbot/api/routers/agent.py b/backend/src/ea_chatbot/api/routers/agent.py index f4c2f1a..fb55af0 100644 --- a/backend/src/ea_chatbot/api/routers/agent.py +++ b/backend/src/ea_chatbot/api/routers/agent.py @@ -1,7 +1,6 @@ import json -import asyncio -from typing import AsyncGenerator, Optional, List -from fastapi import APIRouter, Depends, HTTPException, status +from typing import AsyncGenerator, List +from fastapi import APIRouter, Depends, HTTPException from fastapi.responses import StreamingResponse from ea_chatbot.api.dependencies import get_current_user, history_manager from ea_chatbot.api.utils import convert_to_json_compatible diff --git a/backend/src/ea_chatbot/api/routers/artifacts.py b/backend/src/ea_chatbot/api/routers/artifacts.py index 70fd3d6..bc760ff 100644 --- a/backend/src/ea_chatbot/api/routers/artifacts.py +++ b/backend/src/ea_chatbot/api/routers/artifacts.py @@ -1,7 +1,6 @@ from fastapi import APIRouter, Depends, HTTPException, Response, status from ea_chatbot.api.dependencies import get_current_user, history_manager from ea_chatbot.history.models import Plot, Message, User as UserDB -import io router = APIRouter(prefix="/artifacts", tags=["artifacts"]) diff --git a/backend/src/ea_chatbot/api/routers/auth.py b/backend/src/ea_chatbot/api/routers/auth.py index 65ca3e6..1ea5907 100644 --- a/backend/src/ea_chatbot/api/routers/auth.py +++ b/backend/src/ea_chatbot/api/routers/auth.py @@ -1,19 +1,17 @@ from fastapi import APIRouter, Depends, HTTPException, status, Response, Request from fastapi.responses import RedirectResponse from fastapi.security import OAuth2PasswordRequestForm -from ea_chatbot.api.utils import create_access_token +from ea_chatbot.api.utils import create_access_token, settings from ea_chatbot.api.dependencies import history_manager, oidc_client, get_current_user from ea_chatbot.history.models import User as UserDB from ea_chatbot.api.schemas import Token, UserCreate, UserResponse -import os +from ea_chatbot.auth import OIDCSession import logging logger = logging.getLogger(__name__) router = APIRouter(prefix="/auth", tags=["auth"]) -FRONTEND_URL = os.getenv("FRONTEND_URL", "http://localhost:5173") - def set_auth_cookie(response: Response, token: str): response.set_cookie( key="access_token", @@ -22,7 +20,7 @@ def set_auth_cookie(response: Response, token: str): max_age=1800, expires=1800, samesite="lax", - secure=False, # Set to True in production with HTTPS + secure=not settings.dev_mode, ) @router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED) @@ -72,59 +70,84 @@ async def logout(response: Response): return {"detail": "Successfully logged out"} @router.get("/oidc/login") -async def oidc_login(): - """Get the OIDC authorization URL.""" +async def oidc_login(response: Response): + """Get the OIDC authorization URL and set temporary session cookie.""" if not oidc_client: raise HTTPException( status_code=status.HTTP_510_NOT_EXTENDED, detail="OIDC is not configured" ) - url = oidc_client.get_login_url() - return {"url": url} + auth_data = oidc_client.get_auth_data() + + # Store state, nonce, and code_verifier in a secure short-lived cookie + session_token = OIDCSession.encrypt( + { + "state": auth_data["state"], + "nonce": auth_data["nonce"], + "code_verifier": auth_data["code_verifier"] + }, + settings.secret_key + ) + + response.set_cookie( + key="oidc_session", + value=session_token, + httponly=True, + max_age=300, # 5 minutes + samesite="lax", + secure=not settings.dev_mode + ) + + return {"url": auth_data["url"]} -@router.get("/oidc/callback", response_model=Token) -async def oidc_callback(request: Request, response: Response, code: str): - """Handle the OIDC callback, issue a JWT, and redirect or return JSON.""" +@router.get("/oidc/callback") +async def oidc_callback(request: Request, response: Response, code: str, state: str): + """Handle the OIDC callback, validate state/nonce, issue a JWT, and redirect.""" if not oidc_client: raise HTTPException(status_code=status.HTTP_510_NOT_EXTENDED, detail="OIDC not configured") + # 1. Validate state and retrieve session data from cookie + session_token = request.cookies.get("oidc_session") + if not session_token: + logger.error("OIDC session cookie missing") + return RedirectResponse(url=f"{settings.frontend_url}?error=oidc_failed", status_code=status.HTTP_302_FOUND) + + session_data = OIDCSession.decrypt(session_token, settings.secret_key) + if not session_data or session_data.get("state") != state: + logger.error("OIDC state mismatch or session expired") + return RedirectResponse(url=f"{settings.frontend_url}?error=oidc_failed", status_code=status.HTTP_302_FOUND) + try: - token = oidc_client.exchange_code_for_token(code) - user_info = oidc_client.get_user_info(token) - email = user_info.get("email") - name = user_info.get("name") or user_info.get("preferred_username") + # 2. Exchange code for token using PKCE code_verifier + token = oidc_client.exchange_code_for_token(code, code_verifier=session_data.get("code_verifier")) + + # 3. Validate ID Token and Nonce + id_token = token.get("id_token") + if not id_token: + raise ValueError("ID Token missing from response") + + claims = oidc_client.validate_id_token(id_token, nonce=session_data.get("nonce")) + + email = claims.get("email") + name = claims.get("name") or claims.get("preferred_username") if not email: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Email not provided by OIDC") + # 4. Sync user and establish session user = history_manager.sync_user_from_oidc(email=email, display_name=name) - access_token = create_access_token(data={"sub": str(user.id)}) - # Determine if we should redirect (direct provider callback) or return JSON (frontend exchange) - is_ajax = request.headers.get("X-Requested-With") == "XMLHttpRequest" or \ - "application/json" in request.headers.get("Accept", "") - - if is_ajax: - set_auth_cookie(response, access_token) - return {"access_token": access_token, "token_type": "bearer"} - else: - redirect_response = RedirectResponse(url=f"{FRONTEND_URL}/auth/callback") - set_auth_cookie(redirect_response, access_token) - return redirect_response + redirect_response = RedirectResponse(url=settings.frontend_url, status_code=status.HTTP_302_FOUND) + set_auth_cookie(redirect_response, access_token) + redirect_response.delete_cookie("oidc_session") # Cleanup + + return redirect_response - except HTTPException: - raise - except Exception as e: + except Exception: logger.exception("OIDC authentication failed") - # For non-ajax, redirect with error - is_ajax = request.headers.get("X-Requested-With") == "XMLHttpRequest" or \ - "application/json" in request.headers.get("Accept", "") - if is_ajax: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="OIDC authentication failed") - else: - return RedirectResponse(url=f"{FRONTEND_URL}?error=oidc_failed") + return RedirectResponse(url=f"{settings.frontend_url}?error=oidc_failed", status_code=status.HTTP_302_FOUND) @router.get("/me", response_model=UserResponse) async def get_me(current_user: UserDB = Depends(get_current_user)): diff --git a/backend/src/ea_chatbot/api/utils.py b/backend/src/ea_chatbot/api/utils.py index 96ab678..d96801c 100644 --- a/backend/src/ea_chatbot/api/utils.py +++ b/backend/src/ea_chatbot/api/utils.py @@ -1,5 +1,5 @@ from datetime import datetime, timedelta, timezone -from typing import Optional, Union, Any, List, Dict +from typing import Optional, Any from jose import JWTError, jwt from pydantic import BaseModel from langchain_core.messages import BaseMessage diff --git a/backend/src/ea_chatbot/app.py b/backend/src/ea_chatbot/app.py index acd2e26..9585868 100644 --- a/backend/src/ea_chatbot/app.py +++ b/backend/src/ea_chatbot/app.py @@ -1,5 +1,4 @@ import streamlit as st -import asyncio import os import io from dotenv import load_dotenv @@ -348,7 +347,7 @@ def main(): if node_name == "query_analyzer": analysis = state_update.get("analysis", {}) next_action = state_update.get("next_action", "unknown") - status.write(f"🔍 **Analyzed Query:**") + status.write("🔍 **Analyzed Query:**") for k,v in analysis.items(): status.write(f"- {k:<8}: {v}") status.markdown(f"Next Step: {next_action.capitalize()}") diff --git a/backend/src/ea_chatbot/config.py b/backend/src/ea_chatbot/config.py index 54a9097..6c79eaa 100644 --- a/backend/src/ea_chatbot/config.py +++ b/backend/src/ea_chatbot/config.py @@ -19,6 +19,8 @@ class Settings(BaseSettings): data_dir: str = "data" data_state: str = "new_jersey" log_level: str = Field(default="INFO", alias="LOG_LEVEL") + dev_mode: bool = Field(default=True, alias="DEV_MODE") + frontend_url: str = Field(default="http://localhost:5173", alias="FRONTEND_URL") # Voter Database configuration db_host: str = Field(default="localhost", alias="DB_HOST") @@ -40,6 +42,7 @@ class Settings(BaseSettings): oidc_client_id: Optional[str] = Field(default=None, alias="OIDC_CLIENT_ID") oidc_client_secret: Optional[str] = Field(default=None, alias="OIDC_CLIENT_SECRET") oidc_server_metadata_url: Optional[str] = Field(default=None, alias="OIDC_SERVER_METADATA_URL") + oidc_redirect_uri: str = Field(default="http://localhost:8000/api/v1/auth/oidc/callback", alias="OIDC_REDIRECT_URI") # Default configurations for each node query_analyzer_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0)) diff --git a/backend/src/ea_chatbot/graph/nodes/clarification.py b/backend/src/ea_chatbot/graph/nodes/clarification.py index 49198e9..58e887d 100644 --- a/backend/src/ea_chatbot/graph/nodes/clarification.py +++ b/backend/src/ea_chatbot/graph/nodes/clarification.py @@ -1,4 +1,3 @@ -from langchain_core.messages import AIMessage from ea_chatbot.graph.state import AgentState from ea_chatbot.config import Settings from ea_chatbot.utils.llm_factory import get_llm_model diff --git a/backend/src/ea_chatbot/graph/nodes/coder.py b/backend/src/ea_chatbot/graph/nodes/coder.py index f9864e5..8a3bd84 100644 --- a/backend/src/ea_chatbot/graph/nodes/coder.py +++ b/backend/src/ea_chatbot/graph/nodes/coder.py @@ -1,7 +1,7 @@ from ea_chatbot.graph.state import AgentState from ea_chatbot.config import Settings from ea_chatbot.utils.llm_factory import get_llm_model -from ea_chatbot.utils import helpers, database_inspection +from ea_chatbot.utils import database_inspection from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler from ea_chatbot.graph.prompts.coder import CODE_GENERATOR_PROMPT from ea_chatbot.schemas import CodeGenerationResponse diff --git a/backend/src/ea_chatbot/graph/nodes/executor.py b/backend/src/ea_chatbot/graph/nodes/executor.py index 8b10dce..beeefa1 100644 --- a/backend/src/ea_chatbot/graph/nodes/executor.py +++ b/backend/src/ea_chatbot/graph/nodes/executor.py @@ -2,7 +2,7 @@ import io import sys import traceback from contextlib import redirect_stdout -from typing import Any, Dict, List, TYPE_CHECKING +from typing import TYPE_CHECKING import pandas as pd from matplotlib.figure import Figure diff --git a/backend/src/ea_chatbot/graph/nodes/researcher.py b/backend/src/ea_chatbot/graph/nodes/researcher.py index 26a0f6f..84aa53b 100644 --- a/backend/src/ea_chatbot/graph/nodes/researcher.py +++ b/backend/src/ea_chatbot/graph/nodes/researcher.py @@ -1,4 +1,3 @@ -from langchain_core.messages import AIMessage from langchain_openai import ChatOpenAI from langchain_google_genai import ChatGoogleGenerativeAI from ea_chatbot.graph.state import AgentState diff --git a/backend/src/ea_chatbot/graph/nodes/summarizer.py b/backend/src/ea_chatbot/graph/nodes/summarizer.py index 57195ac..3796e9d 100644 --- a/backend/src/ea_chatbot/graph/nodes/summarizer.py +++ b/backend/src/ea_chatbot/graph/nodes/summarizer.py @@ -1,4 +1,3 @@ -from langchain_core.messages import AIMessage from ea_chatbot.graph.state import AgentState from ea_chatbot.config import Settings from ea_chatbot.utils.llm_factory import get_llm_model diff --git a/backend/src/ea_chatbot/graph/state.py b/backend/src/ea_chatbot/graph/state.py index 65f2863..0885782 100644 --- a/backend/src/ea_chatbot/graph/state.py +++ b/backend/src/ea_chatbot/graph/state.py @@ -1,4 +1,4 @@ -from typing import TypedDict, Annotated, List, Dict, Any, Optional +from typing import Annotated, List, Dict, Any, Optional from langchain_core.messages import BaseMessage from langchain.agents import AgentState as AS import operator diff --git a/backend/src/ea_chatbot/history/manager.py b/backend/src/ea_chatbot/history/manager.py index 6adc8a1..a11b25d 100644 --- a/backend/src/ea_chatbot/history/manager.py +++ b/backend/src/ea_chatbot/history/manager.py @@ -1,7 +1,7 @@ from contextlib import contextmanager from typing import Optional, List -from sqlalchemy import create_engine, select, delete -from sqlalchemy.orm import sessionmaker, Session +from sqlalchemy import create_engine, select +from sqlalchemy.orm import sessionmaker from argon2 import PasswordHasher from argon2.exceptions import VerifyMismatchError diff --git a/backend/src/ea_chatbot/utils/database_inspection.py b/backend/src/ea_chatbot/utils/database_inspection.py index ec9e4ec..446ea58 100644 --- a/backend/src/ea_chatbot/utils/database_inspection.py +++ b/backend/src/ea_chatbot/utils/database_inspection.py @@ -1,6 +1,5 @@ -from typing import Optional, Dict, Any, List, TYPE_CHECKING +from typing import Optional, Dict, Any, TYPE_CHECKING import yaml -import json import os from ea_chatbot.utils.db_client import DBClient if TYPE_CHECKING: diff --git a/backend/src/ea_chatbot/utils/helpers.py b/backend/src/ea_chatbot/utils/helpers.py index e9b2d16..990c28d 100644 --- a/backend/src/ea_chatbot/utils/helpers.py +++ b/backend/src/ea_chatbot/utils/helpers.py @@ -20,7 +20,8 @@ def to_yaml(json_str: str, indent: int = 2) -> str: """ Attempts to convert a JSON string (potentially malformed from LLM) to a YAML string. """ - if not json_str: return "" + if not json_str: + return "" try: # Try direct parse diff --git a/backend/src/ea_chatbot/utils/llm_factory.py b/backend/src/ea_chatbot/utils/llm_factory.py index c66c067..78c5764 100644 --- a/backend/src/ea_chatbot/utils/llm_factory.py +++ b/backend/src/ea_chatbot/utils/llm_factory.py @@ -1,4 +1,4 @@ -from typing import Optional, cast, TYPE_CHECKING, Literal, Dict, List, Tuple, Any +from typing import Optional, List from langchain_core.language_models.chat_models import BaseChatModel from langchain_openai import ChatOpenAI from langchain_google_genai import ChatGoogleGenerativeAI diff --git a/backend/tests/api/test_api_auth.py b/backend/tests/api/test_api_auth.py index 799da97..33a9c5d 100644 --- a/backend/tests/api/test_api_auth.py +++ b/backend/tests/api/test_api_auth.py @@ -1,6 +1,6 @@ import pytest from fastapi.testclient import TestClient -from unittest.mock import MagicMock, patch +from unittest.mock import patch from ea_chatbot.api.main import app from ea_chatbot.history.models import User @@ -66,31 +66,43 @@ def test_protected_route_without_token(): assert response.status_code == 401 def test_oidc_login_redirect(): - """Test that OIDC login returns a redirect URL.""" + """Test that OIDC login returns a redirect URL and sets session cookie.""" with patch("ea_chatbot.api.routers.auth.oidc_client") as mock_oidc: - mock_oidc.get_login_url.return_value = "https://oidc-provider.com/auth" + mock_oidc.get_auth_data.return_value = { + "url": "https://oidc-provider.com/auth", + "state": "test_state", + "nonce": "test_nonce", + "code_verifier": "test_verifier" + } response = client.get("/api/v1/auth/oidc/login") assert response.status_code == 200 assert response.json()["url"] == "https://oidc-provider.com/auth" + assert "oidc_session" in response.cookies -def test_oidc_callback_success_ajax(): - """Test successful OIDC callback and JWT issuance via AJAX.""" +def test_oidc_callback_success(): + """Test successful OIDC callback and JWT issuance.""" with patch("ea_chatbot.api.routers.auth.oidc_client") as mock_oidc, \ + patch("ea_chatbot.api.routers.auth.OIDCSession.decrypt") as mock_decrypt, \ patch("ea_chatbot.api.routers.auth.history_manager") as mock_hm: - mock_oidc.exchange_code_for_token.return_value = {"access_token": "oidc-token"} - mock_oidc.get_user_info.return_value = {"email": "sso@example.com", "name": "SSO User"} + mock_decrypt.return_value = { + "state": "test_state", + "nonce": "test_nonce", + "code_verifier": "test_verifier" + } + mock_oidc.exchange_code_for_token.return_value = {"id_token": "fake_id_token"} + mock_oidc.validate_id_token.return_value = {"email": "sso@example.com", "name": "SSO User"} mock_hm.sync_user_from_oidc.return_value = User(id="sso-123", username="sso@example.com", display_name="SSO User") + client.cookies.set("oidc_session", "fake_token") response = client.get( - "/api/v1/auth/oidc/callback?code=some-code", - headers={"Accept": "application/json"} + "/api/v1/auth/oidc/callback?code=some-code&state=test_state", + follow_redirects=False ) - assert response.status_code == 200 - assert "access_token" in response.json() - assert response.json()["token_type"] == "bearer" + assert response.status_code == 302 + assert "access_token" in response.cookies def test_get_me_success(): """Test getting current user with a valid token.""" diff --git a/backend/tests/api/test_api_auth_cookie.py b/backend/tests/api/test_api_auth_cookie.py index e200c61..e3c7849 100644 --- a/backend/tests/api/test_api_auth_cookie.py +++ b/backend/tests/api/test_api_auth_cookie.py @@ -1,6 +1,6 @@ import pytest from fastapi.testclient import TestClient -from unittest.mock import MagicMock, patch +from unittest.mock import patch from ea_chatbot.api.main import app from ea_chatbot.history.models import User from ea_chatbot.api.utils import create_access_token @@ -83,15 +83,25 @@ def test_logout_clears_cookie(): def test_oidc_callback_redirects_with_cookie(): """Test that OIDC callback sets cookie and redirects.""" with patch("ea_chatbot.api.routers.auth.oidc_client") as mock_oidc, \ + patch("ea_chatbot.api.routers.auth.OIDCSession.decrypt") as mock_decrypt, \ patch("ea_chatbot.api.routers.auth.history_manager") as mock_hm: - mock_oidc.exchange_code_for_token.return_value = {"access_token": "oidc-token"} - mock_oidc.get_user_info.return_value = {"email": "sso@example.com", "name": "SSO User"} + mock_decrypt.return_value = { + "state": "test_state", + "nonce": "test_nonce", + "code_verifier": "test_verifier" + } + mock_oidc.exchange_code_for_token.return_value = {"id_token": "fake_id_token"} + mock_oidc.validate_id_token.return_value = {"email": "sso@example.com", "name": "SSO User"} mock_hm.sync_user_from_oidc.return_value = User(id="sso-123", username="sso@example.com", display_name="SSO User") - # Follow_redirects=False to catch the 307/302 - response = client.get("/api/v1/auth/oidc/callback?code=some-code", follow_redirects=False) + # Set the session cookie + client.cookies.set("oidc_session", "fake_token") - assert response.status_code in [302, 303, 307] + # Follow_redirects=False to catch the 302 + response = client.get("/api/v1/auth/oidc/callback?code=some-code&state=test_state", follow_redirects=False) + + assert response.status_code == 302 assert "access_token" in response.cookies - assert "/auth/callback" in response.headers["location"] + # Should redirect to FRONTEND_URL (default http://localhost:5173) + assert "http://localhost:5173" in response.headers["location"] diff --git a/backend/tests/api/test_oidc_flow.py b/backend/tests/api/test_oidc_flow.py new file mode 100644 index 0000000..f777aad --- /dev/null +++ b/backend/tests/api/test_oidc_flow.py @@ -0,0 +1,82 @@ +import pytest +from fastapi.testclient import TestClient +from unittest.mock import patch +from ea_chatbot.api.main import app +from ea_chatbot.history.models import User + +@pytest.fixture +def client(): + """Provides a fresh TestClient for each test.""" + return TestClient(app) + +@pytest.fixture +def mock_auth_data(): + return { + "url": "https://example.com/auth?state=test_state", + "state": "test_state", + "nonce": "test_nonce", + "code_verifier": "test_verifier" + } + +def test_oidc_login_sets_cookie(client, mock_auth_data): + """Test that OIDC login initiation sets the temporary session cookie.""" + with patch("ea_chatbot.api.routers.auth.oidc_client") as mock_oidc: + mock_oidc.get_auth_data.return_value = mock_auth_data + + response = client.get("/api/v1/auth/oidc/login") + + assert response.status_code == 200 + assert response.json()["url"] == mock_auth_data["url"] + assert "oidc_session" in response.cookies + +def test_oidc_callback_missing_cookie(client): + """Test that OIDC callback fails if the temporary session cookie is missing.""" + with patch("ea_chatbot.api.routers.auth.oidc_client"): + # Ensure no cookies are set + client.cookies.clear() + response = client.get("/api/v1/auth/oidc/callback?code=test_code&state=test_state", follow_redirects=False) + + # Should redirect to frontend with error + assert response.status_code == 302 + assert "error=oidc_failed" in response.headers["location"] + +def test_oidc_callback_invalid_state(client, mock_auth_data): + """Test that OIDC callback fails if the state in the URL doesn't match the cookie.""" + with patch("ea_chatbot.api.routers.auth.oidc_client"), \ + patch("ea_chatbot.api.routers.auth.OIDCSession.decrypt") as mock_decrypt: + + # Mock valid cookie content + mock_decrypt.return_value = mock_auth_data + + # Send different state in URL + client.cookies.set("oidc_session", "fake_token") + response = client.get("/api/v1/auth/oidc/callback?code=test_code&state=wrong_state", follow_redirects=False) + + assert response.status_code == 302 + assert "error=oidc_failed" in response.headers["location"] + +def test_oidc_callback_success(client, mock_auth_data): + """Test successful OIDC callback and session establishment.""" + with patch("ea_chatbot.api.routers.auth.oidc_client") as mock_oidc, \ + patch("ea_chatbot.api.routers.auth.OIDCSession.decrypt") as mock_decrypt, \ + patch("ea_chatbot.api.routers.auth.history_manager") as mock_hm, \ + patch("ea_chatbot.api.routers.auth.create_access_token") as mock_create_token: + + mock_decrypt.return_value = mock_auth_data + mock_oidc.exchange_code_for_token.return_value = {"id_token": "fake_id_token"} + mock_oidc.validate_id_token.return_value = {"email": "user@test.com", "name": "Test User"} + mock_hm.sync_user_from_oidc.return_value = User(id="user-123", username="user@test.com") + mock_create_token.return_value = "fake_access_token" + + client.cookies.set("oidc_session", "fake_token") + response = client.get("/api/v1/auth/oidc/callback?code=test_code&state=test_state", follow_redirects=False) + + assert response.status_code == 302 + # Redirect to FRONTEND_URL (default localhost:5173) + assert "http://localhost:5173" in response.headers["location"] + assert "access_token" in response.cookies + + # Verify oidc_session was cleaned up (deleted) + # In RedirectResponse, delete_cookie works by setting the cookie with empty value and past expiry + cookie_header = response.headers.get("set-cookie", "") + assert "oidc_session=;" in cookie_header or 'oidc_session=""' in cookie_header diff --git a/backend/tests/api/test_persistence.py b/backend/tests/api/test_persistence.py index 0b9089e..8ed8167 100644 --- a/backend/tests/api/test_persistence.py +++ b/backend/tests/api/test_persistence.py @@ -3,10 +3,8 @@ from fastapi.testclient import TestClient from unittest.mock import MagicMock, patch, AsyncMock from ea_chatbot.api.main import app from ea_chatbot.api.dependencies import get_current_user -from ea_chatbot.history.models import User, Conversation, Message, Plot +from ea_chatbot.history.models import User, Conversation from ea_chatbot.api.utils import create_access_token -from datetime import datetime, timezone -import json client = TestClient(app) diff --git a/backend/tests/graph/test_checkpoint.py b/backend/tests/graph/test_checkpoint.py index 4a23128..06a0354 100644 --- a/backend/tests/graph/test_checkpoint.py +++ b/backend/tests/graph/test_checkpoint.py @@ -30,7 +30,6 @@ async def test_get_checkpointer_initialization(): checkpoint.get_pool = mock_get_pool # Patch AsyncPostgresSaver where it's imported in checkpoint.py - import langgraph.checkpoint.postgres.aio as pg_aio original_saver = checkpoint.AsyncPostgresSaver checkpoint.AsyncPostgresSaver = mock_saver_class diff --git a/backend/tests/test_app.py b/backend/tests/test_app.py deleted file mode 100644 index ef2d038..0000000 --- a/backend/tests/test_app.py +++ /dev/null @@ -1,90 +0,0 @@ -import os -import sys -import pytest -from unittest.mock import MagicMock, patch -from streamlit.testing.v1 import AppTest -from langchain_core.messages import AIMessage - -# Ensure src is in python path -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../src'))) - -@pytest.fixture(autouse=True) -def mock_history_manager(): - """Globally mock HistoryManager to avoid DB calls during AppTest.""" - with patch("ea_chatbot.history.manager.HistoryManager") as mock_cls: - instance = mock_cls.return_value - instance.create_conversation.return_value = MagicMock(id="conv_123") - instance.get_conversations.return_value = [] - instance.get_messages.return_value = [] - instance.add_message.return_value = MagicMock() - instance.update_conversation_summary.return_value = MagicMock() - yield instance - -@pytest.fixture -def mock_app_stream(): - with patch("ea_chatbot.graph.workflow.app.stream") as mock_stream: - # Mock events from app.stream - mock_stream.return_value = [ - {"query_analyzer": {"next_action": "research"}}, - {"researcher": {"messages": [AIMessage(content="Research result")]}} - ] - yield mock_stream - -@pytest.fixture -def mock_user(): - user = MagicMock() - user.id = "test_id" - user.username = "test@example.com" - user.display_name = "Test User" - return user - -def test_app_initial_state(mock_app_stream, mock_user): - """Test that the app initializes with the correct title and empty history.""" - at = AppTest.from_file("src/ea_chatbot/app.py") - - # Simulate logged-in user - at.session_state["user"] = mock_user - - at.run() - - assert not at.exception - assert at.title[0].value == "🗳️ Election Analytics Chatbot" - - # Check session state initialization - assert "messages" in at.session_state - assert len(at.session_state["messages"]) == 0 - -def test_app_dev_mode_toggle(mock_app_stream, mock_user): - """Test that the dev mode toggle exists in the sidebar.""" - with patch.dict(os.environ, {"DEV_MODE": "false"}): - at = AppTest.from_file("src/ea_chatbot/app.py") - at.session_state["user"] = mock_user - at.run() - - # Check for sidebar toggle (checkbox) - assert len(at.sidebar.checkbox) > 0 - dev_mode_toggle = at.sidebar.checkbox[0] - assert dev_mode_toggle.label == "Dev Mode" - assert dev_mode_toggle.value is False - -def test_app_graph_execution_streaming(mock_app_stream, mock_user, mock_history_manager): - """Test that entering a prompt triggers the graph stream and displays response.""" - at = AppTest.from_file("src/ea_chatbot/app.py") - at.session_state["user"] = mock_user - at.run() - - # Input a question - at.chat_input[0].set_value("Test question").run() - - # Verify graph stream was called - assert mock_app_stream.called - - # Message should be added to history - assert len(at.session_state["messages"]) == 2 - assert at.session_state["messages"][0]["role"] == "user" - assert at.session_state["messages"][1]["role"] == "assistant" - assert "Research result" in at.session_state["messages"][1]["content"] - - # Verify history manager was used - assert mock_history_manager.create_conversation.called - assert mock_history_manager.add_message.called diff --git a/backend/tests/test_app_auth.py b/backend/tests/test_app_auth.py deleted file mode 100644 index e88d23f..0000000 --- a/backend/tests/test_app_auth.py +++ /dev/null @@ -1,83 +0,0 @@ -import pytest -from unittest.mock import MagicMock, patch -from streamlit.testing.v1 import AppTest -from ea_chatbot.auth import AuthType - -@pytest.fixture -def mock_history_manager_instance(): - # We need to patch before the AppTest loads the module - with patch("ea_chatbot.history.manager.HistoryManager") as mock_cls: - instance = mock_cls.return_value - yield instance - -def test_auth_ui_flow_step1_to_password(mock_history_manager_instance): - """Test UI transition from Step 1 (email) to Step 2a (password) for LOCAL user.""" - # Patch BEFORE creating AppTest - mock_user = MagicMock() - mock_user.password_hash = "hashed_password" - mock_history_manager_instance.get_user.return_value = mock_user - - at = AppTest.from_file("src/ea_chatbot/app.py") - at.run() - - # Step 1: Identification - assert at.session_state["login_step"] == "email" - at.text_input[0].set_value("local@example.com") - - at.button[0].click().run() - - # Verify transition to password step - assert at.session_state["login_step"] == "login_password" - assert at.session_state["login_email"] == "local@example.com" - assert "Welcome back" in at.info[0].value - -def test_auth_ui_flow_step1_to_register(mock_history_manager_instance): - """Test UI transition from Step 1 (email) to Step 2b (registration) for NEW user.""" - mock_history_manager_instance.get_user.return_value = None - - at = AppTest.from_file("src/ea_chatbot/app.py") - at.run() - - # Step 1: Identification - at.text_input[0].set_value("new@example.com") - - at.button[0].click().run() - - # Verify transition to registration step - assert at.session_state["login_step"] == "register_details" - assert at.session_state["login_email"] == "new@example.com" - assert "Create an account" in at.info[0].value - -def test_auth_ui_flow_step1_to_oidc(mock_history_manager_instance): - """Test UI transition from Step 1 (email) to Step 2c (OIDC) for OIDC user.""" - # Mock history_manager.get_user to return a user WITHOUT a password - mock_user = MagicMock() - mock_user.password_hash = None - mock_history_manager_instance.get_user.return_value = mock_user - - at = AppTest.from_file("src/ea_chatbot/app.py") - at.run() - - # Step 1: Identification - at.text_input[0].set_value("oidc@example.com") - - at.button[0].click().run() - - # Verify transition to OIDC step - assert at.session_state["login_step"] == "oidc_login" - assert at.session_state["login_email"] == "oidc@example.com" - assert "configured for Single Sign-On" in at.info[0].value - -def test_auth_ui_flow_back_button(mock_history_manager_instance): - """Test that the 'Back' button returns to Step 1.""" - at = AppTest.from_file("src/ea_chatbot/app.py") - # Simulate being on Step 2a - at.session_state["login_step"] = "login_password" - at.session_state["login_email"] = "local@example.com" - at.run() - - # Click Back (index 1 in Step 2a) - at.button[1].click().run() - - # Verify return to email step - assert at.session_state["login_step"] == "email" \ No newline at end of file diff --git a/backend/tests/test_auth.py b/backend/tests/test_auth.py deleted file mode 100644 index 77694e9..0000000 --- a/backend/tests/test_auth.py +++ /dev/null @@ -1,43 +0,0 @@ -import pytest -from unittest.mock import MagicMock, patch -from ea_chatbot.auth import OIDCClient - -@patch("ea_chatbot.auth.OAuth2Session") -def test_oidc_client_initialization(mock_oauth): - client = OIDCClient( - client_id="test_id", - client_secret="test_secret", - server_metadata_url="https://test.server/.well-known/openid-configuration" - ) - assert client.oauth_session is not None - -@patch("ea_chatbot.auth.requests") -@patch("ea_chatbot.auth.OAuth2Session") -def test_get_login_url(mock_oauth_cls, mock_requests): - # Setup mock session - mock_session = MagicMock() - mock_oauth_cls.return_value = mock_session - - # Mock metadata response - mock_response = MagicMock() - mock_response.json.return_value = { - "authorization_endpoint": "https://test.server/auth", - "token_endpoint": "https://test.server/token", - "userinfo_endpoint": "https://test.server/userinfo" - } - mock_requests.get.return_value = mock_response - - # Mock authorization url generation - mock_session.create_authorization_url.return_value = ("https://test.server/auth?response_type=code", "state") - - client = OIDCClient( - client_id="test_id", - client_secret="test_secret", - server_metadata_url="https://test.server/.well-known/openid-configuration" - ) - - url = client.get_login_url() - - assert url == "https://test.server/auth?response_type=code" - # Verify metadata was fetched via requests - mock_requests.get.assert_called_with("https://test.server/.well-known/openid-configuration") \ No newline at end of file diff --git a/backend/tests/test_auth_flow.py b/backend/tests/test_auth_flow.py deleted file mode 100644 index a62650c..0000000 --- a/backend/tests/test_auth_flow.py +++ /dev/null @@ -1,49 +0,0 @@ -import pytest -from unittest.mock import MagicMock -from ea_chatbot.history.manager import HistoryManager -from ea_chatbot.auth import get_user_auth_type, AuthType - -# Mocks -@pytest.fixture -def mock_history_manager(): - return MagicMock(spec=HistoryManager) - -def test_auth_flow_existing_local_user(mock_history_manager): - """Test that an existing user with a password returns LOCAL auth type.""" - # Setup - mock_user = MagicMock() - mock_user.password_hash = "hashed_secret" - mock_history_manager.get_user.return_value = mock_user - - # Execute - auth_type = get_user_auth_type("test@example.com", mock_history_manager) - - # Verify - assert auth_type == AuthType.LOCAL - mock_history_manager.get_user.assert_called_once_with("test@example.com") - -def test_auth_flow_existing_oidc_user(mock_history_manager): - """Test that an existing user WITHOUT a password returns OIDC auth type.""" - # Setup - mock_user = MagicMock() - mock_user.password_hash = None # No password implies OIDC - mock_history_manager.get_user.return_value = mock_user - - # Execute - auth_type = get_user_auth_type("sso@example.com", mock_history_manager) - - # Verify - assert auth_type == AuthType.OIDC - mock_history_manager.get_user.assert_called_once_with("sso@example.com") - -def test_auth_flow_new_user(mock_history_manager): - """Test that a non-existent user returns NEW auth type.""" - # Setup - mock_history_manager.get_user.return_value = None - - # Execute - auth_type = get_user_auth_type("new@example.com", mock_history_manager) - - # Verify - assert auth_type == AuthType.NEW - mock_history_manager.get_user.assert_called_once_with("new@example.com") diff --git a/backend/tests/test_config.py b/backend/tests/test_config.py index 79f3a2c..96076e6 100644 --- a/backend/tests/test_config.py +++ b/backend/tests/test_config.py @@ -1,5 +1,3 @@ -import pytest -from pydantic import ValidationError from ea_chatbot.config import Settings, LLMConfig def test_default_settings(): diff --git a/backend/tests/test_database_inspection.py b/backend/tests/test_database_inspection.py index 7773ce5..79348fa 100644 --- a/backend/tests/test_database_inspection.py +++ b/backend/tests/test_database_inspection.py @@ -1,7 +1,7 @@ import pytest import pandas as pd import os -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock from ea_chatbot.utils.database_inspection import get_primary_key, inspect_db_table, get_data_summary @pytest.fixture diff --git a/backend/tests/test_executor.py b/backend/tests/test_executor.py index 80ceeae..3f9d68b 100644 --- a/backend/tests/test_executor.py +++ b/backend/tests/test_executor.py @@ -3,7 +3,6 @@ import pandas as pd from unittest.mock import MagicMock, patch from matplotlib.figure import Figure from ea_chatbot.graph.nodes.executor import executor_node -from ea_chatbot.graph.state import AgentState @pytest.fixture def mock_settings(): diff --git a/backend/tests/test_helpers.py b/backend/tests/test_helpers.py index 2897c39..fc28eb1 100644 --- a/backend/tests/test_helpers.py +++ b/backend/tests/test_helpers.py @@ -1,4 +1,3 @@ -import pytest from langchain_core.messages import HumanMessage, AIMessage from ea_chatbot.utils.helpers import merge_agent_state diff --git a/backend/tests/test_history_models.py b/backend/tests/test_history_models.py index 00888ca..3784f54 100644 --- a/backend/tests/test_history_models.py +++ b/backend/tests/test_history_models.py @@ -1,6 +1,4 @@ import pytest -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker, DeclarativeBase # We anticipate these imports will fail initially try: diff --git a/backend/tests/test_llm_factory_callbacks.py b/backend/tests/test_llm_factory_callbacks.py index 4dee27b..bab90aa 100644 --- a/backend/tests/test_llm_factory_callbacks.py +++ b/backend/tests/test_llm_factory_callbacks.py @@ -1,4 +1,3 @@ -import pytest from langchain_openai import ChatOpenAI from ea_chatbot.config import LLMConfig from ea_chatbot.utils.llm_factory import get_llm_model diff --git a/backend/tests/test_logging_e2e.py b/backend/tests/test_logging_e2e.py index 8e81ee7..028eb34 100644 --- a/backend/tests/test_logging_e2e.py +++ b/backend/tests/test_logging_e2e.py @@ -1,8 +1,7 @@ -import os import json import pytest import logging -from unittest.mock import MagicMock, patch +from unittest.mock import patch from ea_chatbot.graph.workflow import app from ea_chatbot.graph.state import AgentState from ea_chatbot.utils.logging import get_logger diff --git a/backend/tests/test_multi_turn_query_analyzer.py b/backend/tests/test_multi_turn_query_analyzer.py index 9f242cd..1baff50 100644 --- a/backend/tests/test_multi_turn_query_analyzer.py +++ b/backend/tests/test_multi_turn_query_analyzer.py @@ -2,7 +2,6 @@ import pytest from unittest.mock import MagicMock, patch from langchain_core.messages import HumanMessage, AIMessage from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node, QueryAnalysis -from ea_chatbot.graph.state import AgentState @pytest.fixture def mock_state_with_history(): diff --git a/backend/tests/test_oidc_client.py b/backend/tests/test_oidc_client.py deleted file mode 100644 index 075f22c..0000000 --- a/backend/tests/test_oidc_client.py +++ /dev/null @@ -1,87 +0,0 @@ -import pytest -from unittest.mock import MagicMock, patch -from ea_chatbot.auth import OIDCClient - -@pytest.fixture -def oidc_config(): - return { - "client_id": "test_id", - "client_secret": "test_secret", - "server_metadata_url": "https://example.com/.well-known/openid-configuration", - "redirect_uri": "http://localhost:8501" - } - -@pytest.fixture -def mock_metadata(): - return { - "authorization_endpoint": "https://example.com/auth", - "token_endpoint": "https://example.com/token", - "userinfo_endpoint": "https://example.com/userinfo" - } - -def test_oidc_fetch_metadata(oidc_config, mock_metadata): - client = OIDCClient(**oidc_config) - - with patch("requests.get") as mock_get: - mock_response = MagicMock() - mock_response.json.return_value = mock_metadata - mock_get.return_value = mock_response - - metadata = client.fetch_metadata() - - assert metadata == mock_metadata - mock_get.assert_called_once_with(oidc_config["server_metadata_url"]) - - # Second call should use cache - client.fetch_metadata() - assert mock_get.call_count == 1 - -def test_oidc_get_login_url(oidc_config, mock_metadata): - client = OIDCClient(**oidc_config) - client.metadata = mock_metadata - - with patch.object(client.oauth_session, "create_authorization_url") as mock_create_url: - mock_create_url.return_value = ("https://example.com/auth?state=xyz", "xyz") - - url = client.get_login_url() - - assert url == "https://example.com/auth?state=xyz" - mock_create_url.assert_called_once_with(mock_metadata["authorization_endpoint"]) - -def test_oidc_get_login_url_missing_endpoint(oidc_config): - client = OIDCClient(**oidc_config) - client.metadata = {"some_other": "field"} - - with pytest.raises(ValueError, match="authorization_endpoint not found"): - client.get_login_url() - -def test_oidc_exchange_code_for_token(oidc_config, mock_metadata): - client = OIDCClient(**oidc_config) - client.metadata = mock_metadata - - with patch.object(client.oauth_session, "fetch_token") as mock_fetch_token: - mock_fetch_token.return_value = {"access_token": "abc"} - - token = client.exchange_code_for_token("test_code") - - assert token == {"access_token": "abc"} - mock_fetch_token.assert_called_once_with( - mock_metadata["token_endpoint"], - code="test_code", - client_secret=oidc_config["client_secret"] - ) - -def test_oidc_get_user_info(oidc_config, mock_metadata): - client = OIDCClient(**oidc_config) - client.metadata = mock_metadata - - with patch.object(client.oauth_session, "get") as mock_get: - mock_response = MagicMock() - mock_response.json.return_value = {"sub": "user123"} - mock_get.return_value = mock_response - - user_info = client.get_user_info({"access_token": "abc"}) - - assert user_info == {"sub": "user123"} - assert client.oauth_session.token == {"access_token": "abc"} - mock_get.assert_called_once_with(mock_metadata["userinfo_endpoint"]) diff --git a/backend/tests/test_oidc_validation.py b/backend/tests/test_oidc_validation.py index b61e4fc..a2bafcc 100644 --- a/backend/tests/test_oidc_validation.py +++ b/backend/tests/test_oidc_validation.py @@ -1,7 +1,6 @@ import pytest -from unittest.mock import MagicMock, patch +from unittest.mock import patch from ea_chatbot.auth import OIDCClient -from jose import jwt @pytest.fixture def oidc_config(): diff --git a/backend/tests/test_query_analyzer.py b/backend/tests/test_query_analyzer.py index 103727f..ccf7a1f 100644 --- a/backend/tests/test_query_analyzer.py +++ b/backend/tests/test_query_analyzer.py @@ -1,7 +1,6 @@ import pytest from unittest.mock import MagicMock, patch from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node, QueryAnalysis -from ea_chatbot.graph.state import AgentState @pytest.fixture def mock_state(): diff --git a/backend/tests/test_query_analyzer_logging.py b/backend/tests/test_query_analyzer_logging.py index 29b493a..7f6e26c 100644 --- a/backend/tests/test_query_analyzer_logging.py +++ b/backend/tests/test_query_analyzer_logging.py @@ -1,5 +1,4 @@ import pytest -import logging from unittest.mock import MagicMock, patch from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node, QueryAnalysis diff --git a/backend/tests/test_query_analyzer_refinement.py b/backend/tests/test_query_analyzer_refinement.py index fd0f431..36e0f07 100644 --- a/backend/tests/test_query_analyzer_refinement.py +++ b/backend/tests/test_query_analyzer_refinement.py @@ -2,7 +2,6 @@ import pytest from unittest.mock import MagicMock, patch from langchain_core.messages import HumanMessage, AIMessage from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node, QueryAnalysis -from ea_chatbot.graph.state import AgentState @pytest.fixture def base_state(): diff --git a/backend/tests/test_state.py b/backend/tests/test_state.py index da38f8e..e79d59d 100644 --- a/backend/tests/test_state.py +++ b/backend/tests/test_state.py @@ -1,8 +1,5 @@ -import pytest -from typing import get_type_hints, List -from langchain_core.messages import BaseMessage, HumanMessage +from typing import get_type_hints from ea_chatbot.graph.state import AgentState -import operator def test_agent_state_structure(): """Verify that AgentState has the required fields and types.""" diff --git a/backend/tests/test_summarizer.py b/backend/tests/test_summarizer.py index 6db03d7..b152fdb 100644 --- a/backend/tests/test_summarizer.py +++ b/backend/tests/test_summarizer.py @@ -2,7 +2,6 @@ import pytest from unittest.mock import MagicMock, patch from langchain_core.messages import AIMessage from ea_chatbot.graph.nodes.summarizer import summarizer_node -from ea_chatbot.graph.state import AgentState @pytest.fixture def mock_llm(): diff --git a/backend/tests/test_workflow.py b/backend/tests/test_workflow.py index e66e591..c3ea36e 100644 --- a/backend/tests/test_workflow.py +++ b/backend/tests/test_workflow.py @@ -1,4 +1,3 @@ -import pytest from unittest.mock import MagicMock, patch from ea_chatbot.graph.workflow import app from ea_chatbot.graph.nodes.query_analyzer import QueryAnalysis diff --git a/backend/tests/test_workflow_e2e.py b/backend/tests/test_workflow_e2e.py index 342fac2..ef76d1f 100644 --- a/backend/tests/test_workflow_e2e.py +++ b/backend/tests/test_workflow_e2e.py @@ -1,5 +1,4 @@ import pytest -import yaml from unittest.mock import MagicMock, patch from langchain_core.messages import AIMessage from ea_chatbot.graph.workflow import app diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 09b7d61..4f3d9e8 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -3,7 +3,6 @@ import { Routes, Route } from "react-router-dom" import { MainLayout } from "./components/layout/MainLayout" import { LoginForm } from "./components/auth/LoginForm" import { RegisterForm } from "./components/auth/RegisterForm" -import { AuthCallback } from "./components/auth/AuthCallback" import { ChatInterface } from "./components/chat/ChatInterface" import { AuthService, type UserResponse } from "./services/auth" import { ChatService, type MessageResponse } from "./services/chat" @@ -136,6 +135,9 @@ function App() { setThreadMessages(prev => ({ ...prev, [id]: messages })) } + const queryParams = new URLSearchParams(window.location.search) + const externalError = queryParams.get("error") + if (isLoading) { return (
@@ -146,7 +148,6 @@ function App() { return ( - } /> setAuthMode("register")} + externalError={externalError === "oidc_failed" ? "SSO authentication failed. Please try again." : null} /> ) : ( { - const verifyAuth = async () => { - const urlParams = new URLSearchParams(window.location.search) - const code = urlParams.get("code") - - try { - if (code) { - // If we have a code, exchange it for a cookie - await AuthService.exchangeOIDCCode(code) - } else { - // If no code, just verify existing cookie (backend-driven redirect) - await AuthService.getMe() - } - - // Success - go to home. We use window.location.href to ensure a clean reload of App state - window.location.href = "/" - } catch (err) { - console.error("Auth callback verification failed:", err) - navigate("/?error=auth_failed", { replace: true }) - } - } - - verifyAuth() - }, [navigate]) - - return ( -
-
-

Completing login...

-
- ) -} diff --git a/frontend/src/components/auth/LoginForm.tsx b/frontend/src/components/auth/LoginForm.tsx index 4e602dd..29458f8 100644 --- a/frontend/src/components/auth/LoginForm.tsx +++ b/frontend/src/components/auth/LoginForm.tsx @@ -22,10 +22,11 @@ import axios from "axios" interface LoginFormProps { onSuccess: () => void onToggleMode: () => void + externalError?: string | null } -export function LoginForm({ onSuccess, onToggleMode }: LoginFormProps) { - const [error, setError] = useState(null) +export function LoginForm({ onSuccess, onToggleMode, externalError }: LoginFormProps) { + const [error, setError] = useState(externalError || null) const [isLoading, setIsLoading] = useState(false) const form = useForm({ diff --git a/frontend/src/services/auth.ts b/frontend/src/services/auth.ts index e2a033f..1b322a8 100644 --- a/frontend/src/services/auth.ts +++ b/frontend/src/services/auth.ts @@ -28,11 +28,6 @@ export const AuthService = { } }, - async exchangeOIDCCode(code: string): Promise { - const response = await api.get(`/auth/oidc/callback?code=${code}`) - return response.data - }, - async register(email: string, password: string): Promise { const response = await api.post("/auth/register", { email,