message_flags: Update 'do_update_message_flags' to send event on commit.

Earlier, we were using 'send_event' in do_update_message_flags
which can lead to a situation where we enqueue events but the
function fails at a later stage.

Events should not be sent until we know we're not rolling back.

Fixes part of #30489.
This commit is contained in:
Prakhar Pratyush 2024-08-16 18:00:41 +05:30 committed by Tim Abbott
parent 32a4a112b1
commit 9026e6ecc1
2 changed files with 38 additions and 41 deletions

View File

@ -273,7 +273,7 @@ def do_update_message_flags(
flagattr = getattr(UserMessage.flags, flag) flagattr = getattr(UserMessage.flags, flag)
flag_target = flagattr if is_adding else 0 flag_target = flagattr if is_adding else 0
with transaction.atomic(savepoint=False): with transaction.atomic(durable=True):
if flag == "read" and not is_adding: if flag == "read" and not is_adding:
# We have an invariant that all stream messages marked as # We have an invariant that all stream messages marked as
# unread must be in streams the user is subscribed to. # unread must be in streams the user is subscribed to.
@ -359,38 +359,40 @@ def do_update_message_flags(
else: else:
to_update.update(flags=F("flags").bitand(~flagattr)) to_update.update(flags=F("flags").bitand(~flagattr))
event = { event = {
"type": "update_message_flags", "type": "update_message_flags",
"op": operation, "op": operation,
"operation": operation, "operation": operation,
"flag": flag, "flag": flag,
"messages": messages, "messages": messages,
"all": False, "all": False,
} }
if flag == "read" and not is_adding: if flag == "read" and not is_adding:
# When removing the read flag (i.e. marking messages as # When removing the read flag (i.e. marking messages as
# unread), extend the event with an additional object with # unread), extend the event with an additional object with
# details on the messages required to update the client's # details on the messages required to update the client's
# `unread_msgs` data structure. # `unread_msgs` data structure.
raw_unread_data = get_raw_unread_data(user_profile, messages) raw_unread_data = get_raw_unread_data(user_profile, messages)
event["message_details"] = format_unread_message_details(user_profile.id, raw_unread_data) event["message_details"] = format_unread_message_details(
user_profile.id, raw_unread_data
)
send_event(user_profile.realm, event, [user_profile.id]) send_event_on_commit(user_profile.realm, event, [user_profile.id])
if flag == "read" and is_adding: if flag == "read" and is_adding:
event_time = timezone_now() event_time = timezone_now()
do_clear_mobile_push_notifications_for_ids([user_profile.id], messages) do_clear_mobile_push_notifications_for_ids([user_profile.id], messages)
do_increment_logging_stat( do_increment_logging_stat(
user_profile, COUNT_STATS["messages_read::hour"], None, event_time, increment=count user_profile, COUNT_STATS["messages_read::hour"], None, event_time, increment=count
) )
do_increment_logging_stat( do_increment_logging_stat(
user_profile, user_profile,
COUNT_STATS["messages_read_interactions::hour"], COUNT_STATS["messages_read_interactions::hour"],
None, None,
event_time, event_time,
increment=min(1, count), increment=min(1, count),
) )
return count return count

View File

@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any
from unittest import mock from unittest import mock
import orjson import orjson
from django.db import connection, transaction from django.db import connection
from typing_extensions import override from typing_extensions import override
from zerver.actions.message_flags import do_update_message_flags from zerver.actions.message_flags import do_update_message_flags
@ -1603,8 +1603,7 @@ class MessageAccessTests(ZulipTestCase):
# Starring private stream messages you didn't receive fails. # Starring private stream messages you didn't receive fails.
self.login("cordelia") self.login("cordelia")
with transaction.atomic(): result = self.change_star(message_ids)
result = self.change_star(message_ids)
self.assert_json_error(result, "Invalid message(s)") self.assert_json_error(result, "Invalid message(s)")
stream_name = "private_stream_2" stream_name = "private_stream_2"
@ -1619,8 +1618,7 @@ class MessageAccessTests(ZulipTestCase):
# can't see it if you didn't receive the message and are # can't see it if you didn't receive the message and are
# not subscribed. # not subscribed.
self.login("cordelia") self.login("cordelia")
with transaction.atomic(): result = self.change_star(message_ids)
result = self.change_star(message_ids)
self.assert_json_error(result, "Invalid message(s)") self.assert_json_error(result, "Invalid message(s)")
# But if you subscribe, then you can star the message # But if you subscribe, then you can star the message
@ -1661,8 +1659,7 @@ class MessageAccessTests(ZulipTestCase):
guest_user = self.example_user("polonius") guest_user = self.example_user("polonius")
self.login_user(guest_user) self.login_user(guest_user)
with transaction.atomic(): result = self.change_star(message_id)
result = self.change_star(message_id)
self.assert_json_error(result, "Invalid message(s)") self.assert_json_error(result, "Invalid message(s)")
# Subscribed guest users can access public stream messages sent before they join # Subscribed guest users can access public stream messages sent before they join
@ -1693,15 +1690,13 @@ class MessageAccessTests(ZulipTestCase):
guest_user = self.example_user("polonius") guest_user = self.example_user("polonius")
self.login_user(guest_user) self.login_user(guest_user)
with transaction.atomic(): result = self.change_star(message_id)
result = self.change_star(message_id)
self.assert_json_error(result, "Invalid message(s)") self.assert_json_error(result, "Invalid message(s)")
# Guest user can't access messages of subscribed private streams if # Guest user can't access messages of subscribed private streams if
# history is not public to subscribers # history is not public to subscribers
self.subscribe(guest_user, stream_name) self.subscribe(guest_user, stream_name)
with transaction.atomic(): result = self.change_star(message_id)
result = self.change_star(message_id)
self.assert_json_error(result, "Invalid message(s)") self.assert_json_error(result, "Invalid message(s)")
# Guest user can access messages of subscribed private streams if # Guest user can access messages of subscribed private streams if