From 5e62903d29d2b9ee6ae7aef40b5a8713824051e3 Mon Sep 17 00:00:00 2001 From: Anders Kaseorg Date: Tue, 24 Sep 2024 14:39:58 -0700 Subject: [PATCH] corporate: Use Literal types for upgrade parameters. Signed-off-by: Anders Kaseorg --- corporate/lib/billing_types.py | 5 +++++ corporate/lib/stripe.py | 26 +++++++++----------------- corporate/views/support.py | 22 +++++++++------------- corporate/views/upgrade.py | 27 ++++++++++----------------- 4 files changed, 33 insertions(+), 47 deletions(-) create mode 100644 corporate/lib/billing_types.py diff --git a/corporate/lib/billing_types.py b/corporate/lib/billing_types.py new file mode 100644 index 0000000000..c8b99ac1c9 --- /dev/null +++ b/corporate/lib/billing_types.py @@ -0,0 +1,5 @@ +from typing import Literal + +BillingModality = Literal["send_invoice", "charge_automatically"] +BillingSchedule = Literal["annual", "monthly"] +LicenseManagement = Literal["automatic", "manual"] diff --git a/corporate/lib/stripe.py b/corporate/lib/stripe.py index 8cd14f76e1..8a2a8a0661 100644 --- a/corporate/lib/stripe.py +++ b/corporate/lib/stripe.py @@ -27,6 +27,7 @@ from django.utils.translation import gettext_lazy from django.utils.translation import override as override_language from typing_extensions import ParamSpec, override +from corporate.lib.billing_types import BillingModality, BillingSchedule, LicenseManagement from corporate.models import ( Customer, CustomerPlan, @@ -90,10 +91,6 @@ MIN_INVOICED_LICENSES = 30 MAX_INVOICED_LICENSES = 1000 DEFAULT_INVOICE_DAYS_UNTIL_DUE = 15 -VALID_BILLING_MODALITY_VALUES = ["send_invoice", "charge_automatically"] -VALID_BILLING_SCHEDULE_VALUES = ["annual", "monthly"] -VALID_LICENSE_MANAGEMENT_VALUES = ["automatic", "manual"] - CARD_CAPITALIZATION = { "amex": "American Express", "diners": "Diners Club", @@ -232,19 +229,15 @@ def validate_licenses( def check_upgrade_parameters( - billing_modality: str, - schedule: str, - license_management: str | None, + billing_modality: BillingModality, + schedule: BillingSchedule, + license_management: LicenseManagement | None, licenses: int | None, seat_count: int, exempt_from_license_number_check: bool, min_licenses_for_plan: int, ) -> None: - if billing_modality not in VALID_BILLING_MODALITY_VALUES: # nocoverage - raise BillingError("unknown billing_modality", "") - if schedule not in VALID_BILLING_SCHEDULE_VALUES: # nocoverage - raise BillingError("unknown schedule") - if license_management not in VALID_LICENSE_MANAGEMENT_VALUES: # nocoverage + if license_management is None: # nocoverage raise BillingError("unknown license_management") validate_licenses( billing_modality == "charge_automatically", @@ -541,11 +534,11 @@ class StripeCustomerData: @dataclass class UpgradeRequest: - billing_modality: str - schedule: str + billing_modality: BillingModality + schedule: BillingSchedule signed_seat_count: str salt: str - license_management: str | None + license_management: LicenseManagement | None licenses: int | None tier: int remote_server_plan_start_date: str | None @@ -592,7 +585,7 @@ class SupportViewRequest(TypedDict, total=False): sponsorship_status: bool | None monthly_discounted_price: int | None annual_discounted_price: int | None - billing_modality: str | None + billing_modality: BillingModality | None plan_modification: str | None new_plan_tier: int | None minimum_licenses: int | None @@ -3497,7 +3490,6 @@ class BillingSession(ABC): success_message = self.configure_temporary_courtesy_plan(temporary_plan_end_date) elif support_type == SupportType.update_billing_modality: assert support_request["billing_modality"] is not None - assert support_request["billing_modality"] in VALID_BILLING_MODALITY_VALUES charge_automatically = support_request["billing_modality"] == "charge_automatically" success_message = self.update_billing_modality_of_current_plan(charge_automatically) elif support_type == SupportType.update_plan_end_date: diff --git a/corporate/views/support.py b/corporate/views/support.py index cb13c229b9..e63f8b57bf 100644 --- a/corporate/views/support.py +++ b/corporate/views/support.py @@ -23,6 +23,7 @@ from pydantic import AfterValidator, Json, NonNegativeInt from confirmation.models import Confirmation, confirmation_url from confirmation.settings import STATUS_USED from corporate.lib.activity import format_optional_datetime, remote_installation_stats_link +from corporate.lib.billing_types import BillingModality from corporate.lib.stripe import ( BILLING_SUPPORT_EMAIL, RealmBillingSession, @@ -330,19 +331,14 @@ def check_update_max_invites(realm: Realm, new_max: int, default_max: int) -> bo return new_max > default_max -VALID_MODIFY_PLAN_METHODS = Literal[ +ModifyPlan = Literal[ "downgrade_at_billing_cycle_end", "downgrade_now_without_additional_licenses", "downgrade_now_void_open_invoices", "upgrade_plan_tier", ] -VALID_STATUS_VALUES = Literal["active", "deactivated"] - -VALID_BILLING_MODALITY_VALUES = Literal[ - "send_invoice", - "charge_automatically", -] +RemoteServerStatus = Literal["active", "deactivated"] SHARED_SUPPORT_CONTEXT = { "get_org_type_display_name": get_org_type_display_name, @@ -363,11 +359,11 @@ def support( minimum_licenses: Json[NonNegativeInt] | None = None, required_plan_tier: Json[NonNegativeInt] | None = None, new_subdomain: str | None = None, - status: VALID_STATUS_VALUES | None = None, - billing_modality: VALID_BILLING_MODALITY_VALUES | None = None, + status: RemoteServerStatus | None = None, + billing_modality: BillingModality | None = None, sponsorship_pending: Json[bool] | None = None, approve_sponsorship: Json[bool] = False, - modify_plan: VALID_MODIFY_PLAN_METHODS | None = None, + modify_plan: ModifyPlan | None = None, scrub_realm: Json[bool] = False, delete_user_by_id: Json[NonNegativeInt] | None = None, query: Annotated[str | None, ApiParamConfig("q")] = None, @@ -685,12 +681,12 @@ def remote_servers_support( sent_invoice_id: str | None = None, sponsorship_pending: Json[bool] | None = None, approve_sponsorship: Json[bool] = False, - billing_modality: VALID_BILLING_MODALITY_VALUES | None = None, + billing_modality: BillingModality | None = None, plan_end_date: Annotated[str, AfterValidator(lambda x: check_date("plan_end_date", x))] | None = None, - modify_plan: VALID_MODIFY_PLAN_METHODS | None = None, + modify_plan: ModifyPlan | None = None, delete_fixed_price_next_plan: Json[bool] = False, - remote_server_status: VALID_STATUS_VALUES | None = None, + remote_server_status: RemoteServerStatus | None = None, temporary_courtesy_plan: Annotated[ str, AfterValidator(lambda x: check_date("temporary_courtesy_plan", x)) ] diff --git a/corporate/views/upgrade.py b/corporate/views/upgrade.py index 6ebe61ad4f..ea4382972c 100644 --- a/corporate/views/upgrade.py +++ b/corporate/views/upgrade.py @@ -1,19 +1,16 @@ import logging -from typing import Annotated from django.conf import settings from django.http import HttpRequest, HttpResponse, HttpResponseRedirect from django.shortcuts import render from pydantic import Json +from corporate.lib.billing_types import BillingModality, BillingSchedule, LicenseManagement from corporate.lib.decorator import ( authenticated_remote_realm_management_endpoint, authenticated_remote_server_management_endpoint, ) from corporate.lib.stripe import ( - VALID_BILLING_MODALITY_VALUES, - VALID_BILLING_SCHEDULE_VALUES, - VALID_LICENSE_MANAGEMENT_VALUES, BillingError, InitialUpgradeRequest, RealmBillingSession, @@ -25,7 +22,6 @@ from corporate.models import CustomerPlan from zerver.decorator import require_organization_member, zulip_login_required from zerver.lib.response import json_success from zerver.lib.typed_endpoint import typed_endpoint -from zerver.lib.typed_endpoint_validators import check_string_in_validator from zerver.models import UserProfile from zilencer.lib.remote_counts import MissingDataError @@ -38,12 +34,11 @@ def upgrade( request: HttpRequest, user: UserProfile, *, - billing_modality: Annotated[str, check_string_in_validator(VALID_BILLING_MODALITY_VALUES)], - schedule: Annotated[str, check_string_in_validator(VALID_BILLING_SCHEDULE_VALUES)], + billing_modality: BillingModality, + schedule: BillingSchedule, signed_seat_count: str, salt: str, - license_management: Annotated[str, check_string_in_validator(VALID_LICENSE_MANAGEMENT_VALUES)] - | None = None, + license_management: LicenseManagement | None = None, licenses: Json[int] | None = None, tier: Json[int] = CustomerPlan.TIER_CLOUD_STANDARD, ) -> HttpResponse: @@ -88,12 +83,11 @@ def remote_realm_upgrade( request: HttpRequest, billing_session: RemoteRealmBillingSession, *, - billing_modality: Annotated[str, check_string_in_validator(VALID_BILLING_MODALITY_VALUES)], - schedule: Annotated[str, check_string_in_validator(VALID_BILLING_SCHEDULE_VALUES)], + billing_modality: BillingModality, + schedule: BillingSchedule, signed_seat_count: str, salt: str, - license_management: Annotated[str, check_string_in_validator(VALID_LICENSE_MANAGEMENT_VALUES)] - | None = None, + license_management: LicenseManagement | None = None, licenses: Json[int] | None = None, remote_server_plan_start_date: str | None = None, tier: Json[int] = CustomerPlan.TIER_SELF_HOSTED_BUSINESS, @@ -137,12 +131,11 @@ def remote_server_upgrade( request: HttpRequest, billing_session: RemoteServerBillingSession, *, - billing_modality: Annotated[str, check_string_in_validator(VALID_BILLING_MODALITY_VALUES)], - schedule: Annotated[str, check_string_in_validator(VALID_BILLING_SCHEDULE_VALUES)], + billing_modality: BillingModality, + schedule: BillingSchedule, signed_seat_count: str, salt: str, - license_management: Annotated[str, check_string_in_validator(VALID_LICENSE_MANAGEMENT_VALUES)] - | None = None, + license_management: LicenseManagement | None = None, licenses: Json[int] | None = None, remote_server_plan_start_date: str | None = None, tier: Json[int] = CustomerPlan.TIER_SELF_HOSTED_BUSINESS,