message: Merge unnecessary cache_transformer step.

Having a non-identity `cache_transformer` is no different from running
it on every row of the query_function.  Simplify understanding of the
codepath used in caching by merging the pieces of code.
This commit is contained in:
Alex Vandiver 2023-10-13 01:53:42 +00:00
parent 548bb5362e
commit 547c8f895d
5 changed files with 12 additions and 18 deletions

View File

@ -218,15 +218,14 @@ def messages_for_ids(
user_profile: Optional[UserProfile], user_profile: Optional[UserProfile],
realm: Realm, realm: Realm,
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
cache_transformer = MessageDict.build_dict_from_raw_db_row
id_fetcher = lambda row: row["id"] id_fetcher = lambda row: row["id"]
message_dicts = generic_bulk_cached_fetch( message_dicts = generic_bulk_cached_fetch(
to_dict_cache_key_id, to_dict_cache_key_id,
MessageDict.get_raw_db_rows, MessageDict.ids_to_dict,
message_ids, message_ids,
id_fetcher=id_fetcher, id_fetcher=id_fetcher,
cache_transformer=cache_transformer, cache_transformer=lambda obj: obj,
extractor=extract_message_dict, extractor=extract_message_dict,
setter=stringify_message_dict, setter=stringify_message_dict,
) )

View File

@ -321,7 +321,7 @@ class MessageDict:
return [MessageDict.build_dict_from_raw_db_row(row) for row in message_rows] return [MessageDict.build_dict_from_raw_db_row(row) for row in message_rows]
@staticmethod @staticmethod
def get_raw_db_rows(needed_ids: List[int]) -> List[Dict[str, Any]]: def ids_to_dict(needed_ids: List[int]) -> List[Dict[str, Any]]:
# This is a special purpose function optimized for # This is a special purpose function optimized for
# callers like get_messages_backend(). # callers like get_messages_backend().
fields = [ fields = [
@ -342,7 +342,8 @@ class MessageDict:
] ]
# Uses index: zerver_message_pkey # Uses index: zerver_message_pkey
messages = Message.objects.filter(id__in=needed_ids).values(*fields) messages = Message.objects.filter(id__in=needed_ids).values(*fields)
return MessageDict.sew_submessages_and_reactions_to_msgs(messages) MessageDict.sew_submessages_and_reactions_to_msgs(messages)
return [MessageDict.build_dict_from_raw_db_row(row) for row in messages]
@staticmethod @staticmethod
def build_dict_from_raw_db_row(row: Dict[str, Any]) -> Dict[str, Any]: def build_dict_from_raw_db_row(row: Dict[str, Any]) -> Dict[str, Any]:

View File

@ -171,14 +171,12 @@ class MessageDictTest(ZulipTestCase):
self.assertTrue(num_ids >= 600) self.assertTrue(num_ids >= 600)
with self.assert_database_query_count(7): with self.assert_database_query_count(7):
rows = list(MessageDict.get_raw_db_rows(ids)) objs = MessageDict.ids_to_dict(ids)
objs = [MessageDict.build_dict_from_raw_db_row(row) for row in rows]
MessageDict.post_process_dicts( MessageDict.post_process_dicts(
objs, apply_markdown=False, client_gravatar=False, realm=realm objs, apply_markdown=False, client_gravatar=False, realm=realm
) )
self.assert_length(rows, num_ids) self.assert_length(objs, num_ids)
def test_applying_markdown(self) -> None: def test_applying_markdown(self) -> None:
sender = self.example_user("othello") sender = self.example_user("othello")
@ -200,8 +198,7 @@ class MessageDictTest(ZulipTestCase):
# An important part of this test is to get the message through this exact code path, # An important part of this test is to get the message through this exact code path,
# because there is an ugly hack we need to cover. So don't just say "row = message". # because there is an ugly hack we need to cover. So don't just say "row = message".
row = MessageDict.get_raw_db_rows([message.id])[0] dct = MessageDict.ids_to_dict([message.id])[0]
dct = MessageDict.build_dict_from_raw_db_row(row)
expected_content = "<p>hello <strong>world</strong></p>" expected_content = "<p>hello <strong>world</strong></p>"
self.assertEqual(dct["rendered_content"], expected_content) self.assertEqual(dct["rendered_content"], expected_content)
message = Message.objects.get(id=message.id) message = Message.objects.get(id=message.id)
@ -231,8 +228,7 @@ class MessageDictTest(ZulipTestCase):
# An important part of this test is to get the message through this exact code path, # An important part of this test is to get the message through this exact code path,
# because there is an ugly hack we need to cover. So don't just say "row = message". # because there is an ugly hack we need to cover. So don't just say "row = message".
row = MessageDict.get_raw_db_rows([message.id])[0] dct = MessageDict.ids_to_dict([message.id])[0]
dct = MessageDict.build_dict_from_raw_db_row(row)
error_content = ( error_content = (
"<p>[Zulip note: Sorry, we could not understand the formatting of your message]</p>" "<p>[Zulip note: Sorry, we could not understand the formatting of your message]</p>"
) )
@ -298,8 +294,7 @@ class MessageDictTest(ZulipTestCase):
reaction = Reaction.objects.create( reaction = Reaction.objects.create(
message=message, user_profile=sender, emoji_name="simple_smile" message=message, user_profile=sender, emoji_name="simple_smile"
) )
row = MessageDict.get_raw_db_rows([message.id])[0] msg_dict = MessageDict.ids_to_dict([message.id])[0]
msg_dict = MessageDict.build_dict_from_raw_db_row(row)
self.assertEqual(msg_dict["reactions"][0]["emoji_name"], reaction.emoji_name) self.assertEqual(msg_dict["reactions"][0]["emoji_name"], reaction.emoji_name)
self.assertEqual(msg_dict["reactions"][0]["user_id"], sender.id) self.assertEqual(msg_dict["reactions"][0]["user_id"], sender.id)
self.assertEqual(msg_dict["reactions"][0]["user"]["id"], sender.id) self.assertEqual(msg_dict["reactions"][0]["user"]["id"], sender.id)

View File

@ -1699,8 +1699,7 @@ class StreamMessagesTest(ZulipTestCase):
self.example_user("hamlet"), "Denmark", content="whatever", topic_name="my topic" self.example_user("hamlet"), "Denmark", content="whatever", topic_name="my topic"
) )
message = most_recent_message(user_profile) message = most_recent_message(user_profile)
row = MessageDict.get_raw_db_rows([message.id])[0] dct = MessageDict.ids_to_dict([message.id])[0]
dct = MessageDict.build_dict_from_raw_db_row(row)
MessageDict.post_process_dicts( MessageDict.post_process_dicts(
[dct], [dct],
apply_markdown=True, apply_markdown=True,

View File

@ -65,7 +65,7 @@ class TestBasics(ZulipTestCase):
rows.sort(key=lambda r: r["id"]) rows.sort(key=lambda r: r["id"])
self.assertEqual(rows, expected_data) self.assertEqual(rows, expected_data)
msg_rows = MessageDict.get_raw_db_rows([message_id]) msg_rows = MessageDict.ids_to_dict([message_id])
rows = msg_rows[0]["submessages"] rows = msg_rows[0]["submessages"]
rows.sort(key=lambda r: r["id"]) rows.sort(key=lambda r: r["id"])
self.assertEqual(rows, expected_data) self.assertEqual(rows, expected_data)