feat(auth): Complete OIDC security refactor and modernize test suite

- Refactored OIDC flow to implement PKCE, state/nonce validation, and BFF pattern.
- Centralized configuration in Settings class (DEV_MODE, FRONTEND_URL, OIDC_REDIRECT_URI).
- Updated auth routers to use conditional secure cookie flags based on DEV_MODE.
- Modernized and cleaned up test suite by removing legacy Streamlit tests.
- Fixed linting errors and unused imports across the backend.
This commit is contained in:
Yunxiao Xu
2026-02-15 02:50:26 -08:00
parent 48ad0ebdd7
commit 68c0985482
50 changed files with 222 additions and 515 deletions

View File

@@ -1,4 +1,3 @@
import os
from fastapi import Depends, HTTPException, status, Request
from fastapi.security import OAuth2PasswordBearer
from ea_chatbot.config import Settings
@@ -18,7 +17,7 @@ if settings.oidc_client_id and settings.oidc_client_secret and settings.oidc_ser
client_id=settings.oidc_client_id,
client_secret=settings.oidc_client_secret,
server_metadata_url=settings.oidc_server_metadata_url,
redirect_uri=os.getenv("OIDC_REDIRECT_URI", "http://localhost:3000/auth/callback")
redirect_uri=settings.oidc_redirect_uri
)
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/v1/auth/login", auto_error=False)

View File

