feat(auth): Complete OIDC security refactor and modernize test suite

- Refactored OIDC flow to implement PKCE, state/nonce validation, and BFF pattern.
- Centralized configuration in Settings class (DEV_MODE, FRONTEND_URL, OIDC_REDIRECT_URI).
- Updated auth routers to use conditional secure cookie flags based on DEV_MODE.
- Modernized and cleaned up test suite by removing legacy Streamlit tests.
- Fixed linting errors and unused imports across the backend.
This commit is contained in:
Yunxiao Xu
2026-02-15 02:50:26 -08:00
parent 48ad0ebdd7
commit 68c0985482
50 changed files with 222 additions and 515 deletions

View File

@@ -6,7 +6,8 @@ GOOGLE_API_KEY=your_google_api_key_here
DATA_DIR=data
DATA_STATE=new_jersey
LOG_LEVEL=INFO
DEV_MODE=false
DEV_MODE=true
FRONTEND_URL=http://localhost:5173
# Security & JWT Configuration
SECRET_KEY=change-me-in-production
@@ -28,7 +29,7 @@ HISTORY_DB_URL=postgresql://user:password@localhost:5433/ea_history
OIDC_CLIENT_ID=your_client_id
OIDC_CLIENT_SECRET=your_client_secret
OIDC_SERVER_METADATA_URL=https://your-authentik.example.com/application/o/ea-chatbot/.well-known/openid-configuration
OIDC_REDIRECT_URI=http://localhost:8501
OIDC_REDIRECT_URI=http://localhost:8000/api/v1/auth/oidc/callback
# Node Configuration Overrides (Optional)
# Format: <NODE_NAME>_LLM__<PARAMETER>

View File

@@ -1,4 +1,3 @@
import os
from fastapi import Depends, HTTPException, status, Request
from fastapi.security import OAuth2PasswordBearer
from ea_chatbot.config import Settings
@@ -18,7 +17,7 @@ if settings.oidc_client_id and settings.oidc_client_secret and settings.oidc_ser
client_id=settings.oidc_client_id,
client_secret=settings.oidc_client_secret,
server_metadata_url=settings.oidc_server_metadata_url,
redirect_uri=os.getenv("OIDC_REDIRECT_URI", "http://localhost:3000/auth/callback")
redirect_uri=settings.oidc_redirect_uri
)
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/v1/auth/login", auto_error=False)

View File

