message_flags: Fix deadlocks when updating message flags.

Previously, an active production Zulip server would experience a class
of deadlocks caused by two or more concurrent bulk update operations
on the UserMessage table.

This is because UPDATE ... SET ... WHERE statements that execute in
parallel take row-level UPDATE locks as they get results; since the
query plans may result in getting rows in different orders between two
queries, this can result in deadlocks.

Some databases allow ORDER BY on their UPDATE ... WHERE statements;
PostgreSQL does not. In PostgreSQL, the answer is to do a sub-select
with an ORDER BY ... FOR UPDATE to ensure consistent ordering on row
locks.

We do this all code paths using bitand or bitor as part of bulk
editing message flags, which should ensure that these concurrent
operations obtain row level locks on the table in the same order.

Fixes #19054.
This commit is contained in:
Christopher Chong 2022-08-14 10:02:05 +00:00 committed by Tim Abbott
parent e1023f45cf
commit 28173cafc8
5 changed files with 94 additions and 62 deletions

View File

@ -2,6 +2,7 @@ from collections import defaultdict
from dataclasses import asdict, dataclass, field
from typing import List, Optional, Set
from django.db import transaction
from django.db.models import F
from django.utils.timezone import now as timezone_now
from django.utils.translation import gettext as _
@ -43,11 +44,13 @@ def do_mark_all_as_read(user_profile: UserProfile) -> int:
)
do_clear_mobile_push_notifications_for_ids([user_profile.id], all_push_message_ids)
msgs = UserMessage.objects.filter(user_profile=user_profile).extra(
where=[UserMessage.where_unread()],
with transaction.atomic(savepoint=False):
query = (
UserMessage.select_for_update_query()
.filter(user_profile=user_profile)
.extra(where=[UserMessage.where_unread()])
)
count = msgs.update(
count = query.update(
flags=F("flags").bitor(UserMessage.flags.read),
)
@ -80,28 +83,30 @@ def do_mark_stream_messages_as_read(
) -> int:
log_statsd_event("mark_stream_as_read")
msgs = UserMessage.objects.filter(
with transaction.atomic(savepoint=False):
query = (
UserMessage.select_for_update_query()
.filter(
user_profile=user_profile,
message__recipient_id=stream_recipient_id,
)
.extra(
where=[UserMessage.where_unread()],
)
)
msgs = msgs.filter(message__recipient_id=stream_recipient_id)
if topic_name:
msgs = filter_by_topic_name_via_message(
query=msgs,
query = filter_by_topic_name_via_message(
query=query,
topic_name=topic_name,
)
msgs = msgs.extra(
where=[UserMessage.where_unread()],
)
message_ids = list(msgs.values_list("message_id", flat=True))
message_ids = list(query.values_list("message_id", flat=True))
if len(message_ids) == 0:
return 0
count = msgs.update(
count = query.update(
flags=F("flags").bitor(UserMessage.flags.read),
)
@ -133,16 +138,18 @@ def do_mark_muted_user_messages_as_read(
user_profile: UserProfile,
muted_user: UserProfile,
) -> int:
messages = UserMessage.objects.filter(
user_profile=user_profile, message__sender=muted_user
).extra(where=[UserMessage.where_unread()])
message_ids = list(messages.values_list("message_id", flat=True))
with transaction.atomic(savepoint=False):
query = (
UserMessage.select_for_update_query()
.filter(user_profile=user_profile, message__sender=muted_user)
.extra(where=[UserMessage.where_unread()])
)
message_ids = list(query.values_list("message_id", flat=True))
if len(message_ids) == 0:
return 0
count = messages.update(
count = query.update(
flags=F("flags").bitor(UserMessage.flags.read),
)
@ -239,8 +246,11 @@ def do_update_message_flags(
raise JsonableError(_("Invalid message flag operation: '{}'").format(operation))
flagattr = getattr(UserMessage.flags, flag)
msgs = UserMessage.objects.filter(user_profile=user_profile, message_id__in=messages)
um_message_ids = {um.message_id for um in msgs}
with transaction.atomic(savepoint=False):
query = UserMessage.select_for_update_query().filter(
user_profile=user_profile, message_id__in=messages
)
um_message_ids = {um.message_id for um in query}
historical_message_ids = list(set(messages) - um_message_ids)
# Users can mutate flags for messages that don't have a UserMessage yet.
@ -252,9 +262,9 @@ def do_update_message_flags(
create_historical_user_messages(user_id=user_profile.id, message_ids=historical_message_ids)
if operation == "add":
count = msgs.update(flags=F("flags").bitor(flagattr))
count = query.update(flags=F("flags").bitor(flagattr))
elif operation == "remove":
count = msgs.update(flags=F("flags").bitand(~flagattr))
count = query.update(flags=F("flags").bitand(~flagattr))
event = {
"type": "update_message_flags",

View File

@ -1025,7 +1025,8 @@ def handle_remove_push_notification(user_profile_id: int, message_ids: List[int]
# assuming in this very rare case that the user has manually
# dismissed these notifications on the device side, and the server
# should no longer track them as outstanding notifications.
UserMessage.objects.filter(
with transaction.atomic(savepoint=False):
UserMessage.select_for_update_query().filter(
user_profile_id=user_profile_id,
message_id__in=message_ids,
).update(flags=F("flags").bitand(~UserMessage.flags.active_mobile_push_notification))

View File

@ -3337,6 +3337,20 @@ class UserMessage(AbstractUserMessage):
display_recipient = get_display_recipient(self.message.recipient)
return f"<{self.__class__.__name__}: {display_recipient} / {self.user_profile.email} ({self.flags_list()})>"
@staticmethod
def select_for_update_query() -> QuerySet["UserMessage"]:
"""This SELECT FOR UPDATE query ensures consistent ordering on
the row locks acquired by a bulk update operation to modify
message flags using bitand/bitor.
This consistent ordering is important to prevent to prevent
deadlocks when 2 or more bulk updates to the same rows in the
UserMessage table race against each other (For example, if a
client submits simultaneous duplicate API requests to mark a
certain set of messages as read).
"""
return UserMessage.objects.select_for_update().order_by("message_id")
def get_usermessage_by_message_id(
user_profile: UserProfile, message_id: int

View File

@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any, List, Mapping, Set
from unittest import mock
import orjson
from django.db import connection
from django.db import connection, transaction
from zerver.actions.message_flags import do_update_message_flags
from zerver.actions.streams import do_change_stream_permission
@ -1178,6 +1178,7 @@ class MessageAccessTests(ZulipTestCase):
# Starring private stream messages you didn't receive fails.
self.login("cordelia")
with transaction.atomic():
result = self.change_star(message_ids)
self.assert_json_error(result, "Invalid message(s)")
@ -1193,6 +1194,7 @@ class MessageAccessTests(ZulipTestCase):
# can't see it if you didn't receive the message and are
# not subscribed.
self.login("cordelia")
with transaction.atomic():
result = self.change_star(message_ids)
self.assert_json_error(result, "Invalid message(s)")
@ -1234,6 +1236,7 @@ class MessageAccessTests(ZulipTestCase):
guest_user = self.example_user("polonius")
self.login_user(guest_user)
with transaction.atomic():
result = self.change_star(message_id)
self.assert_json_error(result, "Invalid message(s)")
@ -1265,12 +1268,14 @@ class MessageAccessTests(ZulipTestCase):
guest_user = self.example_user("polonius")
self.login_user(guest_user)
with transaction.atomic():
result = self.change_star(message_id)
self.assert_json_error(result, "Invalid message(s)")
# Guest user can't access messages of subscribed private streams if
# history is not public to subscribers
self.subscribe(guest_user, stream_name)
with transaction.atomic():
result = self.change_star(message_id)
self.assert_json_error(result, "Invalid message(s)")

View File

@ -995,7 +995,9 @@ class DeferredWorker(QueueProcessingWorker):
messages = Message.objects.filter(
recipient_id=event["stream_recipient_id"]
).order_by("id")[offset : offset + batch_size]
UserMessage.objects.filter(message__in=messages).extra(
with transaction.atomic(savepoint=False):
UserMessage.select_for_update_query().filter(message__in=messages).extra(
where=[UserMessage.where_unread()]
).update(flags=F("flags").bitor(UserMessage.flags.read))
offset += len(messages)