@@ -1,9 +1,6 @@
from fastapi import FastAPI, APIRouter
from fastapi.middleware.cors import CORSMiddleware
from ea_chatbot.api.routers import auth, history, artifacts, agent
from dotenv import load_dotenv
load_dotenv()
app = FastAPI(
title="Election Analytics Chatbot API",

View File

@@ -1,7 +1,6 @@
import json
import asyncio
from typing import AsyncGenerator, Optional, List
from fastapi import APIRouter, Depends, HTTPException, status
from typing import AsyncGenerator, List
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from ea_chatbot.api.dependencies import get_current_user, history_manager
from ea_chatbot.api.utils import convert_to_json_compatible

View File

@@ -1,7 +1,6 @@
from fastapi import APIRouter, Depends, HTTPException, Response, status
from ea_chatbot.api.dependencies import get_current_user, history_manager
from ea_chatbot.history.models import Plot, Message, User as UserDB
import io
router = APIRouter(prefix="/artifacts", tags=["artifacts"])

View File

@@ -1,19 +1,17 @@
from fastapi import APIRouter, Depends, HTTPException, status, Response, Request
from fastapi.responses import RedirectResponse
from fastapi.security import OAuth2PasswordRequestForm
from ea_chatbot.api.utils import create_access_token
from ea_chatbot.api.utils import create_access_token, settings
from ea_chatbot.api.dependencies import history_manager, oidc_client, get_current_user
from ea_chatbot.history.models import User as UserDB
from ea_chatbot.api.schemas import Token, UserCreate, UserResponse
import os
from ea_chatbot.auth import OIDCSession
import logging
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/auth", tags=["auth"])
FRONTEND_URL = os.getenv("FRONTEND_URL", "http://localhost:5173")
def set_auth_cookie(response: Response, token: str):
response.set_cookie(
key="access_token",
@@ -22,7 +20,7 @@ def set_auth_cookie(response: Response, token: str):
max_age=1800,
expires=1800,
samesite="lax",
secure=False, # Set to True in production with HTTPS
secure=not settings.dev_mode,
)
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
@@ -72,59 +70,84 @@ async def logout(response: Response):
return {"detail": "Successfully logged out"}
@router.get("/oidc/login")
async def oidc_login():
"""Get the OIDC authorization URL."""
async def oidc_login(response: Response):
"""Get the OIDC authorization URL and set temporary session cookie."""
if not oidc_client:
raise HTTPException(
status_code=status.HTTP_510_NOT_EXTENDED,
detail="OIDC is not configured"
)
url = oidc_client.get_login_url()
return {"url": url}
auth_data = oidc_client.get_auth_data()
# Store state, nonce, and code_verifier in a secure short-lived cookie
session_token = OIDCSession.encrypt(
{
"state": auth_data["state"],
"nonce": auth_data["nonce"],
"code_verifier": auth_data["code_verifier"]
},
settings.secret_key
)
response.set_cookie(
key="oidc_session",
value=session_token,
httponly=True,
max_age=300, # 5 minutes
samesite="lax",
secure=not settings.dev_mode
)
return {"url": auth_data["url"]}
@router.get("/oidc/callback", response_model=Token)
async def oidc_callback(request: Request, response: Response, code: str):
"""Handle the OIDC callback, issue a JWT, and redirect or return JSON."""
@router.get("/oidc/callback")
async def oidc_callback(request: Request, response: Response, code: str, state: str):
"""Handle the OIDC callback, validate state/nonce, issue a JWT, and redirect."""
if not oidc_client:
raise HTTPException(status_code=status.HTTP_510_NOT_EXTENDED, detail="OIDC not configured")
# 1. Validate state and retrieve session data from cookie
session_token = request.cookies.get("oidc_session")
if not session_token:
logger.error("OIDC session cookie missing")
return RedirectResponse(url=f"{settings.frontend_url}?error=oidc_failed", status_code=status.HTTP_302_FOUND)
session_data = OIDCSession.decrypt(session_token, settings.secret_key)
if not session_data or session_data.get("state") != state:
logger.error("OIDC state mismatch or session expired")
return RedirectResponse(url=f"{settings.frontend_url}?error=oidc_failed", status_code=status.HTTP_302_FOUND)
try:
token = oidc_client.exchange_code_for_token(code)
user_info = oidc_client.get_user_info(token)
email = user_info.get("email")
name = user_info.get("name") or user_info.get("preferred_username")
# 2. Exchange code for token using PKCE code_verifier
token = oidc_client.exchange_code_for_token(code, code_verifier=session_data.get("code_verifier"))
# 3. Validate ID Token and Nonce
id_token = token.get("id_token")
if not id_token:
raise ValueError("ID Token missing from response")
claims = oidc_client.validate_id_token(id_token, nonce=session_data.get("nonce"))
email = claims.get("email")
name = claims.get("name") or claims.get("preferred_username")
if not email:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Email not provided by OIDC")
# 4. Sync user and establish session
user = history_manager.sync_user_from_oidc(email=email, display_name=name)
access_token = create_access_token(data={"sub": str(user.id)})
# Determine if we should redirect (direct provider callback) or return JSON (frontend exchange)
is_ajax = request.headers.get("X-Requested-With") == "XMLHttpRequest" or \
"application/json" in request.headers.get("Accept", "")
if is_ajax:
set_auth_cookie(response, access_token)
return {"access_token": access_token, "token_type": "bearer"}
else:
redirect_response = RedirectResponse(url=f"{FRONTEND_URL}/auth/callback")
set_auth_cookie(redirect_response, access_token)
return redirect_response
redirect_response = RedirectResponse(url=settings.frontend_url, status_code=status.HTTP_302_FOUND)
set_auth_cookie(redirect_response, access_token)
redirect_response.delete_cookie("oidc_session") # Cleanup
return redirect_response
except HTTPException:
raise
except Exception as e:
except Exception:
logger.exception("OIDC authentication failed")
# For non-ajax, redirect with error
is_ajax = request.headers.get("X-Requested-With") == "XMLHttpRequest" or \
"application/json" in request.headers.get("Accept", "")
if is_ajax:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="OIDC authentication failed")
else:
return RedirectResponse(url=f"{FRONTEND_URL}?error=oidc_failed")
return RedirectResponse(url=f"{settings.frontend_url}?error=oidc_failed", status_code=status.HTTP_302_FOUND)
@router.get("/me", response_model=UserResponse)
async def get_me(current_user: UserDB = Depends(get_current_user)):

View File

@@ -1,5 +1,5 @@
from datetime import datetime, timedelta, timezone
from typing import Optional, Union, Any, List, Dict
from typing import Optional, Any
from jose import JWTError, jwt
from pydantic import BaseModel
from langchain_core.messages import BaseMessage

View File

