diff --git a/zerver/lib/validator.py b/zerver/lib/validator.py index 5e4e962714..1710d04e9c 100644 --- a/zerver/lib/validator.py +++ b/zerver/lib/validator.py @@ -337,3 +337,9 @@ def check_string_or_int_list(var_name: str, val: object) -> Optional[str]: return _('%s is not a string or an integer list') % (var_name,) return check_list(check_int)(var_name, val) + +def check_string_or_int(var_name: str, val: object) -> Optional[str]: + if isinstance(val, str) or isinstance(val, int): + return None + + return _('%s is not a string or integer') % (var_name,) diff --git a/zerver/tests/test_decorators.py b/zerver/tests/test_decorators.py index de04065f4c..26a52077f4 100644 --- a/zerver/tests/test_decorators.py +++ b/zerver/tests/test_decorators.py @@ -44,7 +44,7 @@ from zerver.lib.validator import ( check_string, check_dict, check_dict_only, check_bool, check_float, check_int, check_list, Validator, check_variable_type, equals, check_none_or, check_url, check_short_string, check_string_fixed_length, check_capped_string, check_color, to_non_negative_int, - check_string_or_int_list + check_string_or_int_list, check_string_or_int ) from zerver.models import \ get_realm, get_user, UserProfile, Realm @@ -919,6 +919,16 @@ class ValidatorTestCase(TestCase): x = [1, 2, '3'] self.assertEqual(check_string_or_int_list('x', x), 'x[2] is not an integer') + def test_check_string_or_int(self) -> None: + x = "string" # type: Any + self.assertEqual(check_string_or_int('x', x), None) + + x = 1 + self.assertEqual(check_string_or_int('x', x), None) + + x = None + self.assertEqual(check_string_or_int('x', x), 'x is not a string or integer') + class DeactivatedRealmTest(ZulipTestCase): def test_send_deactivated_realm(self) -> None: diff --git a/zerver/tests/test_narrow.py b/zerver/tests/test_narrow.py index 1f04151dfd..ddf475868e 100644 --- a/zerver/tests/test_narrow.py +++ b/zerver/tests/test_narrow.py @@ -1426,11 +1426,13 @@ class GetOldMessagesTest(ZulipTestCase): self.send_personal_message(self.example_email("othello"), self.example_email("hamlet")) self.send_stream_message(self.example_email("iago"), "Scotland") - narrow = [dict(operator='sender', operand=self.example_email("othello"))] - result = self.get_and_check_messages(dict(narrow=ujson.dumps(narrow))) + test_operands = [self.example_email("othello"), self.example_user("othello").id] + for operand in test_operands: + narrow = [dict(operator='sender', operand=operand)] + result = self.get_and_check_messages(dict(narrow=ujson.dumps(narrow))) - for message in result["messages"]: - self.assertEqual(message["sender_email"], self.example_email("othello")) + for message in result["messages"]: + self.assertEqual(message["sender_email"], self.example_email("othello")) def _update_tsvector_index(self) -> None: # We use brute force here and update our text search index diff --git a/zerver/views/messages.py b/zerver/views/messages.py index b38e53f862..51e5e8ae25 100644 --- a/zerver/views/messages.py +++ b/zerver/views/messages.py @@ -45,12 +45,13 @@ from zerver.lib.topic import ( from zerver.lib.topic_mutes import exclude_topic_mutes from zerver.lib.utils import statsd from zerver.lib.validator import \ - check_list, check_int, check_dict, check_string, check_bool, check_string_or_int_list + check_list, check_int, check_dict, check_string, check_bool, \ + check_string_or_int_list, check_string_or_int from zerver.lib.zephyr import compute_mit_user_fullname from zerver.models import Message, UserProfile, Stream, Subscription, Client,\ Realm, RealmDomain, Recipient, UserMessage, bulk_get_recipients, get_personal_recipient, \ get_stream, email_to_domain, get_realm, get_active_streams, \ - get_user_including_cross_realm, get_stream_recipient + get_user_including_cross_realm, get_user_by_id_in_realm_including_cross_realm, get_stream_recipient from sqlalchemy import func from sqlalchemy.sql import select, join, column, literal_column, literal, and_, \ @@ -282,11 +283,14 @@ class NarrowBuilder: cond = topic_match_sa(operand) return query.where(maybe_negate(cond)) - def by_sender(self, query: Query, operand: str, maybe_negate: ConditionTransform) -> Query: + def by_sender(self, query: Query, operand: Union[str, int], maybe_negate: ConditionTransform) -> Query: try: - sender = get_user_including_cross_realm(operand, self.user_realm) + if isinstance(operand, str): + sender = get_user_including_cross_realm(operand, self.user_realm) + else: + sender = get_user_by_id_in_realm_including_cross_realm(operand, self.user_realm) except UserProfile.DoesNotExist: - raise BadNarrowOperator('unknown user ' + operand) + raise BadNarrowOperator('unknown user ' + str(operand)) cond = column("sender_id") == literal(sender.id) return query.where(maybe_negate(cond)) @@ -513,9 +517,14 @@ def narrow_parameter(json: str) -> Optional[List[Dict[str, Any]]]: # Make sure to sync this list to frontend also when adding a new operator. # that supports user IDs. Relevant code is located in static/js/message_fetch.js # in handle_user_ids_supported_operators function where you will need to update - # the user_ids_supported_operator. + # user_id_supported_operator, or user_ids_supported_operator array. + user_id_supported_operator = ['sender'] user_ids_supported_operators = ['pm-with'] - if elem.get('operator', '') in user_ids_supported_operators: + + operator = elem.get('operator', '') + if operator in user_id_supported_operator: + operand_validator = check_string_or_int + elif operator in user_ids_supported_operators: operand_validator = check_string_or_int_list else: operand_validator = check_string