From b8ff6c184c2f3385276f9452216dbe8094fdb6ac Mon Sep 17 00:00:00 2001 From: Anders Kaseorg Date: Mon, 4 Mar 2024 16:46:18 -0800 Subject: [PATCH] narrow: Remove get_base_query_for_search need_message optimization. Signed-off-by: Anders Kaseorg --- zerver/lib/narrow.py | 52 ++++++++---------------------- zerver/tests/test_message_fetch.py | 49 ++++++++++++++++++++++++---- zerver/views/message_fetch.py | 2 +- 3 files changed, 57 insertions(+), 46 deletions(-) diff --git a/zerver/lib/narrow.py b/zerver/lib/narrow.py index 7584ebbb9b..d41d6e1a39 100644 --- a/zerver/lib/narrow.py +++ b/zerver/lib/narrow.py @@ -992,11 +992,10 @@ def exclude_muting_conditions( def get_base_query_for_search( - realm_id: int, user_profile: UserProfile | None, need_message: bool, need_user_message: bool + realm_id: int, user_profile: UserProfile | None, need_user_message: bool ) -> tuple[Select, ColumnElement[Integer]]: # Handle the simple case where user_message isn't involved first. if not need_user_message: - assert need_message query = ( select(column("id", Integer).label("message_id")) .select_from(table("zerver_message")) @@ -1007,30 +1006,21 @@ def get_base_query_for_search( return (query, inner_msg_id_col) assert user_profile is not None - if need_message: - query = ( - select(column("message_id", Integer)) - # We don't limit by realm_id despite the join to - # zerver_messages, since the user_profile_id limit in - # usermessage is more selective, and the query planner - # can't know about that cross-table correlation. - .where(column("user_profile_id", Integer) == literal(user_profile.id)) - .select_from( - join( - table("zerver_usermessage"), - table("zerver_message"), - literal_column("zerver_usermessage.message_id", Integer) - == literal_column("zerver_message.id", Integer), - ) - ) - ) - inner_msg_id_col = column("message_id", Integer) - return (query, inner_msg_id_col) - query = ( select(column("message_id", Integer)) + # We don't limit by realm_id despite the join to + # zerver_messages, since the user_profile_id limit in + # usermessage is more selective, and the query planner + # can't know about that cross-table correlation. .where(column("user_profile_id", Integer) == literal(user_profile.id)) - .select_from(table("zerver_usermessage")) + .select_from( + join( + table("zerver_usermessage"), + table("zerver_message"), + literal_column("zerver_usermessage.message_id", Integer) + == literal_column("zerver_message.id", Integer), + ) + ) ) inner_msg_id_col = column("message_id", Integer) return (query, inner_msg_id_col) @@ -1088,17 +1078,9 @@ def find_first_unread_anchor( # flag for the user. need_user_message = True - # Because we will need to call exclude_muting_conditions, unless - # the user hasn't muted anything, we will need to include Message - # in our query. It may be worth eventually adding an optimization - # for the case of a user who hasn't muted anything to avoid the - # join in that case, but it's low priority. - need_message = True - query, inner_msg_id_col = get_base_query_for_search( realm_id=user_profile.realm_id, user_profile=user_profile, - need_message=need_message, need_user_message=need_user_message, ) query = query.add_columns(column("flags", Integer)) @@ -1359,22 +1341,14 @@ def fetch_messages( # # Note that is_web_public_query=True goes here, since # include_history is semantically correct for is_web_public_query. - need_message = True need_user_message = False - elif narrow is None: - # We need to limit to messages the user has received, but we don't actually - # need any fields from Message - need_message = False - need_user_message = True else: - need_message = True need_user_message = True query: SelectBase query, inner_msg_id_col = get_base_query_for_search( realm_id=realm.id, user_profile=user_profile, - need_message=need_message, need_user_message=need_user_message, ) if need_user_message: diff --git a/zerver/tests/test_message_fetch.py b/zerver/tests/test_message_fetch.py index bd80a5dd5b..ba86ee6d70 100644 --- a/zerver/tests/test_message_fetch.py +++ b/zerver/tests/test_message_fetch.py @@ -4417,27 +4417,64 @@ recipient_id = %(recipient_id_3)s AND upper(subject) = upper(%(param_2)s))\ def test_get_messages_queries(self) -> None: query_ids = self.get_query_ids() - sql_template = "SELECT anon_1.message_id, anon_1.flags \nFROM (SELECT message_id, flags \nFROM zerver_usermessage \nWHERE user_profile_id = {hamlet_id} AND message_id = 0) AS anon_1 ORDER BY message_id ASC" + sql_template = """\ +SELECT anon_1.message_id, anon_1.flags \n\ +FROM (SELECT message_id, flags \n\ +FROM zerver_usermessage JOIN zerver_message ON zerver_usermessage.message_id = zerver_message.id \n\ +WHERE user_profile_id = {hamlet_id} AND message_id = 0) AS anon_1 ORDER BY message_id ASC\ +""" sql = sql_template.format(**query_ids) self.common_check_get_messages_query({"anchor": 0, "num_before": 0, "num_after": 0}, sql) - sql_template = "SELECT anon_1.message_id, anon_1.flags \nFROM (SELECT message_id, flags \nFROM zerver_usermessage \nWHERE user_profile_id = {hamlet_id} AND message_id = 0) AS anon_1 ORDER BY message_id ASC" + sql_template = """\ +SELECT anon_1.message_id, anon_1.flags \n\ +FROM (SELECT message_id, flags \n\ +FROM zerver_usermessage JOIN zerver_message ON zerver_usermessage.message_id = zerver_message.id \n\ +WHERE user_profile_id = {hamlet_id} AND message_id = 0) AS anon_1 ORDER BY message_id ASC\ +""" sql = sql_template.format(**query_ids) self.common_check_get_messages_query({"anchor": 0, "num_before": 1, "num_after": 0}, sql) - sql_template = "SELECT anon_1.message_id, anon_1.flags \nFROM (SELECT message_id, flags \nFROM zerver_usermessage \nWHERE user_profile_id = {hamlet_id} ORDER BY message_id ASC \n LIMIT 2) AS anon_1 ORDER BY message_id ASC" + sql_template = """\ +SELECT anon_1.message_id, anon_1.flags \n\ +FROM (SELECT message_id, flags \n\ +FROM zerver_usermessage JOIN zerver_message ON zerver_usermessage.message_id = zerver_message.id \n\ +WHERE user_profile_id = {hamlet_id} ORDER BY message_id ASC \n\ + LIMIT 2) AS anon_1 ORDER BY message_id ASC\ +""" sql = sql_template.format(**query_ids) self.common_check_get_messages_query({"anchor": 0, "num_before": 0, "num_after": 1}, sql) - sql_template = "SELECT anon_1.message_id, anon_1.flags \nFROM (SELECT message_id, flags \nFROM zerver_usermessage \nWHERE user_profile_id = {hamlet_id} ORDER BY message_id ASC \n LIMIT 11) AS anon_1 ORDER BY message_id ASC" + sql_template = """\ +SELECT anon_1.message_id, anon_1.flags \n\ +FROM (SELECT message_id, flags \n\ +FROM zerver_usermessage JOIN zerver_message ON zerver_usermessage.message_id = zerver_message.id \n\ +WHERE user_profile_id = {hamlet_id} ORDER BY message_id ASC \n\ + LIMIT 11) AS anon_1 ORDER BY message_id ASC\ +""" sql = sql_template.format(**query_ids) self.common_check_get_messages_query({"anchor": 0, "num_before": 0, "num_after": 10}, sql) - sql_template = "SELECT anon_1.message_id, anon_1.flags \nFROM (SELECT message_id, flags \nFROM zerver_usermessage \nWHERE user_profile_id = {hamlet_id} AND message_id <= 100 ORDER BY message_id DESC \n LIMIT 11) AS anon_1 ORDER BY message_id ASC" + sql_template = """\ +SELECT anon_1.message_id, anon_1.flags \n\ +FROM (SELECT message_id, flags \n\ +FROM zerver_usermessage JOIN zerver_message ON zerver_usermessage.message_id = zerver_message.id \n\ +WHERE user_profile_id = {hamlet_id} AND message_id <= 100 ORDER BY message_id DESC \n\ + LIMIT 11) AS anon_1 ORDER BY message_id ASC\ +""" sql = sql_template.format(**query_ids) self.common_check_get_messages_query({"anchor": 100, "num_before": 10, "num_after": 0}, sql) - sql_template = "SELECT anon_1.message_id, anon_1.flags \nFROM ((SELECT message_id, flags \nFROM zerver_usermessage \nWHERE user_profile_id = {hamlet_id} AND message_id <= 99 ORDER BY message_id DESC \n LIMIT 10) UNION ALL (SELECT message_id, flags \nFROM zerver_usermessage \nWHERE user_profile_id = {hamlet_id} AND message_id >= 100 ORDER BY message_id ASC \n LIMIT 11)) AS anon_1 ORDER BY message_id ASC" + sql_template = """\ +SELECT anon_1.message_id, anon_1.flags \n\ +FROM ((SELECT message_id, flags \n\ +FROM zerver_usermessage JOIN zerver_message ON zerver_usermessage.message_id = zerver_message.id \n\ +WHERE user_profile_id = {hamlet_id} AND message_id <= 99 ORDER BY message_id DESC \n\ + LIMIT 10) UNION ALL (SELECT message_id, flags \n\ +FROM zerver_usermessage JOIN zerver_message ON zerver_usermessage.message_id = zerver_message.id \n\ +WHERE user_profile_id = {hamlet_id} AND message_id >= 100 ORDER BY message_id ASC \n\ + LIMIT 11)) AS anon_1 ORDER BY message_id ASC\ +""" sql = sql_template.format(**query_ids) self.common_check_get_messages_query( {"anchor": 100, "num_before": 10, "num_after": 10}, sql diff --git a/zerver/views/message_fetch.py b/zerver/views/message_fetch.py index c7d19eb12f..1e2a31cbad 100644 --- a/zerver/views/message_fetch.py +++ b/zerver/views/message_fetch.py @@ -301,7 +301,7 @@ def messages_in_narrow_backend( # This query is limited to messages the user has access to because they # actually received them, as reflected in `zerver_usermessage`. query, inner_msg_id_col = get_base_query_for_search( - user_profile.realm_id, user_profile, need_message=True, need_user_message=True + user_profile.realm_id, user_profile, need_user_message=True ) query = query.where(column("message_id", Integer).in_(msg_ids))