from contextlib import contextmanager from typing import Optional, List from sqlalchemy import create_engine, select, delete from sqlalchemy.orm import sessionmaker, Session from argon2 import PasswordHasher from argon2.exceptions import VerifyMismatchError from ea_chatbot.history.models import User, Conversation, Message, Plot # Argon2 Password Hasher ph = PasswordHasher() class HistoryManager: """Manages database sessions and operations for history and user data.""" def __init__(self, db_url: str): self.engine = create_engine(db_url) # expire_on_commit=False is important so we can use objects after session closes self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self.engine, expire_on_commit=False) @contextmanager def get_session(self): """Context manager for database sessions.""" session = self.SessionLocal() try: yield session session.commit() except Exception: session.rollback() raise finally: session.close() # --- User Management --- def get_user(self, email: str) -> Optional[User]: """Fetch a user by their email (username).""" with self.get_session() as session: result = session.execute(select(User).where(User.username == email)) return result.scalar_one_or_none() def get_user_by_id(self, user_id: str) -> Optional[User]: """Fetch a user by their ID.""" with self.get_session() as session: return session.get(User, user_id) def create_user(self, email: str, password: Optional[str] = None, display_name: Optional[str] = None) -> User: """Create a new local user.""" hashed_password = ph.hash(password) if password else None user = User( username=email, password_hash=hashed_password, display_name=display_name or email.split("@")[0] ) with self.get_session() as session: session.add(user) session.commit() session.refresh(user) return user def authenticate_user(self, email: str, password: str) -> Optional[User]: """Authenticate a user by email and password.""" user = self.get_user(email) if not user or not user.password_hash: return None try: ph.verify(user.password_hash, password) return user except VerifyMismatchError: return None def sync_user_from_oidc(self, email: str, display_name: Optional[str] = None) -> User: """ Synchronize a user from an OIDC provider. If a user with the same email exists, update their display name. Otherwise, create a new user. """ user = self.get_user(email) if user: # Update existing user if needed if display_name and user.display_name != display_name: with self.get_session() as session: db_user = session.get(User, user.id) db_user.display_name = display_name session.commit() session.refresh(db_user) return db_user return user else: # Create new user (no password for OIDC users initially) return self.create_user(email=email, display_name=display_name) # --- Conversation Management --- def create_conversation(self, user_id: str, data_state: str, name: str, summary: Optional[str] = None) -> Conversation: """Create a new conversation for a user.""" conv = Conversation( user_id=user_id, data_state=data_state, name=name, summary=summary ) with self.get_session() as session: session.add(conv) session.commit() session.refresh(conv) return conv def get_conversations(self, user_id: str, data_state: str) -> List[Conversation]: """Get all conversations for a user and data state, ordered by creation time.""" with self.get_session() as session: stmt = ( select(Conversation) .where(Conversation.user_id == user_id, Conversation.data_state == data_state) .order_by(Conversation.created_at.desc()) ) result = session.execute(stmt) return list(result.scalars().all()) def rename_conversation(self, conversation_id: str, new_name: str) -> Optional[Conversation]: """Rename an existing conversation.""" with self.get_session() as session: conv = session.get(Conversation, conversation_id) if conv: conv.name = new_name session.commit() session.refresh(conv) return conv def delete_conversation(self, conversation_id: str) -> bool: """Delete a conversation and its associated messages/plots (via cascade).""" with self.get_session() as session: conv = session.get(Conversation, conversation_id) if conv: session.delete(conv) session.commit() return True return False def update_conversation_summary(self, conversation_id: str, summary: str) -> Optional[Conversation]: """Update the summary of a conversation.""" with self.get_session() as session: conv = session.get(Conversation, conversation_id) if conv: conv.summary = summary session.commit() session.refresh(conv) return conv # --- Message & Plot Management --- def add_message(self, conversation_id: str, role: str, content: str, plots: Optional[List[bytes]] = None) -> Message: """Add a message to a conversation, optionally with plots.""" msg = Message( conversation_id=conversation_id, role=role, content=content ) with self.get_session() as session: session.add(msg) session.flush() # Populate msg.id for plots if plots: for plot_data in plots: plot = Plot(message_id=msg.id, image_data=plot_data) session.add(plot) session.commit() session.refresh(msg) # Ensure plots are loaded before session closes if we need them _ = msg.plots return msg def get_messages(self, conversation_id: str) -> List[Message]: """Get all messages for a conversation, ordered by creation time.""" with self.get_session() as session: stmt = ( select(Message) .where(Message.conversation_id == conversation_id) .order_by(Message.created_at.asc()) ) result = session.execute(stmt) messages = list(result.scalars().all()) # Pre-load plots for each message for m in messages: _ = m.plots return messages