message_fetch: Move limit_query_to_range to zerver.lib.narrow.

Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
Anders Kaseorg 2022-10-21 13:29:15 -04:00 committed by Tim Abbott
parent 1095efeb52
commit 0a0a70b33d
3 changed files with 159 additions and 166 deletions

View File

@ -20,7 +20,7 @@ from django.core.exceptions import ValidationError
from django.db import connection from django.db import connection
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from sqlalchemy.dialects import postgresql from sqlalchemy.dialects import postgresql
from sqlalchemy.engine import Connection from sqlalchemy.engine import Connection, Row
from sqlalchemy.sql import ( from sqlalchemy.sql import (
ClauseElement, ClauseElement,
ColumnElement, ColumnElement,
@ -35,7 +35,9 @@ from sqlalchemy.sql import (
or_, or_,
select, select,
table, table,
union_all,
) )
from sqlalchemy.sql.selectable import SelectBase
from sqlalchemy.types import ARRAY, Boolean, Integer, Text from sqlalchemy.types import ARRAY, Boolean, Integer, Text
from zerver.lib.addressee import get_user_profiles, get_user_profiles_by_ids from zerver.lib.addressee import get_user_profiles, get_user_profiles_by_ids
@ -980,3 +982,153 @@ def parse_anchor_value(anchor_val: Optional[str], use_first_unread_anchor: bool)
return anchor return anchor
except ValueError: except ValueError:
raise JsonableError(_("Invalid anchor")) raise JsonableError(_("Invalid anchor"))
def limit_query_to_range(
query: Select,
num_before: int,
num_after: int,
anchor: int,
anchored_to_left: bool,
anchored_to_right: bool,
id_col: ColumnElement[Integer],
first_visible_message_id: int,
) -> SelectBase:
"""
This code is actually generic enough that we could move it to a
library, but our only caller for now is message search.
"""
need_before_query = (not anchored_to_left) and (num_before > 0)
need_after_query = (not anchored_to_right) and (num_after > 0)
need_both_sides = need_before_query and need_after_query
# The semantics of our flags are as follows:
#
# num_after = number of rows < anchor
# num_after = number of rows > anchor
#
# But we also want the row where id == anchor (if it exists),
# and we don't want to union up to 3 queries. So in some cases
# we do things like `after_limit = num_after + 1` to grab the
# anchor row in the "after" query.
#
# Note that in some cases, if the anchor row isn't found, we
# actually may fetch an extra row at one of the extremes.
if need_both_sides:
before_anchor = anchor - 1
after_anchor = max(anchor, first_visible_message_id)
before_limit = num_before
after_limit = num_after + 1
elif need_before_query:
before_anchor = anchor
before_limit = num_before
if not anchored_to_right:
before_limit += 1
elif need_after_query:
after_anchor = max(anchor, first_visible_message_id)
after_limit = num_after + 1
if need_before_query:
before_query = query
if not anchored_to_right:
before_query = before_query.where(id_col <= before_anchor)
before_query = before_query.order_by(id_col.desc())
before_query = before_query.limit(before_limit)
if need_after_query:
after_query = query
if not anchored_to_left:
after_query = after_query.where(id_col >= after_anchor)
after_query = after_query.order_by(id_col.asc())
after_query = after_query.limit(after_limit)
if need_both_sides:
return union_all(before_query.self_group(), after_query.self_group())
elif need_before_query:
return before_query
elif need_after_query:
return after_query
else:
# If we don't have either a before_query or after_query, it's because
# some combination of num_before/num_after/anchor are zero or
# use_first_unread_anchor logic found no unread messages.
#
# The most likely reason is somebody is doing an id search, so searching
# for something like `message_id = 42` is exactly what we want. In other
# cases, which could possibly be buggy API clients, at least we will
# return at most one row here.
return query.where(id_col == anchor)
def post_process_limited_query(
rows: Sequence[Union[Row, Sequence[Any]]],
num_before: int,
num_after: int,
anchor: int,
anchored_to_left: bool,
anchored_to_right: bool,
first_visible_message_id: int,
) -> Dict[str, Any]:
# Our queries may have fetched extra rows if they added
# "headroom" to the limits, but we want to truncate those
# rows.
#
# Also, in cases where we had non-zero values of num_before or
# num_after, we want to know found_oldest and found_newest, so
# that the clients will know that they got complete results.
if first_visible_message_id > 0:
visible_rows: Sequence[Union[Row, Sequence[Any]]] = [
r for r in rows if r[0] >= first_visible_message_id
]
else:
visible_rows = rows
rows_limited = len(visible_rows) != len(rows)
if anchored_to_right:
num_after = 0
before_rows = visible_rows[:]
anchor_rows = []
after_rows = []
else:
before_rows = [r for r in visible_rows if r[0] < anchor]
anchor_rows = [r for r in visible_rows if r[0] == anchor]
after_rows = [r for r in visible_rows if r[0] > anchor]
if num_before:
before_rows = before_rows[-1 * num_before :]
if num_after:
after_rows = after_rows[:num_after]
visible_rows = [*before_rows, *anchor_rows, *after_rows]
found_anchor = len(anchor_rows) == 1
found_oldest = anchored_to_left or (len(before_rows) < num_before)
found_newest = anchored_to_right or (len(after_rows) < num_after)
# BUG: history_limited is incorrect False in the event that we had
# to bump `anchor` up due to first_visible_message_id, and there
# were actually older messages. This may be a rare event in the
# context where history_limited is relevant, because it can only
# happen in one-sided queries with no num_before (see tests tagged
# BUG in PostProcessTest for examples), and we don't generally do
# those from the UI, so this might be OK for now.
#
# The correct fix for this probably involves e.g. making a
# `before_query` when we increase `anchor` just to confirm whether
# messages were hidden.
history_limited = rows_limited and found_oldest
return dict(
rows=visible_rows,
found_anchor=found_anchor,
found_newest=found_newest,
found_oldest=found_oldest,
history_limited=history_limited,
)

View File

@ -35,6 +35,7 @@ from zerver.lib.narrow import (
find_first_unread_anchor, find_first_unread_anchor,
is_spectator_compatible, is_spectator_compatible,
ok_to_include_history, ok_to_include_history,
post_process_limited_query,
) )
from zerver.lib.sqlalchemy_utils import get_sqlalchemy_connection from zerver.lib.sqlalchemy_utils import get_sqlalchemy_connection
from zerver.lib.streams import StreamDict, create_streams_if_needed, get_public_streams_queryset from zerver.lib.streams import StreamDict, create_streams_if_needed, get_public_streams_queryset
@ -56,7 +57,7 @@ from zerver.models import (
get_realm, get_realm,
get_stream, get_stream,
) )
from zerver.views.message_fetch import get_messages_backend, post_process_limited_query from zerver.views.message_fetch import get_messages_backend
if TYPE_CHECKING: if TYPE_CHECKING:
from django.test.client import _MonkeyPatchedWSGIResponse as TestHttpResponse from django.test.client import _MonkeyPatchedWSGIResponse as TestHttpResponse

