diff --git a/zerver/tests/test_auth_backends.py b/zerver/tests/test_auth_backends.py index acb845b6e1..6aaa8a43d0 100644 --- a/zerver/tests/test_auth_backends.py +++ b/zerver/tests/test_auth_backends.py @@ -2176,6 +2176,55 @@ class SAMLAuthBackendTest(SocialAuthBase): settings.SOCIAL_AUTH_SAML_ENABLED_IDPS["test_idp"]["slo_url"], result["Location"] ) + @override_settings(TERMS_OF_SERVICE_VERSION=None) + def test_social_auth_sp_initiated_logout_after_desktop_registration(self) -> None: + """ + SAML SP-initiated logout relies on certain necessary information being saved + in the authenticated session that was established during SAML authentication. + The mechanism of plumbing the information to the final session through the signup process + is a bit different than the one for the simpler case of direct login to an already existing + account - thus a separate test is needed for the registration codepath. + """ + email = "newuser@zulip.com" + name = "Full Name" + subdomain = "zulip" + realm = get_realm("zulip") + desktop_flow_otp = "1234abcd" * 8 + account_data_dict = self.get_account_data_dict(email=email, name=name) + + result = self.social_auth_test( + account_data_dict, + subdomain="zulip", + expect_choose_email_screen=True, + is_signup=True, + desktop_flow_otp=desktop_flow_otp, + ) + self.stage_two_of_registration( + result, + realm, + subdomain, + email, + name, + name, + self.BACKEND_CLASS.full_name_validated, + desktop_flow_otp=desktop_flow_otp, + ) + + # Check that the SessionIndex got plumbed through to the final session + # acquired in the desktop application after signup. + session_index = self.client.session["saml_session_index"] + self.assertNotEqual(session_index, None) + + # Verify that the logout request will trigger the SAML SLO flow, + # just like in the regular case where the user simply logged in + # without needing to go through signup. + result = self.client_post("/accounts/logout/") + # A redirect to the IdP is returned. + self.assertEqual(result.status_code, 302) + self.assertIn( + settings.SOCIAL_AUTH_SAML_ENABLED_IDPS["test_idp"]["slo_url"], result["Location"] + ) + def test_saml_sp_initiated_logout_invalid_logoutresponse(self) -> None: hamlet = self.example_user("hamlet") self.login("hamlet") diff --git a/zerver/views/auth.py b/zerver/views/auth.py index 635ecc4c34..10bb941455 100644 --- a/zerver/views/auth.py +++ b/zerver/views/auth.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, from urllib.parse import urlencode import jwt +import orjson from cryptography.hazmat.primitives.ciphers.aead import AESGCM from django.conf import settings from django.contrib.auth import authenticate @@ -159,6 +160,7 @@ def maybe_send_to_registration( is_signup: bool = False, multiuse_object_key: str = "", full_name_validated: bool = False, + params_to_store_in_authenticated_session: Optional[Dict[str, str]] = None, ) -> HttpResponse: """Given a successful authentication for an email address (i.e. we've confirmed the user controls the email address) that does not @@ -199,6 +201,13 @@ def maybe_send_to_registration( desktop_flow_otp, expiry_seconds=EXPIRABLE_SESSION_VAR_DEFAULT_EXPIRY_SECS, ) + if params_to_store_in_authenticated_session: + set_expirable_session_var( + request.session, + "registration_desktop_flow_params_to_store_in_authenticated_session", + orjson.dumps(params_to_store_in_authenticated_session).decode(), + expiry_seconds=EXPIRABLE_SESSION_VAR_DEFAULT_EXPIRY_SECS, + ) try: # TODO: This should use get_realm_from_request, but a bunch of tests @@ -319,6 +328,7 @@ def register_remote_user(request: HttpRequest, result: ExternalAuthResult) -> Ht "is_signup", "multiuse_object_key", "full_name_validated", + "params_to_store_in_authenticated_session", ] for key in dict(kwargs): if key not in kwargs_to_pass: diff --git a/zerver/views/registration.py b/zerver/views/registration.py index d808c450fb..3bffdcbeba 100644 --- a/zerver/views/registration.py +++ b/zerver/views/registration.py @@ -4,6 +4,7 @@ from contextlib import suppress from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from urllib.parse import urlencode, urljoin +import orjson from django.conf import settings from django.contrib.auth import REDIRECT_FIELD_NAME, authenticate, get_backends from django.contrib.sessions.backends.base import SessionBase @@ -647,7 +648,17 @@ def login_and_go_to_home(request: HttpRequest, user_profile: UserProfile) -> Htt if mobile_flow_otp is not None: return finish_mobile_flow(request, user_profile, mobile_flow_otp) elif desktop_flow_otp is not None: - return finish_desktop_flow(request, user_profile, desktop_flow_otp) + params_to_store_in_authenticated_session = orjson.loads( + get_expirable_session_var( + request.session, + "registration_desktop_flow_params_to_store_in_authenticated_session", + default_value="{}", + delete=True, + ) + ) + return finish_desktop_flow( + request, user_profile, desktop_flow_otp, params_to_store_in_authenticated_session + ) do_login(request, user_profile) # Using 'mark_sanitized' to work around false positive where Pysa thinks