From b94402152ddc419412638bcaeb403bb4f19db667 Mon Sep 17 00:00:00 2001 From: Alex Vandiver Date: Wed, 30 Aug 2023 19:19:37 +0000 Subject: [PATCH] models: Always search Messages with a realm_id or id limit. Unless there is a limit on `id`, always provide a `realm_id` limit as well. We also notate which index is expected to be used in each query. --- analytics/lib/counts.py | 21 +++++++++-- analytics/views/installation_activity.py | 1 + analytics/views/realm_activity.py | 3 ++ corporate/tests/test_stripe.py | 2 +- tools/generate-integration-docs-screenshot | 6 ++-- tools/semgrep.yml | 17 +++++++++ zerver/actions/create_user.py | 7 +++- zerver/actions/invites.py | 6 +++- zerver/actions/message_delete.py | 5 ++- zerver/actions/message_edit.py | 8 +++-- zerver/actions/message_flags.py | 2 ++ zerver/actions/message_send.py | 4 ++- zerver/actions/realm_settings.py | 7 ++-- zerver/actions/streams.py | 26 ++++++++++---- zerver/actions/users.py | 10 ++++-- zerver/lib/digest.py | 5 ++- zerver/lib/email_notifications.py | 1 + zerver/lib/export.py | 21 +++++++++-- zerver/lib/home.py | 8 ++++- zerver/lib/message.py | 2 ++ zerver/lib/retention.py | 9 ++++- zerver/lib/scheduled_messages.py | 1 + zerver/lib/soft_deactivation.py | 5 ++- zerver/lib/topic.py | 23 +++++++++--- zerver/models.py | 6 ++++ zerver/tests/test_gitter_importer.py | 2 +- zerver/tests/test_import_export.py | 10 +++--- zerver/tests/test_message_fetch.py | 2 +- zerver/tests/test_realm.py | 42 ++++++++++++++++++---- zerver/tests/test_retention.py | 10 +++--- zerver/tests/test_signup.py | 16 ++++++--- zerver/tests/test_users.py | 16 +++++---- zerver/views/streams.py | 4 ++- zerver/worker/queue_processors.py | 6 +++- 34 files changed, 249 insertions(+), 65 deletions(-) diff --git a/analytics/lib/counts.py b/analytics/lib/counts.py index 382813f479..b0b3e06722 100644 --- a/analytics/lib/counts.py +++ b/analytics/lib/counts.py @@ -447,7 +447,13 @@ def count_message_by_user_query(realm: Optional[Realm]) -> QueryFn: if realm is None: realm_clause: Composable = SQL("") else: - realm_clause = SQL("zerver_userprofile.realm_id = {} AND").format(Literal(realm.id)) + # We limit both userprofile and message so that we only see + # users from this realm, but also get the performance speedup + # of limiting messages by realm. + realm_clause = SQL( + "zerver_userprofile.realm_id = {} AND zerver_message.realm_id = {} AND" + ).format(Literal(realm.id), Literal(realm.id)) + # Uses index: zerver_message_realm_date_sent (or the only-date index) return lambda kwargs: SQL( """ INSERT INTO analytics_usercount @@ -474,7 +480,13 @@ def count_message_type_by_user_query(realm: Optional[Realm]) -> QueryFn: if realm is None: realm_clause: Composable = SQL("") else: - realm_clause = SQL("zerver_userprofile.realm_id = {} AND").format(Literal(realm.id)) + # We limit both userprofile and message so that we only see + # users from this realm, but also get the performance speedup + # of limiting messages by realm. + realm_clause = SQL( + "zerver_userprofile.realm_id = {} AND zerver_message.realm_id = {} AND" + ).format(Literal(realm.id), Literal(realm.id)) + # Uses index: zerver_message_realm_date_sent (or the only-date index) return lambda kwargs: SQL( """ INSERT INTO analytics_usercount @@ -523,7 +535,10 @@ def count_message_by_stream_query(realm: Optional[Realm]) -> QueryFn: if realm is None: realm_clause: Composable = SQL("") else: - realm_clause = SQL("zerver_stream.realm_id = {} AND").format(Literal(realm.id)) + realm_clause = SQL( + "zerver_stream.realm_id = {} AND zerver_message.realm_id = {} AND" + ).format(Literal(realm.id), Literal(realm.id)) + # Uses index: zerver_message_realm_date_sent (or the only-date index) return lambda kwargs: SQL( """ INSERT INTO analytics_streamcount diff --git a/analytics/views/installation_activity.py b/analytics/views/installation_activity.py index e713658b5d..4bfd97e369 100644 --- a/analytics/views/installation_activity.py +++ b/analytics/views/installation_activity.py @@ -39,6 +39,7 @@ if settings.BILLING_ENABLED: def get_realm_day_counts() -> Dict[str, Dict[str, Markup]]: + # Uses index: zerver_message_date_sent_3b5b05d8 query = SQL( """ select diff --git a/analytics/views/realm_activity.py b/analytics/views/realm_activity.py index ffcdbf6769..d5246bda09 100644 --- a/analytics/views/realm_activity.py +++ b/analytics/views/realm_activity.py @@ -163,6 +163,7 @@ def sent_messages_report(realm: str) -> str: "Bots", ] + # Uses index: zerver_message_realm_date_sent query = SQL( """ select @@ -188,6 +189,8 @@ def sent_messages_report(realm: str) -> str: r.string_id = %s and date_sent > now() - interval '2 week' + and + m.realm_id = r.id group by date_sent::date order by diff --git a/corporate/tests/test_stripe.py b/corporate/tests/test_stripe.py index b17388eaa5..321d9ef4f5 100644 --- a/corporate/tests/test_stripe.py +++ b/corporate/tests/test_stripe.py @@ -2550,7 +2550,7 @@ class StripeTest(StripeTestCase): ) sender = get_system_bot(settings.NOTIFICATION_BOT, user.realm_id) recipient_id = self.example_user("desdemona").recipient_id - message = Message.objects.filter(sender=sender.id).first() + message = Message.objects.filter(realm_id=realm.id, sender=sender.id).first() assert message is not None self.assertEqual(message.content, expected_message) self.assertEqual(message.recipient.type, Recipient.PERSONAL) diff --git a/tools/generate-integration-docs-screenshot b/tools/generate-integration-docs-screenshot index 52e761fd7e..b8678eda44 100755 --- a/tools/generate-integration-docs-screenshot +++ b/tools/generate-integration-docs-screenshot @@ -139,7 +139,7 @@ def send_bot_mock_message( bot: UserProfile, integration: Integration, fixture_path: str, config: BaseScreenshotConfig ) -> None: # Delete all messages, so new message is the only one it's message group - Message.objects.filter(sender=bot).delete() + Message.objects.filter(realm_id=bot.realm_id, sender=bot).delete() data, _, _ = get_fixture_info(fixture_path) assert bot.bot_owner is not None @@ -166,7 +166,7 @@ def send_bot_payload_message( bot: UserProfile, integration: WebhookIntegration, fixture_path: str, config: ScreenshotConfig ) -> bool: # Delete all messages, so new message is the only one it's message group - Message.objects.filter(sender=bot).delete() + Message.objects.filter(realm_id=bot.realm_id, sender=bot).delete() data, json_fixture, fixture_name = get_fixture_info(fixture_path) headers = get_requests_headers(integration.name, fixture_name) @@ -217,7 +217,7 @@ def send_bot_payload_message( def capture_last_message_screenshot(bot: UserProfile, image_path: str) -> None: - message = Message.objects.filter(sender=bot).last() + message = Message.objects.filter(realm_id=bot.realm_id, sender=bot).last() realm = get_realm("zulip") if message is None: print(f"No message found for {bot.full_name}") diff --git a/tools/semgrep.yml b/tools/semgrep.yml index 00fda02b5b..d1ee2c6285 100644 --- a/tools/semgrep.yml +++ b/tools/semgrep.yml @@ -17,6 +17,23 @@ rules: include: - zerver/views/ + - id: limit-message-filter + patterns: + - pattern: Message.objects.filter(...) + - pattern-not: Message.objects.filter(..., realm=..., ...) + - pattern-not: Message.objects.filter(..., realm_id=..., ...) + - pattern-not: Message.objects.filter(..., realm_id__in=..., ...) + - pattern-not: Message.objects.filter(..., id=..., ...) + - pattern-not: Message.objects.filter(..., id__in=..., ...) + - pattern-not: Message.objects.filter(..., id__lt=..., ...) + - pattern-not: Message.objects.filter(..., id__gt=..., ...) + message: "Set either a realm limit or an id limit on Message queries" + languages: [python] + severity: ERROR + paths: + exclude: + - "**/migrations/" + - id: dont-import-models-in-migrations patterns: - pattern-not: from zerver.lib.redis_utils import get_redis_client diff --git a/zerver/actions/create_user.py b/zerver/actions/create_user.py index 4da6872adc..6031f9bb62 100644 --- a/zerver/actions/create_user.py +++ b/zerver/actions/create_user.py @@ -181,7 +181,12 @@ def add_new_user_history(user_profile: UserProfile, streams: Iterable[Stream]) - # Start by finding recent messages matching those recipients. cutoff_date = timezone_now() - ONBOARDING_RECENT_TIMEDELTA recent_message_ids = set( - Message.objects.filter(recipient_id__in=recipient_ids, date_sent__gt=cutoff_date) + Message.objects.filter( + # Uses index: zerver_message_realm_recipient_id + realm_id=user_profile.realm_id, + recipient_id__in=recipient_ids, + date_sent__gt=cutoff_date, + ) .order_by("-id") .values_list("id", flat=True)[0:MAX_NUM_ONBOARDING_MESSAGES] ) diff --git a/zerver/actions/invites.py b/zerver/actions/invites.py index 24a0649693..b0fd6d29ae 100644 --- a/zerver/actions/invites.py +++ b/zerver/actions/invites.py @@ -134,7 +134,11 @@ def too_many_recent_realm_invites(realm: Realm, num_invitees: int) -> bool: not estimated_sent["messages"] # Only after we've done the rough-estimate check, take the # time to do the exact check: - and not Message.objects.filter(realm=realm, sender__is_bot=False).exists() + and not Message.objects.filter( + # Uses index: zerver_message_realm_sender_recipient (prefix) + realm=realm, + sender__is_bot=False, + ).exists() ): warning_flags.append("no-messages-sent") diff --git a/zerver/actions/message_delete.py b/zerver/actions/message_delete.py index 7777793bba..6aca536b0a 100644 --- a/zerver/actions/message_delete.py +++ b/zerver/actions/message_delete.py @@ -59,7 +59,10 @@ def do_delete_messages(realm: Realm, messages: Iterable[Message]) -> None: def do_delete_messages_by_sender(user: UserProfile) -> None: message_ids = list( - Message.objects.filter(sender=user).values_list("id", flat=True).order_by("id") + # Uses index: zerver_message_realm_sender_recipient (prefix) + Message.objects.filter(realm_id=user.realm_id, sender=user) + .values_list("id", flat=True) + .order_by("id") ) if message_ids: move_messages_to_archive(message_ids, chunk_size=retention.STREAM_MESSAGE_BATCH_SIZE) diff --git a/zerver/actions/message_edit.py b/zerver/actions/message_edit.py index c62a17f671..ff64b1a9b2 100644 --- a/zerver/actions/message_edit.py +++ b/zerver/actions/message_edit.py @@ -611,7 +611,7 @@ def do_update_message( assert target_stream.recipient_id is not None target_topic_has_messages = messages_for_topic( - target_stream.recipient_id, target_topic + realm.id, target_stream.recipient_id, target_topic ).exists() if propagate_mode in ["change_later", "change_all"]: @@ -804,6 +804,7 @@ def do_update_message( # unless the topic has thousands of messages of history. assert stream_being_edited.recipient_id is not None unmoved_messages = messages_for_topic( + realm.id, stream_being_edited.recipient_id, orig_topic_name, ) @@ -1041,7 +1042,7 @@ def do_update_message( # it reuses existing logic, which is good for keeping it # correct as we maintain the codebase. preexisting_topic_messages = messages_for_topic( - stream_for_new_topic.recipient_id, new_topic + realm.id, stream_for_new_topic.recipient_id, new_topic ).exclude(id__in=[*changed_message_ids, resolved_topic_message_id]) visible_preexisting_messages = bulk_access_messages( @@ -1136,6 +1137,7 @@ def check_time_limit_for_change_all_propagate_mode( ).values_list("message_id", flat=True) messages_allowed_to_move: List[int] = list( Message.objects.filter( + # Uses index: zerver_message_pkey id__in=accessible_messages_in_topic, date_sent__gt=timezone_now() - datetime.timedelta(seconds=message_move_deadline_seconds), @@ -1146,7 +1148,7 @@ def check_time_limit_for_change_all_propagate_mode( total_messages_requested_to_move = len(accessible_messages_in_topic) else: all_messages_in_topic = ( - messages_for_topic(message.recipient_id, message.topic_name()) + messages_for_topic(message.realm_id, message.recipient_id, message.topic_name()) .order_by("id") .values_list("id", "date_sent") ) diff --git a/zerver/actions/message_flags.py b/zerver/actions/message_flags.py index f10f86c1f6..1b4da063bf 100644 --- a/zerver/actions/message_flags.py +++ b/zerver/actions/message_flags.py @@ -276,6 +276,7 @@ def do_update_message_flags( subscribed_recipient_ids = get_subscribed_stream_recipient_ids_for_user(user_profile) message_ids_in_unsubscribed_streams = set( + # Uses index: zerver_message_pkey Message.objects.select_related("recipient") .filter(id__in=messages, recipient__type=Recipient.STREAM) .exclude(recipient_id__in=subscribed_recipient_ids) @@ -326,6 +327,7 @@ def do_update_message_flags( historical_messages = bulk_access_messages( user_profile, list( + # Uses index: zerver_message_pkey Message.objects.filter(id__in=historical_message_ids).prefetch_related( "recipient" ) diff --git a/zerver/actions/message_send.py b/zerver/actions/message_send.py index c25ac74d8e..efb2ec4d19 100644 --- a/zerver/actions/message_send.py +++ b/zerver/actions/message_send.py @@ -232,7 +232,7 @@ def get_recipient_info( # has syntax that might be a @topic mention without having confirmed the syntax isn't, say, # in a code block. topic_participant_user_ids = participants_for_topic( - recipient.id, stream_topic.topic_name + realm_id, recipient.id, stream_topic.topic_name ) subscription_rows = ( get_subscriptions_for_send_message( @@ -1085,6 +1085,8 @@ def already_sent_mirrored_message_id(message: Message) -> Optional[int]: time_window = datetime.timedelta(seconds=0) messages = Message.objects.filter( + # Uses index: zerver_message_realm_recipient_subject + realm_id=message.realm_id, sender=message.sender, recipient=message.recipient, subject=message.topic_name(), diff --git a/zerver/actions/realm_settings.py b/zerver/actions/realm_settings.py index 620432754f..ec453339c6 100644 --- a/zerver/actions/realm_settings.py +++ b/zerver/actions/realm_settings.py @@ -420,8 +420,11 @@ def do_scrub_realm(realm: Realm, *, acting_user: Optional[UserProfile]) -> None: ) cross_realm_bot_message_ids = list( Message.objects.filter( - # Filtering by both message.recipient and message.realm is more robust for ensuring - # no messages belonging to another realm will be deleted due to some bugs. + # Filtering by both message.recipient and message.realm is + # more robust for ensuring no messages belonging to + # another realm will be deleted due to some bugs. + # + # Uses index: zerver_message_realm_sender_recipient sender__realm=internal_realm, recipient_id__in=all_recipient_ids_in_realm, realm=realm, diff --git a/zerver/actions/streams.py b/zerver/actions/streams.py index e205e2b493..8ccf1c16ea 100644 --- a/zerver/actions/streams.py +++ b/zerver/actions/streams.py @@ -193,7 +193,11 @@ def do_reactivate_stream( # Update caches cache_set(display_recipient_cache_key(stream.recipient_id), new_name) - messages = Message.objects.filter(recipient_id=stream.recipient_id).only("id") + messages = Message.objects.filter( + # Uses index: zerver_message_realm_recipient_id + realm_id=realm.id, + recipient_id=stream.recipient_id, + ).only("id") cache_delete_many(to_dict_cache_key_id(message.id) for message in messages) # Unset the is_web_public cache on attachments, since the stream is now private. @@ -284,11 +288,17 @@ def merge_streams( # this before removing the subscription objects, to avoid messages # "disappearing" if an error interrupts this function. message_ids_to_clear = list( - Message.objects.filter(recipient=recipient_to_destroy).values_list("id", flat=True) - ) - count = Message.objects.filter(recipient=recipient_to_destroy).update( - recipient=recipient_to_keep + Message.objects.filter( + # Uses index: zerver_message_realm_recipient_id + realm_id=realm.id, + recipient=recipient_to_destroy, + ).values_list("id", flat=True) ) + count = Message.objects.filter( + # Uses index: zerver_message_realm_recipient_id (prefix) + realm_id=realm.id, + recipient=recipient_to_destroy, + ).update(recipient=recipient_to_keep) bulk_delete_cache_keys(message_ids_to_clear) # Remove subscriptions to the old stream. @@ -1185,7 +1195,11 @@ def do_rename_stream(stream: Stream, new_name: str, user_profile: UserProfile) - assert stream.recipient_id is not None recipient_id: int = stream.recipient_id - messages = Message.objects.filter(recipient_id=recipient_id).only("id") + messages = Message.objects.filter( + # Uses index: zerver_message_realm_recipient_id + realm_id=stream.realm_id, + recipient_id=recipient_id, + ).only("id") cache_set(display_recipient_cache_key(recipient_id), stream.name) diff --git a/zerver/actions/users.py b/zerver/actions/users.py index 750f16dafe..1fe2a703ff 100644 --- a/zerver/actions/users.py +++ b/zerver/actions/users.py @@ -189,7 +189,10 @@ def do_delete_user_preserving_messages(user_profile: UserProfile) -> None: force_date_joined=date_joined, create_personal_recipient=False, ) - Message.objects.filter(sender=user_profile).update(sender=temp_replacement_user) + # Uses index: zerver_message_realm_sender_recipient (prefix) + Message.objects.filter(realm_id=realm.id, sender=user_profile).update( + sender=temp_replacement_user + ) Subscription.objects.filter( user_profile=user_profile, recipient__type=Recipient.HUDDLE ).update(user_profile=temp_replacement_user) @@ -212,7 +215,10 @@ def do_delete_user_preserving_messages(user_profile: UserProfile) -> None: replacement_user.recipient = personal_recipient replacement_user.save(update_fields=["recipient"]) - Message.objects.filter(sender=temp_replacement_user).update(sender=replacement_user) + # Uses index: zerver_message_realm_sender_recipient (prefix) + Message.objects.filter(realm_id=realm.id, sender=temp_replacement_user).update( + sender=replacement_user + ) Subscription.objects.filter( user_profile=temp_replacement_user, recipient__type=Recipient.HUDDLE ).update(user_profile=replacement_user, is_user_active=replacement_user.is_active) diff --git a/zerver/lib/digest.py b/zerver/lib/digest.py index 0c6a824c1c..e184947df3 100644 --- a/zerver/lib/digest.py +++ b/zerver/lib/digest.py @@ -162,6 +162,7 @@ def _enqueue_emails_for_realm(realm: Realm, cutoff: datetime.datetime) -> None: def get_recent_topics( + realm_id: int, stream_ids: List[int], cutoff_date: datetime.datetime, ) -> List[DigestTopic]: @@ -171,7 +172,9 @@ def get_recent_topics( # * number of senders messages = ( + # Uses index: zerver_message_realm_recipient_date_sent Message.objects.filter( + realm_id=realm_id, recipient__type=Recipient.STREAM, recipient__type_id__in=stream_ids, date_sent__gt=cutoff_date, @@ -307,7 +310,7 @@ def bulk_get_digest_context( # Get all the recent topics for all the users. This does the heavy # lifting of making an expensive query to the Message table. Then # for each user, we filter to just the streams they care about. - recent_topics = get_recent_topics(sorted(all_stream_ids), cutoff_date) + recent_topics = get_recent_topics(realm.id, sorted(all_stream_ids), cutoff_date) stream_map = get_slim_stream_map(all_stream_ids) diff --git a/zerver/lib/email_notifications.py b/zerver/lib/email_notifications.py index 1c05a7d753..d7d445ae5b 100644 --- a/zerver/lib/email_notifications.py +++ b/zerver/lib/email_notifications.py @@ -629,6 +629,7 @@ def handle_missedmessage_emails( # messages that were permanently deleted, since those would now be # in the ArchivedMessage table, not the Message table. messages = Message.objects.filter( + # Uses index: zerver_message_pkey usermessage__user_profile_id=user_profile, id__in=message_ids, usermessage__flags=~UserMessage.flags.read, diff --git a/zerver/lib/export.py b/zerver/lib/export.py index 509d63c473..c1d44189f5 100644 --- a/zerver/lib/export.py +++ b/zerver/lib/export.py @@ -1314,6 +1314,8 @@ def export_partial_message_files( if public_only: messages_we_received = Message.objects.filter( + # Uses index: zerver_message_realm_sender_recipient + realm_id=realm.id, sender__in=ids_of_our_possible_senders, recipient__in=recipient_ids_for_us, ) @@ -1329,6 +1331,8 @@ def export_partial_message_files( # anyone in the export and received by any of the users who we # have consent to export. messages_we_received = Message.objects.filter( + # Uses index: zerver_message_realm_sender_recipient + realm_id=realm.id, sender__in=ids_of_our_possible_senders, recipient__in=recipient_ids_for_us, ) @@ -1345,6 +1349,8 @@ def export_partial_message_files( messages_we_received_in_protected_history_streams = Message.objects.annotate( has_usermessage=has_usermessage_expression ).filter( + # Uses index: zerver_message_realm_sender_recipient + realm_id=realm.id, sender__in=ids_of_our_possible_senders, recipient_id__in=( set(consented_recipient_ids) & set(streams_with_protected_history_recipient_ids) @@ -1370,6 +1376,8 @@ def export_partial_message_files( recipient_ids_for_them = get_ids(recipients_for_them) messages_we_sent_to_them = Message.objects.filter( + # Uses index: zerver_message_realm_sender_recipient + realm_id=realm.id, sender__in=consented_user_ids, recipient__in=recipient_ids_for_them, ) @@ -1410,6 +1418,7 @@ def write_message_partials( dump_file_id = 1 for message_id_chunk in message_id_chunks: + # Uses index: zerver_message_pkey actual_query = Message.objects.filter(id__in=message_id_chunk).order_by("id") message_chunk = make_raw(actual_query) @@ -2253,13 +2262,21 @@ def export_messages_single_user( return ", ".join(user_names) - messages_from_me = Message.objects.filter(sender=user_profile) + messages_from_me = Message.objects.filter( + # Uses index: zerver_message_realm_sender_recipient (prefix) + realm_id=user_profile.realm_id, + sender=user_profile, + ) my_subscriptions = Subscription.objects.filter( user_profile=user_profile, recipient__type__in=[Recipient.PERSONAL, Recipient.HUDDLE] ) my_recipient_ids = [sub.recipient_id for sub in my_subscriptions] - messages_to_me = Message.objects.filter(recipient_id__in=my_recipient_ids) + messages_to_me = Message.objects.filter( + # Uses index: zerver_message_realm_recipient_id (prefix) + realm_id=user_profile.realm_id, + recipient_id__in=my_recipient_ids, + ) # Find all message ids that pertain to us. all_message_ids: Set[int] = set() diff --git a/zerver/lib/home.py b/zerver/lib/home.py index 86de33e8db..eb1741bf41 100644 --- a/zerver/lib/home.py +++ b/zerver/lib/home.py @@ -228,7 +228,13 @@ def build_page_params_for_home_page_load( # In narrow_stream context, initial pointer is just latest message recipient = narrow_stream.recipient page_params["max_message_id"] = -1 - max_message = Message.objects.filter(recipient=recipient).order_by("-id").only("id").first() + max_message = ( + # Uses index: zerver_message_realm_recipient_id + Message.objects.filter(realm_id=realm.id, recipient=recipient) + .order_by("-id") + .only("id") + .first() + ) if max_message: page_params["max_message_id"] = max_message.id page_params["narrow_stream"] = narrow_stream.name diff --git a/zerver/lib/message.py b/zerver/lib/message.py index 5e049d290c..0d2286257a 100644 --- a/zerver/lib/message.py +++ b/zerver/lib/message.py @@ -511,6 +511,7 @@ class MessageDict: "sending_client__name", "sender__realm_id", ] + # Uses index: zerver_message_pkey messages = Message.objects.filter(id__in=needed_ids).values(*fields) return MessageDict.sew_submessages_and_reactions_to_msgs(messages) @@ -1476,6 +1477,7 @@ def update_first_visible_message_id(realm: Realm) -> None: else: try: first_visible_message_id = ( + # Uses index: zerver_message_realm_id Message.objects.filter(realm=realm) .values("id") .order_by("-id")[realm.message_visibility_limit - 1]["id"] diff --git a/zerver/lib/retention.py b/zerver/lib/retention.py index 1510aef7ce..2156c36446 100644 --- a/zerver/lib/retention.py +++ b/zerver/lib/retention.py @@ -189,12 +189,14 @@ def move_expired_messages_to_archive_by_recipient( ) -> int: assert message_retention_days != -1 + # Uses index: zerver_message_realm_recipient_date_sent query = SQL( """ INSERT INTO zerver_archivedmessage ({dst_fields}, archive_transaction_id) SELECT {src_fields}, {archive_transaction_id} FROM zerver_message - WHERE zerver_message.recipient_id = {recipient_id} + WHERE zerver_message.realm_id = {realm_id} + AND zerver_message.recipient_id = {recipient_id} AND zerver_message.date_sent < {check_date} LIMIT {chunk_size} ON CONFLICT (id) DO UPDATE SET archive_transaction_id = {archive_transaction_id} @@ -207,6 +209,7 @@ def move_expired_messages_to_archive_by_recipient( query, type=ArchiveTransaction.RETENTION_POLICY_BASED, realm=realm, + realm_id=Literal(realm.id), recipient_id=Literal(recipient.id), check_date=Literal(check_date.isoformat()), chunk_size=chunk_size, @@ -224,6 +227,7 @@ def move_expired_personal_and_huddle_messages_to_archive( recipient_types = (Recipient.PERSONAL, Recipient.HUDDLE) # Archive expired personal and huddle Messages in the realm, including cross-realm messages. + # Uses index: zerver_message_realm_recipient_date_sent query = SQL( """ INSERT INTO zerver_archivedmessage ({dst_fields}, archive_transaction_id) @@ -318,6 +322,8 @@ def delete_messages(msg_ids: List[int]) -> None: # key to Message (due to `on_delete=CASCADE` in our models # configuration), so we need to be sure we've taken care of # archiving the messages before doing this step. + # + # Uses index: zerver_message_pkey Message.objects.filter(id__in=msg_ids).delete() @@ -453,6 +459,7 @@ def get_realms_and_streams_for_archiving() -> List[Tuple[Realm, List[Stream]]]: def move_messages_to_archive( message_ids: List[int], realm: Optional[Realm] = None, chunk_size: int = MESSAGE_BATCH_SIZE ) -> None: + # Uses index: zerver_message_pkey query = SQL( """ INSERT INTO zerver_archivedmessage ({dst_fields}, archive_transaction_id) diff --git a/zerver/lib/scheduled_messages.py b/zerver/lib/scheduled_messages.py index 1cba74856c..39fbc625c8 100644 --- a/zerver/lib/scheduled_messages.py +++ b/zerver/lib/scheduled_messages.py @@ -24,6 +24,7 @@ def get_undelivered_scheduled_messages( user_profile: UserProfile, ) -> List[Union[APIScheduledDirectMessageDict, APIScheduledStreamMessageDict]]: scheduled_messages = ScheduledMessage.objects.filter( + realm_id=user_profile.realm_id, sender=user_profile, # Notably, we don't require failed=False, since we will want # to display those to users. diff --git a/zerver/lib/soft_deactivation.py b/zerver/lib/soft_deactivation.py index f1eee37f42..7a25f8f3c4 100644 --- a/zerver/lib/soft_deactivation.py +++ b/zerver/lib/soft_deactivation.py @@ -210,7 +210,10 @@ def add_missing_messages(user_profile: UserProfile) -> None: all_stream_msgs = list( Message.objects.filter( - recipient_id__in=recipient_ids, id__gt=user_profile.last_active_message_id + # Uses index: zerver_message_realm_recipient_id + realm_id=user_profile.realm_id, + recipient_id__in=recipient_ids, + id__gt=user_profile.last_active_message_id, ) .order_by("id") .values("id", "recipient__type_id") diff --git a/zerver/lib/topic.py b/zerver/lib/topic.py index a2353a63ed..66e6675858 100644 --- a/zerver/lib/topic.py +++ b/zerver/lib/topic.py @@ -92,8 +92,12 @@ def filter_by_topic_name_via_message( return query.filter(message__subject__iexact=topic_name) -def messages_for_topic(stream_recipient_id: int, topic_name: str) -> QuerySet[Message]: +def messages_for_topic( + realm_id: int, stream_recipient_id: int, topic_name: str +) -> QuerySet[Message]: return Message.objects.filter( + # Uses index: zerver_message_realm_recipient_upper_subject + realm_id=realm_id, recipient_id=stream_recipient_id, subject__iexact=topic_name, ) @@ -149,13 +153,17 @@ def update_messages_for_topic_edit( edit_history_event: EditHistoryEvent, last_edit_time: datetime, ) -> List[Message]: - propagate_query = Q(recipient_id=old_stream.recipient_id, subject__iexact=orig_topic_name) + propagate_query = Q( + recipient_id=old_stream.recipient_id, + subject__iexact=orig_topic_name, + ) if propagate_mode == "change_all": propagate_query = propagate_query & ~Q(id=edited_message.id) if propagate_mode == "change_later": propagate_query = propagate_query & Q(id__gt=edited_message.id) - messages = Message.objects.filter(propagate_query).select_related( + # Uses index: zerver_message_realm_recipient_upper_subject + messages = Message.objects.filter(propagate_query, realm_id=old_stream.realm_id).select_related( *Message.DEFAULT_SELECT_RELATED ) @@ -283,12 +291,17 @@ def get_topic_resolution_and_bare_name(stored_name: str) -> Tuple[bool, str]: return (False, stored_name) -def participants_for_topic(recipient_id: int, topic_name: str) -> Set[int]: +def participants_for_topic(realm_id: int, recipient_id: int, topic_name: str) -> Set[int]: """ Users who either sent or reacted to the messages in the topic. The function is expensive for large numbers of messages in the topic. """ - messages = Message.objects.filter(recipient_id=recipient_id, subject__iexact=topic_name) + messages = Message.objects.filter( + # Uses index: zerver_message_realm_recipient_upper_subject + realm_id=realm_id, + recipient_id=recipient_id, + subject__iexact=topic_name, + ) participants = set( UserProfile.objects.filter( Q(id__in=Subquery(messages.values("sender_id"))) diff --git a/zerver/models.py b/zerver/models.py index 4045be6a4b..27a15d4c43 100644 --- a/zerver/models.py +++ b/zerver/models.py @@ -3189,6 +3189,8 @@ class Message(AbstractMessage): def get_context_for_message(message: Message) -> QuerySet[Message]: return Message.objects.filter( + # Uses index: zerver_message_realm_recipient_upper_subject + realm_id=message.realm_id, recipient_id=message.recipient_id, subject__iexact=message.subject, id__lt=message.id, @@ -3676,6 +3678,8 @@ def validate_attachment_request_for_spectator_access( Attachment.objects.filter(id=attachment.id, is_web_public__isnull=True).update( is_web_public=Exists( Message.objects.filter( + # Uses index: zerver_attachment_messages_attachment_id_message_id_key + realm_id=realm.id, attachment=OuterRef("id"), recipient__stream__invite_only=False, recipient__stream__is_web_public=True, @@ -3723,6 +3727,8 @@ def validate_attachment_request( Attachment.objects.filter(id=attachment.id, is_realm_public__isnull=True).update( is_realm_public=Exists( Message.objects.filter( + # Uses index: zerver_attachment_messages_attachment_id_message_id_key + realm_id=user_profile.realm_id, attachment=OuterRef("id"), recipient__stream__invite_only=False, ), diff --git a/zerver/tests/test_gitter_importer.py b/zerver/tests/test_gitter_importer.py index e09f997617..cdb4c2ed6a 100644 --- a/zerver/tests/test_gitter_importer.py +++ b/zerver/tests/test_gitter_importer.py @@ -131,7 +131,7 @@ class GitterImporter(ZulipTestCase): # test rendered_messages realm_users = UserProfile.objects.filter(realm=realm) - messages = Message.objects.filter(sender__in=realm_users) + messages = Message.objects.filter(realm_id=realm.id, sender__in=realm_users) for message in messages: self.assertIsNotNone(message.rendered_content, None) diff --git a/zerver/tests/test_import_export.py b/zerver/tests/test_import_export.py index e82dc074ca..9ae63c0df8 100644 --- a/zerver/tests/test_import_export.py +++ b/zerver/tests/test_import_export.py @@ -643,7 +643,7 @@ class RealmImportExportTest(ExportFile): type_id__in=public_stream_ids, type=Recipient.STREAM ) public_stream_message_ids = Message.objects.filter( - recipient__in=public_stream_recipients + realm_id=realm.id, recipient__in=public_stream_recipients ).values_list("id", flat=True) # Messages from Private stream C are not exported since no member gave consent @@ -656,7 +656,7 @@ class RealmImportExportTest(ExportFile): type_id__in=private_stream_ids, type=Recipient.STREAM ) private_stream_message_ids = Message.objects.filter( - recipient__in=private_stream_recipients + realm_id=realm.id, recipient__in=private_stream_recipients ).values_list("id", flat=True) pm_recipients = Recipient.objects.filter( @@ -664,7 +664,7 @@ class RealmImportExportTest(ExportFile): ) pm_query = Q(recipient__in=pm_recipients) | Q(sender__in=consented_user_ids) exported_pm_ids = ( - Message.objects.filter(pm_query) + Message.objects.filter(pm_query, realm=realm.id) .values_list("id", flat=True) .values_list("id", flat=True) ) @@ -676,7 +676,7 @@ class RealmImportExportTest(ExportFile): ) pm_query = Q(recipient__in=huddle_recipients) | Q(sender__in=consented_user_ids) exported_huddle_ids = ( - Message.objects.filter(pm_query) + Message.objects.filter(pm_query, realm=realm.id) .values_list("id", flat=True) .values_list("id", flat=True) ) @@ -1260,7 +1260,7 @@ class RealmImportExportTest(ExportFile): # test messages def get_stream_messages(r: Realm) -> QuerySet[Message]: recipient = get_recipient_stream(r) - messages = Message.objects.filter(recipient=recipient) + messages = Message.objects.filter(realm_id=r.id, recipient=recipient) return messages @getter diff --git a/zerver/tests/test_message_fetch.py b/zerver/tests/test_message_fetch.py index bd4e8adf42..78aac7dcdb 100644 --- a/zerver/tests/test_message_fetch.py +++ b/zerver/tests/test_message_fetch.py @@ -2289,7 +2289,7 @@ class GetOldMessagesTest(ZulipTestCase): stream_names = ["Scotland", "Verona", "Venice"] def send_messages_to_all_streams() -> None: - Message.objects.filter(recipient__type=Recipient.STREAM).delete() + Message.objects.filter(realm_id=realm.id, recipient__type=Recipient.STREAM).delete() for stream_name in stream_names: self.subscribe(hamlet, stream_name) for i in range(num_messages_per_stream): diff --git a/zerver/tests/test_realm.py b/zerver/tests/test_realm.py index 5ebc2eecf2..75ffe83966 100644 --- a/zerver/tests/test_realm.py +++ b/zerver/tests/test_realm.py @@ -1573,9 +1573,24 @@ class ScrubRealmTest(ZulipTestCase): CustomProfileField.objects.create(realm=lear) - self.assertEqual(Message.objects.filter(sender__in=[iago, othello]).count(), 10) - self.assertEqual(Message.objects.filter(sender__in=[cordelia, king]).count(), 10) - self.assertEqual(Message.objects.filter(sender=notification_bot).count(), 6) + self.assertEqual( + Message.objects.filter( + realm_id__in=(zulip.id, lear.id), sender__in=[iago, othello] + ).count(), + 10, + ) + self.assertEqual( + Message.objects.filter( + realm_id__in=(zulip.id, lear.id), sender__in=[cordelia, king] + ).count(), + 10, + ) + self.assertEqual( + Message.objects.filter( + realm_id__in=(zulip.id, lear.id), sender=notification_bot + ).count(), + 6, + ) self.assertEqual(UserMessage.objects.filter(user_profile__in=[iago, othello]).count(), 25) self.assertEqual(UserMessage.objects.filter(user_profile__in=[cordelia, king]).count(), 25) @@ -1584,9 +1599,24 @@ class ScrubRealmTest(ZulipTestCase): with self.assertLogs(level="WARNING"): do_scrub_realm(zulip, acting_user=None) - self.assertEqual(Message.objects.filter(sender__in=[iago, othello]).count(), 0) - self.assertEqual(Message.objects.filter(sender__in=[cordelia, king]).count(), 10) - self.assertEqual(Message.objects.filter(sender=notification_bot).count(), 3) + self.assertEqual( + Message.objects.filter( + realm_id__in=(zulip.id, lear.id), sender__in=[iago, othello] + ).count(), + 0, + ) + self.assertEqual( + Message.objects.filter( + realm_id__in=(zulip.id, lear.id), sender__in=[cordelia, king] + ).count(), + 10, + ) + self.assertEqual( + Message.objects.filter( + realm_id__in=(zulip.id, lear.id), sender=notification_bot + ).count(), + 3, + ) self.assertEqual(UserMessage.objects.filter(user_profile__in=[iago, othello]).count(), 0) self.assertEqual(UserMessage.objects.filter(user_profile__in=[cordelia, king]).count(), 25) diff --git a/zerver/tests/test_retention.py b/zerver/tests/test_retention.py index 74a3767330..8f6d7b61e8 100644 --- a/zerver/tests/test_retention.py +++ b/zerver/tests/test_retention.py @@ -697,7 +697,9 @@ class MoveMessageToArchiveGeneral(MoveMessageToArchiveBase): ) for attachment_id in attachment_ids: attachment_id_to_message_ids[attachment_id] = list( - Message.objects.filter(attachment__id=attachment_id).values_list("id", flat=True), + Message.objects.filter(realm_id=realm_id, attachment__id=attachment_id).values_list( + "id", flat=True + ), ) usermsg_ids = self._get_usermessage_ids(msg_ids) @@ -736,9 +738,9 @@ class MoveMessageToArchiveGeneral(MoveMessageToArchiveBase): self.assertEqual( set(attachment_id_to_message_ids[attachment_id]), set( - Message.objects.filter(attachment__id=attachment_id).values_list( - "id", flat=True - ) + Message.objects.filter( + realm_id=realm_id, attachment__id=attachment_id + ).values_list("id", flat=True) ), ) diff --git a/zerver/tests/test_signup.py b/zerver/tests/test_signup.py index 0aeb94665c..0c9cc12b12 100644 --- a/zerver/tests/test_signup.py +++ b/zerver/tests/test_signup.py @@ -1306,13 +1306,17 @@ class RealmCreationTest(ZulipTestCase): ]: stream = get_stream(stream_name, realm) recipient = stream.recipient - messages = Message.objects.filter(recipient=recipient).order_by("date_sent") + messages = Message.objects.filter(realm_id=realm.id, recipient=recipient).order_by( + "date_sent" + ) self.assert_length(messages, message_count) self.assertIn(text, messages[0].content) # Check admin organization's signups stream messages recipient = signups_stream.recipient - messages = Message.objects.filter(recipient=recipient).order_by("id") + messages = Message.objects.filter(realm_id=internal_realm.id, recipient=recipient).order_by( + "id" + ) self.assert_length(messages, 1) # Check organization name, subdomain and organization type are in message content self.assertIn("Zulip Test", messages[0].content) @@ -1610,7 +1614,9 @@ class RealmCreationTest(ZulipTestCase): # Make sure the correct Welcome Bot direct message is sent. welcome_msg = Message.objects.filter( - sender__email="welcome-bot@zulip.com", recipient__type=Recipient.PERSONAL + realm_id=get_realm(string_id).id, + sender__email="welcome-bot@zulip.com", + recipient__type=Recipient.PERSONAL, ).latest("id") self.assertTrue(welcome_msg.content.startswith("Hello, and welcome to Zulip!")) @@ -1661,7 +1667,9 @@ class RealmCreationTest(ZulipTestCase): # Make sure the correct Welcome Bot direct message is sent. welcome_msg = Message.objects.filter( - sender__email="welcome-bot@zulip.com", recipient__type=Recipient.PERSONAL + realm_id=get_realm(string_id).id, + sender__email="welcome-bot@zulip.com", + recipient__type=Recipient.PERSONAL, ).latest("id") self.assertTrue(welcome_msg.content.startswith("Hello, and welcome to Zulip!")) diff --git a/zerver/tests/test_users.py b/zerver/tests/test_users.py index d015ab52c9..c5a6b38d03 100644 --- a/zerver/tests/test_users.py +++ b/zerver/tests/test_users.py @@ -2499,10 +2499,10 @@ class DeleteUserTest(ZulipTestCase): self.send_personal_message(hamlet, cordelia) personal_message_ids_to_hamlet = Message.objects.filter( - recipient=hamlet_personal_recipient + realm_id=realm.id, recipient=hamlet_personal_recipient ).values_list("id", flat=True) self.assertGreater(len(personal_message_ids_to_hamlet), 0) - self.assertTrue(Message.objects.filter(sender=hamlet).exists()) + self.assertTrue(Message.objects.filter(realm_id=realm.id, sender=hamlet).exists()) huddle_message_ids_from_cordelia = [ self.send_huddle_message(cordelia, [hamlet, othello]) for i in range(3) @@ -2535,7 +2535,9 @@ class DeleteUserTest(ZulipTestCase): self.assertEqual(Message.objects.filter(id__in=huddle_message_ids_from_hamlet).count(), 0) self.assertEqual(Message.objects.filter(id__in=huddle_message_ids_from_cordelia).count(), 3) - self.assertEqual(Message.objects.filter(sender_id=hamlet_user_id).count(), 0) + self.assertEqual( + Message.objects.filter(realm_id=realm.id, sender_id=hamlet_user_id).count(), 0 + ) # Verify that the dummy user is subscribed to the deleted user's huddles, to keep huddle data # in a correct state. @@ -2564,10 +2566,10 @@ class DeleteUserTest(ZulipTestCase): self.send_personal_message(hamlet, cordelia) personal_message_ids_to_hamlet = Message.objects.filter( - recipient=hamlet_personal_recipient + realm_id=realm.id, recipient=hamlet_personal_recipient ).values_list("id", flat=True) self.assertGreater(len(personal_message_ids_to_hamlet), 0) - self.assertTrue(Message.objects.filter(sender=hamlet).exists()) + self.assertTrue(Message.objects.filter(realm_id=realm.id, sender=hamlet).exists()) huddle_message_ids_from_cordelia = [ self.send_huddle_message(cordelia, [hamlet, othello]) for i in range(3) @@ -2584,7 +2586,7 @@ class DeleteUserTest(ZulipTestCase): self.assertGreater(len(huddle_with_hamlet_recipient_ids), 0) original_messages_from_hamlet_count = Message.objects.filter( - sender_id=hamlet_user_id + realm_id=realm.id, sender_id=hamlet_user_id ).count() self.assertGreater(original_messages_from_hamlet_count, 0) @@ -2614,7 +2616,7 @@ class DeleteUserTest(ZulipTestCase): ) self.assertEqual( - Message.objects.filter(sender_id=hamlet_user_id).count(), + Message.objects.filter(realm_id=realm.id, sender_id=hamlet_user_id).count(), original_messages_from_hamlet_count, ) diff --git a/zerver/views/streams.py b/zerver/views/streams.py index c43d6cef99..5c0cf8dc08 100644 --- a/zerver/views/streams.py +++ b/zerver/views/streams.py @@ -923,7 +923,9 @@ def delete_in_topic( ) -> HttpResponse: stream, ignored_sub = access_stream_by_id(user_profile, stream_id) - messages = messages_for_topic(assert_is_not_none(stream.recipient_id), topic_name) + messages = messages_for_topic( + user_profile.realm_id, assert_is_not_none(stream.recipient_id), topic_name + ) # Note: It would be better to use bulk_access_messages here, which is our core function # for obtaining the accessible messages - and it's good to use it wherever we can, # so that we have a central place to keep up to date with our security model for diff --git a/zerver/worker/queue_processors.py b/zerver/worker/queue_processors.py index 4aac565da2..170541b50f 100644 --- a/zerver/worker/queue_processors.py +++ b/zerver/worker/queue_processors.py @@ -97,6 +97,7 @@ from zerver.models import ( Realm, RealmAuditLog, ScheduledMessageNotificationEmail, + Stream, UserMessage, UserProfile, filter_to_valid_prereg_users, @@ -1013,12 +1014,15 @@ class DeferredWorker(QueueProcessingWorker): "Marking messages as read for all users, stream_recipient_id %s", event["stream_recipient_id"], ) + stream = Stream.objects.get(recipient_id=event["stream_recipient_id"]) # This event is generated by the stream deactivation code path. batch_size = 100 offset = 0 while True: messages = Message.objects.filter( - recipient_id=event["stream_recipient_id"] + # Uses index: zerver_message_realm_recipient_id + realm_id=stream.realm_id, + recipient_id=event["stream_recipient_id"], ).order_by("id")[offset : offset + batch_size] with transaction.atomic(savepoint=False):