message: Add a bulk_access_stream_messages_query method.

This applies access restrictions in SQL, so that individual messages
do not need to be walked one-by-one.  It only functions for stream
messages.

Use of this method significantly speeds up checks if we moved "all
visible messages" in a topic, since we no longer need to walk every
remaining message in the old topic to determine that at least one was
visible to the user.  Similarly, it significantly speeds up merging
into existing topics, since it no longer must walk every message in
the new topic to determine if the user could see at least one.

Finally, it unlocks the ability to bulk-update only messages the user
has access to, in a single query (see subsequent commit).
This commit is contained in:
Alex Vandiver 2023-09-26 15:34:55 +00:00
parent c118f1874e
commit 7dcc7540f9
6 changed files with 119 additions and 79 deletions

View File

@ -33,7 +33,7 @@ from zerver.lib.markdown import version as markdown_version
from zerver.lib.mention import MentionBackend, MentionData, silent_mention_syntax_for_user
from zerver.lib.message import (
access_message,
bulk_access_messages,
bulk_access_stream_messages_query,
check_user_group_mention_allowed,
normalize_body,
stream_wildcard_mention_allowed,
@ -827,27 +827,23 @@ def do_update_message(
# full-topic move.
#
# For security model reasons, we don't want to allow a
# user to take any action that would leak information
# about older messages they cannot access (E.g. the only
# remaining messages are in a stream without shared
# history). The bulk_access_messages call below addresses
# user to take any action (e.g. post a message about
# having not moved the whole topic) that would leak
# information about older messages they cannot access
# (e.g. there were earlier inaccessible messages in the
# topic, in a stream without shared history). The
# bulk_access_stream_messages_query call below addresses
# that concern.
#
# bulk_access_messages is inefficient for this task, since
# we just want to do the exists() version of this
# query. But it's nice to reuse code, and this bulk
# operation is likely cheaper than a `GET /messages`
# unless the topic has thousands of messages of history.
assert stream_being_edited.recipient_id is not None
unmoved_messages = messages_for_topic(
realm.id,
stream_being_edited.recipient_id,
orig_topic_name,
)
visible_unmoved_messages = bulk_access_messages(
user_profile, unmoved_messages, stream=stream_being_edited
visible_unmoved_messages = bulk_access_stream_messages_query(
user_profile, unmoved_messages, stream_being_edited
)
moved_all_visible_messages = len(visible_unmoved_messages) == 0
moved_all_visible_messages = not visible_unmoved_messages.exists()
# Migrate 'topic with visibility_policy' configuration in the following
# circumstances:
@ -1064,24 +1060,15 @@ def do_update_message(
# avoid leaking information about whether there are
# messages in the destination topic's deeper history that
# the acting user does not have permission to access.
#
# TODO: These queries are quite inefficient, in that we're
# fetching full copies of all the messages in the
# destination topic to answer the question of whether the
# current user has access to at least one such message.
#
# The main strength of the current implementation is that
# it reuses existing logic, which is good for keeping it
# correct as we maintain the codebase.
preexisting_topic_messages = messages_for_topic(
realm.id, stream_for_new_topic.recipient_id, new_topic_name
).exclude(id__in=[*changed_message_ids, resolved_topic_message_id])
visible_preexisting_messages = bulk_access_messages(
user_profile, preexisting_topic_messages, stream=stream_for_new_topic
visible_preexisting_messages = bulk_access_stream_messages_query(
user_profile, preexisting_topic_messages, stream_for_new_topic
)
no_visible_preexisting_messages = len(visible_preexisting_messages) == 0
no_visible_preexisting_messages = not visible_preexisting_messages.exists()
if no_visible_preexisting_messages and moved_all_visible_messages:
new_thread_notification_string = gettext_lazy(

View File

@ -21,7 +21,7 @@ import ahocorasick
import orjson
from django.conf import settings
from django.db import connection
from django.db.models import Max, QuerySet, Sum
from django.db.models import Exists, Max, OuterRef, QuerySet, Sum
from django.utils.timezone import now as timezone_now
from django.utils.translation import gettext as _
from django_stubs_ext import ValuesQuerySet
@ -996,6 +996,38 @@ def bulk_access_messages(
return filtered_messages
def bulk_access_stream_messages_query(
user_profile: UserProfile, messages: QuerySet[Message], stream: Stream
) -> QuerySet[Message]:
"""This function mirrors bulk_access_messages, above, but applies the
limits to a QuerySet and returns a new QuerySet which only
contains messages in the given stream which the user can access.
Note that this only works with streams. It may return an empty
QuerySet if the user has access to no messages (for instance, for
a private stream which the user is not subscribed to).
"""
messages = messages.filter(realm_id=user_profile.realm_id, recipient_id=stream.recipient_id)
if stream.is_public() and user_profile.can_access_public_streams():
return messages
if not Subscription.objects.filter(
user_profile=user_profile, active=True, recipient=stream.recipient
).exists():
return Message.objects.none()
if not stream.is_history_public_to_subscribers():
messages = messages.annotate(
has_usermessage=Exists(
UserMessage.objects.filter(
user_profile_id=user_profile.id, message_id=OuterRef("id")
)
)
).filter(has_usermessage=1)
return messages
def get_messages_with_usermessage_rows_for_user(
user_profile_id: int, message_ids: Sequence[int]
) -> ValuesQuerySet[UserMessage, int]:

View File

@ -173,9 +173,9 @@ def update_messages_for_topic_edit(
# If we're moving the messages between streams, only move
# messages that the acting user can access, so that one cannot
# gain access to messages through moving them.
from zerver.lib.message import bulk_access_messages
from zerver.lib.message import bulk_access_stream_messages_query
messages_list = bulk_access_messages(acting_user, messages, stream=old_stream)
messages_list = list(bulk_access_stream_messages_query(acting_user, messages, old_stream))
else:
# For single-message edits or topic moves within a stream, we
# allow moving history the user may not have access in order

View File

@ -1528,7 +1528,7 @@ class EditMessageTest(EditMessageTestCase):
set_topic_visibility_policy(desdemona, muted_topics, UserTopic.VisibilityPolicy.MUTED)
set_topic_visibility_policy(cordelia, muted_topics, UserTopic.VisibilityPolicy.MUTED)
with self.assert_database_query_count(29):
with self.assert_database_query_count(28):
check_update_message(
user_profile=desdemona,
message_id=message_id,
@ -1559,7 +1559,7 @@ class EditMessageTest(EditMessageTestCase):
set_topic_visibility_policy(desdemona, muted_topics, UserTopic.VisibilityPolicy.MUTED)
set_topic_visibility_policy(cordelia, muted_topics, UserTopic.VisibilityPolicy.MUTED)
with self.assert_database_query_count(34):
with self.assert_database_query_count(33):
check_update_message(
user_profile=desdemona,
message_id=message_id,
@ -1592,7 +1592,7 @@ class EditMessageTest(EditMessageTestCase):
set_topic_visibility_policy(desdemona, muted_topics, UserTopic.VisibilityPolicy.MUTED)
set_topic_visibility_policy(cordelia, muted_topics, UserTopic.VisibilityPolicy.MUTED)
with self.assert_database_query_count(29):
with self.assert_database_query_count(28):
check_update_message(
user_profile=desdemona,
message_id=message_id,
@ -1615,7 +1615,7 @@ class EditMessageTest(EditMessageTestCase):
second_message_id = self.send_stream_message(
hamlet, stream_name, topic_name="changed topic name", content="Second message"
)
with self.assert_database_query_count(25):
with self.assert_database_query_count(23):
check_update_message(
user_profile=desdemona,
message_id=second_message_id,
@ -3783,7 +3783,7 @@ class EditMessageTest(EditMessageTestCase):
"iago", "test move stream", "new stream", "test"
)
with self.assert_database_query_count(55), self.assert_memcached_count(14):
with self.assert_database_query_count(52), self.assert_memcached_count(14):
result = self.client_patch(
f"/json/messages/{msg_id}",
{

View File

@ -19,6 +19,7 @@ from zerver.lib.message import (
aggregate_unread_data,
apply_unread_message_event,
bulk_access_messages,
bulk_access_stream_messages_query,
format_unread_message_details,
get_raw_unread_data,
)
@ -1505,6 +1506,30 @@ class MessageAccessTests(ZulipTestCase):
result = self.change_star(message_id)
self.assert_json_success(result)
def assert_bulk_access(
self,
user: UserProfile,
message_ids: List[int],
stream: Stream,
bulk_access_messages_count: int,
bulk_access_stream_messages_query_count: int,
) -> List[Message]:
with self.assert_database_query_count(bulk_access_messages_count):
messages = [
Message.objects.select_related("recipient").get(id=message_id)
for message_id in sorted(message_ids)
]
list_result = bulk_access_messages(user, messages, stream=stream)
with self.assert_database_query_count(bulk_access_stream_messages_query_count):
message_query = (
Message.objects.select_related("recipient")
.filter(id__in=message_ids)
.order_by("id")
)
query_result = list(bulk_access_stream_messages_query(user, message_query, stream))
self.assertEqual(query_result, list_result)
return list_result
def test_bulk_access_messages_private_stream(self) -> None:
user = self.example_user("hamlet")
self.login_user(user)
@ -1526,16 +1551,12 @@ class MessageAccessTests(ZulipTestCase):
message_two_id = self.send_stream_message(user, stream_name, "Message two")
message_ids = [message_one_id, message_two_id]
messages = [
Message.objects.select_related("recipient").get(id=message_id)
for message_id in message_ids
]
with self.assert_database_query_count(2):
filtered_messages = bulk_access_messages(later_subscribed_user, messages, stream=stream)
# Message sent before subscribing wouldn't be accessible by later
# subscribed user as stream has protected history
filtered_messages = self.assert_bulk_access(
later_subscribed_user, message_ids, stream, 4, 2
)
self.assert_length(filtered_messages, 1)
self.assertEqual(filtered_messages[0].id, message_two_id)
@ -1547,27 +1568,44 @@ class MessageAccessTests(ZulipTestCase):
acting_user=self.example_user("cordelia"),
)
with self.assert_database_query_count(2):
filtered_messages = bulk_access_messages(later_subscribed_user, messages, stream=stream)
# Message sent before subscribing are accessible by 8user as stream
# don't have protected history
# Message sent before subscribing are accessible by user as stream
# now don't have protected history
filtered_messages = self.assert_bulk_access(
later_subscribed_user, message_ids, stream, 4, 2
)
self.assert_length(filtered_messages, 2)
# Testing messages accessibility for an unsubscribed user
unsubscribed_user = self.example_user("ZOE")
with self.assert_database_query_count(2):
filtered_messages = bulk_access_messages(unsubscribed_user, messages, stream=stream)
filtered_messages = self.assert_bulk_access(unsubscribed_user, message_ids, stream, 4, 1)
self.assert_length(filtered_messages, 0)
# Adding more message ids to the list increases the query size
# for bulk_access_messages but not
# bulk_access_stream_messages_query
more_message_ids = [
*message_ids,
self.send_stream_message(user, stream_name, "Message three"),
self.send_stream_message(user, stream_name, "Message four"),
]
filtered_messages = self.assert_bulk_access(
later_subscribed_user, more_message_ids, stream, 6, 2
)
self.assert_length(filtered_messages, 4)
# Verify an exception is thrown if called where the passed
# stream not matching the messages.
other_stream = get_stream("Denmark", unsubscribed_user.realm)
with self.assertRaises(AssertionError):
bulk_access_messages(
unsubscribed_user, messages, stream=get_stream("Denmark", unsubscribed_user.realm)
)
messages = [Message.objects.get(id=id) for id in message_ids]
bulk_access_messages(unsubscribed_user, messages, stream=other_stream)
# Verify that bulk_access_stream_messages_query is empty with a stream mismatch
message_query = Message.objects.select_related("recipient").filter(id__in=message_ids)
filtered_query = bulk_access_stream_messages_query(
later_subscribed_user, message_query, other_stream
)
self.assert_length(filtered_query, 0)
def test_bulk_access_messages_public_stream(self) -> None:
user = self.example_user("hamlet")
@ -1585,20 +1623,15 @@ class MessageAccessTests(ZulipTestCase):
message_two_id = self.send_stream_message(user, stream_name, "Message two")
message_ids = [message_one_id, message_two_id]
messages = [
Message.objects.select_related("recipient").get(id=message_id)
for message_id in message_ids
]
# All public stream messages are always accessible
with self.assert_database_query_count(2):
filtered_messages = bulk_access_messages(later_subscribed_user, messages, stream=stream)
filtered_messages = self.assert_bulk_access(
later_subscribed_user, message_ids, stream, 4, 1
)
self.assert_length(filtered_messages, 2)
unsubscribed_user = self.example_user("ZOE")
with self.assert_database_query_count(2):
filtered_messages = bulk_access_messages(unsubscribed_user, messages, stream=stream)
filtered_messages = self.assert_bulk_access(unsubscribed_user, message_ids, stream, 4, 1)
self.assert_length(filtered_messages, 2)

View File

@ -54,6 +54,7 @@ from zerver.lib.exceptions import (
ResourceNotFoundError,
)
from zerver.lib.mention import MentionBackend, silent_mention_syntax_for_user
from zerver.lib.message import bulk_access_stream_messages_query
from zerver.lib.request import REQ, has_request_variables
from zerver.lib.response import json_success
from zerver.lib.retention import STREAM_MESSAGE_BATCH_SIZE as RETENTION_STREAM_MESSAGE_BATCH_SIZE
@ -99,7 +100,7 @@ from zerver.lib.validator import (
check_union,
to_non_negative_int,
)
from zerver.models import Realm, Stream, UserGroup, UserMessage, UserProfile
from zerver.models import Realm, Stream, UserGroup, UserProfile
from zerver.models.users import get_system_bot
@ -925,22 +926,9 @@ def delete_in_topic(
messages = messages_for_topic(
user_profile.realm_id, assert_is_not_none(stream.recipient_id), topic_name
)
# Note: It would be better to use bulk_access_messages here, which is our core function
# for obtaining the accessible messages - and it's good to use it wherever we can,
# so that we have a central place to keep up to date with our security model for
# message access.
# However, it fetches the full Message objects, which would be bad here for very large
# topics.
# The access_stream_by_id call above ensures that the acting user currently has access to the
# stream (which entails having an active Subscription in case of private streams), meaning
# that combined with the UserMessage check below, this is a sufficient replacement for
# bulk_access_messages.
if not stream.is_history_public_to_subscribers():
# Don't allow the user to delete messages that they don't have access to.
deletable_message_ids = UserMessage.objects.filter(
user_profile=user_profile, message_id__in=messages
).values_list("message_id", flat=True)
messages = messages.filter(id__in=deletable_message_ids)
# This handles applying access control, such that only messages
# the user can see are returned in the query.
messages = bulk_access_stream_messages_query(user_profile, messages, stream)
def delete_in_batches() -> Literal[True]:
# Topics can be large enough that this request will inevitably time out.