@@ -1,9 +1,6 @@
from fastapi import FastAPI, APIRouter
from fastapi.middleware.cors import CORSMiddleware
from ea_chatbot.api.routers import auth, history, artifacts, agent
from dotenv import load_dotenv
load_dotenv()
app = FastAPI(
title="Election Analytics Chatbot API",

View File

@@ -1,7 +1,6 @@
import json
import asyncio
from typing import AsyncGenerator, Optional, List
from fastapi import APIRouter, Depends, HTTPException, status
from typing import AsyncGenerator, List
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from ea_chatbot.api.dependencies import get_current_user, history_manager
from ea_chatbot.api.utils import convert_to_json_compatible

View File

@@ -1,7 +1,6 @@
from fastapi import APIRouter, Depends, HTTPException, Response, status
from ea_chatbot.api.dependencies import get_current_user, history_manager
from ea_chatbot.history.models import Plot, Message, User as UserDB
import io
router = APIRouter(prefix="/artifacts", tags=["artifacts"])

View File

@@ -1,19 +1,17 @@
from fastapi import APIRouter, Depends, HTTPException, status, Response, Request
from fastapi.responses import RedirectResponse
from fastapi.security import OAuth2PasswordRequestForm
from ea_chatbot.api.utils import create_access_token
from ea_chatbot.api.utils import create_access_token, settings
from ea_chatbot.api.dependencies import history_manager, oidc_client, get_current_user
from ea_chatbot.history.models import User as UserDB
from ea_chatbot.api.schemas import Token, UserCreate, UserResponse
import os
from ea_chatbot.auth import OIDCSession
import logging
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/auth", tags=["auth"])
FRONTEND_URL = os.getenv("FRONTEND_URL", "http://localhost:5173")
def set_auth_cookie(response: Response, token: str):
response.set_cookie(
key="access_token",
@@ -22,7 +20,7 @@ def set_auth_cookie(response: Response, token: str):
max_age=1800,
expires=1800,
samesite="lax",
secure=False, # Set to True in production with HTTPS
secure=not settings.dev_mode,
)
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
@@ -72,59 +70,84 @@ async def logout(response: Response):
return {"detail": "Successfully logged out"}
@router.get("/oidc/login")
async def oidc_login():
"""Get the OIDC authorization URL."""
async def oidc_login(response: Response):
"""Get the OIDC authorization URL and set temporary session cookie."""
if not oidc_client:
raise HTTPException(
status_code=status.HTTP_510_NOT_EXTENDED,
detail="OIDC is not configured"
)
url = oidc_client.get_login_url()
return {"url": url}
auth_data = oidc_client.get_auth_data()
# Store state, nonce, and code_verifier in a secure short-lived cookie
session_token = OIDCSession.encrypt(
{
"state": auth_data["state"],
"nonce": auth_data["nonce"],
"code_verifier": auth_data["code_verifier"]
},
settings.secret_key
)
response.set_cookie(
key="oidc_session",
value=session_token,
httponly=True,
max_age=300, # 5 minutes
samesite="lax",
secure=not settings.dev_mode
)
return {"url": auth_data["url"]}
@router.get("/oidc/callback", response_model=Token)
async def oidc_callback(request: Request, response: Response, code: str):
"""Handle the OIDC callback, issue a JWT, and redirect or return JSON."""
@router.get("/oidc/callback")
async def oidc_callback(request: Request, response: Response, code: str, state: str):
"""Handle the OIDC callback, validate state/nonce, issue a JWT, and redirect."""
if not oidc_client:
raise HTTPException(status_code=status.HTTP_510_NOT_EXTENDED, detail="OIDC not configured")
# 1. Validate state and retrieve session data from cookie
session_token = request.cookies.get("oidc_session")
if not session_token:
logger.error("OIDC session cookie missing")
return RedirectResponse(url=f"{settings.frontend_url}?error=oidc_failed", status_code=status.HTTP_302_FOUND)
session_data = OIDCSession.decrypt(session_token, settings.secret_key)
if not session_data or session_data.get("state") != state:
logger.error("OIDC state mismatch or session expired")
return RedirectResponse(url=f"{settings.frontend_url}?error=oidc_failed", status_code=status.HTTP_302_FOUND)
try:
token = oidc_client.exchange_code_for_token(code)
user_info = oidc_client.get_user_info(token)
email = user_info.get("email")
name = user_info.get("name") or user_info.get("preferred_username")
# 2. Exchange code for token using PKCE code_verifier
token = oidc_client.exchange_code_for_token(code, code_verifier=session_data.get("code_verifier"))
# 3. Validate ID Token and Nonce
id_token = token.get("id_token")
if not id_token:
raise ValueError("ID Token missing from response")
claims = oidc_client.validate_id_token(id_token, nonce=session_data.get("nonce"))
email = claims.get("email")
name = claims.get("name") or claims.get("preferred_username")
if not email:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Email not provided by OIDC")
# 4. Sync user and establish session
user = history_manager.sync_user_from_oidc(email=email, display_name=name)
access_token = create_access_token(data={"sub": str(user.id)})
# Determine if we should redirect (direct provider callback) or return JSON (frontend exchange)
is_ajax = request.headers.get("X-Requested-With") == "XMLHttpRequest" or \
"application/json" in request.headers.get("Accept", "")
if is_ajax:
set_auth_cookie(response, access_token)
return {"access_token": access_token, "token_type": "bearer"}
else:
redirect_response = RedirectResponse(url=f"{FRONTEND_URL}/auth/callback")
set_auth_cookie(redirect_response, access_token)
return redirect_response
redirect_response = RedirectResponse(url=settings.frontend_url, status_code=status.HTTP_302_FOUND)
set_auth_cookie(redirect_response, access_token)
redirect_response.delete_cookie("oidc_session") # Cleanup
return redirect_response
except HTTPException:
raise
except Exception as e:
except Exception:
logger.exception("OIDC authentication failed")
# For non-ajax, redirect with error
is_ajax = request.headers.get("X-Requested-With") == "XMLHttpRequest" or \
"application/json" in request.headers.get("Accept", "")
if is_ajax:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="OIDC authentication failed")
else:
return RedirectResponse(url=f"{FRONTEND_URL}?error=oidc_failed")
return RedirectResponse(url=f"{settings.frontend_url}?error=oidc_failed", status_code=status.HTTP_302_FOUND)
@router.get("/me", response_model=UserResponse)
async def get_me(current_user: UserDB = Depends(get_current_user)):

View File

@@ -1,5 +1,5 @@
from datetime import datetime, timedelta, timezone
from typing import Optional, Union, Any, List, Dict
from typing import Optional, Any
from jose import JWTError, jwt
from pydantic import BaseModel
from langchain_core.messages import BaseMessage

View File

@@ -1,5 +1,4 @@
import streamlit as st
import asyncio
import os
import io
from dotenv import load_dotenv
@@ -348,7 +347,7 @@ def main():
if node_name == "query_analyzer":
analysis = state_update.get("analysis", {})
next_action = state_update.get("next_action", "unknown")
status.write(f"🔍 **Analyzed Query:**")
status.write("🔍 **Analyzed Query:**")
for k,v in analysis.items():
status.write(f"- {k:<8}: {v}")
status.markdown(f"Next Step: {next_action.capitalize()}")

