feat(auth): Complete OIDC security refactor and modernize test suite
- Refactored OIDC flow to implement PKCE, state/nonce validation, and BFF pattern. - Centralized configuration in Settings class (DEV_MODE, FRONTEND_URL, OIDC_REDIRECT_URI). - Updated auth routers to use conditional secure cookie flags based on DEV_MODE. - Modernized and cleaned up test suite by removing legacy Streamlit tests. - Fixed linting errors and unused imports across the backend.
This commit is contained in:
11
GEMINI.md
11
GEMINI.md
@@ -48,8 +48,9 @@ The frontend is a modern SPA (Single Page Application) designed for data-heavy i
|
||||
- **LangChain Docs**: See the `langchain-docs/` folder for local LangChain and LangGraph documentation.
|
||||
|
||||
## Git Operations
|
||||
- Branches should be used for specific features or bug fixes.
|
||||
- New branches should be created from the `main` branch and `conductor` branch.
|
||||
- The conductor should always use the `conductor` branch and derived branches.
|
||||
- When a feature or fix is complete, use rebase to keep the commit history clean before merging.
|
||||
- The conductor related changes should never be merged into the `main` branch.
|
||||
- All new feature and bug-fix branches must be created from the `conductor` branch except hot-fix.
|
||||
- The `conductor` branch serves as the primary development branch where integration occurs.
|
||||
- The `main` branch is reserved for stable, production-ready code.
|
||||
- Merges from `conductor` to `main` should only occur when significant milestones are reached and stability is verified.
|
||||
- Conductor-specific configuration or meta-files should remain on the `conductor` branch or its derivatives and never be merged into the `main` branch.
|
||||
- Use rebase to keep commit history clean before merging feature branches back into `conductor`.
|
||||
|
||||
@@ -6,7 +6,8 @@ GOOGLE_API_KEY=your_google_api_key_here
|
||||
DATA_DIR=data
|
||||
DATA_STATE=new_jersey
|
||||
LOG_LEVEL=INFO
|
||||
DEV_MODE=false
|
||||
DEV_MODE=true
|
||||
FRONTEND_URL=http://localhost:5173
|
||||
|
||||
# Security & JWT Configuration
|
||||
SECRET_KEY=change-me-in-production
|
||||
@@ -28,7 +29,7 @@ HISTORY_DB_URL=postgresql://user:password@localhost:5433/ea_history
|
||||
OIDC_CLIENT_ID=your_client_id
|
||||
OIDC_CLIENT_SECRET=your_client_secret
|
||||
OIDC_SERVER_METADATA_URL=https://your-authentik.example.com/application/o/ea-chatbot/.well-known/openid-configuration
|
||||
OIDC_REDIRECT_URI=http://localhost:8501
|
||||
OIDC_REDIRECT_URI=http://localhost:8000/api/v1/auth/oidc/callback
|
||||
|
||||
# Node Configuration Overrides (Optional)
|
||||
# Format: <NODE_NAME>_LLM__<PARAMETER>
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import os
|
||||
from fastapi import Depends, HTTPException, status, Request
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from ea_chatbot.config import Settings
|
||||
@@ -18,7 +17,7 @@ if settings.oidc_client_id and settings.oidc_client_secret and settings.oidc_ser
|
||||
client_id=settings.oidc_client_id,
|
||||
client_secret=settings.oidc_client_secret,
|
||||
server_metadata_url=settings.oidc_server_metadata_url,
|
||||
redirect_uri=os.getenv("OIDC_REDIRECT_URI", "http://localhost:3000/auth/callback")
|
||||
redirect_uri=settings.oidc_redirect_uri
|
||||
)
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/v1/auth/login", auto_error=False)
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
from fastapi import FastAPI, APIRouter
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from ea_chatbot.api.routers import auth, history, artifacts, agent
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
app = FastAPI(
|
||||
title="Election Analytics Chatbot API",
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import json
|
||||
import asyncio
|
||||
from typing import AsyncGenerator, Optional, List
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from typing import AsyncGenerator, List
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from ea_chatbot.api.dependencies import get_current_user, history_manager
|
||||
from ea_chatbot.api.utils import convert_to_json_compatible
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Response, status
|
||||
from ea_chatbot.api.dependencies import get_current_user, history_manager
|
||||
from ea_chatbot.history.models import Plot, Message, User as UserDB
|
||||
import io
|
||||
|
||||
router = APIRouter(prefix="/artifacts", tags=["artifacts"])
|
||||
|
||||
|
||||
@@ -1,19 +1,17 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Response, Request
|
||||
from fastapi.responses import RedirectResponse
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from ea_chatbot.api.utils import create_access_token
|
||||
from ea_chatbot.api.utils import create_access_token, settings
|
||||
from ea_chatbot.api.dependencies import history_manager, oidc_client, get_current_user
|
||||
from ea_chatbot.history.models import User as UserDB
|
||||
from ea_chatbot.api.schemas import Token, UserCreate, UserResponse
|
||||
import os
|
||||
from ea_chatbot.auth import OIDCSession
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
FRONTEND_URL = os.getenv("FRONTEND_URL", "http://localhost:5173")
|
||||
|
||||
def set_auth_cookie(response: Response, token: str):
|
||||
response.set_cookie(
|
||||
key="access_token",
|
||||
@@ -22,7 +20,7 @@ def set_auth_cookie(response: Response, token: str):
|
||||
max_age=1800,
|
||||
expires=1800,
|
||||
samesite="lax",
|
||||
secure=False, # Set to True in production with HTTPS
|
||||
secure=not settings.dev_mode,
|
||||
)
|
||||
|
||||
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||
@@ -72,59 +70,84 @@ async def logout(response: Response):
|
||||
return {"detail": "Successfully logged out"}
|
||||
|
||||
@router.get("/oidc/login")
|
||||
async def oidc_login():
|
||||
"""Get the OIDC authorization URL."""
|
||||
async def oidc_login(response: Response):
|
||||
"""Get the OIDC authorization URL and set temporary session cookie."""
|
||||
if not oidc_client:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_510_NOT_EXTENDED,
|
||||
detail="OIDC is not configured"
|
||||
)
|
||||
|
||||
url = oidc_client.get_login_url()
|
||||
return {"url": url}
|
||||
auth_data = oidc_client.get_auth_data()
|
||||
|
||||
@router.get("/oidc/callback", response_model=Token)
|
||||
async def oidc_callback(request: Request, response: Response, code: str):
|
||||
"""Handle the OIDC callback, issue a JWT, and redirect or return JSON."""
|
||||
# Store state, nonce, and code_verifier in a secure short-lived cookie
|
||||
session_token = OIDCSession.encrypt(
|
||||
{
|
||||
"state": auth_data["state"],
|
||||
"nonce": auth_data["nonce"],
|
||||
"code_verifier": auth_data["code_verifier"]
|
||||
},
|
||||
settings.secret_key
|
||||
)
|
||||
|
||||
response.set_cookie(
|
||||
key="oidc_session",
|
||||
value=session_token,
|
||||
httponly=True,
|
||||
max_age=300, # 5 minutes
|
||||
samesite="lax",
|
||||
secure=not settings.dev_mode
|
||||
)
|
||||
|
||||
return {"url": auth_data["url"]}
|
||||
|
||||
@router.get("/oidc/callback")
|
||||
async def oidc_callback(request: Request, response: Response, code: str, state: str):
|
||||
"""Handle the OIDC callback, validate state/nonce, issue a JWT, and redirect."""
|
||||
if not oidc_client:
|
||||
raise HTTPException(status_code=status.HTTP_510_NOT_EXTENDED, detail="OIDC not configured")
|
||||
|
||||
# 1. Validate state and retrieve session data from cookie
|
||||
session_token = request.cookies.get("oidc_session")
|
||||
if not session_token:
|
||||
logger.error("OIDC session cookie missing")
|
||||
return RedirectResponse(url=f"{settings.frontend_url}?error=oidc_failed", status_code=status.HTTP_302_FOUND)
|
||||
|
||||
session_data = OIDCSession.decrypt(session_token, settings.secret_key)
|
||||
if not session_data or session_data.get("state") != state:
|
||||
logger.error("OIDC state mismatch or session expired")
|
||||
return RedirectResponse(url=f"{settings.frontend_url}?error=oidc_failed", status_code=status.HTTP_302_FOUND)
|
||||
|
||||
try:
|
||||
token = oidc_client.exchange_code_for_token(code)
|
||||
user_info = oidc_client.get_user_info(token)
|
||||
email = user_info.get("email")
|
||||
name = user_info.get("name") or user_info.get("preferred_username")
|
||||
# 2. Exchange code for token using PKCE code_verifier
|
||||
token = oidc_client.exchange_code_for_token(code, code_verifier=session_data.get("code_verifier"))
|
||||
|
||||
# 3. Validate ID Token and Nonce
|
||||
id_token = token.get("id_token")
|
||||
if not id_token:
|
||||
raise ValueError("ID Token missing from response")
|
||||
|
||||
claims = oidc_client.validate_id_token(id_token, nonce=session_data.get("nonce"))
|
||||
|
||||
email = claims.get("email")
|
||||
name = claims.get("name") or claims.get("preferred_username")
|
||||
|
||||
if not email:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Email not provided by OIDC")
|
||||
|
||||
# 4. Sync user and establish session
|
||||
user = history_manager.sync_user_from_oidc(email=email, display_name=name)
|
||||
|
||||
access_token = create_access_token(data={"sub": str(user.id)})
|
||||
|
||||
# Determine if we should redirect (direct provider callback) or return JSON (frontend exchange)
|
||||
is_ajax = request.headers.get("X-Requested-With") == "XMLHttpRequest" or \
|
||||
"application/json" in request.headers.get("Accept", "")
|
||||
redirect_response = RedirectResponse(url=settings.frontend_url, status_code=status.HTTP_302_FOUND)
|
||||
set_auth_cookie(redirect_response, access_token)
|
||||
redirect_response.delete_cookie("oidc_session") # Cleanup
|
||||
|
||||
if is_ajax:
|
||||
set_auth_cookie(response, access_token)
|
||||
return {"access_token": access_token, "token_type": "bearer"}
|
||||
else:
|
||||
redirect_response = RedirectResponse(url=f"{FRONTEND_URL}/auth/callback")
|
||||
set_auth_cookie(redirect_response, access_token)
|
||||
return redirect_response
|
||||
return redirect_response
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("OIDC authentication failed")
|
||||
# For non-ajax, redirect with error
|
||||
is_ajax = request.headers.get("X-Requested-With") == "XMLHttpRequest" or \
|
||||
"application/json" in request.headers.get("Accept", "")
|
||||
if is_ajax:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="OIDC authentication failed")
|
||||
else:
|
||||
return RedirectResponse(url=f"{FRONTEND_URL}?error=oidc_failed")
|
||||
return RedirectResponse(url=f"{settings.frontend_url}?error=oidc_failed", status_code=status.HTTP_302_FOUND)
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
async def get_me(current_user: UserDB = Depends(get_current_user)):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional, Union, Any, List, Dict
|
||||
from typing import Optional, Any
|
||||
from jose import JWTError, jwt
|
||||
from pydantic import BaseModel
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import streamlit as st
|
||||
import asyncio
|
||||
import os
|
||||
import io
|
||||
from dotenv import load_dotenv
|
||||
@@ -348,7 +347,7 @@ def main():
|
||||
if node_name == "query_analyzer":
|
||||
analysis = state_update.get("analysis", {})
|
||||
next_action = state_update.get("next_action", "unknown")
|
||||
status.write(f"🔍 **Analyzed Query:**")
|
||||
status.write("🔍 **Analyzed Query:**")
|
||||
for k,v in analysis.items():
|
||||
status.write(f"- {k:<8}: {v}")
|
||||
status.markdown(f"Next Step: {next_action.capitalize()}")
|
||||
|
||||
@@ -19,6 +19,8 @@ class Settings(BaseSettings):
|
||||
data_dir: str = "data"
|
||||
data_state: str = "new_jersey"
|
||||
log_level: str = Field(default="INFO", alias="LOG_LEVEL")
|
||||
dev_mode: bool = Field(default=True, alias="DEV_MODE")
|
||||
frontend_url: str = Field(default="http://localhost:5173", alias="FRONTEND_URL")
|
||||
|
||||
# Voter Database configuration
|
||||
db_host: str = Field(default="localhost", alias="DB_HOST")
|
||||
@@ -40,6 +42,7 @@ class Settings(BaseSettings):
|
||||
oidc_client_id: Optional[str] = Field(default=None, alias="OIDC_CLIENT_ID")
|
||||
oidc_client_secret: Optional[str] = Field(default=None, alias="OIDC_CLIENT_SECRET")
|
||||
oidc_server_metadata_url: Optional[str] = Field(default=None, alias="OIDC_SERVER_METADATA_URL")
|
||||
oidc_redirect_uri: str = Field(default="http://localhost:8000/api/v1/auth/oidc/callback", alias="OIDC_REDIRECT_URI")
|
||||
|
||||
# Default configurations for each node
|
||||
query_analyzer_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0))
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from langchain_core.messages import AIMessage
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from ea_chatbot.utils import helpers, database_inspection
|
||||
from ea_chatbot.utils import database_inspection
|
||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||
from ea_chatbot.graph.prompts.coder import CODE_GENERATOR_PROMPT
|
||||
from ea_chatbot.schemas import CodeGenerationResponse
|
||||
|
||||
@@ -2,7 +2,7 @@ import io
|
||||
import sys
|
||||
import traceback
|
||||
from contextlib import redirect_stdout
|
||||
from typing import Any, Dict, List, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING
|
||||
import pandas as pd
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from langchain_core.messages import AIMessage
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import TypedDict, Annotated, List, Dict, Any, Optional
|
||||
from typing import Annotated, List, Dict, Any, Optional
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain.agents import AgentState as AS
|
||||
import operator
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, List
|
||||
from sqlalchemy import create_engine, select, delete
|
||||
from sqlalchemy.orm import sessionmaker, Session
|
||||
from sqlalchemy import create_engine, select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from argon2 import PasswordHasher
|
||||
from argon2.exceptions import VerifyMismatchError
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from typing import Optional, Dict, Any, List, TYPE_CHECKING
|
||||
from typing import Optional, Dict, Any, TYPE_CHECKING
|
||||
import yaml
|
||||
import json
|
||||
import os
|
||||
from ea_chatbot.utils.db_client import DBClient
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@@ -20,7 +20,8 @@ def to_yaml(json_str: str, indent: int = 2) -> str:
|
||||
"""
|
||||
Attempts to convert a JSON string (potentially malformed from LLM) to a YAML string.
|
||||
"""
|
||||
if not json_str: return ""
|
||||
if not json_str:
|
||||
return ""
|
||||
|
||||
try:
|
||||
# Try direct parse
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Optional, cast, TYPE_CHECKING, Literal, Dict, List, Tuple, Any
|
||||
from typing import Optional, List
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import patch
|
||||
from ea_chatbot.api.main import app
|
||||
from ea_chatbot.history.models import User
|
||||
|
||||
@@ -66,31 +66,43 @@ def test_protected_route_without_token():
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_oidc_login_redirect():
|
||||
"""Test that OIDC login returns a redirect URL."""
|
||||
"""Test that OIDC login returns a redirect URL and sets session cookie."""
|
||||
with patch("ea_chatbot.api.routers.auth.oidc_client") as mock_oidc:
|
||||
mock_oidc.get_login_url.return_value = "https://oidc-provider.com/auth"
|
||||
mock_oidc.get_auth_data.return_value = {
|
||||
"url": "https://oidc-provider.com/auth",
|
||||
"state": "test_state",
|
||||
"nonce": "test_nonce",
|
||||
"code_verifier": "test_verifier"
|
||||
}
|
||||
|
||||
response = client.get("/api/v1/auth/oidc/login")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["url"] == "https://oidc-provider.com/auth"
|
||||
assert "oidc_session" in response.cookies
|
||||
|
||||
def test_oidc_callback_success_ajax():
|
||||
"""Test successful OIDC callback and JWT issuance via AJAX."""
|
||||
def test_oidc_callback_success():
|
||||
"""Test successful OIDC callback and JWT issuance."""
|
||||
with patch("ea_chatbot.api.routers.auth.oidc_client") as mock_oidc, \
|
||||
patch("ea_chatbot.api.routers.auth.OIDCSession.decrypt") as mock_decrypt, \
|
||||
patch("ea_chatbot.api.routers.auth.history_manager") as mock_hm:
|
||||
|
||||
mock_oidc.exchange_code_for_token.return_value = {"access_token": "oidc-token"}
|
||||
mock_oidc.get_user_info.return_value = {"email": "sso@example.com", "name": "SSO User"}
|
||||
mock_decrypt.return_value = {
|
||||
"state": "test_state",
|
||||
"nonce": "test_nonce",
|
||||
"code_verifier": "test_verifier"
|
||||
}
|
||||
mock_oidc.exchange_code_for_token.return_value = {"id_token": "fake_id_token"}
|
||||
mock_oidc.validate_id_token.return_value = {"email": "sso@example.com", "name": "SSO User"}
|
||||
mock_hm.sync_user_from_oidc.return_value = User(id="sso-123", username="sso@example.com", display_name="SSO User")
|
||||
|
||||
client.cookies.set("oidc_session", "fake_token")
|
||||
response = client.get(
|
||||
"/api/v1/auth/oidc/callback?code=some-code",
|
||||
headers={"Accept": "application/json"}
|
||||
"/api/v1/auth/oidc/callback?code=some-code&state=test_state",
|
||||
follow_redirects=False
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert "access_token" in response.json()
|
||||
assert response.json()["token_type"] == "bearer"
|
||||
assert response.status_code == 302
|
||||
assert "access_token" in response.cookies
|
||||
|
||||
def test_get_me_success():
|
||||
"""Test getting current user with a valid token."""
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import patch
|
||||
from ea_chatbot.api.main import app
|
||||
from ea_chatbot.history.models import User
|
||||
from ea_chatbot.api.utils import create_access_token
|
||||
@@ -83,15 +83,25 @@ def test_logout_clears_cookie():
|
||||
def test_oidc_callback_redirects_with_cookie():
|
||||
"""Test that OIDC callback sets cookie and redirects."""
|
||||
with patch("ea_chatbot.api.routers.auth.oidc_client") as mock_oidc, \
|
||||
patch("ea_chatbot.api.routers.auth.OIDCSession.decrypt") as mock_decrypt, \
|
||||
patch("ea_chatbot.api.routers.auth.history_manager") as mock_hm:
|
||||
|
||||
mock_oidc.exchange_code_for_token.return_value = {"access_token": "oidc-token"}
|
||||
mock_oidc.get_user_info.return_value = {"email": "sso@example.com", "name": "SSO User"}
|
||||
mock_decrypt.return_value = {
|
||||
"state": "test_state",
|
||||
"nonce": "test_nonce",
|
||||
"code_verifier": "test_verifier"
|
||||
}
|
||||
mock_oidc.exchange_code_for_token.return_value = {"id_token": "fake_id_token"}
|
||||
mock_oidc.validate_id_token.return_value = {"email": "sso@example.com", "name": "SSO User"}
|
||||
mock_hm.sync_user_from_oidc.return_value = User(id="sso-123", username="sso@example.com", display_name="SSO User")
|
||||
|
||||
# Follow_redirects=False to catch the 307/302
|
||||
response = client.get("/api/v1/auth/oidc/callback?code=some-code", follow_redirects=False)
|
||||
# Set the session cookie
|
||||
client.cookies.set("oidc_session", "fake_token")
|
||||
|
||||
assert response.status_code in [302, 303, 307]
|
||||
# Follow_redirects=False to catch the 302
|
||||
response = client.get("/api/v1/auth/oidc/callback?code=some-code&state=test_state", follow_redirects=False)
|
||||
|
||||
assert response.status_code == 302
|
||||
assert "access_token" in response.cookies
|
||||
assert "/auth/callback" in response.headers["location"]
|
||||
# Should redirect to FRONTEND_URL (default http://localhost:5173)
|
||||
assert "http://localhost:5173" in response.headers["location"]
|
||||
|
||||
82
backend/tests/api/test_oidc_flow.py
Normal file
82
backend/tests/api/test_oidc_flow.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import patch
|
||||
from ea_chatbot.api.main import app
|
||||
from ea_chatbot.history.models import User
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Provides a fresh TestClient for each test."""
|
||||
return TestClient(app)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_auth_data():
|
||||
return {
|
||||
"url": "https://example.com/auth?state=test_state",
|
||||
"state": "test_state",
|
||||
"nonce": "test_nonce",
|
||||
"code_verifier": "test_verifier"
|
||||
}
|
||||
|
||||
def test_oidc_login_sets_cookie(client, mock_auth_data):
|
||||
"""Test that OIDC login initiation sets the temporary session cookie."""
|
||||
with patch("ea_chatbot.api.routers.auth.oidc_client") as mock_oidc:
|
||||
mock_oidc.get_auth_data.return_value = mock_auth_data
|
||||
|
||||
response = client.get("/api/v1/auth/oidc/login")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["url"] == mock_auth_data["url"]
|
||||
assert "oidc_session" in response.cookies
|
||||
|
||||
def test_oidc_callback_missing_cookie(client):
|
||||
"""Test that OIDC callback fails if the temporary session cookie is missing."""
|
||||
with patch("ea_chatbot.api.routers.auth.oidc_client"):
|
||||
# Ensure no cookies are set
|
||||
client.cookies.clear()
|
||||
response = client.get("/api/v1/auth/oidc/callback?code=test_code&state=test_state", follow_redirects=False)
|
||||
|
||||
# Should redirect to frontend with error
|
||||
assert response.status_code == 302
|
||||
assert "error=oidc_failed" in response.headers["location"]
|
||||
|
||||
def test_oidc_callback_invalid_state(client, mock_auth_data):
|
||||
"""Test that OIDC callback fails if the state in the URL doesn't match the cookie."""
|
||||
with patch("ea_chatbot.api.routers.auth.oidc_client"), \
|
||||
patch("ea_chatbot.api.routers.auth.OIDCSession.decrypt") as mock_decrypt:
|
||||
|
||||
# Mock valid cookie content
|
||||
mock_decrypt.return_value = mock_auth_data
|
||||
|
||||
# Send different state in URL
|
||||
client.cookies.set("oidc_session", "fake_token")
|
||||
response = client.get("/api/v1/auth/oidc/callback?code=test_code&state=wrong_state", follow_redirects=False)
|
||||
|
||||
assert response.status_code == 302
|
||||
assert "error=oidc_failed" in response.headers["location"]
|
||||
|
||||
def test_oidc_callback_success(client, mock_auth_data):
|
||||
"""Test successful OIDC callback and session establishment."""
|
||||
with patch("ea_chatbot.api.routers.auth.oidc_client") as mock_oidc, \
|
||||
patch("ea_chatbot.api.routers.auth.OIDCSession.decrypt") as mock_decrypt, \
|
||||
patch("ea_chatbot.api.routers.auth.history_manager") as mock_hm, \
|
||||
patch("ea_chatbot.api.routers.auth.create_access_token") as mock_create_token:
|
||||
|
||||
mock_decrypt.return_value = mock_auth_data
|
||||
mock_oidc.exchange_code_for_token.return_value = {"id_token": "fake_id_token"}
|
||||
mock_oidc.validate_id_token.return_value = {"email": "user@test.com", "name": "Test User"}
|
||||
mock_hm.sync_user_from_oidc.return_value = User(id="user-123", username="user@test.com")
|
||||
mock_create_token.return_value = "fake_access_token"
|
||||
|
||||
client.cookies.set("oidc_session", "fake_token")
|
||||
response = client.get("/api/v1/auth/oidc/callback?code=test_code&state=test_state", follow_redirects=False)
|
||||
|
||||
assert response.status_code == 302
|
||||
# Redirect to FRONTEND_URL (default localhost:5173)
|
||||
assert "http://localhost:5173" in response.headers["location"]
|
||||
assert "access_token" in response.cookies
|
||||
|
||||
# Verify oidc_session was cleaned up (deleted)
|
||||
# In RedirectResponse, delete_cookie works by setting the cookie with empty value and past expiry
|
||||
cookie_header = response.headers.get("set-cookie", "")
|
||||
assert "oidc_session=;" in cookie_header or 'oidc_session=""' in cookie_header
|
||||
@@ -3,10 +3,8 @@ from fastapi.testclient import TestClient
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
from ea_chatbot.api.main import app
|
||||
from ea_chatbot.api.dependencies import get_current_user
|
||||
from ea_chatbot.history.models import User, Conversation, Message, Plot
|
||||
from ea_chatbot.history.models import User, Conversation
|
||||
from ea_chatbot.api.utils import create_access_token
|
||||
from datetime import datetime, timezone
|
||||
import json
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
@@ -30,7 +30,6 @@ async def test_get_checkpointer_initialization():
|
||||
checkpoint.get_pool = mock_get_pool
|
||||
|
||||
# Patch AsyncPostgresSaver where it's imported in checkpoint.py
|
||||
import langgraph.checkpoint.postgres.aio as pg_aio
|
||||
original_saver = checkpoint.AsyncPostgresSaver
|
||||
checkpoint.AsyncPostgresSaver = mock_saver_class
|
||||
|
||||
|
||||
@@ -1,90 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from streamlit.testing.v1 import AppTest
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
# Ensure src is in python path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../src')))
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_history_manager():
|
||||
"""Globally mock HistoryManager to avoid DB calls during AppTest."""
|
||||
with patch("ea_chatbot.history.manager.HistoryManager") as mock_cls:
|
||||
instance = mock_cls.return_value
|
||||
instance.create_conversation.return_value = MagicMock(id="conv_123")
|
||||
instance.get_conversations.return_value = []
|
||||
instance.get_messages.return_value = []
|
||||
instance.add_message.return_value = MagicMock()
|
||||
instance.update_conversation_summary.return_value = MagicMock()
|
||||
yield instance
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_stream():
|
||||
with patch("ea_chatbot.graph.workflow.app.stream") as mock_stream:
|
||||
# Mock events from app.stream
|
||||
mock_stream.return_value = [
|
||||
{"query_analyzer": {"next_action": "research"}},
|
||||
{"researcher": {"messages": [AIMessage(content="Research result")]}}
|
||||
]
|
||||
yield mock_stream
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user():
|
||||
user = MagicMock()
|
||||
user.id = "test_id"
|
||||
user.username = "test@example.com"
|
||||
user.display_name = "Test User"
|
||||
return user
|
||||
|
||||
def test_app_initial_state(mock_app_stream, mock_user):
|
||||
"""Test that the app initializes with the correct title and empty history."""
|
||||
at = AppTest.from_file("src/ea_chatbot/app.py")
|
||||
|
||||
# Simulate logged-in user
|
||||
at.session_state["user"] = mock_user
|
||||
|
||||
at.run()
|
||||
|
||||
assert not at.exception
|
||||
assert at.title[0].value == "🗳️ Election Analytics Chatbot"
|
||||
|
||||
# Check session state initialization
|
||||
assert "messages" in at.session_state
|
||||
assert len(at.session_state["messages"]) == 0
|
||||
|
||||
def test_app_dev_mode_toggle(mock_app_stream, mock_user):
|
||||
"""Test that the dev mode toggle exists in the sidebar."""
|
||||
with patch.dict(os.environ, {"DEV_MODE": "false"}):
|
||||
at = AppTest.from_file("src/ea_chatbot/app.py")
|
||||
at.session_state["user"] = mock_user
|
||||
at.run()
|
||||
|
||||
# Check for sidebar toggle (checkbox)
|
||||
assert len(at.sidebar.checkbox) > 0
|
||||
dev_mode_toggle = at.sidebar.checkbox[0]
|
||||
assert dev_mode_toggle.label == "Dev Mode"
|
||||
assert dev_mode_toggle.value is False
|
||||
|
||||
def test_app_graph_execution_streaming(mock_app_stream, mock_user, mock_history_manager):
|
||||
"""Test that entering a prompt triggers the graph stream and displays response."""
|
||||
at = AppTest.from_file("src/ea_chatbot/app.py")
|
||||
at.session_state["user"] = mock_user
|
||||
at.run()
|
||||
|
||||
# Input a question
|
||||
at.chat_input[0].set_value("Test question").run()
|
||||
|
||||
# Verify graph stream was called
|
||||
assert mock_app_stream.called
|
||||
|
||||
# Message should be added to history
|
||||
assert len(at.session_state["messages"]) == 2
|
||||
assert at.session_state["messages"][0]["role"] == "user"
|
||||
assert at.session_state["messages"][1]["role"] == "assistant"
|
||||
assert "Research result" in at.session_state["messages"][1]["content"]
|
||||
|
||||
# Verify history manager was used
|
||||
assert mock_history_manager.create_conversation.called
|
||||
assert mock_history_manager.add_message.called
|
||||
@@ -1,83 +0,0 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from streamlit.testing.v1 import AppTest
|
||||
from ea_chatbot.auth import AuthType
|
||||
|
||||
@pytest.fixture
|
||||
def mock_history_manager_instance():
|
||||
# We need to patch before the AppTest loads the module
|
||||
with patch("ea_chatbot.history.manager.HistoryManager") as mock_cls:
|
||||
instance = mock_cls.return_value
|
||||
yield instance
|
||||
|
||||
def test_auth_ui_flow_step1_to_password(mock_history_manager_instance):
|
||||
"""Test UI transition from Step 1 (email) to Step 2a (password) for LOCAL user."""
|
||||
# Patch BEFORE creating AppTest
|
||||
mock_user = MagicMock()
|
||||
mock_user.password_hash = "hashed_password"
|
||||
mock_history_manager_instance.get_user.return_value = mock_user
|
||||
|
||||
at = AppTest.from_file("src/ea_chatbot/app.py")
|
||||
at.run()
|
||||
|
||||
# Step 1: Identification
|
||||
assert at.session_state["login_step"] == "email"
|
||||
at.text_input[0].set_value("local@example.com")
|
||||
|
||||
at.button[0].click().run()
|
||||
|
||||
# Verify transition to password step
|
||||
assert at.session_state["login_step"] == "login_password"
|
||||
assert at.session_state["login_email"] == "local@example.com"
|
||||
assert "Welcome back" in at.info[0].value
|
||||
|
||||
def test_auth_ui_flow_step1_to_register(mock_history_manager_instance):
|
||||
"""Test UI transition from Step 1 (email) to Step 2b (registration) for NEW user."""
|
||||
mock_history_manager_instance.get_user.return_value = None
|
||||
|
||||
at = AppTest.from_file("src/ea_chatbot/app.py")
|
||||
at.run()
|
||||
|
||||
# Step 1: Identification
|
||||
at.text_input[0].set_value("new@example.com")
|
||||
|
||||
at.button[0].click().run()
|
||||
|
||||
# Verify transition to registration step
|
||||
assert at.session_state["login_step"] == "register_details"
|
||||
assert at.session_state["login_email"] == "new@example.com"
|
||||
assert "Create an account" in at.info[0].value
|
||||
|
||||
def test_auth_ui_flow_step1_to_oidc(mock_history_manager_instance):
|
||||
"""Test UI transition from Step 1 (email) to Step 2c (OIDC) for OIDC user."""
|
||||
# Mock history_manager.get_user to return a user WITHOUT a password
|
||||
mock_user = MagicMock()
|
||||
mock_user.password_hash = None
|
||||
mock_history_manager_instance.get_user.return_value = mock_user
|
||||
|
||||
at = AppTest.from_file("src/ea_chatbot/app.py")
|
||||
at.run()
|
||||
|
||||
# Step 1: Identification
|
||||
at.text_input[0].set_value("oidc@example.com")
|
||||
|
||||
at.button[0].click().run()
|
||||
|
||||
# Verify transition to OIDC step
|
||||
assert at.session_state["login_step"] == "oidc_login"
|
||||
assert at.session_state["login_email"] == "oidc@example.com"
|
||||
assert "configured for Single Sign-On" in at.info[0].value
|
||||
|
||||
def test_auth_ui_flow_back_button(mock_history_manager_instance):
|
||||
"""Test that the 'Back' button returns to Step 1."""
|
||||
at = AppTest.from_file("src/ea_chatbot/app.py")
|
||||
# Simulate being on Step 2a
|
||||
at.session_state["login_step"] = "login_password"
|
||||
at.session_state["login_email"] = "local@example.com"
|
||||
at.run()
|
||||
|
||||
# Click Back (index 1 in Step 2a)
|
||||
at.button[1].click().run()
|
||||
|
||||
# Verify return to email step
|
||||
assert at.session_state["login_step"] == "email"
|
||||
@@ -1,43 +0,0 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.auth import OIDCClient
|
||||
|
||||
@patch("ea_chatbot.auth.OAuth2Session")
|
||||
def test_oidc_client_initialization(mock_oauth):
|
||||
client = OIDCClient(
|
||||
client_id="test_id",
|
||||
client_secret="test_secret",
|
||||
server_metadata_url="https://test.server/.well-known/openid-configuration"
|
||||
)
|
||||
assert client.oauth_session is not None
|
||||
|
||||
@patch("ea_chatbot.auth.requests")
|
||||
@patch("ea_chatbot.auth.OAuth2Session")
|
||||
def test_get_login_url(mock_oauth_cls, mock_requests):
|
||||
# Setup mock session
|
||||
mock_session = MagicMock()
|
||||
mock_oauth_cls.return_value = mock_session
|
||||
|
||||
# Mock metadata response
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"authorization_endpoint": "https://test.server/auth",
|
||||
"token_endpoint": "https://test.server/token",
|
||||
"userinfo_endpoint": "https://test.server/userinfo"
|
||||
}
|
||||
mock_requests.get.return_value = mock_response
|
||||
|
||||
# Mock authorization url generation
|
||||
mock_session.create_authorization_url.return_value = ("https://test.server/auth?response_type=code", "state")
|
||||
|
||||
client = OIDCClient(
|
||||
client_id="test_id",
|
||||
client_secret="test_secret",
|
||||
server_metadata_url="https://test.server/.well-known/openid-configuration"
|
||||
)
|
||||
|
||||
url = client.get_login_url()
|
||||
|
||||
assert url == "https://test.server/auth?response_type=code"
|
||||
# Verify metadata was fetched via requests
|
||||
mock_requests.get.assert_called_with("https://test.server/.well-known/openid-configuration")
|
||||
@@ -1,49 +0,0 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
from ea_chatbot.history.manager import HistoryManager
|
||||
from ea_chatbot.auth import get_user_auth_type, AuthType
|
||||
|
||||
# Mocks
|
||||
@pytest.fixture
|
||||
def mock_history_manager():
|
||||
return MagicMock(spec=HistoryManager)
|
||||
|
||||
def test_auth_flow_existing_local_user(mock_history_manager):
|
||||
"""Test that an existing user with a password returns LOCAL auth type."""
|
||||
# Setup
|
||||
mock_user = MagicMock()
|
||||
mock_user.password_hash = "hashed_secret"
|
||||
mock_history_manager.get_user.return_value = mock_user
|
||||
|
||||
# Execute
|
||||
auth_type = get_user_auth_type("test@example.com", mock_history_manager)
|
||||
|
||||
# Verify
|
||||
assert auth_type == AuthType.LOCAL
|
||||
mock_history_manager.get_user.assert_called_once_with("test@example.com")
|
||||
|
||||
def test_auth_flow_existing_oidc_user(mock_history_manager):
|
||||
"""Test that an existing user WITHOUT a password returns OIDC auth type."""
|
||||
# Setup
|
||||
mock_user = MagicMock()
|
||||
mock_user.password_hash = None # No password implies OIDC
|
||||
mock_history_manager.get_user.return_value = mock_user
|
||||
|
||||
# Execute
|
||||
auth_type = get_user_auth_type("sso@example.com", mock_history_manager)
|
||||
|
||||
# Verify
|
||||
assert auth_type == AuthType.OIDC
|
||||
mock_history_manager.get_user.assert_called_once_with("sso@example.com")
|
||||
|
||||
def test_auth_flow_new_user(mock_history_manager):
|
||||
"""Test that a non-existent user returns NEW auth type."""
|
||||
# Setup
|
||||
mock_history_manager.get_user.return_value = None
|
||||
|
||||
# Execute
|
||||
auth_type = get_user_auth_type("new@example.com", mock_history_manager)
|
||||
|
||||
# Verify
|
||||
assert auth_type == AuthType.NEW
|
||||
mock_history_manager.get_user.assert_called_once_with("new@example.com")
|
||||
@@ -1,5 +1,3 @@
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
from ea_chatbot.config import Settings, LLMConfig
|
||||
|
||||
def test_default_settings():
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
import pandas as pd
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock
|
||||
from ea_chatbot.utils.database_inspection import get_primary_key, inspect_db_table, get_data_summary
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -3,7 +3,6 @@ import pandas as pd
|
||||
from unittest.mock import MagicMock, patch
|
||||
from matplotlib.figure import Figure
|
||||
from ea_chatbot.graph.nodes.executor import executor_node
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings():
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import pytest
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from ea_chatbot.utils.helpers import merge_agent_state
|
||||
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker, DeclarativeBase
|
||||
|
||||
# We anticipate these imports will fail initially
|
||||
try:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import pytest
|
||||
from langchain_openai import ChatOpenAI
|
||||
from ea_chatbot.config import LLMConfig
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import os
|
||||
import json
|
||||
import pytest
|
||||
import logging
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import patch
|
||||
from ea_chatbot.graph.workflow import app
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.utils.logging import get_logger
|
||||
|
||||
@@ -2,7 +2,6 @@ import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node, QueryAnalysis
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
|
||||
@pytest.fixture
|
||||
def mock_state_with_history():
|
||||
|
||||
@@ -1,87 +0,0 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.auth import OIDCClient
|
||||
|
||||
@pytest.fixture
|
||||
def oidc_config():
|
||||
return {
|
||||
"client_id": "test_id",
|
||||
"client_secret": "test_secret",
|
||||
"server_metadata_url": "https://example.com/.well-known/openid-configuration",
|
||||
"redirect_uri": "http://localhost:8501"
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_metadata():
|
||||
return {
|
||||
"authorization_endpoint": "https://example.com/auth",
|
||||
"token_endpoint": "https://example.com/token",
|
||||
"userinfo_endpoint": "https://example.com/userinfo"
|
||||
}
|
||||
|
||||
def test_oidc_fetch_metadata(oidc_config, mock_metadata):
|
||||
client = OIDCClient(**oidc_config)
|
||||
|
||||
with patch("requests.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = mock_metadata
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
metadata = client.fetch_metadata()
|
||||
|
||||
assert metadata == mock_metadata
|
||||
mock_get.assert_called_once_with(oidc_config["server_metadata_url"])
|
||||
|
||||
# Second call should use cache
|
||||
client.fetch_metadata()
|
||||
assert mock_get.call_count == 1
|
||||
|
||||
def test_oidc_get_login_url(oidc_config, mock_metadata):
|
||||
client = OIDCClient(**oidc_config)
|
||||
client.metadata = mock_metadata
|
||||
|
||||
with patch.object(client.oauth_session, "create_authorization_url") as mock_create_url:
|
||||
mock_create_url.return_value = ("https://example.com/auth?state=xyz", "xyz")
|
||||
|
||||
url = client.get_login_url()
|
||||
|
||||
assert url == "https://example.com/auth?state=xyz"
|
||||
mock_create_url.assert_called_once_with(mock_metadata["authorization_endpoint"])
|
||||
|
||||
def test_oidc_get_login_url_missing_endpoint(oidc_config):
|
||||
client = OIDCClient(**oidc_config)
|
||||
client.metadata = {"some_other": "field"}
|
||||
|
||||
with pytest.raises(ValueError, match="authorization_endpoint not found"):
|
||||
client.get_login_url()
|
||||
|
||||
def test_oidc_exchange_code_for_token(oidc_config, mock_metadata):
|
||||
client = OIDCClient(**oidc_config)
|
||||
client.metadata = mock_metadata
|
||||
|
||||
with patch.object(client.oauth_session, "fetch_token") as mock_fetch_token:
|
||||
mock_fetch_token.return_value = {"access_token": "abc"}
|
||||
|
||||
token = client.exchange_code_for_token("test_code")
|
||||
|
||||
assert token == {"access_token": "abc"}
|
||||
mock_fetch_token.assert_called_once_with(
|
||||
mock_metadata["token_endpoint"],
|
||||
code="test_code",
|
||||
client_secret=oidc_config["client_secret"]
|
||||
)
|
||||
|
||||
def test_oidc_get_user_info(oidc_config, mock_metadata):
|
||||
client = OIDCClient(**oidc_config)
|
||||
client.metadata = mock_metadata
|
||||
|
||||
with patch.object(client.oauth_session, "get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"sub": "user123"}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
user_info = client.get_user_info({"access_token": "abc"})
|
||||
|
||||
assert user_info == {"sub": "user123"}
|
||||
assert client.oauth_session.token == {"access_token": "abc"}
|
||||
mock_get.assert_called_once_with(mock_metadata["userinfo_endpoint"])
|
||||
@@ -1,7 +1,6 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import patch
|
||||
from ea_chatbot.auth import OIDCClient
|
||||
from jose import jwt
|
||||
|
||||
@pytest.fixture
|
||||
def oidc_config():
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node, QueryAnalysis
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
|
||||
@pytest.fixture
|
||||
def mock_state():
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import pytest
|
||||
import logging
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node, QueryAnalysis
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node, QueryAnalysis
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
|
||||
@pytest.fixture
|
||||
def base_state():
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
import pytest
|
||||
from typing import get_type_hints, List
|
||||
from langchain_core.messages import BaseMessage, HumanMessage
|
||||
from typing import get_type_hints
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
import operator
|
||||
|
||||
def test_agent_state_structure():
|
||||
"""Verify that AgentState has the required fields and types."""
|
||||
|
||||
@@ -2,7 +2,6 @@ import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from langchain_core.messages import AIMessage
|
||||
from ea_chatbot.graph.nodes.summarizer import summarizer_node
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm():
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.graph.workflow import app
|
||||
from ea_chatbot.graph.nodes.query_analyzer import QueryAnalysis
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import pytest
|
||||
import yaml
|
||||
from unittest.mock import MagicMock, patch
|
||||
from langchain_core.messages import AIMessage
|
||||
from ea_chatbot.graph.workflow import app
|
||||
|
||||
@@ -3,7 +3,6 @@ import { Routes, Route } from "react-router-dom"
|
||||
import { MainLayout } from "./components/layout/MainLayout"
|
||||
import { LoginForm } from "./components/auth/LoginForm"
|
||||
import { RegisterForm } from "./components/auth/RegisterForm"
|
||||
import { AuthCallback } from "./components/auth/AuthCallback"
|
||||
import { ChatInterface } from "./components/chat/ChatInterface"
|
||||
import { AuthService, type UserResponse } from "./services/auth"
|
||||
import { ChatService, type MessageResponse } from "./services/chat"
|
||||
@@ -136,6 +135,9 @@ function App() {
|
||||
setThreadMessages(prev => ({ ...prev, [id]: messages }))
|
||||
}
|
||||
|
||||
const queryParams = new URLSearchParams(window.location.search)
|
||||
const externalError = queryParams.get("error")
|
||||
|
||||
if (isLoading) {
|
||||
return (
|
||||
<div className="min-h-screen flex items-center justify-center bg-background">
|
||||
@@ -146,7 +148,6 @@ function App() {
|
||||
|
||||
return (
|
||||
<Routes>
|
||||
<Route path="/auth/callback" element={<AuthCallback />} />
|
||||
<Route
|
||||
path="*"
|
||||
element={
|
||||
@@ -156,6 +157,7 @@ function App() {
|
||||
<LoginForm
|
||||
onSuccess={handleAuthSuccess}
|
||||
onToggleMode={() => setAuthMode("register")}
|
||||
externalError={externalError === "oidc_failed" ? "SSO authentication failed. Please try again." : null}
|
||||
/>
|
||||
) : (
|
||||
<RegisterForm
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
import { useEffect } from "react"
|
||||
import { useNavigate } from "react-router-dom"
|
||||
import { AuthService } from "@/services/auth"
|
||||
|
||||
export function AuthCallback() {
|
||||
const navigate = useNavigate()
|
||||
|
||||
useEffect(() => {
|
||||
const verifyAuth = async () => {
|
||||
const urlParams = new URLSearchParams(window.location.search)
|
||||
const code = urlParams.get("code")
|
||||
|
||||
try {
|
||||
if (code) {
|
||||
// If we have a code, exchange it for a cookie
|
||||
await AuthService.exchangeOIDCCode(code)
|
||||
} else {
|
||||
// If no code, just verify existing cookie (backend-driven redirect)
|
||||
await AuthService.getMe()
|
||||
}
|
||||
|
||||
// Success - go to home. We use window.location.href to ensure a clean reload of App state
|
||||
window.location.href = "/"
|
||||
} catch (err) {
|
||||
console.error("Auth callback verification failed:", err)
|
||||
navigate("/?error=auth_failed", { replace: true })
|
||||
}
|
||||
}
|
||||
|
||||
verifyAuth()
|
||||
}, [navigate])
|
||||
|
||||
return (
|
||||
<div className="min-h-screen flex flex-col items-center justify-center bg-background">
|
||||
<div className="animate-spin rounded-full h-8 w-8 border-b-2 border-primary mb-4"></div>
|
||||
<p className="text-muted-foreground">Completing login...</p>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -22,10 +22,11 @@ import axios from "axios"
|
||||
interface LoginFormProps {
|
||||
onSuccess: () => void
|
||||
onToggleMode: () => void
|
||||
externalError?: string | null
|
||||
}
|
||||
|
||||
export function LoginForm({ onSuccess, onToggleMode }: LoginFormProps) {
|
||||
const [error, setError] = useState<string | null>(null)
|
||||
export function LoginForm({ onSuccess, onToggleMode, externalError }: LoginFormProps) {
|
||||
const [error, setError] = useState<string | null>(externalError || null)
|
||||
const [isLoading, setIsLoading] = useState(false)
|
||||
|
||||
const form = useForm<LoginInput>({
|
||||
|
||||
@@ -28,11 +28,6 @@ export const AuthService = {
|
||||
}
|
||||
},
|
||||
|
||||
async exchangeOIDCCode(code: string): Promise<AuthResponse> {
|
||||
const response = await api.get<AuthResponse>(`/auth/oidc/callback?code=${code}`)
|
||||
return response.data
|
||||
},
|
||||
|
||||
async register(email: string, password: string): Promise<UserResponse> {
|
||||
const response = await api.post<UserResponse>("/auth/register", {
|
||||
email,
|
||||
|
||||
Reference in New Issue
Block a user