stripe: Add 'do_upgrade' method to the 'BillingSession' class.

This commit moves a major portion of the 'upgrade`
view to a new shared 'BillingSession.do_upgrade' method.

This refactoring will help in minimizing duplicate code
while supporting both realm and remote_server customers.
This commit is contained in:
Prakhar Pratyush 2023-11-14 16:29:48 +05:30 committed by Tim Abbott
parent e18c180414
commit fb9e258a65
2 changed files with 119 additions and 96 deletions

View File

@ -12,6 +12,7 @@ from typing import Any, Callable, Dict, Generator, Optional, Tuple, TypeVar, Uni
import stripe import stripe
from django.conf import settings from django.conf import settings
from django.core import signing
from django.core.signing import Signer from django.core.signing import Signer
from django.db import transaction from django.db import transaction
from django.urls import reverse from django.urls import reverse
@ -57,6 +58,10 @@ MIN_INVOICED_LICENSES = 30
MAX_INVOICED_LICENSES = 1000 MAX_INVOICED_LICENSES = 1000
DEFAULT_INVOICE_DAYS_UNTIL_DUE = 30 DEFAULT_INVOICE_DAYS_UNTIL_DUE = 30
VALID_BILLING_MODALITY_VALUES = ["send_invoice", "charge_automatically"]
VALID_BILLING_SCHEDULE_VALUES = ["annual", "monthly"]
VALID_LICENSE_MANAGEMENT_VALUES = ["automatic", "manual"]
# The version of Stripe API the billing system supports. # The version of Stripe API the billing system supports.
STRIPE_API_VERSION = "2020-08-27" STRIPE_API_VERSION = "2020-08-27"
@ -101,6 +106,13 @@ def unsign_string(signed_string: str, salt: str) -> str:
return signer.unsign(signed_string) return signer.unsign(signed_string)
def unsign_seat_count(signed_seat_count: str, salt: str) -> int:
try:
return int(unsign_string(signed_seat_count, salt))
except signing.BadSignature:
raise BillingError("tampered seat count")
def validate_licenses( def validate_licenses(
charge_automatically: bool, charge_automatically: bool,
licenses: Optional[int], licenses: Optional[int],
@ -129,6 +141,28 @@ def validate_licenses(
raise BillingError("too many licenses", message) raise BillingError("too many licenses", message)
def check_upgrade_parameters(
billing_modality: str,
schedule: str,
license_management: Optional[str],
licenses: Optional[int],
seat_count: int,
exempt_from_license_number_check: bool,
) -> None:
if billing_modality not in VALID_BILLING_MODALITY_VALUES: # nocoverage
raise BillingError("unknown billing_modality", "")
if schedule not in VALID_BILLING_SCHEDULE_VALUES: # nocoverage
raise BillingError("unknown schedule")
if license_management not in VALID_LICENSE_MANAGEMENT_VALUES: # nocoverage
raise BillingError("unknown license_management")
validate_licenses(
billing_modality == "charge_automatically",
licenses,
seat_count,
exempt_from_license_number_check,
)
# Be extremely careful changing this function. Historical billing periods # Be extremely careful changing this function. Historical billing periods
# are not stored anywhere, and are just computed on the fly using this # are not stored anywhere, and are just computed on the fly using this
# function. Any change you make here should return the same value (or be # function. Any change you make here should return the same value (or be
@ -335,6 +369,17 @@ class StripePaymentIntentData:
email: str email: str
@dataclass
class UpgradeRequest:
billing_modality: str
schedule: str
signed_seat_count: str
salt: str
onboarding: bool
license_management: Optional[str]
licenses: Optional[int]
class AuditLogEventType(Enum): class AuditLogEventType(Enum):
STRIPE_CUSTOMER_CREATED = 1 STRIPE_CUSTOMER_CREATED = 1
STRIPE_CARD_CHANGED = 2 STRIPE_CARD_CHANGED = 2
@ -702,6 +747,66 @@ class BillingSession(ABC):
self.do_change_plan_type(tier=plan_tier) self.do_change_plan_type(tier=plan_tier)
def do_upgrade(self, upgrade_request: UpgradeRequest) -> Dict[str, Any]:
customer = self.get_customer()
if customer is not None:
ensure_customer_does_not_have_active_plan(customer)
billing_modality = upgrade_request.billing_modality
schedule = upgrade_request.schedule
license_management = upgrade_request.license_management
licenses = upgrade_request.licenses
seat_count = unsign_seat_count(upgrade_request.signed_seat_count, upgrade_request.salt)
if billing_modality == "charge_automatically" and license_management == "automatic":
licenses = seat_count
if billing_modality == "send_invoice":
schedule = "annual"
license_management = "manual"
exempt_from_license_number_check = (
customer is not None and customer.exempt_from_license_number_check
)
check_upgrade_parameters(
billing_modality,
schedule,
license_management,
licenses,
seat_count,
exempt_from_license_number_check,
)
assert licenses is not None and license_management is not None
automanage_licenses = license_management == "automatic"
charge_automatically = billing_modality == "charge_automatically"
billing_schedule = {"annual": CustomerPlan.ANNUAL, "monthly": CustomerPlan.MONTHLY}[
schedule
]
data: Dict[str, Any] = {}
if charge_automatically:
stripe_checkout_session = self.setup_upgrade_checkout_session_and_payment_intent(
CustomerPlan.STANDARD,
seat_count,
licenses,
license_management,
billing_schedule,
billing_modality,
upgrade_request.onboarding,
)
data = {
"stripe_session_url": stripe_checkout_session.url,
"stripe_session_id": stripe_checkout_session.id,
}
else:
self.process_initial_upgrade(
CustomerPlan.STANDARD,
licenses,
automanage_licenses,
billing_schedule,
False,
is_free_trial_offer_enabled(),
)
return data
class RealmBillingSession(BillingSession): class RealmBillingSession(BillingSession):
def __init__(self, user: UserProfile, realm: Optional[Realm] = None) -> None: def __init__(self, user: UserProfile, realm: Optional[Realm] = None) -> None:

View File

@ -4,7 +4,6 @@ from typing import Any, Dict, Optional
from django import forms from django import forms
from django.conf import settings from django.conf import settings
from django.core import signing
from django.db import transaction from django.db import transaction
from django.http import HttpRequest, HttpResponse, HttpResponseRedirect from django.http import HttpRequest, HttpResponse, HttpResponseRedirect
from django.shortcuts import render from django.shortcuts import render
@ -13,18 +12,17 @@ from django.urls import reverse
from corporate.lib.stripe import ( from corporate.lib.stripe import (
DEFAULT_INVOICE_DAYS_UNTIL_DUE, DEFAULT_INVOICE_DAYS_UNTIL_DUE,
MIN_INVOICED_LICENSES, MIN_INVOICED_LICENSES,
VALID_BILLING_MODALITY_VALUES,
VALID_BILLING_SCHEDULE_VALUES,
VALID_LICENSE_MANAGEMENT_VALUES,
BillingError, BillingError,
RealmBillingSession, RealmBillingSession,
ensure_customer_does_not_have_active_plan, UpgradeRequest,
get_latest_seat_count, get_latest_seat_count,
is_free_trial_offer_enabled,
sign_string, sign_string,
unsign_string,
validate_licenses,
) )
from corporate.lib.support import get_support_url from corporate.lib.support import get_support_url
from corporate.models import ( from corporate.models import (
CustomerPlan,
ZulipSponsorshipRequest, ZulipSponsorshipRequest,
get_current_plan_by_customer, get_current_plan_by_customer,
get_customer_by_realm, get_customer_by_realm,
@ -40,39 +38,6 @@ from zerver.models import UserProfile, get_org_type_display_name
billing_logger = logging.getLogger("corporate.stripe") billing_logger = logging.getLogger("corporate.stripe")
VALID_BILLING_MODALITY_VALUES = ["send_invoice", "charge_automatically"]
VALID_BILLING_SCHEDULE_VALUES = ["annual", "monthly"]
VALID_LICENSE_MANAGEMENT_VALUES = ["automatic", "manual"]
def unsign_seat_count(signed_seat_count: str, salt: str) -> int:
try:
return int(unsign_string(signed_seat_count, salt))
except signing.BadSignature:
raise BillingError("tampered seat count")
def check_upgrade_parameters(
billing_modality: str,
schedule: str,
license_management: Optional[str],
licenses: Optional[int],
seat_count: int,
exempt_from_license_number_check: bool,
) -> None:
if billing_modality not in VALID_BILLING_MODALITY_VALUES: # nocoverage
raise BillingError("unknown billing_modality", "")
if schedule not in VALID_BILLING_SCHEDULE_VALUES: # nocoverage
raise BillingError("unknown schedule")
if license_management not in VALID_LICENSE_MANAGEMENT_VALUES: # nocoverage
raise BillingError("unknown license_management")
validate_licenses(
billing_modality == "charge_automatically",
licenses,
seat_count,
exempt_from_license_number_check,
)
@require_organization_member @require_organization_member
@has_request_variables @has_request_variables
@ -89,66 +54,19 @@ def upgrade(
), ),
licenses: Optional[int] = REQ(json_validator=check_int, default=None), licenses: Optional[int] = REQ(json_validator=check_int, default=None),
) -> HttpResponse: ) -> HttpResponse:
customer = get_customer_by_realm(user.realm)
if customer is not None:
ensure_customer_does_not_have_active_plan(customer)
try: try:
seat_count = unsign_seat_count(signed_seat_count, salt) upgrade_request = UpgradeRequest(
if billing_modality == "charge_automatically" and license_management == "automatic": billing_modality=billing_modality,
licenses = seat_count schedule=schedule,
if billing_modality == "send_invoice": signed_seat_count=signed_seat_count,
schedule = "annual" salt=salt,
license_management = "manual" onboarding=onboarding,
license_management=license_management,
exempt_from_license_number_check = ( licenses=licenses,
customer is not None and customer.exempt_from_license_number_check
) )
check_upgrade_parameters(
billing_modality,
schedule,
license_management,
licenses,
seat_count,
exempt_from_license_number_check,
)
assert licenses is not None and license_management is not None
automanage_licenses = license_management == "automatic"
charge_automatically = billing_modality == "charge_automatically"
billing_schedule = {"annual": CustomerPlan.ANNUAL, "monthly": CustomerPlan.MONTHLY}[
schedule
]
billing_session = RealmBillingSession(user) billing_session = RealmBillingSession(user)
if charge_automatically: data = billing_session.do_upgrade(upgrade_request)
stripe_checkout_session = ( return json_success(request, data)
billing_session.setup_upgrade_checkout_session_and_payment_intent(
CustomerPlan.STANDARD,
seat_count,
licenses,
license_management,
billing_schedule,
billing_modality,
onboarding,
)
)
return json_success(
request,
data={
"stripe_session_url": stripe_checkout_session.url,
"stripe_session_id": stripe_checkout_session.id,
},
)
else:
billing_session.process_initial_upgrade(
CustomerPlan.STANDARD,
licenses,
automanage_licenses,
billing_schedule,
False,
is_free_trial_offer_enabled(),
)
return json_success(request)
except BillingError as e: except BillingError as e:
billing_logger.warning( billing_logger.warning(
"BillingError during upgrade: %s. user=%s, realm=%s (%s), billing_modality=%s, " "BillingError during upgrade: %s. user=%s, realm=%s (%s), billing_modality=%s, "