View File

@@ -19,6 +19,8 @@ class Settings(BaseSettings):
data_dir: str = "data"
data_state: str = "new_jersey"
log_level: str = Field(default="INFO", alias="LOG_LEVEL")
dev_mode: bool = Field(default=True, alias="DEV_MODE")
frontend_url: str = Field(default="http://localhost:5173", alias="FRONTEND_URL")
# Voter Database configuration
db_host: str = Field(default="localhost", alias="DB_HOST")
@@ -40,6 +42,7 @@ class Settings(BaseSettings):
oidc_client_id: Optional[str] = Field(default=None, alias="OIDC_CLIENT_ID")
oidc_client_secret: Optional[str] = Field(default=None, alias="OIDC_CLIENT_SECRET")
oidc_server_metadata_url: Optional[str] = Field(default=None, alias="OIDC_SERVER_METADATA_URL")
oidc_redirect_uri: str = Field(default="http://localhost:8000/api/v1/auth/oidc/callback", alias="OIDC_REDIRECT_URI")
# Default configurations for each node
query_analyzer_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0))

View File

@@ -1,4 +1,3 @@
from langchain_core.messages import AIMessage
from ea_chatbot.graph.state import AgentState
from ea_chatbot.config import Settings
from ea_chatbot.utils.llm_factory import get_llm_model

View File

@@ -1,7 +1,7 @@
from ea_chatbot.graph.state import AgentState
from ea_chatbot.config import Settings
from ea_chatbot.utils.llm_factory import get_llm_model
from ea_chatbot.utils import helpers, database_inspection
from ea_chatbot.utils import database_inspection
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
from ea_chatbot.graph.prompts.coder import CODE_GENERATOR_PROMPT
from ea_chatbot.schemas import CodeGenerationResponse

View File

@@ -2,7 +2,7 @@ import io
import sys
import traceback
from contextlib import redirect_stdout
from typing import Any, Dict, List, TYPE_CHECKING
from typing import TYPE_CHECKING
import pandas as pd
from matplotlib.figure import Figure

View File

@@ -1,4 +1,3 @@
from langchain_core.messages import AIMessage
from langchain_openai import ChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI
from ea_chatbot.graph.state import AgentState

View File

@@ -1,4 +1,3 @@
from langchain_core.messages import AIMessage
from ea_chatbot.graph.state import AgentState
from ea_chatbot.config import Settings
from ea_chatbot.utils.llm_factory import get_llm_model

View File

@@ -1,4 +1,4 @@
from typing import TypedDict, Annotated, List, Dict, Any, Optional
from typing import Annotated, List, Dict, Any, Optional
from langchain_core.messages import BaseMessage
from langchain.agents import AgentState as AS
import operator

View File

@@ -1,7 +1,7 @@
from contextlib import contextmanager
from typing import Optional, List
from sqlalchemy import create_engine, select, delete
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy import create_engine, select
from sqlalchemy.orm import sessionmaker
from argon2 import PasswordHasher
from argon2.exceptions import VerifyMismatchError

View File

@@ -1,6 +1,5 @@
from typing import Optional, Dict, Any, List, TYPE_CHECKING
from typing import Optional, Dict, Any, TYPE_CHECKING
import yaml
import json
import os
from ea_chatbot.utils.db_client import DBClient
if TYPE_CHECKING:

View File

@@ -20,7 +20,8 @@ def to_yaml(json_str: str, indent: int = 2) -> str:
"""
Attempts to convert a JSON string (potentially malformed from LLM) to a YAML string.
"""
if not json_str: return ""
if not json_str:
return ""
try:
# Try direct parse

View File

@@ -1,4 +1,4 @@
from typing import Optional, cast, TYPE_CHECKING, Literal, Dict, List, Tuple, Any
from typing import Optional, List
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_openai import ChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI

View File

