mirror of https://github.com/zulip/zulip.git
user_message: Use INSERT ... ON CONFLICT for historical UM creation.
Rather than use a bulk insert via Django, use the faster `bulk_insert_all_ums` that we already have. This also adds a `ON CONFLICT` clause, to make the insert resilient to race conditions. There are currently two callsites, with different desired `ON CONFLICT` behaviours: - For `notify_reaction_update`, if the `UserMessage` had already been created, we would have done nothing to change it. - For `do_update_message_flags`, we would have ensured a specific bit was (un)set. Extend `create_historical_user_messages` and `bulk_insert_all_ums` to support `ON CONFLICT (...) UPDATE SET flags = ...`.
This commit is contained in:
parent
52e3c8e1b2
commit
7988aad159
|
@ -338,8 +338,9 @@ def do_update_message_flags(
|
||||||
|
|
||||||
create_historical_user_messages(
|
create_historical_user_messages(
|
||||||
user_id=user_profile.id,
|
user_id=user_profile.id,
|
||||||
message_ids=historical_message_ids,
|
message_ids=list(historical_message_ids),
|
||||||
flags=(DEFAULT_HISTORICAL_FLAGS & ~flagattr) | flag_target,
|
flagattr=flagattr,
|
||||||
|
flag_target=flag_target,
|
||||||
)
|
)
|
||||||
|
|
||||||
to_update = UserMessage.objects.filter(
|
to_update = UserMessage.objects.filter(
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
from typing import Iterable, List
|
from typing import List, Optional
|
||||||
|
|
||||||
from django.db import connection
|
from django.db import connection
|
||||||
from psycopg2.extras import execute_values
|
from psycopg2.extras import execute_values
|
||||||
from psycopg2.sql import SQL
|
from psycopg2.sql import SQL, Composable, Literal
|
||||||
|
|
||||||
from zerver.models import UserMessage
|
from zerver.models import UserMessage
|
||||||
|
|
||||||
|
@ -27,7 +27,11 @@ DEFAULT_HISTORICAL_FLAGS = UserMessage.flags.historical | UserMessage.flags.read
|
||||||
|
|
||||||
|
|
||||||
def create_historical_user_messages(
|
def create_historical_user_messages(
|
||||||
*, user_id: int, message_ids: Iterable[int], flags: int = DEFAULT_HISTORICAL_FLAGS
|
*,
|
||||||
|
user_id: int,
|
||||||
|
message_ids: List[int],
|
||||||
|
flagattr: Optional[int] = None,
|
||||||
|
flag_target: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
# Users can see and interact with messages sent to streams with
|
# Users can see and interact with messages sent to streams with
|
||||||
# public history for which they do not have a UserMessage because
|
# public history for which they do not have a UserMessage because
|
||||||
|
@ -36,10 +40,15 @@ def create_historical_user_messages(
|
||||||
# those messages, we create UserMessage objects for those messages;
|
# those messages, we create UserMessage objects for those messages;
|
||||||
# these have the special historical flag which keeps track of the
|
# these have the special historical flag which keeps track of the
|
||||||
# fact that the user did not receive the message at the time it was sent.
|
# fact that the user did not receive the message at the time it was sent.
|
||||||
UserMessage.objects.bulk_create(
|
if flagattr is not None and flag_target is not None:
|
||||||
UserMessage(user_profile_id=user_id, message_id=message_id, flags=flags)
|
conflict = SQL(
|
||||||
for message_id in message_ids
|
"(user_profile_id, message_id) DO UPDATE SET flags = excluded.flags & ~ {mask} | {attr}"
|
||||||
)
|
).format(mask=Literal(flagattr), attr=Literal(flag_target))
|
||||||
|
flags = (DEFAULT_HISTORICAL_FLAGS & ~flagattr) | flag_target
|
||||||
|
else:
|
||||||
|
conflict = None
|
||||||
|
flags = DEFAULT_HISTORICAL_FLAGS
|
||||||
|
bulk_insert_all_ums([user_id], message_ids, flags, conflict)
|
||||||
|
|
||||||
|
|
||||||
def bulk_insert_ums(ums: List[UserMessageLite]) -> None:
|
def bulk_insert_ums(ums: List[UserMessageLite]) -> None:
|
||||||
|
@ -66,7 +75,9 @@ def bulk_insert_ums(ums: List[UserMessageLite]) -> None:
|
||||||
execute_values(cursor.cursor, query, vals)
|
execute_values(cursor.cursor, query, vals)
|
||||||
|
|
||||||
|
|
||||||
def bulk_insert_all_ums(user_ids: List[int], message_ids: List[int], flags: int) -> None:
|
def bulk_insert_all_ums(
|
||||||
|
user_ids: List[int], message_ids: List[int], flags: int, conflict: Optional[Composable] = None
|
||||||
|
) -> None:
|
||||||
if not user_ids or not message_ids:
|
if not user_ids or not message_ids:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -76,9 +87,9 @@ def bulk_insert_all_ums(user_ids: List[int], message_ids: List[int], flags: int)
|
||||||
SELECT user_profile_id, message_id, %s AS flags
|
SELECT user_profile_id, message_id, %s AS flags
|
||||||
FROM UNNEST(%s) user_profile_id
|
FROM UNNEST(%s) user_profile_id
|
||||||
CROSS JOIN UNNEST(%s) message_id
|
CROSS JOIN UNNEST(%s) message_id
|
||||||
ON CONFLICT DO NOTHING
|
ON CONFLICT {conflict}
|
||||||
"""
|
"""
|
||||||
)
|
).format(conflict=conflict if conflict is not None else SQL("DO NOTHING"))
|
||||||
|
|
||||||
with connection.cursor() as cursor:
|
with connection.cursor() as cursor:
|
||||||
cursor.execute(query, [flags, user_ids, message_ids])
|
cursor.execute(query, [flags, user_ids, message_ids])
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import TYPE_CHECKING, Any, List, Set
|
from typing import TYPE_CHECKING, Any, List, Optional, Set
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
import orjson
|
import orjson
|
||||||
|
@ -26,6 +26,7 @@ from zerver.lib.message_cache import MessageDict
|
||||||
from zerver.lib.test_classes import ZulipTestCase
|
from zerver.lib.test_classes import ZulipTestCase
|
||||||
from zerver.lib.test_helpers import get_subscription, timeout_mock
|
from zerver.lib.test_helpers import get_subscription, timeout_mock
|
||||||
from zerver.lib.timeout import TimeoutExpiredError
|
from zerver.lib.timeout import TimeoutExpiredError
|
||||||
|
from zerver.lib.user_message import DEFAULT_HISTORICAL_FLAGS, create_historical_user_messages
|
||||||
from zerver.models import (
|
from zerver.models import (
|
||||||
Message,
|
Message,
|
||||||
Recipient,
|
Recipient,
|
||||||
|
@ -233,6 +234,102 @@ class UnreadCountTests(ZulipTestCase):
|
||||||
elif msg["id"] == self.unread_msg_ids[1]:
|
elif msg["id"] == self.unread_msg_ids[1]:
|
||||||
check_flags(msg["flags"], set())
|
check_flags(msg["flags"], set())
|
||||||
|
|
||||||
|
def test_update_flags_race(self) -> None:
|
||||||
|
user = self.example_user("hamlet")
|
||||||
|
self.login_user(user)
|
||||||
|
self.unsubscribe(user, "Verona")
|
||||||
|
|
||||||
|
first_message_id = self.send_stream_message(
|
||||||
|
self.example_user("cordelia"),
|
||||||
|
"Verona",
|
||||||
|
topic_name="testing",
|
||||||
|
)
|
||||||
|
self.assertFalse(
|
||||||
|
UserMessage.objects.filter(
|
||||||
|
user_profile_id=user.id, message_id=first_message_id
|
||||||
|
).exists()
|
||||||
|
)
|
||||||
|
# When adjusting flags of messages that we did not receive, we
|
||||||
|
# create UserMessage rows.
|
||||||
|
with mock.patch(
|
||||||
|
"zerver.actions.message_flags.create_historical_user_messages",
|
||||||
|
wraps=create_historical_user_messages,
|
||||||
|
) as mock_backfill:
|
||||||
|
result = self.client_post(
|
||||||
|
"/json/messages/flags",
|
||||||
|
{
|
||||||
|
"messages": orjson.dumps([first_message_id]).decode(),
|
||||||
|
"op": "add",
|
||||||
|
"flag": "starred",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.assert_json_success(result)
|
||||||
|
|
||||||
|
mock_backfill.assert_called_once_with(
|
||||||
|
user_id=user.id,
|
||||||
|
message_ids=[first_message_id],
|
||||||
|
flagattr=UserMessage.flags.starred,
|
||||||
|
flag_target=UserMessage.flags.starred,
|
||||||
|
)
|
||||||
|
um_row = UserMessage.objects.get(user_profile_id=user.id, message_id=first_message_id)
|
||||||
|
self.assertEqual(
|
||||||
|
int(um_row.flags),
|
||||||
|
UserMessage.flags.historical | UserMessage.flags.read | UserMessage.flags.starred,
|
||||||
|
)
|
||||||
|
|
||||||
|
# That creation may race with other things which also create
|
||||||
|
# the UserMessage rows (e.g. reactions); ensure the end result
|
||||||
|
# is correct still.
|
||||||
|
def race_creation(
|
||||||
|
*,
|
||||||
|
user_id: int,
|
||||||
|
message_ids: List[int],
|
||||||
|
flagattr: Optional[int] = None,
|
||||||
|
flag_target: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
UserMessage.objects.create(
|
||||||
|
user_profile_id=user_id, message_id=message_ids[0], flags=DEFAULT_HISTORICAL_FLAGS
|
||||||
|
)
|
||||||
|
create_historical_user_messages(
|
||||||
|
user_id=user_id, message_ids=message_ids, flagattr=flagattr, flag_target=flag_target
|
||||||
|
)
|
||||||
|
|
||||||
|
second_message_id = self.send_stream_message(
|
||||||
|
self.example_user("cordelia"),
|
||||||
|
"Verona",
|
||||||
|
topic_name="testing",
|
||||||
|
)
|
||||||
|
self.assertFalse(
|
||||||
|
UserMessage.objects.filter(
|
||||||
|
user_profile_id=user.id, message_id=second_message_id
|
||||||
|
).exists()
|
||||||
|
)
|
||||||
|
with mock.patch(
|
||||||
|
"zerver.actions.message_flags.create_historical_user_messages", wraps=race_creation
|
||||||
|
) as mock_backfill:
|
||||||
|
result = self.client_post(
|
||||||
|
"/json/messages/flags",
|
||||||
|
{
|
||||||
|
"messages": orjson.dumps([second_message_id]).decode(),
|
||||||
|
"op": "add",
|
||||||
|
"flag": "starred",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.assert_json_success(result)
|
||||||
|
|
||||||
|
mock_backfill.assert_called_once_with(
|
||||||
|
user_id=user.id,
|
||||||
|
message_ids=[second_message_id],
|
||||||
|
flagattr=UserMessage.flags.starred,
|
||||||
|
flag_target=UserMessage.flags.starred,
|
||||||
|
)
|
||||||
|
|
||||||
|
um_row = UserMessage.objects.get(user_profile_id=user.id, message_id=second_message_id)
|
||||||
|
self.assertEqual(
|
||||||
|
int(um_row.flags),
|
||||||
|
UserMessage.flags.historical | UserMessage.flags.read | UserMessage.flags.starred,
|
||||||
|
)
|
||||||
|
|
||||||
def test_update_flags_for_narrow(self) -> None:
|
def test_update_flags_for_narrow(self) -> None:
|
||||||
user = self.example_user("hamlet")
|
user = self.example_user("hamlet")
|
||||||
self.login_user(user)
|
self.login_user(user)
|
||||||
|
|
Loading…
Reference in New Issue