message_flags: Fix deadlocks when updating message flags.

Previously, an active production Zulip server would experience a class
of deadlocks caused by two or more concurrent bulk update operations
on the UserMessage table.

This is because UPDATE ... SET ... WHERE statements that execute in
parallel take row-level UPDATE locks as they get results; since the
query plans may result in getting rows in different orders between two
queries, this can result in deadlocks.

Some databases allow ORDER BY on their UPDATE ... WHERE statements;
PostgreSQL does not. In PostgreSQL, the answer is to do a sub-select
with an ORDER BY ... FOR UPDATE to ensure consistent ordering on row
locks.

We do this all code paths using bitand or bitor as part of bulk
editing message flags, which should ensure that these concurrent
operations obtain row level locks on the table in the same order.

Fixes #19054.
This commit is contained in:
Christopher Chong 2022-08-14 10:02:05 +00:00 committed by Tim Abbott
parent e1023f45cf
commit 28173cafc8
5 changed files with 94 additions and 62 deletions

View File

@ -2,6 +2,7 @@ from collections import defaultdict
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import List, Optional, Set from typing import List, Optional, Set
from django.db import transaction
from django.db.models import F from django.db.models import F
from django.utils.timezone import now as timezone_now from django.utils.timezone import now as timezone_now
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
@ -43,13 +44,15 @@ def do_mark_all_as_read(user_profile: UserProfile) -> int:
) )
do_clear_mobile_push_notifications_for_ids([user_profile.id], all_push_message_ids) do_clear_mobile_push_notifications_for_ids([user_profile.id], all_push_message_ids)
msgs = UserMessage.objects.filter(user_profile=user_profile).extra( with transaction.atomic(savepoint=False):
where=[UserMessage.where_unread()], query = (
) UserMessage.select_for_update_query()
.filter(user_profile=user_profile)
count = msgs.update( .extra(where=[UserMessage.where_unread()])
flags=F("flags").bitor(UserMessage.flags.read), )
) count = query.update(
flags=F("flags").bitor(UserMessage.flags.read),
)
event = asdict( event = asdict(
ReadMessagesEvent( ReadMessagesEvent(
@ -80,30 +83,32 @@ def do_mark_stream_messages_as_read(
) -> int: ) -> int:
log_statsd_event("mark_stream_as_read") log_statsd_event("mark_stream_as_read")
msgs = UserMessage.objects.filter( with transaction.atomic(savepoint=False):
user_profile=user_profile, query = (
) UserMessage.select_for_update_query()
.filter(
msgs = msgs.filter(message__recipient_id=stream_recipient_id) user_profile=user_profile,
message__recipient_id=stream_recipient_id,
if topic_name: )
msgs = filter_by_topic_name_via_message( .extra(
query=msgs, where=[UserMessage.where_unread()],
topic_name=topic_name, )
) )
msgs = msgs.extra( if topic_name:
where=[UserMessage.where_unread()], query = filter_by_topic_name_via_message(
) query=query,
topic_name=topic_name,
)
message_ids = list(msgs.values_list("message_id", flat=True)) message_ids = list(query.values_list("message_id", flat=True))
if len(message_ids) == 0: if len(message_ids) == 0:
return 0 return 0
count = msgs.update( count = query.update(
flags=F("flags").bitor(UserMessage.flags.read), flags=F("flags").bitor(UserMessage.flags.read),
) )
event = asdict( event = asdict(
ReadMessagesEvent( ReadMessagesEvent(
@ -133,18 +138,20 @@ def do_mark_muted_user_messages_as_read(
user_profile: UserProfile, user_profile: UserProfile,
muted_user: UserProfile, muted_user: UserProfile,
) -> int: ) -> int:
messages = UserMessage.objects.filter( with transaction.atomic(savepoint=False):
user_profile=user_profile, message__sender=muted_user query = (
).extra(where=[UserMessage.where_unread()]) UserMessage.select_for_update_query()
.filter(user_profile=user_profile, message__sender=muted_user)
.extra(where=[UserMessage.where_unread()])
)
message_ids = list(query.values_list("message_id", flat=True))
message_ids = list(messages.values_list("message_id", flat=True)) if len(message_ids) == 0:
return 0
if len(message_ids) == 0: count = query.update(
return 0 flags=F("flags").bitor(UserMessage.flags.read),
)
count = messages.update(
flags=F("flags").bitor(UserMessage.flags.read),
)
event = asdict( event = asdict(
ReadMessagesEvent( ReadMessagesEvent(
@ -239,22 +246,25 @@ def do_update_message_flags(
raise JsonableError(_("Invalid message flag operation: '{}'").format(operation)) raise JsonableError(_("Invalid message flag operation: '{}'").format(operation))
flagattr = getattr(UserMessage.flags, flag) flagattr = getattr(UserMessage.flags, flag)
msgs = UserMessage.objects.filter(user_profile=user_profile, message_id__in=messages) with transaction.atomic(savepoint=False):
um_message_ids = {um.message_id for um in msgs} query = UserMessage.select_for_update_query().filter(
historical_message_ids = list(set(messages) - um_message_ids) user_profile=user_profile, message_id__in=messages
)
um_message_ids = {um.message_id for um in query}
historical_message_ids = list(set(messages) - um_message_ids)
# Users can mutate flags for messages that don't have a UserMessage yet. # Users can mutate flags for messages that don't have a UserMessage yet.
# First, validate that the user is even allowed to access these message_ids. # First, validate that the user is even allowed to access these message_ids.
for message_id in historical_message_ids: for message_id in historical_message_ids:
access_message(user_profile, message_id) access_message(user_profile, message_id)
# And then create historical UserMessage records. See the called function for more context. # And then create historical UserMessage records. See the called function for more context.
create_historical_user_messages(user_id=user_profile.id, message_ids=historical_message_ids) create_historical_user_messages(user_id=user_profile.id, message_ids=historical_message_ids)
if operation == "add": if operation == "add":
count = msgs.update(flags=F("flags").bitor(flagattr)) count = query.update(flags=F("flags").bitor(flagattr))
elif operation == "remove": elif operation == "remove":
count = msgs.update(flags=F("flags").bitand(~flagattr)) count = query.update(flags=F("flags").bitand(~flagattr))
event = { event = {
"type": "update_message_flags", "type": "update_message_flags",

View File

@ -1025,10 +1025,11 @@ def handle_remove_push_notification(user_profile_id: int, message_ids: List[int]
# assuming in this very rare case that the user has manually # assuming in this very rare case that the user has manually
# dismissed these notifications on the device side, and the server # dismissed these notifications on the device side, and the server
# should no longer track them as outstanding notifications. # should no longer track them as outstanding notifications.
UserMessage.objects.filter( with transaction.atomic(savepoint=False):
user_profile_id=user_profile_id, UserMessage.select_for_update_query().filter(
message_id__in=message_ids, user_profile_id=user_profile_id,
).update(flags=F("flags").bitand(~UserMessage.flags.active_mobile_push_notification)) message_id__in=message_ids,
).update(flags=F("flags").bitand(~UserMessage.flags.active_mobile_push_notification))
@statsd_increment("push_notifications") @statsd_increment("push_notifications")

View File

@ -3337,6 +3337,20 @@ class UserMessage(AbstractUserMessage):
display_recipient = get_display_recipient(self.message.recipient) display_recipient = get_display_recipient(self.message.recipient)
return f"<{self.__class__.__name__}: {display_recipient} / {self.user_profile.email} ({self.flags_list()})>" return f"<{self.__class__.__name__}: {display_recipient} / {self.user_profile.email} ({self.flags_list()})>"
@staticmethod
def select_for_update_query() -> QuerySet["UserMessage"]:
"""This SELECT FOR UPDATE query ensures consistent ordering on
the row locks acquired by a bulk update operation to modify
message flags using bitand/bitor.
This consistent ordering is important to prevent to prevent
deadlocks when 2 or more bulk updates to the same rows in the
UserMessage table race against each other (For example, if a
client submits simultaneous duplicate API requests to mark a
certain set of messages as read).
"""
return UserMessage.objects.select_for_update().order_by("message_id")
def get_usermessage_by_message_id( def get_usermessage_by_message_id(
user_profile: UserProfile, message_id: int user_profile: UserProfile, message_id: int

View File

@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any, List, Mapping, Set
from unittest import mock from unittest import mock
import orjson import orjson
from django.db import connection from django.db import connection, transaction
from zerver.actions.message_flags import do_update_message_flags from zerver.actions.message_flags import do_update_message_flags
from zerver.actions.streams import do_change_stream_permission from zerver.actions.streams import do_change_stream_permission
@ -1178,7 +1178,8 @@ class MessageAccessTests(ZulipTestCase):
# Starring private stream messages you didn't receive fails. # Starring private stream messages you didn't receive fails.
self.login("cordelia") self.login("cordelia")
result = self.change_star(message_ids) with transaction.atomic():
result = self.change_star(message_ids)
self.assert_json_error(result, "Invalid message(s)") self.assert_json_error(result, "Invalid message(s)")
stream_name = "private_stream_2" stream_name = "private_stream_2"
@ -1193,7 +1194,8 @@ class MessageAccessTests(ZulipTestCase):
# can't see it if you didn't receive the message and are # can't see it if you didn't receive the message and are
# not subscribed. # not subscribed.
self.login("cordelia") self.login("cordelia")
result = self.change_star(message_ids) with transaction.atomic():
result = self.change_star(message_ids)
self.assert_json_error(result, "Invalid message(s)") self.assert_json_error(result, "Invalid message(s)")
# But if you subscribe, then you can star the message # But if you subscribe, then you can star the message
@ -1234,7 +1236,8 @@ class MessageAccessTests(ZulipTestCase):
guest_user = self.example_user("polonius") guest_user = self.example_user("polonius")
self.login_user(guest_user) self.login_user(guest_user)
result = self.change_star(message_id) with transaction.atomic():
result = self.change_star(message_id)
self.assert_json_error(result, "Invalid message(s)") self.assert_json_error(result, "Invalid message(s)")
# Subscribed guest users can access public stream messages sent before they join # Subscribed guest users can access public stream messages sent before they join
@ -1265,13 +1268,15 @@ class MessageAccessTests(ZulipTestCase):
guest_user = self.example_user("polonius") guest_user = self.example_user("polonius")
self.login_user(guest_user) self.login_user(guest_user)
result = self.change_star(message_id) with transaction.atomic():
result = self.change_star(message_id)
self.assert_json_error(result, "Invalid message(s)") self.assert_json_error(result, "Invalid message(s)")
# Guest user can't access messages of subscribed private streams if # Guest user can't access messages of subscribed private streams if
# history is not public to subscribers # history is not public to subscribers
self.subscribe(guest_user, stream_name) self.subscribe(guest_user, stream_name)
result = self.change_star(message_id) with transaction.atomic():
result = self.change_star(message_id)
self.assert_json_error(result, "Invalid message(s)") self.assert_json_error(result, "Invalid message(s)")
# Guest user can access messages of subscribed private streams if # Guest user can access messages of subscribed private streams if

View File

@ -995,9 +995,11 @@ class DeferredWorker(QueueProcessingWorker):
messages = Message.objects.filter( messages = Message.objects.filter(
recipient_id=event["stream_recipient_id"] recipient_id=event["stream_recipient_id"]
).order_by("id")[offset : offset + batch_size] ).order_by("id")[offset : offset + batch_size]
UserMessage.objects.filter(message__in=messages).extra(
where=[UserMessage.where_unread()] with transaction.atomic(savepoint=False):
).update(flags=F("flags").bitor(UserMessage.flags.read)) UserMessage.select_for_update_query().filter(message__in=messages).extra(
where=[UserMessage.where_unread()]
).update(flags=F("flags").bitor(UserMessage.flags.read))
offset += len(messages) offset += len(messages)
if len(messages) < batch_size: if len(messages) < batch_size:
break break