billing: Add command for switching plans from Standard to Plus.

This commit is contained in:
Vishnu KS 2021-09-22 00:51:03 +05:30 committed by Tim Abbott
parent 87c1b9e3bc
commit fcab2ea5f7
10 changed files with 214 additions and 8 deletions

View File

@ -376,12 +376,11 @@ def make_end_of_cycle_updates_if_needed(
assert last_ledger_renewal is not None
last_renewal = last_ledger_renewal.event_time
if plan.is_free_trial():
if plan.is_free_trial() or plan.status == CustomerPlan.SWITCH_NOW_FROM_STANDARD_TO_PLUS:
assert plan.next_invoice_date is not None
next_billing_cycle = plan.next_invoice_date
else:
next_billing_cycle = start_of_next_billing_cycle(plan, last_renewal)
if next_billing_cycle <= event_time and last_ledger_entry is not None:
licenses_at_next_renewal = last_ledger_entry.licenses_at_next_renewal
assert licenses_at_next_renewal is not None
@ -457,6 +456,47 @@ def make_end_of_cycle_updates_if_needed(
)
return new_plan, new_plan_ledger_entry
if plan.status == CustomerPlan.SWITCH_NOW_FROM_STANDARD_TO_PLUS:
standard_plan = plan
standard_plan.end_date = next_billing_cycle
standard_plan.status = CustomerPlan.ENDED
standard_plan.save(update_fields=["status", "end_date"])
(_, _, _, plus_plan_price_per_license) = compute_plan_parameters(
CustomerPlan.PLUS,
standard_plan.automanage_licenses,
standard_plan.billing_schedule,
standard_plan.customer.default_discount,
)
plus_plan_billing_cycle_anchor = standard_plan.end_date.replace(microsecond=0)
plus_plan = CustomerPlan.objects.create(
customer=standard_plan.customer,
status=CustomerPlan.ACTIVE,
automanage_licenses=standard_plan.automanage_licenses,
charge_automatically=standard_plan.charge_automatically,
price_per_license=plus_plan_price_per_license,
discount=standard_plan.customer.default_discount,
billing_schedule=standard_plan.billing_schedule,
tier=CustomerPlan.PLUS,
billing_cycle_anchor=plus_plan_billing_cycle_anchor,
invoicing_status=CustomerPlan.INITIAL_INVOICE_TO_BE_SENT,
next_invoice_date=plus_plan_billing_cycle_anchor,
)
standard_plan_last_ledger = (
LicenseLedger.objects.filter(plan=standard_plan).order_by("id").last()
)
licenses_for_plus_plan = standard_plan_last_ledger.licenses_at_next_renewal
plus_plan_ledger_entry = LicenseLedger.objects.create(
plan=plus_plan,
is_renewal=True,
event_time=plus_plan_billing_cycle_anchor,
licenses=licenses_for_plus_plan,
licenses_at_next_renewal=licenses_for_plus_plan,
)
return plus_plan, plus_plan_ledger_entry
if plan.status == CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE:
process_downgrade(plan)
return None, None
@ -743,6 +783,14 @@ def update_license_ledger_if_needed(realm: Realm, event_time: datetime) -> None:
update_license_ledger_for_automanaged_plan(realm, plan, event_time)
def get_plan_renewal_or_end_date(plan: CustomerPlan, event_time: datetime) -> datetime:
billing_period_end = start_of_next_billing_cycle(plan, event_time)
if plan.end_date is not None and plan.end_date < billing_period_end:
return plan.end_date
return billing_period_end
def invoice_plan(plan: CustomerPlan, event_time: datetime) -> None:
if plan.invoicing_status == CustomerPlan.STARTED:
raise NotImplementedError("Plan with invoicing_status==STARTED needs manual resolution.")
@ -777,7 +825,7 @@ def invoice_plan(plan: CustomerPlan, event_time: datetime) -> None:
}
description = f"{plan.name} - renewal"
elif licenses_base is not None and ledger_entry.licenses != licenses_base:
assert plan.price_per_license and ledger_entry is not None
assert plan.price_per_license
last_ledger_entry_renewal = (
LicenseLedger.objects.filter(
plan=plan, is_renewal=True, event_time__lte=ledger_entry.event_time
@ -787,16 +835,18 @@ def invoice_plan(plan: CustomerPlan, event_time: datetime) -> None:
)
assert last_ledger_entry_renewal is not None
last_renewal = last_ledger_entry_renewal.event_time
period_end = start_of_next_billing_cycle(plan, ledger_entry.event_time)
proration_fraction = (period_end - ledger_entry.event_time) / (
period_end - last_renewal
billing_period_end = start_of_next_billing_cycle(plan, ledger_entry.event_time)
plan_renewal_or_end_date = get_plan_renewal_or_end_date(plan, ledger_entry.event_time)
proration_fraction = (plan_renewal_or_end_date - ledger_entry.event_time) / (
billing_period_end - last_renewal
)
price_args = {
"unit_amount": int(plan.price_per_license * proration_fraction + 0.5),
"quantity": ledger_entry.licenses - licenses_base,
}
description = "Additional license ({} - {})".format(
ledger_entry.event_time.strftime("%b %-d, %Y"), period_end.strftime("%b %-d, %Y")
ledger_entry.event_time.strftime("%b %-d, %Y"),
plan_renewal_or_end_date.strftime("%b %-d, %Y"),
)
if price_args:
@ -811,7 +861,7 @@ def invoice_plan(plan: CustomerPlan, event_time: datetime) -> None:
period={
"start": datetime_to_timestamp(ledger_entry.event_time),
"end": datetime_to_timestamp(
start_of_next_billing_cycle(plan, ledger_entry.event_time)
get_plan_renewal_or_end_date(plan, ledger_entry.event_time)
),
},
idempotency_key=get_idempotency_key(ledger_entry),
@ -1071,6 +1121,52 @@ def downgrade_small_realms_behind_on_payments_as_needed() -> None:
void_all_open_invoices(realm)
def switch_realm_from_standard_to_plus_plan(realm: Realm) -> None:
standard_plan = get_current_plan_by_realm(realm)
if (
not standard_plan
or standard_plan.status != CustomerPlan.ACTIVE
or standard_plan.tier != CustomerPlan.STANDARD
):
raise BillingError("Organization does not have an active Standard plan")
if not standard_plan.customer.stripe_customer_id:
raise BillingError("Organization missing Stripe customer.")
plan_switch_time = timezone_now()
standard_plan.status = CustomerPlan.SWITCH_NOW_FROM_STANDARD_TO_PLUS
standard_plan.next_invoice_date = plan_switch_time
standard_plan.save(update_fields=["status", "next_invoice_date"])
standard_plan_next_renewal_date = start_of_next_billing_cycle(standard_plan, plan_switch_time)
standard_plan_last_renewal_ledger = (
LicenseLedger.objects.filter(is_renewal=True, plan=standard_plan).order_by("id").last()
)
standard_plan_last_renewal_amount = (
standard_plan_last_renewal_ledger.licenses * standard_plan.price_per_license
)
standard_plan_last_renewal_date = standard_plan_last_renewal_ledger.event_time
unused_proration_fraction = 1 - (plan_switch_time - standard_plan_last_renewal_date) / (
standard_plan_next_renewal_date - standard_plan_last_renewal_date
)
amount_to_credit_back_to_realm = math.ceil(
standard_plan_last_renewal_amount * unused_proration_fraction
)
stripe.Customer.create_balance_transaction(
standard_plan.customer.stripe_customer_id,
amount=-1 * amount_to_credit_back_to_realm,
currency="usd",
description="Credit from early termination of Standard plan",
)
invoice_plan(standard_plan, plan_switch_time)
plus_plan = get_current_plan_by_realm(realm)
assert plus_plan is not None # for mypy
invoice_plan(plus_plan, plan_switch_time)
def update_billing_method_of_current_plan(
realm: Realm, charge_automatically: bool, *, acting_user: Optional[UserProfile]
) -> None:

View File

@ -0,0 +1,18 @@
# Generated by Django 3.2.6 on 2021-09-17 10:52
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("corporate", "0013_alter_zulipsponsorshiprequest_org_website"),
]
operations = [
migrations.AddField(
model_name="customerplan",
name="end_date",
field=models.DateTimeField(null=True),
),
]

View File

@ -85,6 +85,7 @@ class CustomerPlan(models.Model):
invoiced_through: Optional["LicenseLedger"] = models.ForeignKey(
"LicenseLedger", null=True, on_delete=CASCADE, related_name="+"
)
end_date: Optional[datetime.datetime] = models.DateTimeField(null=True)
DONE = 1
STARTED = 2
@ -103,6 +104,7 @@ class CustomerPlan(models.Model):
DOWNGRADE_AT_END_OF_CYCLE = 2
FREE_TRIAL = 3
SWITCH_TO_ANNUAL_AT_END_OF_CYCLE = 4
SWITCH_NOW_FROM_STANDARD_TO_PLUS = 5
# "Live" plans should have a value < LIVE_STATUS_THRESHOLD.
# There should be at most one live plan per customer.
LIVE_STATUS_THRESHOLD = 10

View File

@ -50,6 +50,7 @@ from corporate.lib.stripe import (
downgrade_small_realms_behind_on_payments_as_needed,
get_discount_for_realm,
get_latest_seat_count,
get_plan_renewal_or_end_date,
get_price_per_license,
get_realms_to_default_discount_dict,
invoice_plan,
@ -62,6 +63,7 @@ from corporate.lib.stripe import (
sign_string,
stripe_customer_has_credit_card_as_default_source,
stripe_get_customer,
switch_realm_from_standard_to_plus_plan,
unsign_string,
update_billing_method_of_current_plan,
update_license_ledger_for_automanaged_plan,
@ -297,6 +299,7 @@ MOCKED_STRIPE_FUNCTION_NAMES = [
"Charge.list",
"Coupon.create",
"Customer.create",
"Customer.create_balance_transaction",
"Customer.retrieve",
"Customer.save",
"Invoice.create",
@ -2977,6 +2980,44 @@ class StripeTest(StripeTestCase):
email_found = True
self.assertEqual(row.email_expected_to_be_sent, email_found)
@mock_stripe()
def test_switch_realm_from_standard_to_plus_plan(self, *mock: Mock) -> None:
realm = get_realm("zulip")
# Test upgrading to Plus when realm has no Standard subscription
with self.assertRaises(BillingError) as billing_context:
switch_realm_from_standard_to_plus_plan(realm)
self.assertEqual(
"Organization does not have an active Standard plan",
billing_context.exception.error_description,
)
plan, ledger = self.subscribe_realm_to_manual_license_management_plan(
realm, 9, 9, CustomerPlan.MONTHLY
)
# Test upgrading to Plus when realm has no stripe_customer_id
with self.assertRaises(BillingError) as billing_context:
switch_realm_from_standard_to_plus_plan(realm)
self.assertEqual(
"Organization missing Stripe customer.", billing_context.exception.error_description
)
plan.customer.stripe_customer_id = "cus_12345"
plan.customer.save(update_fields=["stripe_customer_id"])
plan.price_per_license = get_price_per_license(CustomerPlan.STANDARD, CustomerPlan.MONTHLY)
plan.automanage_licenses = True
plan.invoiced_through = ledger
plan.save(update_fields=["price_per_license", "automanage_licenses", "invoiced_through"])
switch_realm_from_standard_to_plus_plan(realm)
plan.refresh_from_db()
self.assertEqual(plan.status, CustomerPlan.ENDED)
plus_plan = get_current_plan_by_realm(realm)
assert plus_plan is not None
self.assertEqual(plus_plan.tier, CustomerPlan.PLUS)
self.assertEqual(LicenseLedger.objects.filter(plan=plus_plan).count(), 1)
def test_update_billing_method_of_current_plan(self) -> None:
realm = get_realm("zulip")
customer = Customer.objects.create(realm=realm, stripe_customer_id="cus_12345")
@ -3239,6 +3280,28 @@ class BillingHelpersTest(ZulipTestCase):
with self.assertRaisesRegex(InvalidTier, "Unknown tier: 10"):
get_price_per_license(CustomerPlan.ENTERPRISE, CustomerPlan.ANNUAL)
def test_get_plan_renewal_or_end_date(self) -> None:
realm = get_realm("zulip")
customer = Customer.objects.create(realm=realm, stripe_customer_id="cus_12345")
billing_cycle_anchor = timezone_now()
plan = CustomerPlan.objects.create(
customer=customer,
status=CustomerPlan.ACTIVE,
billing_cycle_anchor=billing_cycle_anchor,
billing_schedule=CustomerPlan.MONTHLY,
tier=CustomerPlan.STANDARD,
)
renewal_date = get_plan_renewal_or_end_date(plan, billing_cycle_anchor)
self.assertEqual(renewal_date, add_months(billing_cycle_anchor, 1))
# When the plan ends 2 days before the start of the next billing cycle,
# the function should return the end_date.
plan_end_date = add_months(billing_cycle_anchor, 1) - timedelta(days=2)
plan.end_date = plan_end_date
plan.save(update_fields=["end_date"])
renewal_date = get_plan_renewal_or_end_date(plan, billing_cycle_anchor)
self.assertEqual(renewal_date, plan_end_date)
def test_update_or_create_stripe_customer_logic(self) -> None:
user = self.example_user("hamlet")
# No existing Customer object

View File

@ -39,6 +39,10 @@ class Customer:
def delete_discount(customer: Customer) -> None: ...
@staticmethod
def list(limit: Optional[int] = ...) -> List[Customer]: ...
@staticmethod
def create_balance_transaction(
customer_id: str, amount: int, currency: str, description: str
) -> None: ...
class Invoice:
id: str

View File

@ -0,0 +1,23 @@
from typing import Any
from django.conf import settings
from django.core.management.base import CommandError, CommandParser
from zerver.lib.management import ZulipBaseCommand
if settings.BILLING_ENABLED:
from corporate.lib.stripe import switch_realm_from_standard_to_plus_plan
class Command(ZulipBaseCommand):
def add_arguments(self, parser: CommandParser) -> None:
self.add_realm_args(parser)
def handle(self, *args: Any, **options: Any) -> None:
realm = self.get_realm(options)
if not realm:
raise CommandError("No realm found.")
if settings.BILLING_ENABLED:
switch_realm_from_standard_to_plus_plan(realm)