mypy: Use sqlalchemy-stubs.

Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
Anders Kaseorg 2020-11-16 13:52:27 -08:00 committed by Anders Kaseorg
parent 8e0240300a
commit 13e35bfa94
9 changed files with 132 additions and 114 deletions

View File

@ -680,7 +680,7 @@ mypy==0.790 \
--hash=sha256:e86bdace26c5fe9cf8cb735e7cedfe7850ad92b327ac5d797c656717d2ca66de \
--hash=sha256:e97e9c13d67fbe524be17e4d8025d51a7dca38f90de2e462243ab8ed8a9178d1 \
--hash=sha256:eea260feb1830a627fb526d22fbb426b750d9f5a47b624e8d5e7e004359b219c \
# via -r requirements/mypy.in
# via -r requirements/mypy.in, sqlalchemy-stubs
networkx==2.5 \
--hash=sha256:7978955423fbc9639c10498878be59caf99b44dc304c2286162fd24b458c1602 \
--hash=sha256:8c5812e9f798d37c50570d15c4a69d5710a18d77bafc903ee9c5fba7454c616c \
@ -1143,6 +1143,10 @@ sphinxcontrib-serializinghtml==1.1.4 \
--hash=sha256:eaa0eccc86e982a9b939b2b82d12cc5d013385ba5eadcc7e4fed23f4405f77bc \
--hash=sha256:f242a81d423f59617a8e5cf16f5d4d74e28ee9a66f9e5b637a18082991db5a9a \
# via sphinx
sqlalchemy-stubs==0.3 \
--hash=sha256:a3318c810697164e8c818aa2d90bac570c1a0e752ced3ec25455b309c0bee8fd \
--hash=sha256:ca1250605a39648cc433f5c70cb1a6f9fe0b60bdda4c51e1f9a2ab3651daadc8 \
# via -r requirements/mypy.in
sqlalchemy==1.3.20 \
--hash=sha256:009e8388d4d551a2107632921320886650b46332f61dc935e70c8bcf37d8e0d6 \
--hash=sha256:0157c269701d88f5faf1fa0e4560e4d814f210c01a5b55df3cab95e9346a8bcc \
@ -1296,7 +1300,7 @@ typing-extensions==3.7.4.3 \
--hash=sha256:7cb407020f00f7bfc3cb3e7881628838e69d8f3fcab2f64742a5e76b2f841918 \
--hash=sha256:99d4073b617d30288f569d3f13d2bd7548c3a7e4c8de87db09a9d29bb3a4a60c \
--hash=sha256:dafc7639cde7f1b6e1acc0f457842a83e722ccca8eef5270af2d74792619a89f \
# via -r requirements/common.in, libcst, mypy, pyre-extensions, typing-inspect, zulint
# via -r requirements/common.in, libcst, mypy, pyre-extensions, sqlalchemy-stubs, typing-inspect, zulint
typing-inspect==0.6.0 \
--hash=sha256:3b98390df4d999a28cf5b35d8b333425af5da2ece8a4ea9e98f71e7591347b4f \
--hash=sha256:8f1b1dd25908dbfd81d3bebc218011531e7ab614ba6e5bf7826d887c834afab7 \

View File

@ -3,3 +3,4 @@
# and requirements/mypy.txt.
# See requirements/README.md for more detail.
mypy
sqlalchemy-stubs

View File

@ -26,6 +26,10 @@ mypy==0.790 \
--hash=sha256:e86bdace26c5fe9cf8cb735e7cedfe7850ad92b327ac5d797c656717d2ca66de \
--hash=sha256:e97e9c13d67fbe524be17e4d8025d51a7dca38f90de2e462243ab8ed8a9178d1 \
--hash=sha256:eea260feb1830a627fb526d22fbb426b750d9f5a47b624e8d5e7e004359b219c \
# via -r requirements/mypy.in, sqlalchemy-stubs
sqlalchemy-stubs==0.3 \
--hash=sha256:a3318c810697164e8c818aa2d90bac570c1a0e752ced3ec25455b309c0bee8fd \
--hash=sha256:ca1250605a39648cc433f5c70cb1a6f9fe0b60bdda4c51e1f9a2ab3651daadc8 \
# via -r requirements/mypy.in
typed-ast==1.4.1 \
--hash=sha256:0666aa36131496aed8f7be0410ff974562ab7eeac11ef351def9ea6fa28f6355 \
@ -63,4 +67,4 @@ typing-extensions==3.7.4.3 \
--hash=sha256:7cb407020f00f7bfc3cb3e7881628838e69d8f3fcab2f64742a5e76b2f841918 \
--hash=sha256:99d4073b617d30288f569d3f13d2bd7548c3a7e4c8de87db09a9d29bb3a4a60c \
--hash=sha256:dafc7639cde7f1b6e1acc0f457842a83e722ccca8eef5270af2d74792619a89f \
# via mypy
# via mypy, sqlalchemy-stubs

