feat: implement mvp with email-first login flow and langgraph architecture
This commit is contained in:
183
src/ea_chatbot/history/manager.py
Normal file
183
src/ea_chatbot/history/manager.py
Normal file
@@ -0,0 +1,183 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, List
|
||||
from sqlalchemy import create_engine, select, delete
|
||||
from sqlalchemy.orm import sessionmaker, Session
|
||||
from argon2 import PasswordHasher
|
||||
from argon2.exceptions import VerifyMismatchError
|
||||
|
||||
from ea_chatbot.history.models import User, Conversation, Message, Plot
|
||||
|
||||
# Argon2 Password Hasher
|
||||
ph = PasswordHasher()
|
||||
|
||||
class HistoryManager:
|
||||
"""Manages database sessions and operations for history and user data."""
|
||||
|
||||
def __init__(self, db_url: str):
|
||||
self.engine = create_engine(db_url)
|
||||
# expire_on_commit=False is important so we can use objects after session closes
|
||||
self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self.engine, expire_on_commit=False)
|
||||
|
||||
@contextmanager
|
||||
def get_session(self):
|
||||
"""Context manager for database sessions."""
|
||||
session = self.SessionLocal()
|
||||
try:
|
||||
yield session
|
||||
session.commit()
|
||||
except Exception:
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
# --- User Management ---
|
||||
|
||||
def get_user(self, email: str) -> Optional[User]:
|
||||
"""Fetch a user by their email (username)."""
|
||||
with self.get_session() as session:
|
||||
result = session.execute(select(User).where(User.username == email))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
def create_user(self, email: str, password: Optional[str] = None, display_name: Optional[str] = None) -> User:
|
||||
"""Create a new local user."""
|
||||
hashed_password = ph.hash(password) if password else None
|
||||
user = User(
|
||||
username=email,
|
||||
password_hash=hashed_password,
|
||||
display_name=display_name or email.split("@")[0]
|
||||
)
|
||||
with self.get_session() as session:
|
||||
session.add(user)
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
return user
|
||||
|
||||
def authenticate_user(self, email: str, password: str) -> Optional[User]:
|
||||
"""Authenticate a user by email and password."""
|
||||
user = self.get_user(email)
|
||||
if not user or not user.password_hash:
|
||||
return None
|
||||
|
||||
try:
|
||||
ph.verify(user.password_hash, password)
|
||||
return user
|
||||
except VerifyMismatchError:
|
||||
return None
|
||||
|
||||
def sync_user_from_oidc(self, email: str, display_name: Optional[str] = None) -> User:
|
||||
"""
|
||||
Synchronize a user from an OIDC provider.
|
||||
If a user with the same email exists, update their display name.
|
||||
Otherwise, create a new user.
|
||||
"""
|
||||
user = self.get_user(email)
|
||||
if user:
|
||||
# Update existing user if needed
|
||||
if display_name and user.display_name != display_name:
|
||||
with self.get_session() as session:
|
||||
db_user = session.get(User, user.id)
|
||||
db_user.display_name = display_name
|
||||
session.commit()
|
||||
session.refresh(db_user)
|
||||
return db_user
|
||||
return user
|
||||
else:
|
||||
# Create new user (no password for OIDC users initially)
|
||||
return self.create_user(email=email, display_name=display_name)
|
||||
|
||||
# --- Conversation Management ---
|
||||
|
||||
def create_conversation(self, user_id: str, data_state: str, name: str, summary: Optional[str] = None) -> Conversation:
|
||||
"""Create a new conversation for a user."""
|
||||
conv = Conversation(
|
||||
user_id=user_id,
|
||||
data_state=data_state,
|
||||
name=name,
|
||||
summary=summary
|
||||
)
|
||||
with self.get_session() as session:
|
||||
session.add(conv)
|
||||
session.commit()
|
||||
session.refresh(conv)
|
||||
return conv
|
||||
|
||||
def get_conversations(self, user_id: str, data_state: str) -> List[Conversation]:
|
||||
"""Get all conversations for a user and data state, ordered by creation time."""
|
||||
with self.get_session() as session:
|
||||
stmt = (
|
||||
select(Conversation)
|
||||
.where(Conversation.user_id == user_id, Conversation.data_state == data_state)
|
||||
.order_by(Conversation.created_at.desc())
|
||||
)
|
||||
result = session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
def rename_conversation(self, conversation_id: str, new_name: str) -> Optional[Conversation]:
|
||||
"""Rename an existing conversation."""
|
||||
with self.get_session() as session:
|
||||
conv = session.get(Conversation, conversation_id)
|
||||
if conv:
|
||||
conv.name = new_name
|
||||
session.commit()
|
||||
session.refresh(conv)
|
||||
return conv
|
||||
|
||||
def delete_conversation(self, conversation_id: str) -> bool:
|
||||
"""Delete a conversation and its associated messages/plots (via cascade)."""
|
||||
with self.get_session() as session:
|
||||
conv = session.get(Conversation, conversation_id)
|
||||
if conv:
|
||||
session.delete(conv)
|
||||
session.commit()
|
||||
return True
|
||||
return False
|
||||
|
||||
def update_conversation_summary(self, conversation_id: str, summary: str) -> Optional[Conversation]:
|
||||
"""Update the summary of a conversation."""
|
||||
with self.get_session() as session:
|
||||
conv = session.get(Conversation, conversation_id)
|
||||
if conv:
|
||||
conv.summary = summary
|
||||
session.commit()
|
||||
session.refresh(conv)
|
||||
return conv
|
||||
|
||||
# --- Message & Plot Management ---
|
||||
|
||||
def add_message(self, conversation_id: str, role: str, content: str, plots: Optional[List[bytes]] = None) -> Message:
|
||||
"""Add a message to a conversation, optionally with plots."""
|
||||
msg = Message(
|
||||
conversation_id=conversation_id,
|
||||
role=role,
|
||||
content=content
|
||||
)
|
||||
with self.get_session() as session:
|
||||
session.add(msg)
|
||||
session.flush() # Populate msg.id for plots
|
||||
|
||||
if plots:
|
||||
for plot_data in plots:
|
||||
plot = Plot(message_id=msg.id, image_data=plot_data)
|
||||
session.add(plot)
|
||||
|
||||
session.commit()
|
||||
session.refresh(msg)
|
||||
# Ensure plots are loaded before session closes if we need them
|
||||
_ = msg.plots
|
||||
return msg
|
||||
|
||||
def get_messages(self, conversation_id: str) -> List[Message]:
|
||||
"""Get all messages for a conversation, ordered by creation time."""
|
||||
with self.get_session() as session:
|
||||
stmt = (
|
||||
select(Message)
|
||||
.where(Message.conversation_id == conversation_id)
|
||||
.order_by(Message.created_at.asc())
|
||||
)
|
||||
result = session.execute(stmt)
|
||||
messages = list(result.scalars().all())
|
||||
# Pre-load plots for each message
|
||||
for m in messages:
|
||||
_ = m.plots
|
||||
return messages
|
||||
Reference in New Issue
Block a user