From 28173cafc86c531fe52b5403cba3e756181b4517 Mon Sep 17 00:00:00 2001 From: Christopher Chong Date: Sun, 14 Aug 2022 10:02:05 +0000 Subject: [PATCH] message_flags: Fix deadlocks when updating message flags. Previously, an active production Zulip server would experience a class of deadlocks caused by two or more concurrent bulk update operations on the UserMessage table. This is because UPDATE ... SET ... WHERE statements that execute in parallel take row-level UPDATE locks as they get results; since the query plans may result in getting rows in different orders between two queries, this can result in deadlocks. Some databases allow ORDER BY on their UPDATE ... WHERE statements; PostgreSQL does not. In PostgreSQL, the answer is to do a sub-select with an ORDER BY ... FOR UPDATE to ensure consistent ordering on row locks. We do this all code paths using bitand or bitor as part of bulk editing message flags, which should ensure that these concurrent operations obtain row level locks on the table in the same order. Fixes #19054. --- zerver/actions/message_flags.py | 108 ++++++++++++++++------------- zerver/lib/push_notifications.py | 9 +-- zerver/models.py | 14 ++++ zerver/tests/test_message_flags.py | 17 +++-- zerver/worker/queue_processors.py | 8 ++- 5 files changed, 94 insertions(+), 62 deletions(-) diff --git a/zerver/actions/message_flags.py b/zerver/actions/message_flags.py index ced7723f82..9bfc03f1ee 100644 --- a/zerver/actions/message_flags.py +++ b/zerver/actions/message_flags.py @@ -2,6 +2,7 @@ from collections import defaultdict from dataclasses import asdict, dataclass, field from typing import List, Optional, Set +from django.db import transaction from django.db.models import F from django.utils.timezone import now as timezone_now from django.utils.translation import gettext as _ @@ -43,13 +44,15 @@ def do_mark_all_as_read(user_profile: UserProfile) -> int: ) do_clear_mobile_push_notifications_for_ids([user_profile.id], all_push_message_ids) - msgs = UserMessage.objects.filter(user_profile=user_profile).extra( - where=[UserMessage.where_unread()], - ) - - count = msgs.update( - flags=F("flags").bitor(UserMessage.flags.read), - ) + with transaction.atomic(savepoint=False): + query = ( + UserMessage.select_for_update_query() + .filter(user_profile=user_profile) + .extra(where=[UserMessage.where_unread()]) + ) + count = query.update( + flags=F("flags").bitor(UserMessage.flags.read), + ) event = asdict( ReadMessagesEvent( @@ -80,30 +83,32 @@ def do_mark_stream_messages_as_read( ) -> int: log_statsd_event("mark_stream_as_read") - msgs = UserMessage.objects.filter( - user_profile=user_profile, - ) - - msgs = msgs.filter(message__recipient_id=stream_recipient_id) - - if topic_name: - msgs = filter_by_topic_name_via_message( - query=msgs, - topic_name=topic_name, + with transaction.atomic(savepoint=False): + query = ( + UserMessage.select_for_update_query() + .filter( + user_profile=user_profile, + message__recipient_id=stream_recipient_id, + ) + .extra( + where=[UserMessage.where_unread()], + ) ) - msgs = msgs.extra( - where=[UserMessage.where_unread()], - ) + if topic_name: + query = filter_by_topic_name_via_message( + query=query, + topic_name=topic_name, + ) - message_ids = list(msgs.values_list("message_id", flat=True)) + message_ids = list(query.values_list("message_id", flat=True)) - if len(message_ids) == 0: - return 0 + if len(message_ids) == 0: + return 0 - count = msgs.update( - flags=F("flags").bitor(UserMessage.flags.read), - ) + count = query.update( + flags=F("flags").bitor(UserMessage.flags.read), + ) event = asdict( ReadMessagesEvent( @@ -133,18 +138,20 @@ def do_mark_muted_user_messages_as_read( user_profile: UserProfile, muted_user: UserProfile, ) -> int: - messages = UserMessage.objects.filter( - user_profile=user_profile, message__sender=muted_user - ).extra(where=[UserMessage.where_unread()]) + with transaction.atomic(savepoint=False): + query = ( + UserMessage.select_for_update_query() + .filter(user_profile=user_profile, message__sender=muted_user) + .extra(where=[UserMessage.where_unread()]) + ) + message_ids = list(query.values_list("message_id", flat=True)) - message_ids = list(messages.values_list("message_id", flat=True)) + if len(message_ids) == 0: + return 0 - if len(message_ids) == 0: - return 0 - - count = messages.update( - flags=F("flags").bitor(UserMessage.flags.read), - ) + count = query.update( + flags=F("flags").bitor(UserMessage.flags.read), + ) event = asdict( ReadMessagesEvent( @@ -239,22 +246,25 @@ def do_update_message_flags( raise JsonableError(_("Invalid message flag operation: '{}'").format(operation)) flagattr = getattr(UserMessage.flags, flag) - msgs = UserMessage.objects.filter(user_profile=user_profile, message_id__in=messages) - um_message_ids = {um.message_id for um in msgs} - historical_message_ids = list(set(messages) - um_message_ids) + with transaction.atomic(savepoint=False): + query = UserMessage.select_for_update_query().filter( + user_profile=user_profile, message_id__in=messages + ) + um_message_ids = {um.message_id for um in query} + historical_message_ids = list(set(messages) - um_message_ids) - # Users can mutate flags for messages that don't have a UserMessage yet. - # First, validate that the user is even allowed to access these message_ids. - for message_id in historical_message_ids: - access_message(user_profile, message_id) + # Users can mutate flags for messages that don't have a UserMessage yet. + # First, validate that the user is even allowed to access these message_ids. + for message_id in historical_message_ids: + access_message(user_profile, message_id) - # And then create historical UserMessage records. See the called function for more context. - create_historical_user_messages(user_id=user_profile.id, message_ids=historical_message_ids) + # And then create historical UserMessage records. See the called function for more context. + create_historical_user_messages(user_id=user_profile.id, message_ids=historical_message_ids) - if operation == "add": - count = msgs.update(flags=F("flags").bitor(flagattr)) - elif operation == "remove": - count = msgs.update(flags=F("flags").bitand(~flagattr)) + if operation == "add": + count = query.update(flags=F("flags").bitor(flagattr)) + elif operation == "remove": + count = query.update(flags=F("flags").bitand(~flagattr)) event = { "type": "update_message_flags", diff --git a/zerver/lib/push_notifications.py b/zerver/lib/push_notifications.py index aec31a8469..382872bcb7 100644 --- a/zerver/lib/push_notifications.py +++ b/zerver/lib/push_notifications.py @@ -1025,10 +1025,11 @@ def handle_remove_push_notification(user_profile_id: int, message_ids: List[int] # assuming in this very rare case that the user has manually # dismissed these notifications on the device side, and the server # should no longer track them as outstanding notifications. - UserMessage.objects.filter( - user_profile_id=user_profile_id, - message_id__in=message_ids, - ).update(flags=F("flags").bitand(~UserMessage.flags.active_mobile_push_notification)) + with transaction.atomic(savepoint=False): + UserMessage.select_for_update_query().filter( + user_profile_id=user_profile_id, + message_id__in=message_ids, + ).update(flags=F("flags").bitand(~UserMessage.flags.active_mobile_push_notification)) @statsd_increment("push_notifications") diff --git a/zerver/models.py b/zerver/models.py index 95db1b8b00..0abd389782 100644 --- a/zerver/models.py +++ b/zerver/models.py @@ -3337,6 +3337,20 @@ class UserMessage(AbstractUserMessage): display_recipient = get_display_recipient(self.message.recipient) return f"<{self.__class__.__name__}: {display_recipient} / {self.user_profile.email} ({self.flags_list()})>" + @staticmethod + def select_for_update_query() -> QuerySet["UserMessage"]: + """This SELECT FOR UPDATE query ensures consistent ordering on + the row locks acquired by a bulk update operation to modify + message flags using bitand/bitor. + + This consistent ordering is important to prevent to prevent + deadlocks when 2 or more bulk updates to the same rows in the + UserMessage table race against each other (For example, if a + client submits simultaneous duplicate API requests to mark a + certain set of messages as read). + """ + return UserMessage.objects.select_for_update().order_by("message_id") + def get_usermessage_by_message_id( user_profile: UserProfile, message_id: int diff --git a/zerver/tests/test_message_flags.py b/zerver/tests/test_message_flags.py index 0f4e09d034..abe24bf101 100644 --- a/zerver/tests/test_message_flags.py +++ b/zerver/tests/test_message_flags.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any, List, Mapping, Set from unittest import mock import orjson -from django.db import connection +from django.db import connection, transaction from zerver.actions.message_flags import do_update_message_flags from zerver.actions.streams import do_change_stream_permission @@ -1178,7 +1178,8 @@ class MessageAccessTests(ZulipTestCase): # Starring private stream messages you didn't receive fails. self.login("cordelia") - result = self.change_star(message_ids) + with transaction.atomic(): + result = self.change_star(message_ids) self.assert_json_error(result, "Invalid message(s)") stream_name = "private_stream_2" @@ -1193,7 +1194,8 @@ class MessageAccessTests(ZulipTestCase): # can't see it if you didn't receive the message and are # not subscribed. self.login("cordelia") - result = self.change_star(message_ids) + with transaction.atomic(): + result = self.change_star(message_ids) self.assert_json_error(result, "Invalid message(s)") # But if you subscribe, then you can star the message @@ -1234,7 +1236,8 @@ class MessageAccessTests(ZulipTestCase): guest_user = self.example_user("polonius") self.login_user(guest_user) - result = self.change_star(message_id) + with transaction.atomic(): + result = self.change_star(message_id) self.assert_json_error(result, "Invalid message(s)") # Subscribed guest users can access public stream messages sent before they join @@ -1265,13 +1268,15 @@ class MessageAccessTests(ZulipTestCase): guest_user = self.example_user("polonius") self.login_user(guest_user) - result = self.change_star(message_id) + with transaction.atomic(): + result = self.change_star(message_id) self.assert_json_error(result, "Invalid message(s)") # Guest user can't access messages of subscribed private streams if # history is not public to subscribers self.subscribe(guest_user, stream_name) - result = self.change_star(message_id) + with transaction.atomic(): + result = self.change_star(message_id) self.assert_json_error(result, "Invalid message(s)") # Guest user can access messages of subscribed private streams if diff --git a/zerver/worker/queue_processors.py b/zerver/worker/queue_processors.py index 975927019e..d7fa345649 100644 --- a/zerver/worker/queue_processors.py +++ b/zerver/worker/queue_processors.py @@ -995,9 +995,11 @@ class DeferredWorker(QueueProcessingWorker): messages = Message.objects.filter( recipient_id=event["stream_recipient_id"] ).order_by("id")[offset : offset + batch_size] - UserMessage.objects.filter(message__in=messages).extra( - where=[UserMessage.where_unread()] - ).update(flags=F("flags").bitor(UserMessage.flags.read)) + + with transaction.atomic(savepoint=False): + UserMessage.select_for_update_query().filter(message__in=messages).extra( + where=[UserMessage.where_unread()] + ).update(flags=F("flags").bitor(UserMessage.flags.read)) offset += len(messages) if len(messages) < batch_size: break