scheduled_messages: Migrate to typed_endpoint.

Migrate `scheduled_message.py` to typed_endpoint.

Perform Json parsing in the endpoint itself instead of
in `recipient_parsing.py`.
This commit is contained in:
Kenneth Rodrigues 2024-08-18 17:46:43 +05:30 committed by Tim Abbott
parent 7f38c95384
commit 50712bfa81
6 changed files with 62 additions and 71 deletions

View File

@ -133,7 +133,7 @@ def edit_scheduled_message(
client: Client, client: Client,
scheduled_message_id: int, scheduled_message_id: int,
recipient_type_name: str | None, recipient_type_name: str | None,
message_to: str | None, message_to: int | list[int] | None,
topic_name: str | None, topic_name: str | None,
message_content: str | None, message_content: str | None,
deliver_at: datetime | None, deliver_at: datetime | None,

View File

@ -1,29 +1,17 @@
import orjson
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from zerver.lib.exceptions import JsonableError from zerver.lib.exceptions import JsonableError
def extract_stream_id(req_to: str) -> int: def extract_stream_id(req_to: int | list[int]) -> int:
# Recipient should only be a single stream ID. # Recipient should only be a single stream ID.
try: if isinstance(req_to, list):
stream_id = int(req_to)
except ValueError:
raise JsonableError(_("Invalid data type for channel ID")) raise JsonableError(_("Invalid data type for channel ID"))
return stream_id return req_to
def extract_direct_message_recipient_ids(req_to: str) -> list[int]: def extract_direct_message_recipient_ids(req_to: int | list[int]) -> list[int]:
try: if not isinstance(req_to, list):
user_ids = orjson.loads(req_to)
except orjson.JSONDecodeError:
user_ids = req_to
if not isinstance(user_ids, list):
raise JsonableError(_("Invalid data type for recipients")) raise JsonableError(_("Invalid data type for recipients"))
for user_id in user_ids: return list(set(req_to))
if not isinstance(user_id, int):
raise JsonableError(_("Recipient list may only contain user IDs"))
return list(set(user_ids))

View File

@ -363,9 +363,12 @@ def has_request_variables(
# #
# TODO: Either run validators for path_only parameters # TODO: Either run validators for path_only parameters
# or don't declare them using REQ. # or don't declare them using REQ.
assert func_var_name in kwargs
# no coverage because has_request_variables will be removed once
# all the endpoints have been migrated to use typed_endpoint.
assert func_var_name in kwargs # nocoverage
if func_var_name in kwargs: if func_var_name in kwargs:
continue continue # nocoverage
assert func_var_name is not None assert func_var_name is not None
post_var_name: str | None post_var_name: str | None

View File

@ -1,5 +1,3 @@
import orjson
from zerver.lib.exceptions import JsonableError from zerver.lib.exceptions import JsonableError
from zerver.lib.recipient_parsing import extract_direct_message_recipient_ids, extract_stream_id from zerver.lib.recipient_parsing import extract_direct_message_recipient_ids, extract_stream_id
from zerver.lib.test_classes import ZulipTestCase from zerver.lib.test_classes import ZulipTestCase
@ -8,39 +6,25 @@ from zerver.lib.test_classes import ZulipTestCase
class TestRecipientParsing(ZulipTestCase): class TestRecipientParsing(ZulipTestCase):
def test_extract_stream_id(self) -> None: def test_extract_stream_id(self) -> None:
# stream message recipient = single stream ID. # stream message recipient = single stream ID.
stream_id = extract_stream_id("1") stream_id = extract_stream_id(1)
self.assertEqual(stream_id, 1) self.assertEqual(stream_id, 1)
with self.assertRaisesRegex(JsonableError, "Invalid data type for channel ID"): with self.assertRaisesRegex(JsonableError, "Invalid data type for channel ID"):
extract_stream_id("1,2") extract_stream_id([1, 2])
with self.assertRaisesRegex(JsonableError, "Invalid data type for channel ID"): with self.assertRaisesRegex(JsonableError, "Invalid data type for channel ID"):
extract_stream_id("[1]") extract_stream_id([1])
with self.assertRaisesRegex(JsonableError, "Invalid data type for channel ID"):
extract_stream_id("general")
def test_extract_recipient_ids(self) -> None: def test_extract_recipient_ids(self) -> None:
# direct message recipients = user IDs. # direct message recipients = user IDs.
user_ids = "[3,2,1]" user_ids = [3, 2, 1]
result = sorted(extract_direct_message_recipient_ids(user_ids)) result = sorted(extract_direct_message_recipient_ids(user_ids))
self.assertEqual(result, [1, 2, 3]) self.assertEqual(result, [1, 2, 3])
# JSON list w/duplicates # list w/duplicates
user_ids = orjson.dumps([3, 3, 12]).decode() user_ids = [3, 3, 12]
result = sorted(extract_direct_message_recipient_ids(user_ids)) result = sorted(extract_direct_message_recipient_ids(user_ids))
self.assertEqual(result, [3, 12]) self.assertEqual(result, [3, 12])
# Invalid data
user_ids = "1, 12"
with self.assertRaisesRegex(JsonableError, "Invalid data type for recipients"): with self.assertRaisesRegex(JsonableError, "Invalid data type for recipients"):
extract_direct_message_recipient_ids(user_ids) extract_direct_message_recipient_ids(1)
user_ids = orjson.dumps(dict(recipient=12)).decode()
with self.assertRaisesRegex(JsonableError, "Invalid data type for recipients"):
extract_direct_message_recipient_ids(user_ids)
# Heterogeneous lists are not supported
user_ids = orjson.dumps([3, 4, "eeshan@example.com"]).decode()
with self.assertRaisesRegex(JsonableError, "Recipient list may only contain user IDs"):
extract_direct_message_recipient_ids(user_ids)

View File

@ -91,7 +91,7 @@ class ScheduledMessageTest(ZulipTestCase):
result = self.do_schedule_message( result = self.do_schedule_message(
"direct", [othello.email], content + " 4", scheduled_delivery_timestamp "direct", [othello.email], content + " 4", scheduled_delivery_timestamp
) )
self.assert_json_error(result, "Recipient list may only contain user IDs") self.assert_json_error(result, 'to["int"] is not an integer')
def create_scheduled_message(self) -> None: def create_scheduled_message(self) -> None:
content = "Test message" content = "Test message"

View File

@ -1,6 +1,9 @@
from typing import Annotated
from django.http import HttpRequest, HttpResponse from django.http import HttpRequest, HttpResponse
from django.utils.timezone import now as timezone_now from django.utils.timezone import now as timezone_now
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from pydantic import Json, NonNegativeInt
from zerver.actions.scheduled_messages import ( from zerver.actions.scheduled_messages import (
check_schedule_message, check_schedule_message,
@ -9,48 +12,57 @@ from zerver.actions.scheduled_messages import (
) )
from zerver.lib.exceptions import JsonableError from zerver.lib.exceptions import JsonableError
from zerver.lib.recipient_parsing import extract_direct_message_recipient_ids, extract_stream_id from zerver.lib.recipient_parsing import extract_direct_message_recipient_ids, extract_stream_id
from zerver.lib.request import REQ, RequestNotes, has_request_variables from zerver.lib.request import RequestNotes
from zerver.lib.response import json_success from zerver.lib.response import json_success
from zerver.lib.scheduled_messages import get_undelivered_scheduled_messages from zerver.lib.scheduled_messages import get_undelivered_scheduled_messages
from zerver.lib.timestamp import timestamp_to_datetime from zerver.lib.timestamp import timestamp_to_datetime
from zerver.lib.topic import REQ_topic from zerver.lib.typed_endpoint import (
from zerver.lib.validator import check_bool, check_int, check_string_in, to_non_negative_int ApiParamConfig,
OptionalTopic,
PathOnly,
typed_endpoint,
typed_endpoint_without_parameters,
)
from zerver.lib.typed_endpoint_validators import check_string_in_validator
from zerver.models import Message, UserProfile from zerver.models import Message, UserProfile
@has_request_variables @typed_endpoint_without_parameters
def fetch_scheduled_messages(request: HttpRequest, user_profile: UserProfile) -> HttpResponse: def fetch_scheduled_messages(request: HttpRequest, user_profile: UserProfile) -> HttpResponse:
return json_success( return json_success(
request, data={"scheduled_messages": get_undelivered_scheduled_messages(user_profile)} request, data={"scheduled_messages": get_undelivered_scheduled_messages(user_profile)}
) )
@has_request_variables @typed_endpoint
def delete_scheduled_messages( def delete_scheduled_messages(
request: HttpRequest, request: HttpRequest,
user_profile: UserProfile, user_profile: UserProfile,
scheduled_message_id: int = REQ(converter=to_non_negative_int, path_only=True), *,
scheduled_message_id: PathOnly[NonNegativeInt],
) -> HttpResponse: ) -> HttpResponse:
delete_scheduled_message(user_profile, scheduled_message_id) delete_scheduled_message(user_profile, scheduled_message_id)
return json_success(request) return json_success(request)
@has_request_variables @typed_endpoint
def update_scheduled_message_backend( def update_scheduled_message_backend(
request: HttpRequest, request: HttpRequest,
user_profile: UserProfile, user_profile: UserProfile,
scheduled_message_id: int = REQ(converter=to_non_negative_int, path_only=True), *,
req_type: str | None = REQ( scheduled_message_id: PathOnly[NonNegativeInt],
"type", str_validator=check_string_in(Message.API_RECIPIENT_TYPES), default=None req_type: Annotated[
), Annotated[str, check_string_in_validator(Message.API_RECIPIENT_TYPES)] | None,
req_to: str | None = REQ("to", default=None), ApiParamConfig("type"),
topic_name: str | None = REQ_topic(), ] = None,
message_content: str | None = REQ("content", default=None), to: Json[int | list[int]] | None = None,
scheduled_delivery_timestamp: int | None = REQ(json_validator=check_int, default=None), topic_name: OptionalTopic = None,
message_content: Annotated[str | None, ApiParamConfig("content")] = None,
scheduled_delivery_timestamp: Json[int] | None = None,
) -> HttpResponse: ) -> HttpResponse:
if ( if (
req_type is None req_type is None
and req_to is None and to is None
and topic_name is None and topic_name is None
and message_content is None and message_content is None
and scheduled_delivery_timestamp is None and scheduled_delivery_timestamp is None
@ -59,7 +71,7 @@ def update_scheduled_message_backend(
recipient_type_name = None recipient_type_name = None
if req_type: if req_type:
if req_to is None: if to is None:
raise JsonableError(_("Recipient required when updating type of scheduled message.")) raise JsonableError(_("Recipient required when updating type of scheduled message."))
else: else:
recipient_type_name = req_type recipient_type_name = req_type
@ -80,10 +92,10 @@ def update_scheduled_message_backend(
recipient_type_name = "private" recipient_type_name = "private"
message_to = None message_to = None
if req_to is not None: if to is not None:
# Because the recipient_type_name may not be updated/changed, # Because the recipient_type_name may not be updated/changed,
# we extract these updated recipient IDs in edit_scheduled_message. # we extract these updated recipient IDs in edit_scheduled_message.
message_to = req_to message_to = to
deliver_at = None deliver_at = None
if scheduled_delivery_timestamp is not None: if scheduled_delivery_timestamp is not None:
@ -110,16 +122,20 @@ def update_scheduled_message_backend(
return json_success(request) return json_success(request)
@has_request_variables @typed_endpoint
def create_scheduled_message_backend( def create_scheduled_message_backend(
request: HttpRequest, request: HttpRequest,
user_profile: UserProfile, user_profile: UserProfile,
req_type: str = REQ("type", str_validator=check_string_in(Message.API_RECIPIENT_TYPES)), *,
req_to: str = REQ("to"), req_type: Annotated[
topic_name: str | None = REQ_topic(), Annotated[str, check_string_in_validator(Message.API_RECIPIENT_TYPES)],
message_content: str = REQ("content"), ApiParamConfig("type"),
scheduled_delivery_timestamp: int = REQ(json_validator=check_int), ],
read_by_sender: bool | None = REQ(json_validator=check_bool, default=None), req_to: Annotated[Json[int | list[int]], ApiParamConfig("to")],
message_content: Annotated[str, ApiParamConfig("content")],
scheduled_delivery_timestamp: Json[int],
topic_name: OptionalTopic = None,
read_by_sender: Json[bool] | None = None,
) -> HttpResponse: ) -> HttpResponse:
recipient_type_name = req_type recipient_type_name = req_type
if recipient_type_name == "direct": if recipient_type_name == "direct":