diff --git a/backend/tests/test_oidc_client_v2.py b/backend/tests/test_oidc_client_v2.py index 6d7beed..b130319 100644 --- a/backend/tests/test_oidc_client_v2.py +++ b/backend/tests/test_oidc_client_v2.py @@ -1,6 +1,6 @@ import pytest from unittest.mock import MagicMock, patch -from ea_chatbot.auth import OIDCClient +from ea_chatbot.auth import OIDCClient, get_user_auth_type, AuthType @pytest.fixture def oidc_config(): @@ -14,12 +14,69 @@ def oidc_config(): @pytest.fixture def mock_metadata(): return { + "issuer": "https://example.com", "authorization_endpoint": "https://example.com/auth", "token_endpoint": "https://example.com/token", "userinfo_endpoint": "https://example.com/userinfo", "jwks_uri": "https://example.com/jwks" } +def test_get_user_auth_type(): + history_manager = MagicMock() + + # Case: New user + history_manager.get_user.return_value = None + assert get_user_auth_type("new@test.com", history_manager) == AuthType.NEW + + # Case: Local user (has password) + user = MagicMock() + user.password_hash = "hashed_pass" + history_manager.get_user.return_value = user + assert get_user_auth_type("local@test.com", history_manager) == AuthType.LOCAL + + # Case: OIDC user (no password) + user.password_hash = None + assert get_user_auth_type("oidc@test.com", history_manager) == AuthType.OIDC + +def test_oidc_fetch_jwks(oidc_config, mock_metadata): + client = OIDCClient(**oidc_config) + client.metadata = mock_metadata + + with patch("requests.get") as mock_get: + mock_response = MagicMock() + mock_response.json.return_value = {"keys": []} + mock_get.return_value = mock_response + + jwks = client.fetch_jwks() + assert jwks == {"keys": []} + mock_get.assert_called_once_with(mock_metadata["jwks_uri"]) + + # Cache check + client.fetch_jwks() + assert mock_get.call_count == 1 + +def test_oidc_get_login_url_legacy(oidc_config, mock_metadata): + client = OIDCClient(**oidc_config) + client.metadata = mock_metadata + + with patch.object(client.oauth_session, "create_authorization_url") as mock_create: + mock_create.return_value = ("https://url", "state") + url = client.get_login_url() + assert url == "https://url" + +def test_oidc_get_user_info(oidc_config, mock_metadata): + client = OIDCClient(**oidc_config) + client.metadata = mock_metadata + + with patch.object(client.oauth_session, "get") as mock_get: + mock_resp = MagicMock() + mock_resp.json.return_value = {"email": "test@test.com"} + mock_get.return_value = mock_resp + + info = client.get_user_info({"access_token": "abc"}) + assert info["email"] == "test@test.com" + assert client.oauth_session.token == {"access_token": "abc"} + def test_oidc_generate_pkce(oidc_config): client = OIDCClient(**oidc_config) code_verifier, code_challenge = client.generate_pkce()