corporate: Check plan tier for new plan discount calculations.

Now that a customer discount may require a particular plan tier to
be applied, update the billing code to check the plan tier when
getting the customer default_discount field/information for a new
plan.

For billing schedule changes and displaying billing information for
current plans, we explicitly use the discount set on the current,
active plan and do not check the customer object for these actions.
This commit is contained in:
Lauryn Menard 2024-01-10 14:59:49 +01:00 committed by Tim Abbott
parent 7542a676c7
commit fbe7145231
2 changed files with 34 additions and 24 deletions

View File

@ -1210,9 +1210,8 @@ class BillingSession(ABC):
) -> str: ) -> str:
customer = self.update_or_create_stripe_customer() customer = self.update_or_create_stripe_customer()
assert customer is not None # for mypy assert customer is not None # for mypy
price_per_license = get_price_per_license( discount_for_plan = customer.get_discount_for_plan_tier(plan_tier)
plan_tier, billing_schedule, customer.default_discount price_per_license = get_price_per_license(plan_tier, billing_schedule, discount_for_plan)
)
general_metadata = { general_metadata = {
"billing_modality": billing_modality, "billing_modality": billing_modality,
"billing_schedule": billing_schedule, "billing_schedule": billing_schedule,
@ -1286,7 +1285,7 @@ class BillingSession(ABC):
if should_schedule_upgrade_for_legacy_remote_server: if should_schedule_upgrade_for_legacy_remote_server:
assert remote_server_legacy_plan is not None assert remote_server_legacy_plan is not None
billing_cycle_anchor = remote_server_legacy_plan.end_date billing_cycle_anchor = remote_server_legacy_plan.end_date
discount_for_plan = customer.get_discount_for_plan_tier(plan_tier)
( (
billing_cycle_anchor, billing_cycle_anchor,
next_invoice_date, next_invoice_date,
@ -1296,7 +1295,7 @@ class BillingSession(ABC):
plan_tier, plan_tier,
automanage_licenses, automanage_licenses,
billing_schedule, billing_schedule,
customer.default_discount, discount_for_plan,
free_trial, free_trial,
billing_cycle_anchor, billing_cycle_anchor,
is_self_hosted_billing, is_self_hosted_billing,
@ -1316,7 +1315,7 @@ class BillingSession(ABC):
"automanage_licenses": automanage_licenses, "automanage_licenses": automanage_licenses,
"charge_automatically": charge_automatically, "charge_automatically": charge_automatically,
"price_per_license": price_per_license, "price_per_license": price_per_license,
"discount": customer.default_discount, "discount": discount_for_plan,
"billing_cycle_anchor": billing_cycle_anchor, "billing_cycle_anchor": billing_cycle_anchor,
"billing_schedule": billing_schedule, "billing_schedule": billing_schedule,
"tier": plan_tier, "tier": plan_tier,
@ -1544,12 +1543,12 @@ class BillingSession(ABC):
plan.status = CustomerPlan.ENDED plan.status = CustomerPlan.ENDED
plan.save(update_fields=["status"]) plan.save(update_fields=["status"])
discount = plan.customer.default_discount or plan.discount discount_for_current_plan = plan.discount
_, _, _, price_per_license = compute_plan_parameters( _, _, _, price_per_license = compute_plan_parameters(
tier=plan.tier, tier=plan.tier,
automanage_licenses=plan.automanage_licenses, automanage_licenses=plan.automanage_licenses,
billing_schedule=schedule, billing_schedule=schedule,
discount=plan.discount, discount=discount_for_current_plan,
) )
new_plan = CustomerPlan.objects.create( new_plan = CustomerPlan.objects.create(
@ -1558,7 +1557,7 @@ class BillingSession(ABC):
automanage_licenses=plan.automanage_licenses, automanage_licenses=plan.automanage_licenses,
charge_automatically=plan.charge_automatically, charge_automatically=plan.charge_automatically,
price_per_license=price_per_license, price_per_license=price_per_license,
discount=discount, discount=discount_for_current_plan,
billing_cycle_anchor=plan.billing_cycle_anchor, billing_cycle_anchor=plan.billing_cycle_anchor,
tier=plan.tier, tier=plan.tier,
status=CustomerPlan.FREE_TRIAL, status=CustomerPlan.FREE_TRIAL,
@ -1679,12 +1678,12 @@ class BillingSession(ABC):
plan.status = CustomerPlan.ENDED plan.status = CustomerPlan.ENDED
plan.save(update_fields=["status"]) plan.save(update_fields=["status"])
discount = plan.customer.default_discount or plan.discount discount_for_current_plan = plan.discount
_, _, _, price_per_license = compute_plan_parameters( _, _, _, price_per_license = compute_plan_parameters(
tier=plan.tier, tier=plan.tier,
automanage_licenses=plan.automanage_licenses, automanage_licenses=plan.automanage_licenses,
billing_schedule=CustomerPlan.BILLING_SCHEDULE_ANNUAL, billing_schedule=CustomerPlan.BILLING_SCHEDULE_ANNUAL,
discount=plan.discount, discount=discount_for_current_plan,
) )
new_plan = CustomerPlan.objects.create( new_plan = CustomerPlan.objects.create(
@ -1693,7 +1692,7 @@ class BillingSession(ABC):
automanage_licenses=plan.automanage_licenses, automanage_licenses=plan.automanage_licenses,
charge_automatically=plan.charge_automatically, charge_automatically=plan.charge_automatically,
price_per_license=price_per_license, price_per_license=price_per_license,
discount=discount, discount=discount_for_current_plan,
billing_cycle_anchor=next_billing_cycle, billing_cycle_anchor=next_billing_cycle,
tier=plan.tier, tier=plan.tier,
status=CustomerPlan.ACTIVE, status=CustomerPlan.ACTIVE,
@ -1728,12 +1727,12 @@ class BillingSession(ABC):
plan.status = CustomerPlan.ENDED plan.status = CustomerPlan.ENDED
plan.save(update_fields=["status"]) plan.save(update_fields=["status"])
discount = plan.customer.default_discount or plan.discount discount_for_current_plan = plan.discount
_, _, _, price_per_license = compute_plan_parameters( _, _, _, price_per_license = compute_plan_parameters(
tier=plan.tier, tier=plan.tier,
automanage_licenses=plan.automanage_licenses, automanage_licenses=plan.automanage_licenses,
billing_schedule=CustomerPlan.BILLING_SCHEDULE_MONTHLY, billing_schedule=CustomerPlan.BILLING_SCHEDULE_MONTHLY,
discount=plan.discount, discount=discount_for_current_plan,
) )
new_plan = CustomerPlan.objects.create( new_plan = CustomerPlan.objects.create(
@ -1742,7 +1741,7 @@ class BillingSession(ABC):
automanage_licenses=plan.automanage_licenses, automanage_licenses=plan.automanage_licenses,
charge_automatically=plan.charge_automatically, charge_automatically=plan.charge_automatically,
price_per_license=price_per_license, price_per_license=price_per_license,
discount=discount, discount=discount_for_current_plan,
billing_cycle_anchor=next_billing_cycle, billing_cycle_anchor=next_billing_cycle,
tier=plan.tier, tier=plan.tier,
status=CustomerPlan.ACTIVE, status=CustomerPlan.ACTIVE,
@ -1846,18 +1845,19 @@ class BillingSession(ABC):
) )
billing_frequency = CustomerPlan.BILLING_SCHEDULES[plan.billing_schedule] billing_frequency = CustomerPlan.BILLING_SCHEDULES[plan.billing_schedule]
discount_for_current_plan = plan.discount
if switch_to_annual_at_end_of_cycle: if switch_to_annual_at_end_of_cycle:
num_months_next_cycle = 12 num_months_next_cycle = 12
annual_price_per_license = get_price_per_license( annual_price_per_license = get_price_per_license(
plan.tier, CustomerPlan.BILLING_SCHEDULE_ANNUAL, customer.default_discount plan.tier, CustomerPlan.BILLING_SCHEDULE_ANNUAL, discount_for_current_plan
) )
renewal_cents = annual_price_per_license * licenses_at_next_renewal renewal_cents = annual_price_per_license * licenses_at_next_renewal
price_per_license = format_money(annual_price_per_license / 12) price_per_license = format_money(annual_price_per_license / 12)
elif switch_to_monthly_at_end_of_cycle: elif switch_to_monthly_at_end_of_cycle:
num_months_next_cycle = 1 num_months_next_cycle = 1
monthly_price_per_license = get_price_per_license( monthly_price_per_license = get_price_per_license(
plan.tier, CustomerPlan.BILLING_SCHEDULE_MONTHLY, customer.default_discount plan.tier, CustomerPlan.BILLING_SCHEDULE_MONTHLY, discount_for_current_plan
) )
renewal_cents = monthly_price_per_license * licenses_at_next_renewal renewal_cents = monthly_price_per_license * licenses_at_next_renewal
price_per_license = format_money(monthly_price_per_license) price_per_license = format_money(monthly_price_per_license)
@ -1930,7 +1930,7 @@ class BillingSession(ABC):
"sponsorship_plan_name": self.get_sponsorship_plan_name( "sponsorship_plan_name": self.get_sponsorship_plan_name(
customer, is_self_hosted_billing customer, is_self_hosted_billing
), ),
"discount_percent": format_discount_percentage(customer.default_discount), "discount_percent": format_discount_percentage(discount_for_current_plan),
"is_self_hosted_billing": is_self_hosted_billing, "is_self_hosted_billing": is_self_hosted_billing,
"is_server_on_legacy_plan": remote_server_legacy_plan_end_date is not None, "is_server_on_legacy_plan": remote_server_legacy_plan_end_date is not None,
"remote_server_legacy_plan_end_date": remote_server_legacy_plan_end_date, "remote_server_legacy_plan_end_date": remote_server_legacy_plan_end_date,
@ -2020,10 +2020,6 @@ class BillingSession(ABC):
if customer_plan is not None: if customer_plan is not None:
return f"{self.billing_session_url}/billing", None return f"{self.billing_session_url}/billing", None
percent_off = Decimal(0)
if customer is not None and customer.default_discount is not None:
percent_off = customer.default_discount
exempt_from_license_number_check = ( exempt_from_license_number_check = (
customer is not None and customer.exempt_from_license_number_check customer is not None and customer.exempt_from_license_number_check
) )
@ -2038,6 +2034,13 @@ class BillingSession(ABC):
current_payment_method = None if "ending in" not in payment_method else payment_method current_payment_method = None if "ending in" not in payment_method else payment_method
tier = initial_upgrade_request.tier tier = initial_upgrade_request.tier
percent_off = Decimal(0)
if customer is not None:
discount_for_plan_tier = customer.get_discount_for_plan_tier(tier)
if discount_for_plan_tier is not None:
percent_off = discount_for_plan_tier
customer_specific_context = self.get_upgrade_page_session_type_specific_context() customer_specific_context = self.get_upgrade_page_session_type_specific_context()
min_licenses_for_plan = self.min_licenses_for_plan(tier) min_licenses_for_plan = self.min_licenses_for_plan(tier)
seat_count = self.current_count_for_billed_licenses() seat_count = self.current_count_for_billed_licenses()
@ -2326,8 +2329,9 @@ class BillingSession(ABC):
current_plan.status = CustomerPlan.ENDED current_plan.status = CustomerPlan.ENDED
current_plan.save(update_fields=["status", "end_date"]) current_plan.save(update_fields=["status", "end_date"])
discount_for_new_plan_tier = current_plan.customer.get_discount_for_plan_tier(new_plan_tier)
new_price_per_license = get_price_per_license( new_price_per_license = get_price_per_license(
new_plan_tier, current_plan.billing_schedule, current_plan.customer.default_discount new_plan_tier, current_plan.billing_schedule, discount_for_new_plan_tier
) )
new_plan_billing_cycle_anchor = current_plan.end_date.replace(microsecond=0) new_plan_billing_cycle_anchor = current_plan.end_date.replace(microsecond=0)
@ -2338,7 +2342,7 @@ class BillingSession(ABC):
automanage_licenses=current_plan.automanage_licenses, automanage_licenses=current_plan.automanage_licenses,
charge_automatically=current_plan.charge_automatically, charge_automatically=current_plan.charge_automatically,
price_per_license=new_price_per_license, price_per_license=new_price_per_license,
discount=current_plan.customer.default_discount, discount=discount_for_new_plan_tier,
billing_schedule=current_plan.billing_schedule, billing_schedule=current_plan.billing_schedule,
tier=new_plan_tier, tier=new_plan_tier,
billing_cycle_anchor=new_plan_billing_cycle_anchor, billing_cycle_anchor=new_plan_billing_cycle_anchor,

View File

@ -1,3 +1,4 @@
from decimal import Decimal
from enum import Enum from enum import Enum
from typing import Any, Dict, Optional, Union from typing import Any, Dict, Optional, Union
@ -61,6 +62,11 @@ class Customer(models.Model):
else: else:
return f"{self.remote_server!r} (with stripe_customer_id: {self.stripe_customer_id})" return f"{self.remote_server!r} (with stripe_customer_id: {self.stripe_customer_id})"
def get_discount_for_plan_tier(self, plan_tier: int) -> Optional[Decimal]:
if self.required_plan_tier is None or self.required_plan_tier == plan_tier:
return self.default_discount
return None # nocoverage
def get_customer_by_realm(realm: Realm) -> Optional[Customer]: def get_customer_by_realm(realm: Realm) -> Optional[Customer]:
return Customer.objects.filter(realm=realm).first() return Customer.objects.filter(realm=realm).first()