invites: Lock the realm when determining invitation counts.

This prevents users from hammering the invitation endpoint, causing
races, and inviting more users than they should otherwise be allowed
to.

Doing this requires that we not raise InvitationError when we have
partially succeeded; that behaviour is left to the one callsite of
do_invite_users.

Reported by Lakshit Agarwal (@chiekosec).
This commit is contained in:
Alex Vandiver 2024-01-10 21:01:21 +00:00 committed by Tim Abbott
parent eef5d22944
commit d863aa56de
12 changed files with 263 additions and 206 deletions

View File

@ -1,5 +1,6 @@
from contextlib import AbstractContextManager, ExitStack, contextmanager
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List, Optional, Tuple, Type from typing import Any, Dict, Iterator, List, Optional, Tuple, Type
from unittest import mock from unittest import mock
import orjson import orjson
@ -1640,6 +1641,23 @@ class TestLoggingCountStats(AnalyticsTestCase):
def test_invites_sent(self) -> None: def test_invites_sent(self) -> None:
property = "invites_sent::day" property = "invites_sent::day"
@contextmanager
def invite_context(
too_many_recent_realm_invites: bool = False, failure: bool = False
) -> Iterator[None]:
managers: List[AbstractContextManager[Any]] = [
mock.patch(
"zerver.actions.invites.too_many_recent_realm_invites", return_value=False
),
self.captureOnCommitCallbacks(execute=True),
]
if failure:
managers.append(self.assertRaises(InvitationError))
with ExitStack() as stack:
for mgr in managers:
stack.enter_context(mgr)
yield
def assertInviteCountEquals(count: int) -> None: def assertInviteCountEquals(count: int) -> None:
self.assertEqual( self.assertEqual(
count, count,
@ -1652,7 +1670,7 @@ class TestLoggingCountStats(AnalyticsTestCase):
stream, _ = self.create_stream_with_recipient() stream, _ = self.create_stream_with_recipient()
invite_expires_in_minutes = 2 * 24 * 60 invite_expires_in_minutes = 2 * 24 * 60
with mock.patch("zerver.actions.invites.too_many_recent_realm_invites", return_value=False): with invite_context():
do_invite_users( do_invite_users(
user, user,
["user1@domain.tld", "user2@domain.tld"], ["user1@domain.tld", "user2@domain.tld"],
@ -1663,7 +1681,7 @@ class TestLoggingCountStats(AnalyticsTestCase):
# We currently send emails when re-inviting users that haven't # We currently send emails when re-inviting users that haven't
# turned into accounts, so count them towards the total # turned into accounts, so count them towards the total
with mock.patch("zerver.actions.invites.too_many_recent_realm_invites", return_value=False): with invite_context():
do_invite_users( do_invite_users(
user, user,
["user1@domain.tld", "user2@domain.tld"], ["user1@domain.tld", "user2@domain.tld"],
@ -1673,9 +1691,7 @@ class TestLoggingCountStats(AnalyticsTestCase):
assertInviteCountEquals(4) assertInviteCountEquals(4)
# Test mix of good and malformed invite emails # Test mix of good and malformed invite emails
with self.assertRaises(InvitationError), mock.patch( with invite_context(failure=True):
"zerver.actions.invites.too_many_recent_realm_invites", return_value=False
):
do_invite_users( do_invite_users(
user, user,
["user3@domain.tld", "malformed"], ["user3@domain.tld", "malformed"],
@ -1685,15 +1701,14 @@ class TestLoggingCountStats(AnalyticsTestCase):
assertInviteCountEquals(4) assertInviteCountEquals(4)
# Test inviting existing users # Test inviting existing users
with self.assertRaises(InvitationError), mock.patch( with invite_context():
"zerver.actions.invites.too_many_recent_realm_invites", return_value=False skipped = do_invite_users(
):
do_invite_users(
user, user,
["first@domain.tld", "user4@domain.tld"], ["first@domain.tld", "user4@domain.tld"],
[stream], [stream],
invite_expires_in_minutes=invite_expires_in_minutes, invite_expires_in_minutes=invite_expires_in_minutes,
) )
self.assert_length(skipped, 1)
assertInviteCountEquals(5) assertInviteCountEquals(5)
# Revoking invite should not give you credit # Revoking invite should not give you credit
@ -1703,7 +1718,7 @@ class TestLoggingCountStats(AnalyticsTestCase):
assertInviteCountEquals(5) assertInviteCountEquals(5)
# Resending invite should cost you # Resending invite should cost you
with mock.patch("zerver.actions.invites.too_many_recent_realm_invites", return_value=False): with invite_context():
do_resend_user_invite_email(assert_is_not_none(PreregistrationUser.objects.first())) do_resend_user_invite_email(assert_is_not_none(PreregistrationUser.objects.first()))
assertInviteCountEquals(6) assertInviteCountEquals(6)

View File

@ -709,16 +709,17 @@ class TestSupportEndpoint(ZulipTestCase):
stream: str, invitee_email: str, realm: Optional[Realm] = None stream: str, invitee_email: str, realm: Optional[Realm] = None
) -> None: ) -> None:
invite_expires_in_minutes = 10 * 24 * 60 invite_expires_in_minutes = 10 * 24 * 60
self.client_post( with self.captureOnCommitCallbacks(execute=True):
"/json/invites", self.client_post(
{ "/json/invites",
"invitee_emails": [invitee_email], {
"stream_ids": orjson.dumps([self.get_stream_id(stream, realm)]).decode(), "invitee_emails": [invitee_email],
"invite_expires_in_minutes": invite_expires_in_minutes, "stream_ids": orjson.dumps([self.get_stream_id(stream, realm)]).decode(),
"invite_as": PreregistrationUser.INVITE_AS["MEMBER"], "invite_expires_in_minutes": invite_expires_in_minutes,
}, "invite_as": PreregistrationUser.INVITE_AS["MEMBER"],
subdomain=realm.string_id if realm is not None else "zulip", },
) subdomain=realm.string_id if realm is not None else "zulip",
)
def check_hamlet_user_query_result(result: "TestHttpResponse") -> None: def check_hamlet_user_query_result(result: "TestHttpResponse") -> None:
assert_user_details_in_html_response( assert_user_details_in_html_response(

View File

@ -21,10 +21,11 @@ from zerver.lib.email_validation import (
) )
from zerver.lib.exceptions import InvitationError from zerver.lib.exceptions import InvitationError
from zerver.lib.invites import notify_invites_changed from zerver.lib.invites import notify_invites_changed
from zerver.lib.queue import queue_json_publish from zerver.lib.queue import queue_event_on_commit
from zerver.lib.send_email import FromAddress, clear_scheduled_invitation_emails, send_email from zerver.lib.send_email import FromAddress, clear_scheduled_invitation_emails, send_email
from zerver.lib.timestamp import datetime_to_timestamp from zerver.lib.timestamp import datetime_to_timestamp
from zerver.lib.types import UnspecifiedValue from zerver.lib.types import UnspecifiedValue
from zerver.lib.utils import assert_is_not_none
from zerver.models import Message, MultiuseInvite, PreregistrationUser, Realm, Stream, UserProfile from zerver.models import Message, MultiuseInvite, PreregistrationUser, Realm, Stream, UserProfile
from zerver.models.prereg_users import filter_to_valid_prereg_users from zerver.models.prereg_users import filter_to_valid_prereg_users
@ -193,6 +194,7 @@ def check_invite_limit(realm: Realm, num_invitees: int) -> None:
) )
@transaction.atomic
def do_invite_users( def do_invite_users(
user_profile: UserProfile, user_profile: UserProfile,
invitee_emails: Collection[str], invitee_emails: Collection[str],
@ -200,23 +202,25 @@ def do_invite_users(
*, *,
invite_expires_in_minutes: Optional[int], invite_expires_in_minutes: Optional[int],
invite_as: int = PreregistrationUser.INVITE_AS["MEMBER"], invite_as: int = PreregistrationUser.INVITE_AS["MEMBER"],
) -> None: ) -> List[Tuple[str, str, bool]]:
num_invites = len(invitee_emails) num_invites = len(invitee_emails)
check_invite_limit(user_profile.realm, num_invites) # Lock the realm, since we need to not race with other invitations
realm = Realm.objects.select_for_update().get(id=user_profile.realm_id)
check_invite_limit(realm, num_invites)
if settings.BILLING_ENABLED: if settings.BILLING_ENABLED:
from corporate.lib.registration import check_spare_licenses_available_for_inviting_new_users from corporate.lib.registration import check_spare_licenses_available_for_inviting_new_users
if invite_as == PreregistrationUser.INVITE_AS["GUEST_USER"]: if invite_as == PreregistrationUser.INVITE_AS["GUEST_USER"]:
check_spare_licenses_available_for_inviting_new_users( check_spare_licenses_available_for_inviting_new_users(
user_profile.realm, extra_guests_count=num_invites realm, extra_guests_count=num_invites
) )
else: else:
check_spare_licenses_available_for_inviting_new_users( check_spare_licenses_available_for_inviting_new_users(
user_profile.realm, extra_non_guests_count=num_invites realm, extra_non_guests_count=num_invites
) )
realm = user_profile.realm
if not realm.invite_required: if not realm.invite_required:
# Inhibit joining an open realm to send spam invitations. # Inhibit joining an open realm to send spam invitations.
min_age = timedelta(days=settings.INVITES_MIN_USER_AGE_DAYS) min_age = timedelta(days=settings.INVITES_MIN_USER_AGE_DAYS)
@ -232,7 +236,7 @@ def do_invite_users(
good_emails: Set[str] = set() good_emails: Set[str] = set()
errors: List[Tuple[str, str, bool]] = [] errors: List[Tuple[str, str, bool]] = []
validate_email_allowed_in_realm = get_realm_email_validator(user_profile.realm) validate_email_allowed_in_realm = get_realm_email_validator(realm)
for email in invitee_emails: for email in invitee_emails:
if email == "": if email == "":
continue continue
@ -251,7 +255,7 @@ def do_invite_users(
but we still need to make sure they're not but we still need to make sure they're not
gonna conflict with existing users gonna conflict with existing users
""" """
error_dict = get_existing_user_errors(user_profile.realm, good_emails) error_dict = get_existing_user_errors(realm, good_emails)
skipped: List[Tuple[str, str, bool]] = [] skipped: List[Tuple[str, str, bool]] = []
for email in error_dict: for email in error_dict:
@ -278,7 +282,7 @@ def do_invite_users(
# is used for rate limiting invitations, rather than keeping track of # is used for rate limiting invitations, rather than keeping track of
# when exactly invitations were sent # when exactly invitations were sent
do_increment_logging_stat( do_increment_logging_stat(
user_profile.realm, realm,
COUNT_STATS["invites_sent::day"], COUNT_STATS["invites_sent::day"],
None, None,
timezone_now(), timezone_now(),
@ -290,7 +294,7 @@ def do_invite_users(
for email in validated_emails: for email in validated_emails:
# The logged in user is the referrer. # The logged in user is the referrer.
prereg_user = PreregistrationUser( prereg_user = PreregistrationUser(
email=email, referred_by=user_profile, invited_as=invite_as, realm=user_profile.realm email=email, referred_by=user_profile, invited_as=invite_as, realm=realm
) )
prereg_user.save() prereg_user.save()
stream_ids = [stream.id for stream in streams] stream_ids = [stream.id for stream in streams]
@ -299,22 +303,14 @@ def do_invite_users(
event = { event = {
"prereg_id": prereg_user.id, "prereg_id": prereg_user.id,
"referrer_id": user_profile.id, "referrer_id": user_profile.id,
"email_language": user_profile.realm.default_language, "email_language": realm.default_language,
"invite_expires_in_minutes": invite_expires_in_minutes, "invite_expires_in_minutes": invite_expires_in_minutes,
} }
queue_json_publish("invites", event) queue_event_on_commit("invites", event)
if skipped: notify_invites_changed(realm, changed_invite_referrer=user_profile)
raise InvitationError(
_( return skipped
"Some of those addresses are already using Zulip, "
"so we didn't send them an invitation. We did send "
"invitations to everyone else!"
),
skipped,
sent_invitations=True,
)
notify_invites_changed(user_profile.realm, changed_invite_referrer=user_profile)
def get_invitation_expiry_date(confirmation_obj: Confirmation) -> Optional[int]: def get_invitation_expiry_date(confirmation_obj: Confirmation) -> Optional[int]:
@ -441,12 +437,14 @@ def do_revoke_multi_use_invite(multiuse_invite: MultiuseInvite) -> None:
notify_invites_changed(realm, changed_invite_referrer=multiuse_invite.referred_by) notify_invites_changed(realm, changed_invite_referrer=multiuse_invite.referred_by)
@transaction.atomic
def do_resend_user_invite_email(prereg_user: PreregistrationUser) -> int: def do_resend_user_invite_email(prereg_user: PreregistrationUser) -> int:
# These are two structurally for the caller's code path. # Take a lock on the realm, so we can check for invitation limits without races
assert prereg_user.referred_by is not None realm_id = assert_is_not_none(prereg_user.realm_id)
assert prereg_user.realm is not None realm = Realm.objects.select_for_update().get(id=realm_id)
check_invite_limit(realm, 1)
check_invite_limit(prereg_user.referred_by.realm, 1) assert prereg_user.referred_by is not None
prereg_user.invited_at = timezone_now() prereg_user.invited_at = timezone_now()
prereg_user.save() prereg_user.save()
@ -460,18 +458,16 @@ def do_resend_user_invite_email(prereg_user: PreregistrationUser) -> int:
invite_expires_in_minutes = (expiry_date - prereg_user.invited_at).total_seconds() / 60 invite_expires_in_minutes = (expiry_date - prereg_user.invited_at).total_seconds() / 60
prereg_user.confirmation.clear() prereg_user.confirmation.clear()
do_increment_logging_stat( do_increment_logging_stat(realm, COUNT_STATS["invites_sent::day"], None, prereg_user.invited_at)
prereg_user.realm, COUNT_STATS["invites_sent::day"], None, prereg_user.invited_at
)
clear_scheduled_invitation_emails(prereg_user.email) clear_scheduled_invitation_emails(prereg_user.email)
# We don't store the custom email body, so just set it to None # We don't store the custom email body, so just set it to None
event = { event = {
"prereg_id": prereg_user.id, "prereg_id": prereg_user.id,
"referrer_id": prereg_user.referred_by.id, "referrer_id": prereg_user.referred_by.id,
"email_language": prereg_user.referred_by.realm.default_language, "email_language": realm.default_language,
"invite_expires_in_minutes": invite_expires_in_minutes, "invite_expires_in_minutes": invite_expires_in_minutes,
} }
queue_json_publish("invites", event) queue_event_on_commit("invites", event)
return datetime_to_timestamp(prereg_user.invited_at) return datetime_to_timestamp(prereg_user.invited_at)

View File

@ -1541,7 +1541,8 @@ class SocialAuthBase(DesktopFlowTestingLib, ZulipTestCase, ABC):
realm = get_realm("zulip") realm = get_realm("zulip")
iago = self.example_user("iago") iago = self.example_user("iago")
do_invite_users(iago, [email], [], invite_expires_in_minutes=2 * 24 * 60) with self.captureOnCommitCallbacks(execute=True):
do_invite_users(iago, [email], [], invite_expires_in_minutes=2 * 24 * 60)
account_data_dict = self.get_account_data_dict(email=email, name=name) account_data_dict = self.get_account_data_dict(email=email, name=name)
result = self.social_auth_test( result = self.social_auth_test(
@ -1883,13 +1884,14 @@ class SocialAuthBase(DesktopFlowTestingLib, ZulipTestCase, ABC):
name = "Alice Jones" name = "Alice Jones"
invite_expires_in_minutes = 2 * 24 * 60 invite_expires_in_minutes = 2 * 24 * 60
do_invite_users( with self.captureOnCommitCallbacks(execute=True):
iago, do_invite_users(
[email], iago,
[], [email],
invite_expires_in_minutes=invite_expires_in_minutes, [],
invite_as=PreregistrationUser.INVITE_AS["REALM_ADMIN"], invite_expires_in_minutes=invite_expires_in_minutes,
) invite_as=PreregistrationUser.INVITE_AS["REALM_ADMIN"],
)
now = timezone_now() + timedelta(days=3) now = timezone_now() + timedelta(days=3)
subdomain = "zulip" subdomain = "zulip"

View File

@ -12,7 +12,8 @@ class EmailLogTest(ZulipTestCase):
with self.settings(EMAIL_BACKEND="zproject.email_backends.EmailLogBackEnd"), mock.patch( with self.settings(EMAIL_BACKEND="zproject.email_backends.EmailLogBackEnd"), mock.patch(
"zproject.email_backends.EmailLogBackEnd._do_send_messages", lambda *args: 1 "zproject.email_backends.EmailLogBackEnd._do_send_messages", lambda *args: 1
), self.assertLogs(level="INFO") as m, self.settings(DEVELOPMENT_LOG_EMAILS=True): ), self.assertLogs(level="INFO") as m, self.settings(DEVELOPMENT_LOG_EMAILS=True):
result = self.client_get("/emails/generate/") with self.captureOnCommitCallbacks(execute=True):
result = self.client_get("/emails/generate/")
self.assertEqual(result.status_code, 302) self.assertEqual(result.status_code, 302)
self.assertIn("emails", result["Location"]) self.assertIn("emails", result["Location"])

View File

@ -1138,12 +1138,13 @@ class NormalActionsTest(BaseAction):
self.user_profile = self.example_user("iago") self.user_profile = self.example_user("iago")
user_profile = self.example_user("cordelia") user_profile = self.example_user("cordelia")
invite_expires_in_minutes = 2 * 24 * 60 invite_expires_in_minutes = 2 * 24 * 60
do_invite_users( with self.captureOnCommitCallbacks(execute=True):
user_profile, do_invite_users(
["foo@zulip.com"], user_profile,
[], ["foo@zulip.com"],
invite_expires_in_minutes=invite_expires_in_minutes, [],
) invite_expires_in_minutes=invite_expires_in_minutes,
)
with self.verify_action(num_events=2) as events: with self.verify_action(num_events=2) as events:
do_deactivate_user(user_profile, acting_user=None) do_deactivate_user(user_profile, acting_user=None)
@ -1159,12 +1160,13 @@ class NormalActionsTest(BaseAction):
] ]
invite_expires_in_minutes = 2 * 24 * 60 invite_expires_in_minutes = 2 * 24 * 60
do_invite_users( with self.captureOnCommitCallbacks(execute=True):
self.user_profile, do_invite_users(
["foo@zulip.com"], self.user_profile,
streams, ["foo@zulip.com"],
invite_expires_in_minutes=invite_expires_in_minutes, streams,
) invite_expires_in_minutes=invite_expires_in_minutes,
)
prereg_users = PreregistrationUser.objects.filter( prereg_users = PreregistrationUser.objects.filter(
referred_by__realm=self.user_profile.realm referred_by__realm=self.user_profile.realm
) )
@ -1202,12 +1204,13 @@ class NormalActionsTest(BaseAction):
] ]
invite_expires_in_minutes = 2 * 24 * 60 invite_expires_in_minutes = 2 * 24 * 60
do_invite_users( with self.captureOnCommitCallbacks(execute=True):
self.user_profile, do_invite_users(
["foo@zulip.com"], self.user_profile,
streams, ["foo@zulip.com"],
invite_expires_in_minutes=invite_expires_in_minutes, streams,
) invite_expires_in_minutes=invite_expires_in_minutes,
)
prereg_user = PreregistrationUser.objects.get(email="foo@zulip.com") prereg_user = PreregistrationUser.objects.get(email="foo@zulip.com")
with self.verify_action(state_change_expected=True, num_events=7) as events: with self.verify_action(state_change_expected=True, num_events=7) as events:
@ -1220,7 +1223,7 @@ class NormalActionsTest(BaseAction):
acting_user=None, acting_user=None,
) )
check_invites_changed("events[1]", events[1]) check_invites_changed("events[6]", events[6])
def test_typing_events(self) -> None: def test_typing_events(self) -> None:
with self.verify_action(state_change_expected=False) as events: with self.verify_action(state_change_expected=False) as events:

View File

@ -395,9 +395,16 @@ class TestDevelopmentEmailsLog(ZulipTestCase):
), self.assertLogs(level="INFO") as logger, mock.patch( ), self.assertLogs(level="INFO") as logger, mock.patch(
"zproject.email_backends.EmailLogBackEnd._do_send_messages", lambda *args: 1 "zproject.email_backends.EmailLogBackEnd._do_send_messages", lambda *args: 1
): ):
result = self.client_get( # Parts of this endpoint use transactions, and use
"/emails/generate/" # transaction.on_commit to run code when the transaction
) # Generates emails and redirects to /emails/ # commits. Tests are run inside one big outer
# transaction, so those never get a chance to run unless
# we explicitly make a fake boundary to run them at; that
# is what captureOnCommitCallbacks does.
with self.captureOnCommitCallbacks(execute=True):
result = self.client_get(
"/emails/generate/"
) # Generates emails and redirects to /emails/
self.assertEqual("/emails/", result["Location"]) # Make sure redirect URL is correct. self.assertEqual("/emails/", result["Location"]) # Make sure redirect URL is correct.
# The above call to /emails/generate/ creates the emails and # The above call to /emails/generate/ creates the emails and

View File

@ -22,10 +22,11 @@ if TYPE_CHECKING:
class EmailTranslationTestCase(ZulipTestCase): class EmailTranslationTestCase(ZulipTestCase):
def test_email_translation(self) -> None: def test_email_translation(self) -> None:
def check_translation(phrase: str, request_type: str, *args: Any, **kwargs: Any) -> None: def check_translation(phrase: str, request_type: str, *args: Any, **kwargs: Any) -> None:
if request_type == "post": with self.captureOnCommitCallbacks(execute=True):
self.client_post(*args, **kwargs) if request_type == "post":
elif request_type == "patch": self.client_post(*args, **kwargs)
self.client_patch(*args, **kwargs) elif request_type == "patch":
self.client_patch(*args, **kwargs)
email_message = mail.outbox[0] email_message = mail.outbox[0]
self.assertIn(phrase, email_message.body) self.assertIn(phrase, email_message.body)

View File

@ -190,16 +190,17 @@ class InviteUserBase(ZulipTestCase):
if invite_expires_in is None: if invite_expires_in is None:
invite_expires_in = orjson.dumps(None).decode() invite_expires_in = orjson.dumps(None).decode()
return self.client_post( with self.captureOnCommitCallbacks(execute=True):
"/json/invites", return self.client_post(
{ "/json/invites",
"invitee_emails": invitee_emails, {
"invite_expires_in_minutes": invite_expires_in, "invitee_emails": invitee_emails,
"stream_ids": orjson.dumps(stream_ids).decode(), "invite_expires_in_minutes": invite_expires_in,
"invite_as": invite_as, "stream_ids": orjson.dumps(stream_ids).decode(),
}, "invite_as": invite_as,
subdomain=realm.string_id if realm else "zulip", },
) subdomain=realm.string_id if realm else "zulip",
)
class InviteUserTest(InviteUserBase): class InviteUserTest(InviteUserBase):
@ -1480,32 +1481,38 @@ so we didn't send them an invitation. We did send invitations to everyone else!"
] ]
invite_expires_in_minutes = 2 * 24 * 60 invite_expires_in_minutes = 2 * 24 * 60
do_invite_users( with self.captureOnCommitCallbacks(execute=True):
self.user_profile, do_invite_users(
["foo@zulip.com"], self.user_profile,
streams, ["foo@zulip.com"],
invite_expires_in_minutes=invite_expires_in_minutes, streams,
) invite_expires_in_minutes=invite_expires_in_minutes,
)
prereg_user = PreregistrationUser.objects.get(email="foo@zulip.com") prereg_user = PreregistrationUser.objects.get(email="foo@zulip.com")
do_invite_users( with self.captureOnCommitCallbacks(execute=True):
self.user_profile, do_invite_users(
["foo@zulip.com"], self.user_profile,
streams, ["foo@zulip.com"],
invite_expires_in_minutes=invite_expires_in_minutes, streams,
) invite_expires_in_minutes=invite_expires_in_minutes,
do_invite_users( )
self.user_profile, do_invite_users(
["foo@zulip.com"], self.user_profile,
streams, ["foo@zulip.com"],
invite_expires_in_minutes=invite_expires_in_minutes, streams,
) invite_expires_in_minutes=invite_expires_in_minutes,
)
# Also send an invite from a different realm. # Also send an invite from a different realm.
lear = get_realm("lear") lear = get_realm("lear")
lear_user = self.lear_user("cordelia") lear_user = self.lear_user("cordelia")
do_invite_users( with self.captureOnCommitCallbacks(execute=True):
lear_user, ["foo@zulip.com"], [], invite_expires_in_minutes=invite_expires_in_minutes do_invite_users(
) lear_user,
["foo@zulip.com"],
[],
invite_expires_in_minutes=invite_expires_in_minutes,
)
invites = PreregistrationUser.objects.filter(email__iexact="foo@zulip.com") invites = PreregistrationUser.objects.filter(email__iexact="foo@zulip.com")
self.assert_length(invites, 4) self.assert_length(invites, 4)
@ -1710,36 +1717,37 @@ class InvitationsTestCase(InviteUserBase):
] ]
invite_expires_in_minutes = 2 * 24 * 60 invite_expires_in_minutes = 2 * 24 * 60
do_invite_users( with self.captureOnCommitCallbacks(execute=True):
user_profile, do_invite_users(
["TestOne@zulip.com"], user_profile,
streams, ["TestOne@zulip.com"],
invite_expires_in_minutes=invite_expires_in_minutes, streams,
) invite_expires_in_minutes=invite_expires_in_minutes,
do_invite_users( )
user_profile, do_invite_users(
["TestTwo@zulip.com"], user_profile,
streams, ["TestTwo@zulip.com"],
invite_expires_in_minutes=invite_expires_in_minutes, streams,
) invite_expires_in_minutes=invite_expires_in_minutes,
do_invite_users( )
hamlet, do_invite_users(
["TestThree@zulip.com"], hamlet,
streams, ["TestThree@zulip.com"],
invite_expires_in_minutes=invite_expires_in_minutes, streams,
) invite_expires_in_minutes=invite_expires_in_minutes,
do_invite_users( )
othello, do_invite_users(
["TestFour@zulip.com"], othello,
streams, ["TestFour@zulip.com"],
invite_expires_in_minutes=invite_expires_in_minutes, streams,
) invite_expires_in_minutes=invite_expires_in_minutes,
do_invite_users( )
self.mit_user("sipbtest"), do_invite_users(
["TestOne@mit.edu"], self.mit_user("sipbtest"),
[], ["TestOne@mit.edu"],
invite_expires_in_minutes=invite_expires_in_minutes, [],
) invite_expires_in_minutes=invite_expires_in_minutes,
)
do_create_multiuse_invite_link( do_create_multiuse_invite_link(
user_profile, PreregistrationUser.INVITE_AS["MEMBER"], invite_expires_in_minutes user_profile, PreregistrationUser.INVITE_AS["MEMBER"], invite_expires_in_minutes
) )
@ -1769,14 +1777,17 @@ class InvitationsTestCase(InviteUserBase):
] ]
invite_expires_in_minutes = 2 * 24 * 60 invite_expires_in_minutes = 2 * 24 * 60
do_invite_users( with self.captureOnCommitCallbacks(execute=True):
user_profile, do_invite_users(
["TestOne@zulip.com"], user_profile,
streams, ["TestOne@zulip.com"],
invite_expires_in_minutes=invite_expires_in_minutes, streams,
) invite_expires_in_minutes=invite_expires_in_minutes,
)
with time_machine.travel((timezone_now() - timedelta(days=3)), tick=False): with time_machine.travel(
(timezone_now() - timedelta(days=3)), tick=False
), self.captureOnCommitCallbacks(execute=True):
do_invite_users( do_invite_users(
user_profile, user_profile,
["TestTwo@zulip.com"], ["TestTwo@zulip.com"],
@ -1819,7 +1830,9 @@ class InvitationsTestCase(InviteUserBase):
get_stream(stream_name, user_profile.realm) for stream_name in ["Denmark", "Scotland"] get_stream(stream_name, user_profile.realm) for stream_name in ["Denmark", "Scotland"]
] ]
with time_machine.travel((timezone_now() - timedelta(days=1000)), tick=False): with time_machine.travel(
(timezone_now() - timedelta(days=1000)), tick=False
), self.captureOnCommitCallbacks(execute=True):
# Testing the invitation with expiry date set to "None" exists # Testing the invitation with expiry date set to "None" exists
# after a large amount of days. # after a large amount of days.
do_invite_users( do_invite_users(
@ -2055,7 +2068,8 @@ class InvitationsTestCase(InviteUserBase):
original_timestamp = scheduledemail_filter.values_list("scheduled_timestamp", flat=True) original_timestamp = scheduledemail_filter.values_list("scheduled_timestamp", flat=True)
# Resend invite # Resend invite
result = self.client_post("/json/invites/" + str(prereg_user.id) + "/resend") with self.captureOnCommitCallbacks(execute=True):
result = self.client_post("/json/invites/" + str(prereg_user.id) + "/resend")
self.assertEqual( self.assertEqual(
ScheduledEmail.objects.filter( ScheduledEmail.objects.filter(
address__iexact=invitee, type=ScheduledEmail.INVITATION_REMINDER address__iexact=invitee, type=ScheduledEmail.INVITATION_REMINDER
@ -2101,7 +2115,8 @@ class InvitationsTestCase(InviteUserBase):
original_timestamp = scheduledemail_filter.values_list("scheduled_timestamp", flat=True) original_timestamp = scheduledemail_filter.values_list("scheduled_timestamp", flat=True)
# Resend invite # Resend invite
result = self.client_post("/json/invites/" + str(prereg_user.id) + "/resend") with self.captureOnCommitCallbacks(execute=True):
result = self.client_post("/json/invites/" + str(prereg_user.id) + "/resend")
self.assertEqual( self.assertEqual(
ScheduledEmail.objects.filter( ScheduledEmail.objects.filter(
address__iexact=invitee, type=ScheduledEmail.INVITATION_REMINDER address__iexact=invitee, type=ScheduledEmail.INVITATION_REMINDER
@ -2153,7 +2168,8 @@ class InvitationsTestCase(InviteUserBase):
self.assert_json_error(error_result, "Must be an organization owner") self.assert_json_error(error_result, "Must be an organization owner")
self.login("desdemona") self.login("desdemona")
result = self.client_post("/json/invites/" + str(prereg_user.id) + "/resend") with self.captureOnCommitCallbacks(execute=True):
result = self.client_post("/json/invites/" + str(prereg_user.id) + "/resend")
self.assert_json_success(result) self.assert_json_success(result)
self.assertEqual( self.assertEqual(
@ -2180,7 +2196,8 @@ class InvitationsTestCase(InviteUserBase):
self.check_sent_emails([invitee]) self.check_sent_emails([invitee])
mail.outbox.pop() mail.outbox.pop()
result = self.client_post("/json/invites/" + str(prereg_user.id) + "/resend") with self.captureOnCommitCallbacks(execute=True):
result = self.client_post("/json/invites/" + str(prereg_user.id) + "/resend")
self.assert_json_success(result) self.assert_json_success(result)
self.check_sent_emails([invitee]) self.check_sent_emails([invitee])

View File

@ -3753,17 +3753,17 @@ class UserSignUpTest(ZulipTestCase):
], ],
) )
stream_ids = [self.get_stream_id(stream_name) for stream_name in streams] stream_ids = [self.get_stream_id(stream_name) for stream_name in streams]
response = self.client_post( with self.captureOnCommitCallbacks(execute=True):
"/json/invites", response = self.client_post(
{ "/json/invites",
"invitee_emails": email, {
"stream_ids": orjson.dumps(stream_ids).decode(), "invitee_emails": email,
"invite_as": invite_as, "stream_ids": orjson.dumps(stream_ids).decode(),
}, "invite_as": invite_as,
) },
)
self.assert_json_success(response) self.assert_json_success(response)
self.logout() self.logout()
result = self.submit_reg_form_for_user( result = self.submit_reg_form_for_user(
email, email,
password, password,

View File

@ -897,12 +897,13 @@ class QueryCountTest(ZulipTestCase):
streams = [get_stream(stream_name, realm) for stream_name in stream_names] streams = [get_stream(stream_name, realm) for stream_name in stream_names]
invite_expires_in_minutes = 4 * 24 * 60 invite_expires_in_minutes = 4 * 24 * 60
do_invite_users( with self.captureOnCommitCallbacks(execute=True):
user_profile=self.example_user("hamlet"), do_invite_users(
invitee_emails=["fred@zulip.com"], user_profile=self.example_user("hamlet"),
streams=streams, invitee_emails=["fred@zulip.com"],
invite_expires_in_minutes=invite_expires_in_minutes, streams=streams,
) invite_expires_in_minutes=invite_expires_in_minutes,
)
prereg_user = PreregistrationUser.objects.get(email="fred@zulip.com") prereg_user = PreregistrationUser.objects.get(email="fred@zulip.com")
@ -1701,35 +1702,36 @@ class ActivateTest(ZulipTestCase):
desdemona = self.example_user("desdemona") desdemona = self.example_user("desdemona")
invite_expires_in_minutes = 2 * 24 * 60 invite_expires_in_minutes = 2 * 24 * 60
do_invite_users( with self.captureOnCommitCallbacks(execute=True):
iago, do_invite_users(
["new1@zulip.com", "new2@zulip.com"], iago,
[], ["new1@zulip.com", "new2@zulip.com"],
invite_expires_in_minutes=invite_expires_in_minutes, [],
invite_as=PreregistrationUser.INVITE_AS["REALM_ADMIN"], invite_expires_in_minutes=invite_expires_in_minutes,
) invite_as=PreregistrationUser.INVITE_AS["REALM_ADMIN"],
do_invite_users( )
desdemona, do_invite_users(
["new3@zulip.com", "new4@zulip.com"], desdemona,
[], ["new3@zulip.com", "new4@zulip.com"],
invite_expires_in_minutes=invite_expires_in_minutes, [],
invite_as=PreregistrationUser.INVITE_AS["REALM_ADMIN"], invite_expires_in_minutes=invite_expires_in_minutes,
) invite_as=PreregistrationUser.INVITE_AS["REALM_ADMIN"],
)
do_invite_users( do_invite_users(
iago, iago,
["new5@zulip.com"], ["new5@zulip.com"],
[], [],
invite_expires_in_minutes=None, invite_expires_in_minutes=None,
invite_as=PreregistrationUser.INVITE_AS["REALM_ADMIN"], invite_as=PreregistrationUser.INVITE_AS["REALM_ADMIN"],
) )
do_invite_users( do_invite_users(
desdemona, desdemona,
["new6@zulip.com"], ["new6@zulip.com"],
[], [],
invite_expires_in_minutes=None, invite_expires_in_minutes=None,
invite_as=PreregistrationUser.INVITE_AS["REALM_ADMIN"], invite_as=PreregistrationUser.INVITE_AS["REALM_ADMIN"],
) )
iago_multiuse_key = do_create_multiuse_invite_link( iago_multiuse_key = do_create_multiuse_invite_link(
iago, PreregistrationUser.INVITE_AS["MEMBER"], invite_expires_in_minutes iago, PreregistrationUser.INVITE_AS["MEMBER"], invite_expires_in_minutes

View File

@ -15,7 +15,7 @@ from zerver.actions.invites import (
do_revoke_user_invite, do_revoke_user_invite,
) )
from zerver.decorator import require_member_or_admin from zerver.decorator import require_member_or_admin
from zerver.lib.exceptions import JsonableError, OrganizationOwnerRequiredError from zerver.lib.exceptions import InvitationError, JsonableError, OrganizationOwnerRequiredError
from zerver.lib.request import REQ, has_request_variables from zerver.lib.request import REQ, has_request_variables
from zerver.lib.response import json_success from zerver.lib.response import json_success
from zerver.lib.streams import access_stream_by_id from zerver.lib.streams import access_stream_by_id
@ -93,13 +93,25 @@ def invite_users_backend(
if len(streams) and not user_profile.can_subscribe_other_users(): if len(streams) and not user_profile.can_subscribe_other_users():
raise JsonableError(_("You do not have permission to subscribe other users to channels.")) raise JsonableError(_("You do not have permission to subscribe other users to channels."))
do_invite_users( skipped = do_invite_users(
user_profile, user_profile,
invitee_emails, invitee_emails,
streams, streams,
invite_expires_in_minutes=invite_expires_in_minutes, invite_expires_in_minutes=invite_expires_in_minutes,
invite_as=invite_as, invite_as=invite_as,
) )
if skipped:
raise InvitationError(
_(
"Some of those addresses are already using Zulip, "
"so we didn't send them an invitation. We did send "
"invitations to everyone else!"
),
skipped,
sent_invitations=True,
)
return json_success(request) return json_success(request)