import time from collections import defaultdict from dataclasses import asdict, dataclass, field from django.conf import settings from django.db import transaction from django.db.models import F from django.utils.timezone import now as timezone_now from django.utils.translation import gettext as _ from analytics.lib.counts import COUNT_STATS, do_increment_logging_stat from zerver.lib.exceptions import JsonableError from zerver.lib.message import ( bulk_access_messages, format_unread_message_details, get_raw_unread_data, ) from zerver.lib.queue import queue_event_on_commit from zerver.lib.stream_subscription import get_subscribed_stream_recipient_ids_for_user from zerver.lib.topic import filter_by_topic_name_via_message from zerver.lib.user_message import DEFAULT_HISTORICAL_FLAGS, create_historical_user_messages from zerver.models import Message, Recipient, UserMessage, UserProfile from zerver.tornado.django_api import send_event_on_commit, send_event_rollback_unsafe @dataclass class ReadMessagesEvent: messages: list[int] all: bool type: str = field(default="update_message_flags", init=False) op: str = field(default="add", init=False) operation: str = field(default="add", init=False) flag: str = field(default="read", init=False) def do_mark_all_as_read(user_profile: UserProfile, *, timeout: float | None = None) -> int | None: start_time = time.monotonic() # First, we clear mobile push notifications. This is safer in the # event that the below logic times out and we're killed. all_push_message_ids = ( UserMessage.objects.filter( user_profile=user_profile, ) .extra( # noqa: S610 where=[UserMessage.where_active_push_notification()], ) .values_list("message_id", flat=True)[0:10000] ) do_clear_mobile_push_notifications_for_ids([user_profile.id], all_push_message_ids) batch_size = 2000 count = 0 while True: if timeout is not None and time.monotonic() >= start_time + timeout: return None with transaction.atomic(durable=True): query = ( UserMessage.select_for_update_query() .filter(user_profile=user_profile) .extra(where=[UserMessage.where_unread()])[:batch_size] # noqa: S610 ) # This updated_count is the same as the number of UserMessage # rows selected, because due to the FOR UPDATE lock, we're guaranteed # that all the selected rows will indeed be updated. # UPDATE queries don't support LIMIT, so we have to use a subquery # to do batching. updated_count = UserMessage.objects.filter(id__in=query).update( flags=F("flags").bitor(UserMessage.flags.read), ) event_time = timezone_now() do_increment_logging_stat( user_profile, COUNT_STATS["messages_read::hour"], None, event_time, increment=updated_count, ) do_increment_logging_stat( user_profile, COUNT_STATS["messages_read_interactions::hour"], None, event_time, increment=min(1, updated_count), ) count += updated_count if updated_count < batch_size: break event = asdict( ReadMessagesEvent( messages=[], # we don't send messages, since the client reloads anyway all=True, ) ) send_event_rollback_unsafe(user_profile.realm, event, [user_profile.id]) return count @transaction.atomic(durable=True) def do_mark_stream_messages_as_read( user_profile: UserProfile, stream_recipient_id: int, topic_name: str | None = None ) -> int: query = ( UserMessage.select_for_update_query() .filter( user_profile=user_profile, message__recipient_id=stream_recipient_id, ) .extra( # noqa: S610 where=[UserMessage.where_unread()], ) ) if topic_name: query = filter_by_topic_name_via_message( query=query, topic_name=topic_name, ) message_ids = list(query.values_list("message_id", flat=True)) if len(message_ids) == 0: return 0 count = query.update( flags=F("flags").bitor(UserMessage.flags.read), ) event = asdict( ReadMessagesEvent( messages=message_ids, all=False, ) ) event_time = timezone_now() send_event_on_commit(user_profile.realm, event, [user_profile.id]) do_clear_mobile_push_notifications_for_ids([user_profile.id], message_ids) do_increment_logging_stat( user_profile, COUNT_STATS["messages_read::hour"], None, event_time, increment=count ) do_increment_logging_stat( user_profile, COUNT_STATS["messages_read_interactions::hour"], None, event_time, increment=min(1, count), ) return count @transaction.atomic(savepoint=False) def do_mark_muted_user_messages_as_read( user_profile: UserProfile, muted_user: UserProfile, ) -> int: query = ( UserMessage.select_for_update_query() .filter(user_profile=user_profile, message__sender=muted_user) .extra(where=[UserMessage.where_unread()]) # noqa: S610 ) message_ids = list(query.values_list("message_id", flat=True)) if len(message_ids) == 0: return 0 count = query.update( flags=F("flags").bitor(UserMessage.flags.read), ) event = asdict( ReadMessagesEvent( messages=message_ids, all=False, ) ) event_time = timezone_now() send_event_on_commit(user_profile.realm, event, [user_profile.id]) do_clear_mobile_push_notifications_for_ids([user_profile.id], message_ids) do_increment_logging_stat( user_profile, COUNT_STATS["messages_read::hour"], None, event_time, increment=count ) do_increment_logging_stat( user_profile, COUNT_STATS["messages_read_interactions::hour"], None, event_time, increment=min(1, count), ) return count def do_update_mobile_push_notification( message: Message, prior_mention_user_ids: set[int], mentions_user_ids: set[int], stream_push_user_ids: set[int], ) -> None: # Called during the message edit code path to remove mobile push # notifications for users who are no longer mentioned following # the edit. See #15428 for details. # # A perfect implementation would also support updating the message # in a sent notification if a message was edited to mention a # group rather than a user (or vice versa), though it is likely # not worth the effort to do such a change. if not message.is_stream_message(): return remove_notify_users = prior_mention_user_ids - mentions_user_ids - stream_push_user_ids do_clear_mobile_push_notifications_for_ids(list(remove_notify_users), [message.id]) def do_clear_mobile_push_notifications_for_ids( user_profile_ids: list[int], message_ids: list[int] ) -> None: if len(message_ids) == 0: return # This function supports clearing notifications for several users # only for the message-edit use case where we'll have a single message_id. assert len(user_profile_ids) == 1 or len(message_ids) == 1 messages_by_user = defaultdict(list) notifications_to_update = ( UserMessage.objects.filter( message_id__in=message_ids, user_profile_id__in=user_profile_ids, ) .extra( # noqa: S610 where=[UserMessage.where_active_push_notification()], ) .values_list("user_profile_id", "message_id") ) for user_id, message_id in notifications_to_update: messages_by_user[user_id].append(message_id) for user_profile_id, event_message_ids in messages_by_user.items(): notice = { "type": "remove", "user_profile_id": user_profile_id, "message_ids": event_message_ids, } if settings.MOBILE_NOTIFICATIONS_SHARDS > 1: # nocoverage shard_id = user_profile_id % settings.MOBILE_NOTIFICATIONS_SHARDS + 1 queue_event_on_commit(f"missedmessage_mobile_notifications_shard{shard_id}", notice) else: queue_event_on_commit("missedmessage_mobile_notifications", notice) def do_update_message_flags( user_profile: UserProfile, operation: str, flag: str, messages: list[int] ) -> int: valid_flags = [item for item in UserMessage.flags if item not in UserMessage.NON_API_FLAGS] if flag not in valid_flags: raise JsonableError(_("Invalid flag: '{flag}'").format(flag=flag)) if flag in UserMessage.NON_EDITABLE_FLAGS: raise JsonableError(_("Flag not editable: '{flag}'").format(flag=flag)) if operation not in ("add", "remove"): raise JsonableError( _("Invalid message flag operation: '{operation}'").format(operation=operation) ) is_adding = operation == "add" flagattr = getattr(UserMessage.flags, flag) flag_target = flagattr if is_adding else 0 with transaction.atomic(durable=True): if flag == "read" and not is_adding: # We have an invariant that all stream messages marked as # unread must be in streams the user is subscribed to. # # When marking as unread, we enforce this invariant by # ignoring any messages in streams the user is not # currently subscribed to. subscribed_recipient_ids = get_subscribed_stream_recipient_ids_for_user(user_profile) message_ids_in_unsubscribed_streams = set( # Uses index: zerver_message_pkey Message.objects.select_related("recipient") .filter(id__in=messages, recipient__type=Recipient.STREAM) .exclude(recipient_id__in=subscribed_recipient_ids) .values_list("id", flat=True) ) messages = [ message_id for message_id in messages if message_id not in message_ids_in_unsubscribed_streams ] ums = { um.message_id: um for um in UserMessage.select_for_update_query().filter( user_profile=user_profile, message_id__in=messages ) } # Filter out rows that already have the desired flag. We do # this here, rather than in the original database query, # because not all flags have database indexes and we want to # bound the cost of this operation. messages = [ message_id for message_id in messages if (int(ums[message_id].flags) if message_id in ums else DEFAULT_HISTORICAL_FLAGS) & flagattr != flag_target ] count = len(messages) if DEFAULT_HISTORICAL_FLAGS & flagattr != flag_target: # When marking messages as read, creating "historical" # UserMessage rows would be a waste of storage, because # `flags.read | flags.historical` is exactly the flags we # simulate when processing a message for which a user has # access but no UserMessage row. # # Users can mutate flags for messages that don't have a # UserMessage yet. Validate that the user is even allowed # to access these message_ids; if so, we will create # "historical" UserMessage rows for the messages in question. # # See create_historical_user_messages for a more detailed # explanation. historical_message_ids = set(messages) - set(ums.keys()) historical_messages = bulk_access_messages( user_profile, list( # Uses index: zerver_message_pkey Message.objects.filter(id__in=historical_message_ids).prefetch_related( "recipient" ) ), ) if len(historical_messages) != len(historical_message_ids): raise JsonableError(_("Invalid message(s)")) create_historical_user_messages( user_id=user_profile.id, message_ids=list(historical_message_ids), flagattr=flagattr, flag_target=flag_target, ) to_update = UserMessage.objects.filter( user_profile=user_profile, message_id__in=set(messages) & set(ums.keys()) ) if is_adding: to_update.update(flags=F("flags").bitor(flagattr)) else: to_update.update(flags=F("flags").bitand(~flagattr)) event = { "type": "update_message_flags", "op": operation, "operation": operation, "flag": flag, "messages": messages, "all": False, } if flag == "read" and not is_adding: # When removing the read flag (i.e. marking messages as # unread), extend the event with an additional object with # details on the messages required to update the client's # `unread_msgs` data structure. raw_unread_data = get_raw_unread_data(user_profile, messages) event["message_details"] = format_unread_message_details( user_profile.id, raw_unread_data ) send_event_on_commit(user_profile.realm, event, [user_profile.id]) if flag == "read" and is_adding: event_time = timezone_now() do_clear_mobile_push_notifications_for_ids([user_profile.id], messages) do_increment_logging_stat( user_profile, COUNT_STATS["messages_read::hour"], None, event_time, increment=count ) do_increment_logging_stat( user_profile, COUNT_STATS["messages_read_interactions::hour"], None, event_time, increment=min(1, count), ) return count