View File

@ -1,22 +1,10 @@
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union from typing import Dict, Iterable, List, Optional, Tuple, Union
from django.contrib.auth.models import AnonymousUser from django.contrib.auth.models import AnonymousUser
from django.http import HttpRequest, HttpResponse from django.http import HttpRequest, HttpResponse
from django.utils.html import escape as escape_html from django.utils.html import escape as escape_html
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from sqlalchemy.engine import Row from sqlalchemy.sql import and_, column, join, literal, literal_column, select, table
from sqlalchemy.sql import (
ColumnElement,
Select,
and_,
column,
join,
literal,
literal_column,
select,
table,
union_all,
)
from sqlalchemy.sql.selectable import SelectBase from sqlalchemy.sql.selectable import SelectBase
from sqlalchemy.types import Integer, Text from sqlalchemy.types import Integer, Text
@ -32,9 +20,11 @@ from zerver.lib.narrow import (
get_base_query_for_search, get_base_query_for_search,
is_spectator_compatible, is_spectator_compatible,
is_web_public_narrow, is_web_public_narrow,
limit_query_to_range,
narrow_parameter, narrow_parameter,
ok_to_include_history, ok_to_include_history,
parse_anchor_value, parse_anchor_value,
post_process_limited_query,
) )
from zerver.lib.request import REQ, RequestNotes, has_request_variables from zerver.lib.request import REQ, RequestNotes, has_request_variables
from zerver.lib.response import json_success from zerver.lib.response import json_success
@ -327,156 +317,6 @@ def get_messages_backend(
return json_success(request, data=ret) return json_success(request, data=ret)
def limit_query_to_range(
query: Select,
num_before: int,
num_after: int,
anchor: int,
anchored_to_left: bool,
anchored_to_right: bool,
id_col: ColumnElement[Integer],
first_visible_message_id: int,
) -> SelectBase:
"""
This code is actually generic enough that we could move it to a
library, but our only caller for now is message search.
"""
need_before_query = (not anchored_to_left) and (num_before > 0)
need_after_query = (not anchored_to_right) and (num_after > 0)
need_both_sides = need_before_query and need_after_query
# The semantics of our flags are as follows:
#
# num_after = number of rows < anchor
# num_after = number of rows > anchor
#
# But we also want the row where id == anchor (if it exists),
# and we don't want to union up to 3 queries. So in some cases
# we do things like `after_limit = num_after + 1` to grab the
# anchor row in the "after" query.
#
# Note that in some cases, if the anchor row isn't found, we
# actually may fetch an extra row at one of the extremes.
if need_both_sides:
before_anchor = anchor - 1
after_anchor = max(anchor, first_visible_message_id)
before_limit = num_before
after_limit = num_after + 1
elif need_before_query:
before_anchor = anchor
before_limit = num_before
if not anchored_to_right:
before_limit += 1
elif need_after_query:
after_anchor = max(anchor, first_visible_message_id)
after_limit = num_after + 1
if need_before_query:
before_query = query
if not anchored_to_right:
before_query = before_query.where(id_col <= before_anchor)
before_query = before_query.order_by(id_col.desc())
before_query = before_query.limit(before_limit)
if need_after_query:
after_query = query
if not anchored_to_left:
after_query = after_query.where(id_col >= after_anchor)
after_query = after_query.order_by(id_col.asc())
after_query = after_query.limit(after_limit)
if need_both_sides:
return union_all(before_query.self_group(), after_query.self_group())
elif need_before_query:
return before_query
elif need_after_query:
return after_query
else:
# If we don't have either a before_query or after_query, it's because
# some combination of num_before/num_after/anchor are zero or
# use_first_unread_anchor logic found no unread messages.
#
# The most likely reason is somebody is doing an id search, so searching
# for something like `message_id = 42` is exactly what we want. In other
# cases, which could possibly be buggy API clients, at least we will
# return at most one row here.
return query.where(id_col == anchor)
def post_process_limited_query(
rows: Sequence[Union[Row, Sequence[Any]]],
num_before: int,
num_after: int,
anchor: int,
anchored_to_left: bool,
anchored_to_right: bool,
first_visible_message_id: int,
) -> Dict[str, Any]:
# Our queries may have fetched extra rows if they added
# "headroom" to the limits, but we want to truncate those
# rows.
#
# Also, in cases where we had non-zero values of num_before or
# num_after, we want to know found_oldest and found_newest, so
# that the clients will know that they got complete results.
if first_visible_message_id > 0:
visible_rows: Sequence[Union[Row, Sequence[Any]]] = [
r for r in rows if r[0] >= first_visible_message_id
]
else:
visible_rows = rows
rows_limited = len(visible_rows) != len(rows)
if anchored_to_right:
num_after = 0
before_rows = visible_rows[:]
anchor_rows = []
after_rows = []
else:
before_rows = [r for r in visible_rows if r[0] < anchor]
anchor_rows = [r for r in visible_rows if r[0] == anchor]
after_rows = [r for r in visible_rows if r[0] > anchor]
if num_before:
before_rows = before_rows[-1 * num_before :]
if num_after:
after_rows = after_rows[:num_after]
visible_rows = [*before_rows, *anchor_rows, *after_rows]
found_anchor = len(anchor_rows) == 1
found_oldest = anchored_to_left or (len(before_rows) < num_before)
found_newest = anchored_to_right or (len(after_rows) < num_after)
# BUG: history_limited is incorrect False in the event that we had
# to bump `anchor` up due to first_visible_message_id, and there
# were actually older messages. This may be a rare event in the
# context where history_limited is relevant, because it can only
# happen in one-sided queries with no num_before (see tests tagged
# BUG in PostProcessTest for examples), and we don't generally do
# those from the UI, so this might be OK for now.
#
# The correct fix for this probably involves e.g. making a
# `before_query` when we increase `anchor` just to confirm whether
# messages were hidden.
history_limited = rows_limited and found_oldest
return dict(
rows=visible_rows,
found_anchor=found_anchor,
found_newest=found_newest,
found_oldest=found_oldest,
history_limited=history_limited,
)
@has_request_variables @has_request_variables
def messages_in_narrow_backend( def messages_in_narrow_backend(
request: HttpRequest, request: HttpRequest,