diff --git a/zerver/actions/message_flags.py b/zerver/actions/message_flags.py index 81da8a4858..b97d09a471 100644 --- a/zerver/actions/message_flags.py +++ b/zerver/actions/message_flags.py @@ -338,8 +338,9 @@ def do_update_message_flags( create_historical_user_messages( user_id=user_profile.id, - message_ids=historical_message_ids, - flags=(DEFAULT_HISTORICAL_FLAGS & ~flagattr) | flag_target, + message_ids=list(historical_message_ids), + flagattr=flagattr, + flag_target=flag_target, ) to_update = UserMessage.objects.filter( diff --git a/zerver/lib/user_message.py b/zerver/lib/user_message.py index 4c856fd4e1..3ef5eadeef 100644 --- a/zerver/lib/user_message.py +++ b/zerver/lib/user_message.py @@ -1,8 +1,8 @@ -from typing import Iterable, List +from typing import List, Optional from django.db import connection from psycopg2.extras import execute_values -from psycopg2.sql import SQL +from psycopg2.sql import SQL, Composable, Literal from zerver.models import UserMessage @@ -27,7 +27,11 @@ DEFAULT_HISTORICAL_FLAGS = UserMessage.flags.historical | UserMessage.flags.read def create_historical_user_messages( - *, user_id: int, message_ids: Iterable[int], flags: int = DEFAULT_HISTORICAL_FLAGS + *, + user_id: int, + message_ids: List[int], + flagattr: Optional[int] = None, + flag_target: Optional[int] = None, ) -> None: # Users can see and interact with messages sent to streams with # public history for which they do not have a UserMessage because @@ -36,10 +40,15 @@ def create_historical_user_messages( # those messages, we create UserMessage objects for those messages; # these have the special historical flag which keeps track of the # fact that the user did not receive the message at the time it was sent. - UserMessage.objects.bulk_create( - UserMessage(user_profile_id=user_id, message_id=message_id, flags=flags) - for message_id in message_ids - ) + if flagattr is not None and flag_target is not None: + conflict = SQL( + "(user_profile_id, message_id) DO UPDATE SET flags = excluded.flags & ~ {mask} | {attr}" + ).format(mask=Literal(flagattr), attr=Literal(flag_target)) + flags = (DEFAULT_HISTORICAL_FLAGS & ~flagattr) | flag_target + else: + conflict = None + flags = DEFAULT_HISTORICAL_FLAGS + bulk_insert_all_ums([user_id], message_ids, flags, conflict) def bulk_insert_ums(ums: List[UserMessageLite]) -> None: @@ -66,7 +75,9 @@ def bulk_insert_ums(ums: List[UserMessageLite]) -> None: execute_values(cursor.cursor, query, vals) -def bulk_insert_all_ums(user_ids: List[int], message_ids: List[int], flags: int) -> None: +def bulk_insert_all_ums( + user_ids: List[int], message_ids: List[int], flags: int, conflict: Optional[Composable] = None +) -> None: if not user_ids or not message_ids: return @@ -76,9 +87,9 @@ def bulk_insert_all_ums(user_ids: List[int], message_ids: List[int], flags: int) SELECT user_profile_id, message_id, %s AS flags FROM UNNEST(%s) user_profile_id CROSS JOIN UNNEST(%s) message_id - ON CONFLICT DO NOTHING + ON CONFLICT {conflict} """ - ) + ).format(conflict=conflict if conflict is not None else SQL("DO NOTHING")) with connection.cursor() as cursor: cursor.execute(query, [flags, user_ids, message_ids]) diff --git a/zerver/tests/test_message_flags.py b/zerver/tests/test_message_flags.py index ef91ca84ab..5fdb35d421 100644 --- a/zerver/tests/test_message_flags.py +++ b/zerver/tests/test_message_flags.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, List, Set +from typing import TYPE_CHECKING, Any, List, Optional, Set from unittest import mock import orjson @@ -26,6 +26,7 @@ from zerver.lib.message_cache import MessageDict from zerver.lib.test_classes import ZulipTestCase from zerver.lib.test_helpers import get_subscription, timeout_mock from zerver.lib.timeout import TimeoutExpiredError +from zerver.lib.user_message import DEFAULT_HISTORICAL_FLAGS, create_historical_user_messages from zerver.models import ( Message, Recipient, @@ -233,6 +234,102 @@ class UnreadCountTests(ZulipTestCase): elif msg["id"] == self.unread_msg_ids[1]: check_flags(msg["flags"], set()) + def test_update_flags_race(self) -> None: + user = self.example_user("hamlet") + self.login_user(user) + self.unsubscribe(user, "Verona") + + first_message_id = self.send_stream_message( + self.example_user("cordelia"), + "Verona", + topic_name="testing", + ) + self.assertFalse( + UserMessage.objects.filter( + user_profile_id=user.id, message_id=first_message_id + ).exists() + ) + # When adjusting flags of messages that we did not receive, we + # create UserMessage rows. + with mock.patch( + "zerver.actions.message_flags.create_historical_user_messages", + wraps=create_historical_user_messages, + ) as mock_backfill: + result = self.client_post( + "/json/messages/flags", + { + "messages": orjson.dumps([first_message_id]).decode(), + "op": "add", + "flag": "starred", + }, + ) + self.assert_json_success(result) + + mock_backfill.assert_called_once_with( + user_id=user.id, + message_ids=[first_message_id], + flagattr=UserMessage.flags.starred, + flag_target=UserMessage.flags.starred, + ) + um_row = UserMessage.objects.get(user_profile_id=user.id, message_id=first_message_id) + self.assertEqual( + int(um_row.flags), + UserMessage.flags.historical | UserMessage.flags.read | UserMessage.flags.starred, + ) + + # That creation may race with other things which also create + # the UserMessage rows (e.g. reactions); ensure the end result + # is correct still. + def race_creation( + *, + user_id: int, + message_ids: List[int], + flagattr: Optional[int] = None, + flag_target: Optional[int] = None, + ) -> None: + UserMessage.objects.create( + user_profile_id=user_id, message_id=message_ids[0], flags=DEFAULT_HISTORICAL_FLAGS + ) + create_historical_user_messages( + user_id=user_id, message_ids=message_ids, flagattr=flagattr, flag_target=flag_target + ) + + second_message_id = self.send_stream_message( + self.example_user("cordelia"), + "Verona", + topic_name="testing", + ) + self.assertFalse( + UserMessage.objects.filter( + user_profile_id=user.id, message_id=second_message_id + ).exists() + ) + with mock.patch( + "zerver.actions.message_flags.create_historical_user_messages", wraps=race_creation + ) as mock_backfill: + result = self.client_post( + "/json/messages/flags", + { + "messages": orjson.dumps([second_message_id]).decode(), + "op": "add", + "flag": "starred", + }, + ) + self.assert_json_success(result) + + mock_backfill.assert_called_once_with( + user_id=user.id, + message_ids=[second_message_id], + flagattr=UserMessage.flags.starred, + flag_target=UserMessage.flags.starred, + ) + + um_row = UserMessage.objects.get(user_profile_id=user.id, message_id=second_message_id) + self.assertEqual( + int(um_row.flags), + UserMessage.flags.historical | UserMessage.flags.read | UserMessage.flags.starred, + ) + def test_update_flags_for_narrow(self) -> None: user = self.example_user("hamlet") self.login_user(user)