narrow: Remove get_base_query_for_search need_message optimization.

Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
Anders Kaseorg 2024-03-04 16:46:18 -08:00
parent 98f7641d77
commit b8ff6c184c
3 changed files with 57 additions and 46 deletions

View File

@ -992,11 +992,10 @@ def exclude_muting_conditions(
def get_base_query_for_search( def get_base_query_for_search(
realm_id: int, user_profile: UserProfile | None, need_message: bool, need_user_message: bool realm_id: int, user_profile: UserProfile | None, need_user_message: bool
) -> tuple[Select, ColumnElement[Integer]]: ) -> tuple[Select, ColumnElement[Integer]]:
# Handle the simple case where user_message isn't involved first. # Handle the simple case where user_message isn't involved first.
if not need_user_message: if not need_user_message:
assert need_message
query = ( query = (
select(column("id", Integer).label("message_id")) select(column("id", Integer).label("message_id"))
.select_from(table("zerver_message")) .select_from(table("zerver_message"))
@ -1007,30 +1006,21 @@ def get_base_query_for_search(
return (query, inner_msg_id_col) return (query, inner_msg_id_col)
assert user_profile is not None assert user_profile is not None
if need_message:
query = (
select(column("message_id", Integer))
# We don't limit by realm_id despite the join to
# zerver_messages, since the user_profile_id limit in
# usermessage is more selective, and the query planner
# can't know about that cross-table correlation.
.where(column("user_profile_id", Integer) == literal(user_profile.id))
.select_from(
join(
table("zerver_usermessage"),
table("zerver_message"),
literal_column("zerver_usermessage.message_id", Integer)
== literal_column("zerver_message.id", Integer),
)
)
)
inner_msg_id_col = column("message_id", Integer)
return (query, inner_msg_id_col)
query = ( query = (
select(column("message_id", Integer)) select(column("message_id", Integer))
# We don't limit by realm_id despite the join to
# zerver_messages, since the user_profile_id limit in
# usermessage is more selective, and the query planner
# can't know about that cross-table correlation.
.where(column("user_profile_id", Integer) == literal(user_profile.id)) .where(column("user_profile_id", Integer) == literal(user_profile.id))
.select_from(table("zerver_usermessage")) .select_from(
join(
table("zerver_usermessage"),
table("zerver_message"),
literal_column("zerver_usermessage.message_id", Integer)
== literal_column("zerver_message.id", Integer),
)
)
) )
inner_msg_id_col = column("message_id", Integer) inner_msg_id_col = column("message_id", Integer)
return (query, inner_msg_id_col) return (query, inner_msg_id_col)
@ -1088,17 +1078,9 @@ def find_first_unread_anchor(
# flag for the user. # flag for the user.
need_user_message = True need_user_message = True
# Because we will need to call exclude_muting_conditions, unless
# the user hasn't muted anything, we will need to include Message
# in our query. It may be worth eventually adding an optimization
# for the case of a user who hasn't muted anything to avoid the
# join in that case, but it's low priority.
need_message = True
query, inner_msg_id_col = get_base_query_for_search( query, inner_msg_id_col = get_base_query_for_search(
realm_id=user_profile.realm_id, realm_id=user_profile.realm_id,
user_profile=user_profile, user_profile=user_profile,
need_message=need_message,
need_user_message=need_user_message, need_user_message=need_user_message,
) )
query = query.add_columns(column("flags", Integer)) query = query.add_columns(column("flags", Integer))
@ -1359,22 +1341,14 @@ def fetch_messages(
# #
# Note that is_web_public_query=True goes here, since # Note that is_web_public_query=True goes here, since
# include_history is semantically correct for is_web_public_query. # include_history is semantically correct for is_web_public_query.
need_message = True
need_user_message = False need_user_message = False
elif narrow is None:
# We need to limit to messages the user has received, but we don't actually
# need any fields from Message
need_message = False
need_user_message = True
else: else:
need_message = True
need_user_message = True need_user_message = True
query: SelectBase query: SelectBase
query, inner_msg_id_col = get_base_query_for_search( query, inner_msg_id_col = get_base_query_for_search(
realm_id=realm.id, realm_id=realm.id,
user_profile=user_profile, user_profile=user_profile,
need_message=need_message,
need_user_message=need_user_message, need_user_message=need_user_message,
) )
if need_user_message: if need_user_message:

View File

@ -4417,27 +4417,64 @@ recipient_id = %(recipient_id_3)s AND upper(subject) = upper(%(param_2)s))\
def test_get_messages_queries(self) -> None: def test_get_messages_queries(self) -> None:
query_ids = self.get_query_ids() query_ids = self.get_query_ids()
sql_template = "SELECT anon_1.message_id, anon_1.flags \nFROM (SELECT message_id, flags \nFROM zerver_usermessage \nWHERE user_profile_id = {hamlet_id} AND message_id = 0) AS anon_1 ORDER BY message_id ASC" sql_template = """\
SELECT anon_1.message_id, anon_1.flags \n\
FROM (SELECT message_id, flags \n\
FROM zerver_usermessage JOIN zerver_message ON zerver_usermessage.message_id = zerver_message.id \n\
WHERE user_profile_id = {hamlet_id} AND message_id = 0) AS anon_1 ORDER BY message_id ASC\
"""
sql = sql_template.format(**query_ids) sql = sql_template.format(**query_ids)
self.common_check_get_messages_query({"anchor": 0, "num_before": 0, "num_after": 0}, sql) self.common_check_get_messages_query({"anchor": 0, "num_before": 0, "num_after": 0}, sql)
sql_template = "SELECT anon_1.message_id, anon_1.flags \nFROM (SELECT message_id, flags \nFROM zerver_usermessage \nWHERE user_profile_id = {hamlet_id} AND message_id = 0) AS anon_1 ORDER BY message_id ASC" sql_template = """\
SELECT anon_1.message_id, anon_1.flags \n\
FROM (SELECT message_id, flags \n\
FROM zerver_usermessage JOIN zerver_message ON zerver_usermessage.message_id = zerver_message.id \n\
WHERE user_profile_id = {hamlet_id} AND message_id = 0) AS anon_1 ORDER BY message_id ASC\
"""
sql = sql_template.format(**query_ids) sql = sql_template.format(**query_ids)
self.common_check_get_messages_query({"anchor": 0, "num_before": 1, "num_after": 0}, sql) self.common_check_get_messages_query({"anchor": 0, "num_before": 1, "num_after": 0}, sql)
sql_template = "SELECT anon_1.message_id, anon_1.flags \nFROM (SELECT message_id, flags \nFROM zerver_usermessage \nWHERE user_profile_id = {hamlet_id} ORDER BY message_id ASC \n LIMIT 2) AS anon_1 ORDER BY message_id ASC" sql_template = """\
SELECT anon_1.message_id, anon_1.flags \n\
FROM (SELECT message_id, flags \n\
FROM zerver_usermessage JOIN zerver_message ON zerver_usermessage.message_id = zerver_message.id \n\
WHERE user_profile_id = {hamlet_id} ORDER BY message_id ASC \n\
LIMIT 2) AS anon_1 ORDER BY message_id ASC\
"""
sql = sql_template.format(**query_ids) sql = sql_template.format(**query_ids)
self.common_check_get_messages_query({"anchor": 0, "num_before": 0, "num_after": 1}, sql) self.common_check_get_messages_query({"anchor": 0, "num_before": 0, "num_after": 1}, sql)
sql_template = "SELECT anon_1.message_id, anon_1.flags \nFROM (SELECT message_id, flags \nFROM zerver_usermessage \nWHERE user_profile_id = {hamlet_id} ORDER BY message_id ASC \n LIMIT 11) AS anon_1 ORDER BY message_id ASC" sql_template = """\
SELECT anon_1.message_id, anon_1.flags \n\
FROM (SELECT message_id, flags \n\
FROM zerver_usermessage JOIN zerver_message ON zerver_usermessage.message_id = zerver_message.id \n\
WHERE user_profile_id = {hamlet_id} ORDER BY message_id ASC \n\
LIMIT 11) AS anon_1 ORDER BY message_id ASC\
"""
sql = sql_template.format(**query_ids) sql = sql_template.format(**query_ids)
self.common_check_get_messages_query({"anchor": 0, "num_before": 0, "num_after": 10}, sql) self.common_check_get_messages_query({"anchor": 0, "num_before": 0, "num_after": 10}, sql)
sql_template = "SELECT anon_1.message_id, anon_1.flags \nFROM (SELECT message_id, flags \nFROM zerver_usermessage \nWHERE user_profile_id = {hamlet_id} AND message_id <= 100 ORDER BY message_id DESC \n LIMIT 11) AS anon_1 ORDER BY message_id ASC" sql_template = """\
SELECT anon_1.message_id, anon_1.flags \n\
FROM (SELECT message_id, flags \n\
FROM zerver_usermessage JOIN zerver_message ON zerver_usermessage.message_id = zerver_message.id \n\
WHERE user_profile_id = {hamlet_id} AND message_id <= 100 ORDER BY message_id DESC \n\
LIMIT 11) AS anon_1 ORDER BY message_id ASC\
"""
sql = sql_template.format(**query_ids) sql = sql_template.format(**query_ids)
self.common_check_get_messages_query({"anchor": 100, "num_before": 10, "num_after": 0}, sql) self.common_check_get_messages_query({"anchor": 100, "num_before": 10, "num_after": 0}, sql)
sql_template = "SELECT anon_1.message_id, anon_1.flags \nFROM ((SELECT message_id, flags \nFROM zerver_usermessage \nWHERE user_profile_id = {hamlet_id} AND message_id <= 99 ORDER BY message_id DESC \n LIMIT 10) UNION ALL (SELECT message_id, flags \nFROM zerver_usermessage \nWHERE user_profile_id = {hamlet_id} AND message_id >= 100 ORDER BY message_id ASC \n LIMIT 11)) AS anon_1 ORDER BY message_id ASC" sql_template = """\
SELECT anon_1.message_id, anon_1.flags \n\
FROM ((SELECT message_id, flags \n\
FROM zerver_usermessage JOIN zerver_message ON zerver_usermessage.message_id = zerver_message.id \n\
WHERE user_profile_id = {hamlet_id} AND message_id <= 99 ORDER BY message_id DESC \n\
LIMIT 10) UNION ALL (SELECT message_id, flags \n\
FROM zerver_usermessage JOIN zerver_message ON zerver_usermessage.message_id = zerver_message.id \n\
WHERE user_profile_id = {hamlet_id} AND message_id >= 100 ORDER BY message_id ASC \n\
LIMIT 11)) AS anon_1 ORDER BY message_id ASC\
"""
sql = sql_template.format(**query_ids) sql = sql_template.format(**query_ids)
self.common_check_get_messages_query( self.common_check_get_messages_query(
{"anchor": 100, "num_before": 10, "num_after": 10}, sql {"anchor": 100, "num_before": 10, "num_after": 10}, sql

View File

@ -301,7 +301,7 @@ def messages_in_narrow_backend(
# This query is limited to messages the user has access to because they # This query is limited to messages the user has access to because they
# actually received them, as reflected in `zerver_usermessage`. # actually received them, as reflected in `zerver_usermessage`.
query, inner_msg_id_col = get_base_query_for_search( query, inner_msg_id_col = get_base_query_for_search(
user_profile.realm_id, user_profile, need_message=True, need_user_message=True user_profile.realm_id, user_profile, need_user_message=True
) )
query = query.where(column("message_id", Integer).in_(msg_ids)) query = query.where(column("message_id", Integer).in_(msg_ids))