corporate-upgrade: Migrate to @typed_endpoint.

migrate the following endpoints from @has_request_variables
to @typed_endpoint :

- upgrade()
- remote_realm_upgrade()
- upgrade_page()
- remote_server_upgrade()
This commit is contained in:
bedo 2024-07-08 08:39:03 +03:00 committed by Tim Abbott
parent c6d975f44d
commit 88a0a3061e
2 changed files with 67 additions and 39 deletions

View File

@ -4,7 +4,8 @@ from typing import Optional
from django.conf import settings from django.conf import settings
from django.http import HttpRequest, HttpResponse, HttpResponseRedirect from django.http import HttpRequest, HttpResponse, HttpResponseRedirect
from django.shortcuts import render from django.shortcuts import render
from pydantic import Json from pydantic import AfterValidator, Json
from typing_extensions import Annotated
from corporate.lib.decorator import ( from corporate.lib.decorator import (
authenticated_remote_realm_management_endpoint, authenticated_remote_realm_management_endpoint,
@ -23,10 +24,9 @@ from corporate.lib.stripe import (
) )
from corporate.models import CustomerPlan from corporate.models import CustomerPlan
from zerver.decorator import require_organization_member, zulip_login_required from zerver.decorator import require_organization_member, zulip_login_required
from zerver.lib.request import REQ, has_request_variables
from zerver.lib.response import json_success from zerver.lib.response import json_success
from zerver.lib.typed_endpoint import typed_endpoint from zerver.lib.typed_endpoint import typed_endpoint
from zerver.lib.validator import check_bool, check_int, check_string_in from zerver.lib.typed_endpoint_validators import check_string_in
from zerver.models import UserProfile from zerver.models import UserProfile
from zilencer.lib.remote_counts import MissingDataError from zilencer.lib.remote_counts import MissingDataError
@ -34,19 +34,26 @@ billing_logger = logging.getLogger("corporate.stripe")
@require_organization_member @require_organization_member
@has_request_variables @typed_endpoint
def upgrade( def upgrade(
request: HttpRequest, request: HttpRequest,
user: UserProfile, user: UserProfile,
billing_modality: str = REQ(str_validator=check_string_in(VALID_BILLING_MODALITY_VALUES)), *,
schedule: str = REQ(str_validator=check_string_in(VALID_BILLING_SCHEDULE_VALUES)), billing_modality: Annotated[
signed_seat_count: str = REQ(), str, AfterValidator(lambda val: check_string_in(val, VALID_BILLING_MODALITY_VALUES))
salt: str = REQ(), ],
license_management: Optional[str] = REQ( schedule: Annotated[
default=None, str_validator=check_string_in(VALID_LICENSE_MANAGEMENT_VALUES) str, AfterValidator(lambda val: check_string_in(val, VALID_BILLING_SCHEDULE_VALUES))
), ],
licenses: Optional[int] = REQ(json_validator=check_int, default=None), signed_seat_count: str,
tier: int = REQ(default=CustomerPlan.TIER_CLOUD_STANDARD, json_validator=check_int), salt: str,
license_management: Optional[
Annotated[
str, AfterValidator(lambda val: check_string_in(val, VALID_LICENSE_MANAGEMENT_VALUES))
]
] = None,
licenses: Optional[Json[int]] = None,
tier: Json[int] = CustomerPlan.TIER_CLOUD_STANDARD,
) -> HttpResponse: ) -> HttpResponse:
try: try:
upgrade_request = UpgradeRequest( upgrade_request = UpgradeRequest(
@ -84,20 +91,27 @@ def upgrade(
@authenticated_remote_realm_management_endpoint @authenticated_remote_realm_management_endpoint
@has_request_variables @typed_endpoint
def remote_realm_upgrade( def remote_realm_upgrade(
request: HttpRequest, request: HttpRequest,
billing_session: RemoteRealmBillingSession, billing_session: RemoteRealmBillingSession,
billing_modality: str = REQ(str_validator=check_string_in(VALID_BILLING_MODALITY_VALUES)), *,
schedule: str = REQ(str_validator=check_string_in(VALID_BILLING_SCHEDULE_VALUES)), billing_modality: Annotated[
signed_seat_count: str = REQ(), str, AfterValidator(lambda val: check_string_in(val, VALID_BILLING_MODALITY_VALUES))
salt: str = REQ(), ],
license_management: Optional[str] = REQ( schedule: Annotated[
default=None, str_validator=check_string_in(VALID_LICENSE_MANAGEMENT_VALUES) str, AfterValidator(lambda val: check_string_in(val, VALID_BILLING_SCHEDULE_VALUES))
), ],
licenses: Optional[int] = REQ(json_validator=check_int, default=None), signed_seat_count: str,
remote_server_plan_start_date: Optional[str] = REQ(default=None), salt: str,
tier: int = REQ(default=CustomerPlan.TIER_SELF_HOSTED_BUSINESS, json_validator=check_int), license_management: Optional[
Annotated[
str, AfterValidator(lambda val: check_string_in(val, VALID_LICENSE_MANAGEMENT_VALUES))
]
] = None,
licenses: Optional[Json[int]] = None,
remote_server_plan_start_date: Optional[str] = None,
tier: Json[int] = CustomerPlan.TIER_SELF_HOSTED_BUSINESS,
) -> HttpResponse: ) -> HttpResponse:
try: try:
upgrade_request = UpgradeRequest( upgrade_request = UpgradeRequest(
@ -133,20 +147,27 @@ def remote_realm_upgrade(
@authenticated_remote_server_management_endpoint @authenticated_remote_server_management_endpoint
@has_request_variables @typed_endpoint
def remote_server_upgrade( def remote_server_upgrade(
request: HttpRequest, request: HttpRequest,
billing_session: RemoteServerBillingSession, billing_session: RemoteServerBillingSession,
billing_modality: str = REQ(str_validator=check_string_in(VALID_BILLING_MODALITY_VALUES)), *,
schedule: str = REQ(str_validator=check_string_in(VALID_BILLING_SCHEDULE_VALUES)), billing_modality: Annotated[
signed_seat_count: str = REQ(), str, AfterValidator(lambda val: check_string_in(val, VALID_BILLING_MODALITY_VALUES))
salt: str = REQ(), ],
license_management: Optional[str] = REQ( schedule: Annotated[
default=None, str_validator=check_string_in(VALID_LICENSE_MANAGEMENT_VALUES) str, AfterValidator(lambda val: check_string_in(val, VALID_BILLING_SCHEDULE_VALUES))
), ],
licenses: Optional[int] = REQ(json_validator=check_int, default=None), signed_seat_count: str,
remote_server_plan_start_date: Optional[str] = REQ(default=None), salt: str,
tier: int = REQ(default=CustomerPlan.TIER_SELF_HOSTED_BUSINESS, json_validator=check_int), license_management: Optional[
Annotated[
str, AfterValidator(lambda val: check_string_in(val, VALID_LICENSE_MANAGEMENT_VALUES))
]
] = None,
licenses: Optional[Json[int]] = None,
remote_server_plan_start_date: Optional[str] = None,
tier: Json[int] = CustomerPlan.TIER_SELF_HOSTED_BUSINESS,
) -> HttpResponse: ) -> HttpResponse:
try: try:
upgrade_request = UpgradeRequest( upgrade_request = UpgradeRequest(
@ -182,12 +203,13 @@ def remote_server_upgrade(
@zulip_login_required @zulip_login_required
@has_request_variables @typed_endpoint
def upgrade_page( def upgrade_page(
request: HttpRequest, request: HttpRequest,
manual_license_management: bool = REQ(default=False, json_validator=check_bool), *,
tier: int = REQ(default=CustomerPlan.TIER_CLOUD_STANDARD, json_validator=check_int), manual_license_management: Json[bool] = False,
setup_payment_by_invoice: bool = REQ(default=False, json_validator=check_bool), tier: Json[int] = CustomerPlan.TIER_CLOUD_STANDARD,
setup_payment_by_invoice: Json[bool] = False,
) -> HttpResponse: ) -> HttpResponse:
user = request.user user = request.user
assert user.is_authenticated assert user.is_authenticated

View File

@ -21,6 +21,12 @@ def check_string_fixed_length(string: str, length: int) -> Optional[str]:
return string return string
def check_string_in(val: str, possible_values: List[str]) -> str:
if val not in possible_values:
raise ValueError(_("Not in the list of possible values"))
return val
def check_int_in(val: int, possible_values: List[int]) -> int: def check_int_in(val: int, possible_values: List[int]) -> int:
if val not in possible_values: if val not in possible_values:
raise ValueError(_("Not in the list of possible values")) raise ValueError(_("Not in the list of possible values"))