diff --git a/zerver/lib/narrow.py b/zerver/lib/narrow.py index f7edc5d097..4605f9f47f 100644 --- a/zerver/lib/narrow.py +++ b/zerver/lib/narrow.py @@ -20,7 +20,7 @@ from django.core.exceptions import ValidationError from django.db import connection from django.utils.translation import gettext as _ from sqlalchemy.dialects import postgresql -from sqlalchemy.engine import Connection +from sqlalchemy.engine import Connection, Row from sqlalchemy.sql import ( ClauseElement, ColumnElement, @@ -35,7 +35,9 @@ from sqlalchemy.sql import ( or_, select, table, + union_all, ) +from sqlalchemy.sql.selectable import SelectBase from sqlalchemy.types import ARRAY, Boolean, Integer, Text from zerver.lib.addressee import get_user_profiles, get_user_profiles_by_ids @@ -980,3 +982,153 @@ def parse_anchor_value(anchor_val: Optional[str], use_first_unread_anchor: bool) return anchor except ValueError: raise JsonableError(_("Invalid anchor")) + + +def limit_query_to_range( + query: Select, + num_before: int, + num_after: int, + anchor: int, + anchored_to_left: bool, + anchored_to_right: bool, + id_col: ColumnElement[Integer], + first_visible_message_id: int, +) -> SelectBase: + """ + This code is actually generic enough that we could move it to a + library, but our only caller for now is message search. + """ + need_before_query = (not anchored_to_left) and (num_before > 0) + need_after_query = (not anchored_to_right) and (num_after > 0) + + need_both_sides = need_before_query and need_after_query + + # The semantics of our flags are as follows: + # + # num_after = number of rows < anchor + # num_after = number of rows > anchor + # + # But we also want the row where id == anchor (if it exists), + # and we don't want to union up to 3 queries. So in some cases + # we do things like `after_limit = num_after + 1` to grab the + # anchor row in the "after" query. + # + # Note that in some cases, if the anchor row isn't found, we + # actually may fetch an extra row at one of the extremes. + if need_both_sides: + before_anchor = anchor - 1 + after_anchor = max(anchor, first_visible_message_id) + before_limit = num_before + after_limit = num_after + 1 + elif need_before_query: + before_anchor = anchor + before_limit = num_before + if not anchored_to_right: + before_limit += 1 + elif need_after_query: + after_anchor = max(anchor, first_visible_message_id) + after_limit = num_after + 1 + + if need_before_query: + before_query = query + + if not anchored_to_right: + before_query = before_query.where(id_col <= before_anchor) + + before_query = before_query.order_by(id_col.desc()) + before_query = before_query.limit(before_limit) + + if need_after_query: + after_query = query + + if not anchored_to_left: + after_query = after_query.where(id_col >= after_anchor) + + after_query = after_query.order_by(id_col.asc()) + after_query = after_query.limit(after_limit) + + if need_both_sides: + return union_all(before_query.self_group(), after_query.self_group()) + elif need_before_query: + return before_query + elif need_after_query: + return after_query + else: + # If we don't have either a before_query or after_query, it's because + # some combination of num_before/num_after/anchor are zero or + # use_first_unread_anchor logic found no unread messages. + # + # The most likely reason is somebody is doing an id search, so searching + # for something like `message_id = 42` is exactly what we want. In other + # cases, which could possibly be buggy API clients, at least we will + # return at most one row here. + return query.where(id_col == anchor) + + +def post_process_limited_query( + rows: Sequence[Union[Row, Sequence[Any]]], + num_before: int, + num_after: int, + anchor: int, + anchored_to_left: bool, + anchored_to_right: bool, + first_visible_message_id: int, +) -> Dict[str, Any]: + # Our queries may have fetched extra rows if they added + # "headroom" to the limits, but we want to truncate those + # rows. + # + # Also, in cases where we had non-zero values of num_before or + # num_after, we want to know found_oldest and found_newest, so + # that the clients will know that they got complete results. + + if first_visible_message_id > 0: + visible_rows: Sequence[Union[Row, Sequence[Any]]] = [ + r for r in rows if r[0] >= first_visible_message_id + ] + else: + visible_rows = rows + + rows_limited = len(visible_rows) != len(rows) + + if anchored_to_right: + num_after = 0 + before_rows = visible_rows[:] + anchor_rows = [] + after_rows = [] + else: + before_rows = [r for r in visible_rows if r[0] < anchor] + anchor_rows = [r for r in visible_rows if r[0] == anchor] + after_rows = [r for r in visible_rows if r[0] > anchor] + + if num_before: + before_rows = before_rows[-1 * num_before :] + + if num_after: + after_rows = after_rows[:num_after] + + visible_rows = [*before_rows, *anchor_rows, *after_rows] + + found_anchor = len(anchor_rows) == 1 + found_oldest = anchored_to_left or (len(before_rows) < num_before) + found_newest = anchored_to_right or (len(after_rows) < num_after) + # BUG: history_limited is incorrect False in the event that we had + # to bump `anchor` up due to first_visible_message_id, and there + # were actually older messages. This may be a rare event in the + # context where history_limited is relevant, because it can only + # happen in one-sided queries with no num_before (see tests tagged + # BUG in PostProcessTest for examples), and we don't generally do + # those from the UI, so this might be OK for now. + # + # The correct fix for this probably involves e.g. making a + # `before_query` when we increase `anchor` just to confirm whether + # messages were hidden. + history_limited = rows_limited and found_oldest + + return dict( + rows=visible_rows, + found_anchor=found_anchor, + found_newest=found_newest, + found_oldest=found_oldest, + history_limited=history_limited, + ) diff --git a/zerver/tests/test_message_fetch.py b/zerver/tests/test_message_fetch.py index 7281a40223..60c8146a7b 100644 --- a/zerver/tests/test_message_fetch.py +++ b/zerver/tests/test_message_fetch.py @@ -35,6 +35,7 @@ from zerver.lib.narrow import ( find_first_unread_anchor, is_spectator_compatible, ok_to_include_history, + post_process_limited_query, ) from zerver.lib.sqlalchemy_utils import get_sqlalchemy_connection from zerver.lib.streams import StreamDict, create_streams_if_needed, get_public_streams_queryset @@ -56,7 +57,7 @@ from zerver.models import ( get_realm, get_stream, ) -from zerver.views.message_fetch import get_messages_backend, post_process_limited_query +from zerver.views.message_fetch import get_messages_backend if TYPE_CHECKING: from django.test.client import _MonkeyPatchedWSGIResponse as TestHttpResponse diff --git a/zerver/views/message_fetch.py b/zerver/views/message_fetch.py index 35abfa69da..887c488598 100644 --- a/zerver/views/message_fetch.py +++ b/zerver/views/message_fetch.py @@ -1,22 +1,10 @@ -from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union +from typing import Dict, Iterable, List, Optional, Tuple, Union from django.contrib.auth.models import AnonymousUser from django.http import HttpRequest, HttpResponse from django.utils.html import escape as escape_html from django.utils.translation import gettext as _ -from sqlalchemy.engine import Row -from sqlalchemy.sql import ( - ColumnElement, - Select, - and_, - column, - join, - literal, - literal_column, - select, - table, - union_all, -) +from sqlalchemy.sql import and_, column, join, literal, literal_column, select, table from sqlalchemy.sql.selectable import SelectBase from sqlalchemy.types import Integer, Text @@ -32,9 +20,11 @@ from zerver.lib.narrow import ( get_base_query_for_search, is_spectator_compatible, is_web_public_narrow, + limit_query_to_range, narrow_parameter, ok_to_include_history, parse_anchor_value, + post_process_limited_query, ) from zerver.lib.request import REQ, RequestNotes, has_request_variables from zerver.lib.response import json_success @@ -327,156 +317,6 @@ def get_messages_backend( return json_success(request, data=ret) -def limit_query_to_range( - query: Select, - num_before: int, - num_after: int, - anchor: int, - anchored_to_left: bool, - anchored_to_right: bool, - id_col: ColumnElement[Integer], - first_visible_message_id: int, -) -> SelectBase: - """ - This code is actually generic enough that we could move it to a - library, but our only caller for now is message search. - """ - need_before_query = (not anchored_to_left) and (num_before > 0) - need_after_query = (not anchored_to_right) and (num_after > 0) - - need_both_sides = need_before_query and need_after_query - - # The semantics of our flags are as follows: - # - # num_after = number of rows < anchor - # num_after = number of rows > anchor - # - # But we also want the row where id == anchor (if it exists), - # and we don't want to union up to 3 queries. So in some cases - # we do things like `after_limit = num_after + 1` to grab the - # anchor row in the "after" query. - # - # Note that in some cases, if the anchor row isn't found, we - # actually may fetch an extra row at one of the extremes. - if need_both_sides: - before_anchor = anchor - 1 - after_anchor = max(anchor, first_visible_message_id) - before_limit = num_before - after_limit = num_after + 1 - elif need_before_query: - before_anchor = anchor - before_limit = num_before - if not anchored_to_right: - before_limit += 1 - elif need_after_query: - after_anchor = max(anchor, first_visible_message_id) - after_limit = num_after + 1 - - if need_before_query: - before_query = query - - if not anchored_to_right: - before_query = before_query.where(id_col <= before_anchor) - - before_query = before_query.order_by(id_col.desc()) - before_query = before_query.limit(before_limit) - - if need_after_query: - after_query = query - - if not anchored_to_left: - after_query = after_query.where(id_col >= after_anchor) - - after_query = after_query.order_by(id_col.asc()) - after_query = after_query.limit(after_limit) - - if need_both_sides: - return union_all(before_query.self_group(), after_query.self_group()) - elif need_before_query: - return before_query - elif need_after_query: - return after_query - else: - # If we don't have either a before_query or after_query, it's because - # some combination of num_before/num_after/anchor are zero or - # use_first_unread_anchor logic found no unread messages. - # - # The most likely reason is somebody is doing an id search, so searching - # for something like `message_id = 42` is exactly what we want. In other - # cases, which could possibly be buggy API clients, at least we will - # return at most one row here. - return query.where(id_col == anchor) - - -def post_process_limited_query( - rows: Sequence[Union[Row, Sequence[Any]]], - num_before: int, - num_after: int, - anchor: int, - anchored_to_left: bool, - anchored_to_right: bool, - first_visible_message_id: int, -) -> Dict[str, Any]: - # Our queries may have fetched extra rows if they added - # "headroom" to the limits, but we want to truncate those - # rows. - # - # Also, in cases where we had non-zero values of num_before or - # num_after, we want to know found_oldest and found_newest, so - # that the clients will know that they got complete results. - - if first_visible_message_id > 0: - visible_rows: Sequence[Union[Row, Sequence[Any]]] = [ - r for r in rows if r[0] >= first_visible_message_id - ] - else: - visible_rows = rows - - rows_limited = len(visible_rows) != len(rows) - - if anchored_to_right: - num_after = 0 - before_rows = visible_rows[:] - anchor_rows = [] - after_rows = [] - else: - before_rows = [r for r in visible_rows if r[0] < anchor] - anchor_rows = [r for r in visible_rows if r[0] == anchor] - after_rows = [r for r in visible_rows if r[0] > anchor] - - if num_before: - before_rows = before_rows[-1 * num_before :] - - if num_after: - after_rows = after_rows[:num_after] - - visible_rows = [*before_rows, *anchor_rows, *after_rows] - - found_anchor = len(anchor_rows) == 1 - found_oldest = anchored_to_left or (len(before_rows) < num_before) - found_newest = anchored_to_right or (len(after_rows) < num_after) - # BUG: history_limited is incorrect False in the event that we had - # to bump `anchor` up due to first_visible_message_id, and there - # were actually older messages. This may be a rare event in the - # context where history_limited is relevant, because it can only - # happen in one-sided queries with no num_before (see tests tagged - # BUG in PostProcessTest for examples), and we don't generally do - # those from the UI, so this might be OK for now. - # - # The correct fix for this probably involves e.g. making a - # `before_query` when we increase `anchor` just to confirm whether - # messages were hidden. - history_limited = rows_limited and found_oldest - - return dict( - rows=visible_rows, - found_anchor=found_anchor, - found_newest=found_newest, - found_oldest=found_oldest, - history_limited=history_limited, - ) - - @has_request_variables def messages_in_narrow_backend( request: HttpRequest,