diff --git a/templates/zerver/api/changelog.md b/templates/zerver/api/changelog.md index e439877981..73a3cf57d3 100644 --- a/templates/zerver/api/changelog.md +++ b/templates/zerver/api/changelog.md @@ -25,6 +25,10 @@ format used by the Zulip server that they are interacting with. * [`GET /messages`](/api/get-messages): The new `include_anchor` parameter controls whether a message with ID matching the specified `anchor` should be included. +* The `update_message_flags` event sent by [`POST + /messages/flags`](/api/update-message-flags) no longer redundantly + lists messages where the flag was set to the same state it was + already in. **Feature level 154** diff --git a/zerver/actions/create_user.py b/zerver/actions/create_user.py index 2bbd08b5f6..fac4db6f66 100644 --- a/zerver/actions/create_user.py +++ b/zerver/actions/create_user.py @@ -62,7 +62,9 @@ ONBOARDING_RECENT_TIMEDELTA = datetime.timedelta(weeks=1) DEFAULT_HISTORICAL_FLAGS = UserMessage.flags.historical | UserMessage.flags.read -def create_historical_user_messages(*, user_id: int, message_ids: Iterable[int]) -> None: +def create_historical_user_messages( + *, user_id: int, message_ids: Iterable[int], flags: int = DEFAULT_HISTORICAL_FLAGS +) -> None: # Users can see and interact with messages sent to streams with # public history for which they do not have a UserMessage because # they were not a subscriber at the time the message was sent. @@ -71,7 +73,7 @@ def create_historical_user_messages(*, user_id: int, message_ids: Iterable[int]) # these have the special historical flag which keeps track of the # fact that the user did not receive the message at the time it was sent. UserMessage.objects.bulk_create( - UserMessage(user_profile_id=user_id, message_id=message_id, flags=DEFAULT_HISTORICAL_FLAGS) + UserMessage(user_profile_id=user_id, message_id=message_id, flags=flags) for message_id in message_ids ) diff --git a/zerver/actions/message_flags.py b/zerver/actions/message_flags.py index 6a0e4f3bc7..937cbde8f1 100644 --- a/zerver/actions/message_flags.py +++ b/zerver/actions/message_flags.py @@ -8,7 +8,7 @@ from django.utils.timezone import now as timezone_now from django.utils.translation import gettext as _ from analytics.lib.counts import COUNT_STATS, do_increment_logging_stat -from zerver.actions.create_user import create_historical_user_messages +from zerver.actions.create_user import DEFAULT_HISTORICAL_FLAGS, create_historical_user_messages from zerver.lib.exceptions import JsonableError from zerver.lib.message import ( bulk_access_messages, @@ -266,6 +266,7 @@ def do_update_message_flags( raise JsonableError(_("Invalid message flag operation: '{}'").format(operation)) is_adding = operation == "add" flagattr = getattr(UserMessage.flags, flag) + flag_target = flagattr if is_adding else 0 with transaction.atomic(savepoint=False): if flag == "read" and not is_adding: @@ -290,19 +291,33 @@ def do_update_message_flags( if message_id not in message_ids_in_unsubscribed_streams ] - query = UserMessage.select_for_update_query().filter( - user_profile=user_profile, message_id__in=messages - ) + ums = { + um.message_id: um + for um in UserMessage.select_for_update_query().filter( + user_profile=user_profile, message_id__in=messages + ) + } - um_message_ids = {um.message_id for um in query} - if flag == "read" and is_adding: + # Filter out rows that already have the desired flag. We do + # this here, rather than in the original database query, + # because not all flags have database indexes and we want to + # bound the cost of this operation. + messages = [ + message_id + for message_id in messages + if (int(ums[message_id].flags) if message_id in ums else DEFAULT_HISTORICAL_FLAGS) + & flagattr + != flag_target + ] + count = len(messages) + + if DEFAULT_HISTORICAL_FLAGS & flagattr != flag_target: # When marking messages as read, creating "historical" # UserMessage rows would be a waste of storage, because # `flags.read | flags.historical` is exactly the flags we # simulate when processing a message for which a user has # access but no UserMessage row. - messages = [message_id for message_id in messages if message_id in um_message_ids] - else: + # # Users can mutate flags for messages that don't have a # UserMessage yet. Validate that the user is even allowed # to access these message_ids; if so, we will create @@ -310,7 +325,7 @@ def do_update_message_flags( # # See create_historical_user_messages for a more detailed # explanation. - historical_message_ids = set(messages) - um_message_ids + historical_message_ids = set(messages) - set(ums.keys()) historical_messages = bulk_access_messages( user_profile, list( @@ -323,13 +338,18 @@ def do_update_message_flags( raise JsonableError(_("Invalid message(s)")) create_historical_user_messages( - user_id=user_profile.id, message_ids=historical_message_ids + user_id=user_profile.id, + message_ids=historical_message_ids, + flags=(DEFAULT_HISTORICAL_FLAGS & ~flagattr) | flag_target, ) + to_update = UserMessage.objects.filter( + user_profile=user_profile, message_id__in=set(messages) & set(ums.keys()) + ) if is_adding: - count = query.update(flags=F("flags").bitor(flagattr)) + to_update.update(flags=F("flags").bitor(flagattr)) else: - count = query.update(flags=F("flags").bitand(~flagattr)) + to_update.update(flags=F("flags").bitand(~flagattr)) event = { "type": "update_message_flags", diff --git a/zerver/tests/test_message_flags.py b/zerver/tests/test_message_flags.py index 3abe21718a..4d5f1e0521 100644 --- a/zerver/tests/test_message_flags.py +++ b/zerver/tests/test_message_flags.py @@ -1908,9 +1908,7 @@ class MarkUnreadTest(ZulipTestCase): ) self.assert_json_success(result) event = events[0]["event"] - self.assertEqual( - event["messages"], subscribed_stream_message_ids + unsubscribed_stream_message_ids - ) + self.assertEqual(event["messages"], subscribed_stream_message_ids) for message_id in subscribed_stream_message_ids + unsubscribed_stream_message_ids: um = UserMessage.objects.get(