feat(auth): Complete OIDC security refactor and modernize test suite
- Refactored OIDC flow to implement PKCE, state/nonce validation, and BFF pattern. - Centralized configuration in Settings class (DEV_MODE, FRONTEND_URL, OIDC_REDIRECT_URI). - Updated auth routers to use conditional secure cookie flags based on DEV_MODE. - Modernized and cleaned up test suite by removing legacy Streamlit tests. - Fixed linting errors and unused imports across the backend.
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import patch
|
||||
from ea_chatbot.api.main import app
|
||||
from ea_chatbot.history.models import User
|
||||
|
||||
@@ -66,31 +66,43 @@ def test_protected_route_without_token():
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_oidc_login_redirect():
|
||||
"""Test that OIDC login returns a redirect URL."""
|
||||
"""Test that OIDC login returns a redirect URL and sets session cookie."""
|
||||
with patch("ea_chatbot.api.routers.auth.oidc_client") as mock_oidc:
|
||||
mock_oidc.get_login_url.return_value = "https://oidc-provider.com/auth"
|
||||
mock_oidc.get_auth_data.return_value = {
|
||||
"url": "https://oidc-provider.com/auth",
|
||||
"state": "test_state",
|
||||
"nonce": "test_nonce",
|
||||
"code_verifier": "test_verifier"
|
||||
}
|
||||
|
||||
response = client.get("/api/v1/auth/oidc/login")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["url"] == "https://oidc-provider.com/auth"
|
||||
assert "oidc_session" in response.cookies
|
||||
|
||||
def test_oidc_callback_success_ajax():
|
||||
"""Test successful OIDC callback and JWT issuance via AJAX."""
|
||||
def test_oidc_callback_success():
|
||||
"""Test successful OIDC callback and JWT issuance."""
|
||||
with patch("ea_chatbot.api.routers.auth.oidc_client") as mock_oidc, \
|
||||
patch("ea_chatbot.api.routers.auth.OIDCSession.decrypt") as mock_decrypt, \
|
||||
patch("ea_chatbot.api.routers.auth.history_manager") as mock_hm:
|
||||
|
||||
mock_oidc.exchange_code_for_token.return_value = {"access_token": "oidc-token"}
|
||||
mock_oidc.get_user_info.return_value = {"email": "sso@example.com", "name": "SSO User"}
|
||||
mock_decrypt.return_value = {
|
||||
"state": "test_state",
|
||||
"nonce": "test_nonce",
|
||||
"code_verifier": "test_verifier"
|
||||
}
|
||||
mock_oidc.exchange_code_for_token.return_value = {"id_token": "fake_id_token"}
|
||||
mock_oidc.validate_id_token.return_value = {"email": "sso@example.com", "name": "SSO User"}
|
||||
mock_hm.sync_user_from_oidc.return_value = User(id="sso-123", username="sso@example.com", display_name="SSO User")
|
||||
|
||||
client.cookies.set("oidc_session", "fake_token")
|
||||
response = client.get(
|
||||
"/api/v1/auth/oidc/callback?code=some-code",
|
||||
headers={"Accept": "application/json"}
|
||||
"/api/v1/auth/oidc/callback?code=some-code&state=test_state",
|
||||
follow_redirects=False
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert "access_token" in response.json()
|
||||
assert response.json()["token_type"] == "bearer"
|
||||
assert response.status_code == 302
|
||||
assert "access_token" in response.cookies
|
||||
|
||||
def test_get_me_success():
|
||||
"""Test getting current user with a valid token."""
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import patch
|
||||
from ea_chatbot.api.main import app
|
||||
from ea_chatbot.history.models import User
|
||||
from ea_chatbot.api.utils import create_access_token
|
||||
@@ -83,15 +83,25 @@ def test_logout_clears_cookie():
|
||||
def test_oidc_callback_redirects_with_cookie():
|
||||
"""Test that OIDC callback sets cookie and redirects."""
|
||||
with patch("ea_chatbot.api.routers.auth.oidc_client") as mock_oidc, \
|
||||
patch("ea_chatbot.api.routers.auth.OIDCSession.decrypt") as mock_decrypt, \
|
||||
patch("ea_chatbot.api.routers.auth.history_manager") as mock_hm:
|
||||
|
||||
mock_oidc.exchange_code_for_token.return_value = {"access_token": "oidc-token"}
|
||||
mock_oidc.get_user_info.return_value = {"email": "sso@example.com", "name": "SSO User"}
|
||||
mock_decrypt.return_value = {
|
||||
"state": "test_state",
|
||||
"nonce": "test_nonce",
|
||||
"code_verifier": "test_verifier"
|
||||
}
|
||||
mock_oidc.exchange_code_for_token.return_value = {"id_token": "fake_id_token"}
|
||||
mock_oidc.validate_id_token.return_value = {"email": "sso@example.com", "name": "SSO User"}
|
||||
mock_hm.sync_user_from_oidc.return_value = User(id="sso-123", username="sso@example.com", display_name="SSO User")
|
||||
|
||||
# Follow_redirects=False to catch the 307/302
|
||||
response = client.get("/api/v1/auth/oidc/callback?code=some-code", follow_redirects=False)
|
||||
# Set the session cookie
|
||||
client.cookies.set("oidc_session", "fake_token")
|
||||
|
||||
assert response.status_code in [302, 303, 307]
|
||||
# Follow_redirects=False to catch the 302
|
||||
response = client.get("/api/v1/auth/oidc/callback?code=some-code&state=test_state", follow_redirects=False)
|
||||
|
||||
assert response.status_code == 302
|
||||
assert "access_token" in response.cookies
|
||||
assert "/auth/callback" in response.headers["location"]
|
||||
# Should redirect to FRONTEND_URL (default http://localhost:5173)
|
||||
assert "http://localhost:5173" in response.headers["location"]
|
||||
|
||||
82
backend/tests/api/test_oidc_flow.py
Normal file
82
backend/tests/api/test_oidc_flow.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import patch
|
||||
from ea_chatbot.api.main import app
|
||||
from ea_chatbot.history.models import User
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Provides a fresh TestClient for each test."""
|
||||
return TestClient(app)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_auth_data():
|
||||
return {
|
||||
"url": "https://example.com/auth?state=test_state",
|
||||
"state": "test_state",
|
||||
"nonce": "test_nonce",
|
||||
"code_verifier": "test_verifier"
|
||||
}
|
||||
|
||||
def test_oidc_login_sets_cookie(client, mock_auth_data):
|
||||
"""Test that OIDC login initiation sets the temporary session cookie."""
|
||||
with patch("ea_chatbot.api.routers.auth.oidc_client") as mock_oidc:
|
||||
mock_oidc.get_auth_data.return_value = mock_auth_data
|
||||
|
||||
response = client.get("/api/v1/auth/oidc/login")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["url"] == mock_auth_data["url"]
|
||||
assert "oidc_session" in response.cookies
|
||||
|
||||
def test_oidc_callback_missing_cookie(client):
|
||||
"""Test that OIDC callback fails if the temporary session cookie is missing."""
|
||||
with patch("ea_chatbot.api.routers.auth.oidc_client"):
|
||||
# Ensure no cookies are set
|
||||
client.cookies.clear()
|
||||
response = client.get("/api/v1/auth/oidc/callback?code=test_code&state=test_state", follow_redirects=False)
|
||||
|
||||
# Should redirect to frontend with error
|
||||
assert response.status_code == 302
|
||||
assert "error=oidc_failed" in response.headers["location"]
|
||||
|
||||
def test_oidc_callback_invalid_state(client, mock_auth_data):
|
||||
"""Test that OIDC callback fails if the state in the URL doesn't match the cookie."""
|
||||
with patch("ea_chatbot.api.routers.auth.oidc_client"), \
|
||||
patch("ea_chatbot.api.routers.auth.OIDCSession.decrypt") as mock_decrypt:
|
||||
|
||||
# Mock valid cookie content
|
||||
mock_decrypt.return_value = mock_auth_data
|
||||
|
||||
# Send different state in URL
|
||||
client.cookies.set("oidc_session", "fake_token")
|
||||
response = client.get("/api/v1/auth/oidc/callback?code=test_code&state=wrong_state", follow_redirects=False)
|
||||
|
||||
assert response.status_code == 302
|
||||
assert "error=oidc_failed" in response.headers["location"]
|
||||
|
||||
def test_oidc_callback_success(client, mock_auth_data):
|
||||
"""Test successful OIDC callback and session establishment."""
|
||||
with patch("ea_chatbot.api.routers.auth.oidc_client") as mock_oidc, \
|
||||
patch("ea_chatbot.api.routers.auth.OIDCSession.decrypt") as mock_decrypt, \
|
||||
patch("ea_chatbot.api.routers.auth.history_manager") as mock_hm, \
|
||||
patch("ea_chatbot.api.routers.auth.create_access_token") as mock_create_token:
|
||||
|
||||
mock_decrypt.return_value = mock_auth_data
|
||||
mock_oidc.exchange_code_for_token.return_value = {"id_token": "fake_id_token"}
|
||||
mock_oidc.validate_id_token.return_value = {"email": "user@test.com", "name": "Test User"}
|
||||
mock_hm.sync_user_from_oidc.return_value = User(id="user-123", username="user@test.com")
|
||||
mock_create_token.return_value = "fake_access_token"
|
||||
|
||||
client.cookies.set("oidc_session", "fake_token")
|
||||
response = client.get("/api/v1/auth/oidc/callback?code=test_code&state=test_state", follow_redirects=False)
|
||||
|
||||
assert response.status_code == 302
|
||||
# Redirect to FRONTEND_URL (default localhost:5173)
|
||||
assert "http://localhost:5173" in response.headers["location"]
|
||||
assert "access_token" in response.cookies
|
||||
|
||||
# Verify oidc_session was cleaned up (deleted)
|
||||
# In RedirectResponse, delete_cookie works by setting the cookie with empty value and past expiry
|
||||
cookie_header = response.headers.get("set-cookie", "")
|
||||
assert "oidc_session=;" in cookie_header or 'oidc_session=""' in cookie_header
|
||||
@@ -3,10 +3,8 @@ from fastapi.testclient import TestClient
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
from ea_chatbot.api.main import app
|
||||
from ea_chatbot.api.dependencies import get_current_user
|
||||
from ea_chatbot.history.models import User, Conversation, Message, Plot
|
||||
from ea_chatbot.history.models import User, Conversation
|
||||
from ea_chatbot.api.utils import create_access_token
|
||||
from datetime import datetime, timezone
|
||||
import json
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
@@ -30,7 +30,6 @@ async def test_get_checkpointer_initialization():
|
||||
checkpoint.get_pool = mock_get_pool
|
||||
|
||||
# Patch AsyncPostgresSaver where it's imported in checkpoint.py
|
||||
import langgraph.checkpoint.postgres.aio as pg_aio
|
||||
original_saver = checkpoint.AsyncPostgresSaver
|
||||
checkpoint.AsyncPostgresSaver = mock_saver_class
|
||||
|
||||
|
||||
@@ -1,90 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from streamlit.testing.v1 import AppTest
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
# Ensure src is in python path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../src')))
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_history_manager():
|
||||
"""Globally mock HistoryManager to avoid DB calls during AppTest."""
|
||||
with patch("ea_chatbot.history.manager.HistoryManager") as mock_cls:
|
||||
instance = mock_cls.return_value
|
||||
instance.create_conversation.return_value = MagicMock(id="conv_123")
|
||||
instance.get_conversations.return_value = []
|
||||
instance.get_messages.return_value = []
|
||||
instance.add_message.return_value = MagicMock()
|
||||
instance.update_conversation_summary.return_value = MagicMock()
|
||||
yield instance
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_stream():
|
||||
with patch("ea_chatbot.graph.workflow.app.stream") as mock_stream:
|
||||
# Mock events from app.stream
|
||||
mock_stream.return_value = [
|
||||
{"query_analyzer": {"next_action": "research"}},
|
||||
{"researcher": {"messages": [AIMessage(content="Research result")]}}
|
||||
]
|
||||
yield mock_stream
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user():
|
||||
user = MagicMock()
|
||||
user.id = "test_id"
|
||||
user.username = "test@example.com"
|
||||
user.display_name = "Test User"
|
||||
return user
|
||||
|
||||
def test_app_initial_state(mock_app_stream, mock_user):
|
||||
"""Test that the app initializes with the correct title and empty history."""
|
||||
at = AppTest.from_file("src/ea_chatbot/app.py")
|
||||
|
||||
# Simulate logged-in user
|
||||
at.session_state["user"] = mock_user
|
||||
|
||||
at.run()
|
||||
|
||||
assert not at.exception
|
||||
assert at.title[0].value == "🗳️ Election Analytics Chatbot"
|
||||
|
||||
# Check session state initialization
|
||||
assert "messages" in at.session_state
|
||||
assert len(at.session_state["messages"]) == 0
|
||||
|
||||
def test_app_dev_mode_toggle(mock_app_stream, mock_user):
|
||||
"""Test that the dev mode toggle exists in the sidebar."""
|
||||
with patch.dict(os.environ, {"DEV_MODE": "false"}):
|
||||
at = AppTest.from_file("src/ea_chatbot/app.py")
|
||||
at.session_state["user"] = mock_user
|
||||
at.run()
|
||||
|
||||
# Check for sidebar toggle (checkbox)
|
||||
assert len(at.sidebar.checkbox) > 0
|
||||
dev_mode_toggle = at.sidebar.checkbox[0]
|
||||
assert dev_mode_toggle.label == "Dev Mode"
|
||||
assert dev_mode_toggle.value is False
|
||||
|
||||
def test_app_graph_execution_streaming(mock_app_stream, mock_user, mock_history_manager):
|
||||
"""Test that entering a prompt triggers the graph stream and displays response."""
|
||||
at = AppTest.from_file("src/ea_chatbot/app.py")
|
||||
at.session_state["user"] = mock_user
|
||||
at.run()
|
||||
|
||||
# Input a question
|
||||
at.chat_input[0].set_value("Test question").run()
|
||||
|
||||
# Verify graph stream was called
|
||||
assert mock_app_stream.called
|
||||
|
||||
# Message should be added to history
|
||||
assert len(at.session_state["messages"]) == 2
|
||||
assert at.session_state["messages"][0]["role"] == "user"
|
||||
assert at.session_state["messages"][1]["role"] == "assistant"
|
||||
assert "Research result" in at.session_state["messages"][1]["content"]
|
||||
|
||||
# Verify history manager was used
|
||||
assert mock_history_manager.create_conversation.called
|
||||
assert mock_history_manager.add_message.called
|
||||
@@ -1,83 +0,0 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from streamlit.testing.v1 import AppTest
|
||||
from ea_chatbot.auth import AuthType
|
||||
|
||||
@pytest.fixture
|
||||
def mock_history_manager_instance():
|
||||
# We need to patch before the AppTest loads the module
|
||||
with patch("ea_chatbot.history.manager.HistoryManager") as mock_cls:
|
||||
instance = mock_cls.return_value
|
||||
yield instance
|
||||
|
||||
def test_auth_ui_flow_step1_to_password(mock_history_manager_instance):
|
||||
"""Test UI transition from Step 1 (email) to Step 2a (password) for LOCAL user."""
|
||||
# Patch BEFORE creating AppTest
|
||||
mock_user = MagicMock()
|
||||
mock_user.password_hash = "hashed_password"
|
||||
mock_history_manager_instance.get_user.return_value = mock_user
|
||||
|
||||
at = AppTest.from_file("src/ea_chatbot/app.py")
|
||||
at.run()
|
||||
|
||||
# Step 1: Identification
|
||||
assert at.session_state["login_step"] == "email"
|
||||
at.text_input[0].set_value("local@example.com")
|
||||
|
||||
at.button[0].click().run()
|
||||
|
||||
# Verify transition to password step
|
||||
assert at.session_state["login_step"] == "login_password"
|
||||
assert at.session_state["login_email"] == "local@example.com"
|
||||
assert "Welcome back" in at.info[0].value
|
||||
|
||||
def test_auth_ui_flow_step1_to_register(mock_history_manager_instance):
|
||||
"""Test UI transition from Step 1 (email) to Step 2b (registration) for NEW user."""
|
||||
mock_history_manager_instance.get_user.return_value = None
|
||||
|
||||
at = AppTest.from_file("src/ea_chatbot/app.py")
|
||||
at.run()
|
||||
|
||||
# Step 1: Identification
|
||||
at.text_input[0].set_value("new@example.com")
|
||||
|
||||
at.button[0].click().run()
|
||||
|
||||
# Verify transition to registration step
|
||||
assert at.session_state["login_step"] == "register_details"
|
||||
assert at.session_state["login_email"] == "new@example.com"
|
||||
assert "Create an account" in at.info[0].value
|
||||
|
||||
def test_auth_ui_flow_step1_to_oidc(mock_history_manager_instance):
|
||||
"""Test UI transition from Step 1 (email) to Step 2c (OIDC) for OIDC user."""
|
||||
# Mock history_manager.get_user to return a user WITHOUT a password
|
||||
mock_user = MagicMock()
|
||||
mock_user.password_hash = None
|
||||
mock_history_manager_instance.get_user.return_value = mock_user
|
||||
|
||||
at = AppTest.from_file("src/ea_chatbot/app.py")
|
||||
at.run()
|
||||
|
||||
# Step 1: Identification
|
||||
at.text_input[0].set_value("oidc@example.com")
|
||||
|
||||
at.button[0].click().run()
|
||||
|
||||
# Verify transition to OIDC step
|
||||
assert at.session_state["login_step"] == "oidc_login"
|
||||
assert at.session_state["login_email"] == "oidc@example.com"
|
||||
assert "configured for Single Sign-On" in at.info[0].value
|
||||
|
||||
def test_auth_ui_flow_back_button(mock_history_manager_instance):
|
||||
"""Test that the 'Back' button returns to Step 1."""
|
||||
at = AppTest.from_file("src/ea_chatbot/app.py")
|
||||
# Simulate being on Step 2a
|
||||
at.session_state["login_step"] = "login_password"
|
||||
at.session_state["login_email"] = "local@example.com"
|
||||
at.run()
|
||||
|
||||
# Click Back (index 1 in Step 2a)
|
||||
at.button[1].click().run()
|
||||
|
||||
# Verify return to email step
|
||||
assert at.session_state["login_step"] == "email"
|
||||
@@ -1,43 +0,0 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.auth import OIDCClient
|
||||
|
||||
@patch("ea_chatbot.auth.OAuth2Session")
|
||||
def test_oidc_client_initialization(mock_oauth):
|
||||
client = OIDCClient(
|
||||
client_id="test_id",
|
||||
client_secret="test_secret",
|
||||
server_metadata_url="https://test.server/.well-known/openid-configuration"
|
||||
)
|
||||
assert client.oauth_session is not None
|
||||
|
||||
@patch("ea_chatbot.auth.requests")
|
||||
@patch("ea_chatbot.auth.OAuth2Session")
|
||||
def test_get_login_url(mock_oauth_cls, mock_requests):
|
||||
# Setup mock session
|
||||
mock_session = MagicMock()
|
||||
mock_oauth_cls.return_value = mock_session
|
||||
|
||||
# Mock metadata response
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"authorization_endpoint": "https://test.server/auth",
|
||||
"token_endpoint": "https://test.server/token",
|
||||
"userinfo_endpoint": "https://test.server/userinfo"
|
||||
}
|
||||
mock_requests.get.return_value = mock_response
|
||||
|
||||
# Mock authorization url generation
|
||||
mock_session.create_authorization_url.return_value = ("https://test.server/auth?response_type=code", "state")
|
||||
|
||||
client = OIDCClient(
|
||||
client_id="test_id",
|
||||
client_secret="test_secret",
|
||||
server_metadata_url="https://test.server/.well-known/openid-configuration"
|
||||
)
|
||||
|
||||
url = client.get_login_url()
|
||||
|
||||
assert url == "https://test.server/auth?response_type=code"
|
||||
# Verify metadata was fetched via requests
|
||||
mock_requests.get.assert_called_with("https://test.server/.well-known/openid-configuration")
|
||||
@@ -1,49 +0,0 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
from ea_chatbot.history.manager import HistoryManager
|
||||
from ea_chatbot.auth import get_user_auth_type, AuthType
|
||||
|
||||
# Mocks
|
||||
@pytest.fixture
|
||||
def mock_history_manager():
|
||||
return MagicMock(spec=HistoryManager)
|
||||
|
||||
def test_auth_flow_existing_local_user(mock_history_manager):
|
||||
"""Test that an existing user with a password returns LOCAL auth type."""
|
||||
# Setup
|
||||
mock_user = MagicMock()
|
||||
mock_user.password_hash = "hashed_secret"
|
||||
mock_history_manager.get_user.return_value = mock_user
|
||||
|
||||
# Execute
|
||||
auth_type = get_user_auth_type("test@example.com", mock_history_manager)
|
||||
|
||||
# Verify
|
||||
assert auth_type == AuthType.LOCAL
|
||||
mock_history_manager.get_user.assert_called_once_with("test@example.com")
|
||||
|
||||
def test_auth_flow_existing_oidc_user(mock_history_manager):
|
||||
"""Test that an existing user WITHOUT a password returns OIDC auth type."""
|
||||
# Setup
|
||||
mock_user = MagicMock()
|
||||
mock_user.password_hash = None # No password implies OIDC
|
||||
mock_history_manager.get_user.return_value = mock_user
|
||||
|
||||
# Execute
|
||||
auth_type = get_user_auth_type("sso@example.com", mock_history_manager)
|
||||
|
||||
# Verify
|
||||
assert auth_type == AuthType.OIDC
|
||||
mock_history_manager.get_user.assert_called_once_with("sso@example.com")
|
||||
|
||||
def test_auth_flow_new_user(mock_history_manager):
|
||||
"""Test that a non-existent user returns NEW auth type."""
|
||||
# Setup
|
||||
mock_history_manager.get_user.return_value = None
|
||||
|
||||
# Execute
|
||||
auth_type = get_user_auth_type("new@example.com", mock_history_manager)
|
||||
|
||||
# Verify
|
||||
assert auth_type == AuthType.NEW
|
||||
mock_history_manager.get_user.assert_called_once_with("new@example.com")
|
||||
@@ -1,5 +1,3 @@
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
from ea_chatbot.config import Settings, LLMConfig
|
||||
|
||||
def test_default_settings():
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
import pandas as pd
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock
|
||||
from ea_chatbot.utils.database_inspection import get_primary_key, inspect_db_table, get_data_summary
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -3,7 +3,6 @@ import pandas as pd
|
||||
from unittest.mock import MagicMock, patch
|
||||
from matplotlib.figure import Figure
|
||||
from ea_chatbot.graph.nodes.executor import executor_node
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings():
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import pytest
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from ea_chatbot.utils.helpers import merge_agent_state
|
||||
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker, DeclarativeBase
|
||||
|
||||
# We anticipate these imports will fail initially
|
||||
try:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import pytest
|
||||
from langchain_openai import ChatOpenAI
|
||||
from ea_chatbot.config import LLMConfig
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import os
|
||||
import json
|
||||
import pytest
|
||||
import logging
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import patch
|
||||
from ea_chatbot.graph.workflow import app
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.utils.logging import get_logger
|
||||
|
||||
@@ -2,7 +2,6 @@ import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node, QueryAnalysis
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
|
||||
@pytest.fixture
|
||||
def mock_state_with_history():
|
||||
|
||||
@@ -1,87 +0,0 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.auth import OIDCClient
|
||||
|
||||
@pytest.fixture
|
||||
def oidc_config():
|
||||
return {
|
||||
"client_id": "test_id",
|
||||
"client_secret": "test_secret",
|
||||
"server_metadata_url": "https://example.com/.well-known/openid-configuration",
|
||||
"redirect_uri": "http://localhost:8501"
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_metadata():
|
||||
return {
|
||||
"authorization_endpoint": "https://example.com/auth",
|
||||
"token_endpoint": "https://example.com/token",
|
||||
"userinfo_endpoint": "https://example.com/userinfo"
|
||||
}
|
||||
|
||||
def test_oidc_fetch_metadata(oidc_config, mock_metadata):
|
||||
client = OIDCClient(**oidc_config)
|
||||
|
||||
with patch("requests.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = mock_metadata
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
metadata = client.fetch_metadata()
|
||||
|
||||
assert metadata == mock_metadata
|
||||
mock_get.assert_called_once_with(oidc_config["server_metadata_url"])
|
||||
|
||||
# Second call should use cache
|
||||
client.fetch_metadata()
|
||||
assert mock_get.call_count == 1
|
||||
|
||||
def test_oidc_get_login_url(oidc_config, mock_metadata):
|
||||
client = OIDCClient(**oidc_config)
|
||||
client.metadata = mock_metadata
|
||||
|
||||
with patch.object(client.oauth_session, "create_authorization_url") as mock_create_url:
|
||||
mock_create_url.return_value = ("https://example.com/auth?state=xyz", "xyz")
|
||||
|
||||
url = client.get_login_url()
|
||||
|
||||
assert url == "https://example.com/auth?state=xyz"
|
||||
mock_create_url.assert_called_once_with(mock_metadata["authorization_endpoint"])
|
||||
|
||||
def test_oidc_get_login_url_missing_endpoint(oidc_config):
|
||||
client = OIDCClient(**oidc_config)
|
||||
client.metadata = {"some_other": "field"}
|
||||
|
||||
with pytest.raises(ValueError, match="authorization_endpoint not found"):
|
||||
client.get_login_url()
|
||||
|
||||
def test_oidc_exchange_code_for_token(oidc_config, mock_metadata):
|
||||
client = OIDCClient(**oidc_config)
|
||||
client.metadata = mock_metadata
|
||||
|
||||
with patch.object(client.oauth_session, "fetch_token") as mock_fetch_token:
|
||||
mock_fetch_token.return_value = {"access_token": "abc"}
|
||||
|
||||
token = client.exchange_code_for_token("test_code")
|
||||
|
||||
assert token == {"access_token": "abc"}
|
||||
mock_fetch_token.assert_called_once_with(
|
||||
mock_metadata["token_endpoint"],
|
||||
code="test_code",
|
||||
client_secret=oidc_config["client_secret"]
|
||||
)
|
||||
|
||||
def test_oidc_get_user_info(oidc_config, mock_metadata):
|
||||
client = OIDCClient(**oidc_config)
|
||||
client.metadata = mock_metadata
|
||||
|
||||
with patch.object(client.oauth_session, "get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"sub": "user123"}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
user_info = client.get_user_info({"access_token": "abc"})
|
||||
|
||||
assert user_info == {"sub": "user123"}
|
||||
assert client.oauth_session.token == {"access_token": "abc"}
|
||||
mock_get.assert_called_once_with(mock_metadata["userinfo_endpoint"])
|
||||
@@ -1,7 +1,6 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import patch
|
||||
from ea_chatbot.auth import OIDCClient
|
||||
from jose import jwt
|
||||
|
||||
@pytest.fixture
|
||||
def oidc_config():
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node, QueryAnalysis
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
|
||||
@pytest.fixture
|
||||
def mock_state():
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import pytest
|
||||
import logging
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node, QueryAnalysis
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node, QueryAnalysis
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
|
||||
@pytest.fixture
|
||||
def base_state():
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
import pytest
|
||||
from typing import get_type_hints, List
|
||||
from langchain_core.messages import BaseMessage, HumanMessage
|
||||
from typing import get_type_hints
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
import operator
|
||||
|
||||
def test_agent_state_structure():
|
||||
"""Verify that AgentState has the required fields and types."""
|
||||
|
||||
@@ -2,7 +2,6 @@ import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from langchain_core.messages import AIMessage
|
||||
from ea_chatbot.graph.nodes.summarizer import summarizer_node
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm():
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.graph.workflow import app
|
||||
from ea_chatbot.graph.nodes.query_analyzer import QueryAnalysis
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import pytest
|
||||
import yaml
|
||||
from unittest.mock import MagicMock, patch
|
||||
from langchain_core.messages import AIMessage
|
||||
from ea_chatbot.graph.workflow import app
|
||||
|
||||
Reference in New Issue
Block a user