diff --git a/zerver/lib/typed_endpoint.py b/zerver/lib/typed_endpoint.py index 3887e32a6c..a2937c7520 100644 --- a/zerver/lib/typed_endpoint.py +++ b/zerver/lib/typed_endpoint.py @@ -330,6 +330,8 @@ ERROR_TEMPLATES = { "string_too_short": _("{var_name} is too short."), "string_type": _("{var_name} is not a string"), "unexpected_keyword_argument": _('Argument "{argument}" at {var_name} is unexpected'), + "string_pattern_mismatch": _("{var_name} has invalid format"), + "string_fixed_length": _("{var_name} is not length {length}"), } diff --git a/zerver/lib/typed_endpoint_validators.py b/zerver/lib/typed_endpoint_validators.py new file mode 100644 index 0000000000..3dc0aaff5b --- /dev/null +++ b/zerver/lib/typed_endpoint_validators.py @@ -0,0 +1,18 @@ +from typing import Optional + +from pydantic_core import PydanticCustomError + +# The Pydantic.StringConstraints does not have validation for the string to be +# of the specified length. So, we need to create a custom validator for that. + + +def check_string_fixed_length(string: str, length: int) -> Optional[str]: + if len(string) != length: + raise PydanticCustomError( + "string_fixed_length", + "", + { + "length": length, + }, + ) + return string diff --git a/zerver/tests/test_push_notifications.py b/zerver/tests/test_push_notifications.py index 5c47f7cd1e..52d98a4735 100644 --- a/zerver/tests/test_push_notifications.py +++ b/zerver/tests/test_push_notifications.py @@ -496,14 +496,19 @@ class PushBouncerNotificationTest(BouncerTestCase): def test_register_validate_ios_app_id(self) -> None: endpoint = "/api/v1/remotes/push/register" - args = {"user_id": 11, "token": "1122", "token_kind": PushDeviceToken.APNS} + args = { + "user_id": 11, + "token": "1122", + "token_kind": PushDeviceToken.APNS, + "ios_app_id": "'; tables --", + } - result = self.uuid_post( - self.server_uuid, - endpoint, - {**args, "ios_app_id": "'; tables --"}, - ) - self.assert_json_error(result, "Invalid app ID") + result = self.uuid_post(self.server_uuid, endpoint, args) + self.assert_json_error(result, "ios_app_id has invalid format") + + args["ios_app_id"] = "com.zulip.apple" + result = self.uuid_post(self.server_uuid, endpoint, args) + self.assert_json_success(result) def test_register_device_deduplication(self) -> None: hamlet = self.example_user("hamlet") @@ -5004,6 +5009,24 @@ class PushBouncerSignupTest(ZulipTestCase): result = self.client_post("/api/v1/remotes/server/register", request) self.assert_json_error(result, "Invalid UUID") + # check if zulip org id is of allowed length + zulip_org_id = "18cedb98" + request["zulip_org_id"] = zulip_org_id + result = self.client_post("/api/v1/remotes/server/register", request) + self.assert_json_error(result, "zulip_org_id is not length 36") + + def test_push_signup_invalid_zulip_org_key(self) -> None: + zulip_org_id = str(uuid.uuid4()) + zulip_org_key = get_random_string(63) + request = dict( + zulip_org_id=zulip_org_id, + zulip_org_key=zulip_org_key, + hostname="invalid-host", + contact_email="server-admin@zulip.com", + ) + result = self.client_post("/api/v1/remotes/server/register", request) + self.assert_json_error(result, "zulip_org_key is not length 64") + def test_push_signup_success(self) -> None: zulip_org_id = str(uuid.uuid4()) zulip_org_key = get_random_string(64) diff --git a/zilencer/views.py b/zilencer/views.py index 3eea3a9cd7..4e604a8603 100644 --- a/zilencer/views.py +++ b/zilencer/views.py @@ -18,7 +18,9 @@ from django.utils.timezone import now as timezone_now from django.utils.translation import gettext as _ from django.utils.translation import gettext as err_ from django.views.decorators.csrf import csrf_exempt -from pydantic import BaseModel, ConfigDict, Json +from pydantic import BaseModel, ConfigDict, Json, StringConstraints +from pydantic.functional_validators import AfterValidator +from typing_extensions import Annotated from analytics.lib.counts import ( BOUNCER_ONLY_REMOTE_COUNT_STAT_PROPERTIES, @@ -60,15 +62,15 @@ from zerver.lib.remote_server import ( RealmCountDataForAnalytics, RealmDataForAnalytics, ) -from zerver.lib.request import REQ, RequestNotes, has_request_variables +from zerver.lib.request import RequestNotes, has_request_variables from zerver.lib.response import json_success from zerver.lib.send_email import FromAddress from zerver.lib.timestamp import timestamp_to_datetime -from zerver.lib.typed_endpoint import JsonBodyPayload, typed_endpoint +from zerver.lib.typed_endpoint import JsonBodyPayload, RequiredStringConstraint, typed_endpoint +from zerver.lib.typed_endpoint_validators import check_string_fixed_length from zerver.lib.types import RemoteRealmDictValue -from zerver.lib.validator import check_capped_string, check_int, check_string_fixed_length from zerver.models.realms import DisposableEmailError -from zerver.views.push_notifications import check_app_id, validate_token +from zerver.views.push_notifications import validate_token from zilencer.auth import InvalidZulipServerKeyError from zilencer.models import ( RemoteInstallationCount, @@ -116,20 +118,29 @@ def deactivate_remote_server( @csrf_exempt @require_post -@has_request_variables +@typed_endpoint def register_remote_server( request: HttpRequest, - zulip_org_id: str = REQ(str_validator=check_string_fixed_length(RemoteZulipServer.UUID_LENGTH)), - zulip_org_key: str = REQ( - str_validator=check_string_fixed_length(RemoteZulipServer.API_KEY_LENGTH) - ), - hostname: str = REQ(str_validator=check_capped_string(RemoteZulipServer.HOSTNAME_MAX_LENGTH)), - contact_email: str = REQ(), - new_org_key: Optional[str] = REQ( - str_validator=check_string_fixed_length(RemoteZulipServer.API_KEY_LENGTH), default=None - ), + *, + zulip_org_id: Annotated[ + str, + RequiredStringConstraint, + AfterValidator(lambda s: check_string_fixed_length(s, RemoteZulipServer.UUID_LENGTH)), + ], + zulip_org_key: Annotated[ + str, + RequiredStringConstraint, + AfterValidator(lambda s: check_string_fixed_length(s, RemoteZulipServer.API_KEY_LENGTH)), + ], + hostname: Annotated[str, StringConstraints(max_length=RemoteZulipServer.HOSTNAME_MAX_LENGTH)], + contact_email: str, + new_org_key: Annotated[ + Optional[str], + RequiredStringConstraint, + AfterValidator(lambda s: check_string_fixed_length(s, RemoteZulipServer.API_KEY_LENGTH)), + ] = None, ) -> HttpResponse: - # REQ validated the the field lengths, but we still need to + # StringConstraints validated the the field lengths, but we still need to # validate the format of these fields. try: # TODO: Ideally we'd not abuse the URL validator this way @@ -218,16 +229,17 @@ def register_remote_server( return json_success(request, data={"created": created}) -@has_request_variables +@typed_endpoint def register_remote_push_device( request: HttpRequest, server: RemoteZulipServer, - user_id: Optional[int] = REQ(json_validator=check_int, default=None), - user_uuid: Optional[str] = REQ(default=None), - realm_uuid: Optional[str] = REQ(default=None), - token: str = REQ(), - token_kind: int = REQ(json_validator=check_int), - ios_app_id: Optional[str] = REQ(str_validator=check_app_id, default=None), + *, + user_id: Optional[Json[int]] = None, + user_uuid: Optional[str] = None, + realm_uuid: Optional[str] = None, + token: Annotated[str, RequiredStringConstraint], + token_kind: Json[int], + ios_app_id: Annotated[Optional[str], StringConstraints(pattern="^[.a-zA-Z0-9-]+$")] = None, ) -> HttpResponse: validate_bouncer_token_request(token, token_kind) if token_kind == RemotePushDeviceToken.APNS and ios_app_id is None: @@ -278,15 +290,16 @@ def register_remote_push_device( return json_success(request) -@has_request_variables +@typed_endpoint def unregister_remote_push_device( request: HttpRequest, server: RemoteZulipServer, - token: str = REQ(), - token_kind: int = REQ(json_validator=check_int), - user_id: Optional[int] = REQ(json_validator=check_int, default=None), - user_uuid: Optional[str] = REQ(default=None), - realm_uuid: Optional[str] = REQ(default=None), + *, + token: Annotated[str, RequiredStringConstraint], + token_kind: Json[int], + user_id: Optional[Json[int]] = None, + user_uuid: Optional[str] = None, + realm_uuid: Optional[str] = None, ) -> HttpResponse: validate_bouncer_token_request(token, token_kind) user_identity = UserPushIdentityCompat(user_id=user_id, user_uuid=user_uuid) @@ -302,13 +315,14 @@ def unregister_remote_push_device( return json_success(request) -@has_request_variables +@typed_endpoint def unregister_all_remote_push_devices( request: HttpRequest, server: RemoteZulipServer, - user_id: Optional[int] = REQ(json_validator=check_int, default=None), - user_uuid: Optional[str] = REQ(default=None), - realm_uuid: Optional[str] = REQ(default=None), + *, + user_id: Optional[Json[int]] = None, + user_uuid: Optional[str] = None, + realm_uuid: Optional[str] = None, ) -> HttpResponse: user_identity = UserPushIdentityCompat(user_id=user_id, user_uuid=user_uuid) @@ -490,21 +504,34 @@ class PushNotificationsDisallowedError(JsonableError): super().__init__(msg) -@has_request_variables +class RemoteServerNotificationPayload(BaseModel): + user_id: Optional[int] = None + user_uuid: Optional[str] = None + realm_uuid: Optional[str] = None + gcm_payload: Dict[str, Any] = {} + apns_payload: Dict[str, Any] = {} + gcm_options: Dict[str, Any] = {} + + android_devices: List[str] = [] + apple_devices: List[str] = [] + + +@typed_endpoint def remote_server_notify_push( request: HttpRequest, server: RemoteZulipServer, - payload: Dict[str, Any] = REQ(argument_type="body"), + *, + payload: JsonBodyPayload[RemoteServerNotificationPayload], ) -> HttpResponse: - user_id = payload.get("user_id") - user_uuid = payload.get("user_uuid") + user_id = payload.user_id + user_uuid = payload.user_uuid user_identity = UserPushIdentityCompat(user_id, user_uuid) - gcm_payload = payload["gcm_payload"] - apns_payload = payload["apns_payload"] - gcm_options = payload.get("gcm_options", {}) + gcm_payload = payload.gcm_payload + apns_payload = payload.apns_payload + gcm_options = payload.gcm_options - realm_uuid = payload.get("realm_uuid") + realm_uuid = payload.realm_uuid remote_realm = None if realm_uuid is not None: assert isinstance( @@ -654,8 +681,8 @@ def remote_server_notify_push( deleted_devices = get_deleted_devices( user_identity, server, - android_devices=payload.get("android_devices", []), - apple_devices=payload.get("apple_devices", []), + android_devices=payload.android_devices, + apple_devices=payload.apple_devices, ) return json_success(