chore: Perform codebase cleanup and refactor App state management
This commit is contained in:
@@ -9,6 +9,7 @@ from ea_chatbot.graph.checkpoint import get_checkpointer
|
|||||||
from ea_chatbot.history.models import User as UserDB, Conversation
|
from ea_chatbot.history.models import User as UserDB, Conversation
|
||||||
from ea_chatbot.history.utils import map_db_messages_to_langchain
|
from ea_chatbot.history.utils import map_db_messages_to_langchain
|
||||||
from ea_chatbot.api.schemas import ChatRequest
|
from ea_chatbot.api.schemas import ChatRequest
|
||||||
|
from ea_chatbot.utils.plots import fig_to_bytes
|
||||||
import io
|
import io
|
||||||
import base64
|
import base64
|
||||||
from langchain_core.runnables.config import RunnableConfig
|
from langchain_core.runnables.config import RunnableConfig
|
||||||
@@ -90,9 +91,7 @@ async def stream_agent_events(
|
|||||||
plots = output["plots"]
|
plots = output["plots"]
|
||||||
encoded_plots: list[str] = []
|
encoded_plots: list[str] = []
|
||||||
for fig in plots:
|
for fig in plots:
|
||||||
buf = io.BytesIO()
|
plot_bytes = fig_to_bytes(fig)
|
||||||
fig.savefig(buf, format="png")
|
|
||||||
plot_bytes = buf.getvalue()
|
|
||||||
assistant_plots.append(plot_bytes)
|
assistant_plots.append(plot_bytes)
|
||||||
encoded_plots.append(base64.b64encode(plot_bytes).decode('utf-8'))
|
encoded_plots.append(base64.b64encode(plot_bytes).decode('utf-8'))
|
||||||
output_event["data"]["encoded_plots"] = encoded_plots
|
output_event["data"]["encoded_plots"] = encoded_plots
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from ea_chatbot.graph.workflow import app
|
|||||||
from ea_chatbot.graph.state import AgentState
|
from ea_chatbot.graph.state import AgentState
|
||||||
from ea_chatbot.utils.logging import get_logger
|
from ea_chatbot.utils.logging import get_logger
|
||||||
from ea_chatbot.utils.helpers import merge_agent_state
|
from ea_chatbot.utils.helpers import merge_agent_state
|
||||||
|
from ea_chatbot.utils.plots import fig_to_bytes
|
||||||
from ea_chatbot.config import Settings
|
from ea_chatbot.config import Settings
|
||||||
from ea_chatbot.history.manager import HistoryManager
|
from ea_chatbot.history.manager import HistoryManager
|
||||||
from ea_chatbot.auth import OIDCClient, AuthType, get_user_auth_type
|
from ea_chatbot.auth import OIDCClient, AuthType, get_user_auth_type
|
||||||
@@ -419,9 +420,7 @@ def main():
|
|||||||
for fig in final_state["plots"]:
|
for fig in final_state["plots"]:
|
||||||
st.pyplot(fig)
|
st.pyplot(fig)
|
||||||
# Convert fig to bytes
|
# Convert fig to bytes
|
||||||
buf = io.BytesIO()
|
plot_bytes_list.append(fig_to_bytes(fig))
|
||||||
fig.savefig(buf, format="png")
|
|
||||||
plot_bytes_list.append(buf.getvalue())
|
|
||||||
|
|
||||||
if final_state.get("dfs"):
|
if final_state.get("dfs"):
|
||||||
for df_name, df in final_state["dfs"].items():
|
for df_name, df in final_state["dfs"].items():
|
||||||
|
|||||||
@@ -162,16 +162,6 @@ class OIDCClient:
|
|||||||
except JWTError as e:
|
except JWTError as e:
|
||||||
raise ValueError(f"ID Token validation failed: {str(e)}")
|
raise ValueError(f"ID Token validation failed: {str(e)}")
|
||||||
|
|
||||||
def get_login_url(self) -> str:
|
|
||||||
"""Legacy method for generating simple authorization URL."""
|
|
||||||
metadata = self.fetch_metadata()
|
|
||||||
authorization_endpoint = metadata.get("authorization_endpoint")
|
|
||||||
if not authorization_endpoint:
|
|
||||||
raise ValueError("authorization_endpoint not found in OIDC metadata")
|
|
||||||
|
|
||||||
uri, state = self.oauth_session.create_authorization_url(authorization_endpoint)
|
|
||||||
return uri
|
|
||||||
|
|
||||||
def exchange_code_for_token(self, code: str, code_verifier: Optional[str] = None) -> Dict[str, Any]:
|
def exchange_code_for_token(self, code: str, code_verifier: Optional[str] = None) -> Dict[str, Any]:
|
||||||
"""Exchange the authorization code for an access token, optionally using PKCE verifier."""
|
"""Exchange the authorization code for an access token, optionally using PKCE verifier."""
|
||||||
metadata = self.fetch_metadata()
|
metadata = self.fetch_metadata()
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ class Settings(BaseSettings):
|
|||||||
data_dir: str = "data"
|
data_dir: str = "data"
|
||||||
data_state: str = "new_jersey"
|
data_state: str = "new_jersey"
|
||||||
log_level: str = Field(default="INFO", alias="LOG_LEVEL")
|
log_level: str = Field(default="INFO", alias="LOG_LEVEL")
|
||||||
dev_mode: bool = Field(default=True, alias="DEV_MODE")
|
dev_mode: bool = Field(default=False, alias="DEV_MODE")
|
||||||
frontend_url: str = Field(default="http://localhost:5173", alias="FRONTEND_URL")
|
frontend_url: str = Field(default="http://localhost:5173", alias="FRONTEND_URL")
|
||||||
|
|
||||||
# Voter Database configuration
|
# Voter Database configuration
|
||||||
|
|||||||
17
backend/src/ea_chatbot/utils/plots.py
Normal file
17
backend/src/ea_chatbot/utils/plots.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
import io
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
def fig_to_bytes(fig: plt.Figure, format: str = "png") -> bytes:
|
||||||
|
"""
|
||||||
|
Convert a Matplotlib figure to bytes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fig: The Matplotlib figure to convert.
|
||||||
|
format: The image format to use (default: "png").
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bytes: The image data.
|
||||||
|
"""
|
||||||
|
buf = io.BytesIO()
|
||||||
|
fig.savefig(buf, format=format)
|
||||||
|
return buf.getvalue()
|
||||||
@@ -55,15 +55,6 @@ def test_oidc_fetch_jwks(oidc_config, mock_metadata):
|
|||||||
client.fetch_jwks()
|
client.fetch_jwks()
|
||||||
assert mock_get.call_count == 1
|
assert mock_get.call_count == 1
|
||||||
|
|
||||||
def test_oidc_get_login_url_legacy(oidc_config, mock_metadata):
|
|
||||||
client = OIDCClient(**oidc_config)
|
|
||||||
client.metadata = mock_metadata
|
|
||||||
|
|
||||||
with patch.object(client.oauth_session, "create_authorization_url") as mock_create:
|
|
||||||
mock_create.return_value = ("https://url", "state")
|
|
||||||
url = client.get_login_url()
|
|
||||||
assert url == "https://url"
|
|
||||||
|
|
||||||
def test_oidc_get_user_info(oidc_config, mock_metadata):
|
def test_oidc_get_user_info(oidc_config, mock_metadata):
|
||||||
client = OIDCClient(**oidc_config)
|
client = OIDCClient(**oidc_config)
|
||||||
client.metadata = mock_metadata
|
client.metadata = mock_metadata
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import { useState, useEffect } from "react"
|
import { useState, useEffect, createContext, useContext } from "react"
|
||||||
import { Routes, Route } from "react-router-dom"
|
import { Routes, Route } from "react-router-dom"
|
||||||
import { MainLayout } from "./components/layout/MainLayout"
|
import { MainLayout } from "./components/layout/MainLayout"
|
||||||
import { LoginForm } from "./components/auth/LoginForm"
|
import { LoginForm } from "./components/auth/LoginForm"
|
||||||
@@ -9,12 +9,42 @@ import { ChatService, type MessageResponse } from "./services/chat"
|
|||||||
import { type Conversation } from "./components/layout/HistorySidebar"
|
import { type Conversation } from "./components/layout/HistorySidebar"
|
||||||
import { registerUnauthorizedCallback } from "./services/api"
|
import { registerUnauthorizedCallback } from "./services/api"
|
||||||
import { Button } from "./components/ui/button"
|
import { Button } from "./components/ui/button"
|
||||||
import { useTheme } from "./components/theme-provider"
|
import { ThemeProvider, useTheme } from "./components/theme-provider"
|
||||||
|
|
||||||
function App() {
|
// --- Auth Context ---
|
||||||
const { setThemeLocal } = useTheme()
|
interface AuthContextType {
|
||||||
|
isAuthenticated: boolean
|
||||||
|
setIsAuthenticated: (val: boolean) => void
|
||||||
|
user: UserResponse | null
|
||||||
|
setUser: (user: UserResponse | null) => void
|
||||||
|
}
|
||||||
|
|
||||||
|
const AuthContext = createContext<AuthContextType | undefined>(undefined)
|
||||||
|
|
||||||
|
function AuthProvider({ children }: { children: React.ReactNode }) {
|
||||||
const [isAuthenticated, setIsAuthenticated] = useState(false)
|
const [isAuthenticated, setIsAuthenticated] = useState(false)
|
||||||
const [user, setUser] = useState<UserResponse | null>(null)
|
const [user, setUser] = useState<UserResponse | null>(null)
|
||||||
|
|
||||||
|
return (
|
||||||
|
<AuthContext.Provider value={{ isAuthenticated, setIsAuthenticated, user, setUser }}>
|
||||||
|
{children}
|
||||||
|
</AuthContext.Provider>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
function useAuth() {
|
||||||
|
const context = useContext(AuthContext)
|
||||||
|
if (context === undefined) {
|
||||||
|
throw new Error("useAuth must be used within an AuthProvider")
|
||||||
|
}
|
||||||
|
return context
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- App Content ---
|
||||||
|
function AppContent() {
|
||||||
|
const { setThemeLocal } = useTheme()
|
||||||
|
const { isAuthenticated, setIsAuthenticated, user, setUser } = useAuth()
|
||||||
|
|
||||||
const [authMode, setAuthMode] = useState<"login" | "register">("login")
|
const [authMode, setAuthMode] = useState<"login" | "register">("login")
|
||||||
const [isLoading, setIsLoading] = useState(true)
|
const [isLoading, setIsLoading] = useState(true)
|
||||||
const [selectedThreadId, setSelectedThreadId] = useState<string | null>(null)
|
const [selectedThreadId, setSelectedThreadId] = useState<string | null>(null)
|
||||||
@@ -22,13 +52,13 @@ function App() {
|
|||||||
const [threadMessages, setThreadMessages] = useState<Record<string, MessageResponse[]>>({})
|
const [threadMessages, setThreadMessages] = useState<Record<string, MessageResponse[]>>({})
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
// Register callback to handle session expiration from anywhere in the app
|
|
||||||
registerUnauthorizedCallback(() => {
|
registerUnauthorizedCallback(() => {
|
||||||
setIsAuthenticated(false)
|
setIsAuthenticated(false)
|
||||||
setUser(null)
|
setUser(null)
|
||||||
setConversations([])
|
setConversations([])
|
||||||
setSelectedThreadId(null)
|
setSelectedThreadId(null)
|
||||||
setThreadMessages({})
|
setThreadMessages({})
|
||||||
|
setThemeLocal("light")
|
||||||
})
|
})
|
||||||
|
|
||||||
const initAuth = async () => {
|
const initAuth = async () => {
|
||||||
@@ -39,7 +69,6 @@ function App() {
|
|||||||
if (userData.theme_preference) {
|
if (userData.theme_preference) {
|
||||||
setThemeLocal(userData.theme_preference)
|
setThemeLocal(userData.theme_preference)
|
||||||
}
|
}
|
||||||
// Load history after successful auth
|
|
||||||
loadHistory()
|
loadHistory()
|
||||||
} catch (err: unknown) {
|
} catch (err: unknown) {
|
||||||
console.log("No active session found", err)
|
console.log("No active session found", err)
|
||||||
@@ -50,7 +79,7 @@ function App() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
initAuth()
|
initAuth()
|
||||||
}, [])
|
}, [setIsAuthenticated, setThemeLocal, setUser])
|
||||||
|
|
||||||
const loadHistory = async () => {
|
const loadHistory = async () => {
|
||||||
try {
|
try {
|
||||||
@@ -86,13 +115,12 @@ function App() {
|
|||||||
setSelectedThreadId(null)
|
setSelectedThreadId(null)
|
||||||
setConversations([])
|
setConversations([])
|
||||||
setThreadMessages({})
|
setThreadMessages({})
|
||||||
|
setThemeLocal("light")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const handleSelectConversation = async (id: string) => {
|
const handleSelectConversation = async (id: string) => {
|
||||||
setSelectedThreadId(id)
|
setSelectedThreadId(id)
|
||||||
// Always fetch messages to avoid stale cache issues when switching back
|
|
||||||
// or if the session was updated from elsewhere
|
|
||||||
try {
|
try {
|
||||||
const msgs = await ChatService.getMessages(id)
|
const msgs = await ChatService.getMessages(id)
|
||||||
setThreadMessages(prev => ({ ...prev, [id]: msgs }))
|
setThreadMessages(prev => ({ ...prev, [id]: msgs }))
|
||||||
@@ -125,7 +153,6 @@ function App() {
|
|||||||
try {
|
try {
|
||||||
await ChatService.deleteConversation(id)
|
await ChatService.deleteConversation(id)
|
||||||
setConversations(prev => prev.filter(c => c.id !== id))
|
setConversations(prev => prev.filter(c => c.id !== id))
|
||||||
// Also clear from cache
|
|
||||||
setThreadMessages(prev => {
|
setThreadMessages(prev => {
|
||||||
const next = { ...prev }
|
const next = { ...prev }
|
||||||
delete next[id]
|
delete next[id]
|
||||||
@@ -202,7 +229,7 @@ function App() {
|
|||||||
<div className="flex-1 min-h-0">
|
<div className="flex-1 min-h-0">
|
||||||
{selectedThreadId ? (
|
{selectedThreadId ? (
|
||||||
<ChatInterface
|
<ChatInterface
|
||||||
key={selectedThreadId} // Force remount on thread change
|
key={selectedThreadId}
|
||||||
threadId={selectedThreadId}
|
threadId={selectedThreadId}
|
||||||
initialMessages={threadMessages[selectedThreadId] || []}
|
initialMessages={threadMessages[selectedThreadId] || []}
|
||||||
onMessagesFinal={(msgs) => handleMessagesFinal(selectedThreadId, msgs)}
|
onMessagesFinal={(msgs) => handleMessagesFinal(selectedThreadId, msgs)}
|
||||||
@@ -249,4 +276,23 @@ function App() {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function ThemeWrapper({ children }: { children: React.ReactNode }) {
|
||||||
|
const { isAuthenticated } = useAuth()
|
||||||
|
return (
|
||||||
|
<ThemeProvider isAuthenticated={isAuthenticated}>
|
||||||
|
{children}
|
||||||
|
</ThemeProvider>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
function App() {
|
||||||
|
return (
|
||||||
|
<AuthProvider>
|
||||||
|
<ThemeWrapper>
|
||||||
|
<AppContent />
|
||||||
|
</ThemeWrapper>
|
||||||
|
</AuthProvider>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
export default App
|
export default App
|
||||||
|
|||||||
@@ -15,9 +15,11 @@ const ThemeContext = createContext<ThemeContextType | undefined>(undefined)
|
|||||||
export function ThemeProvider({
|
export function ThemeProvider({
|
||||||
children,
|
children,
|
||||||
initialTheme = "light",
|
initialTheme = "light",
|
||||||
|
isAuthenticated = false,
|
||||||
}: {
|
}: {
|
||||||
children: React.ReactNode
|
children: React.ReactNode
|
||||||
initialTheme?: Theme
|
initialTheme?: Theme
|
||||||
|
isAuthenticated?: boolean
|
||||||
}) {
|
}) {
|
||||||
const [theme, setThemeState] = useState<Theme>(initialTheme)
|
const [theme, setThemeState] = useState<Theme>(initialTheme)
|
||||||
|
|
||||||
@@ -29,12 +31,14 @@ export function ThemeProvider({
|
|||||||
|
|
||||||
const setTheme = async (newTheme: Theme) => {
|
const setTheme = async (newTheme: Theme) => {
|
||||||
setThemeState(newTheme)
|
setThemeState(newTheme)
|
||||||
|
if (isAuthenticated) {
|
||||||
try {
|
try {
|
||||||
await AuthService.updateTheme(newTheme)
|
await AuthService.updateTheme(newTheme)
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error("Failed to sync theme to backend:", error)
|
console.error("Failed to sync theme to backend:", error)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const setThemeLocal = (newTheme: Theme) => {
|
const setThemeLocal = (newTheme: Theme) => {
|
||||||
setThemeState(newTheme)
|
setThemeState(newTheme)
|
||||||
|
|||||||
@@ -9,11 +9,9 @@ import { ThemeProvider } from "./components/theme-provider"
|
|||||||
createRoot(document.getElementById('root')!).render(
|
createRoot(document.getElementById('root')!).render(
|
||||||
<StrictMode>
|
<StrictMode>
|
||||||
<BrowserRouter>
|
<BrowserRouter>
|
||||||
<ThemeProvider>
|
|
||||||
<TooltipProvider>
|
<TooltipProvider>
|
||||||
<App />
|
<App />
|
||||||
</TooltipProvider>
|
</TooltipProvider>
|
||||||
</ThemeProvider>
|
|
||||||
</BrowserRouter>
|
</BrowserRouter>
|
||||||
</StrictMode>,
|
</StrictMode>,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user