diff --git a/backend/GEMINI.md b/backend/GEMINI.md new file mode 100644 index 0000000..103a117 --- /dev/null +++ b/backend/GEMINI.md @@ -0,0 +1,162 @@ +# Election Analytics Chatbot - Backend Guide + +## Overview +This document serves as a guide for the backend implementation of the Election Analytics Chatbot, specifically focusing on the transition from the "BambooAI" based system to a modern, stateful, and graph-based architecture using **LangGraph**. + +## 1. Migration Goals +- **Framework Switch**: Move from the custom linear `ChatBot` class (in `src/ea_chatbot/bambooai/core/chatbot.py`) to `LangGraph`. +- **State Management**: explicit state management using LangGraph's `StateGraph`. +- **Modularity**: Break down monolithic methods (`pd_agent_converse`, `execute_code`) into distinct Nodes. +- **Observability**: Easier debugging of the decision process (Routing -> Planning -> Coding -> Executing). + +## 2. Architecture Proposal + +### 2.1. The Graph State +The state will track the conversation and execution context. + +```python +from typing import TypedDict, Annotated, List, Dict, Any, Optional +from langchain_core.messages import BaseMessage +import operator + +class AgentState(TypedDict): + # Conversation history + messages: Annotated[List[BaseMessage], operator.add] + + # Task context + question: str + + # Query Analysis (Decomposition results) + analysis: Optional[Dict[str, Any]] + # Expected keys: "requires_dataset", "expert", "data", "unknown", "condition" + + # Step-by-step reasoning + plan: Optional[str] + + # Code execution context + code: Optional[str] + code_output: Optional[str] + error: Optional[str] + + # Artifacts (for UI display) + plots: List[Figure] # Matplotlib figures + dfs: Dict[str, DataFrame] # Pandas DataFrames + + # Control flow + iterations: int + next_action: str # Routing hint: "clarify", "plan", "research", "end" +``` + +### 2.2. Nodes (The Actors) +We will map existing logic to these nodes: + +1. **`query_analyzer_node`** (Router & Refiner): + * **Logic**: Replaces `Expert Selector` and `Analyst Selector`. + * **Function**: + 1. Decomposes the user's query into key elements (Data, Unknowns, Conditions). + 2. Determines if the query is ambiguous or missing critical information. + * **Output**: Updates `messages`. Returns routing decision: + * `clarification_node` (if ambiguous). + * `planner_node` (if clear data task). + * `researcher_node` (if general/web task). + +2. **`clarification_node`** (Human-in-the-loop): + * **Logic**: Replaces `Theorist-Clarification`. + * **Function**: Formulates a specific question to ask the user for missing details. + * **Output**: Returns a message to the user and **interrupts** the graph execution to await user input. + +3. **`researcher_node`** (Theorist): + * **Logic**: Handles general queries or web searches. + * **Function**: Uses `GoogleSearch` tool if necessary. + * **Output**: Final answer. + +4. **`planner_node`**: + * **Logic**: Replaces `Planner`. + * **Function**: Generates a step-by-step plan based on the decomposed query elements and dataframe ontology. + * **Output**: Updates `plan`. + +5. **`coder_node`**: + * **Logic**: Replaces `Code Generator` & `Error Corrector`. + * **Function**: Generates Python code. If `error` exists in state, it attempts to fix it. + * **Output**: Updates `code`. + +6. **`executor_node`**: + * **Logic**: Replaces `Code Executor`. + * **Function**: Executes the Python code in a safe(r) environment. It needs access to the `DBClient`. + * **Output**: Updates `code_output`, `plots`, `dfs`. If exception, updates `error`. + +7. **`summarizer_node`**: + * **Logic**: Replaces `Solution Summarizer`. + * **Function**: Interprets the code output and generates a natural language response. + * **Output**: Final response message. + +### 2.3. The Workflow (Graph) + +```mermaid +graph TD + Start --> QueryAnalyzer + QueryAnalyzer -->|Ambiguous| Clarification + Clarification -->|User Input| QueryAnalyzer + QueryAnalyzer -->|General/Web| Researcher + QueryAnalyzer -->|Data Analysis| Planner + Planner --> Coder + Coder --> Executor + Executor -->|Success| Summarizer + Executor -->|Error| Coder + Researcher --> End + Summarizer --> End +``` + +## 3. Implementation Steps + +### Step 1: Dependencies +Add the following packages to `pyproject.toml`: +* `langgraph` +* `langchain` +* `langchain-openai` +* `langchain-google-genai` +* `langchain-community` + +### Step 2: Directory Structure +Create a new package for the graph logic to keep it separate from the old one during migration. + +``` +src/ea_chatbot/ +├── graph/ +│ ├── __init__.py +│ ├── state.py # State definition +│ ├── nodes/ # Individual node implementations +│ │ ├── __init__.py +│ │ ├── router.py +│ │ ├── planner.py +│ │ ├── coder.py +│ │ ├── executor.py +│ │ └── ... +│ ├── workflow.py # Graph construction +│ └── tools/ # DB and Search tools wrapped for LangChain +└── ... +``` + +### Step 3: Tool Wrapping +Wrap the existing `DBClient` (from `src/ea_chatbot/bambooai/utils/db_client.py`) into a structure accessible by the `executor_node`. The `executor_node` will likely keep the existing `exec()` based approach initially for compatibility with the generated code, but structured as a graph node. + +### Step 4: Prompt Migration +Port the prompts from `data/PROMPT_TEMPLATES.json` or `src/ea_chatbot/bambooai/prompts/strings.py` into the respective nodes. Use LangChain's `ChatPromptTemplate` for better management. + +### Step 5: Integration +Update `src/ea_chatbot/app.py` to use the new `workflow.compile()` runnable. +* Instead of `chatbot.pd_agent_converse(...)`, use `app.stream(...)` (LangGraph app). +* Handle the streaming output to update the UI progressively. + +## 4. Key Considerations for Refactoring + +* **Database Connection**: Ensure `DBClient` is initialized once and passed to the `Executor` node efficiently (e.g., via `configurable` parameters or closure). +* **Prompt Templating**: The current system uses simple `format` strings. Switching to LangChain templates allows for easier model switching and partial formatting. +* **Token Management**: LangGraph provides built-in tracing (if LangSmith is enabled), but we should ensure the `OutputManager` logic (printing costs/tokens) is preserved or adapted if still needed for the CLI/Logs. +* **Vector DB**: The current system has `PineconeWrapper` for RAG. This should be integrated into the `Planner` or `Coder` node to fetch few-shot examples or context. + +## 5. Next Actions +1. **Initialize**: Create the folder structure. +2. **Define State**: Create `src/ea_chatbot/graph/state.py`. +3. **Implement Router**: Create the first node to replicate `Expert Selector` logic. +4. **Implement Executor**: Port the `exec()` logic to a node. diff --git a/backend/src/ea_chatbot/api/dependencies.py b/backend/src/ea_chatbot/api/dependencies.py index fe0f533..e3d2b9f 100644 --- a/backend/src/ea_chatbot/api/dependencies.py +++ b/backend/src/ea_chatbot/api/dependencies.py @@ -1,5 +1,5 @@ import os -from fastapi import Depends, HTTPException, status +from fastapi import Depends, HTTPException, status, Request from fastapi.security import OAuth2PasswordBearer from ea_chatbot.config import Settings from ea_chatbot.history.manager import HistoryManager @@ -21,16 +21,23 @@ if settings.oidc_client_id and settings.oidc_client_secret and settings.oidc_ser redirect_uri=os.getenv("OIDC_REDIRECT_URI", "http://localhost:3000/auth/callback") ) -oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login") +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/v1/auth/login", auto_error=False) + +async def get_current_user(request: Request, token: str = Depends(oauth2_scheme)) -> User: + """Dependency to get the current authenticated user from the JWT token (cookie or header).""" + # Try getting token from cookie first + if not token: + token = request.cookies.get("access_token") -async def get_current_user(token: str = Depends(oauth2_scheme)) -> User: - """Dependency to get the current authenticated user from the JWT token.""" credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) + if not token: + raise credentials_exception + payload = decode_access_token(token) if payload is None: raise credentials_exception diff --git a/backend/src/ea_chatbot/api/main.py b/backend/src/ea_chatbot/api/main.py index 4df3a58..ec045ee 100644 --- a/backend/src/ea_chatbot/api/main.py +++ b/backend/src/ea_chatbot/api/main.py @@ -1,4 +1,4 @@ -from fastapi import FastAPI +from fastapi import FastAPI, APIRouter from fastapi.middleware.cors import CORSMiddleware from ea_chatbot.api.routers import auth, history, artifacts, agent from dotenv import load_dotenv @@ -20,10 +20,16 @@ app.add_middleware( allow_headers=["*"], ) -app.include_router(auth.router) -app.include_router(history.router) -app.include_router(artifacts.router) -app.include_router(agent.router) +# API v1 Router +api_v1_router = APIRouter(prefix="/api/v1") + +api_v1_router.include_router(auth.router) +api_v1_router.include_router(history.router) +api_v1_router.include_router(artifacts.router) +api_v1_router.include_router(agent.router) + +# Include v1 router in app +app.include_router(api_v1_router) @app.get("/health") async def health_check(): diff --git a/backend/src/ea_chatbot/api/routers/auth.py b/backend/src/ea_chatbot/api/routers/auth.py index 96beb9c..d6792e7 100644 --- a/backend/src/ea_chatbot/api/routers/auth.py +++ b/backend/src/ea_chatbot/api/routers/auth.py @@ -1,14 +1,29 @@ -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, Depends, HTTPException, status, Response +from fastapi.responses import RedirectResponse from fastapi.security import OAuth2PasswordRequestForm from ea_chatbot.api.utils import create_access_token from ea_chatbot.api.dependencies import history_manager, oidc_client, get_current_user from ea_chatbot.history.models import User as UserDB from ea_chatbot.api.schemas import Token, UserCreate, UserResponse +import os router = APIRouter(prefix="/auth", tags=["auth"]) +FRONTEND_URL = os.getenv("FRONTEND_URL", "http://localhost:5173") + +def set_auth_cookie(response: Response, token: str): + response.set_cookie( + key="access_token", + value=token, + httponly=True, + max_age=1800, + expires=1800, + samesite="lax", + secure=False, # Set to True in production with HTTPS + ) + @router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED) -async def register(user_in: UserCreate): +async def register(user_in: UserCreate, response: Response): """Register a new user.""" user = history_manager.get_user(user_in.email) if user: @@ -22,6 +37,10 @@ async def register(user_in: UserCreate): password=user_in.password, display_name=user_in.display_name ) + + access_token = create_access_token(data={"sub": str(user.id)}) + set_auth_cookie(response, access_token) + return { "id": str(user.id), "email": user.username, @@ -29,7 +48,7 @@ async def register(user_in: UserCreate): } @router.post("/login", response_model=Token) -async def login(form_data: OAuth2PasswordRequestForm = Depends()): +async def login(response: Response, form_data: OAuth2PasswordRequestForm = Depends()): """Login with email and password to get a JWT.""" user = history_manager.authenticate_user(form_data.username, form_data.password) if not user: @@ -40,8 +59,15 @@ async def login(form_data: OAuth2PasswordRequestForm = Depends()): ) access_token = create_access_token(data={"sub": str(user.id)}) + set_auth_cookie(response, access_token) return {"access_token": access_token, "token_type": "bearer"} +@router.post("/logout") +async def logout(response: Response): + """Logout by clearing the auth cookie.""" + response.delete_cookie(key="access_token") + return {"detail": "Successfully logged out"} + @router.get("/oidc/login") async def oidc_login(): """Get the OIDC authorization URL.""" @@ -54,9 +80,9 @@ async def oidc_login(): url = oidc_client.get_login_url() return {"url": url} -@router.get("/oidc/callback", response_model=Token) +@router.get("/oidc/callback") async def oidc_callback(code: str): - """Handle the OIDC callback and issue a JWT.""" + """Handle the OIDC callback, issue a JWT, and redirect to frontend.""" if not oidc_client: raise HTTPException(status_code=status.HTTP_510_NOT_EXTENDED, detail="OIDC not configured") @@ -72,9 +98,13 @@ async def oidc_callback(code: str): user = history_manager.sync_user_from_oidc(email=email, display_name=name) access_token = create_access_token(data={"sub": str(user.id)}) - return {"access_token": access_token, "token_type": "bearer"} + + response = RedirectResponse(url=f"{FRONTEND_URL}/auth/callback") + set_auth_cookie(response, access_token) + return response except Exception as e: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=f"OIDC authentication failed: {str(e)}") + # Redirect to frontend with error + return RedirectResponse(url=f"{FRONTEND_URL}?error=oidc_failed") @router.get("/me", response_model=UserResponse) async def get_me(current_user: UserDB = Depends(get_current_user)): diff --git a/backend/tests/api/test_api_auth_cookie.py b/backend/tests/api/test_api_auth_cookie.py new file mode 100644 index 0000000..e200c61 --- /dev/null +++ b/backend/tests/api/test_api_auth_cookie.py @@ -0,0 +1,97 @@ +import pytest +from fastapi.testclient import TestClient +from unittest.mock import MagicMock, patch +from ea_chatbot.api.main import app +from ea_chatbot.history.models import User +from ea_chatbot.api.utils import create_access_token + +client = TestClient(app) + +@pytest.fixture +def mock_user(): + return User( + id="user-123", + username="test@example.com", + display_name="Test User", + password_hash="hashed_password" + ) + +def test_v1_prefix(): + """Test that routes are prefixed with /api/v1.""" + # This should now be 404 + response = client.get("/auth/me") + assert response.status_code == 404 + + # This should be 401 (unauthorized) instead of 404 + response = client.get("/api/v1/auth/me") + assert response.status_code == 401 + +def test_login_sets_cookie(): + """Test that login sets the access_token cookie.""" + with patch("ea_chatbot.api.routers.auth.history_manager") as mock_hm: + mock_hm.authenticate_user.return_value = User(id="1", username="test@example.com") + + response = client.post( + "/api/v1/auth/login", + data={"username": "test@example.com", "password": "password123"} + ) + + assert response.status_code == 200 + assert "access_token" in response.cookies + + # Check for HttpOnly in Set-Cookie header + set_cookie = response.headers.get("set-cookie", "") + assert "access_token" in set_cookie + assert "HttpOnly" in set_cookie + +def test_register_sets_cookie(): + """Test that register sets the access_token cookie.""" + with patch("ea_chatbot.api.routers.auth.history_manager") as mock_hm: + mock_hm.get_user.return_value = None + mock_hm.create_user.return_value = User(id="1", username="new@example.com", display_name="New") + + response = client.post( + "/api/v1/auth/register", + json={"email": "new@example.com", "password": "password123", "display_name": "New"} + ) + + assert response.status_code == 201 + assert "access_token" in response.cookies + +def test_auth_via_cookie(): + """Test that protected routes work with the access_token cookie.""" + token = create_access_token(data={"sub": "123"}) + + with patch("ea_chatbot.api.dependencies.history_manager") as mock_hm: + mock_hm.get_user_by_id.return_value = User(id="123", username="test@example.com", display_name="Test") + + # Pass token via cookie instead of header + client.cookies.set("access_token", token) + response = client.get("/api/v1/auth/me") + + assert response.status_code == 200 + assert response.json()["email"] == "test@example.com" + +def test_logout_clears_cookie(): + """Test that logout endpoint clears the cookie.""" + response = client.post("/api/v1/auth/logout") + assert response.status_code == 200 + # Cookie should be expired/empty + cookie = response.cookies.get("access_token") + assert not cookie or cookie == "" + +def test_oidc_callback_redirects_with_cookie(): + """Test that OIDC callback sets cookie and redirects.""" + with patch("ea_chatbot.api.routers.auth.oidc_client") as mock_oidc, \ + patch("ea_chatbot.api.routers.auth.history_manager") as mock_hm: + + mock_oidc.exchange_code_for_token.return_value = {"access_token": "oidc-token"} + mock_oidc.get_user_info.return_value = {"email": "sso@example.com", "name": "SSO User"} + mock_hm.sync_user_from_oidc.return_value = User(id="sso-123", username="sso@example.com", display_name="SSO User") + + # Follow_redirects=False to catch the 307/302 + response = client.get("/api/v1/auth/oidc/callback?code=some-code", follow_redirects=False) + + assert response.status_code in [302, 303, 307] + assert "access_token" in response.cookies + assert "/auth/callback" in response.headers["location"]