zulip/zerver/actions/message_flags.py

399 lines
14 KiB
Python
Raw Normal View History

import time
from collections import defaultdict
from dataclasses import asdict, dataclass, field
from django.conf import settings
from django.db import transaction
from django.db.models import F
from django.utils.timezone import now as timezone_now
from django.utils.translation import gettext as _
from analytics.lib.counts import COUNT_STATS, do_increment_logging_stat
from zerver.lib.exceptions import JsonableError
from zerver.lib.message import (
bulk_access_messages,
format_unread_message_details,
get_raw_unread_data,
)
from zerver.lib.queue import queue_event_on_commit
from zerver.lib.stream_subscription import get_subscribed_stream_recipient_ids_for_user
from zerver.lib.topic import filter_by_topic_name_via_message
from zerver.lib.user_message import DEFAULT_HISTORICAL_FLAGS, create_historical_user_messages
from zerver.models import Message, Recipient, UserMessage, UserProfile
from zerver.tornado.django_api import send_event_on_commit, send_event_rollback_unsafe
@dataclass
class ReadMessagesEvent:
messages: list[int]
all: bool
type: str = field(default="update_message_flags", init=False)
op: str = field(default="add", init=False)
operation: str = field(default="add", init=False)
flag: str = field(default="read", init=False)
def do_mark_all_as_read(user_profile: UserProfile, *, timeout: float | None = None) -> int | None:
start_time = time.monotonic()
# First, we clear mobile push notifications. This is safer in the
# event that the below logic times out and we're killed.
all_push_message_ids = (
UserMessage.objects.filter(
user_profile=user_profile,
)
.extra( # noqa: S610
where=[UserMessage.where_active_push_notification()],
)
.values_list("message_id", flat=True)[0:10000]
)
do_clear_mobile_push_notifications_for_ids([user_profile.id], all_push_message_ids)
batch_size = 2000
count = 0
while True:
if timeout is not None and time.monotonic() >= start_time + timeout:
return None
with transaction.atomic(durable=True):
query = (
UserMessage.select_for_update_query()
.filter(user_profile=user_profile)
.extra(where=[UserMessage.where_unread()])[:batch_size] # noqa: S610
)
# This updated_count is the same as the number of UserMessage
# rows selected, because due to the FOR UPDATE lock, we're guaranteed
# that all the selected rows will indeed be updated.
# UPDATE queries don't support LIMIT, so we have to use a subquery
# to do batching.
updated_count = UserMessage.objects.filter(id__in=query).update(
flags=F("flags").bitor(UserMessage.flags.read),
)
event_time = timezone_now()
do_increment_logging_stat(
user_profile,
COUNT_STATS["messages_read::hour"],
None,
event_time,
increment=updated_count,
)
do_increment_logging_stat(
user_profile,
COUNT_STATS["messages_read_interactions::hour"],
None,
event_time,
increment=min(1, updated_count),
)
count += updated_count
if updated_count < batch_size:
break
event = asdict(
ReadMessagesEvent(
messages=[], # we don't send messages, since the client reloads anyway
all=True,
)
)
send_event_rollback_unsafe(user_profile.realm, event, [user_profile.id])
return count
@transaction.atomic(durable=True)
def do_mark_stream_messages_as_read(
user_profile: UserProfile, stream_recipient_id: int, topic_name: str | None = None
) -> int:
query = (
UserMessage.select_for_update_query()
.filter(
user_profile=user_profile,
message__recipient_id=stream_recipient_id,
)
.extra( # noqa: S610
where=[UserMessage.where_unread()],
)
)
if topic_name:
query = filter_by_topic_name_via_message(
query=query,
topic_name=topic_name,
)
message_ids = list(query.values_list("message_id", flat=True))
if len(message_ids) == 0:
return 0
count = query.update(
flags=F("flags").bitor(UserMessage.flags.read),
)
event = asdict(
ReadMessagesEvent(
messages=message_ids,
all=False,
)
)
event_time = timezone_now()
send_event_on_commit(user_profile.realm, event, [user_profile.id])
do_clear_mobile_push_notifications_for_ids([user_profile.id], message_ids)
do_increment_logging_stat(
user_profile, COUNT_STATS["messages_read::hour"], None, event_time, increment=count
)
do_increment_logging_stat(
user_profile,
COUNT_STATS["messages_read_interactions::hour"],
None,
event_time,
increment=min(1, count),
)
return count
@transaction.atomic(savepoint=False)
def do_mark_muted_user_messages_as_read(
user_profile: UserProfile,
muted_user: UserProfile,
) -> int:
query = (
UserMessage.select_for_update_query()
.filter(user_profile=user_profile, message__sender=muted_user)
.extra(where=[UserMessage.where_unread()]) # noqa: S610
)
message_ids = list(query.values_list("message_id", flat=True))
if len(message_ids) == 0:
return 0
count = query.update(
flags=F("flags").bitor(UserMessage.flags.read),
)
event = asdict(
ReadMessagesEvent(
messages=message_ids,
all=False,
)
)
event_time = timezone_now()
send_event_on_commit(user_profile.realm, event, [user_profile.id])
do_clear_mobile_push_notifications_for_ids([user_profile.id], message_ids)
do_increment_logging_stat(
user_profile, COUNT_STATS["messages_read::hour"], None, event_time, increment=count
)
do_increment_logging_stat(
user_profile,
COUNT_STATS["messages_read_interactions::hour"],
None,
event_time,
increment=min(1, count),
)
return count
def do_update_mobile_push_notification(
message: Message,
prior_mention_user_ids: set[int],
mentions_user_ids: set[int],
stream_push_user_ids: set[int],
) -> None:
# Called during the message edit code path to remove mobile push
# notifications for users who are no longer mentioned following
# the edit. See #15428 for details.
#
# A perfect implementation would also support updating the message
# in a sent notification if a message was edited to mention a
# group rather than a user (or vice versa), though it is likely
# not worth the effort to do such a change.
if not message.is_stream_message():
return
remove_notify_users = prior_mention_user_ids - mentions_user_ids - stream_push_user_ids
do_clear_mobile_push_notifications_for_ids(list(remove_notify_users), [message.id])
def do_clear_mobile_push_notifications_for_ids(
user_profile_ids: list[int], message_ids: list[int]
) -> None:
if len(message_ids) == 0:
return
# This function supports clearing notifications for several users
# only for the message-edit use case where we'll have a single message_id.
assert len(user_profile_ids) == 1 or len(message_ids) == 1
messages_by_user = defaultdict(list)
notifications_to_update = (
UserMessage.objects.filter(
message_id__in=message_ids,
user_profile_id__in=user_profile_ids,
)
.extra( # noqa: S610
where=[UserMessage.where_active_push_notification()],
)
.values_list("user_profile_id", "message_id")
)
for user_id, message_id in notifications_to_update:
messages_by_user[user_id].append(message_id)
for user_profile_id, event_message_ids in messages_by_user.items():
notice = {
"type": "remove",
"user_profile_id": user_profile_id,
"message_ids": event_message_ids,
}
if settings.MOBILE_NOTIFICATIONS_SHARDS > 1: # nocoverage
shard_id = user_profile_id % settings.MOBILE_NOTIFICATIONS_SHARDS + 1
queue_event_on_commit(f"missedmessage_mobile_notifications_shard{shard_id}", notice)
else:
queue_event_on_commit("missedmessage_mobile_notifications", notice)
def do_update_message_flags(
user_profile: UserProfile, operation: str, flag: str, messages: list[int]
) -> int:
valid_flags = [item for item in UserMessage.flags if item not in UserMessage.NON_API_FLAGS]
if flag not in valid_flags:
raise JsonableError(_("Invalid flag: '{flag}'").format(flag=flag))
if flag in UserMessage.NON_EDITABLE_FLAGS:
raise JsonableError(_("Flag not editable: '{flag}'").format(flag=flag))
if operation not in ("add", "remove"):
raise JsonableError(
_("Invalid message flag operation: '{operation}'").format(operation=operation)
)
is_adding = operation == "add"
flagattr = getattr(UserMessage.flags, flag)
flag_target = flagattr if is_adding else 0
with transaction.atomic(durable=True):
if flag == "read" and not is_adding:
# We have an invariant that all stream messages marked as
# unread must be in streams the user is subscribed to.
#
# When marking as unread, we enforce this invariant by
# ignoring any messages in streams the user is not
# currently subscribed to.
subscribed_recipient_ids = get_subscribed_stream_recipient_ids_for_user(user_profile)
message_ids_in_unsubscribed_streams = set(
# Uses index: zerver_message_pkey
Message.objects.select_related("recipient")
.filter(id__in=messages, recipient__type=Recipient.STREAM)
.exclude(recipient_id__in=subscribed_recipient_ids)
.values_list("id", flat=True)
)
messages = [
message_id
for message_id in messages
if message_id not in message_ids_in_unsubscribed_streams
]
ums = {
um.message_id: um
for um in UserMessage.select_for_update_query().filter(
user_profile=user_profile, message_id__in=messages
)
}
# Filter out rows that already have the desired flag. We do
# this here, rather than in the original database query,
# because not all flags have database indexes and we want to
# bound the cost of this operation.
messages = [
message_id
for message_id in messages
if (int(ums[message_id].flags) if message_id in ums else DEFAULT_HISTORICAL_FLAGS)
& flagattr
!= flag_target
]
count = len(messages)
if DEFAULT_HISTORICAL_FLAGS & flagattr != flag_target:
# When marking messages as read, creating "historical"
# UserMessage rows would be a waste of storage, because
# `flags.read | flags.historical` is exactly the flags we
# simulate when processing a message for which a user has
# access but no UserMessage row.
#
# Users can mutate flags for messages that don't have a
# UserMessage yet. Validate that the user is even allowed
# to access these message_ids; if so, we will create
# "historical" UserMessage rows for the messages in question.
#
# See create_historical_user_messages for a more detailed
# explanation.
historical_message_ids = set(messages) - set(ums.keys())
historical_messages = bulk_access_messages(
user_profile,
list(
# Uses index: zerver_message_pkey
Message.objects.filter(id__in=historical_message_ids).prefetch_related(
"recipient"
)
),
)
if len(historical_messages) != len(historical_message_ids):
raise JsonableError(_("Invalid message(s)"))
create_historical_user_messages(
user_id=user_profile.id,
message_ids=list(historical_message_ids),
flagattr=flagattr,
flag_target=flag_target,
)
to_update = UserMessage.objects.filter(
user_profile=user_profile, message_id__in=set(messages) & set(ums.keys())
)
if is_adding:
to_update.update(flags=F("flags").bitor(flagattr))
else:
to_update.update(flags=F("flags").bitand(~flagattr))
event = {
"type": "update_message_flags",
"op": operation,
"operation": operation,
"flag": flag,
"messages": messages,
"all": False,
}
if flag == "read" and not is_adding:
# When removing the read flag (i.e. marking messages as
# unread), extend the event with an additional object with
# details on the messages required to update the client's
# `unread_msgs` data structure.
raw_unread_data = get_raw_unread_data(user_profile, messages)
event["message_details"] = format_unread_message_details(
user_profile.id, raw_unread_data
)
send_event_on_commit(user_profile.realm, event, [user_profile.id])
if flag == "read" and is_adding:
event_time = timezone_now()
do_clear_mobile_push_notifications_for_ids([user_profile.id], messages)
do_increment_logging_stat(
user_profile, COUNT_STATS["messages_read::hour"], None, event_time, increment=count
)
do_increment_logging_stat(
user_profile,
COUNT_STATS["messages_read_interactions::hour"],
None,
event_time,
increment=min(1, count),
)
return count