delete_message: Fix recipients of "delete_message" event.

Earlier, we were sending 'delete_message' event to all active
subscribers of the stream.

We shouldn't send event to those users who don't have access
to the deleted message in a private stream with protected history.

This commit fixes that bug.

Also, now we use 'event_recipient_ids_for_action_on_messages'.
It helps to add hardening such that if the invariant "no usermessage
row corresponding to a message exists if the user loses access to the
message" is violated due to some bug, it has minimal user impact.
This commit is contained in:
Prakhar Pratyush 2024-09-26 18:44:27 +05:30 committed by Tim Abbott
parent 388464fcf4
commit d6c48b7185
4 changed files with 92 additions and 13 deletions

View File

@ -2,9 +2,9 @@ from collections.abc import Iterable
from typing import TypedDict from typing import TypedDict
from zerver.lib import retention from zerver.lib import retention
from zerver.lib.message import event_recipient_ids_for_action_on_messages
from zerver.lib.retention import move_messages_to_archive from zerver.lib.retention import move_messages_to_archive
from zerver.lib.stream_subscription import get_active_subscriptions_for_stream_id from zerver.models import Message, Realm, Stream, UserProfile
from zerver.models import Message, Realm, Stream, UserMessage, UserProfile
from zerver.tornado.django_api import send_event_on_commit from zerver.tornado.django_api import send_event_on_commit
@ -67,29 +67,27 @@ def do_delete_messages(
if not sample_message.is_stream_message(): if not sample_message.is_stream_message():
assert len(messages) == 1 assert len(messages) == 1
message_type = "private" message_type = "private"
ums = UserMessage.objects.filter(message_id__in=message_ids)
users_to_notify = set(ums.values_list("user_profile_id", flat=True))
archiving_chunk_size = retention.MESSAGE_BATCH_SIZE archiving_chunk_size = retention.MESSAGE_BATCH_SIZE
if message_type == "stream": if message_type == "stream":
stream_id = sample_message.recipient.type_id stream_id = sample_message.recipient.type_id
event["stream_id"] = stream_id event["stream_id"] = stream_id
event["topic"] = sample_message.topic_name() event["topic"] = sample_message.topic_name()
subscriptions = get_active_subscriptions_for_stream_id( stream = Stream.objects.get(id=stream_id)
stream_id, include_deactivated_users=False
)
# We exclude long-term idle users, since they by definition have no active clients.
subscriptions = subscriptions.exclude(user_profile__long_term_idle=True)
users_to_notify = set(subscriptions.values_list("user_profile_id", flat=True))
archiving_chunk_size = retention.STREAM_MESSAGE_BATCH_SIZE archiving_chunk_size = retention.STREAM_MESSAGE_BATCH_SIZE
# We exclude long-term idle users, since they by definition have no active clients.
users_to_notify = event_recipient_ids_for_action_on_messages(
messages,
channel=stream if message_type == "stream" else None,
)
if acting_user is not None: if acting_user is not None:
# Always send event to the user who deleted the message. # Always send event to the user who deleted the message.
users_to_notify.add(acting_user.id) users_to_notify.add(acting_user.id)
move_messages_to_archive(message_ids, realm=realm, chunk_size=archiving_chunk_size) move_messages_to_archive(message_ids, realm=realm, chunk_size=archiving_chunk_size)
if message_type == "stream": if message_type == "stream":
stream = Stream.objects.get(id=sample_message.recipient.type_id)
check_update_first_message_id(realm, stream, message_ids, users_to_notify) check_update_first_message_id(realm, stream, message_ids, users_to_notify)
event["message_type"] = message_type event["message_type"] = message_type

View File

@ -22,6 +22,7 @@ from zerver.lib.message_cache import MessageDict, extract_message_dict, stringif
from zerver.lib.partial import partial from zerver.lib.partial import partial
from zerver.lib.request import RequestVariableConversionError from zerver.lib.request import RequestVariableConversionError
from zerver.lib.stream_subscription import ( from zerver.lib.stream_subscription import (
get_active_subscriptions_for_stream_id,
get_stream_subscriptions_for_user, get_stream_subscriptions_for_user,
get_subscribed_stream_recipient_ids_for_user, get_subscribed_stream_recipient_ids_for_user,
num_subscribers_for_stream_id, num_subscribers_for_stream_id,
@ -407,6 +408,86 @@ def has_message_access(
return is_subscribed_helper() return is_subscribed_helper()
def event_recipient_ids_for_action_on_messages(
messages: list[Message],
*,
channel: Stream | None = None,
exclude_long_term_idle_users: bool = True,
) -> set[int]:
"""Returns IDs of users who should receive events when an action
(delete, react, etc) is performed on given set of messages, which
are expected to all be in a single conversation.
This function aligns with the 'has_message_access' above to ensure
that events reach only those users who have access to the messages.
Notably, for performance reasons, we do not send live-update
events to everyone who could potentially have a cached copy of a
message because they fetched messages in a public channel to which
they are not subscribed. Such events are limited to those messages
where the user has a UserMessage row (including `historical` rows).
"""
assert len(messages) > 0
message_ids = [message.id for message in messages]
def get_user_ids_having_usermessage_row_for_messages(message_ids: list[int]) -> set[int]:
"""Returns the IDs of users who actually received the messages."""
usermessages = UserMessage.objects.filter(message_id__in=message_ids)
if exclude_long_term_idle_users:
usermessages = usermessages.exclude(user_profile__long_term_idle=True)
return set(usermessages.values_list("user_profile_id", flat=True))
sample_message = messages[0]
if not sample_message.is_stream_message():
# For DM, event is sent to users who actually received the message.
return get_user_ids_having_usermessage_row_for_messages(message_ids)
channel_id = sample_message.recipient.type_id
if channel is None:
channel = Stream.objects.get(id=channel_id)
subscriptions = get_active_subscriptions_for_stream_id(
channel_id, include_deactivated_users=False
)
if exclude_long_term_idle_users:
subscriptions = subscriptions.exclude(user_profile__long_term_idle=True)
subscriber_ids = set(subscriptions.values_list("user_profile_id", flat=True))
if not channel.is_history_public_to_subscribers():
# For protected history, only users who are subscribed and
# received the original message are notified.
assert not channel.is_public()
user_ids_with_usermessage_row = get_user_ids_having_usermessage_row_for_messages(
message_ids
)
return user_ids_with_usermessage_row & subscriber_ids
if not channel.is_public():
# For private channel with shared history, the set of
# users with access is exactly the subscribers.
return subscriber_ids
# The remaining case is public channels with public history. Events are sent to:
# 1. Current channel subscribers
# 2. Unsubscribed users having usermessage row & channel access.
# * Users who never subscribed but starred or reacted on messages
# (usermessages with historical flag exists for such cases).
# * Users who were initially subscribed and later unsubscribed
# (usermessages exist for messages they received while subscribed).
usermessage_rows = UserMessage.objects.filter(message_id__in=message_ids).exclude(
# Excluding guests here implements can_access_public_channels,
# since we already know realm.is_zephyr_mirror_realm is false,
# based on the value of is_history_public_to_subscribers.
user_profile__role=UserProfile.ROLE_GUEST
)
if exclude_long_term_idle_users:
usermessage_rows = usermessage_rows.exclude(user_profile__long_term_idle=True)
user_ids_with_usermessage_row_and_channel_access = set(
usermessage_rows.values_list("user_profile_id", flat=True)
)
return user_ids_with_usermessage_row_and_channel_access | subscriber_ids
def bulk_access_messages( def bulk_access_messages(
user_profile: UserProfile, user_profile: UserProfile,
messages: Collection[Message] | QuerySet[Message], messages: Collection[Message] | QuerySet[Message],

View File

@ -554,7 +554,7 @@ class DeleteMessageTest(ZulipTestCase):
self.assertEqual(stream.first_message_id, message_ids[1]) self.assertEqual(stream.first_message_id, message_ids[1])
all_messages = Message.objects.filter(id__in=message_ids) all_messages = Message.objects.filter(id__in=message_ids)
with self.assert_database_query_count(24): with self.assert_database_query_count(25):
do_delete_messages(realm, all_messages, acting_user=None) do_delete_messages(realm, all_messages, acting_user=None)
stream = get_stream(stream_name, realm) stream = get_stream(stream_name, realm)
self.assertEqual(stream.first_message_id, None) self.assertEqual(stream.first_message_id, None)

View File

@ -1142,7 +1142,7 @@ class TestDoDeleteMessages(ZulipTestCase):
message_ids = [self.send_stream_message(cordelia, "Verona", str(i)) for i in range(10)] message_ids = [self.send_stream_message(cordelia, "Verona", str(i)) for i in range(10)]
messages = Message.objects.filter(id__in=message_ids) messages = Message.objects.filter(id__in=message_ids)
with self.assert_database_query_count(22): with self.assert_database_query_count(23):
do_delete_messages(realm, messages, acting_user=None) do_delete_messages(realm, messages, acting_user=None)
self.assertFalse(Message.objects.filter(id__in=message_ids).exists()) self.assertFalse(Message.objects.filter(id__in=message_ids).exists())