user_message: Use INSERT ... ON CONFLICT for historical UM creation.

Rather than use a bulk insert via Django, use the faster
`bulk_insert_all_ums` that we already have.  This also adds a `ON
CONFLICT` clause, to make the insert resilient to race conditions.

There are currently two callsites, with different desired `ON
CONFLICT` behaviours:
 - For `notify_reaction_update`, if the `UserMessage` had already been
   created, we would have done nothing to change it.
 - For `do_update_message_flags`, we would have ensured a specific bit
   was (un)set.

Extend `create_historical_user_messages` and `bulk_insert_all_ums` to
support `ON CONFLICT (...) UPDATE SET flags = ...`.
This commit is contained in:
Alex Vandiver 2024-03-26 14:24:45 +00:00 committed by Tim Abbott
parent 52e3c8e1b2
commit 7988aad159
3 changed files with 122 additions and 13 deletions

View File

@ -338,8 +338,9 @@ def do_update_message_flags(
create_historical_user_messages( create_historical_user_messages(
user_id=user_profile.id, user_id=user_profile.id,
message_ids=historical_message_ids, message_ids=list(historical_message_ids),
flags=(DEFAULT_HISTORICAL_FLAGS & ~flagattr) | flag_target, flagattr=flagattr,
flag_target=flag_target,
) )
to_update = UserMessage.objects.filter( to_update = UserMessage.objects.filter(

View File

@ -1,8 +1,8 @@
from typing import Iterable, List from typing import List, Optional
from django.db import connection from django.db import connection
from psycopg2.extras import execute_values from psycopg2.extras import execute_values
from psycopg2.sql import SQL from psycopg2.sql import SQL, Composable, Literal
from zerver.models import UserMessage from zerver.models import UserMessage
@ -27,7 +27,11 @@ DEFAULT_HISTORICAL_FLAGS = UserMessage.flags.historical | UserMessage.flags.read
def create_historical_user_messages( def create_historical_user_messages(
*, user_id: int, message_ids: Iterable[int], flags: int = DEFAULT_HISTORICAL_FLAGS *,
user_id: int,
message_ids: List[int],
flagattr: Optional[int] = None,
flag_target: Optional[int] = None,
) -> None: ) -> None:
# Users can see and interact with messages sent to streams with # Users can see and interact with messages sent to streams with
# public history for which they do not have a UserMessage because # public history for which they do not have a UserMessage because
@ -36,10 +40,15 @@ def create_historical_user_messages(
# those messages, we create UserMessage objects for those messages; # those messages, we create UserMessage objects for those messages;
# these have the special historical flag which keeps track of the # these have the special historical flag which keeps track of the
# fact that the user did not receive the message at the time it was sent. # fact that the user did not receive the message at the time it was sent.
UserMessage.objects.bulk_create( if flagattr is not None and flag_target is not None:
UserMessage(user_profile_id=user_id, message_id=message_id, flags=flags) conflict = SQL(
for message_id in message_ids "(user_profile_id, message_id) DO UPDATE SET flags = excluded.flags & ~ {mask} | {attr}"
) ).format(mask=Literal(flagattr), attr=Literal(flag_target))
flags = (DEFAULT_HISTORICAL_FLAGS & ~flagattr) | flag_target
else:
conflict = None
flags = DEFAULT_HISTORICAL_FLAGS
bulk_insert_all_ums([user_id], message_ids, flags, conflict)
def bulk_insert_ums(ums: List[UserMessageLite]) -> None: def bulk_insert_ums(ums: List[UserMessageLite]) -> None:
@ -66,7 +75,9 @@ def bulk_insert_ums(ums: List[UserMessageLite]) -> None:
execute_values(cursor.cursor, query, vals) execute_values(cursor.cursor, query, vals)
def bulk_insert_all_ums(user_ids: List[int], message_ids: List[int], flags: int) -> None: def bulk_insert_all_ums(
user_ids: List[int], message_ids: List[int], flags: int, conflict: Optional[Composable] = None
) -> None:
if not user_ids or not message_ids: if not user_ids or not message_ids:
return return
@ -76,9 +87,9 @@ def bulk_insert_all_ums(user_ids: List[int], message_ids: List[int], flags: int)
SELECT user_profile_id, message_id, %s AS flags SELECT user_profile_id, message_id, %s AS flags
FROM UNNEST(%s) user_profile_id FROM UNNEST(%s) user_profile_id
CROSS JOIN UNNEST(%s) message_id CROSS JOIN UNNEST(%s) message_id
ON CONFLICT DO NOTHING ON CONFLICT {conflict}
""" """
) ).format(conflict=conflict if conflict is not None else SQL("DO NOTHING"))
with connection.cursor() as cursor: with connection.cursor() as cursor:
cursor.execute(query, [flags, user_ids, message_ids]) cursor.execute(query, [flags, user_ids, message_ids])

View File

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, List, Set from typing import TYPE_CHECKING, Any, List, Optional, Set
from unittest import mock from unittest import mock
import orjson import orjson
@ -26,6 +26,7 @@ from zerver.lib.message_cache import MessageDict
from zerver.lib.test_classes import ZulipTestCase from zerver.lib.test_classes import ZulipTestCase
from zerver.lib.test_helpers import get_subscription, timeout_mock from zerver.lib.test_helpers import get_subscription, timeout_mock
from zerver.lib.timeout import TimeoutExpiredError from zerver.lib.timeout import TimeoutExpiredError
from zerver.lib.user_message import DEFAULT_HISTORICAL_FLAGS, create_historical_user_messages
from zerver.models import ( from zerver.models import (
Message, Message,
Recipient, Recipient,
@ -233,6 +234,102 @@ class UnreadCountTests(ZulipTestCase):
elif msg["id"] == self.unread_msg_ids[1]: elif msg["id"] == self.unread_msg_ids[1]:
check_flags(msg["flags"], set()) check_flags(msg["flags"], set())
def test_update_flags_race(self) -> None:
user = self.example_user("hamlet")
self.login_user(user)
self.unsubscribe(user, "Verona")
first_message_id = self.send_stream_message(
self.example_user("cordelia"),
"Verona",
topic_name="testing",
)
self.assertFalse(
UserMessage.objects.filter(
user_profile_id=user.id, message_id=first_message_id
).exists()
)
# When adjusting flags of messages that we did not receive, we
# create UserMessage rows.
with mock.patch(
"zerver.actions.message_flags.create_historical_user_messages",
wraps=create_historical_user_messages,
) as mock_backfill:
result = self.client_post(
"/json/messages/flags",
{
"messages": orjson.dumps([first_message_id]).decode(),
"op": "add",
"flag": "starred",
},
)
self.assert_json_success(result)
mock_backfill.assert_called_once_with(
user_id=user.id,
message_ids=[first_message_id],
flagattr=UserMessage.flags.starred,
flag_target=UserMessage.flags.starred,
)
um_row = UserMessage.objects.get(user_profile_id=user.id, message_id=first_message_id)
self.assertEqual(
int(um_row.flags),
UserMessage.flags.historical | UserMessage.flags.read | UserMessage.flags.starred,
)
# That creation may race with other things which also create
# the UserMessage rows (e.g. reactions); ensure the end result
# is correct still.
def race_creation(
*,
user_id: int,
message_ids: List[int],
flagattr: Optional[int] = None,
flag_target: Optional[int] = None,
) -> None:
UserMessage.objects.create(
user_profile_id=user_id, message_id=message_ids[0], flags=DEFAULT_HISTORICAL_FLAGS
)
create_historical_user_messages(
user_id=user_id, message_ids=message_ids, flagattr=flagattr, flag_target=flag_target
)
second_message_id = self.send_stream_message(
self.example_user("cordelia"),
"Verona",
topic_name="testing",
)
self.assertFalse(
UserMessage.objects.filter(
user_profile_id=user.id, message_id=second_message_id
).exists()
)
with mock.patch(
"zerver.actions.message_flags.create_historical_user_messages", wraps=race_creation
) as mock_backfill:
result = self.client_post(
"/json/messages/flags",
{
"messages": orjson.dumps([second_message_id]).decode(),
"op": "add",
"flag": "starred",
},
)
self.assert_json_success(result)
mock_backfill.assert_called_once_with(
user_id=user.id,
message_ids=[second_message_id],
flagattr=UserMessage.flags.starred,
flag_target=UserMessage.flags.starred,
)
um_row = UserMessage.objects.get(user_profile_id=user.id, message_id=second_message_id)
self.assertEqual(
int(um_row.flags),
UserMessage.flags.historical | UserMessage.flags.read | UserMessage.flags.starred,
)
def test_update_flags_for_narrow(self) -> None: def test_update_flags_for_narrow(self) -> None:
user = self.example_user("hamlet") user = self.example_user("hamlet")
self.login_user(user) self.login_user(user)