From 50712bfa810bc6e314e6d9e98e8997f0802639fa Mon Sep 17 00:00:00 2001 From: Kenneth Rodrigues Date: Sun, 18 Aug 2024 17:46:43 +0530 Subject: [PATCH] scheduled_messages: Migrate to typed_endpoint. Migrate `scheduled_message.py` to typed_endpoint. Perform Json parsing in the endpoint itself instead of in `recipient_parsing.py`. --- zerver/actions/scheduled_messages.py | 2 +- zerver/lib/recipient_parsing.py | 24 +++------ zerver/lib/request.py | 7 ++- zerver/tests/test_recipient_parsing.py | 30 +++-------- zerver/tests/test_scheduled_messages.py | 2 +- zerver/views/scheduled_messages.py | 68 +++++++++++++++---------- 6 files changed, 62 insertions(+), 71 deletions(-) diff --git a/zerver/actions/scheduled_messages.py b/zerver/actions/scheduled_messages.py index 6a6e08995a..3729b532ef 100644 --- a/zerver/actions/scheduled_messages.py +++ b/zerver/actions/scheduled_messages.py @@ -133,7 +133,7 @@ def edit_scheduled_message( client: Client, scheduled_message_id: int, recipient_type_name: str | None, - message_to: str | None, + message_to: int | list[int] | None, topic_name: str | None, message_content: str | None, deliver_at: datetime | None, diff --git a/zerver/lib/recipient_parsing.py b/zerver/lib/recipient_parsing.py index 8e80904d5c..fa75c07ebe 100644 --- a/zerver/lib/recipient_parsing.py +++ b/zerver/lib/recipient_parsing.py @@ -1,29 +1,17 @@ -import orjson from django.utils.translation import gettext as _ from zerver.lib.exceptions import JsonableError -def extract_stream_id(req_to: str) -> int: +def extract_stream_id(req_to: int | list[int]) -> int: # Recipient should only be a single stream ID. - try: - stream_id = int(req_to) - except ValueError: + if isinstance(req_to, list): raise JsonableError(_("Invalid data type for channel ID")) - return stream_id + return req_to -def extract_direct_message_recipient_ids(req_to: str) -> list[int]: - try: - user_ids = orjson.loads(req_to) - except orjson.JSONDecodeError: - user_ids = req_to - - if not isinstance(user_ids, list): +def extract_direct_message_recipient_ids(req_to: int | list[int]) -> list[int]: + if not isinstance(req_to, list): raise JsonableError(_("Invalid data type for recipients")) - for user_id in user_ids: - if not isinstance(user_id, int): - raise JsonableError(_("Recipient list may only contain user IDs")) - - return list(set(user_ids)) + return list(set(req_to)) diff --git a/zerver/lib/request.py b/zerver/lib/request.py index f24886797e..6d9ea9d48b 100644 --- a/zerver/lib/request.py +++ b/zerver/lib/request.py @@ -363,9 +363,12 @@ def has_request_variables( # # TODO: Either run validators for path_only parameters # or don't declare them using REQ. - assert func_var_name in kwargs + + # no coverage because has_request_variables will be removed once + # all the endpoints have been migrated to use typed_endpoint. + assert func_var_name in kwargs # nocoverage if func_var_name in kwargs: - continue + continue # nocoverage assert func_var_name is not None post_var_name: str | None diff --git a/zerver/tests/test_recipient_parsing.py b/zerver/tests/test_recipient_parsing.py index 029555186f..345ccf8fea 100644 --- a/zerver/tests/test_recipient_parsing.py +++ b/zerver/tests/test_recipient_parsing.py @@ -1,5 +1,3 @@ -import orjson - from zerver.lib.exceptions import JsonableError from zerver.lib.recipient_parsing import extract_direct_message_recipient_ids, extract_stream_id from zerver.lib.test_classes import ZulipTestCase @@ -8,39 +6,25 @@ from zerver.lib.test_classes import ZulipTestCase class TestRecipientParsing(ZulipTestCase): def test_extract_stream_id(self) -> None: # stream message recipient = single stream ID. - stream_id = extract_stream_id("1") + stream_id = extract_stream_id(1) self.assertEqual(stream_id, 1) with self.assertRaisesRegex(JsonableError, "Invalid data type for channel ID"): - extract_stream_id("1,2") + extract_stream_id([1, 2]) with self.assertRaisesRegex(JsonableError, "Invalid data type for channel ID"): - extract_stream_id("[1]") - - with self.assertRaisesRegex(JsonableError, "Invalid data type for channel ID"): - extract_stream_id("general") + extract_stream_id([1]) def test_extract_recipient_ids(self) -> None: # direct message recipients = user IDs. - user_ids = "[3,2,1]" + user_ids = [3, 2, 1] result = sorted(extract_direct_message_recipient_ids(user_ids)) self.assertEqual(result, [1, 2, 3]) - # JSON list w/duplicates - user_ids = orjson.dumps([3, 3, 12]).decode() + # list w/duplicates + user_ids = [3, 3, 12] result = sorted(extract_direct_message_recipient_ids(user_ids)) self.assertEqual(result, [3, 12]) - # Invalid data - user_ids = "1, 12" with self.assertRaisesRegex(JsonableError, "Invalid data type for recipients"): - extract_direct_message_recipient_ids(user_ids) - - user_ids = orjson.dumps(dict(recipient=12)).decode() - with self.assertRaisesRegex(JsonableError, "Invalid data type for recipients"): - extract_direct_message_recipient_ids(user_ids) - - # Heterogeneous lists are not supported - user_ids = orjson.dumps([3, 4, "eeshan@example.com"]).decode() - with self.assertRaisesRegex(JsonableError, "Recipient list may only contain user IDs"): - extract_direct_message_recipient_ids(user_ids) + extract_direct_message_recipient_ids(1) diff --git a/zerver/tests/test_scheduled_messages.py b/zerver/tests/test_scheduled_messages.py index d5a5001633..03ee1c115d 100644 --- a/zerver/tests/test_scheduled_messages.py +++ b/zerver/tests/test_scheduled_messages.py @@ -91,7 +91,7 @@ class ScheduledMessageTest(ZulipTestCase): result = self.do_schedule_message( "direct", [othello.email], content + " 4", scheduled_delivery_timestamp ) - self.assert_json_error(result, "Recipient list may only contain user IDs") + self.assert_json_error(result, 'to["int"] is not an integer') def create_scheduled_message(self) -> None: content = "Test message" diff --git a/zerver/views/scheduled_messages.py b/zerver/views/scheduled_messages.py index cc1360be5e..538fb027e5 100644 --- a/zerver/views/scheduled_messages.py +++ b/zerver/views/scheduled_messages.py @@ -1,6 +1,9 @@ +from typing import Annotated + from django.http import HttpRequest, HttpResponse from django.utils.timezone import now as timezone_now from django.utils.translation import gettext as _ +from pydantic import Json, NonNegativeInt from zerver.actions.scheduled_messages import ( check_schedule_message, @@ -9,48 +12,57 @@ from zerver.actions.scheduled_messages import ( ) from zerver.lib.exceptions import JsonableError from zerver.lib.recipient_parsing import extract_direct_message_recipient_ids, extract_stream_id -from zerver.lib.request import REQ, RequestNotes, has_request_variables +from zerver.lib.request import RequestNotes from zerver.lib.response import json_success from zerver.lib.scheduled_messages import get_undelivered_scheduled_messages from zerver.lib.timestamp import timestamp_to_datetime -from zerver.lib.topic import REQ_topic -from zerver.lib.validator import check_bool, check_int, check_string_in, to_non_negative_int +from zerver.lib.typed_endpoint import ( + ApiParamConfig, + OptionalTopic, + PathOnly, + typed_endpoint, + typed_endpoint_without_parameters, +) +from zerver.lib.typed_endpoint_validators import check_string_in_validator from zerver.models import Message, UserProfile -@has_request_variables +@typed_endpoint_without_parameters def fetch_scheduled_messages(request: HttpRequest, user_profile: UserProfile) -> HttpResponse: return json_success( request, data={"scheduled_messages": get_undelivered_scheduled_messages(user_profile)} ) -@has_request_variables +@typed_endpoint def delete_scheduled_messages( request: HttpRequest, user_profile: UserProfile, - scheduled_message_id: int = REQ(converter=to_non_negative_int, path_only=True), + *, + scheduled_message_id: PathOnly[NonNegativeInt], ) -> HttpResponse: delete_scheduled_message(user_profile, scheduled_message_id) return json_success(request) -@has_request_variables +@typed_endpoint def update_scheduled_message_backend( request: HttpRequest, user_profile: UserProfile, - scheduled_message_id: int = REQ(converter=to_non_negative_int, path_only=True), - req_type: str | None = REQ( - "type", str_validator=check_string_in(Message.API_RECIPIENT_TYPES), default=None - ), - req_to: str | None = REQ("to", default=None), - topic_name: str | None = REQ_topic(), - message_content: str | None = REQ("content", default=None), - scheduled_delivery_timestamp: int | None = REQ(json_validator=check_int, default=None), + *, + scheduled_message_id: PathOnly[NonNegativeInt], + req_type: Annotated[ + Annotated[str, check_string_in_validator(Message.API_RECIPIENT_TYPES)] | None, + ApiParamConfig("type"), + ] = None, + to: Json[int | list[int]] | None = None, + topic_name: OptionalTopic = None, + message_content: Annotated[str | None, ApiParamConfig("content")] = None, + scheduled_delivery_timestamp: Json[int] | None = None, ) -> HttpResponse: if ( req_type is None - and req_to is None + and to is None and topic_name is None and message_content is None and scheduled_delivery_timestamp is None @@ -59,7 +71,7 @@ def update_scheduled_message_backend( recipient_type_name = None if req_type: - if req_to is None: + if to is None: raise JsonableError(_("Recipient required when updating type of scheduled message.")) else: recipient_type_name = req_type @@ -80,10 +92,10 @@ def update_scheduled_message_backend( recipient_type_name = "private" message_to = None - if req_to is not None: + if to is not None: # Because the recipient_type_name may not be updated/changed, # we extract these updated recipient IDs in edit_scheduled_message. - message_to = req_to + message_to = to deliver_at = None if scheduled_delivery_timestamp is not None: @@ -110,16 +122,20 @@ def update_scheduled_message_backend( return json_success(request) -@has_request_variables +@typed_endpoint def create_scheduled_message_backend( request: HttpRequest, user_profile: UserProfile, - req_type: str = REQ("type", str_validator=check_string_in(Message.API_RECIPIENT_TYPES)), - req_to: str = REQ("to"), - topic_name: str | None = REQ_topic(), - message_content: str = REQ("content"), - scheduled_delivery_timestamp: int = REQ(json_validator=check_int), - read_by_sender: bool | None = REQ(json_validator=check_bool, default=None), + *, + req_type: Annotated[ + Annotated[str, check_string_in_validator(Message.API_RECIPIENT_TYPES)], + ApiParamConfig("type"), + ], + req_to: Annotated[Json[int | list[int]], ApiParamConfig("to")], + message_content: Annotated[str, ApiParamConfig("content")], + scheduled_delivery_timestamp: Json[int], + topic_name: OptionalTopic = None, + read_by_sender: Json[bool] | None = None, ) -> HttpResponse: recipient_type_name = req_type if recipient_type_name == "direct":