invites: Lock the realm when determining invitation counts.

This prevents users from hammering the invitation endpoint, causing
races, and inviting more users than they should otherwise be allowed
to.

Doing this requires that we not raise InvitationError when we have
partially succeeded; that behaviour is left to the one callsite of
do_invite_users.

Reported by Lakshit Agarwal (@chiekosec).
This commit is contained in:
Alex Vandiver 2024-01-10 21:01:21 +00:00 committed by Tim Abbott
parent eef5d22944
commit d863aa56de
12 changed files with 263 additions and 206 deletions

View File

@ -1,5 +1,6 @@
from contextlib import AbstractContextManager, ExitStack, contextmanager
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List, Optional, Tuple, Type
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type
from unittest import mock
import orjson
@ -1640,6 +1641,23 @@ class TestLoggingCountStats(AnalyticsTestCase):
def test_invites_sent(self) -> None:
property = "invites_sent::day"
@contextmanager
def invite_context(
too_many_recent_realm_invites: bool = False, failure: bool = False
) -> Iterator[None]:
managers: List[AbstractContextManager[Any]] = [
mock.patch(
"zerver.actions.invites.too_many_recent_realm_invites", return_value=False
),
self.captureOnCommitCallbacks(execute=True),
]
if failure:
managers.append(self.assertRaises(InvitationError))
with ExitStack() as stack:
for mgr in managers:
stack.enter_context(mgr)
yield
def assertInviteCountEquals(count: int) -> None:
self.assertEqual(
count,
@ -1652,7 +1670,7 @@ class TestLoggingCountStats(AnalyticsTestCase):
stream, _ = self.create_stream_with_recipient()
invite_expires_in_minutes = 2 * 24 * 60
with mock.patch("zerver.actions.invites.too_many_recent_realm_invites", return_value=False):
with invite_context():
do_invite_users(
user,
["user1@domain.tld", "user2@domain.tld"],
@ -1663,7 +1681,7 @@ class TestLoggingCountStats(AnalyticsTestCase):
# We currently send emails when re-inviting users that haven't
# turned into accounts, so count them towards the total
with mock.patch("zerver.actions.invites.too_many_recent_realm_invites", return_value=False):
with invite_context():
do_invite_users(
user,
["user1@domain.tld", "user2@domain.tld"],
@ -1673,9 +1691,7 @@ class TestLoggingCountStats(AnalyticsTestCase):
assertInviteCountEquals(4)
# Test mix of good and malformed invite emails
with self.assertRaises(InvitationError), mock.patch(
"zerver.actions.invites.too_many_recent_realm_invites", return_value=False
):
with invite_context(failure=True):
do_invite_users(
user,
["user3@domain.tld", "malformed"],
@ -1685,15 +1701,14 @@ class TestLoggingCountStats(AnalyticsTestCase):
assertInviteCountEquals(4)
# Test inviting existing users
with self.assertRaises(InvitationError), mock.patch(
"zerver.actions.invites.too_many_recent_realm_invites", return_value=False
):
do_invite_users(
with invite_context():
skipped = do_invite_users(
user,
["first@domain.tld", "user4@domain.tld"],
[stream],
invite_expires_in_minutes=invite_expires_in_minutes,
)
self.assert_length(skipped, 1)
assertInviteCountEquals(5)
# Revoking invite should not give you credit
@ -1703,7 +1718,7 @@ class TestLoggingCountStats(AnalyticsTestCase):
assertInviteCountEquals(5)
# Resending invite should cost you
with mock.patch("zerver.actions.invites.too_many_recent_realm_invites", return_value=False):
with invite_context():
do_resend_user_invite_email(assert_is_not_none(PreregistrationUser.objects.first()))
assertInviteCountEquals(6)

View File

@ -709,6 +709,7 @@ class TestSupportEndpoint(ZulipTestCase):
stream: str, invitee_email: str, realm: Optional[Realm] = None
) -> None:
invite_expires_in_minutes = 10 * 24 * 60
with self.captureOnCommitCallbacks(execute=True):
self.client_post(
"/json/invites",
{

View File

@ -21,10 +21,11 @@ from zerver.lib.email_validation import (
)
from zerver.lib.exceptions import InvitationError
from zerver.lib.invites import notify_invites_changed
from zerver.lib.queue import queue_json_publish
from zerver.lib.queue import queue_event_on_commit
from zerver.lib.send_email import FromAddress, clear_scheduled_invitation_emails, send_email
from zerver.lib.timestamp import datetime_to_timestamp
from zerver.lib.types import UnspecifiedValue
from zerver.lib.utils import assert_is_not_none
from zerver.models import Message, MultiuseInvite, PreregistrationUser, Realm, Stream, UserProfile
from zerver.models.prereg_users import filter_to_valid_prereg_users
@ -193,6 +194,7 @@ def check_invite_limit(realm: Realm, num_invitees: int) -> None:
)
@transaction.atomic
def do_invite_users(
user_profile: UserProfile,
invitee_emails: Collection[str],
@ -200,23 +202,25 @@ def do_invite_users(
*,
invite_expires_in_minutes: Optional[int],
invite_as: int = PreregistrationUser.INVITE_AS["MEMBER"],
) -> None:
) -> List[Tuple[str, str, bool]]:
num_invites = len(invitee_emails)
check_invite_limit(user_profile.realm, num_invites)
# Lock the realm, since we need to not race with other invitations
realm = Realm.objects.select_for_update().get(id=user_profile.realm_id)
check_invite_limit(realm, num_invites)
if settings.BILLING_ENABLED:
from corporate.lib.registration import check_spare_licenses_available_for_inviting_new_users
if invite_as == PreregistrationUser.INVITE_AS["GUEST_USER"]:
check_spare_licenses_available_for_inviting_new_users(
user_profile.realm, extra_guests_count=num_invites
realm, extra_guests_count=num_invites
)
else:
check_spare_licenses_available_for_inviting_new_users(
user_profile.realm, extra_non_guests_count=num_invites
realm, extra_non_guests_count=num_invites
)
realm = user_profile.realm
if not realm.invite_required:
# Inhibit joining an open realm to send spam invitations.
min_age = timedelta(days=settings.INVITES_MIN_USER_AGE_DAYS)
@ -232,7 +236,7 @@ def do_invite_users(
good_emails: Set[str] = set()
errors: List[Tuple[str, str, bool]] = []
validate_email_allowed_in_realm = get_realm_email_validator(user_profile.realm)
validate_email_allowed_in_realm = get_realm_email_validator(realm)
for email in invitee_emails:
if email == "":
continue
@ -251,7 +255,7 @@ def do_invite_users(
but we still need to make sure they're not
gonna conflict with existing users
"""
error_dict = get_existing_user_errors(user_profile.realm, good_emails)
error_dict = get_existing_user_errors(realm, good_emails)
skipped: List[Tuple[str, str, bool]] = []
for email in error_dict:
@ -278,7 +282,7 @@ def do_invite_users(
# is used for rate limiting invitations, rather than keeping track of
# when exactly invitations were sent
do_increment_logging_stat(
user_profile.realm,
realm,
COUNT_STATS["invites_sent::day"],
None,
timezone_now(),
@ -290,7 +294,7 @@ def do_invite_users(
for email in validated_emails:
# The logged in user is the referrer.
prereg_user = PreregistrationUser(
email=email, referred_by=user_profile, invited_as=invite_as, realm=user_profile.realm
email=email, referred_by=user_profile, invited_as=invite_as, realm=realm
)
prereg_user.save()
stream_ids = [stream.id for stream in streams]
@ -299,22 +303,14 @@ def do_invite_users(
event = {
"prereg_id": prereg_user.id,
"referrer_id": user_profile.id,
"email_language": user_profile.realm.default_language,
"email_language": realm.default_language,
"invite_expires_in_minutes": invite_expires_in_minutes,
}
queue_json_publish("invites", event)
queue_event_on_commit("invites", event)
if skipped:
raise InvitationError(
_(
"Some of those addresses are already using Zulip, "
"so we didn't send them an invitation. We did send "
"invitations to everyone else!"
),
skipped,
sent_invitations=True,
)
notify_invites_changed(user_profile.realm, changed_invite_referrer=user_profile)
notify_invites_changed(realm, changed_invite_referrer=user_profile)
return skipped
def get_invitation_expiry_date(confirmation_obj: Confirmation) -> Optional[int]:
@ -441,12 +437,14 @@ def do_revoke_multi_use_invite(multiuse_invite: MultiuseInvite) -> None:
notify_invites_changed(realm, changed_invite_referrer=multiuse_invite.referred_by)
@transaction.atomic
def do_resend_user_invite_email(prereg_user: PreregistrationUser) -> int:
# These are two structurally for the caller's code path.
assert prereg_user.referred_by is not None
assert prereg_user.realm is not None
# Take a lock on the realm, so we can check for invitation limits without races
realm_id = assert_is_not_none(prereg_user.realm_id)
realm = Realm.objects.select_for_update().get(id=realm_id)
check_invite_limit(realm, 1)
check_invite_limit(prereg_user.referred_by.realm, 1)
assert prereg_user.referred_by is not None
prereg_user.invited_at = timezone_now()
prereg_user.save()
@ -460,18 +458,16 @@ def do_resend_user_invite_email(prereg_user: PreregistrationUser) -> int:
invite_expires_in_minutes = (expiry_date - prereg_user.invited_at).total_seconds() / 60
prereg_user.confirmation.clear()
do_increment_logging_stat(
prereg_user.realm, COUNT_STATS["invites_sent::day"], None, prereg_user.invited_at
)
do_increment_logging_stat(realm, COUNT_STATS["invites_sent::day"], None, prereg_user.invited_at)
clear_scheduled_invitation_emails(prereg_user.email)
# We don't store the custom email body, so just set it to None
event = {
"prereg_id": prereg_user.id,
"referrer_id": prereg_user.referred_by.id,
"email_language": prereg_user.referred_by.realm.default_language,
"email_language": realm.default_language,
"invite_expires_in_minutes": invite_expires_in_minutes,
}
queue_json_publish("invites", event)
queue_event_on_commit("invites", event)
return datetime_to_timestamp(prereg_user.invited_at)

View File

@ -1541,6 +1541,7 @@ class SocialAuthBase(DesktopFlowTestingLib, ZulipTestCase, ABC):
realm = get_realm("zulip")
iago = self.example_user("iago")
with self.captureOnCommitCallbacks(execute=True):
do_invite_users(iago, [email], [], invite_expires_in_minutes=2 * 24 * 60)
account_data_dict = self.get_account_data_dict(email=email, name=name)
@ -1883,6 +1884,7 @@ class SocialAuthBase(DesktopFlowTestingLib, ZulipTestCase, ABC):
name = "Alice Jones"
invite_expires_in_minutes = 2 * 24 * 60
with self.captureOnCommitCallbacks(execute=True):
do_invite_users(
iago,
[email],

View File

@ -12,6 +12,7 @@ class EmailLogTest(ZulipTestCase):
with self.settings(EMAIL_BACKEND="zproject.email_backends.EmailLogBackEnd"), mock.patch(
"zproject.email_backends.EmailLogBackEnd._do_send_messages", lambda *args: 1
), self.assertLogs(level="INFO") as m, self.settings(DEVELOPMENT_LOG_EMAILS=True):
with self.captureOnCommitCallbacks(execute=True):
result = self.client_get("/emails/generate/")
self.assertEqual(result.status_code, 302)
self.assertIn("emails", result["Location"])

View File

@ -1138,6 +1138,7 @@ class NormalActionsTest(BaseAction):
self.user_profile = self.example_user("iago")
user_profile = self.example_user("cordelia")
invite_expires_in_minutes = 2 * 24 * 60
with self.captureOnCommitCallbacks(execute=True):
do_invite_users(
user_profile,
["foo@zulip.com"],
@ -1159,6 +1160,7 @@ class NormalActionsTest(BaseAction):
]
invite_expires_in_minutes = 2 * 24 * 60
with self.captureOnCommitCallbacks(execute=True):
do_invite_users(
self.user_profile,
["foo@zulip.com"],
@ -1202,6 +1204,7 @@ class NormalActionsTest(BaseAction):
]
invite_expires_in_minutes = 2 * 24 * 60
with self.captureOnCommitCallbacks(execute=True):
do_invite_users(
self.user_profile,
["foo@zulip.com"],
@ -1220,7 +1223,7 @@ class NormalActionsTest(BaseAction):
acting_user=None,
)
check_invites_changed("events[1]", events[1])
check_invites_changed("events[6]", events[6])
def test_typing_events(self) -> None:
with self.verify_action(state_change_expected=False) as events:

View File

@ -395,6 +395,13 @@ class TestDevelopmentEmailsLog(ZulipTestCase):
), self.assertLogs(level="INFO") as logger, mock.patch(
"zproject.email_backends.EmailLogBackEnd._do_send_messages", lambda *args: 1
):
# Parts of this endpoint use transactions, and use
# transaction.on_commit to run code when the transaction
# commits. Tests are run inside one big outer
# transaction, so those never get a chance to run unless
# we explicitly make a fake boundary to run them at; that
# is what captureOnCommitCallbacks does.
with self.captureOnCommitCallbacks(execute=True):
result = self.client_get(
"/emails/generate/"
) # Generates emails and redirects to /emails/

View File

@ -22,6 +22,7 @@ if TYPE_CHECKING:
class EmailTranslationTestCase(ZulipTestCase):
def test_email_translation(self) -> None:
def check_translation(phrase: str, request_type: str, *args: Any, **kwargs: Any) -> None:
with self.captureOnCommitCallbacks(execute=True):
if request_type == "post":
self.client_post(*args, **kwargs)
elif request_type == "patch":

View File

@ -190,6 +190,7 @@ class InviteUserBase(ZulipTestCase):
if invite_expires_in is None:
invite_expires_in = orjson.dumps(None).decode()
with self.captureOnCommitCallbacks(execute=True):
return self.client_post(
"/json/invites",
{
@ -1480,6 +1481,7 @@ so we didn't send them an invitation. We did send invitations to everyone else!"
]
invite_expires_in_minutes = 2 * 24 * 60
with self.captureOnCommitCallbacks(execute=True):
do_invite_users(
self.user_profile,
["foo@zulip.com"],
@ -1487,6 +1489,7 @@ so we didn't send them an invitation. We did send invitations to everyone else!"
invite_expires_in_minutes=invite_expires_in_minutes,
)
prereg_user = PreregistrationUser.objects.get(email="foo@zulip.com")
with self.captureOnCommitCallbacks(execute=True):
do_invite_users(
self.user_profile,
["foo@zulip.com"],
@ -1503,8 +1506,12 @@ so we didn't send them an invitation. We did send invitations to everyone else!"
# Also send an invite from a different realm.
lear = get_realm("lear")
lear_user = self.lear_user("cordelia")
with self.captureOnCommitCallbacks(execute=True):
do_invite_users(
lear_user, ["foo@zulip.com"], [], invite_expires_in_minutes=invite_expires_in_minutes
lear_user,
["foo@zulip.com"],
[],
invite_expires_in_minutes=invite_expires_in_minutes,
)
invites = PreregistrationUser.objects.filter(email__iexact="foo@zulip.com")
@ -1710,6 +1717,7 @@ class InvitationsTestCase(InviteUserBase):
]
invite_expires_in_minutes = 2 * 24 * 60
with self.captureOnCommitCallbacks(execute=True):
do_invite_users(
user_profile,
["TestOne@zulip.com"],
@ -1769,6 +1777,7 @@ class InvitationsTestCase(InviteUserBase):
]
invite_expires_in_minutes = 2 * 24 * 60
with self.captureOnCommitCallbacks(execute=True):
do_invite_users(
user_profile,
["TestOne@zulip.com"],
@ -1776,7 +1785,9 @@ class InvitationsTestCase(InviteUserBase):
invite_expires_in_minutes=invite_expires_in_minutes,
)
with time_machine.travel((timezone_now() - timedelta(days=3)), tick=False):
with time_machine.travel(
(timezone_now() - timedelta(days=3)), tick=False
), self.captureOnCommitCallbacks(execute=True):
do_invite_users(
user_profile,
["TestTwo@zulip.com"],
@ -1819,7 +1830,9 @@ class InvitationsTestCase(InviteUserBase):
get_stream(stream_name, user_profile.realm) for stream_name in ["Denmark", "Scotland"]
]
with time_machine.travel((timezone_now() - timedelta(days=1000)), tick=False):
with time_machine.travel(
(timezone_now() - timedelta(days=1000)), tick=False
), self.captureOnCommitCallbacks(execute=True):
# Testing the invitation with expiry date set to "None" exists
# after a large amount of days.
do_invite_users(
@ -2055,6 +2068,7 @@ class InvitationsTestCase(InviteUserBase):
original_timestamp = scheduledemail_filter.values_list("scheduled_timestamp", flat=True)
# Resend invite
with self.captureOnCommitCallbacks(execute=True):
result = self.client_post("/json/invites/" + str(prereg_user.id) + "/resend")
self.assertEqual(
ScheduledEmail.objects.filter(
@ -2101,6 +2115,7 @@ class InvitationsTestCase(InviteUserBase):
original_timestamp = scheduledemail_filter.values_list("scheduled_timestamp", flat=True)
# Resend invite
with self.captureOnCommitCallbacks(execute=True):
result = self.client_post("/json/invites/" + str(prereg_user.id) + "/resend")
self.assertEqual(
ScheduledEmail.objects.filter(
@ -2153,6 +2168,7 @@ class InvitationsTestCase(InviteUserBase):
self.assert_json_error(error_result, "Must be an organization owner")
self.login("desdemona")
with self.captureOnCommitCallbacks(execute=True):
result = self.client_post("/json/invites/" + str(prereg_user.id) + "/resend")
self.assert_json_success(result)
@ -2180,6 +2196,7 @@ class InvitationsTestCase(InviteUserBase):
self.check_sent_emails([invitee])
mail.outbox.pop()
with self.captureOnCommitCallbacks(execute=True):
result = self.client_post("/json/invites/" + str(prereg_user.id) + "/resend")
self.assert_json_success(result)
self.check_sent_emails([invitee])

View File

@ -3753,6 +3753,7 @@ class UserSignUpTest(ZulipTestCase):
],
)
stream_ids = [self.get_stream_id(stream_name) for stream_name in streams]
with self.captureOnCommitCallbacks(execute=True):
response = self.client_post(
"/json/invites",
{
@ -3763,7 +3764,6 @@ class UserSignUpTest(ZulipTestCase):
)
self.assert_json_success(response)
self.logout()
result = self.submit_reg_form_for_user(
email,
password,

View File

@ -897,6 +897,7 @@ class QueryCountTest(ZulipTestCase):
streams = [get_stream(stream_name, realm) for stream_name in stream_names]
invite_expires_in_minutes = 4 * 24 * 60
with self.captureOnCommitCallbacks(execute=True):
do_invite_users(
user_profile=self.example_user("hamlet"),
invitee_emails=["fred@zulip.com"],
@ -1701,6 +1702,7 @@ class ActivateTest(ZulipTestCase):
desdemona = self.example_user("desdemona")
invite_expires_in_minutes = 2 * 24 * 60
with self.captureOnCommitCallbacks(execute=True):
do_invite_users(
iago,
["new1@zulip.com", "new2@zulip.com"],

View File

@ -15,7 +15,7 @@ from zerver.actions.invites import (
do_revoke_user_invite,
)
from zerver.decorator import require_member_or_admin
from zerver.lib.exceptions import JsonableError, OrganizationOwnerRequiredError
from zerver.lib.exceptions import InvitationError, JsonableError, OrganizationOwnerRequiredError
from zerver.lib.request import REQ, has_request_variables
from zerver.lib.response import json_success
from zerver.lib.streams import access_stream_by_id
@ -93,13 +93,25 @@ def invite_users_backend(
if len(streams) and not user_profile.can_subscribe_other_users():
raise JsonableError(_("You do not have permission to subscribe other users to channels."))
do_invite_users(
skipped = do_invite_users(
user_profile,
invitee_emails,
streams,
invite_expires_in_minutes=invite_expires_in_minutes,
invite_as=invite_as,
)
if skipped:
raise InvitationError(
_(
"Some of those addresses are already using Zulip, "
"so we didn't send them an invitation. We did send "
"invitations to everyone else!"
),
skipped,
sent_invitations=True,
)
return json_success(request)