feat(auth): Complete OIDC security refactor and modernize test suite
- Refactored OIDC flow to implement PKCE, state/nonce validation, and BFF pattern. - Centralized configuration in Settings class (DEV_MODE, FRONTEND_URL, OIDC_REDIRECT_URI). - Updated auth routers to use conditional secure cookie flags based on DEV_MODE. - Modernized and cleaned up test suite by removing legacy Streamlit tests. - Fixed linting errors and unused imports across the backend.
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
import os
|
||||
from fastapi import Depends, HTTPException, status, Request
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from ea_chatbot.config import Settings
|
||||
@@ -18,7 +17,7 @@ if settings.oidc_client_id and settings.oidc_client_secret and settings.oidc_ser
|
||||
client_id=settings.oidc_client_id,
|
||||
client_secret=settings.oidc_client_secret,
|
||||
server_metadata_url=settings.oidc_server_metadata_url,
|
||||
redirect_uri=os.getenv("OIDC_REDIRECT_URI", "http://localhost:3000/auth/callback")
|
||||
redirect_uri=settings.oidc_redirect_uri
|
||||
)
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/v1/auth/login", auto_error=False)
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
from fastapi import FastAPI, APIRouter
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from ea_chatbot.api.routers import auth, history, artifacts, agent
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
app = FastAPI(
|
||||
title="Election Analytics Chatbot API",
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import json
|
||||
import asyncio
|
||||
from typing import AsyncGenerator, Optional, List
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from typing import AsyncGenerator, List
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from ea_chatbot.api.dependencies import get_current_user, history_manager
|
||||
from ea_chatbot.api.utils import convert_to_json_compatible
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Response, status
|
||||
from ea_chatbot.api.dependencies import get_current_user, history_manager
|
||||
from ea_chatbot.history.models import Plot, Message, User as UserDB
|
||||
import io
|
||||
|
||||
router = APIRouter(prefix="/artifacts", tags=["artifacts"])
|
||||
|
||||
|
||||
@@ -1,19 +1,17 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Response, Request
|
||||
from fastapi.responses import RedirectResponse
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from ea_chatbot.api.utils import create_access_token
|
||||
from ea_chatbot.api.utils import create_access_token, settings
|
||||
from ea_chatbot.api.dependencies import history_manager, oidc_client, get_current_user
|
||||
from ea_chatbot.history.models import User as UserDB
|
||||
from ea_chatbot.api.schemas import Token, UserCreate, UserResponse
|
||||
import os
|
||||
from ea_chatbot.auth import OIDCSession
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
FRONTEND_URL = os.getenv("FRONTEND_URL", "http://localhost:5173")
|
||||
|
||||
def set_auth_cookie(response: Response, token: str):
|
||||
response.set_cookie(
|
||||
key="access_token",
|
||||
@@ -22,7 +20,7 @@ def set_auth_cookie(response: Response, token: str):
|
||||
max_age=1800,
|
||||
expires=1800,
|
||||
samesite="lax",
|
||||
secure=False, # Set to True in production with HTTPS
|
||||
secure=not settings.dev_mode,
|
||||
)
|
||||
|
||||
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||
@@ -72,59 +70,84 @@ async def logout(response: Response):
|
||||
return {"detail": "Successfully logged out"}
|
||||
|
||||
@router.get("/oidc/login")
|
||||
async def oidc_login():
|
||||
"""Get the OIDC authorization URL."""
|
||||
async def oidc_login(response: Response):
|
||||
"""Get the OIDC authorization URL and set temporary session cookie."""
|
||||
if not oidc_client:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_510_NOT_EXTENDED,
|
||||
detail="OIDC is not configured"
|
||||
)
|
||||
|
||||
url = oidc_client.get_login_url()
|
||||
return {"url": url}
|
||||
auth_data = oidc_client.get_auth_data()
|
||||
|
||||
# Store state, nonce, and code_verifier in a secure short-lived cookie
|
||||
session_token = OIDCSession.encrypt(
|
||||
{
|
||||
"state": auth_data["state"],
|
||||
"nonce": auth_data["nonce"],
|
||||
"code_verifier": auth_data["code_verifier"]
|
||||
},
|
||||
settings.secret_key
|
||||
)
|
||||
|
||||
response.set_cookie(
|
||||
key="oidc_session",
|
||||
value=session_token,
|
||||
httponly=True,
|
||||
max_age=300, # 5 minutes
|
||||
samesite="lax",
|
||||
secure=not settings.dev_mode
|
||||
)
|
||||
|
||||
return {"url": auth_data["url"]}
|
||||
|
||||
@router.get("/oidc/callback", response_model=Token)
|
||||
async def oidc_callback(request: Request, response: Response, code: str):
|
||||
"""Handle the OIDC callback, issue a JWT, and redirect or return JSON."""
|
||||
@router.get("/oidc/callback")
|
||||
async def oidc_callback(request: Request, response: Response, code: str, state: str):
|
||||
"""Handle the OIDC callback, validate state/nonce, issue a JWT, and redirect."""
|
||||
if not oidc_client:
|
||||
raise HTTPException(status_code=status.HTTP_510_NOT_EXTENDED, detail="OIDC not configured")
|
||||
|
||||
# 1. Validate state and retrieve session data from cookie
|
||||
session_token = request.cookies.get("oidc_session")
|
||||
if not session_token:
|
||||
logger.error("OIDC session cookie missing")
|
||||
return RedirectResponse(url=f"{settings.frontend_url}?error=oidc_failed", status_code=status.HTTP_302_FOUND)
|
||||
|
||||
session_data = OIDCSession.decrypt(session_token, settings.secret_key)
|
||||
if not session_data or session_data.get("state") != state:
|
||||
logger.error("OIDC state mismatch or session expired")
|
||||
return RedirectResponse(url=f"{settings.frontend_url}?error=oidc_failed", status_code=status.HTTP_302_FOUND)
|
||||
|
||||
try:
|
||||
token = oidc_client.exchange_code_for_token(code)
|
||||
user_info = oidc_client.get_user_info(token)
|
||||
email = user_info.get("email")
|
||||
name = user_info.get("name") or user_info.get("preferred_username")
|
||||
# 2. Exchange code for token using PKCE code_verifier
|
||||
token = oidc_client.exchange_code_for_token(code, code_verifier=session_data.get("code_verifier"))
|
||||
|
||||
# 3. Validate ID Token and Nonce
|
||||
id_token = token.get("id_token")
|
||||
if not id_token:
|
||||
raise ValueError("ID Token missing from response")
|
||||
|
||||
claims = oidc_client.validate_id_token(id_token, nonce=session_data.get("nonce"))
|
||||
|
||||
email = claims.get("email")
|
||||
name = claims.get("name") or claims.get("preferred_username")
|
||||
|
||||
if not email:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Email not provided by OIDC")
|
||||
|
||||
# 4. Sync user and establish session
|
||||
user = history_manager.sync_user_from_oidc(email=email, display_name=name)
|
||||
|
||||
access_token = create_access_token(data={"sub": str(user.id)})
|
||||
|
||||
# Determine if we should redirect (direct provider callback) or return JSON (frontend exchange)
|
||||
is_ajax = request.headers.get("X-Requested-With") == "XMLHttpRequest" or \
|
||||
"application/json" in request.headers.get("Accept", "")
|
||||
|
||||
if is_ajax:
|
||||
set_auth_cookie(response, access_token)
|
||||
return {"access_token": access_token, "token_type": "bearer"}
|
||||
else:
|
||||
redirect_response = RedirectResponse(url=f"{FRONTEND_URL}/auth/callback")
|
||||
set_auth_cookie(redirect_response, access_token)
|
||||
return redirect_response
|
||||
redirect_response = RedirectResponse(url=settings.frontend_url, status_code=status.HTTP_302_FOUND)
|
||||
set_auth_cookie(redirect_response, access_token)
|
||||
redirect_response.delete_cookie("oidc_session") # Cleanup
|
||||
|
||||
return redirect_response
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("OIDC authentication failed")
|
||||
# For non-ajax, redirect with error
|
||||
is_ajax = request.headers.get("X-Requested-With") == "XMLHttpRequest" or \
|
||||
"application/json" in request.headers.get("Accept", "")
|
||||
if is_ajax:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="OIDC authentication failed")
|
||||
else:
|
||||
return RedirectResponse(url=f"{FRONTEND_URL}?error=oidc_failed")
|
||||
return RedirectResponse(url=f"{settings.frontend_url}?error=oidc_failed", status_code=status.HTTP_302_FOUND)
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
async def get_me(current_user: UserDB = Depends(get_current_user)):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional, Union, Any, List, Dict
|
||||
from typing import Optional, Any
|
||||
from jose import JWTError, jwt
|
||||
from pydantic import BaseModel
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import streamlit as st
|
||||
import asyncio
|
||||
import os
|
||||
import io
|
||||
from dotenv import load_dotenv
|
||||
@@ -348,7 +347,7 @@ def main():
|
||||
if node_name == "query_analyzer":
|
||||
analysis = state_update.get("analysis", {})
|
||||
next_action = state_update.get("next_action", "unknown")
|
||||
status.write(f"🔍 **Analyzed Query:**")
|
||||
status.write("🔍 **Analyzed Query:**")
|
||||
for k,v in analysis.items():
|
||||
status.write(f"- {k:<8}: {v}")
|
||||
status.markdown(f"Next Step: {next_action.capitalize()}")
|
||||
|
||||
@@ -19,6 +19,8 @@ class Settings(BaseSettings):
|
||||
data_dir: str = "data"
|
||||
data_state: str = "new_jersey"
|
||||
log_level: str = Field(default="INFO", alias="LOG_LEVEL")
|
||||
dev_mode: bool = Field(default=True, alias="DEV_MODE")
|
||||
frontend_url: str = Field(default="http://localhost:5173", alias="FRONTEND_URL")
|
||||
|
||||
# Voter Database configuration
|
||||
db_host: str = Field(default="localhost", alias="DB_HOST")
|
||||
@@ -40,6 +42,7 @@ class Settings(BaseSettings):
|
||||
oidc_client_id: Optional[str] = Field(default=None, alias="OIDC_CLIENT_ID")
|
||||
oidc_client_secret: Optional[str] = Field(default=None, alias="OIDC_CLIENT_SECRET")
|
||||
oidc_server_metadata_url: Optional[str] = Field(default=None, alias="OIDC_SERVER_METADATA_URL")
|
||||
oidc_redirect_uri: str = Field(default="http://localhost:8000/api/v1/auth/oidc/callback", alias="OIDC_REDIRECT_URI")
|
||||
|
||||
# Default configurations for each node
|
||||
query_analyzer_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0))
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from langchain_core.messages import AIMessage
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from ea_chatbot.utils import helpers, database_inspection
|
||||
from ea_chatbot.utils import database_inspection
|
||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||
from ea_chatbot.graph.prompts.coder import CODE_GENERATOR_PROMPT
|
||||
from ea_chatbot.schemas import CodeGenerationResponse
|
||||
|
||||
@@ -2,7 +2,7 @@ import io
|
||||
import sys
|
||||
import traceback
|
||||
from contextlib import redirect_stdout
|
||||
from typing import Any, Dict, List, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING
|
||||
import pandas as pd
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from langchain_core.messages import AIMessage
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import TypedDict, Annotated, List, Dict, Any, Optional
|
||||
from typing import Annotated, List, Dict, Any, Optional
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain.agents import AgentState as AS
|
||||
import operator
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, List
|
||||
from sqlalchemy import create_engine, select, delete
|
||||
from sqlalchemy.orm import sessionmaker, Session
|
||||
from sqlalchemy import create_engine, select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from argon2 import PasswordHasher
|
||||
from argon2.exceptions import VerifyMismatchError
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from typing import Optional, Dict, Any, List, TYPE_CHECKING
|
||||
from typing import Optional, Dict, Any, TYPE_CHECKING
|
||||
import yaml
|
||||
import json
|
||||
import os
|
||||
from ea_chatbot.utils.db_client import DBClient
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@@ -20,7 +20,8 @@ def to_yaml(json_str: str, indent: int = 2) -> str:
|
||||
"""
|
||||
Attempts to convert a JSON string (potentially malformed from LLM) to a YAML string.
|
||||
"""
|
||||
if not json_str: return ""
|
||||
if not json_str:
|
||||
return ""
|
||||
|
||||
try:
|
||||
# Try direct parse
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Optional, cast, TYPE_CHECKING, Literal, Dict, List, Tuple, Any
|
||||
from typing import Optional, List
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
|
||||
Reference in New Issue
Block a user