From dae4633745404f03bcdc44a7513400ea3793996a Mon Sep 17 00:00:00 2001 From: Anders Kaseorg Date: Wed, 9 Nov 2022 15:35:52 -0800 Subject: [PATCH] message_fetch: Extract fetch_messages helper to zerver.lib.narrow. Signed-off-by: Anders Kaseorg --- zerver/lib/narrow.py | 121 ++++++++++++++++++++++++++++- zerver/tests/test_message_fetch.py | 2 +- zerver/tests/test_message_flags.py | 6 +- zerver/views/message_fetch.py | 102 +++--------------------- 4 files changed, 134 insertions(+), 97 deletions(-) diff --git a/zerver/lib/narrow.py b/zerver/lib/narrow.py index a1d2467af4..6fd3c64a09 100644 --- a/zerver/lib/narrow.py +++ b/zerver/lib/narrow.py @@ -23,7 +23,7 @@ from django.core.exceptions import ValidationError from django.db import connection from django.utils.translation import gettext as _ from sqlalchemy.dialects import postgresql -from sqlalchemy.engine import Connection +from sqlalchemy.engine import Connection, Row from sqlalchemy.sql import ( ClauseElement, ColumnElement, @@ -45,7 +45,9 @@ from sqlalchemy.types import ARRAY, Boolean, Integer, Text from zerver.lib.addressee import get_user_profiles, get_user_profiles_by_ids from zerver.lib.exceptions import ErrorCode, JsonableError +from zerver.lib.message import get_first_visible_message_id from zerver.lib.recipient_users import recipient_for_user_profiles +from zerver.lib.sqlalchemy_utils import get_sqlalchemy_connection from zerver.lib.streams import ( can_access_stream_history_by_id, can_access_stream_history_by_name, @@ -1145,3 +1147,120 @@ def post_process_limited_query( found_oldest=found_oldest, history_limited=history_limited, ) + + +@dataclass +class FetchedMessages(LimitedMessages[Row]): + anchor: int + include_history: bool + is_search: bool + + +def fetch_messages( + *, + narrow: OptionalNarrowListT, + user_profile: Optional[UserProfile], + realm: Realm, + is_web_public_query: bool, + anchor: Optional[int], + num_before: int, + num_after: int, +) -> FetchedMessages: + include_history = ok_to_include_history(narrow, user_profile, is_web_public_query) + if include_history: + # The initial query in this case doesn't use `zerver_usermessage`, + # and isn't yet limited to messages the user is entitled to see! + # + # This is OK only because we've made sure this is a narrow that + # will cause us to limit the query appropriately elsewhere. + # See `ok_to_include_history` for details. + # + # Note that is_web_public_query=True goes here, since + # include_history is semantically correct for is_web_public_query. + need_message = True + need_user_message = False + elif narrow is None: + # We need to limit to messages the user has received, but we don't actually + # need any fields from Message + need_message = False + need_user_message = True + else: + need_message = True + need_user_message = True + + query: SelectBase + query, inner_msg_id_col = get_base_query_for_search( + user_profile=user_profile, + need_message=need_message, + need_user_message=need_user_message, + ) + + query, is_search = add_narrow_conditions( + user_profile=user_profile, + inner_msg_id_col=inner_msg_id_col, + query=query, + narrow=narrow, + realm=realm, + is_web_public_query=is_web_public_query, + ) + + with get_sqlalchemy_connection() as sa_conn: + if anchor is None: + # `anchor=None` corresponds to the anchor="first_unread" parameter. + anchor = find_first_unread_anchor( + sa_conn, + user_profile, + narrow, + ) + + anchored_to_left = anchor == 0 + + # Set value that will be used to short circuit the after_query + # altogether and avoid needless conditions in the before_query. + anchored_to_right = anchor >= LARGER_THAN_MAX_MESSAGE_ID + if anchored_to_right: + num_after = 0 + + first_visible_message_id = get_first_visible_message_id(realm) + + query = limit_query_to_range( + query=query, + num_before=num_before, + num_after=num_after, + anchor=anchor, + anchored_to_left=anchored_to_left, + anchored_to_right=anchored_to_right, + id_col=inner_msg_id_col, + first_visible_message_id=first_visible_message_id, + ) + + main_query = query.subquery() + query = ( + select(*main_query.c) + .select_from(main_query) + .order_by(column("message_id", Integer).asc()) + ) + # This is a hack to tag the query we use for testing + query = query.prefix_with("/* get_messages */") + rows = list(sa_conn.execute(query).fetchall()) + + query_info = post_process_limited_query( + rows=rows, + num_before=num_before, + num_after=num_after, + anchor=anchor, + anchored_to_left=anchored_to_left, + anchored_to_right=anchored_to_right, + first_visible_message_id=first_visible_message_id, + ) + + return FetchedMessages( + rows=query_info.rows, + found_anchor=query_info.found_anchor, + found_newest=query_info.found_newest, + found_oldest=query_info.found_oldest, + history_limited=query_info.history_limited, + anchor=anchor, + include_history=include_history, + is_search=is_search, + ) diff --git a/zerver/tests/test_message_fetch.py b/zerver/tests/test_message_fetch.py index 0e8ca7bbbc..96209979c5 100644 --- a/zerver/tests/test_message_fetch.py +++ b/zerver/tests/test_message_fetch.py @@ -92,7 +92,7 @@ def mute_stream(realm: Realm, user_profile: UserProfile, stream_name: str) -> No def first_visible_id_as(message_id: int) -> Any: return mock.patch( - "zerver.views.message_fetch.get_first_visible_message_id", + "zerver.lib.narrow.get_first_visible_message_id", return_value=message_id, ) diff --git a/zerver/tests/test_message_flags.py b/zerver/tests/test_message_flags.py index f7a6b5129f..3abe21718a 100644 --- a/zerver/tests/test_message_flags.py +++ b/zerver/tests/test_message_flags.py @@ -136,7 +136,7 @@ class FirstUnreadAnchorTests(ZulipTestCase): self.assertEqual(messages_response["anchor"], new_message_id) with mock.patch( - "zerver.views.message_fetch.get_first_visible_message_id", return_value=new_message_id + "zerver.lib.narrow.get_first_visible_message_id", return_value=new_message_id ): messages_response = self.get_messages_response( anchor="first_unread", num_before=0, num_after=1 @@ -145,7 +145,7 @@ class FirstUnreadAnchorTests(ZulipTestCase): self.assertEqual(messages_response["anchor"], new_message_id) with mock.patch( - "zerver.views.message_fetch.get_first_visible_message_id", + "zerver.lib.narrow.get_first_visible_message_id", return_value=new_message_id + 1, ): messages_reponse = self.get_messages_response( @@ -155,7 +155,7 @@ class FirstUnreadAnchorTests(ZulipTestCase): self.assertIn("anchor", messages_reponse) with mock.patch( - "zerver.views.message_fetch.get_first_visible_message_id", + "zerver.lib.narrow.get_first_visible_message_id", return_value=new_message_id - 1, ): messages = self.get_messages(anchor="first_unread", num_before=0, num_after=1) diff --git a/zerver/views/message_fetch.py b/zerver/views/message_fetch.py index 751aee8ad8..3826361048 100644 --- a/zerver/views/message_fetch.py +++ b/zerver/views/message_fetch.py @@ -5,26 +5,19 @@ from django.http import HttpRequest, HttpResponse from django.utils.html import escape as escape_html from django.utils.translation import gettext as _ from sqlalchemy.sql import and_, column, join, literal, literal_column, select, table -from sqlalchemy.sql.selectable import SelectBase from sqlalchemy.types import Integer, Text from zerver.context_processors import get_valid_realm_from_request from zerver.lib.exceptions import JsonableError, MissingAuthenticationError from zerver.lib.message import get_first_visible_message_id, messages_for_ids from zerver.lib.narrow import ( - LARGER_THAN_MAX_MESSAGE_ID, NarrowBuilder, OptionalNarrowListT, - add_narrow_conditions, - find_first_unread_anchor, - get_base_query_for_search, + fetch_messages, is_spectator_compatible, is_web_public_narrow, - limit_query_to_range, narrow_parameter, - ok_to_include_history, parse_anchor_value, - post_process_limited_query, ) from zerver.lib.request import REQ, RequestNotes, has_request_variables from zerver.lib.response import json_success @@ -149,44 +142,6 @@ def get_messages_backend( # clients cannot compute gravatars, so we force-set it to false. client_gravatar = False - include_history = ok_to_include_history(narrow, user_profile, is_web_public_query) - if include_history: - # The initial query in this case doesn't use `zerver_usermessage`, - # and isn't yet limited to messages the user is entitled to see! - # - # This is OK only because we've made sure this is a narrow that - # will cause us to limit the query appropriately elsewhere. - # See `ok_to_include_history` for details. - # - # Note that is_web_public_query=True goes here, since - # include_history is semantically correct for is_web_public_query. - need_message = True - need_user_message = False - elif narrow is None: - # We need to limit to messages the user has received, but we don't actually - # need any fields from Message - need_message = False - need_user_message = True - else: - need_message = True - need_user_message = True - - query: SelectBase - query, inner_msg_id_col = get_base_query_for_search( - user_profile=user_profile, - need_message=need_message, - need_user_message=need_user_message, - ) - - query, is_search = add_narrow_conditions( - user_profile=user_profile, - inner_msg_id_col=inner_msg_id_col, - query=query, - narrow=narrow, - realm=realm, - is_web_public_query=is_web_public_query, - ) - if narrow is not None: # Add some metadata to our logging data for narrows verbose_operators = [] @@ -199,56 +154,19 @@ def get_messages_backend( assert log_data is not None log_data["extra"] = "[{}]".format(",".join(verbose_operators)) - with get_sqlalchemy_connection() as sa_conn: - if anchor is None: - # `anchor=None` corresponds to the anchor="first_unread" parameter. - anchor = find_first_unread_anchor( - sa_conn, - user_profile, - narrow, - ) - - anchored_to_left = anchor == 0 - - # Set value that will be used to short circuit the after_query - # altogether and avoid needless conditions in the before_query. - anchored_to_right = anchor >= LARGER_THAN_MAX_MESSAGE_ID - if anchored_to_right: - num_after = 0 - - first_visible_message_id = get_first_visible_message_id(realm) - - query = limit_query_to_range( - query=query, - num_before=num_before, - num_after=num_after, - anchor=anchor, - anchored_to_left=anchored_to_left, - anchored_to_right=anchored_to_right, - id_col=inner_msg_id_col, - first_visible_message_id=first_visible_message_id, - ) - - main_query = query.subquery() - query = ( - select(*main_query.c) - .select_from(main_query) - .order_by(column("message_id", Integer).asc()) - ) - # This is a hack to tag the query we use for testing - query = query.prefix_with("/* get_messages */") - rows = list(sa_conn.execute(query).fetchall()) - - query_info = post_process_limited_query( - rows=rows, + query_info = fetch_messages( + narrow=narrow, + user_profile=user_profile, + realm=realm, + is_web_public_query=is_web_public_query, + anchor=anchor, num_before=num_before, num_after=num_after, - anchor=anchor, - anchored_to_left=anchored_to_left, - anchored_to_right=anchored_to_right, - first_visible_message_id=first_visible_message_id, ) + anchor = query_info.anchor + include_history = query_info.include_history + is_search = query_info.is_search rows = query_info.rows # The following is a little messy, but ensures that the code paths