@@ -1,5 +1,4 @@
import streamlit as st
import asyncio
import os
import io
from dotenv import load_dotenv
@@ -348,7 +347,7 @@ def main():
if node_name == "query_analyzer":
analysis = state_update.get("analysis", {})
next_action = state_update.get("next_action", "unknown")
status.write(f"🔍 **Analyzed Query:**")
status.write("🔍 **Analyzed Query:**")
for k,v in analysis.items():
status.write(f"- {k:<8}: {v}")
status.markdown(f"Next Step: {next_action.capitalize()}")

View File

@@ -19,6 +19,8 @@ class Settings(BaseSettings):
data_dir: str = "data"
data_state: str = "new_jersey"
log_level: str = Field(default="INFO", alias="LOG_LEVEL")
dev_mode: bool = Field(default=True, alias="DEV_MODE")
frontend_url: str = Field(default="http://localhost:5173", alias="FRONTEND_URL")
# Voter Database configuration
db_host: str = Field(default="localhost", alias="DB_HOST")
@@ -40,6 +42,7 @@ class Settings(BaseSettings):
oidc_client_id: Optional[str] = Field(default=None, alias="OIDC_CLIENT_ID")
oidc_client_secret: Optional[str] = Field(default=None, alias="OIDC_CLIENT_SECRET")
oidc_server_metadata_url: Optional[str] = Field(default=None, alias="OIDC_SERVER_METADATA_URL")
oidc_redirect_uri: str = Field(default="http://localhost:8000/api/v1/auth/oidc/callback", alias="OIDC_REDIRECT_URI")
# Default configurations for each node
query_analyzer_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0))

View File

@@ -1,4 +1,3 @@
from langchain_core.messages import AIMessage
from ea_chatbot.graph.state import AgentState
from ea_chatbot.config import Settings
from ea_chatbot.utils.llm_factory import get_llm_model

View File

@@ -1,7 +1,7 @@
from ea_chatbot.graph.state import AgentState
from ea_chatbot.config import Settings
from ea_chatbot.utils.llm_factory import get_llm_model
from ea_chatbot.utils import helpers, database_inspection
from ea_chatbot.utils import database_inspection
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
from ea_chatbot.graph.prompts.coder import CODE_GENERATOR_PROMPT
from ea_chatbot.schemas import CodeGenerationResponse

View File

@@ -2,7 +2,7 @@ import io
import sys
import traceback
from contextlib import redirect_stdout
from typing import Any, Dict, List, TYPE_CHECKING
from typing import TYPE_CHECKING
import pandas as pd
from matplotlib.figure import Figure

View File

@@ -1,4 +1,3 @@
from langchain_core.messages import AIMessage
from langchain_openai import ChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI
from ea_chatbot.graph.state import AgentState

View File

@@ -1,4 +1,3 @@
from langchain_core.messages import AIMessage
from ea_chatbot.graph.state import AgentState
from ea_chatbot.config import Settings
from ea_chatbot.utils.llm_factory import get_llm_model

View File

@@ -1,4 +1,4 @@
from typing import TypedDict, Annotated, List, Dict, Any, Optional
from typing import Annotated, List, Dict, Any, Optional
from langchain_core.messages import BaseMessage
from langchain.agents import AgentState as AS
import operator

View File

@@ -1,7 +1,7 @@
from contextlib import contextmanager
from typing import Optional, List
from sqlalchemy import create_engine, select, delete
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy import create_engine, select
from sqlalchemy.orm import sessionmaker
from argon2 import PasswordHasher
from argon2.exceptions import VerifyMismatchError

View File

@@ -1,6 +1,5 @@
from typing import Optional, Dict, Any, List, TYPE_CHECKING
from typing import Optional, Dict, Any, TYPE_CHECKING
import yaml
import json
import os
from ea_chatbot.utils.db_client import DBClient
if TYPE_CHECKING:

View File

@@ -20,7 +20,8 @@ def to_yaml(json_str: str, indent: int = 2) -> str:
"""
Attempts to convert a JSON string (potentially malformed from LLM) to a YAML string.
"""
if not json_str: return ""
if not json_str:
return ""
try:
# Try direct parse

View File

@@ -1,4 +1,4 @@
from typing import Optional, cast, TYPE_CHECKING, Literal, Dict, List, Tuple, Any
from typing import Optional, List
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_openai import ChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI