79 lines
2.5 KiB
Python
79 lines
2.5 KiB
Python
from fastapi import APIRouter, Depends, HTTPException, status
|
|
from fastapi.security import OAuth2PasswordRequestForm
|
|
from pydantic import BaseModel, EmailStr
|
|
from typing import Optional
|
|
from ea_chatbot.api.utils import create_access_token
|
|
from ea_chatbot.api.dependencies import history_manager, oidc_client, get_current_user
|
|
from ea_chatbot.history.models import User as UserDB
|
|
|
|
router = APIRouter(prefix="/auth", tags=["auth"])
|
|
|
|
class Token(BaseModel):
|
|
access_token: str
|
|
token_type: str
|
|
|
|
class UserCreate(BaseModel):
|
|
email: EmailStr
|
|
password: str
|
|
display_name: Optional[str] = None
|
|
|
|
class UserResponse(BaseModel):
|
|
id: str
|
|
email: str
|
|
display_name: str
|
|
|
|
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
|
async def register(user_in: UserCreate):
|
|
"""Register a new user."""
|
|
user = history_manager.get_user(user_in.email)
|
|
if user:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="User already exists"
|
|
)
|
|
|
|
user = history_manager.create_user(
|
|
email=user_in.email,
|
|
password=user_in.password,
|
|
display_name=user_in.display_name
|
|
)
|
|
return {
|
|
"id": str(user.id),
|
|
"email": user.username,
|
|
"display_name": user.display_name
|
|
}
|
|
|
|
@router.post("/login", response_model=Token)
|
|
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
|
|
"""Login with email and password to get a JWT."""
|
|
user = history_manager.authenticate_user(form_data.username, form_data.password)
|
|
if not user:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Incorrect email or password",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
access_token = create_access_token(data={"sub": user.username, "user_id": str(user.id)})
|
|
return {"access_token": access_token, "token_type": "bearer"}
|
|
|
|
@router.get("/oidc/login")
|
|
async def oidc_login():
|
|
"""Get the OIDC authorization URL."""
|
|
if not oidc_client:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_510_NOT_EXTENDED,
|
|
detail="OIDC is not configured"
|
|
)
|
|
|
|
url = oidc_client.get_login_url()
|
|
return {"url": url}
|
|
|
|
@router.get("/me", response_model=UserResponse)
|
|
async def get_me(current_user: UserDB = Depends(get_current_user)):
|
|
"""Get the current authenticated user's profile."""
|
|
return {
|
|
"id": str(current_user.id),
|
|
"email": current_user.username,
|
|
"display_name": current_user.display_name
|
|
} |