mirror of https://github.com/zulip/zulip.git
mypy: Use sqlalchemy-stubs.
Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
parent
8e0240300a
commit
13e35bfa94
|
@ -680,7 +680,7 @@ mypy==0.790 \
|
||||||
--hash=sha256:e86bdace26c5fe9cf8cb735e7cedfe7850ad92b327ac5d797c656717d2ca66de \
|
--hash=sha256:e86bdace26c5fe9cf8cb735e7cedfe7850ad92b327ac5d797c656717d2ca66de \
|
||||||
--hash=sha256:e97e9c13d67fbe524be17e4d8025d51a7dca38f90de2e462243ab8ed8a9178d1 \
|
--hash=sha256:e97e9c13d67fbe524be17e4d8025d51a7dca38f90de2e462243ab8ed8a9178d1 \
|
||||||
--hash=sha256:eea260feb1830a627fb526d22fbb426b750d9f5a47b624e8d5e7e004359b219c \
|
--hash=sha256:eea260feb1830a627fb526d22fbb426b750d9f5a47b624e8d5e7e004359b219c \
|
||||||
# via -r requirements/mypy.in
|
# via -r requirements/mypy.in, sqlalchemy-stubs
|
||||||
networkx==2.5 \
|
networkx==2.5 \
|
||||||
--hash=sha256:7978955423fbc9639c10498878be59caf99b44dc304c2286162fd24b458c1602 \
|
--hash=sha256:7978955423fbc9639c10498878be59caf99b44dc304c2286162fd24b458c1602 \
|
||||||
--hash=sha256:8c5812e9f798d37c50570d15c4a69d5710a18d77bafc903ee9c5fba7454c616c \
|
--hash=sha256:8c5812e9f798d37c50570d15c4a69d5710a18d77bafc903ee9c5fba7454c616c \
|
||||||
|
@ -1143,6 +1143,10 @@ sphinxcontrib-serializinghtml==1.1.4 \
|
||||||
--hash=sha256:eaa0eccc86e982a9b939b2b82d12cc5d013385ba5eadcc7e4fed23f4405f77bc \
|
--hash=sha256:eaa0eccc86e982a9b939b2b82d12cc5d013385ba5eadcc7e4fed23f4405f77bc \
|
||||||
--hash=sha256:f242a81d423f59617a8e5cf16f5d4d74e28ee9a66f9e5b637a18082991db5a9a \
|
--hash=sha256:f242a81d423f59617a8e5cf16f5d4d74e28ee9a66f9e5b637a18082991db5a9a \
|
||||||
# via sphinx
|
# via sphinx
|
||||||
|
sqlalchemy-stubs==0.3 \
|
||||||
|
--hash=sha256:a3318c810697164e8c818aa2d90bac570c1a0e752ced3ec25455b309c0bee8fd \
|
||||||
|
--hash=sha256:ca1250605a39648cc433f5c70cb1a6f9fe0b60bdda4c51e1f9a2ab3651daadc8 \
|
||||||
|
# via -r requirements/mypy.in
|
||||||
sqlalchemy==1.3.20 \
|
sqlalchemy==1.3.20 \
|
||||||
--hash=sha256:009e8388d4d551a2107632921320886650b46332f61dc935e70c8bcf37d8e0d6 \
|
--hash=sha256:009e8388d4d551a2107632921320886650b46332f61dc935e70c8bcf37d8e0d6 \
|
||||||
--hash=sha256:0157c269701d88f5faf1fa0e4560e4d814f210c01a5b55df3cab95e9346a8bcc \
|
--hash=sha256:0157c269701d88f5faf1fa0e4560e4d814f210c01a5b55df3cab95e9346a8bcc \
|
||||||
|
@ -1296,7 +1300,7 @@ typing-extensions==3.7.4.3 \
|
||||||
--hash=sha256:7cb407020f00f7bfc3cb3e7881628838e69d8f3fcab2f64742a5e76b2f841918 \
|
--hash=sha256:7cb407020f00f7bfc3cb3e7881628838e69d8f3fcab2f64742a5e76b2f841918 \
|
||||||
--hash=sha256:99d4073b617d30288f569d3f13d2bd7548c3a7e4c8de87db09a9d29bb3a4a60c \
|
--hash=sha256:99d4073b617d30288f569d3f13d2bd7548c3a7e4c8de87db09a9d29bb3a4a60c \
|
||||||
--hash=sha256:dafc7639cde7f1b6e1acc0f457842a83e722ccca8eef5270af2d74792619a89f \
|
--hash=sha256:dafc7639cde7f1b6e1acc0f457842a83e722ccca8eef5270af2d74792619a89f \
|
||||||
# via -r requirements/common.in, libcst, mypy, pyre-extensions, typing-inspect, zulint
|
# via -r requirements/common.in, libcst, mypy, pyre-extensions, sqlalchemy-stubs, typing-inspect, zulint
|
||||||
typing-inspect==0.6.0 \
|
typing-inspect==0.6.0 \
|
||||||
--hash=sha256:3b98390df4d999a28cf5b35d8b333425af5da2ece8a4ea9e98f71e7591347b4f \
|
--hash=sha256:3b98390df4d999a28cf5b35d8b333425af5da2ece8a4ea9e98f71e7591347b4f \
|
||||||
--hash=sha256:8f1b1dd25908dbfd81d3bebc218011531e7ab614ba6e5bf7826d887c834afab7 \
|
--hash=sha256:8f1b1dd25908dbfd81d3bebc218011531e7ab614ba6e5bf7826d887c834afab7 \
|
||||||
|
|
|
@ -3,3 +3,4 @@
|
||||||
# and requirements/mypy.txt.
|
# and requirements/mypy.txt.
|
||||||
# See requirements/README.md for more detail.
|
# See requirements/README.md for more detail.
|
||||||
mypy
|
mypy
|
||||||
|
sqlalchemy-stubs
|
||||||
|
|
|
@ -26,6 +26,10 @@ mypy==0.790 \
|
||||||
--hash=sha256:e86bdace26c5fe9cf8cb735e7cedfe7850ad92b327ac5d797c656717d2ca66de \
|
--hash=sha256:e86bdace26c5fe9cf8cb735e7cedfe7850ad92b327ac5d797c656717d2ca66de \
|
||||||
--hash=sha256:e97e9c13d67fbe524be17e4d8025d51a7dca38f90de2e462243ab8ed8a9178d1 \
|
--hash=sha256:e97e9c13d67fbe524be17e4d8025d51a7dca38f90de2e462243ab8ed8a9178d1 \
|
||||||
--hash=sha256:eea260feb1830a627fb526d22fbb426b750d9f5a47b624e8d5e7e004359b219c \
|
--hash=sha256:eea260feb1830a627fb526d22fbb426b750d9f5a47b624e8d5e7e004359b219c \
|
||||||
|
# via -r requirements/mypy.in, sqlalchemy-stubs
|
||||||
|
sqlalchemy-stubs==0.3 \
|
||||||
|
--hash=sha256:a3318c810697164e8c818aa2d90bac570c1a0e752ced3ec25455b309c0bee8fd \
|
||||||
|
--hash=sha256:ca1250605a39648cc433f5c70cb1a6f9fe0b60bdda4c51e1f9a2ab3651daadc8 \
|
||||||
# via -r requirements/mypy.in
|
# via -r requirements/mypy.in
|
||||||
typed-ast==1.4.1 \
|
typed-ast==1.4.1 \
|
||||||
--hash=sha256:0666aa36131496aed8f7be0410ff974562ab7eeac11ef351def9ea6fa28f6355 \
|
--hash=sha256:0666aa36131496aed8f7be0410ff974562ab7eeac11ef351def9ea6fa28f6355 \
|
||||||
|
@ -63,4 +67,4 @@ typing-extensions==3.7.4.3 \
|
||||||
--hash=sha256:7cb407020f00f7bfc3cb3e7881628838e69d8f3fcab2f64742a5e76b2f841918 \
|
--hash=sha256:7cb407020f00f7bfc3cb3e7881628838e69d8f3fcab2f64742a5e76b2f841918 \
|
||||||
--hash=sha256:99d4073b617d30288f569d3f13d2bd7548c3a7e4c8de87db09a9d29bb3a4a60c \
|
--hash=sha256:99d4073b617d30288f569d3f13d2bd7548c3a7e4c8de87db09a9d29bb3a4a60c \
|
||||||
--hash=sha256:dafc7639cde7f1b6e1acc0f457842a83e722ccca8eef5270af2d74792619a89f \
|
--hash=sha256:dafc7639cde7f1b6e1acc0f457842a83e722ccca8eef5270af2d74792619a89f \
|
||||||
# via mypy
|
# via mypy, sqlalchemy-stubs
|
||||||
|
|
|
@ -43,4 +43,4 @@ API_FEATURE_LEVEL = 35
|
||||||
# historical commits sharing the same major version, in which case a
|
# historical commits sharing the same major version, in which case a
|
||||||
# minor version bump suffices.
|
# minor version bump suffices.
|
||||||
|
|
||||||
PROVISION_VERSION = '115.0'
|
PROVISION_VERSION = '115.1'
|
||||||
|
|
|
@ -16,13 +16,15 @@ class NonClosingPool(sqlalchemy.pool.NullPool):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def recreate(self) -> 'NonClosingPool':
|
def recreate(self) -> 'NonClosingPool':
|
||||||
return self.__class__(creator=self._creator,
|
return self.__class__(
|
||||||
recycle=self._recycle,
|
creator=self._creator, # type: ignore[attr-defined] # implementation detail
|
||||||
use_threadlocal=self._use_threadlocal,
|
recycle=self._recycle, # type: ignore[attr-defined] # implementation detail
|
||||||
reset_on_return=self._reset_on_return,
|
use_threadlocal=self._use_threadlocal, # type: ignore[attr-defined] # implementation detail
|
||||||
|
reset_on_return=self._reset_on_return, # type: ignore[attr-defined] # implementation detail
|
||||||
echo=self.echo,
|
echo=self.echo,
|
||||||
logging_name=self._orig_logging_name,
|
logging_name=self._orig_logging_name, # type: ignore[attr-defined] # implementation detail
|
||||||
_dispatch=self.dispatch)
|
_dispatch=self.dispatch, # type: ignore[attr-defined] # implementation detail
|
||||||
|
)
|
||||||
|
|
||||||
sqlalchemy_engine: Optional[Any] = None
|
sqlalchemy_engine: Optional[Any] = None
|
||||||
def get_sqlalchemy_connection() -> sqlalchemy.engine.base.Connection:
|
def get_sqlalchemy_connection() -> sqlalchemy.engine.base.Connection:
|
||||||
|
|
|
@ -2,7 +2,8 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from django.db import connection
|
from django.db import connection
|
||||||
from django.db.models.query import Q, QuerySet
|
from django.db.models.query import Q, QuerySet
|
||||||
from sqlalchemy.sql import column, func, literal
|
from sqlalchemy import Text
|
||||||
|
from sqlalchemy.sql import ColumnElement, column, func, literal
|
||||||
|
|
||||||
from zerver.lib.request import REQ
|
from zerver.lib.request import REQ
|
||||||
from zerver.models import Message, Stream, UserMessage, UserProfile
|
from zerver.models import Message, Stream, UserMessage, UserProfile
|
||||||
|
@ -65,14 +66,14 @@ using "subject" in the DB sense, and nothing customer facing.
|
||||||
DB_TOPIC_NAME = "subject"
|
DB_TOPIC_NAME = "subject"
|
||||||
MESSAGE__TOPIC = 'message__subject'
|
MESSAGE__TOPIC = 'message__subject'
|
||||||
|
|
||||||
def topic_match_sa(topic_name: str) -> Any:
|
def topic_match_sa(topic_name: str) -> "ColumnElement[bool]":
|
||||||
# _sa is short for SQLAlchemy, which we use mostly for
|
# _sa is short for SQLAlchemy, which we use mostly for
|
||||||
# queries that search messages
|
# queries that search messages
|
||||||
topic_cond = func.upper(column("subject")) == func.upper(literal(topic_name))
|
topic_cond = func.upper(column("subject", Text)) == func.upper(literal(topic_name))
|
||||||
return topic_cond
|
return topic_cond
|
||||||
|
|
||||||
def topic_column_sa() -> Any:
|
def topic_column_sa() -> "ColumnElement[str]":
|
||||||
return column("subject")
|
return column("subject", Text)
|
||||||
|
|
||||||
def filter_by_exact_message_topic(query: QuerySet, message: Message) -> QuerySet:
|
def filter_by_exact_message_topic(query: QuerySet, message: Message) -> QuerySet:
|
||||||
topic_name = message.topic_name()
|
topic_name = message.topic_name()
|
||||||
|
|
|
@ -2,7 +2,7 @@ import datetime
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from django.utils.timezone import now as timezone_now
|
from django.utils.timezone import now as timezone_now
|
||||||
from sqlalchemy.sql import Selectable, and_, column, not_, or_
|
from sqlalchemy.sql import ClauseElement, and_, column, not_, or_
|
||||||
|
|
||||||
from zerver.lib.timestamp import datetime_to_timestamp
|
from zerver.lib.timestamp import datetime_to_timestamp
|
||||||
from zerver.lib.topic import topic_match_sa
|
from zerver.lib.topic import topic_match_sa
|
||||||
|
@ -74,9 +74,9 @@ def topic_is_muted(user_profile: UserProfile, stream_id: int, topic_name: str) -
|
||||||
).exists()
|
).exists()
|
||||||
return is_muted
|
return is_muted
|
||||||
|
|
||||||
def exclude_topic_mutes(conditions: List[Selectable],
|
def exclude_topic_mutes(conditions: List[ClauseElement],
|
||||||
user_profile: UserProfile,
|
user_profile: UserProfile,
|
||||||
stream_id: Optional[int]) -> List[Selectable]:
|
stream_id: Optional[int]) -> List[ClauseElement]:
|
||||||
query = MutedTopic.objects.filter(
|
query = MutedTopic.objects.filter(
|
||||||
user_profile=user_profile,
|
user_profile=user_profile,
|
||||||
)
|
)
|
||||||
|
@ -95,7 +95,7 @@ def exclude_topic_mutes(conditions: List[Selectable],
|
||||||
if not rows:
|
if not rows:
|
||||||
return conditions
|
return conditions
|
||||||
|
|
||||||
def mute_cond(row: Dict[str, Any]) -> Selectable:
|
def mute_cond(row: Dict[str, Any]) -> ClauseElement:
|
||||||
recipient_id = row['recipient_id']
|
recipient_id = row['recipient_id']
|
||||||
topic_name = row['topic_name']
|
topic_name = row['topic_name']
|
||||||
stream_cond = column("recipient_id") == recipient_id
|
stream_cond = column("recipient_id") == recipient_id
|
||||||
|
|
|
@ -8,7 +8,7 @@ from django.db import connection
|
||||||
from django.http import HttpResponse
|
from django.http import HttpResponse
|
||||||
from django.test import override_settings
|
from django.test import override_settings
|
||||||
from django.utils.timezone import now as timezone_now
|
from django.utils.timezone import now as timezone_now
|
||||||
from sqlalchemy.sql import and_, column, select, table
|
from sqlalchemy.sql import Select, and_, column, select, table
|
||||||
from sqlalchemy.sql.elements import ClauseElement
|
from sqlalchemy.sql.elements import ClauseElement
|
||||||
|
|
||||||
from analytics.lib.counts import COUNT_STATS
|
from analytics.lib.counts import COUNT_STATS
|
||||||
|
@ -54,7 +54,6 @@ from zerver.views.message_fetch import (
|
||||||
LARGER_THAN_MAX_MESSAGE_ID,
|
LARGER_THAN_MAX_MESSAGE_ID,
|
||||||
BadNarrowOperator,
|
BadNarrowOperator,
|
||||||
NarrowBuilder,
|
NarrowBuilder,
|
||||||
Query,
|
|
||||||
exclude_muting_conditions,
|
exclude_muting_conditions,
|
||||||
find_first_unread_anchor,
|
find_first_unread_anchor,
|
||||||
get_messages_backend,
|
get_messages_backend,
|
||||||
|
@ -454,7 +453,7 @@ class NarrowBuilderTest(ZulipTestCase):
|
||||||
term = dict(operator='stream', operand='non-web-public-stream')
|
term = dict(operator='stream', operand='non-web-public-stream')
|
||||||
builder = NarrowBuilder(self.user_profile, column('id'), self.realm, True)
|
builder = NarrowBuilder(self.user_profile, column('id'), self.realm, True)
|
||||||
|
|
||||||
def _build_query(term: Dict[str, Any]) -> Query:
|
def _build_query(term: Dict[str, Any]) -> Select:
|
||||||
return builder.add_term(self.raw_query, term)
|
return builder.add_term(self.raw_query, term)
|
||||||
|
|
||||||
self.assertRaises(BadNarrowOperator, _build_query, term)
|
self.assertRaises(BadNarrowOperator, _build_query, term)
|
||||||
|
@ -467,7 +466,7 @@ class NarrowBuilderTest(ZulipTestCase):
|
||||||
self.assertEqual(actual_params, params)
|
self.assertEqual(actual_params, params)
|
||||||
self.assertIn(where_clause, get_sqlalchemy_sql(query))
|
self.assertIn(where_clause, get_sqlalchemy_sql(query))
|
||||||
|
|
||||||
def _build_query(self, term: Dict[str, Any]) -> Query:
|
def _build_query(self, term: Dict[str, Any]) -> Select:
|
||||||
return self.builder.add_term(self.raw_query, term)
|
return self.builder.add_term(self.raw_query, term)
|
||||||
|
|
||||||
class NarrowLibraryTest(ZulipTestCase):
|
class NarrowLibraryTest(ZulipTestCase):
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import re
|
import re
|
||||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
import orjson
|
import orjson
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
|
@ -11,9 +11,12 @@ from django.utils.html import escape as escape_html
|
||||||
from django.utils.translation import ugettext as _
|
from django.utils.translation import ugettext as _
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
from sqlalchemy.dialects import postgresql
|
from sqlalchemy.dialects import postgresql
|
||||||
|
from sqlalchemy.engine import Connection, RowProxy
|
||||||
from sqlalchemy.sql import (
|
from sqlalchemy.sql import (
|
||||||
|
ClauseElement,
|
||||||
ColumnElement,
|
ColumnElement,
|
||||||
Selectable,
|
FromClause,
|
||||||
|
Select,
|
||||||
alias,
|
alias,
|
||||||
and_,
|
and_,
|
||||||
column,
|
column,
|
||||||
|
@ -26,6 +29,7 @@ from sqlalchemy.sql import (
|
||||||
table,
|
table,
|
||||||
union_all,
|
union_all,
|
||||||
)
|
)
|
||||||
|
from sqlalchemy.types import Boolean, Integer, Text
|
||||||
|
|
||||||
from zerver.context_processors import get_valid_realm_from_request
|
from zerver.context_processors import get_valid_realm_from_request
|
||||||
from zerver.decorator import REQ, has_request_variables
|
from zerver.decorator import REQ, has_request_variables
|
||||||
|
@ -84,11 +88,7 @@ class BadNarrowOperator(JsonableError):
|
||||||
def msg_format() -> str:
|
def msg_format() -> str:
|
||||||
return _('Invalid narrow operator: {desc}')
|
return _('Invalid narrow operator: {desc}')
|
||||||
|
|
||||||
# TODO: Should be Select, but sqlalchemy stubs are busted
|
ConditionTransform = Callable[[ClauseElement], ClauseElement]
|
||||||
Query = Any
|
|
||||||
|
|
||||||
# TODO: should be Callable[[ColumnElement], ColumnElement], but sqlalchemy stubs are busted
|
|
||||||
ConditionTransform = Any
|
|
||||||
|
|
||||||
OptionalNarrowListT = Optional[List[Dict[str, Any]]]
|
OptionalNarrowListT = Optional[List[Dict[str, Any]]]
|
||||||
|
|
||||||
|
@ -97,21 +97,24 @@ TS_START = "<ts-match>"
|
||||||
TS_STOP = "</ts-match>"
|
TS_STOP = "</ts-match>"
|
||||||
|
|
||||||
def ts_locs_array(
|
def ts_locs_array(
|
||||||
config: ColumnElement, text: ColumnElement, tsquery: ColumnElement,
|
config: "ColumnElement[str]", text: "ColumnElement[str]", tsquery: "ColumnElement[object]",
|
||||||
) -> ColumnElement:
|
) -> "ColumnElement[List[List[int]]]":
|
||||||
options = f"HighlightAll = TRUE, StartSel = {TS_START}, StopSel = {TS_STOP}"
|
options = f"HighlightAll = TRUE, StartSel = {TS_START}, StopSel = {TS_STOP}"
|
||||||
delimited = func.ts_headline(config, text, tsquery, options)
|
delimited = func.ts_headline(config, text, tsquery, options)
|
||||||
parts = func.unnest(func.string_to_array(delimited, TS_START)).alias()
|
parts = func.unnest(func.string_to_array(delimited, TS_START)).alias()
|
||||||
part = column(parts.name)
|
part = column(parts.name, Text)
|
||||||
part_len = func.length(part) - len(TS_STOP)
|
part_len = func.length(part) - len(TS_STOP)
|
||||||
match_pos = func.sum(part_len).over(rows=(None, -1)) + len(TS_STOP)
|
match_pos = func.sum(part_len).over(rows=(None, -1)) + len(TS_STOP)
|
||||||
match_len = func.strpos(part, TS_STOP) - 1
|
match_len = func.strpos(part, TS_STOP) - 1
|
||||||
return func.array(
|
ret = func.array(
|
||||||
select([postgresql.array([match_pos, match_len])])
|
select([
|
||||||
|
postgresql.array([match_pos, match_len]), # type: ignore[call-overload] # https://github.com/dropbox/sqlalchemy-stubs/issues/188
|
||||||
|
])
|
||||||
.select_from(parts)
|
.select_from(parts)
|
||||||
.offset(1)
|
.offset(1)
|
||||||
.as_scalar(),
|
.as_scalar(),
|
||||||
)
|
)
|
||||||
|
return ret
|
||||||
|
|
||||||
# When you add a new operator to this, also update zerver/lib/narrow.py
|
# When you add a new operator to this, also update zerver/lib/narrow.py
|
||||||
class NarrowBuilder:
|
class NarrowBuilder:
|
||||||
|
@ -124,8 +127,8 @@ class NarrowBuilder:
|
||||||
# None of these methods ever *add* messages to a query's result.
|
# None of these methods ever *add* messages to a query's result.
|
||||||
#
|
#
|
||||||
# That is, the `add_term` method, and its helpers the `by_*` methods,
|
# That is, the `add_term` method, and its helpers the `by_*` methods,
|
||||||
# are passed a Query object representing a query for messages; they may
|
# are passed a Select object representing a query for messages; they may
|
||||||
# call some methods on it, and then they return a resulting Query
|
# call some methods on it, and then they return a resulting Select
|
||||||
# object. Things these methods may do to the queries they handle
|
# object. Things these methods may do to the queries they handle
|
||||||
# include
|
# include
|
||||||
# * add conditions to filter out rows (i.e., messages), with `query.where`
|
# * add conditions to filter out rows (i.e., messages), with `query.where`
|
||||||
|
@ -136,14 +139,14 @@ class NarrowBuilder:
|
||||||
# * anything that would pull in additional rows, or information on
|
# * anything that would pull in additional rows, or information on
|
||||||
# other messages.
|
# other messages.
|
||||||
|
|
||||||
def __init__(self, user_profile: Optional[UserProfile], msg_id_column: str,
|
def __init__(self, user_profile: Optional[UserProfile], msg_id_column: "ColumnElement[int]",
|
||||||
realm: Realm, is_web_public_query: bool=False) -> None:
|
realm: Realm, is_web_public_query: bool=False) -> None:
|
||||||
self.user_profile = user_profile
|
self.user_profile = user_profile
|
||||||
self.msg_id_column = msg_id_column
|
self.msg_id_column = msg_id_column
|
||||||
self.realm = realm
|
self.realm = realm
|
||||||
self.is_web_public_query = is_web_public_query
|
self.is_web_public_query = is_web_public_query
|
||||||
|
|
||||||
def add_term(self, query: Query, term: Dict[str, Any]) -> Query:
|
def add_term(self, query: Select, term: Dict[str, Any]) -> Select:
|
||||||
"""
|
"""
|
||||||
Extend the given query to one narrowed by the given term, and return the result.
|
Extend the given query to one narrowed by the given term, and return the result.
|
||||||
|
|
||||||
|
@ -176,14 +179,14 @@ class NarrowBuilder:
|
||||||
|
|
||||||
return method(query, operand, maybe_negate)
|
return method(query, operand, maybe_negate)
|
||||||
|
|
||||||
def by_has(self, query: Query, operand: str, maybe_negate: ConditionTransform) -> Query:
|
def by_has(self, query: Select, operand: str, maybe_negate: ConditionTransform) -> Select:
|
||||||
if operand not in ['attachment', 'image', 'link']:
|
if operand not in ['attachment', 'image', 'link']:
|
||||||
raise BadNarrowOperator("unknown 'has' operand " + operand)
|
raise BadNarrowOperator("unknown 'has' operand " + operand)
|
||||||
col_name = 'has_' + operand
|
col_name = 'has_' + operand
|
||||||
cond = column(col_name)
|
cond = column(col_name, Boolean)
|
||||||
return query.where(maybe_negate(cond))
|
return query.where(maybe_negate(cond))
|
||||||
|
|
||||||
def by_in(self, query: Query, operand: str, maybe_negate: ConditionTransform) -> Query:
|
def by_in(self, query: Select, operand: str, maybe_negate: ConditionTransform) -> Select:
|
||||||
# This operator does not support is_web_public_query.
|
# This operator does not support is_web_public_query.
|
||||||
assert not self.is_web_public_query
|
assert not self.is_web_public_query
|
||||||
assert self.user_profile is not None
|
assert self.user_profile is not None
|
||||||
|
@ -196,27 +199,27 @@ class NarrowBuilder:
|
||||||
|
|
||||||
raise BadNarrowOperator("unknown 'in' operand " + operand)
|
raise BadNarrowOperator("unknown 'in' operand " + operand)
|
||||||
|
|
||||||
def by_is(self, query: Query, operand: str, maybe_negate: ConditionTransform) -> Query:
|
def by_is(self, query: Select, operand: str, maybe_negate: ConditionTransform) -> Select:
|
||||||
# This operator class does not support is_web_public_query.
|
# This operator class does not support is_web_public_query.
|
||||||
assert not self.is_web_public_query
|
assert not self.is_web_public_query
|
||||||
assert self.user_profile is not None
|
assert self.user_profile is not None
|
||||||
|
|
||||||
if operand == 'private':
|
if operand == 'private':
|
||||||
cond = column("flags").op("&")(UserMessage.flags.is_private.mask) != 0
|
cond = column("flags", Integer).op("&")(UserMessage.flags.is_private.mask) != 0
|
||||||
return query.where(maybe_negate(cond))
|
return query.where(maybe_negate(cond))
|
||||||
elif operand == 'starred':
|
elif operand == 'starred':
|
||||||
cond = column("flags").op("&")(UserMessage.flags.starred.mask) != 0
|
cond = column("flags", Integer).op("&")(UserMessage.flags.starred.mask) != 0
|
||||||
return query.where(maybe_negate(cond))
|
return query.where(maybe_negate(cond))
|
||||||
elif operand == 'unread':
|
elif operand == 'unread':
|
||||||
cond = column("flags").op("&")(UserMessage.flags.read.mask) == 0
|
cond = column("flags", Integer).op("&")(UserMessage.flags.read.mask) == 0
|
||||||
return query.where(maybe_negate(cond))
|
return query.where(maybe_negate(cond))
|
||||||
elif operand == 'mentioned':
|
elif operand == 'mentioned':
|
||||||
cond1 = column("flags").op("&")(UserMessage.flags.mentioned.mask) != 0
|
cond1 = column("flags", Integer).op("&")(UserMessage.flags.mentioned.mask) != 0
|
||||||
cond2 = column("flags").op("&")(UserMessage.flags.wildcard_mentioned.mask) != 0
|
cond2 = column("flags", Integer).op("&")(UserMessage.flags.wildcard_mentioned.mask) != 0
|
||||||
cond = or_(cond1, cond2)
|
cond = or_(cond1, cond2)
|
||||||
return query.where(maybe_negate(cond))
|
return query.where(maybe_negate(cond))
|
||||||
elif operand == 'alerted':
|
elif operand == 'alerted':
|
||||||
cond = column("flags").op("&")(UserMessage.flags.has_alert_word.mask) != 0
|
cond = column("flags", Integer).op("&")(UserMessage.flags.has_alert_word.mask) != 0
|
||||||
return query.where(maybe_negate(cond))
|
return query.where(maybe_negate(cond))
|
||||||
raise BadNarrowOperator("unknown 'is' operand " + operand)
|
raise BadNarrowOperator("unknown 'is' operand " + operand)
|
||||||
|
|
||||||
|
@ -242,7 +245,7 @@ class NarrowBuilder:
|
||||||
s[i] = '\\' + c
|
s[i] = '\\' + c
|
||||||
return ''.join(s)
|
return ''.join(s)
|
||||||
|
|
||||||
def by_stream(self, query: Query, operand: Union[str, int], maybe_negate: ConditionTransform) -> Query:
|
def by_stream(self, query: Select, operand: Union[str, int], maybe_negate: ConditionTransform) -> Select:
|
||||||
try:
|
try:
|
||||||
# Because you can see your own message history for
|
# Because you can see your own message history for
|
||||||
# private streams you are no longer subscribed to, we
|
# private streams you are no longer subscribed to, we
|
||||||
|
@ -274,14 +277,14 @@ class NarrowBuilder:
|
||||||
matching_streams = get_active_streams(self.realm).filter(
|
matching_streams = get_active_streams(self.realm).filter(
|
||||||
name__iregex=fr'^(un)*{self._pg_re_escape(base_stream_name)}(\.d)*$')
|
name__iregex=fr'^(un)*{self._pg_re_escape(base_stream_name)}(\.d)*$')
|
||||||
recipient_ids = [matching_stream.recipient_id for matching_stream in matching_streams]
|
recipient_ids = [matching_stream.recipient_id for matching_stream in matching_streams]
|
||||||
cond = column("recipient_id").in_(recipient_ids)
|
cond = column("recipient_id", Integer).in_(recipient_ids)
|
||||||
return query.where(maybe_negate(cond))
|
return query.where(maybe_negate(cond))
|
||||||
|
|
||||||
recipient = stream.recipient
|
recipient = stream.recipient
|
||||||
cond = column("recipient_id") == recipient.id
|
cond = column("recipient_id", Integer) == recipient.id
|
||||||
return query.where(maybe_negate(cond))
|
return query.where(maybe_negate(cond))
|
||||||
|
|
||||||
def by_streams(self, query: Query, operand: str, maybe_negate: ConditionTransform) -> Query:
|
def by_streams(self, query: Select, operand: str, maybe_negate: ConditionTransform) -> Select:
|
||||||
if operand == 'public':
|
if operand == 'public':
|
||||||
# Get all both subscribed and non subscribed public streams
|
# Get all both subscribed and non subscribed public streams
|
||||||
# but exclude any private subscribed streams.
|
# but exclude any private subscribed streams.
|
||||||
|
@ -292,10 +295,10 @@ class NarrowBuilder:
|
||||||
raise BadNarrowOperator('unknown streams operand ' + operand)
|
raise BadNarrowOperator('unknown streams operand ' + operand)
|
||||||
|
|
||||||
recipient_ids = recipient_queryset.values_list("recipient_id", flat=True).order_by('id')
|
recipient_ids = recipient_queryset.values_list("recipient_id", flat=True).order_by('id')
|
||||||
cond = column("recipient_id").in_(recipient_ids)
|
cond = column("recipient_id", Integer).in_(recipient_ids)
|
||||||
return query.where(maybe_negate(cond))
|
return query.where(maybe_negate(cond))
|
||||||
|
|
||||||
def by_topic(self, query: Query, operand: str, maybe_negate: ConditionTransform) -> Query:
|
def by_topic(self, query: Select, operand: str, maybe_negate: ConditionTransform) -> Select:
|
||||||
if self.realm.is_zephyr_mirror_realm:
|
if self.realm.is_zephyr_mirror_realm:
|
||||||
# MIT users expect narrowing to topic "foo" to also show messages to /^foo(.d)*$/
|
# MIT users expect narrowing to topic "foo" to also show messages to /^foo(.d)*$/
|
||||||
# (foo, foo.d, foo.d.d, etc)
|
# (foo, foo.d, foo.d.d, etc)
|
||||||
|
@ -307,7 +310,7 @@ class NarrowBuilder:
|
||||||
# Additionally, MIT users expect the empty instance and
|
# Additionally, MIT users expect the empty instance and
|
||||||
# instance "personal" to be the same.
|
# instance "personal" to be the same.
|
||||||
if base_topic in ('', 'personal', '(instance "")'):
|
if base_topic in ('', 'personal', '(instance "")'):
|
||||||
cond = or_(
|
cond: ClauseElement = or_(
|
||||||
topic_match_sa(""),
|
topic_match_sa(""),
|
||||||
topic_match_sa(".d"),
|
topic_match_sa(".d"),
|
||||||
topic_match_sa(".d.d"),
|
topic_match_sa(".d.d"),
|
||||||
|
@ -340,7 +343,7 @@ class NarrowBuilder:
|
||||||
cond = topic_match_sa(operand)
|
cond = topic_match_sa(operand)
|
||||||
return query.where(maybe_negate(cond))
|
return query.where(maybe_negate(cond))
|
||||||
|
|
||||||
def by_sender(self, query: Query, operand: Union[str, int], maybe_negate: ConditionTransform) -> Query:
|
def by_sender(self, query: Select, operand: Union[str, int], maybe_negate: ConditionTransform) -> Select:
|
||||||
try:
|
try:
|
||||||
if isinstance(operand, str):
|
if isinstance(operand, str):
|
||||||
sender = get_user_including_cross_realm(operand, self.realm)
|
sender = get_user_including_cross_realm(operand, self.realm)
|
||||||
|
@ -349,20 +352,20 @@ class NarrowBuilder:
|
||||||
except UserProfile.DoesNotExist:
|
except UserProfile.DoesNotExist:
|
||||||
raise BadNarrowOperator('unknown user ' + str(operand))
|
raise BadNarrowOperator('unknown user ' + str(operand))
|
||||||
|
|
||||||
cond = column("sender_id") == literal(sender.id)
|
cond = column("sender_id", Integer) == literal(sender.id)
|
||||||
return query.where(maybe_negate(cond))
|
return query.where(maybe_negate(cond))
|
||||||
|
|
||||||
def by_near(self, query: Query, operand: str, maybe_negate: ConditionTransform) -> Query:
|
def by_near(self, query: Select, operand: str, maybe_negate: ConditionTransform) -> Select:
|
||||||
return query
|
return query
|
||||||
|
|
||||||
def by_id(self, query: Query, operand: str, maybe_negate: ConditionTransform) -> Query:
|
def by_id(self, query: Select, operand: Union[int, str], maybe_negate: ConditionTransform) -> Select:
|
||||||
if not str(operand).isdigit():
|
if not str(operand).isdigit():
|
||||||
raise BadNarrowOperator("Invalid message ID")
|
raise BadNarrowOperator("Invalid message ID")
|
||||||
cond = self.msg_id_column == literal(operand)
|
cond = self.msg_id_column == literal(operand)
|
||||||
return query.where(maybe_negate(cond))
|
return query.where(maybe_negate(cond))
|
||||||
|
|
||||||
def by_pm_with(self, query: Query, operand: Union[str, Iterable[int]],
|
def by_pm_with(self, query: Select, operand: Union[str, Iterable[int]],
|
||||||
maybe_negate: ConditionTransform) -> Query:
|
maybe_negate: ConditionTransform) -> Select:
|
||||||
# This operator does not support is_web_public_query.
|
# This operator does not support is_web_public_query.
|
||||||
assert not self.is_web_public_query
|
assert not self.is_web_public_query
|
||||||
assert self.user_profile is not None
|
assert self.user_profile is not None
|
||||||
|
@ -394,7 +397,7 @@ class NarrowBuilder:
|
||||||
|
|
||||||
# Group DM
|
# Group DM
|
||||||
if recipient.type == Recipient.HUDDLE:
|
if recipient.type == Recipient.HUDDLE:
|
||||||
cond = column("recipient_id") == recipient.id
|
cond = column("recipient_id", Integer) == recipient.id
|
||||||
return query.where(maybe_negate(cond))
|
return query.where(maybe_negate(cond))
|
||||||
|
|
||||||
# 1:1 PM
|
# 1:1 PM
|
||||||
|
@ -414,19 +417,19 @@ class NarrowBuilder:
|
||||||
# complex query to get messages between these two users
|
# complex query to get messages between these two users
|
||||||
# with either of them as the sender.
|
# with either of them as the sender.
|
||||||
self_recipient_id = self.user_profile.recipient_id
|
self_recipient_id = self.user_profile.recipient_id
|
||||||
cond = or_(and_(column("sender_id") == other_participant.id,
|
cond = or_(and_(column("sender_id", Integer) == other_participant.id,
|
||||||
column("recipient_id") == self_recipient_id),
|
column("recipient_id", Integer) == self_recipient_id),
|
||||||
and_(column("sender_id") == self.user_profile.id,
|
and_(column("sender_id", Integer) == self.user_profile.id,
|
||||||
column("recipient_id") == recipient.id))
|
column("recipient_id", Integer) == recipient.id))
|
||||||
return query.where(maybe_negate(cond))
|
return query.where(maybe_negate(cond))
|
||||||
|
|
||||||
# PM with self
|
# PM with self
|
||||||
cond = and_(column("sender_id") == self.user_profile.id,
|
cond = and_(column("sender_id", Integer) == self.user_profile.id,
|
||||||
column("recipient_id") == recipient.id)
|
column("recipient_id", Integer) == recipient.id)
|
||||||
return query.where(maybe_negate(cond))
|
return query.where(maybe_negate(cond))
|
||||||
|
|
||||||
def by_group_pm_with(self, query: Query, operand: Union[str, int],
|
def by_group_pm_with(self, query: Select, operand: Union[str, int],
|
||||||
maybe_negate: ConditionTransform) -> Query:
|
maybe_negate: ConditionTransform) -> Select:
|
||||||
# This operator does not support is_web_public_query.
|
# This operator does not support is_web_public_query.
|
||||||
assert not self.is_web_public_query
|
assert not self.is_web_public_query
|
||||||
assert self.user_profile is not None
|
assert self.user_profile is not None
|
||||||
|
@ -453,33 +456,33 @@ class NarrowBuilder:
|
||||||
).values("recipient_id")]
|
).values("recipient_id")]
|
||||||
|
|
||||||
recipient_ids = set(self_recipient_ids) & set(narrow_recipient_ids)
|
recipient_ids = set(self_recipient_ids) & set(narrow_recipient_ids)
|
||||||
cond = column("recipient_id").in_(recipient_ids)
|
cond = column("recipient_id", Integer).in_(recipient_ids)
|
||||||
return query.where(maybe_negate(cond))
|
return query.where(maybe_negate(cond))
|
||||||
|
|
||||||
def by_search(self, query: Query, operand: str, maybe_negate: ConditionTransform) -> Query:
|
def by_search(self, query: Select, operand: str, maybe_negate: ConditionTransform) -> Select:
|
||||||
if settings.USING_PGROONGA:
|
if settings.USING_PGROONGA:
|
||||||
return self._by_search_pgroonga(query, operand, maybe_negate)
|
return self._by_search_pgroonga(query, operand, maybe_negate)
|
||||||
else:
|
else:
|
||||||
return self._by_search_tsearch(query, operand, maybe_negate)
|
return self._by_search_tsearch(query, operand, maybe_negate)
|
||||||
|
|
||||||
def _by_search_pgroonga(self, query: Query, operand: str,
|
def _by_search_pgroonga(self, query: Select, operand: str,
|
||||||
maybe_negate: ConditionTransform) -> Query:
|
maybe_negate: ConditionTransform) -> Select:
|
||||||
match_positions_character = func.pgroonga_match_positions_character
|
match_positions_character = func.pgroonga_match_positions_character
|
||||||
query_extract_keywords = func.pgroonga_query_extract_keywords
|
query_extract_keywords = func.pgroonga_query_extract_keywords
|
||||||
operand_escaped = func.escape_html(operand)
|
operand_escaped = func.escape_html(operand)
|
||||||
keywords = query_extract_keywords(operand_escaped)
|
keywords = query_extract_keywords(operand_escaped)
|
||||||
query = query.column(match_positions_character(column("rendered_content"),
|
query = query.column(match_positions_character(column("rendered_content", Text),
|
||||||
keywords).label("content_matches"))
|
keywords).label("content_matches"))
|
||||||
query = query.column(match_positions_character(func.escape_html(topic_column_sa()),
|
query = query.column(match_positions_character(func.escape_html(topic_column_sa()),
|
||||||
keywords).label("topic_matches"))
|
keywords).label("topic_matches"))
|
||||||
condition = column("search_pgroonga").op("&@~")(operand_escaped)
|
condition = column("search_pgroonga").op("&@~")(operand_escaped)
|
||||||
return query.where(maybe_negate(condition))
|
return query.where(maybe_negate(condition))
|
||||||
|
|
||||||
def _by_search_tsearch(self, query: Query, operand: str,
|
def _by_search_tsearch(self, query: Select, operand: str,
|
||||||
maybe_negate: ConditionTransform) -> Query:
|
maybe_negate: ConditionTransform) -> Select:
|
||||||
tsquery = func.plainto_tsquery(literal("zulip.english_us_search"), literal(operand))
|
tsquery = func.plainto_tsquery(literal("zulip.english_us_search"), literal(operand))
|
||||||
query = query.column(ts_locs_array(literal("zulip.english_us_search"),
|
query = query.column(ts_locs_array(literal("zulip.english_us_search"),
|
||||||
column("rendered_content"),
|
column("rendered_content", Text),
|
||||||
tsquery).label("content_matches"))
|
tsquery).label("content_matches"))
|
||||||
# We HTML-escape the topic in PostgreSQL to avoid doing a server round-trip
|
# We HTML-escape the topic in PostgreSQL to avoid doing a server round-trip
|
||||||
query = query.column(ts_locs_array(literal("zulip.english_us_search"),
|
query = query.column(ts_locs_array(literal("zulip.english_us_search"),
|
||||||
|
@ -494,11 +497,11 @@ class NarrowBuilder:
|
||||||
if term[0] == '"' and term[-1] == '"':
|
if term[0] == '"' and term[-1] == '"':
|
||||||
term = term[1:-1]
|
term = term[1:-1]
|
||||||
term = '%' + connection.ops.prep_for_like_query(term) + '%'
|
term = '%' + connection.ops.prep_for_like_query(term) + '%'
|
||||||
cond = or_(column("content").ilike(term),
|
cond = or_(column("content", Text).ilike(term),
|
||||||
topic_column_sa().ilike(term))
|
topic_column_sa().ilike(term))
|
||||||
query = query.where(maybe_negate(cond))
|
query = query.where(maybe_negate(cond))
|
||||||
|
|
||||||
cond = column("search_tsvector").op("@@")(tsquery)
|
cond = column("search_tsvector", postgresql.TSVECTOR).op("@@")(tsquery)
|
||||||
return query.where(maybe_negate(cond))
|
return query.where(maybe_negate(cond))
|
||||||
|
|
||||||
def highlight_string(text: str, locs: Iterable[Tuple[int, int]]) -> str:
|
def highlight_string(text: str, locs: Iterable[Tuple[int, int]]) -> str:
|
||||||
|
@ -658,7 +661,7 @@ def get_stream_from_narrow_access_unchecked(narrow: OptionalNarrowListT, realm:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def exclude_muting_conditions(user_profile: UserProfile,
|
def exclude_muting_conditions(user_profile: UserProfile,
|
||||||
narrow: OptionalNarrowListT) -> List[Selectable]:
|
narrow: OptionalNarrowListT) -> List[ClauseElement]:
|
||||||
conditions = []
|
conditions = []
|
||||||
stream_id = None
|
stream_id = None
|
||||||
try:
|
try:
|
||||||
|
@ -681,7 +684,7 @@ def exclude_muting_conditions(user_profile: UserProfile,
|
||||||
muted_recipient_ids = [row['recipient_id'] for row in rows]
|
muted_recipient_ids = [row['recipient_id'] for row in rows]
|
||||||
if len(muted_recipient_ids) > 0:
|
if len(muted_recipient_ids) > 0:
|
||||||
# Only add the condition if we have muted streams to simplify/avoid warnings.
|
# Only add the condition if we have muted streams to simplify/avoid warnings.
|
||||||
condition = not_(column("recipient_id").in_(muted_recipient_ids))
|
condition = not_(column("recipient_id", Integer).in_(muted_recipient_ids))
|
||||||
conditions.append(condition)
|
conditions.append(condition)
|
||||||
|
|
||||||
conditions = exclude_topic_mutes(conditions, user_profile, stream_id)
|
conditions = exclude_topic_mutes(conditions, user_profile, stream_id)
|
||||||
|
@ -690,38 +693,39 @@ def exclude_muting_conditions(user_profile: UserProfile,
|
||||||
|
|
||||||
def get_base_query_for_search(user_profile: Optional[UserProfile],
|
def get_base_query_for_search(user_profile: Optional[UserProfile],
|
||||||
need_message: bool,
|
need_message: bool,
|
||||||
need_user_message: bool) -> Tuple[Query, ColumnElement]:
|
need_user_message: bool) -> Tuple[Select, "ColumnElement[int]"]:
|
||||||
# Handle the simple case where user_message isn't involved first.
|
# Handle the simple case where user_message isn't involved first.
|
||||||
if not need_user_message:
|
if not need_user_message:
|
||||||
assert(need_message)
|
assert(need_message)
|
||||||
query = select([column("id").label("message_id")],
|
query = select([column("id", Integer).label("message_id")],
|
||||||
None,
|
None,
|
||||||
table("zerver_message"))
|
table("zerver_message"))
|
||||||
inner_msg_id_col = literal_column("zerver_message.id")
|
inner_msg_id_col: ColumnElement[int]
|
||||||
|
inner_msg_id_col = literal_column("zerver_message.id", Integer) # type: ignore[assignment] # https://github.com/dropbox/sqlalchemy-stubs/pull/189
|
||||||
return (query, inner_msg_id_col)
|
return (query, inner_msg_id_col)
|
||||||
|
|
||||||
assert user_profile is not None
|
assert user_profile is not None
|
||||||
if need_message:
|
if need_message:
|
||||||
query = select([column("message_id"), column("flags")],
|
query = select([column("message_id"), column("flags", Integer)],
|
||||||
column("user_profile_id") == literal(user_profile.id),
|
column("user_profile_id") == literal(user_profile.id),
|
||||||
join(table("zerver_usermessage"), table("zerver_message"),
|
join(table("zerver_usermessage"), table("zerver_message"),
|
||||||
literal_column("zerver_usermessage.message_id") ==
|
literal_column("zerver_usermessage.message_id", Integer) ==
|
||||||
literal_column("zerver_message.id")))
|
literal_column("zerver_message.id", Integer)))
|
||||||
inner_msg_id_col = column("message_id")
|
inner_msg_id_col = column("message_id", Integer)
|
||||||
return (query, inner_msg_id_col)
|
return (query, inner_msg_id_col)
|
||||||
|
|
||||||
query = select([column("message_id"), column("flags")],
|
query = select([column("message_id"), column("flags", Integer)],
|
||||||
column("user_profile_id") == literal(user_profile.id),
|
column("user_profile_id") == literal(user_profile.id),
|
||||||
table("zerver_usermessage"))
|
table("zerver_usermessage"))
|
||||||
inner_msg_id_col = column("message_id")
|
inner_msg_id_col = column("message_id", Integer)
|
||||||
return (query, inner_msg_id_col)
|
return (query, inner_msg_id_col)
|
||||||
|
|
||||||
def add_narrow_conditions(user_profile: Optional[UserProfile],
|
def add_narrow_conditions(user_profile: Optional[UserProfile],
|
||||||
inner_msg_id_col: ColumnElement,
|
inner_msg_id_col: "ColumnElement[int]",
|
||||||
query: Query,
|
query: Select,
|
||||||
narrow: OptionalNarrowListT,
|
narrow: OptionalNarrowListT,
|
||||||
is_web_public_query: bool,
|
is_web_public_query: bool,
|
||||||
realm: Realm) -> Tuple[Query, bool]:
|
realm: Realm) -> Tuple[Select, bool]:
|
||||||
is_search = False # for now
|
is_search = False # for now
|
||||||
|
|
||||||
if narrow is None:
|
if narrow is None:
|
||||||
|
@ -742,7 +746,7 @@ def add_narrow_conditions(user_profile: Optional[UserProfile],
|
||||||
|
|
||||||
if search_operands:
|
if search_operands:
|
||||||
is_search = True
|
is_search = True
|
||||||
query = query.column(topic_column_sa()).column(column("rendered_content"))
|
query = query.column(topic_column_sa()).column(column("rendered_content", Text))
|
||||||
search_term = dict(
|
search_term = dict(
|
||||||
operator='search',
|
operator='search',
|
||||||
operand=' '.join(search_operands),
|
operand=' '.join(search_operands),
|
||||||
|
@ -751,7 +755,7 @@ def add_narrow_conditions(user_profile: Optional[UserProfile],
|
||||||
|
|
||||||
return (query, is_search)
|
return (query, is_search)
|
||||||
|
|
||||||
def find_first_unread_anchor(sa_conn: Any,
|
def find_first_unread_anchor(sa_conn: Connection,
|
||||||
user_profile: Optional[UserProfile],
|
user_profile: Optional[UserProfile],
|
||||||
narrow: OptionalNarrowListT) -> int:
|
narrow: OptionalNarrowListT) -> int:
|
||||||
# For anonymous web users, all messages are treated as read, and so
|
# For anonymous web users, all messages are treated as read, and so
|
||||||
|
@ -785,7 +789,7 @@ def find_first_unread_anchor(sa_conn: Any,
|
||||||
realm=user_profile.realm,
|
realm=user_profile.realm,
|
||||||
)
|
)
|
||||||
|
|
||||||
condition = column("flags").op("&")(UserMessage.flags.read.mask) == 0
|
condition = column("flags", Integer).op("&")(UserMessage.flags.read.mask) == 0
|
||||||
|
|
||||||
# We exclude messages on muted topics when finding the first unread
|
# We exclude messages on muted topics when finding the first unread
|
||||||
# message in this narrow
|
# message in this narrow
|
||||||
|
@ -913,6 +917,7 @@ def get_messages_backend(request: HttpRequest,
|
||||||
need_message = True
|
need_message = True
|
||||||
need_user_message = True
|
need_user_message = True
|
||||||
|
|
||||||
|
query: FromClause
|
||||||
query, inner_msg_id_col = get_base_query_for_search(
|
query, inner_msg_id_col = get_base_query_for_search(
|
||||||
user_profile=user_profile,
|
user_profile=user_profile,
|
||||||
need_message=need_message,
|
need_message=need_message,
|
||||||
|
@ -970,7 +975,7 @@ def get_messages_backend(request: HttpRequest,
|
||||||
)
|
)
|
||||||
|
|
||||||
main_query = alias(query)
|
main_query = alias(query)
|
||||||
query = select(main_query.c, None, main_query).order_by(column("message_id").asc())
|
query = select(main_query.c, None, main_query).order_by(column("message_id", Integer).asc())
|
||||||
# This is a hack to tag the query we use for testing
|
# This is a hack to tag the query we use for testing
|
||||||
query = query.prefix_with("/* get_messages */")
|
query = query.prefix_with("/* get_messages */")
|
||||||
rows = list(sa_conn.execute(query).fetchall())
|
rows = list(sa_conn.execute(query).fetchall())
|
||||||
|
@ -1059,14 +1064,14 @@ def get_messages_backend(request: HttpRequest,
|
||||||
)
|
)
|
||||||
return json_success(ret)
|
return json_success(ret)
|
||||||
|
|
||||||
def limit_query_to_range(query: Query,
|
def limit_query_to_range(query: Select,
|
||||||
num_before: int,
|
num_before: int,
|
||||||
num_after: int,
|
num_after: int,
|
||||||
anchor: int,
|
anchor: int,
|
||||||
anchored_to_left: bool,
|
anchored_to_left: bool,
|
||||||
anchored_to_right: bool,
|
anchored_to_right: bool,
|
||||||
id_col: ColumnElement,
|
id_col: "ColumnElement[int]",
|
||||||
first_visible_message_id: int) -> Query:
|
first_visible_message_id: int) -> FromClause:
|
||||||
'''
|
'''
|
||||||
This code is actually generic enough that we could move it to a
|
This code is actually generic enough that we could move it to a
|
||||||
library, but our only caller for now is message search.
|
library, but our only caller for now is message search.
|
||||||
|
@ -1137,7 +1142,7 @@ def limit_query_to_range(query: Query,
|
||||||
# return at most one row here.
|
# return at most one row here.
|
||||||
return query.where(id_col == anchor)
|
return query.where(id_col == anchor)
|
||||||
|
|
||||||
def post_process_limited_query(rows: List[Any],
|
def post_process_limited_query(rows: Sequence[Union[RowProxy, Sequence[Any]]],
|
||||||
num_before: int,
|
num_before: int,
|
||||||
num_after: int,
|
num_after: int,
|
||||||
anchor: int,
|
anchor: int,
|
||||||
|
@ -1153,7 +1158,9 @@ def post_process_limited_query(rows: List[Any],
|
||||||
# that the clients will know that they got complete results.
|
# that the clients will know that they got complete results.
|
||||||
|
|
||||||
if first_visible_message_id > 0:
|
if first_visible_message_id > 0:
|
||||||
visible_rows = [r for r in rows if r[0] >= first_visible_message_id]
|
visible_rows: Sequence[Union[RowProxy, Sequence[Any]]] = [
|
||||||
|
r for r in rows if r[0] >= first_visible_message_id
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
visible_rows = rows
|
visible_rows = rows
|
||||||
|
|
||||||
|
@ -1162,8 +1169,8 @@ def post_process_limited_query(rows: List[Any],
|
||||||
if anchored_to_right:
|
if anchored_to_right:
|
||||||
num_after = 0
|
num_after = 0
|
||||||
before_rows = visible_rows[:]
|
before_rows = visible_rows[:]
|
||||||
anchor_rows: List[Any] = []
|
anchor_rows = []
|
||||||
after_rows: List[Any] = []
|
after_rows = []
|
||||||
else:
|
else:
|
||||||
before_rows = [r for r in visible_rows if r[0] < anchor]
|
before_rows = [r for r in visible_rows if r[0] < anchor]
|
||||||
anchor_rows = [r for r in visible_rows if r[0] == anchor]
|
anchor_rows = [r for r in visible_rows if r[0] == anchor]
|
||||||
|
@ -1175,7 +1182,7 @@ def post_process_limited_query(rows: List[Any],
|
||||||
if num_after:
|
if num_after:
|
||||||
after_rows = after_rows[:num_after]
|
after_rows = after_rows[:num_after]
|
||||||
|
|
||||||
visible_rows = before_rows + anchor_rows + after_rows
|
visible_rows = [*before_rows, *anchor_rows, *after_rows]
|
||||||
|
|
||||||
found_anchor = len(anchor_rows) == 1
|
found_anchor = len(anchor_rows) == 1
|
||||||
found_oldest = anchored_to_left or (len(before_rows) < num_before)
|
found_oldest = anchored_to_left or (len(before_rows) < num_before)
|
||||||
|
@ -1211,14 +1218,14 @@ def messages_in_narrow_backend(request: HttpRequest, user_profile: UserProfile,
|
||||||
msg_ids = [message_id for message_id in msg_ids if message_id >= first_visible_message_id]
|
msg_ids = [message_id for message_id in msg_ids if message_id >= first_visible_message_id]
|
||||||
# This query is limited to messages the user has access to because they
|
# This query is limited to messages the user has access to because they
|
||||||
# actually received them, as reflected in `zerver_usermessage`.
|
# actually received them, as reflected in `zerver_usermessage`.
|
||||||
query = select([column("message_id"), topic_column_sa(), column("rendered_content")],
|
query = select([column("message_id", Integer), topic_column_sa(), column("rendered_content", Text)],
|
||||||
and_(column("user_profile_id") == literal(user_profile.id),
|
and_(column("user_profile_id", Integer) == literal(user_profile.id),
|
||||||
column("message_id").in_(msg_ids)),
|
column("message_id", Integer).in_(msg_ids)),
|
||||||
join(table("zerver_usermessage"), table("zerver_message"),
|
join(table("zerver_usermessage"), table("zerver_message"),
|
||||||
literal_column("zerver_usermessage.message_id") ==
|
literal_column("zerver_usermessage.message_id", Integer) ==
|
||||||
literal_column("zerver_message.id")))
|
literal_column("zerver_message.id", Integer)))
|
||||||
|
|
||||||
builder = NarrowBuilder(user_profile, column("message_id"), user_profile.realm)
|
builder = NarrowBuilder(user_profile, column("message_id", Integer), user_profile.realm)
|
||||||
if narrow is not None:
|
if narrow is not None:
|
||||||
for term in narrow:
|
for term in narrow:
|
||||||
query = builder.add_term(query, term)
|
query = builder.add_term(query, term)
|
||||||
|
|
Loading…
Reference in New Issue