@@ -1,6 +1,6 @@
import pytest
from fastapi.testclient import TestClient
from unittest.mock import MagicMock, patch
from unittest.mock import patch
from ea_chatbot.api.main import app
from ea_chatbot.history.models import User
@@ -66,31 +66,43 @@ def test_protected_route_without_token():
assert response.status_code == 401
def test_oidc_login_redirect():
"""Test that OIDC login returns a redirect URL."""
"""Test that OIDC login returns a redirect URL and sets session cookie."""
with patch("ea_chatbot.api.routers.auth.oidc_client") as mock_oidc:
mock_oidc.get_login_url.return_value = "https://oidc-provider.com/auth"
mock_oidc.get_auth_data.return_value = {
"url": "https://oidc-provider.com/auth",
"state": "test_state",
"nonce": "test_nonce",
"code_verifier": "test_verifier"
}
response = client.get("/api/v1/auth/oidc/login")
assert response.status_code == 200
assert response.json()["url"] == "https://oidc-provider.com/auth"
assert "oidc_session" in response.cookies
def test_oidc_callback_success_ajax():
"""Test successful OIDC callback and JWT issuance via AJAX."""
def test_oidc_callback_success():
"""Test successful OIDC callback and JWT issuance."""
with patch("ea_chatbot.api.routers.auth.oidc_client") as mock_oidc, \
patch("ea_chatbot.api.routers.auth.OIDCSession.decrypt") as mock_decrypt, \
patch("ea_chatbot.api.routers.auth.history_manager") as mock_hm:
mock_oidc.exchange_code_for_token.return_value = {"access_token": "oidc-token"}
mock_oidc.get_user_info.return_value = {"email": "sso@example.com", "name": "SSO User"}
mock_decrypt.return_value = {
"state": "test_state",
"nonce": "test_nonce",
"code_verifier": "test_verifier"
}
mock_oidc.exchange_code_for_token.return_value = {"id_token": "fake_id_token"}
mock_oidc.validate_id_token.return_value = {"email": "sso@example.com", "name": "SSO User"}
mock_hm.sync_user_from_oidc.return_value = User(id="sso-123", username="sso@example.com", display_name="SSO User")
client.cookies.set("oidc_session", "fake_token")
response = client.get(
"/api/v1/auth/oidc/callback?code=some-code",
headers={"Accept": "application/json"}
"/api/v1/auth/oidc/callback?code=some-code&state=test_state",
follow_redirects=False
)
assert response.status_code == 200
assert "access_token" in response.json()
assert response.json()["token_type"] == "bearer"
assert response.status_code == 302
assert "access_token" in response.cookies
def test_get_me_success():
"""Test getting current user with a valid token."""

View File

@@ -1,6 +1,6 @@
import pytest
from fastapi.testclient import TestClient
from unittest.mock import MagicMock, patch
from unittest.mock import patch
from ea_chatbot.api.main import app
from ea_chatbot.history.models import User
from ea_chatbot.api.utils import create_access_token
@@ -83,15 +83,25 @@ def test_logout_clears_cookie():
def test_oidc_callback_redirects_with_cookie():
"""Test that OIDC callback sets cookie and redirects."""
with patch("ea_chatbot.api.routers.auth.oidc_client") as mock_oidc, \
patch("ea_chatbot.api.routers.auth.OIDCSession.decrypt") as mock_decrypt, \
patch("ea_chatbot.api.routers.auth.history_manager") as mock_hm:
mock_oidc.exchange_code_for_token.return_value = {"access_token": "oidc-token"}
mock_oidc.get_user_info.return_value = {"email": "sso@example.com", "name": "SSO User"}
mock_decrypt.return_value = {
"state": "test_state",
"nonce": "test_nonce",
"code_verifier": "test_verifier"
}
mock_oidc.exchange_code_for_token.return_value = {"id_token": "fake_id_token"}
mock_oidc.validate_id_token.return_value = {"email": "sso@example.com", "name": "SSO User"}
mock_hm.sync_user_from_oidc.return_value = User(id="sso-123", username="sso@example.com", display_name="SSO User")
# Follow_redirects=False to catch the 307/302
response = client.get("/api/v1/auth/oidc/callback?code=some-code", follow_redirects=False)
# Set the session cookie
client.cookies.set("oidc_session", "fake_token")
assert response.status_code in [302, 303, 307]
# Follow_redirects=False to catch the 302
response = client.get("/api/v1/auth/oidc/callback?code=some-code&state=test_state", follow_redirects=False)
assert response.status_code == 302
assert "access_token" in response.cookies
assert "/auth/callback" in response.headers["location"]
# Should redirect to FRONTEND_URL (default http://localhost:5173)
assert "http://localhost:5173" in response.headers["location"]

View File

