mirror of https://github.com/zulip/zulip.git
message: Add a bulk_access_stream_messages_query method.
This applies access restrictions in SQL, so that individual messages do not need to be walked one-by-one. It only functions for stream messages. Use of this method significantly speeds up checks if we moved "all visible messages" in a topic, since we no longer need to walk every remaining message in the old topic to determine that at least one was visible to the user. Similarly, it significantly speeds up merging into existing topics, since it no longer must walk every message in the new topic to determine if the user could see at least one. Finally, it unlocks the ability to bulk-update only messages the user has access to, in a single query (see subsequent commit).
This commit is contained in:
parent
628be8d433
commit
822131fef4
|
@ -33,7 +33,7 @@ from zerver.lib.markdown import version as markdown_version
|
|||
from zerver.lib.mention import MentionBackend, MentionData, silent_mention_syntax_for_user
|
||||
from zerver.lib.message import (
|
||||
access_message,
|
||||
bulk_access_messages,
|
||||
bulk_access_stream_messages_query,
|
||||
check_user_group_mention_allowed,
|
||||
normalize_body,
|
||||
stream_wildcard_mention_allowed,
|
||||
|
@ -827,27 +827,23 @@ def do_update_message(
|
|||
# full-topic move.
|
||||
#
|
||||
# For security model reasons, we don't want to allow a
|
||||
# user to take any action that would leak information
|
||||
# about older messages they cannot access (E.g. the only
|
||||
# remaining messages are in a stream without shared
|
||||
# history). The bulk_access_messages call below addresses
|
||||
# user to take any action (e.g. post a message about
|
||||
# having not moved the whole topic) that would leak
|
||||
# information about older messages they cannot access
|
||||
# (e.g. there were earlier inaccessible messages in the
|
||||
# topic, in a stream without shared history). The
|
||||
# bulk_access_stream_messages_query call below addresses
|
||||
# that concern.
|
||||
#
|
||||
# bulk_access_messages is inefficient for this task, since
|
||||
# we just want to do the exists() version of this
|
||||
# query. But it's nice to reuse code, and this bulk
|
||||
# operation is likely cheaper than a `GET /messages`
|
||||
# unless the topic has thousands of messages of history.
|
||||
assert stream_being_edited.recipient_id is not None
|
||||
unmoved_messages = messages_for_topic(
|
||||
realm.id,
|
||||
stream_being_edited.recipient_id,
|
||||
orig_topic_name,
|
||||
)
|
||||
visible_unmoved_messages = bulk_access_messages(
|
||||
user_profile, unmoved_messages, stream=stream_being_edited
|
||||
visible_unmoved_messages = bulk_access_stream_messages_query(
|
||||
user_profile, unmoved_messages, stream_being_edited
|
||||
)
|
||||
moved_all_visible_messages = len(visible_unmoved_messages) == 0
|
||||
moved_all_visible_messages = not visible_unmoved_messages.exists()
|
||||
|
||||
# Migrate 'topic with visibility_policy' configuration in the following
|
||||
# circumstances:
|
||||
|
@ -1064,24 +1060,15 @@ def do_update_message(
|
|||
# avoid leaking information about whether there are
|
||||
# messages in the destination topic's deeper history that
|
||||
# the acting user does not have permission to access.
|
||||
#
|
||||
# TODO: These queries are quite inefficient, in that we're
|
||||
# fetching full copies of all the messages in the
|
||||
# destination topic to answer the question of whether the
|
||||
# current user has access to at least one such message.
|
||||
#
|
||||
# The main strength of the current implementation is that
|
||||
# it reuses existing logic, which is good for keeping it
|
||||
# correct as we maintain the codebase.
|
||||
preexisting_topic_messages = messages_for_topic(
|
||||
realm.id, stream_for_new_topic.recipient_id, new_topic_name
|
||||
).exclude(id__in=[*changed_message_ids, resolved_topic_message_id])
|
||||
|
||||
visible_preexisting_messages = bulk_access_messages(
|
||||
user_profile, preexisting_topic_messages, stream=stream_for_new_topic
|
||||
visible_preexisting_messages = bulk_access_stream_messages_query(
|
||||
user_profile, preexisting_topic_messages, stream_for_new_topic
|
||||
)
|
||||
|
||||
no_visible_preexisting_messages = len(visible_preexisting_messages) == 0
|
||||
no_visible_preexisting_messages = not visible_preexisting_messages.exists()
|
||||
|
||||
if no_visible_preexisting_messages and moved_all_visible_messages:
|
||||
new_thread_notification_string = gettext_lazy(
|
||||
|
|
|
@ -21,7 +21,7 @@ import ahocorasick
|
|||
import orjson
|
||||
from django.conf import settings
|
||||
from django.db import connection
|
||||
from django.db.models import Max, QuerySet, Sum
|
||||
from django.db.models import Exists, Max, OuterRef, QuerySet, Sum
|
||||
from django.utils.timezone import now as timezone_now
|
||||
from django.utils.translation import gettext as _
|
||||
from django_stubs_ext import ValuesQuerySet
|
||||
|
@ -996,6 +996,38 @@ def bulk_access_messages(
|
|||
return filtered_messages
|
||||
|
||||
|
||||
def bulk_access_stream_messages_query(
|
||||
user_profile: UserProfile, messages: QuerySet[Message], stream: Stream
|
||||
) -> QuerySet[Message]:
|
||||
"""This function mirrors bulk_access_messages, above, but applies the
|
||||
limits to a QuerySet and returns a new QuerySet which only
|
||||
contains messages in the given stream which the user can access.
|
||||
Note that this only works with streams. It may return an empty
|
||||
QuerySet if the user has access to no messages (for instance, for
|
||||
a private stream which the user is not subscribed to).
|
||||
|
||||
"""
|
||||
|
||||
messages = messages.filter(realm_id=user_profile.realm_id, recipient_id=stream.recipient_id)
|
||||
|
||||
if stream.is_public() and user_profile.can_access_public_streams():
|
||||
return messages
|
||||
|
||||
if not Subscription.objects.filter(
|
||||
user_profile=user_profile, active=True, recipient=stream.recipient
|
||||
).exists():
|
||||
return Message.objects.none()
|
||||
if not stream.is_history_public_to_subscribers():
|
||||
messages = messages.annotate(
|
||||
has_usermessage=Exists(
|
||||
UserMessage.objects.filter(
|
||||
user_profile_id=user_profile.id, message_id=OuterRef("id")
|
||||
)
|
||||
)
|
||||
).filter(has_usermessage=1)
|
||||
return messages
|
||||
|
||||
|
||||
def get_messages_with_usermessage_rows_for_user(
|
||||
user_profile_id: int, message_ids: Sequence[int]
|
||||
) -> ValuesQuerySet[UserMessage, int]:
|
||||
|
|
|
@ -173,9 +173,9 @@ def update_messages_for_topic_edit(
|
|||
# If we're moving the messages between streams, only move
|
||||
# messages that the acting user can access, so that one cannot
|
||||
# gain access to messages through moving them.
|
||||
from zerver.lib.message import bulk_access_messages
|
||||
from zerver.lib.message import bulk_access_stream_messages_query
|
||||
|
||||
messages_list = bulk_access_messages(acting_user, messages, stream=old_stream)
|
||||
messages_list = list(bulk_access_stream_messages_query(acting_user, messages, old_stream))
|
||||
else:
|
||||
# For single-message edits or topic moves within a stream, we
|
||||
# allow moving history the user may not have access in order
|
||||
|
|
|
@ -1528,7 +1528,7 @@ class EditMessageTest(EditMessageTestCase):
|
|||
set_topic_visibility_policy(desdemona, muted_topics, UserTopic.VisibilityPolicy.MUTED)
|
||||
set_topic_visibility_policy(cordelia, muted_topics, UserTopic.VisibilityPolicy.MUTED)
|
||||
|
||||
with self.assert_database_query_count(29):
|
||||
with self.assert_database_query_count(28):
|
||||
check_update_message(
|
||||
user_profile=desdemona,
|
||||
message_id=message_id,
|
||||
|
@ -1559,7 +1559,7 @@ class EditMessageTest(EditMessageTestCase):
|
|||
set_topic_visibility_policy(desdemona, muted_topics, UserTopic.VisibilityPolicy.MUTED)
|
||||
set_topic_visibility_policy(cordelia, muted_topics, UserTopic.VisibilityPolicy.MUTED)
|
||||
|
||||
with self.assert_database_query_count(34):
|
||||
with self.assert_database_query_count(33):
|
||||
check_update_message(
|
||||
user_profile=desdemona,
|
||||
message_id=message_id,
|
||||
|
@ -1592,7 +1592,7 @@ class EditMessageTest(EditMessageTestCase):
|
|||
set_topic_visibility_policy(desdemona, muted_topics, UserTopic.VisibilityPolicy.MUTED)
|
||||
set_topic_visibility_policy(cordelia, muted_topics, UserTopic.VisibilityPolicy.MUTED)
|
||||
|
||||
with self.assert_database_query_count(29):
|
||||
with self.assert_database_query_count(28):
|
||||
check_update_message(
|
||||
user_profile=desdemona,
|
||||
message_id=message_id,
|
||||
|
@ -1615,7 +1615,7 @@ class EditMessageTest(EditMessageTestCase):
|
|||
second_message_id = self.send_stream_message(
|
||||
hamlet, stream_name, topic_name="changed topic name", content="Second message"
|
||||
)
|
||||
with self.assert_database_query_count(25):
|
||||
with self.assert_database_query_count(23):
|
||||
check_update_message(
|
||||
user_profile=desdemona,
|
||||
message_id=second_message_id,
|
||||
|
@ -3783,7 +3783,7 @@ class EditMessageTest(EditMessageTestCase):
|
|||
"iago", "test move stream", "new stream", "test"
|
||||
)
|
||||
|
||||
with self.assert_database_query_count(55), self.assert_memcached_count(14):
|
||||
with self.assert_database_query_count(52), self.assert_memcached_count(14):
|
||||
result = self.client_patch(
|
||||
f"/json/messages/{msg_id}",
|
||||
{
|
||||
|
|
|
@ -19,6 +19,7 @@ from zerver.lib.message import (
|
|||
aggregate_unread_data,
|
||||
apply_unread_message_event,
|
||||
bulk_access_messages,
|
||||
bulk_access_stream_messages_query,
|
||||
format_unread_message_details,
|
||||
get_raw_unread_data,
|
||||
)
|
||||
|
@ -1505,6 +1506,30 @@ class MessageAccessTests(ZulipTestCase):
|
|||
result = self.change_star(message_id)
|
||||
self.assert_json_success(result)
|
||||
|
||||
def assert_bulk_access(
|
||||
self,
|
||||
user: UserProfile,
|
||||
message_ids: List[int],
|
||||
stream: Stream,
|
||||
bulk_access_messages_count: int,
|
||||
bulk_access_stream_messages_query_count: int,
|
||||
) -> List[Message]:
|
||||
with self.assert_database_query_count(bulk_access_messages_count):
|
||||
messages = [
|
||||
Message.objects.select_related("recipient").get(id=message_id)
|
||||
for message_id in sorted(message_ids)
|
||||
]
|
||||
list_result = bulk_access_messages(user, messages, stream=stream)
|
||||
with self.assert_database_query_count(bulk_access_stream_messages_query_count):
|
||||
message_query = (
|
||||
Message.objects.select_related("recipient")
|
||||
.filter(id__in=message_ids)
|
||||
.order_by("id")
|
||||
)
|
||||
query_result = list(bulk_access_stream_messages_query(user, message_query, stream))
|
||||
self.assertEqual(query_result, list_result)
|
||||
return list_result
|
||||
|
||||
def test_bulk_access_messages_private_stream(self) -> None:
|
||||
user = self.example_user("hamlet")
|
||||
self.login_user(user)
|
||||
|
@ -1526,16 +1551,12 @@ class MessageAccessTests(ZulipTestCase):
|
|||
message_two_id = self.send_stream_message(user, stream_name, "Message two")
|
||||
|
||||
message_ids = [message_one_id, message_two_id]
|
||||
messages = [
|
||||
Message.objects.select_related("recipient").get(id=message_id)
|
||||
for message_id in message_ids
|
||||
]
|
||||
|
||||
with self.assert_database_query_count(2):
|
||||
filtered_messages = bulk_access_messages(later_subscribed_user, messages, stream=stream)
|
||||
|
||||
# Message sent before subscribing wouldn't be accessible by later
|
||||
# subscribed user as stream has protected history
|
||||
filtered_messages = self.assert_bulk_access(
|
||||
later_subscribed_user, message_ids, stream, 4, 2
|
||||
)
|
||||
self.assert_length(filtered_messages, 1)
|
||||
self.assertEqual(filtered_messages[0].id, message_two_id)
|
||||
|
||||
|
@ -1547,27 +1568,44 @@ class MessageAccessTests(ZulipTestCase):
|
|||
acting_user=self.example_user("cordelia"),
|
||||
)
|
||||
|
||||
with self.assert_database_query_count(2):
|
||||
filtered_messages = bulk_access_messages(later_subscribed_user, messages, stream=stream)
|
||||
|
||||
# Message sent before subscribing are accessible by 8user as stream
|
||||
# don't have protected history
|
||||
# Message sent before subscribing are accessible by user as stream
|
||||
# now don't have protected history
|
||||
filtered_messages = self.assert_bulk_access(
|
||||
later_subscribed_user, message_ids, stream, 4, 2
|
||||
)
|
||||
self.assert_length(filtered_messages, 2)
|
||||
|
||||
# Testing messages accessibility for an unsubscribed user
|
||||
unsubscribed_user = self.example_user("ZOE")
|
||||
|
||||
with self.assert_database_query_count(2):
|
||||
filtered_messages = bulk_access_messages(unsubscribed_user, messages, stream=stream)
|
||||
|
||||
filtered_messages = self.assert_bulk_access(unsubscribed_user, message_ids, stream, 4, 1)
|
||||
self.assert_length(filtered_messages, 0)
|
||||
|
||||
# Adding more message ids to the list increases the query size
|
||||
# for bulk_access_messages but not
|
||||
# bulk_access_stream_messages_query
|
||||
more_message_ids = [
|
||||
*message_ids,
|
||||
self.send_stream_message(user, stream_name, "Message three"),
|
||||
self.send_stream_message(user, stream_name, "Message four"),
|
||||
]
|
||||
filtered_messages = self.assert_bulk_access(
|
||||
later_subscribed_user, more_message_ids, stream, 6, 2
|
||||
)
|
||||
self.assert_length(filtered_messages, 4)
|
||||
|
||||
# Verify an exception is thrown if called where the passed
|
||||
# stream not matching the messages.
|
||||
other_stream = get_stream("Denmark", unsubscribed_user.realm)
|
||||
with self.assertRaises(AssertionError):
|
||||
bulk_access_messages(
|
||||
unsubscribed_user, messages, stream=get_stream("Denmark", unsubscribed_user.realm)
|
||||
)
|
||||
messages = [Message.objects.get(id=id) for id in message_ids]
|
||||
bulk_access_messages(unsubscribed_user, messages, stream=other_stream)
|
||||
|
||||
# Verify that bulk_access_stream_messages_query is empty with a stream mismatch
|
||||
message_query = Message.objects.select_related("recipient").filter(id__in=message_ids)
|
||||
filtered_query = bulk_access_stream_messages_query(
|
||||
later_subscribed_user, message_query, other_stream
|
||||
)
|
||||
self.assert_length(filtered_query, 0)
|
||||
|
||||
def test_bulk_access_messages_public_stream(self) -> None:
|
||||
user = self.example_user("hamlet")
|
||||
|
@ -1585,20 +1623,15 @@ class MessageAccessTests(ZulipTestCase):
|
|||
message_two_id = self.send_stream_message(user, stream_name, "Message two")
|
||||
|
||||
message_ids = [message_one_id, message_two_id]
|
||||
messages = [
|
||||
Message.objects.select_related("recipient").get(id=message_id)
|
||||
for message_id in message_ids
|
||||
]
|
||||
|
||||
# All public stream messages are always accessible
|
||||
with self.assert_database_query_count(2):
|
||||
filtered_messages = bulk_access_messages(later_subscribed_user, messages, stream=stream)
|
||||
filtered_messages = self.assert_bulk_access(
|
||||
later_subscribed_user, message_ids, stream, 4, 1
|
||||
)
|
||||
self.assert_length(filtered_messages, 2)
|
||||
|
||||
unsubscribed_user = self.example_user("ZOE")
|
||||
with self.assert_database_query_count(2):
|
||||
filtered_messages = bulk_access_messages(unsubscribed_user, messages, stream=stream)
|
||||
|
||||
filtered_messages = self.assert_bulk_access(unsubscribed_user, message_ids, stream, 4, 1)
|
||||
self.assert_length(filtered_messages, 2)
|
||||
|
||||
|
||||
|
|
|
@ -54,6 +54,7 @@ from zerver.lib.exceptions import (
|
|||
ResourceNotFoundError,
|
||||
)
|
||||
from zerver.lib.mention import MentionBackend, silent_mention_syntax_for_user
|
||||
from zerver.lib.message import bulk_access_stream_messages_query
|
||||
from zerver.lib.request import REQ, has_request_variables
|
||||
from zerver.lib.response import json_success
|
||||
from zerver.lib.retention import STREAM_MESSAGE_BATCH_SIZE as RETENTION_STREAM_MESSAGE_BATCH_SIZE
|
||||
|
@ -99,7 +100,7 @@ from zerver.lib.validator import (
|
|||
check_union,
|
||||
to_non_negative_int,
|
||||
)
|
||||
from zerver.models import Realm, Stream, UserGroup, UserMessage, UserProfile
|
||||
from zerver.models import Realm, Stream, UserGroup, UserProfile
|
||||
from zerver.models.users import get_system_bot
|
||||
|
||||
|
||||
|
@ -925,22 +926,9 @@ def delete_in_topic(
|
|||
messages = messages_for_topic(
|
||||
user_profile.realm_id, assert_is_not_none(stream.recipient_id), topic_name
|
||||
)
|
||||
# Note: It would be better to use bulk_access_messages here, which is our core function
|
||||
# for obtaining the accessible messages - and it's good to use it wherever we can,
|
||||
# so that we have a central place to keep up to date with our security model for
|
||||
# message access.
|
||||
# However, it fetches the full Message objects, which would be bad here for very large
|
||||
# topics.
|
||||
# The access_stream_by_id call above ensures that the acting user currently has access to the
|
||||
# stream (which entails having an active Subscription in case of private streams), meaning
|
||||
# that combined with the UserMessage check below, this is a sufficient replacement for
|
||||
# bulk_access_messages.
|
||||
if not stream.is_history_public_to_subscribers():
|
||||
# Don't allow the user to delete messages that they don't have access to.
|
||||
deletable_message_ids = UserMessage.objects.filter(
|
||||
user_profile=user_profile, message_id__in=messages
|
||||
).values_list("message_id", flat=True)
|
||||
messages = messages.filter(id__in=deletable_message_ids)
|
||||
# This handles applying access control, such that only messages
|
||||
# the user can see are returned in the query.
|
||||
messages = bulk_access_stream_messages_query(user_profile, messages, stream)
|
||||
|
||||
def delete_in_batches() -> Literal[True]:
|
||||
# Topics can be large enough that this request will inevitably time out.
|
||||
|
|
Loading…
Reference in New Issue