user_message: Use INSERT ... ON CONFLICT for historical UM creation.

Rather than use a bulk insert via Django, use the faster
`bulk_insert_all_ums` that we already have.  This also adds a `ON
CONFLICT` clause, to make the insert resilient to race conditions.

There are currently two callsites, with different desired `ON
CONFLICT` behaviours:
 - For `notify_reaction_update`, if the `UserMessage` had already been
   created, we would have done nothing to change it.
 - For `do_update_message_flags`, we would have ensured a specific bit
   was (un)set.

Extend `create_historical_user_messages` and `bulk_insert_all_ums` to
support `ON CONFLICT (...) UPDATE SET flags = ...`.
This commit is contained in:
Alex Vandiver 2024-03-26 14:24:45 +00:00 committed by Tim Abbott
parent 52e3c8e1b2
commit 7988aad159
3 changed files with 122 additions and 13 deletions

View File

@ -338,8 +338,9 @@ def do_update_message_flags(
create_historical_user_messages(
user_id=user_profile.id,
message_ids=historical_message_ids,
flags=(DEFAULT_HISTORICAL_FLAGS & ~flagattr) | flag_target,
message_ids=list(historical_message_ids),
flagattr=flagattr,
flag_target=flag_target,
)
to_update = UserMessage.objects.filter(

View File

@ -1,8 +1,8 @@
from typing import Iterable, List
from typing import List, Optional
from django.db import connection
from psycopg2.extras import execute_values
from psycopg2.sql import SQL
from psycopg2.sql import SQL, Composable, Literal
from zerver.models import UserMessage
@ -27,7 +27,11 @@ DEFAULT_HISTORICAL_FLAGS = UserMessage.flags.historical | UserMessage.flags.read
def create_historical_user_messages(
*, user_id: int, message_ids: Iterable[int], flags: int = DEFAULT_HISTORICAL_FLAGS
*,
user_id: int,
message_ids: List[int],
flagattr: Optional[int] = None,
flag_target: Optional[int] = None,
) -> None:
# Users can see and interact with messages sent to streams with
# public history for which they do not have a UserMessage because
@ -36,10 +40,15 @@ def create_historical_user_messages(
# those messages, we create UserMessage objects for those messages;
# these have the special historical flag which keeps track of the
# fact that the user did not receive the message at the time it was sent.
UserMessage.objects.bulk_create(
UserMessage(user_profile_id=user_id, message_id=message_id, flags=flags)
for message_id in message_ids
)
if flagattr is not None and flag_target is not None:
conflict = SQL(
"(user_profile_id, message_id) DO UPDATE SET flags = excluded.flags & ~ {mask} | {attr}"
).format(mask=Literal(flagattr), attr=Literal(flag_target))
flags = (DEFAULT_HISTORICAL_FLAGS & ~flagattr) | flag_target
else:
conflict = None
flags = DEFAULT_HISTORICAL_FLAGS
bulk_insert_all_ums([user_id], message_ids, flags, conflict)
def bulk_insert_ums(ums: List[UserMessageLite]) -> None:
@ -66,7 +75,9 @@ def bulk_insert_ums(ums: List[UserMessageLite]) -> None:
execute_values(cursor.cursor, query, vals)
def bulk_insert_all_ums(user_ids: List[int], message_ids: List[int], flags: int) -> None:
def bulk_insert_all_ums(
user_ids: List[int], message_ids: List[int], flags: int, conflict: Optional[Composable] = None
) -> None:
if not user_ids or not message_ids:
return
@ -76,9 +87,9 @@ def bulk_insert_all_ums(user_ids: List[int], message_ids: List[int], flags: int)
SELECT user_profile_id, message_id, %s AS flags
FROM UNNEST(%s) user_profile_id
CROSS JOIN UNNEST(%s) message_id
ON CONFLICT DO NOTHING
ON CONFLICT {conflict}
"""
)
).format(conflict=conflict if conflict is not None else SQL("DO NOTHING"))
with connection.cursor() as cursor:
cursor.execute(query, [flags, user_ids, message_ids])

View File

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, List, Set
from typing import TYPE_CHECKING, Any, List, Optional, Set
from unittest import mock
import orjson
@ -26,6 +26,7 @@ from zerver.lib.message_cache import MessageDict
from zerver.lib.test_classes import ZulipTestCase
from zerver.lib.test_helpers import get_subscription, timeout_mock
from zerver.lib.timeout import TimeoutExpiredError
from zerver.lib.user_message import DEFAULT_HISTORICAL_FLAGS, create_historical_user_messages
from zerver.models import (
Message,
Recipient,
@ -233,6 +234,102 @@ class UnreadCountTests(ZulipTestCase):
elif msg["id"] == self.unread_msg_ids[1]:
check_flags(msg["flags"], set())
def test_update_flags_race(self) -> None:
user = self.example_user("hamlet")
self.login_user(user)
self.unsubscribe(user, "Verona")
first_message_id = self.send_stream_message(
self.example_user("cordelia"),
"Verona",
topic_name="testing",
)
self.assertFalse(
UserMessage.objects.filter(
user_profile_id=user.id, message_id=first_message_id
).exists()
)
# When adjusting flags of messages that we did not receive, we
# create UserMessage rows.
with mock.patch(
"zerver.actions.message_flags.create_historical_user_messages",
wraps=create_historical_user_messages,
) as mock_backfill:
result = self.client_post(
"/json/messages/flags",
{
"messages": orjson.dumps([first_message_id]).decode(),
"op": "add",
"flag": "starred",
},
)
self.assert_json_success(result)
mock_backfill.assert_called_once_with(
user_id=user.id,
message_ids=[first_message_id],
flagattr=UserMessage.flags.starred,
flag_target=UserMessage.flags.starred,
)
um_row = UserMessage.objects.get(user_profile_id=user.id, message_id=first_message_id)
self.assertEqual(
int(um_row.flags),
UserMessage.flags.historical | UserMessage.flags.read | UserMessage.flags.starred,
)
# That creation may race with other things which also create
# the UserMessage rows (e.g. reactions); ensure the end result
# is correct still.
def race_creation(
*,
user_id: int,
message_ids: List[int],
flagattr: Optional[int] = None,
flag_target: Optional[int] = None,
) -> None:
UserMessage.objects.create(
user_profile_id=user_id, message_id=message_ids[0], flags=DEFAULT_HISTORICAL_FLAGS
)
create_historical_user_messages(
user_id=user_id, message_ids=message_ids, flagattr=flagattr, flag_target=flag_target
)
second_message_id = self.send_stream_message(
self.example_user("cordelia"),
"Verona",
topic_name="testing",
)
self.assertFalse(
UserMessage.objects.filter(
user_profile_id=user.id, message_id=second_message_id
).exists()
)
with mock.patch(
"zerver.actions.message_flags.create_historical_user_messages", wraps=race_creation
) as mock_backfill:
result = self.client_post(
"/json/messages/flags",
{
"messages": orjson.dumps([second_message_id]).decode(),
"op": "add",
"flag": "starred",
},
)
self.assert_json_success(result)
mock_backfill.assert_called_once_with(
user_id=user.id,
message_ids=[second_message_id],
flagattr=UserMessage.flags.starred,
flag_target=UserMessage.flags.starred,
)
um_row = UserMessage.objects.get(user_profile_id=user.id, message_id=second_message_id)
self.assertEqual(
int(um_row.flags),
UserMessage.flags.historical | UserMessage.flags.read | UserMessage.flags.starred,
)
def test_update_flags_for_narrow(self) -> None:
user = self.example_user("hamlet")
self.login_user(user)