2024-03-26 15:24:45 +01:00
|
|
|
from typing import List, Optional
|
2022-04-14 23:28:01 +02:00
|
|
|
|
|
|
|
from django.db import connection
|
|
|
|
from psycopg2.extras import execute_values
|
2024-03-26 15:24:45 +01:00
|
|
|
from psycopg2.sql import SQL, Composable, Literal
|
2022-04-14 23:28:01 +02:00
|
|
|
|
|
|
|
from zerver.models import UserMessage
|
|
|
|
|
|
|
|
|
|
|
|
class UserMessageLite:
|
|
|
|
"""
|
|
|
|
The Django ORM is too slow for bulk operations. This class
|
|
|
|
is optimized for the simple use case of inserting a bunch of
|
|
|
|
rows into zerver_usermessage.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, user_profile_id: int, message_id: int, flags: int) -> None:
|
|
|
|
self.user_profile_id = user_profile_id
|
|
|
|
self.message_id = message_id
|
|
|
|
self.flags = flags
|
|
|
|
|
|
|
|
def flags_list(self) -> List[str]:
|
|
|
|
return UserMessage.flags_list_for_flags(self.flags)
|
|
|
|
|
|
|
|
|
2024-03-26 15:18:32 +01:00
|
|
|
DEFAULT_HISTORICAL_FLAGS = UserMessage.flags.historical | UserMessage.flags.read
|
|
|
|
|
|
|
|
|
|
|
|
def create_historical_user_messages(
|
2024-03-26 15:24:45 +01:00
|
|
|
*,
|
|
|
|
user_id: int,
|
|
|
|
message_ids: List[int],
|
|
|
|
flagattr: Optional[int] = None,
|
|
|
|
flag_target: Optional[int] = None,
|
2024-03-26 15:18:32 +01:00
|
|
|
) -> None:
|
|
|
|
# Users can see and interact with messages sent to streams with
|
|
|
|
# public history for which they do not have a UserMessage because
|
|
|
|
# they were not a subscriber at the time the message was sent.
|
|
|
|
# In order to add emoji reactions or mutate message flags for
|
|
|
|
# those messages, we create UserMessage objects for those messages;
|
|
|
|
# these have the special historical flag which keeps track of the
|
|
|
|
# fact that the user did not receive the message at the time it was sent.
|
2024-03-26 15:24:45 +01:00
|
|
|
if flagattr is not None and flag_target is not None:
|
|
|
|
conflict = SQL(
|
|
|
|
"(user_profile_id, message_id) DO UPDATE SET flags = excluded.flags & ~ {mask} | {attr}"
|
|
|
|
).format(mask=Literal(flagattr), attr=Literal(flag_target))
|
|
|
|
flags = (DEFAULT_HISTORICAL_FLAGS & ~flagattr) | flag_target
|
|
|
|
else:
|
|
|
|
conflict = None
|
|
|
|
flags = DEFAULT_HISTORICAL_FLAGS
|
|
|
|
bulk_insert_all_ums([user_id], message_ids, flags, conflict)
|
2024-03-26 15:18:32 +01:00
|
|
|
|
|
|
|
|
2022-04-14 23:28:01 +02:00
|
|
|
def bulk_insert_ums(ums: List[UserMessageLite]) -> None:
|
|
|
|
"""
|
|
|
|
Doing bulk inserts this way is much faster than using Django,
|
|
|
|
since we don't have any ORM overhead. Profiling with 1000
|
|
|
|
users shows a speedup of 0.436 -> 0.027 seconds, so we're
|
|
|
|
talking about a 15x speedup.
|
|
|
|
"""
|
|
|
|
if not ums:
|
|
|
|
return
|
|
|
|
|
|
|
|
vals = [(um.user_profile_id, um.message_id, um.flags) for um in ums]
|
|
|
|
query = SQL(
|
|
|
|
"""
|
|
|
|
INSERT into
|
|
|
|
zerver_usermessage (user_profile_id, message_id, flags)
|
|
|
|
VALUES %s
|
2024-03-06 21:09:23 +01:00
|
|
|
ON CONFLICT DO NOTHING
|
2022-04-14 23:28:01 +02:00
|
|
|
"""
|
|
|
|
)
|
|
|
|
|
|
|
|
with connection.cursor() as cursor:
|
|
|
|
execute_values(cursor.cursor, query, vals)
|
2024-03-14 17:44:57 +01:00
|
|
|
|
|
|
|
|
2024-03-26 15:24:45 +01:00
|
|
|
def bulk_insert_all_ums(
|
|
|
|
user_ids: List[int], message_ids: List[int], flags: int, conflict: Optional[Composable] = None
|
|
|
|
) -> None:
|
2024-03-14 17:44:57 +01:00
|
|
|
if not user_ids or not message_ids:
|
|
|
|
return
|
|
|
|
|
|
|
|
query = SQL(
|
|
|
|
"""
|
|
|
|
INSERT INTO zerver_usermessage (user_profile_id, message_id, flags)
|
|
|
|
SELECT user_profile_id, message_id, %s AS flags
|
|
|
|
FROM UNNEST(%s) user_profile_id
|
|
|
|
CROSS JOIN UNNEST(%s) message_id
|
2024-03-26 15:24:45 +01:00
|
|
|
ON CONFLICT {conflict}
|
2024-03-14 17:44:57 +01:00
|
|
|
"""
|
2024-03-26 15:24:45 +01:00
|
|
|
).format(conflict=conflict if conflict is not None else SQL("DO NOTHING"))
|
2024-03-14 17:44:57 +01:00
|
|
|
|
|
|
|
with connection.cursor() as cursor:
|
|
|
|
cursor.execute(query, [flags, user_ids, message_ids])
|