@@ -0,0 +1,82 @@
import pytest
from fastapi.testclient import TestClient
from unittest.mock import patch
from ea_chatbot.api.main import app
from ea_chatbot.history.models import User
@pytest.fixture
def client():
"""Provides a fresh TestClient for each test."""
return TestClient(app)
@pytest.fixture
def mock_auth_data():
return {
"url": "https://example.com/auth?state=test_state",
"state": "test_state",
"nonce": "test_nonce",
"code_verifier": "test_verifier"
}
def test_oidc_login_sets_cookie(client, mock_auth_data):
"""Test that OIDC login initiation sets the temporary session cookie."""
with patch("ea_chatbot.api.routers.auth.oidc_client") as mock_oidc:
mock_oidc.get_auth_data.return_value = mock_auth_data
response = client.get("/api/v1/auth/oidc/login")
assert response.status_code == 200
assert response.json()["url"] == mock_auth_data["url"]
assert "oidc_session" in response.cookies
def test_oidc_callback_missing_cookie(client):
"""Test that OIDC callback fails if the temporary session cookie is missing."""
with patch("ea_chatbot.api.routers.auth.oidc_client"):
# Ensure no cookies are set
client.cookies.clear()
response = client.get("/api/v1/auth/oidc/callback?code=test_code&state=test_state", follow_redirects=False)
# Should redirect to frontend with error
assert response.status_code == 302
assert "error=oidc_failed" in response.headers["location"]
def test_oidc_callback_invalid_state(client, mock_auth_data):
"""Test that OIDC callback fails if the state in the URL doesn't match the cookie."""
with patch("ea_chatbot.api.routers.auth.oidc_client"), \
patch("ea_chatbot.api.routers.auth.OIDCSession.decrypt") as mock_decrypt:
# Mock valid cookie content
mock_decrypt.return_value = mock_auth_data
# Send different state in URL
client.cookies.set("oidc_session", "fake_token")
response = client.get("/api/v1/auth/oidc/callback?code=test_code&state=wrong_state", follow_redirects=False)
assert response.status_code == 302
assert "error=oidc_failed" in response.headers["location"]
def test_oidc_callback_success(client, mock_auth_data):
"""Test successful OIDC callback and session establishment."""
with patch("ea_chatbot.api.routers.auth.oidc_client") as mock_oidc, \
patch("ea_chatbot.api.routers.auth.OIDCSession.decrypt") as mock_decrypt, \
patch("ea_chatbot.api.routers.auth.history_manager") as mock_hm, \
patch("ea_chatbot.api.routers.auth.create_access_token") as mock_create_token:
mock_decrypt.return_value = mock_auth_data
mock_oidc.exchange_code_for_token.return_value = {"id_token": "fake_id_token"}
mock_oidc.validate_id_token.return_value = {"email": "user@test.com", "name": "Test User"}
mock_hm.sync_user_from_oidc.return_value = User(id="user-123", username="user@test.com")
mock_create_token.return_value = "fake_access_token"
client.cookies.set("oidc_session", "fake_token")
response = client.get("/api/v1/auth/oidc/callback?code=test_code&state=test_state", follow_redirects=False)
assert response.status_code == 302
# Redirect to FRONTEND_URL (default localhost:5173)
assert "http://localhost:5173" in response.headers["location"]
assert "access_token" in response.cookies
# Verify oidc_session was cleaned up (deleted)
# In RedirectResponse, delete_cookie works by setting the cookie with empty value and past expiry
cookie_header = response.headers.get("set-cookie", "")
assert "oidc_session=;" in cookie_header or 'oidc_session=""' in cookie_header

View File

@@ -3,10 +3,8 @@ from fastapi.testclient import TestClient
from unittest.mock import MagicMock, patch, AsyncMock
from ea_chatbot.api.main import app
from ea_chatbot.api.dependencies import get_current_user
from ea_chatbot.history.models import User, Conversation, Message, Plot
from ea_chatbot.history.models import User, Conversation
from ea_chatbot.api.utils import create_access_token
from datetime import datetime, timezone
import json
client = TestClient(app)

View File

@@ -30,7 +30,6 @@ async def test_get_checkpointer_initialization():
checkpoint.get_pool = mock_get_pool
# Patch AsyncPostgresSaver where it's imported in checkpoint.py
import langgraph.checkpoint.postgres.aio as pg_aio
original_saver = checkpoint.AsyncPostgresSaver
checkpoint.AsyncPostgresSaver = mock_saver_class

View File

