diff --git a/zerver/lib/test_helpers.py b/zerver/lib/test_helpers.py index ac655fa260..87a982753b 100644 --- a/zerver/lib/test_helpers.py +++ b/zerver/lib/test_helpers.py @@ -551,7 +551,9 @@ def write_instrumentation_reports(full_suite: bool, include_webhooks: bool) -> N def load_subdomain_token(response: Union["TestHttpResponse", HttpResponse]) -> ExternalAuthDataDict: assert isinstance(response, HttpResponseRedirect) token = response.url.rsplit("/", 1)[1] - data = ExternalAuthResult(login_token=token, delete_stored_data=False).data_dict + data = ExternalAuthResult( + request=mock.MagicMock(), login_token=token, delete_stored_data=False + ).data_dict assert data is not None return data diff --git a/zerver/tests/test_auth_backends.py b/zerver/tests/test_auth_backends.py index f69887dd78..ea5ac0a010 100644 --- a/zerver/tests/test_auth_backends.py +++ b/zerver/tests/test_auth_backends.py @@ -124,6 +124,7 @@ from zerver.views.auth import log_into_subdomain, maybe_send_to_registration from zproject.backends import ( AUTH_BACKEND_NAME_MAP, AppleAuthBackend, + AuthFuncT, AzureADAuthBackend, DevAuthBackend, EmailAuthBackend, @@ -796,6 +797,7 @@ class SocialAuthBase(DesktopFlowTestingLib, ZulipTestCase, ABC): alternative_start_url: Optional[str] = None, *, user_agent: Optional[str] = None, + extra_headers: Optional[Dict[str, Any]] = None, ) -> Tuple[str, Dict[str, Any]]: url = self.LOGIN_URL if alternative_start_url is not None: @@ -823,6 +825,9 @@ class SocialAuthBase(DesktopFlowTestingLib, ZulipTestCase, ABC): if user_agent is not None: headers["HTTP_USER_AGENT"] = user_agent + if extra_headers is not None: + headers.update(extra_headers) + return url, headers def social_auth_test_finish( @@ -859,6 +864,7 @@ class SocialAuthBase(DesktopFlowTestingLib, ZulipTestCase, ABC): expect_choose_email_screen: bool = False, alternative_start_url: Optional[str] = None, user_agent: Optional[str] = None, + extra_headers: Optional[Dict[str, Any]] = None, **extra_data: Any, ) -> "TestHttpResponse": """Main entry point for all social authentication tests. @@ -896,6 +902,7 @@ class SocialAuthBase(DesktopFlowTestingLib, ZulipTestCase, ABC): multiuse_object_key, alternative_start_url, user_agent=user_agent, + extra_headers=extra_headers, ) result = self.client_get(url, **headers) @@ -1048,6 +1055,84 @@ class SocialAuthBase(DesktopFlowTestingLib, ZulipTestCase, ABC): m.output[0], ) + def test_social_auth_custom_auth_decorator(self) -> None: + account_data_dict = self.get_account_data_dict(email=self.email, name=self.name) + + backends_with_restriction = [] + + def custom_auth_wrapper( + auth_func: AuthFuncT, *args: Any, **kwargs: Any + ) -> Optional[UserProfile]: + nonlocal backends_with_restriction + + backend = args[0] + backend_name = backend.name + request = args[1] + test_header_value = request.headers.get("X-Test-Auth-Header") + + user_profile = auth_func(*args, **kwargs) + if backend_name in backends_with_restriction and test_header_value != "allowed": + raise JsonableError("Forbidden header value") + + return user_profile + + # It's the ZulipDummyBackend that runs in the social auth codepaths, causing + # the custom_auth_wrapper logic to be executed. + backends_with_restriction = ["dummy"] + with self.settings(CUSTOM_AUTHENTICATION_WRAPPER_FUNCTION=custom_auth_wrapper): + result = self.social_auth_test( + account_data_dict, + expect_choose_email_screen=False, + subdomain="zulip", + ) + self.assert_json_error(result, "Forbidden header value") + + with self.assertLogs(self.logger_string, level="INFO") as m: + result = self.social_auth_test( + account_data_dict, + expect_choose_email_screen=False, + subdomain="zulip", + next="/user_uploads/image", + extra_headers={"HTTP_X_TEST_AUTH_HEADER": "allowed"}, + ) + + self.assertEqual(result.status_code, 302) + url = result["Location"] + self.assertTrue(url.startswith("http://zulip.testserver/accounts/login/subdomain/")) + + url = url.replace("http://zulip.testserver", "") + + result = self.client_get(url, subdomain="zulip", HTTP_X_TEST_AUTH_HEADER="allowed") + self.assertEqual(result.status_code, 302) + self.assert_logged_in_user_id(self.user_profile.id) + + # Test with a silly custom_auth_wrapper that always returns None, to verify + # logging of such failures (which doesn't run if the wrapper is throwing an exception early + # like above) + + def custom_auth_wrapper_none( + auth_func: AuthFuncT, *args: Any, **kwargs: Any + ) -> Optional[UserProfile]: + return None + + with self.settings(CUSTOM_AUTHENTICATION_WRAPPER_FUNCTION=custom_auth_wrapper_none): + with self.assertLogs(self.logger_string, level="INFO") as m: + result = self.social_auth_test( + account_data_dict, + expect_choose_email_screen=False, + subdomain="zulip", + ) + self.assertEqual(result.status_code, 302) + self.assertEqual(result["Location"], "http://zulip.testserver/login/") + + self.assertEqual( + m.output, + [ + f"INFO:{self.logger_string}:Authentication attempt from 127.0.0.1: " + f"subdomain=zulip;username={self.email};outcome=failed;return_data={{}}" + ], + ) + @override_settings(SOCIAL_AUTH_SUBDOMAIN=None) def test_when_social_auth_subdomain_is_not_set(self) -> None: account_data_dict = self.get_account_data_dict(email=self.email, name=self.name) @@ -1845,6 +1930,7 @@ class SAMLAuthBackendTest(SocialAuthBase): multiuse_object_key: str = "", user_agent: Optional[str] = None, extra_attributes: Mapping[str, List[str]] = {}, + extra_headers: Optional[Dict[str, Any]] = None, **extra_data: Any, ) -> "TestHttpResponse": url, headers = self.prepare_login_url_and_headers( @@ -1855,6 +1941,7 @@ class SAMLAuthBackendTest(SocialAuthBase): next, multiuse_object_key, user_agent=user_agent, + extra_headers=extra_headers, ) result = self.client_get(url, **headers) @@ -3358,6 +3445,7 @@ class AppleAuthBackendNativeFlowTest(AppleAuthMixin, SocialAuthBase): account_data_dict: Mapping[str, str] = {}, *, user_agent: Optional[str] = None, + extra_headers: Optional[Dict[str, Any]] = None, ) -> Tuple[str, Dict[str, Any]]: url, headers = super().prepare_login_url_and_headers( subdomain, @@ -3368,6 +3456,7 @@ class AppleAuthBackendNativeFlowTest(AppleAuthMixin, SocialAuthBase): multiuse_object_key, alternative_start_url=alternative_start_url, user_agent=user_agent, + extra_headers=extra_headers, ) params = {"native_flow": "true"} @@ -3400,6 +3489,7 @@ class AppleAuthBackendNativeFlowTest(AppleAuthMixin, SocialAuthBase): alternative_start_url: Optional[str] = None, skip_id_token: bool = False, user_agent: Optional[str] = None, + extra_headers: Optional[Dict[str, Any]] = None, **extra_data: Any, ) -> "TestHttpResponse": """In Apple's native authentication flow, the client app authenticates @@ -3432,6 +3522,7 @@ class AppleAuthBackendNativeFlowTest(AppleAuthMixin, SocialAuthBase): user_agent=user_agent, id_token=id_token, account_data_dict=account_data_dict, + extra_headers=extra_headers, ) with self.apple_jwk_url_mock(): @@ -7459,3 +7550,91 @@ class LDAPGroupSyncTest(ZulipTestCase): # Don't load the base class as a test: https://bugs.python.org/issue17519. del SocialAuthBase + + +class TestCustomAuthDecorator(ZulipTestCase): + def test_custom_auth_decorator(self) -> None: + call_count = 0 + backends_with_restriction = [] + + def custom_auth_wrapper( + auth_func: AuthFuncT, *args: Any, **kwargs: Any + ) -> Optional[UserProfile]: + nonlocal call_count + nonlocal backends_with_restriction + call_count += 1 + + backend = args[0] + backend_name = backend.name + request = args[1] + test_header_value = request.headers.get("X-Test-Auth-Header") + + user_profile = auth_func(*args, **kwargs) + if backend_name in backends_with_restriction and test_header_value != "allowed": + raise JsonableError("Forbidden header value") + + return user_profile + + with self.settings(CUSTOM_AUTHENTICATION_WRAPPER_FUNCTION=custom_auth_wrapper): + self.login("hamlet") + self.assertEqual(call_count, 1) + + backends_with_restriction = ["email", "dummy"] + + realm = get_realm("zulip") + hamlet = self.example_user("hamlet") + password = "testpassword" + + request = mock.MagicMock() + request.headers = {"X-Test-Auth-Header": "allowed"} + + # The wrapper structurally gets executed whenever .authenticate() for a backend + # is called, so it doesn't matter whether e.g. auth credentials are correct or not. + result = EmailAuthBackend().authenticate( + request, username=hamlet.delivery_email, password="wrong", realm=realm + ) + self.assertEqual(result, None) + self.assertEqual(call_count, 2) + + hamlet.set_password(password) + hamlet.save() + result = EmailAuthBackend().authenticate( + request, username=hamlet.delivery_email, password=password, realm=realm + ) + self.assertEqual(result, hamlet) + self.assertEqual(call_count, 3) + + # But without the appropriate header value, this fails. + request.headers = {} + with self.assertRaisesRegex(JsonableError, "Forbidden header value"): + EmailAuthBackend().authenticate( + request, username=hamlet.delivery_email, password=password, realm=realm + ) + + self.assertEqual(call_count, 4) + + # Now try the registration codepath. + alice_email = self.nonreg_email("alice") + password = "password" + inviter = self.example_user("iago") + prereg_user = PreregistrationUser.objects.create( + email=alice_email, referred_by=inviter, realm=realm + ) + + confirmation_link = create_confirmation_link(prereg_user, Confirmation.USER_REGISTRATION) + registration_key = confirmation_link.split("/")[-1] + + url = "/accounts/register/" + with self.settings(CUSTOM_AUTHENTICATION_WRAPPER_FUNCTION=custom_auth_wrapper): + self.client_post( + url, {"key": registration_key, "from_confirmation": 1, "full_name": "alice"} + ) + result = self.submit_reg_form_for_user(alice_email, password, key=registration_key) + + # The account gets created, because it's the authentication layer that's wrapped + # with custom logic, so it doesn't affect the registration process itself - just + # the signing in of the user at the end. Ultimately, the user cannot acquire an + # authenticated session, so the objective of the functionality is accomplished. + self.assert_json_error(result, "Forbidden header value") + self.assertEqual(UserProfile.objects.latest("id").delivery_email, alice_email) + self.assertEqual(call_count, 5) diff --git a/zerver/views/auth.py b/zerver/views/auth.py index 965ca96d54..16e98a82a9 100644 --- a/zerver/views/auth.py +++ b/zerver/views/auth.py @@ -733,7 +733,7 @@ def log_into_subdomain(request: HttpRequest, token: str) -> HttpResponse: return HttpResponse(status=400) try: - result = ExternalAuthResult(login_token=token) + result = ExternalAuthResult(request=request, login_token=token) except ExternalAuthResult.InvalidTokenError: logging.warning("log_into_subdomain: Invalid token given: %s", token) return render(request, "zerver/log_into_subdomain_token_invalid.html", status=400) diff --git a/zerver/views/registration.py b/zerver/views/registration.py index 5a35f436bd..04c0b827f6 100644 --- a/zerver/views/registration.py +++ b/zerver/views/registration.py @@ -636,6 +636,7 @@ def registration_helper( # This dummy_backend check below confirms the user is # authenticating to the correct subdomain. auth_result = authenticate( + request=request, username=user_profile.delivery_email, realm=realm, return_data=return_data, diff --git a/zproject/backends.py b/zproject/backends.py index f9a8ebe606..3d7383da7c 100644 --- a/zproject/backends.py +++ b/zproject/backends.py @@ -337,6 +337,15 @@ def auth_rate_limiting_already_applied(request: HttpRequest) -> bool: # defined by backends, so we need a decorator that doesn't break function signatures. # @decorator does this for us. # The usual @wraps from functools breaks signatures, so it can't be used here. +@decorator +def custom_auth_decorator(auth_func: AuthFuncT, *args: Any, **kwargs: Any) -> Optional[UserProfile]: + custom_auth_wrapper_func = settings.CUSTOM_AUTHENTICATION_WRAPPER_FUNCTION + if custom_auth_wrapper_func is None: + return auth_func(*args, **kwargs) + else: + return custom_auth_wrapper_func(auth_func, *args, **kwargs) + + @decorator def rate_limit_auth(auth_func: AuthFuncT, *args: Any, **kwargs: Any) -> Optional[UserProfile]: if not settings.RATE_LIMITING_AUTHENTICATE: @@ -443,6 +452,9 @@ class ZulipDummyBackend(ZulipAuthMixin): when explicitly requested by including the use_dummy_backend kwarg. """ + name = "dummy" + + @custom_auth_decorator def authenticate( self, request: Optional[HttpRequest] = None, @@ -487,6 +499,7 @@ class EmailAuthBackend(ZulipAuthMixin): @rate_limit_auth @log_auth_attempts + @custom_auth_decorator def authenticate( self, request: HttpRequest, @@ -1002,6 +1015,7 @@ class ZulipLDAPAuthBackend(ZulipLDAPAuthBackendBase): @rate_limit_auth @log_auth_attempts + @custom_auth_decorator def authenticate( self, request: Optional[HttpRequest] = None, @@ -1408,6 +1422,7 @@ class ExternalAuthResult: *, user_profile: Optional[UserProfile] = None, data_dict: Optional[ExternalAuthDataDict] = None, + request: Optional[HttpRequest] = None, login_token: Optional[str] = None, delete_stored_data: bool = True, ) -> None: @@ -1418,7 +1433,8 @@ class ExternalAuthResult: assert (not data_dict) and ( user_profile is None ), "Passing in data_dict or user_profile with login_token is disallowed." - self.instantiate_with_token(login_token, delete_stored_data) + assert request is not None, "Passing in request with login_token is required." + self.instantiate_with_token(request, login_token, delete_stored_data) else: self.data_dict = data_dict.copy() self.user_profile = user_profile @@ -1457,7 +1473,9 @@ class ExternalAuthResult: token = key.split(self.LOGIN_KEY_PREFIX, 1)[1] # remove the prefix return token - def instantiate_with_token(self, token: str, delete_stored_data: bool = True) -> None: + def instantiate_with_token( + self, request: HttpRequest, token: str, delete_stored_data: bool = True + ) -> None: key = self.LOGIN_KEY_FORMAT.format(token=token) data = get_dict_from_redis(redis_client, self.LOGIN_KEY_FORMAT, key) if data is None or None in [data.get("email"), data.get("subdomain")]: @@ -1477,7 +1495,9 @@ class ExternalAuthResult: # more customized error messages for those unlikely races, but # it's likely not worth implementing. realm = get_realm(data["subdomain"]) - auth_result = authenticate(username=data["email"], realm=realm, use_dummy_backend=True) + auth_result = authenticate( + request=request, username=data["email"], realm=realm, use_dummy_backend=True + ) if auth_result is not None: assert isinstance(auth_result, UserProfile) self.user_profile = auth_result @@ -1873,6 +1893,35 @@ def social_auth_finish( str(e), ) + if user_profile: + # This call to authenticate() is just to get to invoke the custom_auth_decorator logic. + # Social auth backends don't work via authenticate() in the same way as normal backends, + # so we can't just wrap their authenticate() methods. But the decorator is applied on + # ZulipDummyBackend.authenticate(), so we can invoke it here to trigger the custom logic. + # + # Note: We're only doing in the case where we already have a user_profile, meaning the + # account already exists and the user is just logging in. The new account registration case + # is handled in the registration codepath. + validated_user_profile = authenticate( + request=strategy.request, + username=user_profile.delivery_email, + realm=realm, + use_dummy_backend=True, + ) + if validated_user_profile is None or validated_user_profile != user_profile: + # Log this as as a failure to authenticate via the social backend, since that's + # the correct way to think about this. ZulipDummyBackend is just an implementation + # tool, not an actual backend a user could be authenticating through. + log_auth_attempt( + backend.logger, + strategy.request, + realm, + username=email_address, + succeeded=False, + return_data={}, + ) + return redirect_to_login(realm) + # At this point, we have now confirmed that the user has # demonstrated control over the target email address. # diff --git a/zproject/default_settings.py b/zproject/default_settings.py index 91d0ef9c59..151ba99f65 100644 --- a/zproject/default_settings.py +++ b/zproject/default_settings.py @@ -1,6 +1,6 @@ import os from email.headerregistry import Address -from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple from django_auth_ldap.config import GroupOfUniqueNamesType, LDAPGroupType @@ -12,6 +12,8 @@ from .config import DEVELOPMENT, PRODUCTION, get_secret if TYPE_CHECKING: from django_auth_ldap.config import LDAPSearch + from zerver.models.users import UserProfile + if PRODUCTION: # nocoverage from .prod_settings import EXTERNAL_HOST, ZULIP_ADMINISTRATOR else: @@ -619,3 +621,5 @@ CAN_ACCESS_ALL_USERS_GROUP_LIMITS_PRESENCE = False # General expiry time for signed tokens we may generate # in some places through the codebase. SIGNED_ACCESS_TOKEN_VALIDITY_IN_SECONDS = 60 + +CUSTOM_AUTHENTICATION_WRAPPER_FUNCTION: Optional[Callable[..., Optional["UserProfile"]]] = None