chore: Perform codebase cleanup and refactor App state management

This commit is contained in:
Yunxiao Xu
2026-02-17 02:34:47 -08:00
parent ec6760b5a7
commit a94cbc7f6d
9 changed files with 90 additions and 46 deletions

View File

@@ -9,6 +9,7 @@ from ea_chatbot.graph.checkpoint import get_checkpointer
from ea_chatbot.history.models import User as UserDB, Conversation from ea_chatbot.history.models import User as UserDB, Conversation
from ea_chatbot.history.utils import map_db_messages_to_langchain from ea_chatbot.history.utils import map_db_messages_to_langchain
from ea_chatbot.api.schemas import ChatRequest from ea_chatbot.api.schemas import ChatRequest
from ea_chatbot.utils.plots import fig_to_bytes
import io import io
import base64 import base64
from langchain_core.runnables.config import RunnableConfig from langchain_core.runnables.config import RunnableConfig
@@ -90,9 +91,7 @@ async def stream_agent_events(
plots = output["plots"] plots = output["plots"]
encoded_plots: list[str] = [] encoded_plots: list[str] = []
for fig in plots: for fig in plots:
buf = io.BytesIO() plot_bytes = fig_to_bytes(fig)
fig.savefig(buf, format="png")
plot_bytes = buf.getvalue()
assistant_plots.append(plot_bytes) assistant_plots.append(plot_bytes)
encoded_plots.append(base64.b64encode(plot_bytes).decode('utf-8')) encoded_plots.append(base64.b64encode(plot_bytes).decode('utf-8'))
output_event["data"]["encoded_plots"] = encoded_plots output_event["data"]["encoded_plots"] = encoded_plots

View File

@@ -6,6 +6,7 @@ from ea_chatbot.graph.workflow import app
from ea_chatbot.graph.state import AgentState from ea_chatbot.graph.state import AgentState
from ea_chatbot.utils.logging import get_logger from ea_chatbot.utils.logging import get_logger
from ea_chatbot.utils.helpers import merge_agent_state from ea_chatbot.utils.helpers import merge_agent_state
from ea_chatbot.utils.plots import fig_to_bytes
from ea_chatbot.config import Settings from ea_chatbot.config import Settings
from ea_chatbot.history.manager import HistoryManager from ea_chatbot.history.manager import HistoryManager
from ea_chatbot.auth import OIDCClient, AuthType, get_user_auth_type from ea_chatbot.auth import OIDCClient, AuthType, get_user_auth_type
@@ -419,9 +420,7 @@ def main():
for fig in final_state["plots"]: for fig in final_state["plots"]:
st.pyplot(fig) st.pyplot(fig)
# Convert fig to bytes # Convert fig to bytes
buf = io.BytesIO() plot_bytes_list.append(fig_to_bytes(fig))
fig.savefig(buf, format="png")
plot_bytes_list.append(buf.getvalue())
if final_state.get("dfs"): if final_state.get("dfs"):
for df_name, df in final_state["dfs"].items(): for df_name, df in final_state["dfs"].items():

View File

@@ -162,16 +162,6 @@ class OIDCClient:
except JWTError as e: except JWTError as e:
raise ValueError(f"ID Token validation failed: {str(e)}") raise ValueError(f"ID Token validation failed: {str(e)}")
def get_login_url(self) -> str:
"""Legacy method for generating simple authorization URL."""
metadata = self.fetch_metadata()
authorization_endpoint = metadata.get("authorization_endpoint")
if not authorization_endpoint:
raise ValueError("authorization_endpoint not found in OIDC metadata")
uri, state = self.oauth_session.create_authorization_url(authorization_endpoint)
return uri
def exchange_code_for_token(self, code: str, code_verifier: Optional[str] = None) -> Dict[str, Any]: def exchange_code_for_token(self, code: str, code_verifier: Optional[str] = None) -> Dict[str, Any]:
"""Exchange the authorization code for an access token, optionally using PKCE verifier.""" """Exchange the authorization code for an access token, optionally using PKCE verifier."""
metadata = self.fetch_metadata() metadata = self.fetch_metadata()

View File

@@ -19,7 +19,7 @@ class Settings(BaseSettings):
data_dir: str = "data" data_dir: str = "data"
data_state: str = "new_jersey" data_state: str = "new_jersey"
log_level: str = Field(default="INFO", alias="LOG_LEVEL") log_level: str = Field(default="INFO", alias="LOG_LEVEL")
dev_mode: bool = Field(default=True, alias="DEV_MODE") dev_mode: bool = Field(default=False, alias="DEV_MODE")
frontend_url: str = Field(default="http://localhost:5173", alias="FRONTEND_URL") frontend_url: str = Field(default="http://localhost:5173", alias="FRONTEND_URL")
# Voter Database configuration # Voter Database configuration

View File

@@ -0,0 +1,17 @@
import io
import matplotlib.pyplot as plt
def fig_to_bytes(fig: plt.Figure, format: str = "png") -> bytes:
"""
Convert a Matplotlib figure to bytes.
Args:
fig: The Matplotlib figure to convert.
format: The image format to use (default: "png").
Returns:
bytes: The image data.
"""
buf = io.BytesIO()
fig.savefig(buf, format=format)
return buf.getvalue()

View File

@@ -55,15 +55,6 @@ def test_oidc_fetch_jwks(oidc_config, mock_metadata):
client.fetch_jwks() client.fetch_jwks()
assert mock_get.call_count == 1 assert mock_get.call_count == 1
def test_oidc_get_login_url_legacy(oidc_config, mock_metadata):
client = OIDCClient(**oidc_config)
client.metadata = mock_metadata
with patch.object(client.oauth_session, "create_authorization_url") as mock_create:
mock_create.return_value = ("https://url", "state")
url = client.get_login_url()
assert url == "https://url"
def test_oidc_get_user_info(oidc_config, mock_metadata): def test_oidc_get_user_info(oidc_config, mock_metadata):
client = OIDCClient(**oidc_config) client = OIDCClient(**oidc_config)
client.metadata = mock_metadata client.metadata = mock_metadata

View File

@@ -1,4 +1,4 @@
import { useState, useEffect } from "react" import { useState, useEffect, createContext, useContext } from "react"
import { Routes, Route } from "react-router-dom" import { Routes, Route } from "react-router-dom"
import { MainLayout } from "./components/layout/MainLayout" import { MainLayout } from "./components/layout/MainLayout"
import { LoginForm } from "./components/auth/LoginForm" import { LoginForm } from "./components/auth/LoginForm"
@@ -9,12 +9,42 @@ import { ChatService, type MessageResponse } from "./services/chat"
import { type Conversation } from "./components/layout/HistorySidebar" import { type Conversation } from "./components/layout/HistorySidebar"
import { registerUnauthorizedCallback } from "./services/api" import { registerUnauthorizedCallback } from "./services/api"
import { Button } from "./components/ui/button" import { Button } from "./components/ui/button"
import { useTheme } from "./components/theme-provider" import { ThemeProvider, useTheme } from "./components/theme-provider"
function App() { // --- Auth Context ---
const { setThemeLocal } = useTheme() interface AuthContextType {
isAuthenticated: boolean
setIsAuthenticated: (val: boolean) => void
user: UserResponse | null
setUser: (user: UserResponse | null) => void
}
const AuthContext = createContext<AuthContextType | undefined>(undefined)
function AuthProvider({ children }: { children: React.ReactNode }) {
const [isAuthenticated, setIsAuthenticated] = useState(false) const [isAuthenticated, setIsAuthenticated] = useState(false)
const [user, setUser] = useState<UserResponse | null>(null) const [user, setUser] = useState<UserResponse | null>(null)
return (
<AuthContext.Provider value={{ isAuthenticated, setIsAuthenticated, user, setUser }}>
{children}
</AuthContext.Provider>
)
}
function useAuth() {
const context = useContext(AuthContext)
if (context === undefined) {
throw new Error("useAuth must be used within an AuthProvider")
}
return context
}
// --- App Content ---
function AppContent() {
const { setThemeLocal } = useTheme()
const { isAuthenticated, setIsAuthenticated, user, setUser } = useAuth()
const [authMode, setAuthMode] = useState<"login" | "register">("login") const [authMode, setAuthMode] = useState<"login" | "register">("login")
const [isLoading, setIsLoading] = useState(true) const [isLoading, setIsLoading] = useState(true)
const [selectedThreadId, setSelectedThreadId] = useState<string | null>(null) const [selectedThreadId, setSelectedThreadId] = useState<string | null>(null)
@@ -22,13 +52,13 @@ function App() {
const [threadMessages, setThreadMessages] = useState<Record<string, MessageResponse[]>>({}) const [threadMessages, setThreadMessages] = useState<Record<string, MessageResponse[]>>({})
useEffect(() => { useEffect(() => {
// Register callback to handle session expiration from anywhere in the app
registerUnauthorizedCallback(() => { registerUnauthorizedCallback(() => {
setIsAuthenticated(false) setIsAuthenticated(false)
setUser(null) setUser(null)
setConversations([]) setConversations([])
setSelectedThreadId(null) setSelectedThreadId(null)
setThreadMessages({}) setThreadMessages({})
setThemeLocal("light")
}) })
const initAuth = async () => { const initAuth = async () => {
@@ -39,7 +69,6 @@ function App() {
if (userData.theme_preference) { if (userData.theme_preference) {
setThemeLocal(userData.theme_preference) setThemeLocal(userData.theme_preference)
} }
// Load history after successful auth
loadHistory() loadHistory()
} catch (err: unknown) { } catch (err: unknown) {
console.log("No active session found", err) console.log("No active session found", err)
@@ -50,7 +79,7 @@ function App() {
} }
initAuth() initAuth()
}, []) }, [setIsAuthenticated, setThemeLocal, setUser])
const loadHistory = async () => { const loadHistory = async () => {
try { try {
@@ -86,13 +115,12 @@ function App() {
setSelectedThreadId(null) setSelectedThreadId(null)
setConversations([]) setConversations([])
setThreadMessages({}) setThreadMessages({})
setThemeLocal("light")
} }
} }
const handleSelectConversation = async (id: string) => { const handleSelectConversation = async (id: string) => {
setSelectedThreadId(id) setSelectedThreadId(id)
// Always fetch messages to avoid stale cache issues when switching back
// or if the session was updated from elsewhere
try { try {
const msgs = await ChatService.getMessages(id) const msgs = await ChatService.getMessages(id)
setThreadMessages(prev => ({ ...prev, [id]: msgs })) setThreadMessages(prev => ({ ...prev, [id]: msgs }))
@@ -125,7 +153,6 @@ function App() {
try { try {
await ChatService.deleteConversation(id) await ChatService.deleteConversation(id)
setConversations(prev => prev.filter(c => c.id !== id)) setConversations(prev => prev.filter(c => c.id !== id))
// Also clear from cache
setThreadMessages(prev => { setThreadMessages(prev => {
const next = { ...prev } const next = { ...prev }
delete next[id] delete next[id]
@@ -202,7 +229,7 @@ function App() {
<div className="flex-1 min-h-0"> <div className="flex-1 min-h-0">
{selectedThreadId ? ( {selectedThreadId ? (
<ChatInterface <ChatInterface
key={selectedThreadId} // Force remount on thread change key={selectedThreadId}
threadId={selectedThreadId} threadId={selectedThreadId}
initialMessages={threadMessages[selectedThreadId] || []} initialMessages={threadMessages[selectedThreadId] || []}
onMessagesFinal={(msgs) => handleMessagesFinal(selectedThreadId, msgs)} onMessagesFinal={(msgs) => handleMessagesFinal(selectedThreadId, msgs)}
@@ -249,4 +276,23 @@ function App() {
) )
} }
function ThemeWrapper({ children }: { children: React.ReactNode }) {
const { isAuthenticated } = useAuth()
return (
<ThemeProvider isAuthenticated={isAuthenticated}>
{children}
</ThemeProvider>
)
}
function App() {
return (
<AuthProvider>
<ThemeWrapper>
<AppContent />
</ThemeWrapper>
</AuthProvider>
)
}
export default App export default App

View File

@@ -15,9 +15,11 @@ const ThemeContext = createContext<ThemeContextType | undefined>(undefined)
export function ThemeProvider({ export function ThemeProvider({
children, children,
initialTheme = "light", initialTheme = "light",
isAuthenticated = false,
}: { }: {
children: React.ReactNode children: React.ReactNode
initialTheme?: Theme initialTheme?: Theme
isAuthenticated?: boolean
}) { }) {
const [theme, setThemeState] = useState<Theme>(initialTheme) const [theme, setThemeState] = useState<Theme>(initialTheme)
@@ -29,10 +31,12 @@ export function ThemeProvider({
const setTheme = async (newTheme: Theme) => { const setTheme = async (newTheme: Theme) => {
setThemeState(newTheme) setThemeState(newTheme)
try { if (isAuthenticated) {
await AuthService.updateTheme(newTheme) try {
} catch (error) { await AuthService.updateTheme(newTheme)
console.error("Failed to sync theme to backend:", error) } catch (error) {
console.error("Failed to sync theme to backend:", error)
}
} }
} }

View File

@@ -9,11 +9,9 @@ import { ThemeProvider } from "./components/theme-provider"
createRoot(document.getElementById('root')!).render( createRoot(document.getElementById('root')!).render(
<StrictMode> <StrictMode>
<BrowserRouter> <BrowserRouter>
<ThemeProvider> <TooltipProvider>
<TooltipProvider> <App />
<App /> </TooltipProvider>
</TooltipProvider>
</ThemeProvider>
</BrowserRouter> </BrowserRouter>
</StrictMode>, </StrictMode>,
) )