@@ -1,90 +0,0 @@
import os
import sys
import pytest
from unittest.mock import MagicMock, patch
from streamlit.testing.v1 import AppTest
from langchain_core.messages import AIMessage
# Ensure src is in python path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../src')))
@pytest.fixture(autouse=True)
def mock_history_manager():
"""Globally mock HistoryManager to avoid DB calls during AppTest."""
with patch("ea_chatbot.history.manager.HistoryManager") as mock_cls:
instance = mock_cls.return_value
instance.create_conversation.return_value = MagicMock(id="conv_123")
instance.get_conversations.return_value = []
instance.get_messages.return_value = []
instance.add_message.return_value = MagicMock()
instance.update_conversation_summary.return_value = MagicMock()
yield instance
@pytest.fixture
def mock_app_stream():
with patch("ea_chatbot.graph.workflow.app.stream") as mock_stream:
# Mock events from app.stream
mock_stream.return_value = [
{"query_analyzer": {"next_action": "research"}},
{"researcher": {"messages": [AIMessage(content="Research result")]}}
]
yield mock_stream
@pytest.fixture
def mock_user():
user = MagicMock()
user.id = "test_id"
user.username = "test@example.com"
user.display_name = "Test User"
return user
def test_app_initial_state(mock_app_stream, mock_user):
"""Test that the app initializes with the correct title and empty history."""
at = AppTest.from_file("src/ea_chatbot/app.py")
# Simulate logged-in user
at.session_state["user"] = mock_user
at.run()
assert not at.exception
assert at.title[0].value == "🗳️ Election Analytics Chatbot"
# Check session state initialization
assert "messages" in at.session_state
assert len(at.session_state["messages"]) == 0
def test_app_dev_mode_toggle(mock_app_stream, mock_user):
"""Test that the dev mode toggle exists in the sidebar."""
with patch.dict(os.environ, {"DEV_MODE": "false"}):
at = AppTest.from_file("src/ea_chatbot/app.py")
at.session_state["user"] = mock_user
at.run()
# Check for sidebar toggle (checkbox)
assert len(at.sidebar.checkbox) > 0
dev_mode_toggle = at.sidebar.checkbox[0]
assert dev_mode_toggle.label == "Dev Mode"
assert dev_mode_toggle.value is False
def test_app_graph_execution_streaming(mock_app_stream, mock_user, mock_history_manager):
"""Test that entering a prompt triggers the graph stream and displays response."""
at = AppTest.from_file("src/ea_chatbot/app.py")
at.session_state["user"] = mock_user
at.run()
# Input a question
at.chat_input[0].set_value("Test question").run()
# Verify graph stream was called
assert mock_app_stream.called
# Message should be added to history
assert len(at.session_state["messages"]) == 2
assert at.session_state["messages"][0]["role"] == "user"
assert at.session_state["messages"][1]["role"] == "assistant"
assert "Research result" in at.session_state["messages"][1]["content"]
# Verify history manager was used
assert mock_history_manager.create_conversation.called
assert mock_history_manager.add_message.called

View File

@@ -1,83 +0,0 @@
import pytest
from unittest.mock import MagicMock, patch
from streamlit.testing.v1 import AppTest
from ea_chatbot.auth import AuthType
@pytest.fixture
def mock_history_manager_instance():
# We need to patch before the AppTest loads the module
with patch("ea_chatbot.history.manager.HistoryManager") as mock_cls:
instance = mock_cls.return_value
yield instance
def test_auth_ui_flow_step1_to_password(mock_history_manager_instance):
"""Test UI transition from Step 1 (email) to Step 2a (password) for LOCAL user."""
# Patch BEFORE creating AppTest
mock_user = MagicMock()
mock_user.password_hash = "hashed_password"
mock_history_manager_instance.get_user.return_value = mock_user
at = AppTest.from_file("src/ea_chatbot/app.py")
at.run()
# Step 1: Identification
assert at.session_state["login_step"] == "email"
at.text_input[0].set_value("local@example.com")
at.button[0].click().run()
# Verify transition to password step
assert at.session_state["login_step"] == "login_password"
assert at.session_state["login_email"] == "local@example.com"
assert "Welcome back" in at.info[0].value
def test_auth_ui_flow_step1_to_register(mock_history_manager_instance):
"""Test UI transition from Step 1 (email) to Step 2b (registration) for NEW user."""
mock_history_manager_instance.get_user.return_value = None
at = AppTest.from_file("src/ea_chatbot/app.py")
at.run()
# Step 1: Identification
at.text_input[0].set_value("new@example.com")
at.button[0].click().run()
# Verify transition to registration step
assert at.session_state["login_step"] == "register_details"
assert at.session_state["login_email"] == "new@example.com"
assert "Create an account" in at.info[0].value
def test_auth_ui_flow_step1_to_oidc(mock_history_manager_instance):
"""Test UI transition from Step 1 (email) to Step 2c (OIDC) for OIDC user."""
# Mock history_manager.get_user to return a user WITHOUT a password
mock_user = MagicMock()
mock_user.password_hash = None
mock_history_manager_instance.get_user.return_value = mock_user
at = AppTest.from_file("src/ea_chatbot/app.py")
at.run()
# Step 1: Identification
at.text_input[0].set_value("oidc@example.com")
at.button[0].click().run()
# Verify transition to OIDC step
assert at.session_state["login_step"] == "oidc_login"
assert at.session_state["login_email"] == "oidc@example.com"
assert "configured for Single Sign-On" in at.info[0].value
def test_auth_ui_flow_back_button(mock_history_manager_instance):
"""Test that the 'Back' button returns to Step 1."""
at = AppTest.from_file("src/ea_chatbot/app.py")
# Simulate being on Step 2a
at.session_state["login_step"] = "login_password"
at.session_state["login_email"] = "local@example.com"
at.run()
# Click Back (index 1 in Step 2a)
at.button[1].click().run()
# Verify return to email step
assert at.session_state["login_step"] == "email"

View File

