mirror of https://github.com/zulip/zulip.git
user_message: Use INSERT ... ON CONFLICT for historical UM creation.
Rather than use a bulk insert via Django, use the faster `bulk_insert_all_ums` that we already have. This also adds a `ON CONFLICT` clause, to make the insert resilient to race conditions. There are currently two callsites, with different desired `ON CONFLICT` behaviours: - For `notify_reaction_update`, if the `UserMessage` had already been created, we would have done nothing to change it. - For `do_update_message_flags`, we would have ensured a specific bit was (un)set. Extend `create_historical_user_messages` and `bulk_insert_all_ums` to support `ON CONFLICT (...) UPDATE SET flags = ...`.
This commit is contained in:
parent
52e3c8e1b2
commit
7988aad159
|
@ -338,8 +338,9 @@ def do_update_message_flags(
|
|||
|
||||
create_historical_user_messages(
|
||||
user_id=user_profile.id,
|
||||
message_ids=historical_message_ids,
|
||||
flags=(DEFAULT_HISTORICAL_FLAGS & ~flagattr) | flag_target,
|
||||
message_ids=list(historical_message_ids),
|
||||
flagattr=flagattr,
|
||||
flag_target=flag_target,
|
||||
)
|
||||
|
||||
to_update = UserMessage.objects.filter(
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
from typing import Iterable, List
|
||||
from typing import List, Optional
|
||||
|
||||
from django.db import connection
|
||||
from psycopg2.extras import execute_values
|
||||
from psycopg2.sql import SQL
|
||||
from psycopg2.sql import SQL, Composable, Literal
|
||||
|
||||
from zerver.models import UserMessage
|
||||
|
||||
|
@ -27,7 +27,11 @@ DEFAULT_HISTORICAL_FLAGS = UserMessage.flags.historical | UserMessage.flags.read
|
|||
|
||||
|
||||
def create_historical_user_messages(
|
||||
*, user_id: int, message_ids: Iterable[int], flags: int = DEFAULT_HISTORICAL_FLAGS
|
||||
*,
|
||||
user_id: int,
|
||||
message_ids: List[int],
|
||||
flagattr: Optional[int] = None,
|
||||
flag_target: Optional[int] = None,
|
||||
) -> None:
|
||||
# Users can see and interact with messages sent to streams with
|
||||
# public history for which they do not have a UserMessage because
|
||||
|
@ -36,10 +40,15 @@ def create_historical_user_messages(
|
|||
# those messages, we create UserMessage objects for those messages;
|
||||
# these have the special historical flag which keeps track of the
|
||||
# fact that the user did not receive the message at the time it was sent.
|
||||
UserMessage.objects.bulk_create(
|
||||
UserMessage(user_profile_id=user_id, message_id=message_id, flags=flags)
|
||||
for message_id in message_ids
|
||||
)
|
||||
if flagattr is not None and flag_target is not None:
|
||||
conflict = SQL(
|
||||
"(user_profile_id, message_id) DO UPDATE SET flags = excluded.flags & ~ {mask} | {attr}"
|
||||
).format(mask=Literal(flagattr), attr=Literal(flag_target))
|
||||
flags = (DEFAULT_HISTORICAL_FLAGS & ~flagattr) | flag_target
|
||||
else:
|
||||
conflict = None
|
||||
flags = DEFAULT_HISTORICAL_FLAGS
|
||||
bulk_insert_all_ums([user_id], message_ids, flags, conflict)
|
||||
|
||||
|
||||
def bulk_insert_ums(ums: List[UserMessageLite]) -> None:
|
||||
|
@ -66,7 +75,9 @@ def bulk_insert_ums(ums: List[UserMessageLite]) -> None:
|
|||
execute_values(cursor.cursor, query, vals)
|
||||
|
||||
|
||||
def bulk_insert_all_ums(user_ids: List[int], message_ids: List[int], flags: int) -> None:
|
||||
def bulk_insert_all_ums(
|
||||
user_ids: List[int], message_ids: List[int], flags: int, conflict: Optional[Composable] = None
|
||||
) -> None:
|
||||
if not user_ids or not message_ids:
|
||||
return
|
||||
|
||||
|
@ -76,9 +87,9 @@ def bulk_insert_all_ums(user_ids: List[int], message_ids: List[int], flags: int)
|
|||
SELECT user_profile_id, message_id, %s AS flags
|
||||
FROM UNNEST(%s) user_profile_id
|
||||
CROSS JOIN UNNEST(%s) message_id
|
||||
ON CONFLICT DO NOTHING
|
||||
ON CONFLICT {conflict}
|
||||
"""
|
||||
)
|
||||
).format(conflict=conflict if conflict is not None else SQL("DO NOTHING"))
|
||||
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(query, [flags, user_ids, message_ids])
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import TYPE_CHECKING, Any, List, Set
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Set
|
||||
from unittest import mock
|
||||
|
||||
import orjson
|
||||
|
@ -26,6 +26,7 @@ from zerver.lib.message_cache import MessageDict
|
|||
from zerver.lib.test_classes import ZulipTestCase
|
||||
from zerver.lib.test_helpers import get_subscription, timeout_mock
|
||||
from zerver.lib.timeout import TimeoutExpiredError
|
||||
from zerver.lib.user_message import DEFAULT_HISTORICAL_FLAGS, create_historical_user_messages
|
||||
from zerver.models import (
|
||||
Message,
|
||||
Recipient,
|
||||
|
@ -233,6 +234,102 @@ class UnreadCountTests(ZulipTestCase):
|
|||
elif msg["id"] == self.unread_msg_ids[1]:
|
||||
check_flags(msg["flags"], set())
|
||||
|
||||
def test_update_flags_race(self) -> None:
|
||||
user = self.example_user("hamlet")
|
||||
self.login_user(user)
|
||||
self.unsubscribe(user, "Verona")
|
||||
|
||||
first_message_id = self.send_stream_message(
|
||||
self.example_user("cordelia"),
|
||||
"Verona",
|
||||
topic_name="testing",
|
||||
)
|
||||
self.assertFalse(
|
||||
UserMessage.objects.filter(
|
||||
user_profile_id=user.id, message_id=first_message_id
|
||||
).exists()
|
||||
)
|
||||
# When adjusting flags of messages that we did not receive, we
|
||||
# create UserMessage rows.
|
||||
with mock.patch(
|
||||
"zerver.actions.message_flags.create_historical_user_messages",
|
||||
wraps=create_historical_user_messages,
|
||||
) as mock_backfill:
|
||||
result = self.client_post(
|
||||
"/json/messages/flags",
|
||||
{
|
||||
"messages": orjson.dumps([first_message_id]).decode(),
|
||||
"op": "add",
|
||||
"flag": "starred",
|
||||
},
|
||||
)
|
||||
self.assert_json_success(result)
|
||||
|
||||
mock_backfill.assert_called_once_with(
|
||||
user_id=user.id,
|
||||
message_ids=[first_message_id],
|
||||
flagattr=UserMessage.flags.starred,
|
||||
flag_target=UserMessage.flags.starred,
|
||||
)
|
||||
um_row = UserMessage.objects.get(user_profile_id=user.id, message_id=first_message_id)
|
||||
self.assertEqual(
|
||||
int(um_row.flags),
|
||||
UserMessage.flags.historical | UserMessage.flags.read | UserMessage.flags.starred,
|
||||
)
|
||||
|
||||
# That creation may race with other things which also create
|
||||
# the UserMessage rows (e.g. reactions); ensure the end result
|
||||
# is correct still.
|
||||
def race_creation(
|
||||
*,
|
||||
user_id: int,
|
||||
message_ids: List[int],
|
||||
flagattr: Optional[int] = None,
|
||||
flag_target: Optional[int] = None,
|
||||
) -> None:
|
||||
UserMessage.objects.create(
|
||||
user_profile_id=user_id, message_id=message_ids[0], flags=DEFAULT_HISTORICAL_FLAGS
|
||||
)
|
||||
create_historical_user_messages(
|
||||
user_id=user_id, message_ids=message_ids, flagattr=flagattr, flag_target=flag_target
|
||||
)
|
||||
|
||||
second_message_id = self.send_stream_message(
|
||||
self.example_user("cordelia"),
|
||||
"Verona",
|
||||
topic_name="testing",
|
||||
)
|
||||
self.assertFalse(
|
||||
UserMessage.objects.filter(
|
||||
user_profile_id=user.id, message_id=second_message_id
|
||||
).exists()
|
||||
)
|
||||
with mock.patch(
|
||||
"zerver.actions.message_flags.create_historical_user_messages", wraps=race_creation
|
||||
) as mock_backfill:
|
||||
result = self.client_post(
|
||||
"/json/messages/flags",
|
||||
{
|
||||
"messages": orjson.dumps([second_message_id]).decode(),
|
||||
"op": "add",
|
||||
"flag": "starred",
|
||||
},
|
||||
)
|
||||
self.assert_json_success(result)
|
||||
|
||||
mock_backfill.assert_called_once_with(
|
||||
user_id=user.id,
|
||||
message_ids=[second_message_id],
|
||||
flagattr=UserMessage.flags.starred,
|
||||
flag_target=UserMessage.flags.starred,
|
||||
)
|
||||
|
||||
um_row = UserMessage.objects.get(user_profile_id=user.id, message_id=second_message_id)
|
||||
self.assertEqual(
|
||||
int(um_row.flags),
|
||||
UserMessage.flags.historical | UserMessage.flags.read | UserMessage.flags.starred,
|
||||
)
|
||||
|
||||
def test_update_flags_for_narrow(self) -> None:
|
||||
user = self.example_user("hamlet")
|
||||
self.login_user(user)
|
||||
|
|
Loading…
Reference in New Issue