View File

@ -43,4 +43,4 @@ API_FEATURE_LEVEL = 35
# historical commits sharing the same major version, in which case a
# minor version bump suffices.
PROVISION_VERSION = '115.0'
PROVISION_VERSION = '115.1'

View File

@ -16,13 +16,15 @@ class NonClosingPool(sqlalchemy.pool.NullPool):
pass
def recreate(self) -> 'NonClosingPool':
return self.__class__(creator=self._creator,
recycle=self._recycle,
use_threadlocal=self._use_threadlocal,
reset_on_return=self._reset_on_return,
echo=self.echo,
logging_name=self._orig_logging_name,
_dispatch=self.dispatch)
return self.__class__(
creator=self._creator, # type: ignore[attr-defined] # implementation detail
recycle=self._recycle, # type: ignore[attr-defined] # implementation detail
use_threadlocal=self._use_threadlocal, # type: ignore[attr-defined] # implementation detail
reset_on_return=self._reset_on_return, # type: ignore[attr-defined] # implementation detail
echo=self.echo,
logging_name=self._orig_logging_name, # type: ignore[attr-defined] # implementation detail
_dispatch=self.dispatch, # type: ignore[attr-defined] # implementation detail
)
sqlalchemy_engine: Optional[Any] = None
def get_sqlalchemy_connection() -> sqlalchemy.engine.base.Connection:

View File

@ -2,7 +2,8 @@ from typing import Any, Dict, List, Optional, Tuple
from django.db import connection
from django.db.models.query import Q, QuerySet
from sqlalchemy.sql import column, func, literal
from sqlalchemy import Text
from sqlalchemy.sql import ColumnElement, column, func, literal
from zerver.lib.request import REQ
from zerver.models import Message, Stream, UserMessage, UserProfile
@ -65,14 +66,14 @@ using "subject" in the DB sense, and nothing customer facing.
DB_TOPIC_NAME = "subject"
MESSAGE__TOPIC = 'message__subject'
def topic_match_sa(topic_name: str) -> Any:
def topic_match_sa(topic_name: str) -> "ColumnElement[bool]":
# _sa is short for SQLAlchemy, which we use mostly for
# queries that search messages
topic_cond = func.upper(column("subject")) == func.upper(literal(topic_name))
topic_cond = func.upper(column("subject", Text)) == func.upper(literal(topic_name))
return topic_cond
def topic_column_sa() -> Any:
return column("subject")
def topic_column_sa() -> "ColumnElement[str]":
return column("subject", Text)
def filter_by_exact_message_topic(query: QuerySet, message: Message) -> QuerySet:
topic_name = message.topic_name()

View File

