mypy: Use sqlalchemy-stubs.

Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
Anders Kaseorg 2020-11-16 13:52:27 -08:00 committed by Anders Kaseorg
parent 8e0240300a
commit 13e35bfa94
9 changed files with 132 additions and 114 deletions

View File

@ -680,7 +680,7 @@ mypy==0.790 \
--hash=sha256:e86bdace26c5fe9cf8cb735e7cedfe7850ad92b327ac5d797c656717d2ca66de \ --hash=sha256:e86bdace26c5fe9cf8cb735e7cedfe7850ad92b327ac5d797c656717d2ca66de \
--hash=sha256:e97e9c13d67fbe524be17e4d8025d51a7dca38f90de2e462243ab8ed8a9178d1 \ --hash=sha256:e97e9c13d67fbe524be17e4d8025d51a7dca38f90de2e462243ab8ed8a9178d1 \
--hash=sha256:eea260feb1830a627fb526d22fbb426b750d9f5a47b624e8d5e7e004359b219c \ --hash=sha256:eea260feb1830a627fb526d22fbb426b750d9f5a47b624e8d5e7e004359b219c \
# via -r requirements/mypy.in # via -r requirements/mypy.in, sqlalchemy-stubs
networkx==2.5 \ networkx==2.5 \
--hash=sha256:7978955423fbc9639c10498878be59caf99b44dc304c2286162fd24b458c1602 \ --hash=sha256:7978955423fbc9639c10498878be59caf99b44dc304c2286162fd24b458c1602 \
--hash=sha256:8c5812e9f798d37c50570d15c4a69d5710a18d77bafc903ee9c5fba7454c616c \ --hash=sha256:8c5812e9f798d37c50570d15c4a69d5710a18d77bafc903ee9c5fba7454c616c \
@ -1143,6 +1143,10 @@ sphinxcontrib-serializinghtml==1.1.4 \
--hash=sha256:eaa0eccc86e982a9b939b2b82d12cc5d013385ba5eadcc7e4fed23f4405f77bc \ --hash=sha256:eaa0eccc86e982a9b939b2b82d12cc5d013385ba5eadcc7e4fed23f4405f77bc \
--hash=sha256:f242a81d423f59617a8e5cf16f5d4d74e28ee9a66f9e5b637a18082991db5a9a \ --hash=sha256:f242a81d423f59617a8e5cf16f5d4d74e28ee9a66f9e5b637a18082991db5a9a \
# via sphinx # via sphinx
sqlalchemy-stubs==0.3 \
--hash=sha256:a3318c810697164e8c818aa2d90bac570c1a0e752ced3ec25455b309c0bee8fd \
--hash=sha256:ca1250605a39648cc433f5c70cb1a6f9fe0b60bdda4c51e1f9a2ab3651daadc8 \
# via -r requirements/mypy.in
sqlalchemy==1.3.20 \ sqlalchemy==1.3.20 \
--hash=sha256:009e8388d4d551a2107632921320886650b46332f61dc935e70c8bcf37d8e0d6 \ --hash=sha256:009e8388d4d551a2107632921320886650b46332f61dc935e70c8bcf37d8e0d6 \
--hash=sha256:0157c269701d88f5faf1fa0e4560e4d814f210c01a5b55df3cab95e9346a8bcc \ --hash=sha256:0157c269701d88f5faf1fa0e4560e4d814f210c01a5b55df3cab95e9346a8bcc \
@ -1296,7 +1300,7 @@ typing-extensions==3.7.4.3 \
--hash=sha256:7cb407020f00f7bfc3cb3e7881628838e69d8f3fcab2f64742a5e76b2f841918 \ --hash=sha256:7cb407020f00f7bfc3cb3e7881628838e69d8f3fcab2f64742a5e76b2f841918 \
--hash=sha256:99d4073b617d30288f569d3f13d2bd7548c3a7e4c8de87db09a9d29bb3a4a60c \ --hash=sha256:99d4073b617d30288f569d3f13d2bd7548c3a7e4c8de87db09a9d29bb3a4a60c \
--hash=sha256:dafc7639cde7f1b6e1acc0f457842a83e722ccca8eef5270af2d74792619a89f \ --hash=sha256:dafc7639cde7f1b6e1acc0f457842a83e722ccca8eef5270af2d74792619a89f \
# via -r requirements/common.in, libcst, mypy, pyre-extensions, typing-inspect, zulint # via -r requirements/common.in, libcst, mypy, pyre-extensions, sqlalchemy-stubs, typing-inspect, zulint
typing-inspect==0.6.0 \ typing-inspect==0.6.0 \
--hash=sha256:3b98390df4d999a28cf5b35d8b333425af5da2ece8a4ea9e98f71e7591347b4f \ --hash=sha256:3b98390df4d999a28cf5b35d8b333425af5da2ece8a4ea9e98f71e7591347b4f \
--hash=sha256:8f1b1dd25908dbfd81d3bebc218011531e7ab614ba6e5bf7826d887c834afab7 \ --hash=sha256:8f1b1dd25908dbfd81d3bebc218011531e7ab614ba6e5bf7826d887c834afab7 \

View File

@ -3,3 +3,4 @@
# and requirements/mypy.txt. # and requirements/mypy.txt.
# See requirements/README.md for more detail. # See requirements/README.md for more detail.
mypy mypy
sqlalchemy-stubs

View File

