message_flags: Update 'do_mark_stream...read' to send event on commit.

Earlier, we were using 'send_event' in do_mark_stream_messages_as_read
codepath which can lead to a situation where we enqueue events but the
function fails at a later stage.

Events should not be sent until we know we're not rolling back.

Fixes part of #30489.
This commit is contained in:
Prakhar Pratyush 2024-08-16 14:51:26 +05:30 committed by Tim Abbott
parent ed512f06bb
commit 64beea2765
1 changed files with 24 additions and 24 deletions

View File

@ -20,7 +20,7 @@ from zerver.lib.stream_subscription import get_subscribed_stream_recipient_ids_f
from zerver.lib.topic import filter_by_topic_name_via_message
from zerver.lib.user_message import DEFAULT_HISTORICAL_FLAGS, create_historical_user_messages
from zerver.models import Message, Recipient, UserMessage, UserProfile
from zerver.tornado.django_api import send_event
from zerver.tornado.django_api import send_event, send_event_on_commit
@dataclass
@ -101,35 +101,35 @@ def do_mark_all_as_read(user_profile: UserProfile, *, timeout: float | None = No
return count
@transaction.atomic(durable=True)
def do_mark_stream_messages_as_read(
user_profile: UserProfile, stream_recipient_id: int, topic_name: str | None = None
) -> int:
with transaction.atomic(savepoint=False):
query = (
UserMessage.select_for_update_query()
.filter(
user_profile=user_profile,
message__recipient_id=stream_recipient_id,
)
.extra( # noqa: S610
where=[UserMessage.where_unread()],
)
query = (
UserMessage.select_for_update_query()
.filter(
user_profile=user_profile,
message__recipient_id=stream_recipient_id,
)
.extra( # noqa: S610
where=[UserMessage.where_unread()],
)
)
if topic_name:
query = filter_by_topic_name_via_message(
query=query,
topic_name=topic_name,
)
if topic_name:
query = filter_by_topic_name_via_message(
query=query,
topic_name=topic_name,
)
message_ids = list(query.values_list("message_id", flat=True))
message_ids = list(query.values_list("message_id", flat=True))
if len(message_ids) == 0:
return 0
if len(message_ids) == 0:
return 0
count = query.update(
flags=F("flags").bitor(UserMessage.flags.read),
)
count = query.update(
flags=F("flags").bitor(UserMessage.flags.read),
)
event = asdict(
ReadMessagesEvent(
@ -139,7 +139,7 @@ def do_mark_stream_messages_as_read(
)
event_time = timezone_now()
send_event(user_profile.realm, event, [user_profile.id])
send_event_on_commit(user_profile.realm, event, [user_profile.id])
do_clear_mobile_push_notifications_for_ids([user_profile.id], message_ids)
do_increment_logging_stat(