@ -2,7 +2,7 @@ import datetime
from typing import Any, Callable, Dict, List, Optional, Tuple
from django.utils.timezone import now as timezone_now
from sqlalchemy.sql import Selectable, and_, column, not_, or_
from sqlalchemy.sql import ClauseElement, and_, column, not_, or_
from zerver.lib.timestamp import datetime_to_timestamp
from zerver.lib.topic import topic_match_sa
@ -74,9 +74,9 @@ def topic_is_muted(user_profile: UserProfile, stream_id: int, topic_name: str) -
).exists()
return is_muted
def exclude_topic_mutes(conditions: List[Selectable],
def exclude_topic_mutes(conditions: List[ClauseElement],
user_profile: UserProfile,
stream_id: Optional[int]) -> List[Selectable]:
stream_id: Optional[int]) -> List[ClauseElement]:
query = MutedTopic.objects.filter(
user_profile=user_profile,
)
@ -95,7 +95,7 @@ def exclude_topic_mutes(conditions: List[Selectable],
if not rows:
return conditions
def mute_cond(row: Dict[str, Any]) -> Selectable:
def mute_cond(row: Dict[str, Any]) -> ClauseElement:
recipient_id = row['recipient_id']
topic_name = row['topic_name']
stream_cond = column("recipient_id") == recipient_id

View File

@ -8,7 +8,7 @@ from django.db import connection
from django.http import HttpResponse
from django.test import override_settings
from django.utils.timezone import now as timezone_now
from sqlalchemy.sql import and_, column, select, table
from sqlalchemy.sql import Select, and_, column, select, table
from sqlalchemy.sql.elements import ClauseElement
from analytics.lib.counts import COUNT_STATS
@ -54,7 +54,6 @@ from zerver.views.message_fetch import (
LARGER_THAN_MAX_MESSAGE_ID,
BadNarrowOperator,
NarrowBuilder,
Query,
exclude_muting_conditions,
find_first_unread_anchor,
get_messages_backend,
@ -454,7 +453,7 @@ class NarrowBuilderTest(ZulipTestCase):
term = dict(operator='stream', operand='non-web-public-stream')
builder = NarrowBuilder(self.user_profile, column('id'), self.realm, True)
def _build_query(term: Dict[str, Any]) -> Query:
def _build_query(term: Dict[str, Any]) -> Select:
return builder.add_term(self.raw_query, term)
self.assertRaises(BadNarrowOperator, _build_query, term)
@ -467,7 +466,7 @@ class NarrowBuilderTest(ZulipTestCase):
self.assertEqual(actual_params, params)
self.assertIn(where_clause, get_sqlalchemy_sql(query))
def _build_query(self, term: Dict[str, Any]) -> Query:
def _build_query(self, term: Dict[str, Any]) -> Select:
return self.builder.add_term(self.raw_query, term)
class NarrowLibraryTest(ZulipTestCase):

View File

@ -1,5 +1,5 @@
import re
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
import orjson
from django.conf import settings
@ -11,9 +11,12 @@ from django.utils.html import escape as escape_html
from django.utils.translation import ugettext as _
from sqlalchemy import func
from sqlalchemy.dialects import postgresql
from sqlalchemy.engine import Connection, RowProxy
from sqlalchemy.sql import (
ClauseElement,
ColumnElement,
Selectable,
FromClause,
Select,
alias,
and_,
column,
@ -26,6 +29,7 @@ from sqlalchemy.sql import (
table,
union_all,
)
from sqlalchemy.types import Boolean, Integer, Text
from zerver.context_processors import get_valid_realm_from_request
from zerver.decorator import REQ, has_request_variables
@ -84,11 +88,7 @@ class BadNarrowOperator(JsonableError):
def msg_format() -> str:
return _('Invalid narrow operator: {desc}')
# TODO: Should be Select, but sqlalchemy stubs are busted
Query = Any
# TODO: should be Callable[[ColumnElement], ColumnElement], but sqlalchemy stubs are busted
ConditionTransform = Any
ConditionTransform = Callable[[ClauseElement], ClauseElement]
OptionalNarrowListT = Optional[List[Dict[str, Any]]]
@ -97,21 +97,24 @@ TS_START = "<ts-match>"
TS_STOP = "</ts-match>"
def ts_locs_array(
config: ColumnElement, text: ColumnElement, tsquery: ColumnElement,
) -> ColumnElement:
config: "ColumnElement[str]", text: "ColumnElement[str]", tsquery: "ColumnElement[object]",
) -> "ColumnElement[List[List[int]]]":
options = f"HighlightAll = TRUE, StartSel = {TS_START}, StopSel = {TS_STOP}"
delimited = func.ts_headline(config, text, tsquery, options)
parts = func.unnest(func.string_to_array(delimited, TS_START)).alias()
part = column(parts.name)
part = column(parts.name, Text)
part_len = func.length(part) - len(TS_STOP)
match_pos = func.sum(part_len).over(rows=(None, -1)) + len(TS_STOP)
match_len = func.strpos(part, TS_STOP) - 1
return func.array(
select([postgresql.array([match_pos, match_len])])
ret = func.array(
select([
postgresql.array([match_pos, match_len]), # type: ignore[call-overload] # https://github.com/dropbox/sqlalchemy-stubs/issues/188
])
.select_from(parts)
.offset(1)
.as_scalar(),
)
return ret
# When you add a new operator to this, also update zerver/lib/narrow.py
class NarrowBuilder:
@ -124,8 +127,8 @@ class NarrowBuilder:
# None of these methods ever *add* messages to a query's result.
#
# That is, the `add_term` method, and its helpers the `by_*` methods,
# are passed a Query object representing a query for messages; they may
# call some methods on it, and then they return a resulting Query
# are passed a Select object representing a query for messages; they may
# call some methods on it, and then they return a resulting Select
# object. Things these methods may do to the queries they handle
# include
# * add conditions to filter out rows (i.e., messages), with `query.where`
@ -136,14 +139,14 @@ class NarrowBuilder:
# * anything that would pull in additional rows, or information on
# other messages.
def __init__(self, user_profile: Optional[UserProfile], msg_id_column: str,
def __init__(self, user_profile: Optional[UserProfile], msg_id_column: "ColumnElement[int]",
realm: Realm, is_web_public_query: bool=False) -> None:
self.user_profile = user_profile
self.msg_id_column = msg_id_column
self.realm = realm
self.is_web_public_query = is_web_public_query
def add_term(self, query: Query, term: Dict[str, Any]) -> Query:
def add_term(self, query: Select, term: Dict[str, Any]) -> Select:
"""
Extend the given query to one narrowed by the given term, and return the result.
@ -176,14 +179,14 @@ class NarrowBuilder:
return method(query, operand, maybe_negate)
def by_has(self, query: Query, operand: str, maybe_negate: ConditionTransform) -> Query:
def by_has(self, query: Select, operand: str, maybe_negate: ConditionTransform) -> Select:
if operand not in ['attachment', 'image', 'link']:
raise BadNarrowOperator("unknown 'has' operand " + operand)
col_name = 'has_' + operand
cond = column(col_name)
cond = column(col_name, Boolean)
return query.where(maybe_negate(cond))
def by_in(self, query: Query, operand: str, maybe_negate: ConditionTransform) -> Query:
def by_in(self, query: Select, operand: str, maybe_negate: ConditionTransform) -> Select:
# This operator does not support is_web_public_query.
assert not self.is_web_public_query
assert self.user_profile is not None
@ -196,27 +199,27 @@ class NarrowBuilder:
raise BadNarrowOperator("unknown 'in' operand " + operand)
def by_is(self, query: Query, operand: str, maybe_negate: ConditionTransform) -> Query:
def by_is(self, query: Select, operand: str, maybe_negate: ConditionTransform) -> Select:
# This operator class does not support is_web_public_query.
assert not self.is_web_public_query
assert self.user_profile is not None
if operand == 'private':
cond = column("flags").op("&")(UserMessage.flags.is_private.mask) != 0
cond = column("flags", Integer).op("&")(UserMessage.flags.is_private.mask) != 0
return query.where(maybe_negate(cond))
elif operand == 'starred':
cond = column("flags").op("&")(UserMessage.flags.starred.mask) != 0
cond = column("flags", Integer).op("&")(UserMessage.flags.starred.mask) != 0
return query.where(maybe_negate(cond))
elif operand == 'unread':
cond = column("flags").op("&")(UserMessage.flags.read.mask) == 0
cond = column("flags", Integer).op("&")(UserMessage.flags.read.mask) == 0
return query.where(maybe_negate(cond))
elif operand == 'mentioned':
cond1 = column("flags").op("&")(UserMessage.flags.mentioned.mask) != 0
cond2 = column("flags").op("&")(UserMessage.flags.wildcard_mentioned.mask) != 0
cond1 = column("flags", Integer).op("&")(UserMessage.flags.mentioned.mask) != 0
cond2 = column("flags", Integer).op("&")(UserMessage.flags.wildcard_mentioned.mask) != 0
cond = or_(cond1, cond2)
return query.where(maybe_negate(cond))
elif operand == 'alerted':
cond = column("flags").op("&")(UserMessage.flags.has_alert_word.mask) != 0
cond = column("flags", Integer).op("&")(UserMessage.flags.has_alert_word.mask) != 0
return query.where(maybe_negate(cond))
raise BadNarrowOperator("unknown 'is' operand " + operand)
@ -242,7 +245,7 @@ class NarrowBuilder:
s[i] = '\\' + c
return ''.join(s)
def by_stream(self, query: Query, operand: Union[str, int], maybe_negate: ConditionTransform) -> Query:
def by_stream(self, query: Select, operand: Union[str, int], maybe_negate: ConditionTransform) -> Select:
try:
# Because you can see your own message history for
# private streams you are no longer subscribed to, we
@ -274,14 +277,14 @@ class NarrowBuilder:
matching_streams = get_active_streams(self.realm).filter(
name__iregex=fr'^(un)*{self._pg_re_escape(base_stream_name)}(\.d)*$')
recipient_ids = [matching_stream.recipient_id for matching_stream in matching_streams]
cond = column("recipient_id").in_(recipient_ids)
cond = column("recipient_id", Integer).in_(recipient_ids)
return query.where(maybe_negate(cond))
recipient = stream.recipient
cond = column("recipient_id") == recipient.id
cond = column("recipient_id", Integer) == recipient.id
return query.where(maybe_negate(cond))
def by_streams(self, query: Query, operand: str, maybe_negate: ConditionTransform) -> Query:
def by_streams(self, query: Select, operand: str, maybe_negate: ConditionTransform) -> Select:
if operand == 'public':
# Get all both subscribed and non subscribed public streams
# but exclude any private subscribed streams.
@ -292,10 +295,10 @@ class NarrowBuilder:
raise BadNarrowOperator('unknown streams operand ' + operand)
recipient_ids = recipient_queryset.values_list("recipient_id", flat=True).order_by('id')
cond = column("recipient_id").in_(recipient_ids)
cond = column("recipient_id", Integer).in_(recipient_ids)
return query.where(maybe_negate(cond))
def by_topic(self, query: Query, operand: str, maybe_negate: ConditionTransform) -> Query:
def by_topic(self, query: Select, operand: str, maybe_negate: ConditionTransform) -> Select:
if self.realm.is_zephyr_mirror_realm:
# MIT users expect narrowing to topic "foo" to also show messages to /^foo(.d)*$/
# (foo, foo.d, foo.d.d, etc)
@ -307,7 +310,7 @@ class NarrowBuilder:
# Additionally, MIT users expect the empty instance and
# instance "personal" to be the same.
if base_topic in ('', 'personal', '(instance "")'):
cond = or_(
cond: ClauseElement = or_(
topic_match_sa(""),
topic_match_sa(".d"),
topic_match_sa(".d.d"),
@ -340,7 +343,7 @@ class NarrowBuilder:
cond = topic_match_sa(operand)
return query.where(maybe_negate(cond))
def by_sender(self, query: Query, operand: Union[str, int], maybe_negate: ConditionTransform) -> Query:
def by_sender(self, query: Select, operand: Union[str, int], maybe_negate: ConditionTransform) -> Select:
try:
if isinstance(operand, str):
sender = get_user_including_cross_realm(operand, self.realm)
@ -349,20 +352,20 @@ class NarrowBuilder:
except UserProfile.DoesNotExist:
raise BadNarrowOperator('unknown user ' + str(operand))
cond = column("sender_id") == literal(sender.id)
cond = column("sender_id", Integer) == literal(sender.id)
return query.where(maybe_negate(cond))
def by_near(self, query: Query, operand: str, maybe_negate: ConditionTransform) -> Query:
def by_near(self, query: Select, operand: str, maybe_negate: ConditionTransform) -> Select:
return query
def by_id(self, query: Query, operand: str, maybe_negate: ConditionTransform) -> Query:
def by_id(self, query: Select, operand: Union[int, str], maybe_negate: ConditionTransform) -> Select:
if not str(operand).isdigit():
raise BadNarrowOperator("Invalid message ID")
cond = self.msg_id_column == literal(operand)
return query.where(maybe_negate(cond))
def by_pm_with(self, query: Query, operand: Union[str, Iterable[int]],
maybe_negate: ConditionTransform) -> Query:
def by_pm_with(self, query: Select, operand: Union[str, Iterable[int]],
maybe_negate: ConditionTransform) -> Select:
# This operator does not support is_web_public_query.
assert not self.is_web_public_query
assert self.user_profile is not None
@ -394,7 +397,7 @@ class NarrowBuilder:
# Group DM
if recipient.type == Recipient.HUDDLE:
cond = column("recipient_id") == recipient.id
cond = column("recipient_id", Integer) == recipient.id
return query.where(maybe_negate(cond))
# 1:1 PM
@ -414,19 +417,19 @@ class NarrowBuilder:
# complex query to get messages between these two users
# with either of them as the sender.
self_recipient_id = self.user_profile.recipient_id
cond = or_(and_(column("sender_id") == other_participant.id,
column("recipient_id") == self_recipient_id),
and_(column("sender_id") == self.user_profile.id,
column("recipient_id") == recipient.id))
cond = or_(and_(column("sender_id", Integer) == other_participant.id,
column("recipient_id", Integer) == self_recipient_id),
and_(column("sender_id", Integer) == self.user_profile.id,
column("recipient_id", Integer) == recipient.id))
return query.where(maybe_negate(cond))
# PM with self
cond = and_(column("sender_id") == self.user_profile.id,
column("recipient_id") == recipient.id)
cond = and_(column("sender_id", Integer) == self.user_profile.id,
column("recipient_id", Integer) == recipient.id)
return query.where(maybe_negate(cond))
def by_group_pm_with(self, query: Query, operand: Union[str, int],
maybe_negate: ConditionTransform) -> Query:
def by_group_pm_with(self, query: Select, operand: Union[str, int],
maybe_negate: ConditionTransform) -> Select:
# This operator does not support is_web_public_query.
assert not self.is_web_public_query
assert self.user_profile is not None
@ -453,33 +456,33 @@ class NarrowBuilder:
).values("recipient_id")]
recipient_ids = set(self_recipient_ids) & set(narrow_recipient_ids)
cond = column("recipient_id").in_(recipient_ids)
cond = column("recipient_id", Integer).in_(recipient_ids)
return query.where(maybe_negate(cond))
def by_search(self, query: Query, operand: str, maybe_negate: ConditionTransform) -> Query:
def by_search(self, query: Select, operand: str, maybe_negate: ConditionTransform) -> Select:
if settings.USING_PGROONGA:
return self._by_search_pgroonga(query, operand, maybe_negate)
else:
return self._by_search_tsearch(query, operand, maybe_negate)
def _by_search_pgroonga(self, query: Query, operand: str,
maybe_negate: ConditionTransform) -> Query:
def _by_search_pgroonga(self, query: Select, operand: str,
maybe_negate: ConditionTransform) -> Select:
match_positions_character = func.pgroonga_match_positions_character
query_extract_keywords = func.pgroonga_query_extract_keywords
operand_escaped = func.escape_html(operand)
keywords = query_extract_keywords(operand_escaped)
query = query.column(match_positions_character(column("rendered_content"),
query = query.column(match_positions_character(column("rendered_content", Text),
keywords).label("content_matches"))
query = query.column(match_positions_character(func.escape_html(topic_column_sa()),
keywords).label("topic_matches"))
condition = column("search_pgroonga").op("&@~")(operand_escaped)
return query.where(maybe_negate(condition))
def _by_search_tsearch(self, query: Query, operand: str,
maybe_negate: ConditionTransform) -> Query:
def _by_search_tsearch(self, query: Select, operand: str,
maybe_negate: ConditionTransform) -> Select:
tsquery = func.plainto_tsquery(literal("zulip.english_us_search"), literal(operand))
query = query.column(ts_locs_array(literal("zulip.english_us_search"),
column("rendered_content"),
column("rendered_content", Text),
tsquery).label("content_matches"))
# We HTML-escape the topic in PostgreSQL to avoid doing a server round-trip
query = query.column(ts_locs_array(literal("zulip.english_us_search"),
@ -494,11 +497,11 @@ class NarrowBuilder:
if term[0] == '"' and term[-1] == '"':
term = term[1:-1]
term = '%' + connection.ops.prep_for_like_query(term) + '%'
cond = or_(column("content").ilike(term),
cond = or_(column("content", Text).ilike(term),
topic_column_sa().ilike(term))
query = query.where(maybe_negate(cond))
cond = column("search_tsvector").op("@@")(tsquery)
cond = column("search_tsvector", postgresql.TSVECTOR).op("@@")(tsquery)
return query.where(maybe_negate(cond))
def highlight_string(text: str, locs: Iterable[Tuple[int, int]]) -> str:
@ -658,7 +661,7 @@ def get_stream_from_narrow_access_unchecked(narrow: OptionalNarrowListT, realm:
return None
def exclude_muting_conditions(user_profile: UserProfile,
narrow: OptionalNarrowListT) -> List[Selectable]:
narrow: OptionalNarrowListT) -> List[ClauseElement]:
conditions = []
stream_id = None
try:
@ -681,7 +684,7 @@ def exclude_muting_conditions(user_profile: UserProfile,
muted_recipient_ids = [row['recipient_id'] for row in rows]
if len(muted_recipient_ids) > 0:
# Only add the condition if we have muted streams to simplify/avoid warnings.
condition = not_(column("recipient_id").in_(muted_recipient_ids))
condition = not_(column("recipient_id", Integer).in_(muted_recipient_ids))
conditions.append(condition)
conditions = exclude_topic_mutes(conditions, user_profile, stream_id)
@ -690,38 +693,39 @@ def exclude_muting_conditions(user_profile: UserProfile,
def get_base_query_for_search(user_profile: Optional[UserProfile],
need_message: bool,
need_user_message: bool) -> Tuple[Query, ColumnElement]:
need_user_message: bool) -> Tuple[Select, "ColumnElement[int]"]:
# Handle the simple case where user_message isn't involved first.
if not need_user_message:
assert(need_message)
query = select([column("id").label("message_id")],
query = select([column("id", Integer).label("message_id")],
None,
table("zerver_message"))
inner_msg_id_col = literal_column("zerver_message.id")
inner_msg_id_col: ColumnElement[int]
inner_msg_id_col = literal_column("zerver_message.id", Integer) # type: ignore[assignment] # https://github.com/dropbox/sqlalchemy-stubs/pull/189
return (query, inner_msg_id_col)
assert user_profile is not None
if need_message:
query = select([column("message_id"), column("flags")],
query = select([column("message_id"), column("flags", Integer)],
column("user_profile_id") == literal(user_profile.id),
join(table("zerver_usermessage"), table("zerver_message"),
literal_column("zerver_usermessage.message_id") ==
literal_column("zerver_message.id")))
inner_msg_id_col = column("message_id")
literal_column("zerver_usermessage.message_id", Integer) ==
literal_column("zerver_message.id", Integer)))
inner_msg_id_col = column("message_id", Integer)
return (query, inner_msg_id_col)
query = select([column("message_id"), column("flags")],
query = select([column("message_id"), column("flags", Integer)],
column("user_profile_id") == literal(user_profile.id),
table("zerver_usermessage"))
inner_msg_id_col = column("message_id")
inner_msg_id_col = column("message_id", Integer)
return (query, inner_msg_id_col)
def add_narrow_conditions(user_profile: Optional[UserProfile],
inner_msg_id_col: ColumnElement,
query: Query,
inner_msg_id_col: "ColumnElement[int]",
query: Select,
narrow: OptionalNarrowListT,
is_web_public_query: bool,
realm: Realm) -> Tuple[Query, bool]:
realm: Realm) -> Tuple[Select, bool]:
is_search = False # for now
if narrow is None:
@ -742,7 +746,7 @@ def add_narrow_conditions(user_profile: Optional[UserProfile],
if search_operands:
is_search = True
query = query.column(topic_column_sa()).column(column("rendered_content"))
query = query.column(topic_column_sa()).column(column("rendered_content", Text))
search_term = dict(
operator='search',
operand=' '.join(search_operands),
@ -751,7 +755,7 @@ def add_narrow_conditions(user_profile: Optional[UserProfile],
return (query, is_search)
def find_first_unread_anchor(sa_conn: Any,
def find_first_unread_anchor(sa_conn: Connection,
user_profile: Optional[UserProfile],
narrow: OptionalNarrowListT) -> int:
# For anonymous web users, all messages are treated as read, and so
@ -785,7 +789,7 @@ def find_first_unread_anchor(sa_conn: Any,
realm=user_profile.realm,
)
condition = column("flags").op("&")(UserMessage.flags.read.mask) == 0
condition = column("flags", Integer).op("&")(UserMessage.flags.read.mask) == 0
# We exclude messages on muted topics when finding the first unread
# message in this narrow
@ -913,6 +917,7 @@ def get_messages_backend(request: HttpRequest,
need_message = True
need_user_message = True
query: FromClause
query, inner_msg_id_col = get_base_query_for_search(
user_profile=user_profile,
need_message=need_message,
@ -970,7 +975,7 @@ def get_messages_backend(request: HttpRequest,
)
main_query = alias(query)
query = select(main_query.c, None, main_query).order_by(column("message_id").asc())
query = select(main_query.c, None, main_query).order_by(column("message_id", Integer).asc())
# This is a hack to tag the query we use for testing
query = query.prefix_with("/* get_messages */")
rows = list(sa_conn.execute(query).fetchall())
@ -1059,14 +1064,14 @@ def get_messages_backend(request: HttpRequest,
)
return json_success(ret)
def limit_query_to_range(query: Query,
def limit_query_to_range(query: Select,
num_before: int,
num_after: int,
anchor: int,
anchored_to_left: bool,
anchored_to_right: bool,
id_col: ColumnElement,
first_visible_message_id: int) -> Query:
id_col: "ColumnElement[int]",
first_visible_message_id: int) -> FromClause:
'''
This code is actually generic enough that we could move it to a
library, but our only caller for now is message search.
@ -1137,7 +1142,7 @@ def limit_query_to_range(query: Query,
# return at most one row here.
return query.where(id_col == anchor)
def post_process_limited_query(rows: List[Any],
def post_process_limited_query(rows: Sequence[Union[RowProxy, Sequence[Any]]],
num_before: int,
num_after: int,
anchor: int,
@ -1153,7 +1158,9 @@ def post_process_limited_query(rows: List[Any],
# that the clients will know that they got complete results.
if first_visible_message_id > 0:
visible_rows = [r for r in rows if r[0] >= first_visible_message_id]
visible_rows: Sequence[Union[RowProxy, Sequence[Any]]] = [
r for r in rows if r[0] >= first_visible_message_id
]
else:
visible_rows = rows
@ -1162,8 +1169,8 @@ def post_process_limited_query(rows: List[Any],
if anchored_to_right:
num_after = 0
before_rows = visible_rows[:]
anchor_rows: List[Any] = []
after_rows: List[Any] = []
anchor_rows = []
after_rows = []
else:
before_rows = [r for r in visible_rows if r[0] < anchor]
anchor_rows = [r for r in visible_rows if r[0] == anchor]
@ -1175,7 +1182,7 @@ def post_process_limited_query(rows: List[Any],
if num_after:
after_rows = after_rows[:num_after]
visible_rows = before_rows + anchor_rows + after_rows
visible_rows = [*before_rows, *anchor_rows, *after_rows]
found_anchor = len(anchor_rows) == 1
found_oldest = anchored_to_left or (len(before_rows) < num_before)
@ -1211,14 +1218,14 @@ def messages_in_narrow_backend(request: HttpRequest, user_profile: UserProfile,
msg_ids = [message_id for message_id in msg_ids if message_id >= first_visible_message_id]
# This query is limited to messages the user has access to because they
# actually received them, as reflected in `zerver_usermessage`.
query = select([column("message_id"), topic_column_sa(), column("rendered_content")],
and_(column("user_profile_id") == literal(user_profile.id),
column("message_id").in_(msg_ids)),
query = select([column("message_id", Integer), topic_column_sa(), column("rendered_content", Text)],
and_(column("user_profile_id", Integer) == literal(user_profile.id),
column("message_id", Integer).in_(msg_ids)),
join(table("zerver_usermessage"), table("zerver_message"),
literal_column("zerver_usermessage.message_id") ==
literal_column("zerver_message.id")))
literal_column("zerver_usermessage.message_id", Integer) ==
literal_column("zerver_message.id", Integer)))
builder = NarrowBuilder(user_profile, column("message_id"), user_profile.realm)
builder = NarrowBuilder(user_profile, column("message_id", Integer), user_profile.realm)
if narrow is not None:
for term in narrow:
query = builder.add_term(query, term)