@ -26,6 +26,10 @@ mypy==0.790 \
--hash=sha256:e86bdace26c5fe9cf8cb735e7cedfe7850ad92b327ac5d797c656717d2ca66de \ --hash=sha256:e86bdace26c5fe9cf8cb735e7cedfe7850ad92b327ac5d797c656717d2ca66de \
--hash=sha256:e97e9c13d67fbe524be17e4d8025d51a7dca38f90de2e462243ab8ed8a9178d1 \ --hash=sha256:e97e9c13d67fbe524be17e4d8025d51a7dca38f90de2e462243ab8ed8a9178d1 \
--hash=sha256:eea260feb1830a627fb526d22fbb426b750d9f5a47b624e8d5e7e004359b219c \ --hash=sha256:eea260feb1830a627fb526d22fbb426b750d9f5a47b624e8d5e7e004359b219c \
# via -r requirements/mypy.in, sqlalchemy-stubs
sqlalchemy-stubs==0.3 \
--hash=sha256:a3318c810697164e8c818aa2d90bac570c1a0e752ced3ec25455b309c0bee8fd \
--hash=sha256:ca1250605a39648cc433f5c70cb1a6f9fe0b60bdda4c51e1f9a2ab3651daadc8 \
# via -r requirements/mypy.in # via -r requirements/mypy.in
typed-ast==1.4.1 \ typed-ast==1.4.1 \
--hash=sha256:0666aa36131496aed8f7be0410ff974562ab7eeac11ef351def9ea6fa28f6355 \ --hash=sha256:0666aa36131496aed8f7be0410ff974562ab7eeac11ef351def9ea6fa28f6355 \
@ -63,4 +67,4 @@ typing-extensions==3.7.4.3 \
--hash=sha256:7cb407020f00f7bfc3cb3e7881628838e69d8f3fcab2f64742a5e76b2f841918 \ --hash=sha256:7cb407020f00f7bfc3cb3e7881628838e69d8f3fcab2f64742a5e76b2f841918 \
--hash=sha256:99d4073b617d30288f569d3f13d2bd7548c3a7e4c8de87db09a9d29bb3a4a60c \ --hash=sha256:99d4073b617d30288f569d3f13d2bd7548c3a7e4c8de87db09a9d29bb3a4a60c \
--hash=sha256:dafc7639cde7f1b6e1acc0f457842a83e722ccca8eef5270af2d74792619a89f \ --hash=sha256:dafc7639cde7f1b6e1acc0f457842a83e722ccca8eef5270af2d74792619a89f \
# via mypy # via mypy, sqlalchemy-stubs

View File

@ -43,4 +43,4 @@ API_FEATURE_LEVEL = 35
# historical commits sharing the same major version, in which case a # historical commits sharing the same major version, in which case a
# minor version bump suffices. # minor version bump suffices.
PROVISION_VERSION = '115.0' PROVISION_VERSION = '115.1'

View File

@ -16,13 +16,15 @@ class NonClosingPool(sqlalchemy.pool.NullPool):
pass pass
def recreate(self) -> 'NonClosingPool': def recreate(self) -> 'NonClosingPool':
return self.__class__(creator=self._creator, return self.__class__(
recycle=self._recycle, creator=self._creator, # type: ignore[attr-defined] # implementation detail
use_threadlocal=self._use_threadlocal, recycle=self._recycle, # type: ignore[attr-defined] # implementation detail
reset_on_return=self._reset_on_return, use_threadlocal=self._use_threadlocal, # type: ignore[attr-defined] # implementation detail
echo=self.echo, reset_on_return=self._reset_on_return, # type: ignore[attr-defined] # implementation detail
logging_name=self._orig_logging_name, echo=self.echo,
_dispatch=self.dispatch) logging_name=self._orig_logging_name, # type: ignore[attr-defined] # implementation detail
_dispatch=self.dispatch, # type: ignore[attr-defined] # implementation detail
)
sqlalchemy_engine: Optional[Any] = None sqlalchemy_engine: Optional[Any] = None
def get_sqlalchemy_connection() -> sqlalchemy.engine.base.Connection: def get_sqlalchemy_connection() -> sqlalchemy.engine.base.Connection:

View File

@ -2,7 +2,8 @@ from typing import Any, Dict, List, Optional, Tuple
from django.db import connection from django.db import connection
from django.db.models.query import Q, QuerySet from django.db.models.query import Q, QuerySet
from sqlalchemy.sql import column, func, literal from sqlalchemy import Text
from sqlalchemy.sql import ColumnElement, column, func, literal
from zerver.lib.request import REQ from zerver.lib.request import REQ
from zerver.models import Message, Stream, UserMessage, UserProfile from zerver.models import Message, Stream, UserMessage, UserProfile
@ -65,14 +66,14 @@ using "subject" in the DB sense, and nothing customer facing.
DB_TOPIC_NAME = "subject" DB_TOPIC_NAME = "subject"
MESSAGE__TOPIC = 'message__subject' MESSAGE__TOPIC = 'message__subject'
def topic_match_sa(topic_name: str) -> Any: def topic_match_sa(topic_name: str) -> "ColumnElement[bool]":
# _sa is short for SQLAlchemy, which we use mostly for # _sa is short for SQLAlchemy, which we use mostly for
# queries that search messages # queries that search messages
topic_cond = func.upper(column("subject")) == func.upper(literal(topic_name)) topic_cond = func.upper(column("subject", Text)) == func.upper(literal(topic_name))
return topic_cond return topic_cond
def topic_column_sa() -> Any: def topic_column_sa() -> "ColumnElement[str]":
return column("subject") return column("subject", Text)
def filter_by_exact_message_topic(query: QuerySet, message: Message) -> QuerySet: def filter_by_exact_message_topic(query: QuerySet, message: Message) -> QuerySet:
topic_name = message.topic_name() topic_name = message.topic_name()

View File

