121 lines
4.2 KiB
Python
121 lines
4.2 KiB
Python
import pytest
|
|
from unittest.mock import MagicMock, patch
|
|
from ea_chatbot.auth import OIDCClient, get_user_auth_type, AuthType
|
|
|
|
@pytest.fixture
|
|
def oidc_config():
|
|
return {
|
|
"client_id": "test_id",
|
|
"client_secret": "test_secret",
|
|
"server_metadata_url": "https://example.com/.well-known/openid-configuration",
|
|
"redirect_uri": "http://localhost:5173/auth/callback"
|
|
}
|
|
|
|
@pytest.fixture
|
|
def mock_metadata():
|
|
return {
|
|
"issuer": "https://example.com",
|
|
"authorization_endpoint": "https://example.com/auth",
|
|
"token_endpoint": "https://example.com/token",
|
|
"userinfo_endpoint": "https://example.com/userinfo",
|
|
"jwks_uri": "https://example.com/jwks"
|
|
}
|
|
|
|
def test_get_user_auth_type():
|
|
history_manager = MagicMock()
|
|
|
|
# Case: New user
|
|
history_manager.get_user.return_value = None
|
|
assert get_user_auth_type("new@test.com", history_manager) == AuthType.NEW
|
|
|
|
# Case: Local user (has password)
|
|
user = MagicMock()
|
|
user.password_hash = "hashed_pass"
|
|
history_manager.get_user.return_value = user
|
|
assert get_user_auth_type("local@test.com", history_manager) == AuthType.LOCAL
|
|
|
|
# Case: OIDC user (no password)
|
|
user.password_hash = None
|
|
assert get_user_auth_type("oidc@test.com", history_manager) == AuthType.OIDC
|
|
|
|
def test_oidc_fetch_jwks(oidc_config, mock_metadata):
|
|
client = OIDCClient(**oidc_config)
|
|
client.metadata = mock_metadata
|
|
|
|
with patch("requests.get") as mock_get:
|
|
mock_response = MagicMock()
|
|
mock_response.json.return_value = {"keys": []}
|
|
mock_get.return_value = mock_response
|
|
|
|
jwks = client.fetch_jwks()
|
|
assert jwks == {"keys": []}
|
|
mock_get.assert_called_once_with(mock_metadata["jwks_uri"])
|
|
|
|
# Cache check
|
|
client.fetch_jwks()
|
|
assert mock_get.call_count == 1
|
|
|
|
def test_oidc_get_login_url_legacy(oidc_config, mock_metadata):
|
|
client = OIDCClient(**oidc_config)
|
|
client.metadata = mock_metadata
|
|
|
|
with patch.object(client.oauth_session, "create_authorization_url") as mock_create:
|
|
mock_create.return_value = ("https://url", "state")
|
|
url = client.get_login_url()
|
|
assert url == "https://url"
|
|
|
|
def test_oidc_get_user_info(oidc_config, mock_metadata):
|
|
client = OIDCClient(**oidc_config)
|
|
client.metadata = mock_metadata
|
|
|
|
with patch.object(client.oauth_session, "get") as mock_get:
|
|
mock_resp = MagicMock()
|
|
mock_resp.json.return_value = {"email": "test@test.com"}
|
|
mock_get.return_value = mock_resp
|
|
|
|
info = client.get_user_info({"access_token": "abc"})
|
|
assert info["email"] == "test@test.com"
|
|
assert client.oauth_session.token == {"access_token": "abc"}
|
|
|
|
def test_oidc_generate_pkce(oidc_config):
|
|
client = OIDCClient(**oidc_config)
|
|
code_verifier, code_challenge = client.generate_pkce()
|
|
assert len(code_verifier) >= 43
|
|
assert len(code_challenge) > 0
|
|
# Verifier should be high entropy
|
|
assert code_verifier != client.generate_pkce()[0]
|
|
|
|
def test_oidc_get_auth_data(oidc_config, mock_metadata):
|
|
client = OIDCClient(**oidc_config)
|
|
client.metadata = mock_metadata
|
|
|
|
auth_data = client.get_auth_data()
|
|
|
|
assert "url" in auth_data
|
|
assert "state" in auth_data
|
|
assert "nonce" in auth_data
|
|
assert "code_verifier" in auth_data
|
|
|
|
url = auth_data["url"]
|
|
assert "code_challenge=" in url
|
|
assert "code_challenge_method=S256" in url
|
|
assert f"state={auth_data['state']}" in url
|
|
assert f"nonce={auth_data['nonce']}" in url
|
|
|
|
def test_oidc_exchange_code_with_pkce(oidc_config, mock_metadata):
|
|
client = OIDCClient(**oidc_config)
|
|
client.metadata = mock_metadata
|
|
|
|
with patch.object(client.oauth_session, "fetch_token") as mock_fetch:
|
|
mock_fetch.return_value = {"access_token": "token"}
|
|
|
|
token = client.exchange_code_for_token("code", code_verifier="verifier")
|
|
|
|
assert token == {"access_token": "token"}
|
|
mock_fetch.assert_called_once_with(
|
|
mock_metadata["token_endpoint"],
|
|
code="code",
|
|
client_secret=oidc_config["client_secret"],
|
|
code_verifier="verifier"
|
|
)
|