Files
ea-chatbot-lg/src/ea_chatbot/history/manager.py
Yunxiao Xu b23fbce8d0 refactor(auth): Use user_id as JWT sub and implement get_user_by_id
Switched from username to user_id as the primary identifier in JWT tokens to better support external authentication providers. Added get_user_by_id to HistoryManager and updated API dependencies and tests to reflect these changes.
2026-02-11 16:41:27 -08:00

188 lines
7.1 KiB
Python

from contextlib import contextmanager
from typing import Optional, List
from sqlalchemy import create_engine, select, delete
from sqlalchemy.orm import sessionmaker, Session
from argon2 import PasswordHasher
from argon2.exceptions import VerifyMismatchError
from ea_chatbot.history.models import User, Conversation, Message, Plot
# Argon2 Password Hasher
ph = PasswordHasher()
class HistoryManager:
"""Manages database sessions and operations for history and user data."""
def __init__(self, db_url: str):
self.engine = create_engine(db_url)
# expire_on_commit=False is important so we can use objects after session closes
self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self.engine, expire_on_commit=False)
@contextmanager
def get_session(self):
"""Context manager for database sessions."""
session = self.SessionLocal()
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
# --- User Management ---
def get_user(self, email: str) -> Optional[User]:
"""Fetch a user by their email (username)."""
with self.get_session() as session:
result = session.execute(select(User).where(User.username == email))
return result.scalar_one_or_none()
def get_user_by_id(self, user_id: str) -> Optional[User]:
"""Fetch a user by their ID."""
with self.get_session() as session:
return session.get(User, user_id)
def create_user(self, email: str, password: Optional[str] = None, display_name: Optional[str] = None) -> User:
"""Create a new local user."""
hashed_password = ph.hash(password) if password else None
user = User(
username=email,
password_hash=hashed_password,
display_name=display_name or email.split("@")[0]
)
with self.get_session() as session:
session.add(user)
session.commit()
session.refresh(user)
return user
def authenticate_user(self, email: str, password: str) -> Optional[User]:
"""Authenticate a user by email and password."""
user = self.get_user(email)
if not user or not user.password_hash:
return None
try:
ph.verify(user.password_hash, password)
return user
except VerifyMismatchError:
return None
def sync_user_from_oidc(self, email: str, display_name: Optional[str] = None) -> User:
"""
Synchronize a user from an OIDC provider.
If a user with the same email exists, update their display name.
Otherwise, create a new user.
"""
user = self.get_user(email)
if user:
# Update existing user if needed
if display_name and user.display_name != display_name:
with self.get_session() as session:
db_user = session.get(User, user.id)
db_user.display_name = display_name
session.commit()
session.refresh(db_user)
return db_user
return user
else:
# Create new user (no password for OIDC users initially)
return self.create_user(email=email, display_name=display_name)
# --- Conversation Management ---
def create_conversation(self, user_id: str, data_state: str, name: str, summary: Optional[str] = None) -> Conversation:
"""Create a new conversation for a user."""
conv = Conversation(
user_id=user_id,
data_state=data_state,
name=name,
summary=summary
)
with self.get_session() as session:
session.add(conv)
session.commit()
session.refresh(conv)
return conv
def get_conversations(self, user_id: str, data_state: str) -> List[Conversation]:
"""Get all conversations for a user and data state, ordered by creation time."""
with self.get_session() as session:
stmt = (
select(Conversation)
.where(Conversation.user_id == user_id, Conversation.data_state == data_state)
.order_by(Conversation.created_at.desc())
)
result = session.execute(stmt)
return list(result.scalars().all())
def rename_conversation(self, conversation_id: str, new_name: str) -> Optional[Conversation]:
"""Rename an existing conversation."""
with self.get_session() as session:
conv = session.get(Conversation, conversation_id)
if conv:
conv.name = new_name
session.commit()
session.refresh(conv)
return conv
def delete_conversation(self, conversation_id: str) -> bool:
"""Delete a conversation and its associated messages/plots (via cascade)."""
with self.get_session() as session:
conv = session.get(Conversation, conversation_id)
if conv:
session.delete(conv)
session.commit()
return True
return False
def update_conversation_summary(self, conversation_id: str, summary: str) -> Optional[Conversation]:
"""Update the summary of a conversation."""
with self.get_session() as session:
conv = session.get(Conversation, conversation_id)
if conv:
conv.summary = summary
session.commit()
session.refresh(conv)
return conv
# --- Message & Plot Management ---
def add_message(self, conversation_id: str, role: str, content: str, plots: Optional[List[bytes]] = None) -> Message:
"""Add a message to a conversation, optionally with plots."""
msg = Message(
conversation_id=conversation_id,
role=role,
content=content
)
with self.get_session() as session:
session.add(msg)
session.flush() # Populate msg.id for plots
if plots:
for plot_data in plots:
plot = Plot(message_id=msg.id, image_data=plot_data)
session.add(plot)
session.commit()
session.refresh(msg)
# Ensure plots are loaded before session closes if we need them
_ = msg.plots
return msg
def get_messages(self, conversation_id: str) -> List[Message]:
"""Get all messages for a conversation, ordered by creation time."""
with self.get_session() as session:
stmt = (
select(Message)
.where(Message.conversation_id == conversation_id)
.order_by(Message.created_at.asc())
)
result = session.execute(stmt)
messages = list(result.scalars().all())
# Pre-load plots for each message
for m in messages:
_ = m.plots
return messages