@ -2,7 +2,7 @@ import datetime
from typing import Any, Callable, Dict, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple
from django.utils.timezone import now as timezone_now from django.utils.timezone import now as timezone_now
from sqlalchemy.sql import Selectable, and_, column, not_, or_ from sqlalchemy.sql import ClauseElement, and_, column, not_, or_
from zerver.lib.timestamp import datetime_to_timestamp from zerver.lib.timestamp import datetime_to_timestamp
from zerver.lib.topic import topic_match_sa from zerver.lib.topic import topic_match_sa
@ -74,9 +74,9 @@ def topic_is_muted(user_profile: UserProfile, stream_id: int, topic_name: str) -
).exists() ).exists()
return is_muted return is_muted
def exclude_topic_mutes(conditions: List[Selectable], def exclude_topic_mutes(conditions: List[ClauseElement],
user_profile: UserProfile, user_profile: UserProfile,
stream_id: Optional[int]) -> List[Selectable]: stream_id: Optional[int]) -> List[ClauseElement]:
query = MutedTopic.objects.filter( query = MutedTopic.objects.filter(
user_profile=user_profile, user_profile=user_profile,
) )
@ -95,7 +95,7 @@ def exclude_topic_mutes(conditions: List[Selectable],
if not rows: if not rows:
return conditions return conditions
def mute_cond(row: Dict[str, Any]) -> Selectable: def mute_cond(row: Dict[str, Any]) -> ClauseElement:
recipient_id = row['recipient_id'] recipient_id = row['recipient_id']
topic_name = row['topic_name'] topic_name = row['topic_name']
stream_cond = column("recipient_id") == recipient_id stream_cond = column("recipient_id") == recipient_id

View File

@ -8,7 +8,7 @@ from django.db import connection
from django.http import HttpResponse from django.http import HttpResponse
from django.test import override_settings from django.test import override_settings
from django.utils.timezone import now as timezone_now from django.utils.timezone import now as timezone_now
from sqlalchemy.sql import and_, column, select, table from sqlalchemy.sql import Select, and_, column, select, table
from sqlalchemy.sql.elements import ClauseElement from sqlalchemy.sql.elements import ClauseElement
from analytics.lib.counts import COUNT_STATS from analytics.lib.counts import COUNT_STATS
@ -54,7 +54,6 @@ from zerver.views.message_fetch import (
LARGER_THAN_MAX_MESSAGE_ID, LARGER_THAN_MAX_MESSAGE_ID,
BadNarrowOperator, BadNarrowOperator,
NarrowBuilder, NarrowBuilder,
Query,
exclude_muting_conditions, exclude_muting_conditions,
find_first_unread_anchor, find_first_unread_anchor,
get_messages_backend, get_messages_backend,
@ -454,7 +453,7 @@ class NarrowBuilderTest(ZulipTestCase):
term = dict(operator='stream', operand='non-web-public-stream') term = dict(operator='stream', operand='non-web-public-stream')
builder = NarrowBuilder(self.user_profile, column('id'), self.realm, True) builder = NarrowBuilder(self.user_profile, column('id'), self.realm, True)
def _build_query(term: Dict[str, Any]) -> Query: def _build_query(term: Dict[str, Any]) -> Select:
return builder.add_term(self.raw_query, term) return builder.add_term(self.raw_query, term)
self.assertRaises(BadNarrowOperator, _build_query, term) self.assertRaises(BadNarrowOperator, _build_query, term)
@ -467,7 +466,7 @@ class NarrowBuilderTest(ZulipTestCase):
self.assertEqual(actual_params, params) self.assertEqual(actual_params, params)
self.assertIn(where_clause, get_sqlalchemy_sql(query)) self.assertIn(where_clause, get_sqlalchemy_sql(query))
def _build_query(self, term: Dict[str, Any]) -> Query: def _build_query(self, term: Dict[str, Any]) -> Select:
return self.builder.add_term(self.raw_query, term) return self.builder.add_term(self.raw_query, term)
class NarrowLibraryTest(ZulipTestCase): class NarrowLibraryTest(ZulipTestCase):

View File

@ -1,5 +1,5 @@
import re import re
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
import orjson import orjson
from django.conf import settings from django.conf import settings
@ -11,9 +11,12 @@ from django.utils.html import escape as escape_html
from django.utils.translation import ugettext as _ from django.utils.translation import ugettext as _
from sqlalchemy import func from sqlalchemy import func
from sqlalchemy.dialects import postgresql from sqlalchemy.dialects import postgresql
from sqlalchemy.engine import Connection, RowProxy
from sqlalchemy.sql import ( from sqlalchemy.sql import (
ClauseElement,
ColumnElement, ColumnElement,
Selectable, FromClause,
Select,
alias, alias,
and_, and_,
column, column,
@ -26,6 +29,7 @@ from sqlalchemy.sql import (
table, table,
union_all, union_all,
) )
from sqlalchemy.types import Boolean, Integer, Text
from zerver.context_processors import get_valid_realm_from_request from zerver.context_processors import get_valid_realm_from_request
from zerver.decorator import REQ, has_request_variables from zerver.decorator import REQ, has_request_variables
@ -84,11 +88,7 @@ class BadNarrowOperator(JsonableError):
def msg_format() -> str: def msg_format() -> str:
return _('Invalid narrow operator: {desc}') return _('Invalid narrow operator: {desc}')
# TODO: Should be Select, but sqlalchemy stubs are busted ConditionTransform = Callable[[ClauseElement], ClauseElement]
Query = Any
# TODO: should be Callable[[ColumnElement], ColumnElement], but sqlalchemy stubs are busted
ConditionTransform = Any
OptionalNarrowListT = Optional[List[Dict[str, Any]]] OptionalNarrowListT = Optional[List[Dict[str, Any]]]
@ -97,21 +97,24 @@ TS_START = "<ts-match>"
TS_STOP = "</ts-match>" TS_STOP = "</ts-match>"
def ts_locs_array( def ts_locs_array(
config: ColumnElement, text: ColumnElement, tsquery: ColumnElement, config: "ColumnElement[str]", text: "ColumnElement[str]", tsquery: "ColumnElement[object]",
) -> ColumnElement: ) -> "ColumnElement[List[List[int]]]":
options = f"HighlightAll = TRUE, StartSel = {TS_START}, StopSel = {TS_STOP}" options = f"HighlightAll = TRUE, StartSel = {TS_START}, StopSel = {TS_STOP}"
delimited = func.ts_headline(config, text, tsquery, options) delimited = func.ts_headline(config, text, tsquery, options)
parts = func.unnest(func.string_to_array(delimited, TS_START)).alias() parts = func.unnest(func.string_to_array(delimited, TS_START)).alias()
part = column(parts.name) part = column(parts.name, Text)
part_len = func.length(part) - len(TS_STOP) part_len = func.length(part) - len(TS_STOP)
match_pos = func.sum(part_len).over(rows=(None, -1)) + len(TS_STOP) match_pos = func.sum(part_len).over(rows=(None, -1)) + len(TS_STOP)
match_len = func.strpos(part, TS_STOP) - 1 match_len = func.strpos(part, TS_STOP) - 1
return func.array( ret = func.array(
select([postgresql.array([match_pos, match_len])]) select([
postgresql.array([match_pos, match_len]), # type: ignore[call-overload] # https://github.com/dropbox/sqlalchemy-stubs/issues/188
])
.select_from(parts) .select_from(parts)
.offset(1) .offset(1)
.as_scalar(), .as_scalar(),
) )
return ret
# When you add a new operator to this, also update zerver/lib/narrow.py # When you add a new operator to this, also update zerver/lib/narrow.py
class NarrowBuilder: class NarrowBuilder:
@ -124,8 +127,8 @@ class NarrowBuilder:
# None of these methods ever *add* messages to a query's result. # None of these methods ever *add* messages to a query's result.
# #
# That is, the `add_term` method, and its helpers the `by_*` methods, # That is, the `add_term` method, and its helpers the `by_*` methods,
# are passed a Query object representing a query for messages; they may # are passed a Select object representing a query for messages; they may
# call some methods on it, and then they return a resulting Query # call some methods on it, and then they return a resulting Select
# object. Things these methods may do to the queries they handle # object. Things these methods may do to the queries they handle
# include # include
# * add conditions to filter out rows (i.e., messages), with `query.where` # * add conditions to filter out rows (i.e., messages), with `query.where`
@ -136,14 +139,14 @@ class NarrowBuilder:
# * anything that would pull in additional rows, or information on # * anything that would pull in additional rows, or information on
# other messages. # other messages.
def __init__(self, user_profile: Optional[UserProfile], msg_id_column: str, def __init__(self, user_profile: Optional[UserProfile], msg_id_column: "ColumnElement[int]",
realm: Realm, is_web_public_query: bool=False) -> None: realm: Realm, is_web_public_query: bool=False) -> None:
self.user_profile = user_profile self.user_profile = user_profile
self.msg_id_column = msg_id_column self.msg_id_column = msg_id_column
self.realm = realm self.realm = realm
self.is_web_public_query = is_web_public_query self.is_web_public_query = is_web_public_query
def add_term(self, query: Query, term: Dict[str, Any]) -> Query: def add_term(self, query: Select, term: Dict[str, Any]) -> Select:
""" """
Extend the given query to one narrowed by the given term, and return the result. Extend the given query to one narrowed by the given term, and return the result.
@ -176,14 +179,14 @@ class NarrowBuilder:
return method(query, operand, maybe_negate) return method(query, operand, maybe_negate)
def by_has(self, query: Query, operand: str, maybe_negate: ConditionTransform) -> Query: def by_has(self, query: Select, operand: str, maybe_negate: ConditionTransform) -> Select:
if operand not in ['attachment', 'image', 'link']: if operand not in ['attachment', 'image', 'link']:
raise BadNarrowOperator("unknown 'has' operand " + operand) raise BadNarrowOperator("unknown 'has' operand " + operand)
col_name = 'has_' + operand col_name = 'has_' + operand
cond = column(col_name) cond = column(col_name, Boolean)
return query.where(maybe_negate(cond)) return query.where(maybe_negate(cond))
def by_in(self, query: Query, operand: str, maybe_negate: ConditionTransform) -> Query: def by_in(self, query: Select, operand: str, maybe_negate: ConditionTransform) -> Select:
# This operator does not support is_web_public_query. # This operator does not support is_web_public_query.
assert not self.is_web_public_query assert not self.is_web_public_query
assert self.user_profile is not None assert self.user_profile is not None
@ -196,27 +199,27 @@ class NarrowBuilder:
raise BadNarrowOperator("unknown 'in' operand " + operand) raise BadNarrowOperator("unknown 'in' operand " + operand)
def by_is(self, query: Query, operand: str, maybe_negate: ConditionTransform) -> Query: def by_is(self, query: Select, operand: str, maybe_negate: ConditionTransform) -> Select:
# This operator class does not support is_web_public_query. # This operator class does not support is_web_public_query.
assert not self.is_web_public_query assert not self.is_web_public_query
assert self.user_profile is not None assert self.user_profile is not None
if operand == 'private': if operand == 'private':
cond = column("flags").op("&")(UserMessage.flags.is_private.mask) != 0 cond = column("flags", Integer).op("&")(UserMessage.flags.is_private.mask) != 0
return query.where(maybe_negate(cond)) return query.where(maybe_negate(cond))
elif operand == 'starred': elif operand == 'starred':
cond = column("flags").op("&")(UserMessage.flags.starred.mask) != 0 cond = column("flags", Integer).op("&")(UserMessage.flags.starred.mask) != 0
return query.where(maybe_negate(cond)) return query.where(maybe_negate(cond))
elif operand == 'unread': elif operand == 'unread':
cond = column("flags").op("&")(UserMessage.flags.read.mask) == 0 cond = column("flags", Integer).op("&")(UserMessage.flags.read.mask) == 0
return query.where(maybe_negate(cond)) return query.where(maybe_negate(cond))
elif operand == 'mentioned': elif operand == 'mentioned':
cond1 = column("flags").op("&")(UserMessage.flags.mentioned.mask) != 0 cond1 = column("flags", Integer).op("&")(UserMessage.flags.mentioned.mask) != 0
cond2 = column("flags").op("&")(UserMessage.flags.wildcard_mentioned.mask) != 0 cond2 = column("flags", Integer).op("&")(UserMessage.flags.wildcard_mentioned.mask) != 0
cond = or_(cond1, cond2) cond = or_(cond1, cond2)
return query.where(maybe_negate(cond)) return query.where(maybe_negate(cond))
elif operand == 'alerted': elif operand == 'alerted':
cond = column("flags").op("&")(UserMessage.flags.has_alert_word.mask) != 0 cond = column("flags", Integer).op("&")(UserMessage.flags.has_alert_word.mask) != 0
return query.where(maybe_negate(cond)) return query.where(maybe_negate(cond))
raise BadNarrowOperator("unknown 'is' operand " + operand) raise BadNarrowOperator("unknown 'is' operand " + operand)
@ -242,7 +245,7 @@ class NarrowBuilder:
s[i] = '\\' + c s[i] = '\\' + c
return ''.join(s) return ''.join(s)
def by_stream(self, query: Query, operand: Union[str, int], maybe_negate: ConditionTransform) -> Query: def by_stream(self, query: Select, operand: Union[str, int], maybe_negate: ConditionTransform) -> Select:
try: try:
# Because you can see your own message history for # Because you can see your own message history for
# private streams you are no longer subscribed to, we # private streams you are no longer subscribed to, we
@ -274,14 +277,14 @@ class NarrowBuilder:
matching_streams = get_active_streams(self.realm).filter( matching_streams = get_active_streams(self.realm).filter(
name__iregex=fr'^(un)*{self._pg_re_escape(base_stream_name)}(\.d)*$') name__iregex=fr'^(un)*{self._pg_re_escape(base_stream_name)}(\.d)*$')
recipient_ids = [matching_stream.recipient_id for matching_stream in matching_streams] recipient_ids = [matching_stream.recipient_id for matching_stream in matching_streams]
cond = column("recipient_id").in_(recipient_ids) cond = column("recipient_id", Integer).in_(recipient_ids)
return query.where(maybe_negate(cond)) return query.where(maybe_negate(cond))
recipient = stream.recipient recipient = stream.recipient
cond = column("recipient_id") == recipient.id cond = column("recipient_id", Integer) == recipient.id
return query.where(maybe_negate(cond)) return query.where(maybe_negate(cond))
def by_streams(self, query: Query, operand: str, maybe_negate: ConditionTransform) -> Query: def by_streams(self, query: Select, operand: str, maybe_negate: ConditionTransform) -> Select:
if operand == 'public': if operand == 'public':
# Get all both subscribed and non subscribed public streams # Get all both subscribed and non subscribed public streams
# but exclude any private subscribed streams. # but exclude any private subscribed streams.
@ -292,10 +295,10 @@ class NarrowBuilder:
raise BadNarrowOperator('unknown streams operand ' + operand) raise BadNarrowOperator('unknown streams operand ' + operand)
recipient_ids = recipient_queryset.values_list("recipient_id", flat=True).order_by('id') recipient_ids = recipient_queryset.values_list("recipient_id", flat=True).order_by('id')
cond = column("recipient_id").in_(recipient_ids) cond = column("recipient_id", Integer).in_(recipient_ids)
return query.where(maybe_negate(cond)) return query.where(maybe_negate(cond))
def by_topic(self, query: Query, operand: str, maybe_negate: ConditionTransform) -> Query: def by_topic(self, query: Select, operand: str, maybe_negate: ConditionTransform) -> Select:
if self.realm.is_zephyr_mirror_realm: if self.realm.is_zephyr_mirror_realm:
# MIT users expect narrowing to topic "foo" to also show messages to /^foo(.d)*$/ # MIT users expect narrowing to topic "foo" to also show messages to /^foo(.d)*$/
# (foo, foo.d, foo.d.d, etc) # (foo, foo.d, foo.d.d, etc)
@ -307,7 +310,7 @@ class NarrowBuilder:
# Additionally, MIT users expect the empty instance and # Additionally, MIT users expect the empty instance and
# instance "personal" to be the same. # instance "personal" to be the same.
if base_topic in ('', 'personal', '(instance "")'): if base_topic in ('', 'personal', '(instance "")'):
cond = or_( cond: ClauseElement = or_(
topic_match_sa(""), topic_match_sa(""),
topic_match_sa(".d"), topic_match_sa(".d"),
topic_match_sa(".d.d"), topic_match_sa(".d.d"),
@ -340,7 +343,7 @@ class NarrowBuilder:
cond = topic_match_sa(operand) cond = topic_match_sa(operand)
return query.where(maybe_negate(cond)) return query.where(maybe_negate(cond))
def by_sender(self, query: Query, operand: Union[str, int], maybe_negate: ConditionTransform) -> Query: def by_sender(self, query: Select, operand: Union[str, int], maybe_negate: ConditionTransform) -> Select:
try: try:
if isinstance(operand, str): if isinstance(operand, str):
sender = get_user_including_cross_realm(operand, self.realm) sender = get_user_including_cross_realm(operand, self.realm)
@ -349,20 +352,20 @@ class NarrowBuilder:
except UserProfile.DoesNotExist: except UserProfile.DoesNotExist:
raise BadNarrowOperator('unknown user ' + str(operand)) raise BadNarrowOperator('unknown user ' + str(operand))
cond = column("sender_id") == literal(sender.id) cond = column("sender_id", Integer) == literal(sender.id)
return query.where(maybe_negate(cond)) return query.where(maybe_negate(cond))
def by_near(self, query: Query, operand: str, maybe_negate: ConditionTransform) -> Query: def by_near(self, query: Select, operand: str, maybe_negate: ConditionTransform) -> Select:
return query return query
def by_id(self, query: Query, operand: str, maybe_negate: ConditionTransform) -> Query: def by_id(self, query: Select, operand: Union[int, str], maybe_negate: ConditionTransform) -> Select:
if not str(operand).isdigit(): if not str(operand).isdigit():
raise BadNarrowOperator("Invalid message ID") raise BadNarrowOperator("Invalid message ID")
cond = self.msg_id_column == literal(operand) cond = self.msg_id_column == literal(operand)
return query.where(maybe_negate(cond)) return query.where(maybe_negate(cond))
def by_pm_with(self, query: Query, operand: Union[str, Iterable[int]], def by_pm_with(self, query: Select, operand: Union[str, Iterable[int]],
maybe_negate: ConditionTransform) -> Query: maybe_negate: ConditionTransform) -> Select:
# This operator does not support is_web_public_query. # This operator does not support is_web_public_query.
assert not self.is_web_public_query assert not self.is_web_public_query
assert self.user_profile is not None assert self.user_profile is not None
@ -394,7 +397,7 @@ class NarrowBuilder:
# Group DM # Group DM
if recipient.type == Recipient.HUDDLE: if recipient.type == Recipient.HUDDLE:
cond = column("recipient_id") == recipient.id cond = column("recipient_id", Integer) == recipient.id
return query.where(maybe_negate(cond)) return query.where(maybe_negate(cond))
# 1:1 PM # 1:1 PM
@ -414,19 +417,19 @@ class NarrowBuilder:
# complex query to get messages between these two users # complex query to get messages between these two users
# with either of them as the sender. # with either of them as the sender.
self_recipient_id = self.user_profile.recipient_id self_recipient_id = self.user_profile.recipient_id
cond = or_(and_(column("sender_id") == other_participant.id, cond = or_(and_(column("sender_id", Integer) == other_participant.id,
column("recipient_id") == self_recipient_id), column("recipient_id", Integer) == self_recipient_id),
and_(column("sender_id") == self.user_profile.id, and_(column("sender_id", Integer) == self.user_profile.id,
column("recipient_id") == recipient.id)) column("recipient_id", Integer) == recipient.id))
return query.where(maybe_negate(cond)) return query.where(maybe_negate(cond))
# PM with self # PM with self
cond = and_(column("sender_id") == self.user_profile.id, cond = and_(column("sender_id", Integer) == self.user_profile.id,
column("recipient_id") == recipient.id) column("recipient_id", Integer) == recipient.id)
return query.where(maybe_negate(cond)) return query.where(maybe_negate(cond))
def by_group_pm_with(self, query: Query, operand: Union[str, int], def by_group_pm_with(self, query: Select, operand: Union[str, int],
maybe_negate: ConditionTransform) -> Query: maybe_negate: ConditionTransform) -> Select:
# This operator does not support is_web_public_query. # This operator does not support is_web_public_query.
assert not self.is_web_public_query assert not self.is_web_public_query
assert self.user_profile is not None assert self.user_profile is not None
@ -453,33 +456,33 @@ class NarrowBuilder:
).values("recipient_id")] ).values("recipient_id")]
recipient_ids = set(self_recipient_ids) & set(narrow_recipient_ids) recipient_ids = set(self_recipient_ids) & set(narrow_recipient_ids)
cond = column("recipient_id").in_(recipient_ids) cond = column("recipient_id", Integer).in_(recipient_ids)
return query.where(maybe_negate(cond)) return query.where(maybe_negate(cond))
def by_search(self, query: Query, operand: str, maybe_negate: ConditionTransform) -> Query: def by_search(self, query: Select, operand: str, maybe_negate: ConditionTransform) -> Select:
if settings.USING_PGROONGA: if settings.USING_PGROONGA:
return self._by_search_pgroonga(query, operand, maybe_negate) return self._by_search_pgroonga(query, operand, maybe_negate)
else: else:
return self._by_search_tsearch(query, operand, maybe_negate) return self._by_search_tsearch(query, operand, maybe_negate)
def _by_search_pgroonga(self, query: Query, operand: str, def _by_search_pgroonga(self, query: Select, operand: str,
maybe_negate: ConditionTransform) -> Query: maybe_negate: ConditionTransform) -> Select:
match_positions_character = func.pgroonga_match_positions_character match_positions_character = func.pgroonga_match_positions_character
query_extract_keywords = func.pgroonga_query_extract_keywords query_extract_keywords = func.pgroonga_query_extract_keywords
operand_escaped = func.escape_html(operand) operand_escaped = func.escape_html(operand)
keywords = query_extract_keywords(operand_escaped) keywords = query_extract_keywords(operand_escaped)
query = query.column(match_positions_character(column("rendered_content"), query = query.column(match_positions_character(column("rendered_content", Text),
keywords).label("content_matches")) keywords).label("content_matches"))
query = query.column(match_positions_character(func.escape_html(topic_column_sa()), query = query.column(match_positions_character(func.escape_html(topic_column_sa()),
keywords).label("topic_matches")) keywords).label("topic_matches"))
condition = column("search_pgroonga").op("&@~")(operand_escaped) condition = column("search_pgroonga").op("&@~")(operand_escaped)
return query.where(maybe_negate(condition)) return query.where(maybe_negate(condition))
def _by_search_tsearch(self, query: Query, operand: str, def _by_search_tsearch(self, query: Select, operand: str,
maybe_negate: ConditionTransform) -> Query: maybe_negate: ConditionTransform) -> Select:
tsquery = func.plainto_tsquery(literal("zulip.english_us_search"), literal(operand)) tsquery = func.plainto_tsquery(literal("zulip.english_us_search"), literal(operand))
query = query.column(ts_locs_array(literal("zulip.english_us_search"), query = query.column(ts_locs_array(literal("zulip.english_us_search"),
column("rendered_content"), column("rendered_content", Text),
tsquery).label("content_matches")) tsquery).label("content_matches"))
# We HTML-escape the topic in PostgreSQL to avoid doing a server round-trip # We HTML-escape the topic in PostgreSQL to avoid doing a server round-trip
query = query.column(ts_locs_array(literal("zulip.english_us_search"), query = query.column(ts_locs_array(literal("zulip.english_us_search"),
@ -494,11 +497,11 @@ class NarrowBuilder:
if term[0] == '"' and term[-1] == '"': if term[0] == '"' and term[-1] == '"':
term = term[1:-1] term = term[1:-1]
term = '%' + connection.ops.prep_for_like_query(term) + '%' term = '%' + connection.ops.prep_for_like_query(term) + '%'
cond = or_(column("content").ilike(term), cond = or_(column("content", Text).ilike(term),
topic_column_sa().ilike(term)) topic_column_sa().ilike(term))
query = query.where(maybe_negate(cond)) query = query.where(maybe_negate(cond))
cond = column("search_tsvector").op("@@")(tsquery) cond = column("search_tsvector", postgresql.TSVECTOR).op("@@")(tsquery)
return query.where(maybe_negate(cond)) return query.where(maybe_negate(cond))
def highlight_string(text: str, locs: Iterable[Tuple[int, int]]) -> str: def highlight_string(text: str, locs: Iterable[Tuple[int, int]]) -> str:
@ -658,7 +661,7 @@ def get_stream_from_narrow_access_unchecked(narrow: OptionalNarrowListT, realm:
return None return None
def exclude_muting_conditions(user_profile: UserProfile, def exclude_muting_conditions(user_profile: UserProfile,
narrow: OptionalNarrowListT) -> List[Selectable]: narrow: OptionalNarrowListT) -> List[ClauseElement]:
conditions = [] conditions = []
stream_id = None stream_id = None
try: try:
@ -681,7 +684,7 @@ def exclude_muting_conditions(user_profile: UserProfile,
muted_recipient_ids = [row['recipient_id'] for row in rows] muted_recipient_ids = [row['recipient_id'] for row in rows]
if len(muted_recipient_ids) > 0: if len(muted_recipient_ids) > 0:
# Only add the condition if we have muted streams to simplify/avoid warnings. # Only add the condition if we have muted streams to simplify/avoid warnings.
condition = not_(column("recipient_id").in_(muted_recipient_ids)) condition = not_(column("recipient_id", Integer).in_(muted_recipient_ids))
conditions.append(condition) conditions.append(condition)
conditions = exclude_topic_mutes(conditions, user_profile, stream_id) conditions = exclude_topic_mutes(conditions, user_profile, stream_id)
@ -690,38 +693,39 @@ def exclude_muting_conditions(user_profile: UserProfile,
def get_base_query_for_search(user_profile: Optional[UserProfile], def get_base_query_for_search(user_profile: Optional[UserProfile],
need_message: bool, need_message: bool,
need_user_message: bool) -> Tuple[Query, ColumnElement]: need_user_message: bool) -> Tuple[Select, "ColumnElement[int]"]:
# Handle the simple case where user_message isn't involved first. # Handle the simple case where user_message isn't involved first.
if not need_user_message: if not need_user_message:
assert(need_message) assert(need_message)
query = select([column("id").label("message_id")], query = select([column("id", Integer).label("message_id")],
None, None,
table("zerver_message")) table("zerver_message"))
inner_msg_id_col = literal_column("zerver_message.id") inner_msg_id_col: ColumnElement[int]
inner_msg_id_col = literal_column("zerver_message.id", Integer) # type: ignore[assignment] # https://github.com/dropbox/sqlalchemy-stubs/pull/189
return (query, inner_msg_id_col) return (query, inner_msg_id_col)
assert user_profile is not None assert user_profile is not None
if need_message: if need_message:
query = select([column("message_id"), column("flags")], query = select([column("message_id"), column("flags", Integer)],
column("user_profile_id") == literal(user_profile.id), column("user_profile_id") == literal(user_profile.id),
join(table("zerver_usermessage"), table("zerver_message"), join(table("zerver_usermessage"), table("zerver_message"),
literal_column("zerver_usermessage.message_id") == literal_column("zerver_usermessage.message_id", Integer) ==
literal_column("zerver_message.id"))) literal_column("zerver_message.id", Integer)))
inner_msg_id_col = column("message_id") inner_msg_id_col = column("message_id", Integer)
return (query, inner_msg_id_col) return (query, inner_msg_id_col)
query = select([column("message_id"), column("flags")], query = select([column("message_id"), column("flags", Integer)],
column("user_profile_id") == literal(user_profile.id), column("user_profile_id") == literal(user_profile.id),
table("zerver_usermessage")) table("zerver_usermessage"))
inner_msg_id_col = column("message_id") inner_msg_id_col = column("message_id", Integer)
return (query, inner_msg_id_col) return (query, inner_msg_id_col)
def add_narrow_conditions(user_profile: Optional[UserProfile], def add_narrow_conditions(user_profile: Optional[UserProfile],
inner_msg_id_col: ColumnElement, inner_msg_id_col: "ColumnElement[int]",
query: Query, query: Select,
narrow: OptionalNarrowListT, narrow: OptionalNarrowListT,
is_web_public_query: bool, is_web_public_query: bool,
realm: Realm) -> Tuple[Query, bool]: realm: Realm) -> Tuple[Select, bool]:
is_search = False # for now is_search = False # for now
if narrow is None: if narrow is None:
@ -742,7 +746,7 @@ def add_narrow_conditions(user_profile: Optional[UserProfile],
if search_operands: if search_operands:
is_search = True is_search = True
query = query.column(topic_column_sa()).column(column("rendered_content")) query = query.column(topic_column_sa()).column(column("rendered_content", Text))
search_term = dict( search_term = dict(
operator='search', operator='search',
operand=' '.join(search_operands), operand=' '.join(search_operands),
@ -751,7 +755,7 @@ def add_narrow_conditions(user_profile: Optional[UserProfile],
return (query, is_search) return (query, is_search)
def find_first_unread_anchor(sa_conn: Any, def find_first_unread_anchor(sa_conn: Connection,
user_profile: Optional[UserProfile], user_profile: Optional[UserProfile],
narrow: OptionalNarrowListT) -> int: narrow: OptionalNarrowListT) -> int:
# For anonymous web users, all messages are treated as read, and so # For anonymous web users, all messages are treated as read, and so
@ -785,7 +789,7 @@ def find_first_unread_anchor(sa_conn: Any,
realm=user_profile.realm, realm=user_profile.realm,
) )
condition = column("flags").op("&")(UserMessage.flags.read.mask) == 0 condition = column("flags", Integer).op("&")(UserMessage.flags.read.mask) == 0
# We exclude messages on muted topics when finding the first unread # We exclude messages on muted topics when finding the first unread
# message in this narrow # message in this narrow
@ -913,6 +917,7 @@ def get_messages_backend(request: HttpRequest,
need_message = True need_message = True
need_user_message = True need_user_message = True
query: FromClause
query, inner_msg_id_col = get_base_query_for_search( query, inner_msg_id_col = get_base_query_for_search(
user_profile=user_profile, user_profile=user_profile,
need_message=need_message, need_message=need_message,
@ -970,7 +975,7 @@ def get_messages_backend(request: HttpRequest,
) )
main_query = alias(query) main_query = alias(query)
query = select(main_query.c, None, main_query).order_by(column("message_id").asc()) query = select(main_query.c, None, main_query).order_by(column("message_id", Integer).asc())
# This is a hack to tag the query we use for testing # This is a hack to tag the query we use for testing
query = query.prefix_with("/* get_messages */") query = query.prefix_with("/* get_messages */")
rows = list(sa_conn.execute(query).fetchall()) rows = list(sa_conn.execute(query).fetchall())
@ -1059,14 +1064,14 @@ def get_messages_backend(request: HttpRequest,
) )
return json_success(ret) return json_success(ret)
def limit_query_to_range(query: Query, def limit_query_to_range(query: Select,
num_before: int, num_before: int,
num_after: int, num_after: int,
anchor: int, anchor: int,
anchored_to_left: bool, anchored_to_left: bool,
anchored_to_right: bool, anchored_to_right: bool,
id_col: ColumnElement, id_col: "ColumnElement[int]",
first_visible_message_id: int) -> Query: first_visible_message_id: int) -> FromClause:
''' '''
This code is actually generic enough that we could move it to a This code is actually generic enough that we could move it to a
library, but our only caller for now is message search. library, but our only caller for now is message search.
@ -1137,7 +1142,7 @@ def limit_query_to_range(query: Query,
# return at most one row here. # return at most one row here.
return query.where(id_col == anchor) return query.where(id_col == anchor)
def post_process_limited_query(rows: List[Any], def post_process_limited_query(rows: Sequence[Union[RowProxy, Sequence[Any]]],
num_before: int, num_before: int,
num_after: int, num_after: int,
anchor: int, anchor: int,
@ -1153,7 +1158,9 @@ def post_process_limited_query(rows: List[Any],
# that the clients will know that they got complete results. # that the clients will know that they got complete results.
if first_visible_message_id > 0: if first_visible_message_id > 0:
visible_rows = [r for r in rows if r[0] >= first_visible_message_id] visible_rows: Sequence[Union[RowProxy, Sequence[Any]]] = [
r for r in rows if r[0] >= first_visible_message_id
]
else: else:
visible_rows = rows visible_rows = rows
@ -1162,8 +1169,8 @@ def post_process_limited_query(rows: List[Any],
if anchored_to_right: if anchored_to_right:
num_after = 0 num_after = 0
before_rows = visible_rows[:] before_rows = visible_rows[:]
anchor_rows: List[Any] = [] anchor_rows = []
after_rows: List[Any] = [] after_rows = []
else: else:
before_rows = [r for r in visible_rows if r[0] < anchor] before_rows = [r for r in visible_rows if r[0] < anchor]
anchor_rows = [r for r in visible_rows if r[0] == anchor] anchor_rows = [r for r in visible_rows if r[0] == anchor]
@ -1175,7 +1182,7 @@ def post_process_limited_query(rows: List[Any],
if num_after: if num_after:
after_rows = after_rows[:num_after] after_rows = after_rows[:num_after]
visible_rows = before_rows + anchor_rows + after_rows visible_rows = [*before_rows, *anchor_rows, *after_rows]
found_anchor = len(anchor_rows) == 1 found_anchor = len(anchor_rows) == 1
found_oldest = anchored_to_left or (len(before_rows) < num_before) found_oldest = anchored_to_left or (len(before_rows) < num_before)
@ -1211,14 +1218,14 @@ def messages_in_narrow_backend(request: HttpRequest, user_profile: UserProfile,
msg_ids = [message_id for message_id in msg_ids if message_id >= first_visible_message_id] msg_ids = [message_id for message_id in msg_ids if message_id >= first_visible_message_id]
# This query is limited to messages the user has access to because they # This query is limited to messages the user has access to because they
# actually received them, as reflected in `zerver_usermessage`. # actually received them, as reflected in `zerver_usermessage`.
query = select([column("message_id"), topic_column_sa(), column("rendered_content")], query = select([column("message_id", Integer), topic_column_sa(), column("rendered_content", Text)],
and_(column("user_profile_id") == literal(user_profile.id), and_(column("user_profile_id", Integer) == literal(user_profile.id),
column("message_id").in_(msg_ids)), column("message_id", Integer).in_(msg_ids)),
join(table("zerver_usermessage"), table("zerver_message"), join(table("zerver_usermessage"), table("zerver_message"),
literal_column("zerver_usermessage.message_id") == literal_column("zerver_usermessage.message_id", Integer) ==
literal_column("zerver_message.id"))) literal_column("zerver_message.id", Integer)))
builder = NarrowBuilder(user_profile, column("message_id"), user_profile.realm) builder = NarrowBuilder(user_profile, column("message_id", Integer), user_profile.realm)
if narrow is not None: if narrow is not None:
for term in narrow: for term in narrow:
query = builder.add_term(query, term) query = builder.add_term(query, term)