From 547c8f895dfdddfe7e3004cf7f93e956c47b3a70 Mon Sep 17 00:00:00 2001 From: Alex Vandiver Date: Fri, 13 Oct 2023 01:53:42 +0000 Subject: [PATCH] message: Merge unnecessary cache_transformer step. Having a non-identity `cache_transformer` is no different from running it on every row of the query_function. Simplify understanding of the codepath used in caching by merging the pieces of code. --- zerver/lib/message.py | 5 ++--- zerver/lib/message_cache.py | 5 +++-- zerver/tests/test_message_dict.py | 15 +++++---------- zerver/tests/test_message_send.py | 3 +-- zerver/tests/test_submessage.py | 2 +- 5 files changed, 12 insertions(+), 18 deletions(-) diff --git a/zerver/lib/message.py b/zerver/lib/message.py index 71d8a83c7e..5b00943efa 100644 --- a/zerver/lib/message.py +++ b/zerver/lib/message.py @@ -218,15 +218,14 @@ def messages_for_ids( user_profile: Optional[UserProfile], realm: Realm, ) -> List[Dict[str, Any]]: - cache_transformer = MessageDict.build_dict_from_raw_db_row id_fetcher = lambda row: row["id"] message_dicts = generic_bulk_cached_fetch( to_dict_cache_key_id, - MessageDict.get_raw_db_rows, + MessageDict.ids_to_dict, message_ids, id_fetcher=id_fetcher, - cache_transformer=cache_transformer, + cache_transformer=lambda obj: obj, extractor=extract_message_dict, setter=stringify_message_dict, ) diff --git a/zerver/lib/message_cache.py b/zerver/lib/message_cache.py index 9250b227b7..c3f5d5f248 100644 --- a/zerver/lib/message_cache.py +++ b/zerver/lib/message_cache.py @@ -321,7 +321,7 @@ class MessageDict: return [MessageDict.build_dict_from_raw_db_row(row) for row in message_rows] @staticmethod - def get_raw_db_rows(needed_ids: List[int]) -> List[Dict[str, Any]]: + def ids_to_dict(needed_ids: List[int]) -> List[Dict[str, Any]]: # This is a special purpose function optimized for # callers like get_messages_backend(). fields = [ @@ -342,7 +342,8 @@ class MessageDict: ] # Uses index: zerver_message_pkey messages = Message.objects.filter(id__in=needed_ids).values(*fields) - return MessageDict.sew_submessages_and_reactions_to_msgs(messages) + MessageDict.sew_submessages_and_reactions_to_msgs(messages) + return [MessageDict.build_dict_from_raw_db_row(row) for row in messages] @staticmethod def build_dict_from_raw_db_row(row: Dict[str, Any]) -> Dict[str, Any]: diff --git a/zerver/tests/test_message_dict.py b/zerver/tests/test_message_dict.py index 9b08882015..c172b05648 100644 --- a/zerver/tests/test_message_dict.py +++ b/zerver/tests/test_message_dict.py @@ -171,14 +171,12 @@ class MessageDictTest(ZulipTestCase): self.assertTrue(num_ids >= 600) with self.assert_database_query_count(7): - rows = list(MessageDict.get_raw_db_rows(ids)) - - objs = [MessageDict.build_dict_from_raw_db_row(row) for row in rows] + objs = MessageDict.ids_to_dict(ids) MessageDict.post_process_dicts( objs, apply_markdown=False, client_gravatar=False, realm=realm ) - self.assert_length(rows, num_ids) + self.assert_length(objs, num_ids) def test_applying_markdown(self) -> None: sender = self.example_user("othello") @@ -200,8 +198,7 @@ class MessageDictTest(ZulipTestCase): # An important part of this test is to get the message through this exact code path, # because there is an ugly hack we need to cover. So don't just say "row = message". - row = MessageDict.get_raw_db_rows([message.id])[0] - dct = MessageDict.build_dict_from_raw_db_row(row) + dct = MessageDict.ids_to_dict([message.id])[0] expected_content = "

hello world

" self.assertEqual(dct["rendered_content"], expected_content) message = Message.objects.get(id=message.id) @@ -231,8 +228,7 @@ class MessageDictTest(ZulipTestCase): # An important part of this test is to get the message through this exact code path, # because there is an ugly hack we need to cover. So don't just say "row = message". - row = MessageDict.get_raw_db_rows([message.id])[0] - dct = MessageDict.build_dict_from_raw_db_row(row) + dct = MessageDict.ids_to_dict([message.id])[0] error_content = ( "

[Zulip note: Sorry, we could not understand the formatting of your message]

" ) @@ -298,8 +294,7 @@ class MessageDictTest(ZulipTestCase): reaction = Reaction.objects.create( message=message, user_profile=sender, emoji_name="simple_smile" ) - row = MessageDict.get_raw_db_rows([message.id])[0] - msg_dict = MessageDict.build_dict_from_raw_db_row(row) + msg_dict = MessageDict.ids_to_dict([message.id])[0] self.assertEqual(msg_dict["reactions"][0]["emoji_name"], reaction.emoji_name) self.assertEqual(msg_dict["reactions"][0]["user_id"], sender.id) self.assertEqual(msg_dict["reactions"][0]["user"]["id"], sender.id) diff --git a/zerver/tests/test_message_send.py b/zerver/tests/test_message_send.py index 1a9c7f58ac..639bccf94f 100644 --- a/zerver/tests/test_message_send.py +++ b/zerver/tests/test_message_send.py @@ -1699,8 +1699,7 @@ class StreamMessagesTest(ZulipTestCase): self.example_user("hamlet"), "Denmark", content="whatever", topic_name="my topic" ) message = most_recent_message(user_profile) - row = MessageDict.get_raw_db_rows([message.id])[0] - dct = MessageDict.build_dict_from_raw_db_row(row) + dct = MessageDict.ids_to_dict([message.id])[0] MessageDict.post_process_dicts( [dct], apply_markdown=True, diff --git a/zerver/tests/test_submessage.py b/zerver/tests/test_submessage.py index 9bf73bf1cc..f2fb1d7a1b 100644 --- a/zerver/tests/test_submessage.py +++ b/zerver/tests/test_submessage.py @@ -65,7 +65,7 @@ class TestBasics(ZulipTestCase): rows.sort(key=lambda r: r["id"]) self.assertEqual(rows, expected_data) - msg_rows = MessageDict.get_raw_db_rows([message_id]) + msg_rows = MessageDict.ids_to_dict([message_id]) rows = msg_rows[0]["submessages"] rows.sort(key=lambda r: r["id"]) self.assertEqual(rows, expected_data)