social_auth: Fix handling of user errors in the authentication process.

The code didn't account for existence of SOCIAL_AUTH_SUBDOMAIN. So the
redirects would happen to endpoints on the SOCIAL_AUTH_SUBDOMAIN, which
is incorrect. The redirects should happen to the realm from which the
user came.
This commit is contained in:
Mateusz Mandera 2021-06-26 18:51:43 +02:00 committed by Tim Abbott
parent 388932bcc4
commit 86c330b752
2 changed files with 77 additions and 32 deletions

View File

@ -168,7 +168,7 @@ class AuthBackendTest(ZulipTestCase):
if isinstance(backend, SocialAuthMixin): if isinstance(backend, SocialAuthMixin):
# Returns a redirect to login page with an error. # Returns a redirect to login page with an error.
self.assertEqual(result.status_code, 302) self.assertEqual(result.status_code, 302)
self.assertEqual(result.url, "/login/?is_deactivated=true") self.assertEqual(result.url, user_profile.realm.uri + "/login/?is_deactivated=true")
else: else:
# Just takes you back to the login page treating as # Just takes you back to the login page treating as
# invalid auth; this is correct because the form will # invalid auth; this is correct because the form will
@ -183,7 +183,12 @@ class AuthBackendTest(ZulipTestCase):
# Verify auth fails with a deactivated realm # Verify auth fails with a deactivated realm
do_deactivate_realm(user_profile.realm, acting_user=None) do_deactivate_realm(user_profile.realm, acting_user=None)
self.assertIsNone(backend.authenticate(**good_kwargs)) result = backend.authenticate(**good_kwargs)
if isinstance(backend, SocialAuthMixin):
self.assertEqual(result.status_code, 302)
self.assertEqual(result.url, user_profile.realm.uri + "/login/")
else:
self.assertIsNone(result)
# Verify auth works again after reactivating the realm # Verify auth works again after reactivating the realm
do_reactivate_realm(user_profile.realm) do_reactivate_realm(user_profile.realm)
@ -198,7 +203,12 @@ class AuthBackendTest(ZulipTestCase):
# Verify auth fails if the auth backend is disabled on server # Verify auth fails if the auth backend is disabled on server
with self.settings(AUTHENTICATION_BACKENDS=("zproject.backends.ZulipDummyBackend",)): with self.settings(AUTHENTICATION_BACKENDS=("zproject.backends.ZulipDummyBackend",)):
clear_supported_auth_backends_cache() clear_supported_auth_backends_cache()
self.assertIsNone(backend.authenticate(**good_kwargs)) result = backend.authenticate(**good_kwargs)
if isinstance(backend, SocialAuthMixin):
self.assertEqual(result.status_code, 302)
self.assertEqual(result.url, user_profile.realm.uri + "/login/")
else:
self.assertIsNone(result)
clear_supported_auth_backends_cache() clear_supported_auth_backends_cache()
# Verify auth fails if the auth backend is disabled for the realm # Verify auth fails if the auth backend is disabled for the realm
@ -217,7 +227,12 @@ class AuthBackendTest(ZulipTestCase):
# propagate the changes we just made to the actual realm # propagate the changes we just made to the actual realm
# object in good_kwargs. # object in good_kwargs.
good_kwargs["realm"] = user_profile.realm good_kwargs["realm"] = user_profile.realm
self.assertIsNone(backend.authenticate(**good_kwargs))
if isinstance(backend, SocialAuthMixin):
self.assertEqual(result.status_code, 302)
self.assertEqual(result.url, user_profile.realm.uri + "/login/")
else:
self.assertIsNone(result)
user_profile.realm.authentication_methods.set_bit(index, True) user_profile.realm.authentication_methods.set_bit(index, True)
user_profile.realm.save() user_profile.realm.save()
@ -1084,7 +1099,7 @@ class SocialAuthBase(DesktopFlowTestingLib, ZulipTestCase):
account_data_dict, expect_choose_email_screen=True, subdomain="zulip" account_data_dict, expect_choose_email_screen=True, subdomain="zulip"
) )
self.assertEqual(result.status_code, 302) self.assertEqual(result.status_code, 302)
self.assertEqual(result.url, "/login/?is_deactivated=true") self.assertEqual(result.url, user_profile.realm.uri + "/login/?is_deactivated=true")
self.assertEqual( self.assertEqual(
m.output, m.output,
[ [
@ -1108,15 +1123,15 @@ class SocialAuthBase(DesktopFlowTestingLib, ZulipTestCase):
def test_social_auth_invalid_email(self) -> None: def test_social_auth_invalid_email(self) -> None:
account_data_dict = self.get_account_data_dict(email="invalid", name=self.name) account_data_dict = self.get_account_data_dict(email="invalid", name=self.name)
subdomain = "zulip"
realm = get_realm(subdomain)
with self.assertLogs(self.logger_string, level="INFO") as m: with self.assertLogs(self.logger_string, level="INFO") as m:
result = self.social_auth_test( result = self.social_auth_test(
account_data_dict, account_data_dict,
expect_choose_email_screen=True, expect_choose_email_screen=True,
subdomain="zulip", subdomain=subdomain,
next="/user_uploads/image", next="/user_uploads/image",
) )
self.assertEqual(result.status_code, 302)
self.assertEqual(result.url, "/login/?next=/user_uploads/image")
self.assertEqual( self.assertEqual(
m.output, m.output,
[ [
@ -1126,6 +1141,8 @@ class SocialAuthBase(DesktopFlowTestingLib, ZulipTestCase):
) )
], ],
) )
self.assertEqual(result.status_code, 302)
self.assertEqual(result.url, realm.uri + "/register/")
def test_user_cannot_log_into_nonexisting_realm(self) -> None: def test_user_cannot_log_into_nonexisting_realm(self) -> None:
account_data_dict = self.get_account_data_dict(email=self.email, name=self.name) account_data_dict = self.get_account_data_dict(email=self.email, name=self.name)
@ -2204,18 +2221,20 @@ class SAMLAuthBackendTest(SocialAuthBase):
if the authentication attempt failed. See SAMLAuthBackend.auth_complete for details. if the authentication attempt failed. See SAMLAuthBackend.auth_complete for details.
""" """
account_data_dict = self.get_account_data_dict(email="invalid", name=self.name) account_data_dict = self.get_account_data_dict(email="invalid", name=self.name)
subdomain = "zulip"
realm = get_realm(subdomain)
with self.assertLogs(self.logger_string, "WARNING") as warn_log: with self.assertLogs(self.logger_string, "WARNING") as warn_log:
result = self.social_auth_test( result = self.social_auth_test(
account_data_dict, account_data_dict,
expect_choose_email_screen=True, expect_choose_email_screen=True,
subdomain="zulip", subdomain=subdomain,
next="/user_uploads/image", next="/user_uploads/image",
) )
self.assertEqual( self.assertEqual(
warn_log.output, [self.logger_output("SAML got invalid email argument.", "warning")] warn_log.output, [self.logger_output("SAML got invalid email argument.", "warning")]
) )
self.assertEqual(result.status_code, 302) self.assertEqual(result.status_code, 302)
self.assertEqual(result.url, "/login/") self.assertEqual(result.url, realm.uri + "/register/")
def test_social_auth_saml_multiple_idps_configured(self) -> None: def test_social_auth_saml_multiple_idps_configured(self) -> None:
# Setup a new SOCIAL_AUTH_SAML_ENABLED_IDPS dict with two idps. # Setup a new SOCIAL_AUTH_SAML_ENABLED_IDPS dict with two idps.
@ -3292,12 +3311,14 @@ class GitHubAuthBackendTest(SocialAuthBase):
email_data = [ email_data = [
dict(email=account_data_dict["email"], verified=False, primary=True), dict(email=account_data_dict["email"], verified=False, primary=True),
] ]
subdomain = "zulip"
realm = get_realm(subdomain)
with self.assertLogs(self.logger_string, level="WARNING") as m: with self.assertLogs(self.logger_string, level="WARNING") as m:
result = self.social_auth_test( result = self.social_auth_test(
account_data_dict, subdomain="zulip", email_data=email_data account_data_dict, subdomain=subdomain, email_data=email_data
) )
self.assertEqual(result.status_code, 302) self.assertEqual(result.status_code, 302)
self.assertEqual(result.url, "/login/") self.assertEqual(result.url, realm.uri + "/login/")
self.assertEqual( self.assertEqual(
m.output, m.output,
[ [
@ -3311,13 +3332,15 @@ class GitHubAuthBackendTest(SocialAuthBase):
@override_settings(SOCIAL_AUTH_GITHUB_TEAM_ID="zulip-webapp") @override_settings(SOCIAL_AUTH_GITHUB_TEAM_ID="zulip-webapp")
def test_social_auth_github_team_not_member_failed(self) -> None: def test_social_auth_github_team_not_member_failed(self) -> None:
account_data_dict = self.get_account_data_dict(email=self.email, name=self.name) account_data_dict = self.get_account_data_dict(email=self.email, name=self.name)
subdomain = "zulip"
realm = get_realm(subdomain)
with mock.patch( with mock.patch(
"social_core.backends.github.GithubTeamOAuth2.user_data", "social_core.backends.github.GithubTeamOAuth2.user_data",
side_effect=AuthFailed("Not found"), side_effect=AuthFailed("Not found"),
), self.assertLogs(self.logger_string, level="INFO") as mock_info: ), self.assertLogs(self.logger_string, level="INFO") as mock_info:
result = self.social_auth_test(account_data_dict, subdomain="zulip") result = self.social_auth_test(account_data_dict, subdomain=subdomain)
self.assertEqual(result.status_code, 302) self.assertEqual(result.status_code, 302)
self.assertEqual(result.url, "/login/") self.assertEqual(result.url, realm.uri + "/login/")
self.assertEqual( self.assertEqual(
mock_info.output, mock_info.output,
[ [
@ -3345,13 +3368,15 @@ class GitHubAuthBackendTest(SocialAuthBase):
@override_settings(SOCIAL_AUTH_GITHUB_ORG_NAME="Zulip") @override_settings(SOCIAL_AUTH_GITHUB_ORG_NAME="Zulip")
def test_social_auth_github_organization_not_member_failed(self) -> None: def test_social_auth_github_organization_not_member_failed(self) -> None:
account_data_dict = self.get_account_data_dict(email=self.email, name=self.name) account_data_dict = self.get_account_data_dict(email=self.email, name=self.name)
subdomain = "zulip"
realm = get_realm(subdomain)
with mock.patch( with mock.patch(
"social_core.backends.github.GithubOrganizationOAuth2.user_data", "social_core.backends.github.GithubOrganizationOAuth2.user_data",
side_effect=AuthFailed("Not found"), side_effect=AuthFailed("Not found"),
), self.assertLogs(self.logger_string, level="INFO") as mock_info: ), self.assertLogs(self.logger_string, level="INFO") as mock_info:
result = self.social_auth_test(account_data_dict, subdomain="zulip") result = self.social_auth_test(account_data_dict, subdomain=subdomain)
self.assertEqual(result.status_code, 302) self.assertEqual(result.status_code, 302)
self.assertEqual(result.url, "/login/") self.assertEqual(result.url, realm.uri + "/login/")
self.assertEqual( self.assertEqual(
mock_info.output, mock_info.output,
[ [
@ -3591,15 +3616,17 @@ class GitHubAuthBackendTest(SocialAuthBase):
dict(email="aaron@zulip.com", verified=True), dict(email="aaron@zulip.com", verified=True),
dict(email=account_data_dict["email"], verified=True), dict(email=account_data_dict["email"], verified=True),
] ]
subdomain = "zulip"
realm = get_realm(subdomain)
with self.assertLogs(self.logger_string, level="WARNING") as m: with self.assertLogs(self.logger_string, level="WARNING") as m:
result = self.social_auth_test( result = self.social_auth_test(
account_data_dict, account_data_dict,
subdomain="zulip", subdomain=subdomain,
expect_choose_email_screen=True, expect_choose_email_screen=True,
email_data=email_data, email_data=email_data,
) )
self.assertEqual(result.status_code, 302) self.assertEqual(result.status_code, 302)
self.assertEqual(result.url, "/login/") self.assertEqual(result.url, realm.uri + "/login/")
self.assertEqual( self.assertEqual(
m.output, m.output,
[ [
@ -3642,15 +3669,17 @@ class GitHubAuthBackendTest(SocialAuthBase):
dict(email="hamlet@zulip.com", verified=True, primary=True), dict(email="hamlet@zulip.com", verified=True, primary=True),
dict(email="aaron@zulip.com", verified=True), dict(email="aaron@zulip.com", verified=True),
] ]
subdomain = "zulip"
realm = get_realm(subdomain)
with self.assertLogs(self.logger_string, level="WARNING") as m: with self.assertLogs(self.logger_string, level="WARNING") as m:
result = self.social_auth_test( result = self.social_auth_test(
account_data_dict, account_data_dict,
subdomain="zulip", subdomain=subdomain,
expect_choose_email_screen=True, expect_choose_email_screen=True,
email_data=email_data, email_data=email_data,
) )
self.assertEqual(result.status_code, 302) self.assertEqual(result.status_code, 302)
self.assertEqual(result.url, "/login/") self.assertEqual(result.url, realm.uri + "/login/")
self.assertEqual( self.assertEqual(
m.output, m.output,
[ [
@ -3666,7 +3695,9 @@ class GitHubAuthBackendTest(SocialAuthBase):
# check if a user is denied to log in if the user manages to # check if a user is denied to log in if the user manages to
# send an unverified email that has an existing account in # send an unverified email that has an existing account in
# organisation through `email` GET parameter. # organisation through `email` GET parameter.
account_data_dict = self.get_account_data_dict(email="hamlet@zulip.com", name=self.name) subdomain = "zulip"
realm = get_realm(subdomain)
account_data_dict = dict(email="hamlet@zulip.com", name=self.name)
email_data = [ email_data = [
dict(email="iago@zulip.com", verified=True), dict(email="iago@zulip.com", verified=True),
dict(email="hamlet@zulip.com", verified=False), dict(email="hamlet@zulip.com", verified=False),
@ -3675,12 +3706,12 @@ class GitHubAuthBackendTest(SocialAuthBase):
with self.assertLogs(self.logger_string, level="WARNING") as m: with self.assertLogs(self.logger_string, level="WARNING") as m:
result = self.social_auth_test( result = self.social_auth_test(
account_data_dict, account_data_dict,
subdomain="zulip", subdomain=subdomain,
expect_choose_email_screen=True, expect_choose_email_screen=True,
email_data=email_data, email_data=email_data,
) )
self.assertEqual(result.status_code, 302) self.assertEqual(result.status_code, 302)
self.assertEqual(result.url, "/login/") self.assertEqual(result.url, realm.uri + "/login/")
self.assertEqual( self.assertEqual(
m.output, m.output,
[ [
@ -3735,10 +3766,12 @@ class GoogleAuthBackendTest(SocialAuthBase):
def test_social_auth_email_not_verified(self) -> None: def test_social_auth_email_not_verified(self) -> None:
account_data_dict = dict(email=self.email, name=self.name) account_data_dict = dict(email=self.email, name=self.name)
subdomain = "zulip"
realm = get_realm(subdomain)
with self.assertLogs(self.logger_string, level="WARNING") as m: with self.assertLogs(self.logger_string, level="WARNING") as m:
result = self.social_auth_test(account_data_dict, subdomain="zulip") result = self.social_auth_test(account_data_dict, subdomain=subdomain)
self.assertEqual(result.status_code, 302) self.assertEqual(result.status_code, 302)
self.assertEqual(result.url, "/login/") self.assertEqual(result.url, realm.uri + "/login/")
self.assertEqual( self.assertEqual(
m.output, m.output,
[ [

View File

@ -1296,11 +1296,23 @@ class ZulipRemoteUserBackend(RemoteUserBackend, ExternalAuthMethod):
] ]
def redirect_deactivated_user_to_login() -> HttpResponseRedirect: def redirect_to_signup(realm: Realm) -> HttpResponseRedirect:
signup_url = reverse("register")
redirect_url = realm.uri + signup_url
return HttpResponseRedirect(redirect_url)
def redirect_to_login(realm: Realm) -> HttpResponseRedirect:
login_url = reverse("login_page", kwargs={"template_name": "zerver/login.html"})
redirect_url = realm.uri + login_url
return HttpResponseRedirect(redirect_url)
def redirect_deactivated_user_to_login(realm: Realm) -> HttpResponseRedirect:
# Specifying the template name makes sure that the user is not redirected to dev_login in case of # Specifying the template name makes sure that the user is not redirected to dev_login in case of
# a deactivated account on a test server. # a deactivated account on a test server.
login_url = reverse("login_page", kwargs={"template_name": "zerver/login.html"}) login_url = reverse("login_page", kwargs={"template_name": "zerver/login.html"})
redirect_url = login_url + "?is_deactivated=true" redirect_url = realm.uri + login_url + "?is_deactivated=true"
return HttpResponseRedirect(redirect_url) return HttpResponseRedirect(redirect_url)
@ -1508,18 +1520,19 @@ def social_auth_finish(
# form on. # form on.
return HttpResponseRedirect(reverse("find_account")) return HttpResponseRedirect(reverse("find_account"))
realm = Realm.objects.get(id=return_data["realm_id"])
if inactive_user: if inactive_user:
backend.logger.info( backend.logger.info(
"Failed login attempt for deactivated account: %s@%s", "Failed login attempt for deactivated account: %s@%s",
return_data["inactive_user_id"], return_data["inactive_user_id"],
return_data["realm_string_id"], return_data["realm_string_id"],
) )
return redirect_deactivated_user_to_login() return redirect_deactivated_user_to_login(realm)
if auth_backend_disabled or inactive_realm or no_verified_email or email_not_associated: if auth_backend_disabled or inactive_realm or no_verified_email or email_not_associated:
# Redirect to login page. We can't send to registration # Redirect to login page. We can't send to registration
# workflow with these errors. We will redirect to login page. # workflow with these errors. We will redirect to login page.
return None return redirect_to_login(realm)
if invalid_email: if invalid_email:
# In case of invalid email, we will end up on registration page. # In case of invalid email, we will end up on registration page.
@ -1528,11 +1541,11 @@ def social_auth_finish(
"%s got invalid email argument.", "%s got invalid email argument.",
backend.auth_backend_name, backend.auth_backend_name,
) )
return None return redirect_to_signup(realm)
if auth_failed_reason: if auth_failed_reason:
backend.logger.info(auth_failed_reason) backend.logger.info(auth_failed_reason)
return None return redirect_to_login(realm)
# Structurally, all the cases where we don't have an authenticated # Structurally, all the cases where we don't have an authenticated
# email for the user should be handled above; this assertion helps # email for the user should be handled above; this assertion helps
@ -1545,7 +1558,6 @@ def social_auth_finish(
email_address = return_data["validated_email"] email_address = return_data["validated_email"]
full_name = return_data["full_name"] full_name = return_data["full_name"]
redirect_to = strategy.session_get("next") redirect_to = strategy.session_get("next")
realm = Realm.objects.get(id=return_data["realm_id"])
multiuse_object_key = strategy.session_get("multiuse_object_key", "") multiuse_object_key = strategy.session_get("multiuse_object_key", "")
mobile_flow_otp = strategy.session_get("mobile_flow_otp") mobile_flow_otp = strategy.session_get("mobile_flow_otp")