@@ -1,43 +0,0 @@
import pytest
from unittest.mock import MagicMock, patch
from ea_chatbot.auth import OIDCClient
@patch("ea_chatbot.auth.OAuth2Session")
def test_oidc_client_initialization(mock_oauth):
client = OIDCClient(
client_id="test_id",
client_secret="test_secret",
server_metadata_url="https://test.server/.well-known/openid-configuration"
)
assert client.oauth_session is not None
@patch("ea_chatbot.auth.requests")
@patch("ea_chatbot.auth.OAuth2Session")
def test_get_login_url(mock_oauth_cls, mock_requests):
# Setup mock session
mock_session = MagicMock()
mock_oauth_cls.return_value = mock_session
# Mock metadata response
mock_response = MagicMock()
mock_response.json.return_value = {
"authorization_endpoint": "https://test.server/auth",
"token_endpoint": "https://test.server/token",
"userinfo_endpoint": "https://test.server/userinfo"
}
mock_requests.get.return_value = mock_response
# Mock authorization url generation
mock_session.create_authorization_url.return_value = ("https://test.server/auth?response_type=code", "state")
client = OIDCClient(
client_id="test_id",
client_secret="test_secret",
server_metadata_url="https://test.server/.well-known/openid-configuration"
)
url = client.get_login_url()
assert url == "https://test.server/auth?response_type=code"
# Verify metadata was fetched via requests
mock_requests.get.assert_called_with("https://test.server/.well-known/openid-configuration")

View File

@@ -1,49 +0,0 @@
import pytest
from unittest.mock import MagicMock
from ea_chatbot.history.manager import HistoryManager
from ea_chatbot.auth import get_user_auth_type, AuthType
# Mocks
@pytest.fixture
def mock_history_manager():
return MagicMock(spec=HistoryManager)
def test_auth_flow_existing_local_user(mock_history_manager):
"""Test that an existing user with a password returns LOCAL auth type."""
# Setup
mock_user = MagicMock()
mock_user.password_hash = "hashed_secret"
mock_history_manager.get_user.return_value = mock_user
# Execute
auth_type = get_user_auth_type("test@example.com", mock_history_manager)
# Verify
assert auth_type == AuthType.LOCAL
mock_history_manager.get_user.assert_called_once_with("test@example.com")
def test_auth_flow_existing_oidc_user(mock_history_manager):
"""Test that an existing user WITHOUT a password returns OIDC auth type."""
# Setup
mock_user = MagicMock()
mock_user.password_hash = None # No password implies OIDC
mock_history_manager.get_user.return_value = mock_user
# Execute
auth_type = get_user_auth_type("sso@example.com", mock_history_manager)
# Verify
assert auth_type == AuthType.OIDC
mock_history_manager.get_user.assert_called_once_with("sso@example.com")
def test_auth_flow_new_user(mock_history_manager):
"""Test that a non-existent user returns NEW auth type."""
# Setup
mock_history_manager.get_user.return_value = None
# Execute
auth_type = get_user_auth_type("new@example.com", mock_history_manager)
# Verify
assert auth_type == AuthType.NEW
mock_history_manager.get_user.assert_called_once_with("new@example.com")

View File

@@ -1,5 +1,3 @@
import pytest
from pydantic import ValidationError
from ea_chatbot.config import Settings, LLMConfig
def test_default_settings():

View File

@@ -1,7 +1,7 @@
import pytest
import pandas as pd
import os
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock
from ea_chatbot.utils.database_inspection import get_primary_key, inspect_db_table, get_data_summary
@pytest.fixture

View File

@@ -3,7 +3,6 @@ import pandas as pd
from unittest.mock import MagicMock, patch
from matplotlib.figure import Figure
from ea_chatbot.graph.nodes.executor import executor_node
from ea_chatbot.graph.state import AgentState
@pytest.fixture
def mock_settings():

View File

@@ -1,4 +1,3 @@
import pytest
from langchain_core.messages import HumanMessage, AIMessage
from ea_chatbot.utils.helpers import merge_agent_state

View File

@@ -1,6 +1,4 @@
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, DeclarativeBase
# We anticipate these imports will fail initially
try:

View File

@@ -1,4 +1,3 @@
import pytest
from langchain_openai import ChatOpenAI
from ea_chatbot.config import LLMConfig
from ea_chatbot.utils.llm_factory import get_llm_model

View File

@@ -1,8 +1,7 @@
import os
import json
import pytest
import logging
from unittest.mock import MagicMock, patch
from unittest.mock import patch
from ea_chatbot.graph.workflow import app
from ea_chatbot.graph.state import AgentState
from ea_chatbot.utils.logging import get_logger

View File

@@ -2,7 +2,6 @@ import pytest
from unittest.mock import MagicMock, patch
from langchain_core.messages import HumanMessage, AIMessage
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node, QueryAnalysis
from ea_chatbot.graph.state import AgentState
@pytest.fixture
def mock_state_with_history():

View File

