From a94cbc7f6d9d2cb39c258b8185f97b8c612c7a5c Mon Sep 17 00:00:00 2001 From: Yunxiao Xu Date: Tue, 17 Feb 2026 02:34:47 -0800 Subject: [PATCH] chore: Perform codebase cleanup and refactor App state management --- backend/src/ea_chatbot/api/routers/agent.py | 5 +- backend/src/ea_chatbot/app.py | 5 +- backend/src/ea_chatbot/auth.py | 10 --- backend/src/ea_chatbot/config.py | 2 +- backend/src/ea_chatbot/utils/plots.py | 17 ++++++ backend/tests/test_oidc_client_v2.py | 9 --- frontend/src/App.tsx | 68 +++++++++++++++++---- frontend/src/components/theme-provider.tsx | 12 ++-- frontend/src/main.tsx | 8 +-- 9 files changed, 90 insertions(+), 46 deletions(-) create mode 100644 backend/src/ea_chatbot/utils/plots.py diff --git a/backend/src/ea_chatbot/api/routers/agent.py b/backend/src/ea_chatbot/api/routers/agent.py index 349dc6b..a39b868 100644 --- a/backend/src/ea_chatbot/api/routers/agent.py +++ b/backend/src/ea_chatbot/api/routers/agent.py @@ -9,6 +9,7 @@ from ea_chatbot.graph.checkpoint import get_checkpointer from ea_chatbot.history.models import User as UserDB, Conversation from ea_chatbot.history.utils import map_db_messages_to_langchain from ea_chatbot.api.schemas import ChatRequest +from ea_chatbot.utils.plots import fig_to_bytes import io import base64 from langchain_core.runnables.config import RunnableConfig @@ -90,9 +91,7 @@ async def stream_agent_events( plots = output["plots"] encoded_plots: list[str] = [] for fig in plots: - buf = io.BytesIO() - fig.savefig(buf, format="png") - plot_bytes = buf.getvalue() + plot_bytes = fig_to_bytes(fig) assistant_plots.append(plot_bytes) encoded_plots.append(base64.b64encode(plot_bytes).decode('utf-8')) output_event["data"]["encoded_plots"] = encoded_plots diff --git a/backend/src/ea_chatbot/app.py b/backend/src/ea_chatbot/app.py index 9585868..af3d282 100644 --- a/backend/src/ea_chatbot/app.py +++ b/backend/src/ea_chatbot/app.py @@ -6,6 +6,7 @@ from ea_chatbot.graph.workflow import app from ea_chatbot.graph.state import AgentState from ea_chatbot.utils.logging import get_logger from ea_chatbot.utils.helpers import merge_agent_state +from ea_chatbot.utils.plots import fig_to_bytes from ea_chatbot.config import Settings from ea_chatbot.history.manager import HistoryManager from ea_chatbot.auth import OIDCClient, AuthType, get_user_auth_type @@ -419,9 +420,7 @@ def main(): for fig in final_state["plots"]: st.pyplot(fig) # Convert fig to bytes - buf = io.BytesIO() - fig.savefig(buf, format="png") - plot_bytes_list.append(buf.getvalue()) + plot_bytes_list.append(fig_to_bytes(fig)) if final_state.get("dfs"): for df_name, df in final_state["dfs"].items(): diff --git a/backend/src/ea_chatbot/auth.py b/backend/src/ea_chatbot/auth.py index 01ea087..f35c329 100644 --- a/backend/src/ea_chatbot/auth.py +++ b/backend/src/ea_chatbot/auth.py @@ -162,16 +162,6 @@ class OIDCClient: except JWTError as e: raise ValueError(f"ID Token validation failed: {str(e)}") - def get_login_url(self) -> str: - """Legacy method for generating simple authorization URL.""" - metadata = self.fetch_metadata() - authorization_endpoint = metadata.get("authorization_endpoint") - if not authorization_endpoint: - raise ValueError("authorization_endpoint not found in OIDC metadata") - - uri, state = self.oauth_session.create_authorization_url(authorization_endpoint) - return uri - def exchange_code_for_token(self, code: str, code_verifier: Optional[str] = None) -> Dict[str, Any]: """Exchange the authorization code for an access token, optionally using PKCE verifier.""" metadata = self.fetch_metadata() diff --git a/backend/src/ea_chatbot/config.py b/backend/src/ea_chatbot/config.py index 6c79eaa..19a218f 100644 --- a/backend/src/ea_chatbot/config.py +++ b/backend/src/ea_chatbot/config.py @@ -19,7 +19,7 @@ class Settings(BaseSettings): data_dir: str = "data" data_state: str = "new_jersey" log_level: str = Field(default="INFO", alias="LOG_LEVEL") - dev_mode: bool = Field(default=True, alias="DEV_MODE") + dev_mode: bool = Field(default=False, alias="DEV_MODE") frontend_url: str = Field(default="http://localhost:5173", alias="FRONTEND_URL") # Voter Database configuration diff --git a/backend/src/ea_chatbot/utils/plots.py b/backend/src/ea_chatbot/utils/plots.py new file mode 100644 index 0000000..5677a56 --- /dev/null +++ b/backend/src/ea_chatbot/utils/plots.py @@ -0,0 +1,17 @@ +import io +import matplotlib.pyplot as plt + +def fig_to_bytes(fig: plt.Figure, format: str = "png") -> bytes: + """ + Convert a Matplotlib figure to bytes. + + Args: + fig: The Matplotlib figure to convert. + format: The image format to use (default: "png"). + + Returns: + bytes: The image data. + """ + buf = io.BytesIO() + fig.savefig(buf, format=format) + return buf.getvalue() diff --git a/backend/tests/test_oidc_client_v2.py b/backend/tests/test_oidc_client_v2.py index b130319..762b954 100644 --- a/backend/tests/test_oidc_client_v2.py +++ b/backend/tests/test_oidc_client_v2.py @@ -55,15 +55,6 @@ def test_oidc_fetch_jwks(oidc_config, mock_metadata): client.fetch_jwks() assert mock_get.call_count == 1 -def test_oidc_get_login_url_legacy(oidc_config, mock_metadata): - client = OIDCClient(**oidc_config) - client.metadata = mock_metadata - - with patch.object(client.oauth_session, "create_authorization_url") as mock_create: - mock_create.return_value = ("https://url", "state") - url = client.get_login_url() - assert url == "https://url" - def test_oidc_get_user_info(oidc_config, mock_metadata): client = OIDCClient(**oidc_config) client.metadata = mock_metadata diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index ad45108..3a35911 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -1,4 +1,4 @@ -import { useState, useEffect } from "react" +import { useState, useEffect, createContext, useContext } from "react" import { Routes, Route } from "react-router-dom" import { MainLayout } from "./components/layout/MainLayout" import { LoginForm } from "./components/auth/LoginForm" @@ -9,12 +9,42 @@ import { ChatService, type MessageResponse } from "./services/chat" import { type Conversation } from "./components/layout/HistorySidebar" import { registerUnauthorizedCallback } from "./services/api" import { Button } from "./components/ui/button" -import { useTheme } from "./components/theme-provider" +import { ThemeProvider, useTheme } from "./components/theme-provider" -function App() { - const { setThemeLocal } = useTheme() +// --- Auth Context --- +interface AuthContextType { + isAuthenticated: boolean + setIsAuthenticated: (val: boolean) => void + user: UserResponse | null + setUser: (user: UserResponse | null) => void +} + +const AuthContext = createContext(undefined) + +function AuthProvider({ children }: { children: React.ReactNode }) { const [isAuthenticated, setIsAuthenticated] = useState(false) const [user, setUser] = useState(null) + + return ( + + {children} + + ) +} + +function useAuth() { + const context = useContext(AuthContext) + if (context === undefined) { + throw new Error("useAuth must be used within an AuthProvider") + } + return context +} + +// --- App Content --- +function AppContent() { + const { setThemeLocal } = useTheme() + const { isAuthenticated, setIsAuthenticated, user, setUser } = useAuth() + const [authMode, setAuthMode] = useState<"login" | "register">("login") const [isLoading, setIsLoading] = useState(true) const [selectedThreadId, setSelectedThreadId] = useState(null) @@ -22,13 +52,13 @@ function App() { const [threadMessages, setThreadMessages] = useState>({}) useEffect(() => { - // Register callback to handle session expiration from anywhere in the app registerUnauthorizedCallback(() => { setIsAuthenticated(false) setUser(null) setConversations([]) setSelectedThreadId(null) setThreadMessages({}) + setThemeLocal("light") }) const initAuth = async () => { @@ -39,7 +69,6 @@ function App() { if (userData.theme_preference) { setThemeLocal(userData.theme_preference) } - // Load history after successful auth loadHistory() } catch (err: unknown) { console.log("No active session found", err) @@ -50,7 +79,7 @@ function App() { } initAuth() - }, []) + }, [setIsAuthenticated, setThemeLocal, setUser]) const loadHistory = async () => { try { @@ -86,13 +115,12 @@ function App() { setSelectedThreadId(null) setConversations([]) setThreadMessages({}) + setThemeLocal("light") } } const handleSelectConversation = async (id: string) => { setSelectedThreadId(id) - // Always fetch messages to avoid stale cache issues when switching back - // or if the session was updated from elsewhere try { const msgs = await ChatService.getMessages(id) setThreadMessages(prev => ({ ...prev, [id]: msgs })) @@ -125,7 +153,6 @@ function App() { try { await ChatService.deleteConversation(id) setConversations(prev => prev.filter(c => c.id !== id)) - // Also clear from cache setThreadMessages(prev => { const next = { ...prev } delete next[id] @@ -202,7 +229,7 @@ function App() {
{selectedThreadId ? ( handleMessagesFinal(selectedThreadId, msgs)} @@ -249,4 +276,23 @@ function App() { ) } +function ThemeWrapper({ children }: { children: React.ReactNode }) { + const { isAuthenticated } = useAuth() + return ( + + {children} + + ) +} + +function App() { + return ( + + + + + + ) +} + export default App diff --git a/frontend/src/components/theme-provider.tsx b/frontend/src/components/theme-provider.tsx index 3207bce..4bf0316 100644 --- a/frontend/src/components/theme-provider.tsx +++ b/frontend/src/components/theme-provider.tsx @@ -15,9 +15,11 @@ const ThemeContext = createContext(undefined) export function ThemeProvider({ children, initialTheme = "light", + isAuthenticated = false, }: { children: React.ReactNode initialTheme?: Theme + isAuthenticated?: boolean }) { const [theme, setThemeState] = useState(initialTheme) @@ -29,10 +31,12 @@ export function ThemeProvider({ const setTheme = async (newTheme: Theme) => { setThemeState(newTheme) - try { - await AuthService.updateTheme(newTheme) - } catch (error) { - console.error("Failed to sync theme to backend:", error) + if (isAuthenticated) { + try { + await AuthService.updateTheme(newTheme) + } catch (error) { + console.error("Failed to sync theme to backend:", error) + } } } diff --git a/frontend/src/main.tsx b/frontend/src/main.tsx index cc459f8..de6b4c5 100644 --- a/frontend/src/main.tsx +++ b/frontend/src/main.tsx @@ -9,11 +9,9 @@ import { ThemeProvider } from "./components/theme-provider" createRoot(document.getElementById('root')!).render( - - - - - + + + , )