diff --git a/zerver/actions/message_edit.py b/zerver/actions/message_edit.py index a36cf82429..43b0a6fd05 100644 --- a/zerver/actions/message_edit.py +++ b/zerver/actions/message_edit.py @@ -33,7 +33,7 @@ from zerver.lib.markdown import version as markdown_version from zerver.lib.mention import MentionBackend, MentionData, silent_mention_syntax_for_user from zerver.lib.message import ( access_message, - bulk_access_messages, + bulk_access_stream_messages_query, check_user_group_mention_allowed, normalize_body, stream_wildcard_mention_allowed, @@ -827,27 +827,23 @@ def do_update_message( # full-topic move. # # For security model reasons, we don't want to allow a - # user to take any action that would leak information - # about older messages they cannot access (E.g. the only - # remaining messages are in a stream without shared - # history). The bulk_access_messages call below addresses + # user to take any action (e.g. post a message about + # having not moved the whole topic) that would leak + # information about older messages they cannot access + # (e.g. there were earlier inaccessible messages in the + # topic, in a stream without shared history). The + # bulk_access_stream_messages_query call below addresses # that concern. - # - # bulk_access_messages is inefficient for this task, since - # we just want to do the exists() version of this - # query. But it's nice to reuse code, and this bulk - # operation is likely cheaper than a `GET /messages` - # unless the topic has thousands of messages of history. assert stream_being_edited.recipient_id is not None unmoved_messages = messages_for_topic( realm.id, stream_being_edited.recipient_id, orig_topic_name, ) - visible_unmoved_messages = bulk_access_messages( - user_profile, unmoved_messages, stream=stream_being_edited + visible_unmoved_messages = bulk_access_stream_messages_query( + user_profile, unmoved_messages, stream_being_edited ) - moved_all_visible_messages = len(visible_unmoved_messages) == 0 + moved_all_visible_messages = not visible_unmoved_messages.exists() # Migrate 'topic with visibility_policy' configuration in the following # circumstances: @@ -1064,24 +1060,15 @@ def do_update_message( # avoid leaking information about whether there are # messages in the destination topic's deeper history that # the acting user does not have permission to access. - # - # TODO: These queries are quite inefficient, in that we're - # fetching full copies of all the messages in the - # destination topic to answer the question of whether the - # current user has access to at least one such message. - # - # The main strength of the current implementation is that - # it reuses existing logic, which is good for keeping it - # correct as we maintain the codebase. preexisting_topic_messages = messages_for_topic( realm.id, stream_for_new_topic.recipient_id, new_topic_name ).exclude(id__in=[*changed_message_ids, resolved_topic_message_id]) - visible_preexisting_messages = bulk_access_messages( - user_profile, preexisting_topic_messages, stream=stream_for_new_topic + visible_preexisting_messages = bulk_access_stream_messages_query( + user_profile, preexisting_topic_messages, stream_for_new_topic ) - no_visible_preexisting_messages = len(visible_preexisting_messages) == 0 + no_visible_preexisting_messages = not visible_preexisting_messages.exists() if no_visible_preexisting_messages and moved_all_visible_messages: new_thread_notification_string = gettext_lazy( diff --git a/zerver/lib/message.py b/zerver/lib/message.py index f9872fbb67..7654373ad3 100644 --- a/zerver/lib/message.py +++ b/zerver/lib/message.py @@ -21,7 +21,7 @@ import ahocorasick import orjson from django.conf import settings from django.db import connection -from django.db.models import Max, QuerySet, Sum +from django.db.models import Exists, Max, OuterRef, QuerySet, Sum from django.utils.timezone import now as timezone_now from django.utils.translation import gettext as _ from django_stubs_ext import ValuesQuerySet @@ -996,6 +996,38 @@ def bulk_access_messages( return filtered_messages +def bulk_access_stream_messages_query( + user_profile: UserProfile, messages: QuerySet[Message], stream: Stream +) -> QuerySet[Message]: + """This function mirrors bulk_access_messages, above, but applies the + limits to a QuerySet and returns a new QuerySet which only + contains messages in the given stream which the user can access. + Note that this only works with streams. It may return an empty + QuerySet if the user has access to no messages (for instance, for + a private stream which the user is not subscribed to). + + """ + + messages = messages.filter(realm_id=user_profile.realm_id, recipient_id=stream.recipient_id) + + if stream.is_public() and user_profile.can_access_public_streams(): + return messages + + if not Subscription.objects.filter( + user_profile=user_profile, active=True, recipient=stream.recipient + ).exists(): + return Message.objects.none() + if not stream.is_history_public_to_subscribers(): + messages = messages.annotate( + has_usermessage=Exists( + UserMessage.objects.filter( + user_profile_id=user_profile.id, message_id=OuterRef("id") + ) + ) + ).filter(has_usermessage=1) + return messages + + def get_messages_with_usermessage_rows_for_user( user_profile_id: int, message_ids: Sequence[int] ) -> ValuesQuerySet[UserMessage, int]: diff --git a/zerver/lib/topic.py b/zerver/lib/topic.py index 4801cbff9e..0843528ca3 100644 --- a/zerver/lib/topic.py +++ b/zerver/lib/topic.py @@ -173,9 +173,9 @@ def update_messages_for_topic_edit( # If we're moving the messages between streams, only move # messages that the acting user can access, so that one cannot # gain access to messages through moving them. - from zerver.lib.message import bulk_access_messages + from zerver.lib.message import bulk_access_stream_messages_query - messages_list = bulk_access_messages(acting_user, messages, stream=old_stream) + messages_list = list(bulk_access_stream_messages_query(acting_user, messages, old_stream)) else: # For single-message edits or topic moves within a stream, we # allow moving history the user may not have access in order diff --git a/zerver/tests/test_message_edit.py b/zerver/tests/test_message_edit.py index e4ebd32065..200f3032d6 100644 --- a/zerver/tests/test_message_edit.py +++ b/zerver/tests/test_message_edit.py @@ -1528,7 +1528,7 @@ class EditMessageTest(EditMessageTestCase): set_topic_visibility_policy(desdemona, muted_topics, UserTopic.VisibilityPolicy.MUTED) set_topic_visibility_policy(cordelia, muted_topics, UserTopic.VisibilityPolicy.MUTED) - with self.assert_database_query_count(29): + with self.assert_database_query_count(28): check_update_message( user_profile=desdemona, message_id=message_id, @@ -1559,7 +1559,7 @@ class EditMessageTest(EditMessageTestCase): set_topic_visibility_policy(desdemona, muted_topics, UserTopic.VisibilityPolicy.MUTED) set_topic_visibility_policy(cordelia, muted_topics, UserTopic.VisibilityPolicy.MUTED) - with self.assert_database_query_count(34): + with self.assert_database_query_count(33): check_update_message( user_profile=desdemona, message_id=message_id, @@ -1592,7 +1592,7 @@ class EditMessageTest(EditMessageTestCase): set_topic_visibility_policy(desdemona, muted_topics, UserTopic.VisibilityPolicy.MUTED) set_topic_visibility_policy(cordelia, muted_topics, UserTopic.VisibilityPolicy.MUTED) - with self.assert_database_query_count(29): + with self.assert_database_query_count(28): check_update_message( user_profile=desdemona, message_id=message_id, @@ -1615,7 +1615,7 @@ class EditMessageTest(EditMessageTestCase): second_message_id = self.send_stream_message( hamlet, stream_name, topic_name="changed topic name", content="Second message" ) - with self.assert_database_query_count(25): + with self.assert_database_query_count(23): check_update_message( user_profile=desdemona, message_id=second_message_id, @@ -3783,7 +3783,7 @@ class EditMessageTest(EditMessageTestCase): "iago", "test move stream", "new stream", "test" ) - with self.assert_database_query_count(55), self.assert_memcached_count(14): + with self.assert_database_query_count(52), self.assert_memcached_count(14): result = self.client_patch( f"/json/messages/{msg_id}", { diff --git a/zerver/tests/test_message_flags.py b/zerver/tests/test_message_flags.py index bc906e2c07..f952463c59 100644 --- a/zerver/tests/test_message_flags.py +++ b/zerver/tests/test_message_flags.py @@ -19,6 +19,7 @@ from zerver.lib.message import ( aggregate_unread_data, apply_unread_message_event, bulk_access_messages, + bulk_access_stream_messages_query, format_unread_message_details, get_raw_unread_data, ) @@ -1505,6 +1506,30 @@ class MessageAccessTests(ZulipTestCase): result = self.change_star(message_id) self.assert_json_success(result) + def assert_bulk_access( + self, + user: UserProfile, + message_ids: List[int], + stream: Stream, + bulk_access_messages_count: int, + bulk_access_stream_messages_query_count: int, + ) -> List[Message]: + with self.assert_database_query_count(bulk_access_messages_count): + messages = [ + Message.objects.select_related("recipient").get(id=message_id) + for message_id in sorted(message_ids) + ] + list_result = bulk_access_messages(user, messages, stream=stream) + with self.assert_database_query_count(bulk_access_stream_messages_query_count): + message_query = ( + Message.objects.select_related("recipient") + .filter(id__in=message_ids) + .order_by("id") + ) + query_result = list(bulk_access_stream_messages_query(user, message_query, stream)) + self.assertEqual(query_result, list_result) + return list_result + def test_bulk_access_messages_private_stream(self) -> None: user = self.example_user("hamlet") self.login_user(user) @@ -1526,16 +1551,12 @@ class MessageAccessTests(ZulipTestCase): message_two_id = self.send_stream_message(user, stream_name, "Message two") message_ids = [message_one_id, message_two_id] - messages = [ - Message.objects.select_related("recipient").get(id=message_id) - for message_id in message_ids - ] - - with self.assert_database_query_count(2): - filtered_messages = bulk_access_messages(later_subscribed_user, messages, stream=stream) # Message sent before subscribing wouldn't be accessible by later # subscribed user as stream has protected history + filtered_messages = self.assert_bulk_access( + later_subscribed_user, message_ids, stream, 4, 2 + ) self.assert_length(filtered_messages, 1) self.assertEqual(filtered_messages[0].id, message_two_id) @@ -1547,27 +1568,44 @@ class MessageAccessTests(ZulipTestCase): acting_user=self.example_user("cordelia"), ) - with self.assert_database_query_count(2): - filtered_messages = bulk_access_messages(later_subscribed_user, messages, stream=stream) - - # Message sent before subscribing are accessible by 8user as stream - # don't have protected history + # Message sent before subscribing are accessible by user as stream + # now don't have protected history + filtered_messages = self.assert_bulk_access( + later_subscribed_user, message_ids, stream, 4, 2 + ) self.assert_length(filtered_messages, 2) # Testing messages accessibility for an unsubscribed user unsubscribed_user = self.example_user("ZOE") - - with self.assert_database_query_count(2): - filtered_messages = bulk_access_messages(unsubscribed_user, messages, stream=stream) - + filtered_messages = self.assert_bulk_access(unsubscribed_user, message_ids, stream, 4, 1) self.assert_length(filtered_messages, 0) + # Adding more message ids to the list increases the query size + # for bulk_access_messages but not + # bulk_access_stream_messages_query + more_message_ids = [ + *message_ids, + self.send_stream_message(user, stream_name, "Message three"), + self.send_stream_message(user, stream_name, "Message four"), + ] + filtered_messages = self.assert_bulk_access( + later_subscribed_user, more_message_ids, stream, 6, 2 + ) + self.assert_length(filtered_messages, 4) + # Verify an exception is thrown if called where the passed # stream not matching the messages. + other_stream = get_stream("Denmark", unsubscribed_user.realm) with self.assertRaises(AssertionError): - bulk_access_messages( - unsubscribed_user, messages, stream=get_stream("Denmark", unsubscribed_user.realm) - ) + messages = [Message.objects.get(id=id) for id in message_ids] + bulk_access_messages(unsubscribed_user, messages, stream=other_stream) + + # Verify that bulk_access_stream_messages_query is empty with a stream mismatch + message_query = Message.objects.select_related("recipient").filter(id__in=message_ids) + filtered_query = bulk_access_stream_messages_query( + later_subscribed_user, message_query, other_stream + ) + self.assert_length(filtered_query, 0) def test_bulk_access_messages_public_stream(self) -> None: user = self.example_user("hamlet") @@ -1585,20 +1623,15 @@ class MessageAccessTests(ZulipTestCase): message_two_id = self.send_stream_message(user, stream_name, "Message two") message_ids = [message_one_id, message_two_id] - messages = [ - Message.objects.select_related("recipient").get(id=message_id) - for message_id in message_ids - ] # All public stream messages are always accessible - with self.assert_database_query_count(2): - filtered_messages = bulk_access_messages(later_subscribed_user, messages, stream=stream) + filtered_messages = self.assert_bulk_access( + later_subscribed_user, message_ids, stream, 4, 1 + ) self.assert_length(filtered_messages, 2) unsubscribed_user = self.example_user("ZOE") - with self.assert_database_query_count(2): - filtered_messages = bulk_access_messages(unsubscribed_user, messages, stream=stream) - + filtered_messages = self.assert_bulk_access(unsubscribed_user, message_ids, stream, 4, 1) self.assert_length(filtered_messages, 2) diff --git a/zerver/views/streams.py b/zerver/views/streams.py index f3a1d3c66d..f085582288 100644 --- a/zerver/views/streams.py +++ b/zerver/views/streams.py @@ -54,6 +54,7 @@ from zerver.lib.exceptions import ( ResourceNotFoundError, ) from zerver.lib.mention import MentionBackend, silent_mention_syntax_for_user +from zerver.lib.message import bulk_access_stream_messages_query from zerver.lib.request import REQ, has_request_variables from zerver.lib.response import json_success from zerver.lib.retention import STREAM_MESSAGE_BATCH_SIZE as RETENTION_STREAM_MESSAGE_BATCH_SIZE @@ -99,7 +100,7 @@ from zerver.lib.validator import ( check_union, to_non_negative_int, ) -from zerver.models import Realm, Stream, UserGroup, UserMessage, UserProfile +from zerver.models import Realm, Stream, UserGroup, UserProfile from zerver.models.users import get_system_bot @@ -925,22 +926,9 @@ def delete_in_topic( messages = messages_for_topic( user_profile.realm_id, assert_is_not_none(stream.recipient_id), topic_name ) - # Note: It would be better to use bulk_access_messages here, which is our core function - # for obtaining the accessible messages - and it's good to use it wherever we can, - # so that we have a central place to keep up to date with our security model for - # message access. - # However, it fetches the full Message objects, which would be bad here for very large - # topics. - # The access_stream_by_id call above ensures that the acting user currently has access to the - # stream (which entails having an active Subscription in case of private streams), meaning - # that combined with the UserMessage check below, this is a sufficient replacement for - # bulk_access_messages. - if not stream.is_history_public_to_subscribers(): - # Don't allow the user to delete messages that they don't have access to. - deletable_message_ids = UserMessage.objects.filter( - user_profile=user_profile, message_id__in=messages - ).values_list("message_id", flat=True) - messages = messages.filter(id__in=deletable_message_ids) + # This handles applying access control, such that only messages + # the user can see are returned in the query. + messages = bulk_access_stream_messages_query(user_profile, messages, stream) def delete_in_batches() -> Literal[True]: # Topics can be large enough that this request will inevitably time out.