From fb9e258a6565864b2943e67e11bfc3eb357fbf17 Mon Sep 17 00:00:00 2001 From: Prakhar Pratyush Date: Tue, 14 Nov 2023 16:29:48 +0530 Subject: [PATCH] stripe: Add 'do_upgrade' method to the 'BillingSession' class. This commit moves a major portion of the 'upgrade` view to a new shared 'BillingSession.do_upgrade' method. This refactoring will help in minimizing duplicate code while supporting both realm and remote_server customers. --- corporate/lib/stripe.py | 105 +++++++++++++++++++++++++++++++++++ corporate/views/upgrade.py | 110 +++++-------------------------------- 2 files changed, 119 insertions(+), 96 deletions(-) diff --git a/corporate/lib/stripe.py b/corporate/lib/stripe.py index 6315f8f91e..407920411e 100644 --- a/corporate/lib/stripe.py +++ b/corporate/lib/stripe.py @@ -12,6 +12,7 @@ from typing import Any, Callable, Dict, Generator, Optional, Tuple, TypeVar, Uni import stripe from django.conf import settings +from django.core import signing from django.core.signing import Signer from django.db import transaction from django.urls import reverse @@ -57,6 +58,10 @@ MIN_INVOICED_LICENSES = 30 MAX_INVOICED_LICENSES = 1000 DEFAULT_INVOICE_DAYS_UNTIL_DUE = 30 +VALID_BILLING_MODALITY_VALUES = ["send_invoice", "charge_automatically"] +VALID_BILLING_SCHEDULE_VALUES = ["annual", "monthly"] +VALID_LICENSE_MANAGEMENT_VALUES = ["automatic", "manual"] + # The version of Stripe API the billing system supports. STRIPE_API_VERSION = "2020-08-27" @@ -101,6 +106,13 @@ def unsign_string(signed_string: str, salt: str) -> str: return signer.unsign(signed_string) +def unsign_seat_count(signed_seat_count: str, salt: str) -> int: + try: + return int(unsign_string(signed_seat_count, salt)) + except signing.BadSignature: + raise BillingError("tampered seat count") + + def validate_licenses( charge_automatically: bool, licenses: Optional[int], @@ -129,6 +141,28 @@ def validate_licenses( raise BillingError("too many licenses", message) +def check_upgrade_parameters( + billing_modality: str, + schedule: str, + license_management: Optional[str], + licenses: Optional[int], + seat_count: int, + exempt_from_license_number_check: bool, +) -> None: + if billing_modality not in VALID_BILLING_MODALITY_VALUES: # nocoverage + raise BillingError("unknown billing_modality", "") + if schedule not in VALID_BILLING_SCHEDULE_VALUES: # nocoverage + raise BillingError("unknown schedule") + if license_management not in VALID_LICENSE_MANAGEMENT_VALUES: # nocoverage + raise BillingError("unknown license_management") + validate_licenses( + billing_modality == "charge_automatically", + licenses, + seat_count, + exempt_from_license_number_check, + ) + + # Be extremely careful changing this function. Historical billing periods # are not stored anywhere, and are just computed on the fly using this # function. Any change you make here should return the same value (or be @@ -335,6 +369,17 @@ class StripePaymentIntentData: email: str +@dataclass +class UpgradeRequest: + billing_modality: str + schedule: str + signed_seat_count: str + salt: str + onboarding: bool + license_management: Optional[str] + licenses: Optional[int] + + class AuditLogEventType(Enum): STRIPE_CUSTOMER_CREATED = 1 STRIPE_CARD_CHANGED = 2 @@ -702,6 +747,66 @@ class BillingSession(ABC): self.do_change_plan_type(tier=plan_tier) + def do_upgrade(self, upgrade_request: UpgradeRequest) -> Dict[str, Any]: + customer = self.get_customer() + if customer is not None: + ensure_customer_does_not_have_active_plan(customer) + billing_modality = upgrade_request.billing_modality + schedule = upgrade_request.schedule + license_management = upgrade_request.license_management + licenses = upgrade_request.licenses + + seat_count = unsign_seat_count(upgrade_request.signed_seat_count, upgrade_request.salt) + if billing_modality == "charge_automatically" and license_management == "automatic": + licenses = seat_count + if billing_modality == "send_invoice": + schedule = "annual" + license_management = "manual" + + exempt_from_license_number_check = ( + customer is not None and customer.exempt_from_license_number_check + ) + check_upgrade_parameters( + billing_modality, + schedule, + license_management, + licenses, + seat_count, + exempt_from_license_number_check, + ) + assert licenses is not None and license_management is not None + automanage_licenses = license_management == "automatic" + charge_automatically = billing_modality == "charge_automatically" + + billing_schedule = {"annual": CustomerPlan.ANNUAL, "monthly": CustomerPlan.MONTHLY}[ + schedule + ] + data: Dict[str, Any] = {} + if charge_automatically: + stripe_checkout_session = self.setup_upgrade_checkout_session_and_payment_intent( + CustomerPlan.STANDARD, + seat_count, + licenses, + license_management, + billing_schedule, + billing_modality, + upgrade_request.onboarding, + ) + data = { + "stripe_session_url": stripe_checkout_session.url, + "stripe_session_id": stripe_checkout_session.id, + } + else: + self.process_initial_upgrade( + CustomerPlan.STANDARD, + licenses, + automanage_licenses, + billing_schedule, + False, + is_free_trial_offer_enabled(), + ) + return data + class RealmBillingSession(BillingSession): def __init__(self, user: UserProfile, realm: Optional[Realm] = None) -> None: diff --git a/corporate/views/upgrade.py b/corporate/views/upgrade.py index 17079db640..685b1e1003 100644 --- a/corporate/views/upgrade.py +++ b/corporate/views/upgrade.py @@ -4,7 +4,6 @@ from typing import Any, Dict, Optional from django import forms from django.conf import settings -from django.core import signing from django.db import transaction from django.http import HttpRequest, HttpResponse, HttpResponseRedirect from django.shortcuts import render @@ -13,18 +12,17 @@ from django.urls import reverse from corporate.lib.stripe import ( DEFAULT_INVOICE_DAYS_UNTIL_DUE, MIN_INVOICED_LICENSES, + VALID_BILLING_MODALITY_VALUES, + VALID_BILLING_SCHEDULE_VALUES, + VALID_LICENSE_MANAGEMENT_VALUES, BillingError, RealmBillingSession, - ensure_customer_does_not_have_active_plan, + UpgradeRequest, get_latest_seat_count, - is_free_trial_offer_enabled, sign_string, - unsign_string, - validate_licenses, ) from corporate.lib.support import get_support_url from corporate.models import ( - CustomerPlan, ZulipSponsorshipRequest, get_current_plan_by_customer, get_customer_by_realm, @@ -40,39 +38,6 @@ from zerver.models import UserProfile, get_org_type_display_name billing_logger = logging.getLogger("corporate.stripe") -VALID_BILLING_MODALITY_VALUES = ["send_invoice", "charge_automatically"] -VALID_BILLING_SCHEDULE_VALUES = ["annual", "monthly"] -VALID_LICENSE_MANAGEMENT_VALUES = ["automatic", "manual"] - - -def unsign_seat_count(signed_seat_count: str, salt: str) -> int: - try: - return int(unsign_string(signed_seat_count, salt)) - except signing.BadSignature: - raise BillingError("tampered seat count") - - -def check_upgrade_parameters( - billing_modality: str, - schedule: str, - license_management: Optional[str], - licenses: Optional[int], - seat_count: int, - exempt_from_license_number_check: bool, -) -> None: - if billing_modality not in VALID_BILLING_MODALITY_VALUES: # nocoverage - raise BillingError("unknown billing_modality", "") - if schedule not in VALID_BILLING_SCHEDULE_VALUES: # nocoverage - raise BillingError("unknown schedule") - if license_management not in VALID_LICENSE_MANAGEMENT_VALUES: # nocoverage - raise BillingError("unknown license_management") - validate_licenses( - billing_modality == "charge_automatically", - licenses, - seat_count, - exempt_from_license_number_check, - ) - @require_organization_member @has_request_variables @@ -89,66 +54,19 @@ def upgrade( ), licenses: Optional[int] = REQ(json_validator=check_int, default=None), ) -> HttpResponse: - customer = get_customer_by_realm(user.realm) - if customer is not None: - ensure_customer_does_not_have_active_plan(customer) try: - seat_count = unsign_seat_count(signed_seat_count, salt) - if billing_modality == "charge_automatically" and license_management == "automatic": - licenses = seat_count - if billing_modality == "send_invoice": - schedule = "annual" - license_management = "manual" - - exempt_from_license_number_check = ( - customer is not None and customer.exempt_from_license_number_check + upgrade_request = UpgradeRequest( + billing_modality=billing_modality, + schedule=schedule, + signed_seat_count=signed_seat_count, + salt=salt, + onboarding=onboarding, + license_management=license_management, + licenses=licenses, ) - check_upgrade_parameters( - billing_modality, - schedule, - license_management, - licenses, - seat_count, - exempt_from_license_number_check, - ) - assert licenses is not None and license_management is not None - automanage_licenses = license_management == "automatic" - charge_automatically = billing_modality == "charge_automatically" - - billing_schedule = {"annual": CustomerPlan.ANNUAL, "monthly": CustomerPlan.MONTHLY}[ - schedule - ] billing_session = RealmBillingSession(user) - if charge_automatically: - stripe_checkout_session = ( - billing_session.setup_upgrade_checkout_session_and_payment_intent( - CustomerPlan.STANDARD, - seat_count, - licenses, - license_management, - billing_schedule, - billing_modality, - onboarding, - ) - ) - return json_success( - request, - data={ - "stripe_session_url": stripe_checkout_session.url, - "stripe_session_id": stripe_checkout_session.id, - }, - ) - else: - billing_session.process_initial_upgrade( - CustomerPlan.STANDARD, - licenses, - automanage_licenses, - billing_schedule, - False, - is_free_trial_offer_enabled(), - ) - return json_success(request) - + data = billing_session.do_upgrade(upgrade_request) + return json_success(request, data) except BillingError as e: billing_logger.warning( "BillingError during upgrade: %s. user=%s, realm=%s (%s), billing_modality=%s, "