mirror of https://github.com/zulip/zulip.git
message_flags: Fix deadlocks when updating message flags.
Previously, an active production Zulip server would experience a class of deadlocks caused by two or more concurrent bulk update operations on the UserMessage table. This is because UPDATE ... SET ... WHERE statements that execute in parallel take row-level UPDATE locks as they get results; since the query plans may result in getting rows in different orders between two queries, this can result in deadlocks. Some databases allow ORDER BY on their UPDATE ... WHERE statements; PostgreSQL does not. In PostgreSQL, the answer is to do a sub-select with an ORDER BY ... FOR UPDATE to ensure consistent ordering on row locks. We do this all code paths using bitand or bitor as part of bulk editing message flags, which should ensure that these concurrent operations obtain row level locks on the table in the same order. Fixes #19054.
This commit is contained in:
parent
e1023f45cf
commit
28173cafc8
|
@ -2,6 +2,7 @@ from collections import defaultdict
|
|||
from dataclasses import asdict, dataclass, field
|
||||
from typing import List, Optional, Set
|
||||
|
||||
from django.db import transaction
|
||||
from django.db.models import F
|
||||
from django.utils.timezone import now as timezone_now
|
||||
from django.utils.translation import gettext as _
|
||||
|
@ -43,13 +44,15 @@ def do_mark_all_as_read(user_profile: UserProfile) -> int:
|
|||
)
|
||||
do_clear_mobile_push_notifications_for_ids([user_profile.id], all_push_message_ids)
|
||||
|
||||
msgs = UserMessage.objects.filter(user_profile=user_profile).extra(
|
||||
where=[UserMessage.where_unread()],
|
||||
)
|
||||
|
||||
count = msgs.update(
|
||||
flags=F("flags").bitor(UserMessage.flags.read),
|
||||
)
|
||||
with transaction.atomic(savepoint=False):
|
||||
query = (
|
||||
UserMessage.select_for_update_query()
|
||||
.filter(user_profile=user_profile)
|
||||
.extra(where=[UserMessage.where_unread()])
|
||||
)
|
||||
count = query.update(
|
||||
flags=F("flags").bitor(UserMessage.flags.read),
|
||||
)
|
||||
|
||||
event = asdict(
|
||||
ReadMessagesEvent(
|
||||
|
@ -80,30 +83,32 @@ def do_mark_stream_messages_as_read(
|
|||
) -> int:
|
||||
log_statsd_event("mark_stream_as_read")
|
||||
|
||||
msgs = UserMessage.objects.filter(
|
||||
user_profile=user_profile,
|
||||
)
|
||||
|
||||
msgs = msgs.filter(message__recipient_id=stream_recipient_id)
|
||||
|
||||
if topic_name:
|
||||
msgs = filter_by_topic_name_via_message(
|
||||
query=msgs,
|
||||
topic_name=topic_name,
|
||||
with transaction.atomic(savepoint=False):
|
||||
query = (
|
||||
UserMessage.select_for_update_query()
|
||||
.filter(
|
||||
user_profile=user_profile,
|
||||
message__recipient_id=stream_recipient_id,
|
||||
)
|
||||
.extra(
|
||||
where=[UserMessage.where_unread()],
|
||||
)
|
||||
)
|
||||
|
||||
msgs = msgs.extra(
|
||||
where=[UserMessage.where_unread()],
|
||||
)
|
||||
if topic_name:
|
||||
query = filter_by_topic_name_via_message(
|
||||
query=query,
|
||||
topic_name=topic_name,
|
||||
)
|
||||
|
||||
message_ids = list(msgs.values_list("message_id", flat=True))
|
||||
message_ids = list(query.values_list("message_id", flat=True))
|
||||
|
||||
if len(message_ids) == 0:
|
||||
return 0
|
||||
if len(message_ids) == 0:
|
||||
return 0
|
||||
|
||||
count = msgs.update(
|
||||
flags=F("flags").bitor(UserMessage.flags.read),
|
||||
)
|
||||
count = query.update(
|
||||
flags=F("flags").bitor(UserMessage.flags.read),
|
||||
)
|
||||
|
||||
event = asdict(
|
||||
ReadMessagesEvent(
|
||||
|
@ -133,18 +138,20 @@ def do_mark_muted_user_messages_as_read(
|
|||
user_profile: UserProfile,
|
||||
muted_user: UserProfile,
|
||||
) -> int:
|
||||
messages = UserMessage.objects.filter(
|
||||
user_profile=user_profile, message__sender=muted_user
|
||||
).extra(where=[UserMessage.where_unread()])
|
||||
with transaction.atomic(savepoint=False):
|
||||
query = (
|
||||
UserMessage.select_for_update_query()
|
||||
.filter(user_profile=user_profile, message__sender=muted_user)
|
||||
.extra(where=[UserMessage.where_unread()])
|
||||
)
|
||||
message_ids = list(query.values_list("message_id", flat=True))
|
||||
|
||||
message_ids = list(messages.values_list("message_id", flat=True))
|
||||
if len(message_ids) == 0:
|
||||
return 0
|
||||
|
||||
if len(message_ids) == 0:
|
||||
return 0
|
||||
|
||||
count = messages.update(
|
||||
flags=F("flags").bitor(UserMessage.flags.read),
|
||||
)
|
||||
count = query.update(
|
||||
flags=F("flags").bitor(UserMessage.flags.read),
|
||||
)
|
||||
|
||||
event = asdict(
|
||||
ReadMessagesEvent(
|
||||
|
@ -239,22 +246,25 @@ def do_update_message_flags(
|
|||
raise JsonableError(_("Invalid message flag operation: '{}'").format(operation))
|
||||
flagattr = getattr(UserMessage.flags, flag)
|
||||
|
||||
msgs = UserMessage.objects.filter(user_profile=user_profile, message_id__in=messages)
|
||||
um_message_ids = {um.message_id for um in msgs}
|
||||
historical_message_ids = list(set(messages) - um_message_ids)
|
||||
with transaction.atomic(savepoint=False):
|
||||
query = UserMessage.select_for_update_query().filter(
|
||||
user_profile=user_profile, message_id__in=messages
|
||||
)
|
||||
um_message_ids = {um.message_id for um in query}
|
||||
historical_message_ids = list(set(messages) - um_message_ids)
|
||||
|
||||
# Users can mutate flags for messages that don't have a UserMessage yet.
|
||||
# First, validate that the user is even allowed to access these message_ids.
|
||||
for message_id in historical_message_ids:
|
||||
access_message(user_profile, message_id)
|
||||
# Users can mutate flags for messages that don't have a UserMessage yet.
|
||||
# First, validate that the user is even allowed to access these message_ids.
|
||||
for message_id in historical_message_ids:
|
||||
access_message(user_profile, message_id)
|
||||
|
||||
# And then create historical UserMessage records. See the called function for more context.
|
||||
create_historical_user_messages(user_id=user_profile.id, message_ids=historical_message_ids)
|
||||
# And then create historical UserMessage records. See the called function for more context.
|
||||
create_historical_user_messages(user_id=user_profile.id, message_ids=historical_message_ids)
|
||||
|
||||
if operation == "add":
|
||||
count = msgs.update(flags=F("flags").bitor(flagattr))
|
||||
elif operation == "remove":
|
||||
count = msgs.update(flags=F("flags").bitand(~flagattr))
|
||||
if operation == "add":
|
||||
count = query.update(flags=F("flags").bitor(flagattr))
|
||||
elif operation == "remove":
|
||||
count = query.update(flags=F("flags").bitand(~flagattr))
|
||||
|
||||
event = {
|
||||
"type": "update_message_flags",
|
||||
|
|
|
@ -1025,10 +1025,11 @@ def handle_remove_push_notification(user_profile_id: int, message_ids: List[int]
|
|||
# assuming in this very rare case that the user has manually
|
||||
# dismissed these notifications on the device side, and the server
|
||||
# should no longer track them as outstanding notifications.
|
||||
UserMessage.objects.filter(
|
||||
user_profile_id=user_profile_id,
|
||||
message_id__in=message_ids,
|
||||
).update(flags=F("flags").bitand(~UserMessage.flags.active_mobile_push_notification))
|
||||
with transaction.atomic(savepoint=False):
|
||||
UserMessage.select_for_update_query().filter(
|
||||
user_profile_id=user_profile_id,
|
||||
message_id__in=message_ids,
|
||||
).update(flags=F("flags").bitand(~UserMessage.flags.active_mobile_push_notification))
|
||||
|
||||
|
||||
@statsd_increment("push_notifications")
|
||||
|
|
|
@ -3337,6 +3337,20 @@ class UserMessage(AbstractUserMessage):
|
|||
display_recipient = get_display_recipient(self.message.recipient)
|
||||
return f"<{self.__class__.__name__}: {display_recipient} / {self.user_profile.email} ({self.flags_list()})>"
|
||||
|
||||
@staticmethod
|
||||
def select_for_update_query() -> QuerySet["UserMessage"]:
|
||||
"""This SELECT FOR UPDATE query ensures consistent ordering on
|
||||
the row locks acquired by a bulk update operation to modify
|
||||
message flags using bitand/bitor.
|
||||
|
||||
This consistent ordering is important to prevent to prevent
|
||||
deadlocks when 2 or more bulk updates to the same rows in the
|
||||
UserMessage table race against each other (For example, if a
|
||||
client submits simultaneous duplicate API requests to mark a
|
||||
certain set of messages as read).
|
||||
"""
|
||||
return UserMessage.objects.select_for_update().order_by("message_id")
|
||||
|
||||
|
||||
def get_usermessage_by_message_id(
|
||||
user_profile: UserProfile, message_id: int
|
||||
|
|
|
@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any, List, Mapping, Set
|
|||
from unittest import mock
|
||||
|
||||
import orjson
|
||||
from django.db import connection
|
||||
from django.db import connection, transaction
|
||||
|
||||
from zerver.actions.message_flags import do_update_message_flags
|
||||
from zerver.actions.streams import do_change_stream_permission
|
||||
|
@ -1178,7 +1178,8 @@ class MessageAccessTests(ZulipTestCase):
|
|||
|
||||
# Starring private stream messages you didn't receive fails.
|
||||
self.login("cordelia")
|
||||
result = self.change_star(message_ids)
|
||||
with transaction.atomic():
|
||||
result = self.change_star(message_ids)
|
||||
self.assert_json_error(result, "Invalid message(s)")
|
||||
|
||||
stream_name = "private_stream_2"
|
||||
|
@ -1193,7 +1194,8 @@ class MessageAccessTests(ZulipTestCase):
|
|||
# can't see it if you didn't receive the message and are
|
||||
# not subscribed.
|
||||
self.login("cordelia")
|
||||
result = self.change_star(message_ids)
|
||||
with transaction.atomic():
|
||||
result = self.change_star(message_ids)
|
||||
self.assert_json_error(result, "Invalid message(s)")
|
||||
|
||||
# But if you subscribe, then you can star the message
|
||||
|
@ -1234,7 +1236,8 @@ class MessageAccessTests(ZulipTestCase):
|
|||
|
||||
guest_user = self.example_user("polonius")
|
||||
self.login_user(guest_user)
|
||||
result = self.change_star(message_id)
|
||||
with transaction.atomic():
|
||||
result = self.change_star(message_id)
|
||||
self.assert_json_error(result, "Invalid message(s)")
|
||||
|
||||
# Subscribed guest users can access public stream messages sent before they join
|
||||
|
@ -1265,13 +1268,15 @@ class MessageAccessTests(ZulipTestCase):
|
|||
|
||||
guest_user = self.example_user("polonius")
|
||||
self.login_user(guest_user)
|
||||
result = self.change_star(message_id)
|
||||
with transaction.atomic():
|
||||
result = self.change_star(message_id)
|
||||
self.assert_json_error(result, "Invalid message(s)")
|
||||
|
||||
# Guest user can't access messages of subscribed private streams if
|
||||
# history is not public to subscribers
|
||||
self.subscribe(guest_user, stream_name)
|
||||
result = self.change_star(message_id)
|
||||
with transaction.atomic():
|
||||
result = self.change_star(message_id)
|
||||
self.assert_json_error(result, "Invalid message(s)")
|
||||
|
||||
# Guest user can access messages of subscribed private streams if
|
||||
|
|
|
@ -995,9 +995,11 @@ class DeferredWorker(QueueProcessingWorker):
|
|||
messages = Message.objects.filter(
|
||||
recipient_id=event["stream_recipient_id"]
|
||||
).order_by("id")[offset : offset + batch_size]
|
||||
UserMessage.objects.filter(message__in=messages).extra(
|
||||
where=[UserMessage.where_unread()]
|
||||
).update(flags=F("flags").bitor(UserMessage.flags.read))
|
||||
|
||||
with transaction.atomic(savepoint=False):
|
||||
UserMessage.select_for_update_query().filter(message__in=messages).extra(
|
||||
where=[UserMessage.where_unread()]
|
||||
).update(flags=F("flags").bitor(UserMessage.flags.read))
|
||||
offset += len(messages)
|
||||
if len(messages) < batch_size:
|
||||
break
|
||||
|
|
Loading…
Reference in New Issue