diff --git a/zerver/lib/cache.py b/zerver/lib/cache.py index d02a7123ec..1181ee87fa 100644 --- a/zerver/lib/cache.py +++ b/zerver/lib/cache.py @@ -450,6 +450,13 @@ def flush_message(sender: Any, **kwargs: Any) -> None: message = kwargs['instance'] cache_delete(to_dict_cache_key_id(message.id)) +def flush_submessage(sender: Any, **kwargs: Any) -> None: + submessage = kwargs['instance'] + # submessages are not cached directly, they are part of their + # parent messages + message_id = submessage.message_id + cache_delete(to_dict_cache_key_id(message_id)) + DECORATOR = Callable[[Callable[..., Any]], Callable[..., Any]] def ignore_unhashable_lru_cache(maxsize: int=128, typed: bool=False) -> DECORATOR: diff --git a/zerver/lib/message.py b/zerver/lib/message.py index 7b654d877a..d859d412cb 100644 --- a/zerver/lib/message.py +++ b/zerver/lib/message.py @@ -37,6 +37,7 @@ from zerver.models import ( Realm, Recipient, Stream, + SubMessage, Subscription, UserProfile, UserMessage, @@ -121,6 +122,20 @@ def sew_messages_and_reactions(messages: List[Dict[str, Any]], return list(converted_messages.values()) +def sew_messages_and_submessages(messages: List[Dict[str, Any]], + submessages: List[Dict[str, Any]]) -> None: + # This is super similar to sew_messages_and_reactions. + for message in messages: + message['submessages'] = [] + + message_dict = {message['id']: message for message in messages} + + for submessage in submessages: + message_id = submessage['message_id'] + if message_id in message_dict: + message = message_dict[message_id] + message['submessages'].append(submessage) + def extract_message_dict(message_bytes: bytes) -> Dict[str, Any]: return ujson.loads(zlib.decompress(message_bytes).decode("utf-8")) @@ -205,7 +220,8 @@ class MessageDict: recipient_id = message.recipient.id, recipient_type = message.recipient.type, recipient_type_id = message.recipient.type_id, - reactions = Reaction.get_raw_db_rows([message.id]) + reactions = Reaction.get_raw_db_rows([message.id]), + submessages = SubMessage.get_raw_db_rows([message.id]), ) @staticmethod @@ -229,11 +245,10 @@ class MessageDict: 'sender__realm_id', ] messages = Message.objects.filter(id__in=needed_ids).values(*fields) - """Adding one-many or Many-Many relationship in values results in N X - results. - Link: https://docs.djangoproject.com/en/1.8/ref/models/querysets/#values - """ + submessages = SubMessage.get_raw_db_rows(needed_ids) + sew_messages_and_submessages(messages, submessages) + reactions = Reaction.get_raw_db_rows(needed_ids) return sew_messages_and_reactions(messages, reactions) @@ -259,7 +274,8 @@ class MessageDict: recipient_id = row['recipient_id'], recipient_type = row['recipient__type'], recipient_type_id = row['recipient__type_id'], - reactions=row['reactions'] + reactions=row['reactions'], + submessages=row['submessages'], ) @staticmethod @@ -279,7 +295,8 @@ class MessageDict: recipient_id: int, recipient_type: int, recipient_type_id: int, - reactions: List[Dict[str, Any]] + reactions: List[Dict[str, Any]], + submessages: List[Dict[str, Any]] ) -> Dict[str, Any]: obj = dict( @@ -343,6 +360,7 @@ class MessageDict: obj['reactions'] = [ReactionDict.build_dict_from_raw_db_row(reaction) for reaction in reactions] + obj['submessages'] = submessages return obj @staticmethod diff --git a/zerver/models.py b/zerver/models.py index ec94d17a0b..89a63e9293 100644 --- a/zerver/models.py +++ b/zerver/models.py @@ -22,7 +22,7 @@ from zerver.lib.cache import cache_with_key, flush_user_profile, flush_realm, \ display_recipient_cache_key, cache_delete, active_user_ids_cache_key, \ get_stream_cache_key, realm_user_dicts_cache_key, \ bot_dicts_in_realm_cache_key, realm_user_dict_fields, \ - bot_dict_fields, flush_message, bot_profile_cache_key + bot_dict_fields, flush_message, flush_submessage, bot_profile_cache_key from zerver.lib.utils import make_safe_digest, generate_random_token from django.db import transaction from django.utils.timezone import now as timezone_now @@ -1318,6 +1318,8 @@ class SubMessage(models.Model): query = query.order_by('message_id', 'id') return list(query) +post_save.connect(flush_submessage, sender=SubMessage) + class Reaction(models.Model): user_profile = models.ForeignKey(UserProfile, on_delete=CASCADE) # type: UserProfile message = models.ForeignKey(Message, on_delete=CASCADE) # type: Message diff --git a/zerver/tests/test_events.py b/zerver/tests/test_events.py index 6f29212a31..2973d392bb 100644 --- a/zerver/tests/test_events.py +++ b/zerver/tests/test_events.py @@ -637,6 +637,7 @@ class EventsRegisterTest(ZulipTestCase): ('stream_id', check_int), ('subject', check_string), ('subject_links', check_list(None)), + ('submessages', check_list(None)), ('timestamp', check_int), ('type', check_string), ])), diff --git a/zerver/tests/test_messages.py b/zerver/tests/test_messages.py index be224943b3..93a8d39bec 100644 --- a/zerver/tests/test_messages.py +++ b/zerver/tests/test_messages.py @@ -623,7 +623,7 @@ class StreamMessagesTest(ZulipTestCase): body=content, ) - self.assert_length(queries, 13) + self.assert_length(queries, 14) def test_stream_message_dict(self) -> None: user_profile = self.example_user('iago') @@ -867,7 +867,7 @@ class MessageDictTest(ZulipTestCase): # slower. error_msg = "Number of ids: {}. Time delay: {}".format(num_ids, delay) self.assertTrue(delay < 0.0015 * num_ids, error_msg) - self.assert_length(queries, 6) + self.assert_length(queries, 7) self.assertEqual(len(rows), num_ids) def test_applying_markdown(self) -> None: diff --git a/zerver/tests/test_signup.py b/zerver/tests/test_signup.py index ae91aacec0..1a12fa6645 100644 --- a/zerver/tests/test_signup.py +++ b/zerver/tests/test_signup.py @@ -364,7 +364,7 @@ class LoginTest(ZulipTestCase): with queries_captured() as queries: self.register(self.nonreg_email('test'), "test") # Ensure the number of queries we make is not O(streams) - self.assert_length(queries, 70) + self.assert_length(queries, 71) user_profile = self.nonreg_user('test') self.assertEqual(get_session_dict_user(self.client.session), user_profile.id) self.assertFalse(user_profile.enable_stream_desktop_notifications) diff --git a/zerver/tests/test_submessage.py b/zerver/tests/test_submessage.py index 8b32ebb322..6a3d1ceca4 100644 --- a/zerver/tests/test_submessage.py +++ b/zerver/tests/test_submessage.py @@ -1,6 +1,11 @@ from zerver.lib.test_classes import ZulipTestCase +from zerver.lib.message import ( + MessageDict, +) + from zerver.models import ( + Message, SubMessage, ) @@ -57,3 +62,14 @@ class TestBasics(ZulipTestCase): ] self.assertEqual(get_raw_rows(), expected_data) + + message = Message.objects.get(id=message_id) + message_json = MessageDict.wide_dict(message) + rows = message_json['submessages'] + rows.sort(key=lambda r: r['id']) + self.assertEqual(rows, expected_data) + + msg_rows = MessageDict.get_raw_db_rows([message_id]) + rows = msg_rows[0]['submessages'] + rows.sort(key=lambda r: r['id']) + self.assertEqual(rows, expected_data) diff --git a/zerver/tests/test_subs.py b/zerver/tests/test_subs.py index 53eb4798e5..cbb6151a74 100644 --- a/zerver/tests/test_subs.py +++ b/zerver/tests/test_subs.py @@ -1995,7 +1995,7 @@ class SubscriptionAPITest(ZulipTestCase): streams_to_sub, dict(principals=ujson.dumps([user1.email, user2.email])), ) - self.assert_length(queries, 42) + self.assert_length(queries, 43) self.assert_length(events, 7) for ev in [x for x in events if x['event']['type'] not in ('message', 'stream')]: