Use topic_match_sa() for topic searches.

Note this introduce literal(), which makes the way
we handle topic mutes more consistent with general
topic searches.
This commit is contained in:
Steve Howell 2018-11-01 21:15:43 +00:00 committed by Tim Abbott
parent 79d5e36ca3
commit ff60055fa4
3 changed files with 32 additions and 28 deletions

View File

@ -6,6 +6,7 @@ from django.utils.timezone import now as timezone_now
from sqlalchemy.sql import (
column,
literal,
func,
)
@ -26,7 +27,7 @@ PREV_TOPIC = "prev_subject"
def topic_match_sa(topic_name: str) -> Any:
# _sa is short for Sql Alchemy, which we use mostly for
# queries that search messages
topic_cond = func.upper(column("subject")) == func.upper(topic_name)
topic_cond = func.upper(column("subject")) == func.upper(literal(topic_name))
return topic_cond
def filter_by_exact_message_topic(query: QuerySet, message: Message) -> QuerySet:

View File

@ -2410,13 +2410,13 @@ class GetOldMessagesTest(ZulipTestCase):
expected_query = '''
SELECT id AS message_id
FROM zerver_message
WHERE NOT (recipient_id = :recipient_id_1 AND upper(subject) = upper(:upper_1))
WHERE NOT (recipient_id = :recipient_id_1 AND upper(subject) = upper(:param_1))
'''
self.assertEqual(fix_ws(query), fix_ws(expected_query))
params = get_sqlalchemy_query_params(query)
self.assertEqual(params['recipient_id_1'], get_recipient_id_for_stream_name(realm, 'Scotland'))
self.assertEqual(params['upper_1'], 'golf')
self.assertEqual(params['param_1'], 'golf')
mute_stream(realm, user_profile, 'Verona')
@ -2435,15 +2435,15 @@ class GetOldMessagesTest(ZulipTestCase):
FROM zerver_message
WHERE recipient_id NOT IN (:recipient_id_1)
AND NOT
(recipient_id = :recipient_id_2 AND upper(subject) = upper(:upper_1) OR
recipient_id = :recipient_id_3 AND upper(subject) = upper(:upper_2))'''
(recipient_id = :recipient_id_2 AND upper(subject) = upper(:param_1) OR
recipient_id = :recipient_id_3 AND upper(subject) = upper(:param_2))'''
self.assertEqual(fix_ws(query), fix_ws(expected_query))
params = get_sqlalchemy_query_params(query)
self.assertEqual(params['recipient_id_1'], get_recipient_id_for_stream_name(realm, 'Verona'))
self.assertEqual(params['recipient_id_2'], get_recipient_id_for_stream_name(realm, 'Scotland'))
self.assertEqual(params['upper_1'], 'golf')
self.assertEqual(params['param_1'], 'golf')
self.assertEqual(params['recipient_id_3'], get_recipient_id_for_stream_name(realm, 'web stuff'))
self.assertEqual(params['upper_2'], 'css')
self.assertEqual(params['param_2'], 'css')
def test_get_messages_queries(self) -> None:
query_ids = self.get_query_ids()

View File

@ -33,6 +33,9 @@ from zerver.lib.sqlalchemy_utils import get_sqlalchemy_connection
from zerver.lib.streams import access_stream_by_id, can_access_stream_history_by_name
from zerver.lib.timestamp import datetime_to_timestamp, convert_to_UTC
from zerver.lib.timezone import get_timezone
from zerver.lib.topic import (
topic_match_sa,
)
from zerver.lib.topic_mutes import exclude_topic_mutes
from zerver.lib.utils import statsd
from zerver.lib.validator import \
@ -241,36 +244,36 @@ class NarrowBuilder:
# instance "personal" to be the same.
if base_topic in ('', 'personal', '(instance "")'):
cond = or_(
func.upper(column("subject")) == func.upper(literal("")),
func.upper(column("subject")) == func.upper(literal(".d")),
func.upper(column("subject")) == func.upper(literal(".d.d")),
func.upper(column("subject")) == func.upper(literal(".d.d.d")),
func.upper(column("subject")) == func.upper(literal(".d.d.d.d")),
func.upper(column("subject")) == func.upper(literal("personal")),
func.upper(column("subject")) == func.upper(literal("personal.d")),
func.upper(column("subject")) == func.upper(literal("personal.d.d")),
func.upper(column("subject")) == func.upper(literal("personal.d.d.d")),
func.upper(column("subject")) == func.upper(literal("personal.d.d.d.d")),
func.upper(column("subject")) == func.upper(literal('(instance "")')),
func.upper(column("subject")) == func.upper(literal('(instance "").d')),
func.upper(column("subject")) == func.upper(literal('(instance "").d.d')),
func.upper(column("subject")) == func.upper(literal('(instance "").d.d.d')),
func.upper(column("subject")) == func.upper(literal('(instance "").d.d.d.d')),
topic_match_sa(""),
topic_match_sa(".d"),
topic_match_sa(".d.d"),
topic_match_sa(".d.d.d"),
topic_match_sa(".d.d.d.d"),
topic_match_sa("personal"),
topic_match_sa("personal.d"),
topic_match_sa("personal.d.d"),
topic_match_sa("personal.d.d.d"),
topic_match_sa("personal.d.d.d.d"),
topic_match_sa('(instance "")'),
topic_match_sa('(instance "").d'),
topic_match_sa('(instance "").d.d'),
topic_match_sa('(instance "").d.d.d'),
topic_match_sa('(instance "").d.d.d.d'),
)
else:
# We limit `.d` counts, since postgres has much better
# query planning for this than they do for a regular
# expression (which would sometimes table scan).
cond = or_(
func.upper(column("subject")) == func.upper(literal(base_topic)),
func.upper(column("subject")) == func.upper(literal(base_topic + ".d")),
func.upper(column("subject")) == func.upper(literal(base_topic + ".d.d")),
func.upper(column("subject")) == func.upper(literal(base_topic + ".d.d.d")),
func.upper(column("subject")) == func.upper(literal(base_topic + ".d.d.d.d")),
topic_match_sa(base_topic),
topic_match_sa(base_topic + ".d"),
topic_match_sa(base_topic + ".d.d"),
topic_match_sa(base_topic + ".d.d.d"),
topic_match_sa(base_topic + ".d.d.d.d"),
)
return query.where(maybe_negate(cond))
cond = func.upper(column("subject")) == func.upper(literal(operand))
cond = topic_match_sa(operand)
return query.where(maybe_negate(cond))
def by_sender(self, query: Query, operand: str, maybe_negate: ConditionTransform) -> Query: