mirror of https://github.com/zulip/zulip.git
populate_db: Simplify how we create reactions.
For 3000 messages and 400 users, this saved about 30 seconds. We only do two queries per batch of messages now, and the algorithm is easier to analyze, as it's just three nested loops.
This commit is contained in:
parent
5eb63ddb7a
commit
99e725cbde
|
@ -1,23 +1,11 @@
|
|||
import random
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
from django.db.models import Model
|
||||
|
||||
from zerver.lib.create_user import create_user_profile, get_display_email_address
|
||||
from zerver.lib.initial_password import initial_password
|
||||
from zerver.lib.streams import render_stream_description
|
||||
from zerver.models import (
|
||||
Message,
|
||||
Reaction,
|
||||
Realm,
|
||||
RealmAuditLog,
|
||||
Recipient,
|
||||
Stream,
|
||||
Subscription,
|
||||
UserMessage,
|
||||
UserProfile,
|
||||
)
|
||||
from zerver.models import Realm, RealmAuditLog, Recipient, Stream, Subscription, UserProfile
|
||||
|
||||
|
||||
def bulk_create_users(realm: Realm,
|
||||
|
@ -157,111 +145,3 @@ def bulk_create_streams(realm: Realm,
|
|||
Recipient.objects.bulk_create(recipients_to_create)
|
||||
|
||||
bulk_set_users_or_streams_recipient_fields(Stream, streams_to_create, recipients_to_create)
|
||||
|
||||
DEFAULT_EMOJIS = [
|
||||
('+1', '1f44d'),
|
||||
('smiley', '1f603'),
|
||||
('eyes', '1f440'),
|
||||
('crying_cat_face', '1f63f'),
|
||||
('arrow_up', '2b06'),
|
||||
('confetti_ball', '1f38a'),
|
||||
('hundred_points', '1f4af'),
|
||||
]
|
||||
|
||||
def bulk_create_reactions(
|
||||
messages: Iterable[Message],
|
||||
users: Optional[List[UserProfile]] = None,
|
||||
emojis: Optional[List[Tuple[str, str]]] = None
|
||||
) -> None:
|
||||
messages = list(messages)
|
||||
if not emojis:
|
||||
emojis = DEFAULT_EMOJIS
|
||||
emojis = list(emojis)
|
||||
|
||||
reactions: List[Reaction] = []
|
||||
for message in messages:
|
||||
reactions.extend(_add_random_reactions_to_message(
|
||||
message, emojis, users))
|
||||
Reaction.objects.bulk_create(reactions)
|
||||
|
||||
def _add_random_reactions_to_message(
|
||||
message: Message,
|
||||
emojis: List[Tuple[str, str]],
|
||||
users: Optional[List[UserProfile]] = None,
|
||||
prob_reaction: float = 0.075,
|
||||
prob_upvote: float = 0.5,
|
||||
prob_repeat: float = 0.5) -> List[Reaction]:
|
||||
'''Randomly add emoji reactions to each message from a list.
|
||||
|
||||
Algorithm:
|
||||
|
||||
Give the message at least one reaction with probability `prob_reaction`.
|
||||
Once the first reaction is added, have another user upvote it with probability
|
||||
`prob_upvote`, provided there is another recipient of the message left to upvote.
|
||||
Repeat the process for a different emoji with probability `prob_repeat`.
|
||||
|
||||
If the number of emojis or users is small, there is a chance the above process
|
||||
will produce multiple reactions with the same user and emoji, so group the
|
||||
reactions by emoji code and user profile and then return one reaction from
|
||||
each group.
|
||||
'''
|
||||
for p in (prob_reaction, prob_repeat, prob_upvote):
|
||||
# Prevent p=1 since for prob_repeat and prob_upvote, this will
|
||||
# lead to an infinite loop.
|
||||
if p >= 1 or p < 0:
|
||||
raise ValueError('Probability argument must be between 0 and 1.')
|
||||
|
||||
# Avoid performing database queries if there will be no reactions.
|
||||
compute_next_reaction: bool = random.random() < prob_reaction
|
||||
if not compute_next_reaction:
|
||||
return []
|
||||
|
||||
if users is None:
|
||||
users = []
|
||||
user_ids: Sequence[int] = [user.id for user in users]
|
||||
if not user_ids:
|
||||
user_ids = UserMessage.objects.filter(message=message) \
|
||||
.values_list("user_profile_id", flat=True)
|
||||
if not user_ids:
|
||||
return []
|
||||
|
||||
emojis = list(emojis)
|
||||
|
||||
reactions = []
|
||||
while compute_next_reaction:
|
||||
# We do this O(users) operation only if we've decided to do a
|
||||
# reaction, to avoid performance issues with large numbers of
|
||||
# users.
|
||||
users_available = set(user_ids)
|
||||
|
||||
(emoji_name, emoji_code) = random.choice(emojis)
|
||||
while True:
|
||||
# Handle corner case where all the users have reacted.
|
||||
if not users_available:
|
||||
break
|
||||
|
||||
user_id = random.choice(list(users_available))
|
||||
reactions.append(Reaction(
|
||||
user_profile_id=user_id,
|
||||
message=message,
|
||||
emoji_name=emoji_name,
|
||||
emoji_code=emoji_code,
|
||||
reaction_type=Reaction.UNICODE_EMOJI
|
||||
))
|
||||
users_available.remove(user_id)
|
||||
|
||||
# Add an upvote with the defined probability.
|
||||
if not random.random() < prob_upvote:
|
||||
break
|
||||
|
||||
# Repeat with a possibly different random emoji with the
|
||||
# defined probability.
|
||||
compute_next_reaction = random.random() < prob_repeat
|
||||
|
||||
# Avoid returning duplicate reactions by deduplicating on
|
||||
# (user_profile_id, emoji_code).
|
||||
grouped_reactions = defaultdict(list)
|
||||
for reaction in reactions:
|
||||
k = (str(reaction.user_profile_id), str(reaction.emoji_code))
|
||||
grouped_reactions[k].append(reaction)
|
||||
return [reactions[0] for reactions in grouped_reactions.values()]
|
||||
|
|
|
@ -1,293 +0,0 @@
|
|||
import itertools
|
||||
import random
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from django.test import TestCase
|
||||
from django.utils.timezone import now as timezone_now
|
||||
|
||||
from zerver.lib.bulk_create import (
|
||||
DEFAULT_EMOJIS,
|
||||
_add_random_reactions_to_message,
|
||||
bulk_create_reactions,
|
||||
)
|
||||
from zerver.models import (
|
||||
Client,
|
||||
Huddle,
|
||||
Message,
|
||||
Realm,
|
||||
Recipient,
|
||||
Stream,
|
||||
Subscription,
|
||||
UserMessage,
|
||||
UserProfile,
|
||||
)
|
||||
|
||||
|
||||
class TestBulkCreateReactions(TestCase):
|
||||
"""This test class is somewhat low value and uses extensive mocking of
|
||||
random; it's possible we should delete it rather than doing a
|
||||
great deal of work to preserve it; this test mostly exists to
|
||||
achieve coverage goals."""
|
||||
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
random.seed(42)
|
||||
self.realm = Realm.objects.create(
|
||||
name="test_realm",
|
||||
string_id="test_realm"
|
||||
)
|
||||
self.message_client = Client.objects.create(
|
||||
name='test_client'
|
||||
)
|
||||
self.alice = UserProfile.objects.create(
|
||||
delivery_email='alice@gmail.com',
|
||||
email='alice@gmail.com',
|
||||
realm=self.realm,
|
||||
full_name='Alice'
|
||||
)
|
||||
self.bob = UserProfile.objects.create(
|
||||
delivery_email='bob@gmail.com',
|
||||
email='bob@gmail.com',
|
||||
realm=self.realm,
|
||||
full_name='Bob'
|
||||
)
|
||||
self.charlie = UserProfile.objects.create(
|
||||
delivery_email='charlie@gmail.com',
|
||||
email='charlie@gmail.com',
|
||||
realm=self.realm,
|
||||
full_name='Charlie'
|
||||
)
|
||||
|
||||
self.users = [self.alice, self.bob, self.charlie]
|
||||
type_ids = Recipient \
|
||||
.objects.filter(type=Recipient.PERSONAL).values_list('type_id')
|
||||
max_type_id = max(x[0] for x in type_ids)
|
||||
self.recipients = []
|
||||
for i, user in enumerate(self.users):
|
||||
recipient = Recipient.objects.create(
|
||||
type=Recipient.PERSONAL,
|
||||
type_id=max_type_id + i + 1
|
||||
)
|
||||
user.recipient = recipient
|
||||
user.save()
|
||||
self.recipients.append(recipient)
|
||||
self.personal_message = Message.objects.create(
|
||||
sender=self.alice,
|
||||
recipient=self.bob.recipient,
|
||||
content='It is I, Alice.',
|
||||
sending_client=self.message_client,
|
||||
date_sent=timezone_now()
|
||||
)
|
||||
|
||||
self.stream = Stream.objects.create(
|
||||
name="test_stream",
|
||||
realm=self.realm,
|
||||
)
|
||||
self.stream.recipient = Recipient.objects.create(
|
||||
type=Recipient.STREAM,
|
||||
type_id=1 + max(
|
||||
x[0] for x in Recipient.objects.filter(type=Recipient.STREAM).values_list('type_id'))
|
||||
)
|
||||
self.stream.save()
|
||||
for user in self.users:
|
||||
Subscription.objects.create(
|
||||
user_profile=user,
|
||||
recipient=self.stream.recipient
|
||||
)
|
||||
self.stream_message = Message.objects.create(
|
||||
sender=self.alice,
|
||||
recipient=self.stream.recipient,
|
||||
content='This is Alice.',
|
||||
sending_client=self.message_client,
|
||||
date_sent=timezone_now()
|
||||
)
|
||||
|
||||
self.huddle = Huddle.objects.create(
|
||||
huddle_hash="bad-hash",
|
||||
)
|
||||
self.huddle.recipient = Recipient.objects.create(
|
||||
type=Recipient.HUDDLE,
|
||||
type_id=1 + max(
|
||||
itertools.chain(
|
||||
(x[0] for x in Recipient.objects.filter(type=Recipient.HUDDLE).values_list('type_id')),
|
||||
[0])))
|
||||
self.huddle.save()
|
||||
for user in self.users:
|
||||
Subscription.objects.create(
|
||||
user_profile=user,
|
||||
recipient=self.huddle.recipient
|
||||
)
|
||||
self.huddle_message = Message.objects.create(
|
||||
sender=self.alice,
|
||||
recipient=self.huddle.recipient,
|
||||
content='Alice my name is.',
|
||||
sending_client=self.message_client,
|
||||
date_sent=timezone_now()
|
||||
)
|
||||
|
||||
def test_invalid_probabilities(self) -> None:
|
||||
message = self.personal_message
|
||||
emojis = DEFAULT_EMOJIS
|
||||
users = self.users
|
||||
prob_keys = ['prob_reaction', 'prob_upvote', 'prob_repeat']
|
||||
for probs in [
|
||||
(1, .5, .5),
|
||||
(.5, 1, .5),
|
||||
(.5, .5, 1),
|
||||
(-0.01, .5, .5),
|
||||
(.5, -.01, .5),
|
||||
(.5, .5, -.01),
|
||||
]:
|
||||
kwargs = dict(zip(prob_keys, probs))
|
||||
with self.assertRaises(ValueError):
|
||||
_add_random_reactions_to_message(message, emojis, users, **kwargs)
|
||||
|
||||
@patch('zerver.lib.bulk_create.random')
|
||||
@patch('zerver.lib.bulk_create.UserProfile')
|
||||
@patch('zerver.lib.bulk_create.Subscription')
|
||||
def test_early_exit_if_no_reactions(
|
||||
self,
|
||||
MockSubscription: MagicMock,
|
||||
MockUserProfile: MagicMock,
|
||||
mock_random: MagicMock) -> None:
|
||||
message = self.personal_message
|
||||
emojis = DEFAULT_EMOJIS
|
||||
users = None
|
||||
mock_random.random.return_value = 1
|
||||
reactions = _add_random_reactions_to_message(message, emojis, users)
|
||||
self.assertEqual(reactions, [])
|
||||
self.assertFalse(MockUserProfile.objects.get.called)
|
||||
self.assertFalse(MockSubscription.objects.filter.called)
|
||||
|
||||
@patch('zerver.lib.bulk_create.random')
|
||||
@patch('zerver.lib.bulk_create.UserMessage')
|
||||
def test_query_for_personal_message_users(
|
||||
self,
|
||||
MockUserProfile: MagicMock,
|
||||
mock_random: MagicMock) -> None:
|
||||
message = self.personal_message
|
||||
emojis = DEFAULT_EMOJIS
|
||||
users = None
|
||||
mock_random.choice = random.choice
|
||||
mock_random.random.side_effect = [0, 1, 1, 1, 1, 1]
|
||||
_add_random_reactions_to_message(message, emojis, users)
|
||||
self.assertTrue(MockUserProfile.objects.filter.called)
|
||||
|
||||
@patch('zerver.lib.bulk_create.random')
|
||||
@patch('zerver.lib.bulk_create.UserMessage')
|
||||
def test_query_for_stream_message_users(
|
||||
self,
|
||||
MockUserMessage: MagicMock,
|
||||
mock_random: MagicMock) -> None:
|
||||
message = self.stream_message
|
||||
emojis = DEFAULT_EMOJIS
|
||||
users = None
|
||||
mock_random.choice = random.choice
|
||||
mock_random.random.side_effect = [0, 1, 1, 1, 1, 1]
|
||||
_add_random_reactions_to_message(message, emojis, users)
|
||||
self.assertTrue(MockUserMessage.objects.filter.called)
|
||||
|
||||
@patch('zerver.lib.bulk_create.random')
|
||||
@patch('zerver.lib.bulk_create.UserMessage')
|
||||
def test_query_for_huddle_message_users(
|
||||
self,
|
||||
MockUserMessage: MagicMock,
|
||||
mock_random: MagicMock) -> None:
|
||||
message = self.huddle_message
|
||||
emojis = DEFAULT_EMOJIS
|
||||
users = None
|
||||
mock_random.choice = random.choice
|
||||
mock_random.random.side_effect = [0, 1, 1, 1, 1, 1]
|
||||
_add_random_reactions_to_message(message, emojis, users)
|
||||
self.assertTrue(MockUserMessage.objects.filter.called)
|
||||
|
||||
@patch('zerver.lib.bulk_create.random')
|
||||
@patch('zerver.lib.bulk_create.UserMessage')
|
||||
def test_early_exit_if_no_users(
|
||||
self,
|
||||
MockUserMessage: MagicMock,
|
||||
mock_random: MagicMock) -> None:
|
||||
message = self.stream_message
|
||||
emojis = DEFAULT_EMOJIS
|
||||
users = None
|
||||
mock_random.choice = random.choice
|
||||
mock_random.random.side_effect = [0, 1, 1, 1, 1, 1]
|
||||
MockUserMessage.objects.filter.return_value = UserMessage.objects.none()
|
||||
reactions = _add_random_reactions_to_message(message, emojis, users)
|
||||
self.assertTrue(MockUserMessage.objects.filter.called)
|
||||
self.assertEqual(reactions, [])
|
||||
|
||||
@patch('zerver.lib.bulk_create.random')
|
||||
def test_single_reaction(
|
||||
self,
|
||||
mock_random: MagicMock) -> None:
|
||||
message = self.stream_message
|
||||
emojis = DEFAULT_EMOJIS
|
||||
users = self.users
|
||||
mock_random.choice = random.choice
|
||||
mock_random.random.side_effect = [0, 1, 1]
|
||||
reactions = _add_random_reactions_to_message(message, emojis, users)
|
||||
self.assertEqual(len(reactions), 1)
|
||||
|
||||
@patch('zerver.lib.bulk_create.random')
|
||||
def test_single_reaction_with_upvote(
|
||||
self,
|
||||
mock_random: MagicMock) -> None:
|
||||
message = self.stream_message
|
||||
emojis = DEFAULT_EMOJIS
|
||||
users = self.users
|
||||
mock_random.choice = random.choice
|
||||
mock_random.random.side_effect = [0, 0, 1, 1]
|
||||
reactions = _add_random_reactions_to_message(message, emojis, users)
|
||||
self.assertEqual(len(reactions), 2)
|
||||
assert reactions[0].emoji_name == reactions[1].emoji_name
|
||||
assert reactions[0].user_profile_id != reactions[1].user_profile_id
|
||||
|
||||
@patch('zerver.lib.bulk_create.random')
|
||||
def test_two_reactions_with_different_emojis(
|
||||
self, mock_random: MagicMock) -> None:
|
||||
message = self.stream_message
|
||||
emojis = DEFAULT_EMOJIS
|
||||
users = self.users
|
||||
mock_random.choice.side_effect = [emojis[0], users[0].id, emojis[1], users[1].id]
|
||||
mock_random.random.side_effect = [0, 1, 0, 1, 1]
|
||||
reactions = _add_random_reactions_to_message(message, emojis, users)
|
||||
self.assertEqual(len(reactions), 2)
|
||||
assert reactions[0].emoji_name != reactions[1].emoji_name
|
||||
assert reactions[0].user_profile_id != reactions[1].user_profile_id
|
||||
|
||||
@patch('zerver.lib.bulk_create.random')
|
||||
def test_deduplicated_reactions(
|
||||
self, mock_random: MagicMock) -> None:
|
||||
message = self.stream_message
|
||||
emojis = DEFAULT_EMOJIS[:1]
|
||||
users = self.users[:1]
|
||||
mock_random.choice = random.choice
|
||||
mock_random.random.side_effect = [0, 1, 0, 1, 1]
|
||||
reactions = _add_random_reactions_to_message(message, emojis, users)
|
||||
self.assertEqual(len(reactions), 1)
|
||||
|
||||
@patch('zerver.lib.bulk_create.random')
|
||||
def test_no_available_users(
|
||||
self, mock_random: MagicMock) -> None:
|
||||
message = self.stream_message
|
||||
emojis = DEFAULT_EMOJIS
|
||||
users = self.users[:1]
|
||||
mock_random.choice = random.choice
|
||||
mock_random.random.side_effect = [0, 0, 1, 1]
|
||||
reactions = _add_random_reactions_to_message(message, emojis, users)
|
||||
self.assertEqual(len(reactions), 1)
|
||||
|
||||
@patch('zerver.lib.bulk_create.Reaction')
|
||||
@patch('zerver.lib.bulk_create._add_random_reactions_to_message')
|
||||
def test_default_emojis(
|
||||
self,
|
||||
mock_add_random_reactions_to_message: MagicMock,
|
||||
MockReaction: MagicMock) -> None:
|
||||
messages = [self.personal_message]
|
||||
users = [self.users[0]]
|
||||
emojis = None
|
||||
bulk_create_reactions(messages, users, emojis)
|
||||
self.assertTrue(mock_add_random_reactions_to_message.called)
|
||||
mock_add_random_reactions_to_message.assert_called_with(
|
||||
messages[0], DEFAULT_EMOJIS, users)
|
|
@ -1,6 +1,7 @@
|
|||
import itertools
|
||||
import os
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Dict, List, Mapping, Sequence, Tuple
|
||||
|
||||
|
@ -25,7 +26,7 @@ from zerver.lib.actions import (
|
|||
try_add_realm_custom_profile_field,
|
||||
try_add_realm_default_custom_profile_field,
|
||||
)
|
||||
from zerver.lib.bulk_create import bulk_create_reactions, bulk_create_streams
|
||||
from zerver.lib.bulk_create import bulk_create_streams
|
||||
from zerver.lib.cache import cache_set
|
||||
from zerver.lib.generate_test_data import create_test_data, generate_topics
|
||||
from zerver.lib.onboarding import create_if_missing_realm_internal_bots
|
||||
|
@ -44,6 +45,7 @@ from zerver.models import (
|
|||
DefaultStream,
|
||||
Huddle,
|
||||
Message,
|
||||
Reaction,
|
||||
Realm,
|
||||
RealmAuditLog,
|
||||
RealmDomain,
|
||||
|
@ -72,6 +74,16 @@ settings.CACHES['default'] = {
|
|||
'BACKEND': 'django.core.cache.backends.locmem.LocMemCache',
|
||||
}
|
||||
|
||||
DEFAULT_EMOJIS = [
|
||||
('+1', '1f44d'),
|
||||
('smiley', '1f603'),
|
||||
('eyes', '1f440'),
|
||||
('crying_cat_face', '1f63f'),
|
||||
('arrow_up', '2b06'),
|
||||
('confetti_ball', '1f38a'),
|
||||
('hundred_points', '1f4af'),
|
||||
]
|
||||
|
||||
def clear_database() -> None:
|
||||
# Hacky function only for use inside populate_db. Designed to
|
||||
# allow running populate_db repeatedly in series to work without
|
||||
|
@ -807,6 +819,55 @@ def send_messages(messages: List[Message]) -> None:
|
|||
bulk_create_reactions(messages)
|
||||
settings.USING_RABBITMQ = True
|
||||
|
||||
def get_message_to_users(message_ids: List[int]) -> Dict[int, List[int]]:
|
||||
rows = UserMessage.objects.filter(
|
||||
message_id__in=message_ids,
|
||||
).values("message_id", "user_profile_id")
|
||||
|
||||
result: Dict[int, List[int]] = defaultdict(list)
|
||||
|
||||
for row in rows:
|
||||
result[row["message_id"]].append(row["user_profile_id"])
|
||||
|
||||
return result
|
||||
|
||||
def bulk_create_reactions(all_messages: List[Message]) -> None:
|
||||
reactions: List[Reaction] = []
|
||||
|
||||
num_messages = int(0.2 * len(all_messages))
|
||||
messages = random.sample(all_messages, num_messages)
|
||||
message_ids = [message.id for message in messages]
|
||||
|
||||
message_to_users = get_message_to_users(message_ids)
|
||||
|
||||
for message_id in message_ids:
|
||||
msg_user_ids = message_to_users[message_id]
|
||||
|
||||
if msg_user_ids:
|
||||
# Now let between 1 and 7 users react.
|
||||
#
|
||||
# Ideally, we'd make exactly 1 reaction more common than
|
||||
# this algorithm generates.
|
||||
max_num_users = min(7, len(msg_user_ids))
|
||||
num_users = random.randrange(1, max_num_users + 1)
|
||||
user_ids = random.sample(msg_user_ids, num_users)
|
||||
|
||||
for user_id in user_ids:
|
||||
# each user does between 1 and 3 emojis
|
||||
num_emojis = random.choice([1, 2, 3])
|
||||
emojis = random.sample(DEFAULT_EMOJIS, num_emojis)
|
||||
|
||||
for emoji_name, emoji_code in emojis:
|
||||
reaction = Reaction(
|
||||
user_profile_id=user_id,
|
||||
message_id=message_id,
|
||||
emoji_name=emoji_name,
|
||||
emoji_code=emoji_code,
|
||||
reaction_type=Reaction.UNICODE_EMOJI
|
||||
)
|
||||
reactions.append(reaction)
|
||||
|
||||
Reaction.objects.bulk_create(reactions)
|
||||
def choose_date_sent(num_messages: int, tot_messages: int, threads: int) -> datetime:
|
||||
# Spoofing time not supported with threading
|
||||
if threads != 1:
|
||||
|
|
Loading…
Reference in New Issue