diff --git a/corporate/views/upgrade.py b/corporate/views/upgrade.py index 7949b541db..523381f639 100644 --- a/corporate/views/upgrade.py +++ b/corporate/views/upgrade.py @@ -4,7 +4,8 @@ from typing import Optional from django.conf import settings from django.http import HttpRequest, HttpResponse, HttpResponseRedirect from django.shortcuts import render -from pydantic import Json +from pydantic import AfterValidator, Json +from typing_extensions import Annotated from corporate.lib.decorator import ( authenticated_remote_realm_management_endpoint, @@ -23,10 +24,9 @@ from corporate.lib.stripe import ( ) from corporate.models import CustomerPlan from zerver.decorator import require_organization_member, zulip_login_required -from zerver.lib.request import REQ, has_request_variables from zerver.lib.response import json_success from zerver.lib.typed_endpoint import typed_endpoint -from zerver.lib.validator import check_bool, check_int, check_string_in +from zerver.lib.typed_endpoint_validators import check_string_in from zerver.models import UserProfile from zilencer.lib.remote_counts import MissingDataError @@ -34,19 +34,26 @@ billing_logger = logging.getLogger("corporate.stripe") @require_organization_member -@has_request_variables +@typed_endpoint def upgrade( request: HttpRequest, user: UserProfile, - billing_modality: str = REQ(str_validator=check_string_in(VALID_BILLING_MODALITY_VALUES)), - schedule: str = REQ(str_validator=check_string_in(VALID_BILLING_SCHEDULE_VALUES)), - signed_seat_count: str = REQ(), - salt: str = REQ(), - license_management: Optional[str] = REQ( - default=None, str_validator=check_string_in(VALID_LICENSE_MANAGEMENT_VALUES) - ), - licenses: Optional[int] = REQ(json_validator=check_int, default=None), - tier: int = REQ(default=CustomerPlan.TIER_CLOUD_STANDARD, json_validator=check_int), + *, + billing_modality: Annotated[ + str, AfterValidator(lambda val: check_string_in(val, VALID_BILLING_MODALITY_VALUES)) + ], + schedule: Annotated[ + str, AfterValidator(lambda val: check_string_in(val, VALID_BILLING_SCHEDULE_VALUES)) + ], + signed_seat_count: str, + salt: str, + license_management: Optional[ + Annotated[ + str, AfterValidator(lambda val: check_string_in(val, VALID_LICENSE_MANAGEMENT_VALUES)) + ] + ] = None, + licenses: Optional[Json[int]] = None, + tier: Json[int] = CustomerPlan.TIER_CLOUD_STANDARD, ) -> HttpResponse: try: upgrade_request = UpgradeRequest( @@ -84,20 +91,27 @@ def upgrade( @authenticated_remote_realm_management_endpoint -@has_request_variables +@typed_endpoint def remote_realm_upgrade( request: HttpRequest, billing_session: RemoteRealmBillingSession, - billing_modality: str = REQ(str_validator=check_string_in(VALID_BILLING_MODALITY_VALUES)), - schedule: str = REQ(str_validator=check_string_in(VALID_BILLING_SCHEDULE_VALUES)), - signed_seat_count: str = REQ(), - salt: str = REQ(), - license_management: Optional[str] = REQ( - default=None, str_validator=check_string_in(VALID_LICENSE_MANAGEMENT_VALUES) - ), - licenses: Optional[int] = REQ(json_validator=check_int, default=None), - remote_server_plan_start_date: Optional[str] = REQ(default=None), - tier: int = REQ(default=CustomerPlan.TIER_SELF_HOSTED_BUSINESS, json_validator=check_int), + *, + billing_modality: Annotated[ + str, AfterValidator(lambda val: check_string_in(val, VALID_BILLING_MODALITY_VALUES)) + ], + schedule: Annotated[ + str, AfterValidator(lambda val: check_string_in(val, VALID_BILLING_SCHEDULE_VALUES)) + ], + signed_seat_count: str, + salt: str, + license_management: Optional[ + Annotated[ + str, AfterValidator(lambda val: check_string_in(val, VALID_LICENSE_MANAGEMENT_VALUES)) + ] + ] = None, + licenses: Optional[Json[int]] = None, + remote_server_plan_start_date: Optional[str] = None, + tier: Json[int] = CustomerPlan.TIER_SELF_HOSTED_BUSINESS, ) -> HttpResponse: try: upgrade_request = UpgradeRequest( @@ -133,20 +147,27 @@ def remote_realm_upgrade( @authenticated_remote_server_management_endpoint -@has_request_variables +@typed_endpoint def remote_server_upgrade( request: HttpRequest, billing_session: RemoteServerBillingSession, - billing_modality: str = REQ(str_validator=check_string_in(VALID_BILLING_MODALITY_VALUES)), - schedule: str = REQ(str_validator=check_string_in(VALID_BILLING_SCHEDULE_VALUES)), - signed_seat_count: str = REQ(), - salt: str = REQ(), - license_management: Optional[str] = REQ( - default=None, str_validator=check_string_in(VALID_LICENSE_MANAGEMENT_VALUES) - ), - licenses: Optional[int] = REQ(json_validator=check_int, default=None), - remote_server_plan_start_date: Optional[str] = REQ(default=None), - tier: int = REQ(default=CustomerPlan.TIER_SELF_HOSTED_BUSINESS, json_validator=check_int), + *, + billing_modality: Annotated[ + str, AfterValidator(lambda val: check_string_in(val, VALID_BILLING_MODALITY_VALUES)) + ], + schedule: Annotated[ + str, AfterValidator(lambda val: check_string_in(val, VALID_BILLING_SCHEDULE_VALUES)) + ], + signed_seat_count: str, + salt: str, + license_management: Optional[ + Annotated[ + str, AfterValidator(lambda val: check_string_in(val, VALID_LICENSE_MANAGEMENT_VALUES)) + ] + ] = None, + licenses: Optional[Json[int]] = None, + remote_server_plan_start_date: Optional[str] = None, + tier: Json[int] = CustomerPlan.TIER_SELF_HOSTED_BUSINESS, ) -> HttpResponse: try: upgrade_request = UpgradeRequest( @@ -182,12 +203,13 @@ def remote_server_upgrade( @zulip_login_required -@has_request_variables +@typed_endpoint def upgrade_page( request: HttpRequest, - manual_license_management: bool = REQ(default=False, json_validator=check_bool), - tier: int = REQ(default=CustomerPlan.TIER_CLOUD_STANDARD, json_validator=check_int), - setup_payment_by_invoice: bool = REQ(default=False, json_validator=check_bool), + *, + manual_license_management: Json[bool] = False, + tier: Json[int] = CustomerPlan.TIER_CLOUD_STANDARD, + setup_payment_by_invoice: Json[bool] = False, ) -> HttpResponse: user = request.user assert user.is_authenticated diff --git a/zerver/lib/typed_endpoint_validators.py b/zerver/lib/typed_endpoint_validators.py index efe3f1c585..8b3dd4589c 100644 --- a/zerver/lib/typed_endpoint_validators.py +++ b/zerver/lib/typed_endpoint_validators.py @@ -21,6 +21,12 @@ def check_string_fixed_length(string: str, length: int) -> Optional[str]: return string +def check_string_in(val: str, possible_values: List[str]) -> str: + if val not in possible_values: + raise ValueError(_("Not in the list of possible values")) + return val + + def check_int_in(val: int, possible_values: List[int]) -> int: if val not in possible_values: raise ValueError(_("Not in the list of possible values"))