@@ -1,87 +0,0 @@
import pytest
from unittest.mock import MagicMock, patch
from ea_chatbot.auth import OIDCClient
@pytest.fixture
def oidc_config():
return {
"client_id": "test_id",
"client_secret": "test_secret",
"server_metadata_url": "https://example.com/.well-known/openid-configuration",
"redirect_uri": "http://localhost:8501"
}
@pytest.fixture
def mock_metadata():
return {
"authorization_endpoint": "https://example.com/auth",
"token_endpoint": "https://example.com/token",
"userinfo_endpoint": "https://example.com/userinfo"
}
def test_oidc_fetch_metadata(oidc_config, mock_metadata):
client = OIDCClient(**oidc_config)
with patch("requests.get") as mock_get:
mock_response = MagicMock()
mock_response.json.return_value = mock_metadata
mock_get.return_value = mock_response
metadata = client.fetch_metadata()
assert metadata == mock_metadata
mock_get.assert_called_once_with(oidc_config["server_metadata_url"])
# Second call should use cache
client.fetch_metadata()
assert mock_get.call_count == 1
def test_oidc_get_login_url(oidc_config, mock_metadata):
client = OIDCClient(**oidc_config)
client.metadata = mock_metadata
with patch.object(client.oauth_session, "create_authorization_url") as mock_create_url:
mock_create_url.return_value = ("https://example.com/auth?state=xyz", "xyz")
url = client.get_login_url()
assert url == "https://example.com/auth?state=xyz"
mock_create_url.assert_called_once_with(mock_metadata["authorization_endpoint"])
def test_oidc_get_login_url_missing_endpoint(oidc_config):
client = OIDCClient(**oidc_config)
client.metadata = {"some_other": "field"}
with pytest.raises(ValueError, match="authorization_endpoint not found"):
client.get_login_url()
def test_oidc_exchange_code_for_token(oidc_config, mock_metadata):
client = OIDCClient(**oidc_config)
client.metadata = mock_metadata
with patch.object(client.oauth_session, "fetch_token") as mock_fetch_token:
mock_fetch_token.return_value = {"access_token": "abc"}
token = client.exchange_code_for_token("test_code")
assert token == {"access_token": "abc"}
mock_fetch_token.assert_called_once_with(
mock_metadata["token_endpoint"],
code="test_code",
client_secret=oidc_config["client_secret"]
)
def test_oidc_get_user_info(oidc_config, mock_metadata):
client = OIDCClient(**oidc_config)
client.metadata = mock_metadata
with patch.object(client.oauth_session, "get") as mock_get:
mock_response = MagicMock()
mock_response.json.return_value = {"sub": "user123"}
mock_get.return_value = mock_response
user_info = client.get_user_info({"access_token": "abc"})
assert user_info == {"sub": "user123"}
assert client.oauth_session.token == {"access_token": "abc"}
mock_get.assert_called_once_with(mock_metadata["userinfo_endpoint"])

View File

@@ -1,7 +1,6 @@
import pytest
from unittest.mock import MagicMock, patch
from unittest.mock import patch
from ea_chatbot.auth import OIDCClient
from jose import jwt
@pytest.fixture
def oidc_config():

View File

@@ -1,7 +1,6 @@
import pytest
from unittest.mock import MagicMock, patch
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node, QueryAnalysis
from ea_chatbot.graph.state import AgentState
@pytest.fixture
def mock_state():

View File

@@ -1,5 +1,4 @@
import pytest
import logging
from unittest.mock import MagicMock, patch
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node, QueryAnalysis

View File

@@ -2,7 +2,6 @@ import pytest
from unittest.mock import MagicMock, patch
from langchain_core.messages import HumanMessage, AIMessage
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node, QueryAnalysis
from ea_chatbot.graph.state import AgentState
@pytest.fixture
def base_state():

View File

@@ -1,8 +1,5 @@
import pytest
from typing import get_type_hints, List
from langchain_core.messages import BaseMessage, HumanMessage
from typing import get_type_hints
from ea_chatbot.graph.state import AgentState
import operator
def test_agent_state_structure():
"""Verify that AgentState has the required fields and types."""

View File

@@ -2,7 +2,6 @@ import pytest
from unittest.mock import MagicMock, patch
from langchain_core.messages import AIMessage
from ea_chatbot.graph.nodes.summarizer import summarizer_node
from ea_chatbot.graph.state import AgentState
@pytest.fixture
def mock_llm():

View File

@@ -1,4 +1,3 @@
import pytest
from unittest.mock import MagicMock, patch
from ea_chatbot.graph.workflow import app
from ea_chatbot.graph.nodes.query_analyzer import QueryAnalysis

View File

@@ -1,5 +1,4 @@
import pytest
import yaml
from unittest.mock import MagicMock, patch
from langchain_core.messages import AIMessage
from ea_chatbot.graph.workflow import app