Switched from username to user_id as the primary identifier in JWT tokens to better support external authentication providers. Added get_user_by_id to HistoryManager and updated API dependencies and tests to reflect these changes.
188 lines
7.1 KiB
Python
188 lines
7.1 KiB
Python
from contextlib import contextmanager
|
|
from typing import Optional, List
|
|
from sqlalchemy import create_engine, select, delete
|
|
from sqlalchemy.orm import sessionmaker, Session
|
|
from argon2 import PasswordHasher
|
|
from argon2.exceptions import VerifyMismatchError
|
|
|
|
from ea_chatbot.history.models import User, Conversation, Message, Plot
|
|
|
|
# Argon2 Password Hasher
|
|
ph = PasswordHasher()
|
|
|
|
class HistoryManager:
|
|
"""Manages database sessions and operations for history and user data."""
|
|
|
|
def __init__(self, db_url: str):
|
|
self.engine = create_engine(db_url)
|
|
# expire_on_commit=False is important so we can use objects after session closes
|
|
self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self.engine, expire_on_commit=False)
|
|
|
|
@contextmanager
|
|
def get_session(self):
|
|
"""Context manager for database sessions."""
|
|
session = self.SessionLocal()
|
|
try:
|
|
yield session
|
|
session.commit()
|
|
except Exception:
|
|
session.rollback()
|
|
raise
|
|
finally:
|
|
session.close()
|
|
|
|
# --- User Management ---
|
|
|
|
def get_user(self, email: str) -> Optional[User]:
|
|
"""Fetch a user by their email (username)."""
|
|
with self.get_session() as session:
|
|
result = session.execute(select(User).where(User.username == email))
|
|
return result.scalar_one_or_none()
|
|
|
|
def get_user_by_id(self, user_id: str) -> Optional[User]:
|
|
"""Fetch a user by their ID."""
|
|
with self.get_session() as session:
|
|
return session.get(User, user_id)
|
|
|
|
def create_user(self, email: str, password: Optional[str] = None, display_name: Optional[str] = None) -> User:
|
|
"""Create a new local user."""
|
|
hashed_password = ph.hash(password) if password else None
|
|
user = User(
|
|
username=email,
|
|
password_hash=hashed_password,
|
|
display_name=display_name or email.split("@")[0]
|
|
)
|
|
with self.get_session() as session:
|
|
session.add(user)
|
|
session.commit()
|
|
session.refresh(user)
|
|
return user
|
|
|
|
def authenticate_user(self, email: str, password: str) -> Optional[User]:
|
|
"""Authenticate a user by email and password."""
|
|
user = self.get_user(email)
|
|
if not user or not user.password_hash:
|
|
return None
|
|
|
|
try:
|
|
ph.verify(user.password_hash, password)
|
|
return user
|
|
except VerifyMismatchError:
|
|
return None
|
|
|
|
def sync_user_from_oidc(self, email: str, display_name: Optional[str] = None) -> User:
|
|
"""
|
|
Synchronize a user from an OIDC provider.
|
|
If a user with the same email exists, update their display name.
|
|
Otherwise, create a new user.
|
|
"""
|
|
user = self.get_user(email)
|
|
if user:
|
|
# Update existing user if needed
|
|
if display_name and user.display_name != display_name:
|
|
with self.get_session() as session:
|
|
db_user = session.get(User, user.id)
|
|
db_user.display_name = display_name
|
|
session.commit()
|
|
session.refresh(db_user)
|
|
return db_user
|
|
return user
|
|
else:
|
|
# Create new user (no password for OIDC users initially)
|
|
return self.create_user(email=email, display_name=display_name)
|
|
|
|
# --- Conversation Management ---
|
|
|
|
def create_conversation(self, user_id: str, data_state: str, name: str, summary: Optional[str] = None) -> Conversation:
|
|
"""Create a new conversation for a user."""
|
|
conv = Conversation(
|
|
user_id=user_id,
|
|
data_state=data_state,
|
|
name=name,
|
|
summary=summary
|
|
)
|
|
with self.get_session() as session:
|
|
session.add(conv)
|
|
session.commit()
|
|
session.refresh(conv)
|
|
return conv
|
|
|
|
def get_conversations(self, user_id: str, data_state: str) -> List[Conversation]:
|
|
"""Get all conversations for a user and data state, ordered by creation time."""
|
|
with self.get_session() as session:
|
|
stmt = (
|
|
select(Conversation)
|
|
.where(Conversation.user_id == user_id, Conversation.data_state == data_state)
|
|
.order_by(Conversation.created_at.desc())
|
|
)
|
|
result = session.execute(stmt)
|
|
return list(result.scalars().all())
|
|
|
|
def rename_conversation(self, conversation_id: str, new_name: str) -> Optional[Conversation]:
|
|
"""Rename an existing conversation."""
|
|
with self.get_session() as session:
|
|
conv = session.get(Conversation, conversation_id)
|
|
if conv:
|
|
conv.name = new_name
|
|
session.commit()
|
|
session.refresh(conv)
|
|
return conv
|
|
|
|
def delete_conversation(self, conversation_id: str) -> bool:
|
|
"""Delete a conversation and its associated messages/plots (via cascade)."""
|
|
with self.get_session() as session:
|
|
conv = session.get(Conversation, conversation_id)
|
|
if conv:
|
|
session.delete(conv)
|
|
session.commit()
|
|
return True
|
|
return False
|
|
|
|
def update_conversation_summary(self, conversation_id: str, summary: str) -> Optional[Conversation]:
|
|
"""Update the summary of a conversation."""
|
|
with self.get_session() as session:
|
|
conv = session.get(Conversation, conversation_id)
|
|
if conv:
|
|
conv.summary = summary
|
|
session.commit()
|
|
session.refresh(conv)
|
|
return conv
|
|
|
|
# --- Message & Plot Management ---
|
|
|
|
def add_message(self, conversation_id: str, role: str, content: str, plots: Optional[List[bytes]] = None) -> Message:
|
|
"""Add a message to a conversation, optionally with plots."""
|
|
msg = Message(
|
|
conversation_id=conversation_id,
|
|
role=role,
|
|
content=content
|
|
)
|
|
with self.get_session() as session:
|
|
session.add(msg)
|
|
session.flush() # Populate msg.id for plots
|
|
|
|
if plots:
|
|
for plot_data in plots:
|
|
plot = Plot(message_id=msg.id, image_data=plot_data)
|
|
session.add(plot)
|
|
|
|
session.commit()
|
|
session.refresh(msg)
|
|
# Ensure plots are loaded before session closes if we need them
|
|
_ = msg.plots
|
|
return msg
|
|
|
|
def get_messages(self, conversation_id: str) -> List[Message]:
|
|
"""Get all messages for a conversation, ordered by creation time."""
|
|
with self.get_session() as session:
|
|
stmt = (
|
|
select(Message)
|
|
.where(Message.conversation_id == conversation_id)
|
|
.order_by(Message.created_at.asc())
|
|
)
|
|
result = session.execute(stmt)
|
|
messages = list(result.scalars().all())
|
|
# Pre-load plots for each message
|
|
for m in messages:
|
|
_ = m.plots
|
|
return messages |