feat: Add light/dark mode support with backend persistence
This commit is contained in:
@@ -0,0 +1,28 @@
|
||||
"""Add theme preference to user
|
||||
|
||||
Revision ID: 2473a00afb70
|
||||
Revises: 63886baa1255
|
||||
Create Date: 2026-02-16 17:00:25.537643
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '2473a00afb70'
|
||||
down_revision: Union[str, Sequence[str], None] = '63886baa1255'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
op.add_column('users', sa.Column('theme_preference', sa.String(), nullable=False, server_default='light'))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
op.drop_column('users', 'theme_preference')
|
||||
@@ -91,7 +91,7 @@ async def stream_agent_events(
|
||||
encoded_plots: list[str] = []
|
||||
for fig in plots:
|
||||
buf = io.BytesIO()
|
||||
fig.savefig(buf, format="png")
|
||||
fig.savefig(buf, format="png", transparent=True)
|
||||
plot_bytes = buf.getvalue()
|
||||
assistant_plots.append(plot_bytes)
|
||||
encoded_plots.append(base64.b64encode(plot_bytes).decode('utf-8'))
|
||||
|
||||
@@ -4,7 +4,7 @@ from fastapi.security import OAuth2PasswordRequestForm
|
||||
from ea_chatbot.api.utils import create_access_token, settings
|
||||
from ea_chatbot.api.dependencies import history_manager, oidc_client, get_current_user
|
||||
from ea_chatbot.history.models import User as UserDB
|
||||
from ea_chatbot.api.schemas import Token, UserCreate, UserResponse
|
||||
from ea_chatbot.api.schemas import Token, UserCreate, UserResponse, ThemeUpdate
|
||||
from ea_chatbot.auth import OIDCSession
|
||||
import logging
|
||||
|
||||
@@ -45,7 +45,8 @@ async def register(user_in: UserCreate, response: Response):
|
||||
return {
|
||||
"id": str(user.id),
|
||||
"email": user.username,
|
||||
"display_name": user.display_name
|
||||
"display_name": user.display_name,
|
||||
"theme_preference": user.theme_preference
|
||||
}
|
||||
|
||||
@router.post("/login", response_model=Token)
|
||||
@@ -155,5 +156,25 @@ async def get_me(current_user: UserDB = Depends(get_current_user)):
|
||||
return {
|
||||
"id": str(current_user.id),
|
||||
"email": current_user.username,
|
||||
"display_name": current_user.display_name
|
||||
"display_name": current_user.display_name,
|
||||
"theme_preference": current_user.theme_preference
|
||||
}
|
||||
|
||||
@router.patch("/theme", response_model=UserResponse)
|
||||
async def update_theme(
|
||||
theme_in: ThemeUpdate,
|
||||
current_user: UserDB = Depends(get_current_user)
|
||||
):
|
||||
"""Update the current user's theme preference."""
|
||||
user = history_manager.update_user_theme(current_user.id, theme_in.theme)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
return {
|
||||
"id": str(user.id),
|
||||
"email": user.username,
|
||||
"display_name": user.display_name,
|
||||
"theme_preference": user.theme_preference
|
||||
}
|
||||
|
||||
@@ -17,6 +17,10 @@ class UserResponse(BaseModel):
|
||||
id: str
|
||||
email: str
|
||||
display_name: str
|
||||
theme_preference: str
|
||||
|
||||
class ThemeUpdate(BaseModel):
|
||||
theme: str
|
||||
|
||||
# --- History Schemas ---
|
||||
|
||||
|
||||
@@ -420,7 +420,7 @@ def main():
|
||||
st.pyplot(fig)
|
||||
# Convert fig to bytes
|
||||
buf = io.BytesIO()
|
||||
fig.savefig(buf, format="png")
|
||||
fig.savefig(buf, format="png", transparent=True)
|
||||
plot_bytes_list.append(buf.getvalue())
|
||||
|
||||
if final_state.get("dfs"):
|
||||
|
||||
@@ -70,6 +70,16 @@ class HistoryManager:
|
||||
except VerifyMismatchError:
|
||||
return None
|
||||
|
||||
def update_user_theme(self, user_id: str, theme: str) -> Optional[User]:
|
||||
"""Update the user's theme preference."""
|
||||
with self.get_session() as session:
|
||||
user = session.get(User, user_id)
|
||||
if user:
|
||||
user.theme_preference = theme
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
return user
|
||||
|
||||
def sync_user_from_oidc(self, email: str, display_name: Optional[str] = None) -> User:
|
||||
"""
|
||||
Synchronize a user from an OIDC provider.
|
||||
|
||||
@@ -14,6 +14,7 @@ class User(Base):
|
||||
username: Mapped[str] = mapped_column(String, unique=True, index=True)
|
||||
password_hash: Mapped[Optional[str]] = mapped_column(String, nullable=True)
|
||||
display_name: Mapped[Optional[str]] = mapped_column(String, nullable=True)
|
||||
theme_preference: Mapped[str] = mapped_column(String, default="light")
|
||||
|
||||
conversations: Mapped[List["Conversation"]] = relationship(back_populates="user", cascade="all, delete-orphan")
|
||||
|
||||
|
||||
@@ -15,7 +15,8 @@ def mock_user():
|
||||
id="user-123",
|
||||
username="test@example.com",
|
||||
display_name="Test User",
|
||||
password_hash="hashed_password"
|
||||
password_hash="hashed_password",
|
||||
theme_preference="light"
|
||||
)
|
||||
|
||||
def test_register_user_success():
|
||||
@@ -23,7 +24,7 @@ def test_register_user_success():
|
||||
# We mock it where it is used in the router
|
||||
with patch("ea_chatbot.api.routers.auth.history_manager") as mock_hm:
|
||||
mock_hm.get_user.return_value = None
|
||||
mock_hm.create_user.return_value = User(id="1", username="new@example.com", display_name="New")
|
||||
mock_hm.create_user.return_value = User(id="1", username="new@example.com", display_name="New", theme_preference="light")
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/register",
|
||||
@@ -93,7 +94,7 @@ def test_oidc_callback_success():
|
||||
}
|
||||
mock_oidc.exchange_code_for_token.return_value = {"id_token": "fake_id_token"}
|
||||
mock_oidc.validate_id_token.return_value = {"email": "sso@example.com", "name": "SSO User"}
|
||||
mock_hm.sync_user_from_oidc.return_value = User(id="sso-123", username="sso@example.com", display_name="SSO User")
|
||||
mock_hm.sync_user_from_oidc.return_value = User(id="sso-123", username="sso@example.com", display_name="SSO User", theme_preference="light")
|
||||
|
||||
client.cookies.set("oidc_session", "fake_token")
|
||||
response = client.get(
|
||||
@@ -110,7 +111,7 @@ def test_get_me_success():
|
||||
token = create_access_token(data={"sub": "123"})
|
||||
|
||||
with patch("ea_chatbot.api.dependencies.history_manager") as mock_hm:
|
||||
mock_hm.get_user_by_id.return_value = User(id="123", username="test@example.com", display_name="Test")
|
||||
mock_hm.get_user_by_id.return_value = User(id="123", username="test@example.com", display_name="Test", theme_preference="light")
|
||||
|
||||
response = client.get(
|
||||
"/api/v1/auth/me",
|
||||
|
||||
73
backend/tests/api/test_theme.py
Normal file
73
backend/tests/api/test_theme.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import patch
|
||||
from ea_chatbot.api.main import app
|
||||
from ea_chatbot.history.models import User
|
||||
from ea_chatbot.api.utils import create_access_token
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
@pytest.fixture
|
||||
def test_user():
|
||||
return User(
|
||||
id="user-123",
|
||||
username="test@example.com",
|
||||
display_name="Test User",
|
||||
theme_preference="light"
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def auth_token():
|
||||
return create_access_token(data={"sub": "user-123"})
|
||||
|
||||
def test_get_me_includes_theme(test_user, auth_token):
|
||||
"""Test that /auth/me returns the theme_preference."""
|
||||
with patch("ea_chatbot.api.dependencies.history_manager") as mock_hm:
|
||||
mock_hm.get_user_by_id.return_value = test_user
|
||||
|
||||
response = client.get(
|
||||
"/api/v1/auth/me",
|
||||
cookies={"access_token": auth_token}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "theme_preference" in data
|
||||
assert data["theme_preference"] == "light"
|
||||
|
||||
def test_update_theme_success(test_user, auth_token):
|
||||
"""Test successful theme update via PATCH /auth/theme."""
|
||||
updated_user = User(
|
||||
id="user-123",
|
||||
username="test@example.com",
|
||||
display_name="Test User",
|
||||
theme_preference="dark"
|
||||
)
|
||||
|
||||
with patch("ea_chatbot.api.dependencies.history_manager") as mock_hm_dep, \
|
||||
patch("ea_chatbot.api.routers.auth.history_manager") as mock_hm_router:
|
||||
|
||||
# Dependency injection uses the one from dependencies
|
||||
mock_hm_dep.get_user_by_id.return_value = test_user
|
||||
|
||||
# The router uses its own reference to history_manager
|
||||
mock_hm_router.update_user_theme.return_value = updated_user
|
||||
|
||||
response = client.patch(
|
||||
"/api/v1/auth/theme",
|
||||
json={"theme": "dark"},
|
||||
cookies={"access_token": auth_token}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["theme_preference"] == "dark"
|
||||
mock_hm_router.update_user_theme.assert_called_once_with("user-123", "dark")
|
||||
|
||||
def test_update_theme_unauthorized():
|
||||
"""Test that theme update requires authentication."""
|
||||
response = client.patch(
|
||||
"/api/v1/auth/theme",
|
||||
json={"theme": "dark"}
|
||||
)
|
||||
assert response.status_code == 401
|
||||
@@ -9,8 +9,10 @@ import { ChatService, type MessageResponse } from "./services/chat"
|
||||
import { type Conversation } from "./components/layout/HistorySidebar"
|
||||
import { registerUnauthorizedCallback } from "./services/api"
|
||||
import { Button } from "./components/ui/button"
|
||||
import { useTheme } from "./components/theme-provider"
|
||||
|
||||
function App() {
|
||||
const { setTheme } = useTheme()
|
||||
const [isAuthenticated, setIsAuthenticated] = useState(false)
|
||||
const [user, setUser] = useState<UserResponse | null>(null)
|
||||
const [authMode, setAuthMode] = useState<"login" | "register">("login")
|
||||
@@ -34,6 +36,9 @@ function App() {
|
||||
const userData = await AuthService.getMe()
|
||||
setUser(userData)
|
||||
setIsAuthenticated(true)
|
||||
if (userData.theme_preference) {
|
||||
setTheme(userData.theme_preference as "light" | "dark")
|
||||
}
|
||||
// Load history after successful auth
|
||||
loadHistory()
|
||||
} catch (err: unknown) {
|
||||
@@ -61,6 +66,9 @@ function App() {
|
||||
const userData = await AuthService.getMe()
|
||||
setUser(userData)
|
||||
setIsAuthenticated(true)
|
||||
if (userData.theme_preference) {
|
||||
setTheme(userData.theme_preference as "light" | "dark")
|
||||
}
|
||||
loadHistory()
|
||||
} catch (err: unknown) {
|
||||
console.error("Failed to fetch user profile after login:", err)
|
||||
|
||||
@@ -2,6 +2,7 @@ import ReactMarkdown, { type Components } from "react-markdown"
|
||||
import remarkGfm from "remark-gfm"
|
||||
import { Prism as SyntaxHighlighter } from "react-syntax-highlighter"
|
||||
import { oneDark } from "react-syntax-highlighter/dist/esm/styles/prism"
|
||||
import { cn } from "@/lib/utils"
|
||||
|
||||
interface MarkdownContentProps {
|
||||
content: string
|
||||
@@ -27,7 +28,7 @@ export function MarkdownContent({ content }: MarkdownContentProps) {
|
||||
</SyntaxHighlighter>
|
||||
)
|
||||
}
|
||||
return <code className={className}>{children}</code>
|
||||
return <code className={cn("bg-muted px-1.5 py-0.5 rounded font-mono text-[0.8em]", className)}>{children}</code>
|
||||
},
|
||||
table({ children }) {
|
||||
return (
|
||||
|
||||
@@ -55,7 +55,7 @@ export function MessageBubble({ message }: MessageBubbleProps) {
|
||||
<Button
|
||||
key={`stream-${index}`}
|
||||
variant="ghost"
|
||||
className="relative group p-0 h-auto w-full max-w-2xl mx-auto overflow-hidden hover:bg-transparent flex justify-center border bg-white"
|
||||
className="relative group p-0 h-auto w-full max-w-2xl mx-auto overflow-hidden hover:bg-transparent flex justify-center border bg-white dark:bg-muted/20"
|
||||
onClick={() => setSelectedPlot(src)}
|
||||
>
|
||||
<img
|
||||
@@ -75,7 +75,7 @@ export function MessageBubble({ message }: MessageBubbleProps) {
|
||||
<Button
|
||||
key={`history-${index}`}
|
||||
variant="ghost"
|
||||
className="relative group p-0 h-auto w-full max-w-2xl mx-auto overflow-hidden hover:bg-transparent flex justify-center border bg-white"
|
||||
className="relative group p-0 h-auto w-full max-w-2xl mx-auto overflow-hidden hover:bg-transparent flex justify-center border bg-white dark:bg-muted/20"
|
||||
onClick={() => setSelectedPlot(src)}
|
||||
>
|
||||
<img
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import * as React from "react"
|
||||
import { SidebarProvider, SidebarInset, SidebarTrigger } from "@/components/ui/sidebar"
|
||||
import { HistorySidebar, type Conversation } from "./HistorySidebar"
|
||||
import { ThemeToggle } from "./ThemeToggle"
|
||||
|
||||
interface MainLayoutProps {
|
||||
children: React.ReactNode
|
||||
@@ -33,9 +34,12 @@ export function MainLayout({
|
||||
onDelete={onDelete}
|
||||
/>
|
||||
<SidebarInset className="flex flex-col flex-1 h-full overflow-hidden">
|
||||
<header className="flex h-16 shrink-0 items-center gap-2 border-b px-4" role="navigation">
|
||||
<header className="flex h-16 shrink-0 items-center justify-between border-b px-4" role="navigation">
|
||||
<div className="flex items-center gap-2">
|
||||
<SidebarTrigger />
|
||||
<div className="font-semibold">Chat</div>
|
||||
</div>
|
||||
<ThemeToggle />
|
||||
</header>
|
||||
<main className="flex-1 flex flex-col p-6 overflow-hidden bg-muted/10">
|
||||
{children}
|
||||
|
||||
20
frontend/src/components/layout/ThemeToggle.tsx
Normal file
20
frontend/src/components/layout/ThemeToggle.tsx
Normal file
@@ -0,0 +1,20 @@
|
||||
import { Moon, Sun } from "lucide-react"
|
||||
import { Button } from "@/components/ui/button"
|
||||
import { useTheme } from "@/components/theme-provider"
|
||||
|
||||
export function ThemeToggle() {
|
||||
const { theme, toggleTheme } = useTheme()
|
||||
|
||||
return (
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
onClick={toggleTheme}
|
||||
title={`Switch to ${theme === "light" ? "dark" : "light"} mode`}
|
||||
>
|
||||
<Sun className="h-[1.2rem] w-[1.2rem] rotate-0 scale-100 transition-all dark:-rotate-90 dark:scale-0" />
|
||||
<Moon className="absolute h-[1.2rem] w-[1.2rem] rotate-90 scale-0 transition-all dark:rotate-0 dark:scale-100" />
|
||||
<span className="sr-only">Toggle theme</span>
|
||||
</Button>
|
||||
)
|
||||
}
|
||||
55
frontend/src/components/theme-provider.tsx
Normal file
55
frontend/src/components/theme-provider.tsx
Normal file
@@ -0,0 +1,55 @@
|
||||
import { createContext, useContext, useEffect, useState } from "react"
|
||||
import { AuthService } from "@/services/auth"
|
||||
|
||||
type Theme = "light" | "dark"
|
||||
|
||||
interface ThemeContextType {
|
||||
theme: Theme
|
||||
setTheme: (theme: Theme) => void
|
||||
toggleTheme: () => void
|
||||
}
|
||||
|
||||
const ThemeContext = createContext<ThemeContextType | undefined>(undefined)
|
||||
|
||||
export function ThemeProvider({
|
||||
children,
|
||||
initialTheme = "light",
|
||||
}: {
|
||||
children: React.ReactNode
|
||||
initialTheme?: Theme
|
||||
}) {
|
||||
const [theme, setThemeState] = useState<Theme>(initialTheme)
|
||||
|
||||
useEffect(() => {
|
||||
const root = window.document.documentElement
|
||||
root.classList.remove("light", "dark")
|
||||
root.classList.add(theme)
|
||||
}, [theme])
|
||||
|
||||
const setTheme = async (newTheme: Theme) => {
|
||||
setThemeState(newTheme)
|
||||
try {
|
||||
await AuthService.updateTheme(newTheme)
|
||||
} catch (error) {
|
||||
console.error("Failed to sync theme to backend:", error)
|
||||
}
|
||||
}
|
||||
|
||||
const toggleTheme = () => {
|
||||
setTheme(theme === "light" ? "dark" : "light")
|
||||
}
|
||||
|
||||
return (
|
||||
<ThemeContext.Provider value={{ theme, setTheme, toggleTheme }}>
|
||||
{children}
|
||||
</ThemeContext.Provider>
|
||||
)
|
||||
}
|
||||
|
||||
export const useTheme = () => {
|
||||
const context = useContext(ThemeContext)
|
||||
if (context === undefined) {
|
||||
throw new Error("useTheme must be used within a ThemeProvider")
|
||||
}
|
||||
return context
|
||||
}
|
||||
@@ -4,13 +4,16 @@ import './index.css'
|
||||
import App from './App.tsx'
|
||||
import { TooltipProvider } from "@/components/ui/tooltip"
|
||||
import { BrowserRouter } from "react-router-dom"
|
||||
import { ThemeProvider } from "./components/theme-provider"
|
||||
|
||||
createRoot(document.getElementById('root')!).render(
|
||||
<StrictMode>
|
||||
<BrowserRouter>
|
||||
<ThemeProvider>
|
||||
<TooltipProvider>
|
||||
<App />
|
||||
</TooltipProvider>
|
||||
</ThemeProvider>
|
||||
</BrowserRouter>
|
||||
</StrictMode>,
|
||||
)
|
||||
|
||||
@@ -9,6 +9,7 @@ export interface UserResponse {
|
||||
id: string
|
||||
email: string
|
||||
display_name?: string
|
||||
theme_preference: string
|
||||
}
|
||||
|
||||
export const AuthService = {
|
||||
@@ -49,4 +50,9 @@ export const AuthService = {
|
||||
async logout() {
|
||||
await api.post("/auth/logout")
|
||||
},
|
||||
|
||||
async updateTheme(theme: string): Promise<UserResponse> {
|
||||
const response = await api.patch<UserResponse>("/auth/theme", { theme })
|
||||
return response.data
|
||||
},
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user