diff --git a/zerver/lib/soft_deactivation.py b/zerver/lib/soft_deactivation.py index 44409aa2fd..8d7d40256c 100644 --- a/zerver/lib/soft_deactivation.py +++ b/zerver/lib/soft_deactivation.py @@ -5,12 +5,14 @@ from typing import Any, DefaultDict, Dict, Iterable, List, Optional, Sequence, S from django.conf import settings from django.db import transaction -from django.db.models import Exists, Max, OuterRef, QuerySet +from django.db.models import Exists, F, Max, OuterRef, QuerySet +from django.db.models.functions import Greatest from django.utils.timezone import now as timezone_now from sentry_sdk import capture_exception from zerver.lib.logging_util import log_to_file from zerver.lib.queue import queue_json_publish +from zerver.lib.user_message import bulk_insert_all_ums from zerver.lib.utils import assert_is_not_none from zerver.models import ( Message, @@ -38,15 +40,8 @@ def filter_by_subscription_history( user_profile: UserProfile, all_stream_messages: DefaultDict[int, List[MissingMessageDict]], all_stream_subscription_logs: DefaultDict[int, List[RealmAuditLog]], -) -> List[UserMessage]: - user_messages_to_insert: List[UserMessage] = [] - seen_message_ids: Set[int] = set() - - def store_user_message_to_insert(message: MissingMessageDict) -> None: - if message["id"] not in seen_message_ids: - user_message = UserMessage(user_profile=user_profile, message_id=message["id"], flags=0) - user_messages_to_insert.append(user_message) - seen_message_ids.add(message["id"]) +) -> List[int]: + message_ids: Set[int] = set() for stream_id, stream_messages_raw in all_stream_messages.items(): stream_subscription_logs = all_stream_subscription_logs[stream_id] @@ -82,7 +77,7 @@ def filter_by_subscription_history( # subscribed immediately before the event. for stream_message in stream_messages: if stream_message["id"] <= event_last_message_id: - store_user_message_to_insert(stream_message) + message_ids.add(stream_message["id"]) else: break elif log_entry.event_type in ( @@ -110,9 +105,8 @@ def filter_by_subscription_history( RealmAuditLog.SUBSCRIPTION_ACTIVATED, RealmAuditLog.SUBSCRIPTION_CREATED, ): - for stream_message in stream_messages: - store_user_message_to_insert(stream_message) - return user_messages_to_insert + message_ids.update(stream_message["id"] for stream_message in stream_messages) + return sorted(message_ids) def add_missing_messages(user_profile: UserProfile) -> None: @@ -238,19 +232,20 @@ def add_missing_messages(user_profile: UserProfile) -> None: # subscription logs and then store all UserMessage objects for bulk insert # This function does not perform any SQL related task and gets all the data # required for its operation in its params. - user_messages_to_insert = filter_by_subscription_history( + message_ids_to_insert = filter_by_subscription_history( user_profile, stream_messages, all_stream_subscription_logs ) # Doing a bulk create for all the UserMessage objects stored for creation. - while len(user_messages_to_insert) > 0: - messages, user_messages_to_insert = ( - user_messages_to_insert[0:BULK_CREATE_BATCH_SIZE], - user_messages_to_insert[BULK_CREATE_BATCH_SIZE:], + while len(message_ids_to_insert) > 0: + message_ids, message_ids_to_insert = ( + message_ids_to_insert[0:BULK_CREATE_BATCH_SIZE], + message_ids_to_insert[BULK_CREATE_BATCH_SIZE:], + ) + bulk_insert_all_ums(user_ids=[user_profile.id], message_ids=message_ids, flags=0) + UserProfile.objects.filter(id=user_profile.id).update( + last_active_message_id=Greatest(F("last_active_message_id"), message_ids[-1]) ) - UserMessage.objects.bulk_create(messages) - user_profile.last_active_message_id = messages[-1].message_id - user_profile.save(update_fields=["last_active_message_id"]) def do_soft_deactivate_user(user_profile: UserProfile) -> None: