corporate: Use Literal types for upgrade parameters.

Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
Anders Kaseorg 2024-09-24 14:39:58 -07:00 committed by Tim Abbott
parent 88782f2917
commit 5e62903d29
4 changed files with 33 additions and 47 deletions

View File

@ -0,0 +1,5 @@
from typing import Literal
BillingModality = Literal["send_invoice", "charge_automatically"]
BillingSchedule = Literal["annual", "monthly"]
LicenseManagement = Literal["automatic", "manual"]

View File

@ -27,6 +27,7 @@ from django.utils.translation import gettext_lazy
from django.utils.translation import override as override_language from django.utils.translation import override as override_language
from typing_extensions import ParamSpec, override from typing_extensions import ParamSpec, override
from corporate.lib.billing_types import BillingModality, BillingSchedule, LicenseManagement
from corporate.models import ( from corporate.models import (
Customer, Customer,
CustomerPlan, CustomerPlan,
@ -90,10 +91,6 @@ MIN_INVOICED_LICENSES = 30
MAX_INVOICED_LICENSES = 1000 MAX_INVOICED_LICENSES = 1000
DEFAULT_INVOICE_DAYS_UNTIL_DUE = 15 DEFAULT_INVOICE_DAYS_UNTIL_DUE = 15
VALID_BILLING_MODALITY_VALUES = ["send_invoice", "charge_automatically"]
VALID_BILLING_SCHEDULE_VALUES = ["annual", "monthly"]
VALID_LICENSE_MANAGEMENT_VALUES = ["automatic", "manual"]
CARD_CAPITALIZATION = { CARD_CAPITALIZATION = {
"amex": "American Express", "amex": "American Express",
"diners": "Diners Club", "diners": "Diners Club",
@ -232,19 +229,15 @@ def validate_licenses(
def check_upgrade_parameters( def check_upgrade_parameters(
billing_modality: str, billing_modality: BillingModality,
schedule: str, schedule: BillingSchedule,
license_management: str | None, license_management: LicenseManagement | None,
licenses: int | None, licenses: int | None,
seat_count: int, seat_count: int,
exempt_from_license_number_check: bool, exempt_from_license_number_check: bool,
min_licenses_for_plan: int, min_licenses_for_plan: int,
) -> None: ) -> None:
if billing_modality not in VALID_BILLING_MODALITY_VALUES: # nocoverage if license_management is None: # nocoverage
raise BillingError("unknown billing_modality", "")
if schedule not in VALID_BILLING_SCHEDULE_VALUES: # nocoverage
raise BillingError("unknown schedule")
if license_management not in VALID_LICENSE_MANAGEMENT_VALUES: # nocoverage
raise BillingError("unknown license_management") raise BillingError("unknown license_management")
validate_licenses( validate_licenses(
billing_modality == "charge_automatically", billing_modality == "charge_automatically",
@ -541,11 +534,11 @@ class StripeCustomerData:
@dataclass @dataclass
class UpgradeRequest: class UpgradeRequest:
billing_modality: str billing_modality: BillingModality
schedule: str schedule: BillingSchedule
signed_seat_count: str signed_seat_count: str
salt: str salt: str
license_management: str | None license_management: LicenseManagement | None
licenses: int | None licenses: int | None
tier: int tier: int
remote_server_plan_start_date: str | None remote_server_plan_start_date: str | None
@ -592,7 +585,7 @@ class SupportViewRequest(TypedDict, total=False):
sponsorship_status: bool | None sponsorship_status: bool | None
monthly_discounted_price: int | None monthly_discounted_price: int | None
annual_discounted_price: int | None annual_discounted_price: int | None
billing_modality: str | None billing_modality: BillingModality | None
plan_modification: str | None plan_modification: str | None
new_plan_tier: int | None new_plan_tier: int | None
minimum_licenses: int | None minimum_licenses: int | None
@ -3497,7 +3490,6 @@ class BillingSession(ABC):
success_message = self.configure_temporary_courtesy_plan(temporary_plan_end_date) success_message = self.configure_temporary_courtesy_plan(temporary_plan_end_date)
elif support_type == SupportType.update_billing_modality: elif support_type == SupportType.update_billing_modality:
assert support_request["billing_modality"] is not None assert support_request["billing_modality"] is not None
assert support_request["billing_modality"] in VALID_BILLING_MODALITY_VALUES
charge_automatically = support_request["billing_modality"] == "charge_automatically" charge_automatically = support_request["billing_modality"] == "charge_automatically"
success_message = self.update_billing_modality_of_current_plan(charge_automatically) success_message = self.update_billing_modality_of_current_plan(charge_automatically)
elif support_type == SupportType.update_plan_end_date: elif support_type == SupportType.update_plan_end_date:

View File

@ -23,6 +23,7 @@ from pydantic import AfterValidator, Json, NonNegativeInt
from confirmation.models import Confirmation, confirmation_url from confirmation.models import Confirmation, confirmation_url
from confirmation.settings import STATUS_USED from confirmation.settings import STATUS_USED
from corporate.lib.activity import format_optional_datetime, remote_installation_stats_link from corporate.lib.activity import format_optional_datetime, remote_installation_stats_link
from corporate.lib.billing_types import BillingModality
from corporate.lib.stripe import ( from corporate.lib.stripe import (
BILLING_SUPPORT_EMAIL, BILLING_SUPPORT_EMAIL,
RealmBillingSession, RealmBillingSession,
@ -330,19 +331,14 @@ def check_update_max_invites(realm: Realm, new_max: int, default_max: int) -> bo
return new_max > default_max return new_max > default_max
VALID_MODIFY_PLAN_METHODS = Literal[ ModifyPlan = Literal[
"downgrade_at_billing_cycle_end", "downgrade_at_billing_cycle_end",
"downgrade_now_without_additional_licenses", "downgrade_now_without_additional_licenses",
"downgrade_now_void_open_invoices", "downgrade_now_void_open_invoices",
"upgrade_plan_tier", "upgrade_plan_tier",
] ]
VALID_STATUS_VALUES = Literal["active", "deactivated"] RemoteServerStatus = Literal["active", "deactivated"]
VALID_BILLING_MODALITY_VALUES = Literal[
"send_invoice",
"charge_automatically",
]
SHARED_SUPPORT_CONTEXT = { SHARED_SUPPORT_CONTEXT = {
"get_org_type_display_name": get_org_type_display_name, "get_org_type_display_name": get_org_type_display_name,
@ -363,11 +359,11 @@ def support(
minimum_licenses: Json[NonNegativeInt] | None = None, minimum_licenses: Json[NonNegativeInt] | None = None,
required_plan_tier: Json[NonNegativeInt] | None = None, required_plan_tier: Json[NonNegativeInt] | None = None,
new_subdomain: str | None = None, new_subdomain: str | None = None,
status: VALID_STATUS_VALUES | None = None, status: RemoteServerStatus | None = None,
billing_modality: VALID_BILLING_MODALITY_VALUES | None = None, billing_modality: BillingModality | None = None,
sponsorship_pending: Json[bool] | None = None, sponsorship_pending: Json[bool] | None = None,
approve_sponsorship: Json[bool] = False, approve_sponsorship: Json[bool] = False,
modify_plan: VALID_MODIFY_PLAN_METHODS | None = None, modify_plan: ModifyPlan | None = None,
scrub_realm: Json[bool] = False, scrub_realm: Json[bool] = False,
delete_user_by_id: Json[NonNegativeInt] | None = None, delete_user_by_id: Json[NonNegativeInt] | None = None,
query: Annotated[str | None, ApiParamConfig("q")] = None, query: Annotated[str | None, ApiParamConfig("q")] = None,
@ -685,12 +681,12 @@ def remote_servers_support(
sent_invoice_id: str | None = None, sent_invoice_id: str | None = None,
sponsorship_pending: Json[bool] | None = None, sponsorship_pending: Json[bool] | None = None,
approve_sponsorship: Json[bool] = False, approve_sponsorship: Json[bool] = False,
billing_modality: VALID_BILLING_MODALITY_VALUES | None = None, billing_modality: BillingModality | None = None,
plan_end_date: Annotated[str, AfterValidator(lambda x: check_date("plan_end_date", x))] plan_end_date: Annotated[str, AfterValidator(lambda x: check_date("plan_end_date", x))]
| None = None, | None = None,
modify_plan: VALID_MODIFY_PLAN_METHODS | None = None, modify_plan: ModifyPlan | None = None,
delete_fixed_price_next_plan: Json[bool] = False, delete_fixed_price_next_plan: Json[bool] = False,
remote_server_status: VALID_STATUS_VALUES | None = None, remote_server_status: RemoteServerStatus | None = None,
temporary_courtesy_plan: Annotated[ temporary_courtesy_plan: Annotated[
str, AfterValidator(lambda x: check_date("temporary_courtesy_plan", x)) str, AfterValidator(lambda x: check_date("temporary_courtesy_plan", x))
] ]

View File

@ -1,19 +1,16 @@
import logging import logging
from typing import Annotated
from django.conf import settings from django.conf import settings
from django.http import HttpRequest, HttpResponse, HttpResponseRedirect from django.http import HttpRequest, HttpResponse, HttpResponseRedirect
from django.shortcuts import render from django.shortcuts import render
from pydantic import Json from pydantic import Json
from corporate.lib.billing_types import BillingModality, BillingSchedule, LicenseManagement
from corporate.lib.decorator import ( from corporate.lib.decorator import (
authenticated_remote_realm_management_endpoint, authenticated_remote_realm_management_endpoint,
authenticated_remote_server_management_endpoint, authenticated_remote_server_management_endpoint,
) )
from corporate.lib.stripe import ( from corporate.lib.stripe import (
VALID_BILLING_MODALITY_VALUES,
VALID_BILLING_SCHEDULE_VALUES,
VALID_LICENSE_MANAGEMENT_VALUES,
BillingError, BillingError,
InitialUpgradeRequest, InitialUpgradeRequest,
RealmBillingSession, RealmBillingSession,
@ -25,7 +22,6 @@ from corporate.models import CustomerPlan
from zerver.decorator import require_organization_member, zulip_login_required from zerver.decorator import require_organization_member, zulip_login_required
from zerver.lib.response import json_success from zerver.lib.response import json_success
from zerver.lib.typed_endpoint import typed_endpoint from zerver.lib.typed_endpoint import typed_endpoint
from zerver.lib.typed_endpoint_validators import check_string_in_validator
from zerver.models import UserProfile from zerver.models import UserProfile
from zilencer.lib.remote_counts import MissingDataError from zilencer.lib.remote_counts import MissingDataError
@ -38,12 +34,11 @@ def upgrade(
request: HttpRequest, request: HttpRequest,
user: UserProfile, user: UserProfile,
*, *,
billing_modality: Annotated[str, check_string_in_validator(VALID_BILLING_MODALITY_VALUES)], billing_modality: BillingModality,
schedule: Annotated[str, check_string_in_validator(VALID_BILLING_SCHEDULE_VALUES)], schedule: BillingSchedule,
signed_seat_count: str, signed_seat_count: str,
salt: str, salt: str,
license_management: Annotated[str, check_string_in_validator(VALID_LICENSE_MANAGEMENT_VALUES)] license_management: LicenseManagement | None = None,
| None = None,
licenses: Json[int] | None = None, licenses: Json[int] | None = None,
tier: Json[int] = CustomerPlan.TIER_CLOUD_STANDARD, tier: Json[int] = CustomerPlan.TIER_CLOUD_STANDARD,
) -> HttpResponse: ) -> HttpResponse:
@ -88,12 +83,11 @@ def remote_realm_upgrade(
request: HttpRequest, request: HttpRequest,
billing_session: RemoteRealmBillingSession, billing_session: RemoteRealmBillingSession,
*, *,
billing_modality: Annotated[str, check_string_in_validator(VALID_BILLING_MODALITY_VALUES)], billing_modality: BillingModality,
schedule: Annotated[str, check_string_in_validator(VALID_BILLING_SCHEDULE_VALUES)], schedule: BillingSchedule,
signed_seat_count: str, signed_seat_count: str,
salt: str, salt: str,
license_management: Annotated[str, check_string_in_validator(VALID_LICENSE_MANAGEMENT_VALUES)] license_management: LicenseManagement | None = None,
| None = None,
licenses: Json[int] | None = None, licenses: Json[int] | None = None,
remote_server_plan_start_date: str | None = None, remote_server_plan_start_date: str | None = None,
tier: Json[int] = CustomerPlan.TIER_SELF_HOSTED_BUSINESS, tier: Json[int] = CustomerPlan.TIER_SELF_HOSTED_BUSINESS,
@ -137,12 +131,11 @@ def remote_server_upgrade(
request: HttpRequest, request: HttpRequest,
billing_session: RemoteServerBillingSession, billing_session: RemoteServerBillingSession,
*, *,
billing_modality: Annotated[str, check_string_in_validator(VALID_BILLING_MODALITY_VALUES)], billing_modality: BillingModality,
schedule: Annotated[str, check_string_in_validator(VALID_BILLING_SCHEDULE_VALUES)], schedule: BillingSchedule,
signed_seat_count: str, signed_seat_count: str,
salt: str, salt: str,
license_management: Annotated[str, check_string_in_validator(VALID_LICENSE_MANAGEMENT_VALUES)] license_management: LicenseManagement | None = None,
| None = None,
licenses: Json[int] | None = None, licenses: Json[int] | None = None,
remote_server_plan_start_date: str | None = None, remote_server_plan_start_date: str | None = None,
tier: Json[int] = CustomerPlan.TIER_SELF_HOSTED_BUSINESS, tier: Json[int] = CustomerPlan.TIER_SELF_HOSTED_BUSINESS,