billing: Do subscription management in-house instead of with Stripe Billing.

This is a major rewrite of the billing system. It moves subscription
information off of stripe Subscriptions and into a local CustomerPlan
table.

To keep this manageable, it leaves several things unimplemented
(downgrading, etc), and a variety of other TODOs in the code. There are also
some known regressions, e.g. error-handling on /upgrade is broken.
This commit is contained in:
Rishi Gupta 2018-12-15 00:33:25 -08:00
parent 5633049292
commit e7220fd71f
107 changed files with 654 additions and 456 deletions

View File

@ -495,21 +495,12 @@ def realm_summary_table(realm_minutes: Dict[str, float]) -> str:
# estimate annual subscription revenue # estimate annual subscription revenue
total_amount = 0 total_amount = 0
if settings.BILLING_ENABLED: if settings.BILLING_ENABLED:
from corporate.lib.stripe import estimate_customer_arr from corporate.lib.stripe import estimate_annual_recurring_revenue_by_realm
from corporate.models import Customer estimated_arrs = estimate_annual_recurring_revenue_by_realm()
stripe.api_key = get_secret('stripe_secret_key')
estimated_arr = {}
try:
for stripe_customer in stripe.Customer.list(limit=100):
# TODO: could do a select_related to get the realm.string_id, potentially
customer = Customer.objects.filter(stripe_customer_id=stripe_customer.id).first()
if customer is not None:
estimated_arr[customer.realm.string_id] = estimate_customer_arr(stripe_customer)
except stripe.error.StripeError:
pass
for row in rows: for row in rows:
row['amount'] = estimated_arr.get(row['string_id'], None) if row['string_id'] in estimated_arrs:
total_amount = sum(estimated_arr.values()) row['amount'] = estimated_arrs[row['string_id']]
total_amount += sum(estimated_arrs.values())
# augment data with realm_minutes # augment data with realm_minutes
total_hours = 0.0 total_hours = 0.0

View File

@ -1,9 +1,9 @@
import datetime from datetime import datetime
from decimal import Decimal from decimal import Decimal
from functools import wraps from functools import wraps
import logging import logging
import os import os
from typing import Any, Callable, Dict, Optional, TypeVar, Tuple from typing import Any, Callable, Dict, Optional, TypeVar, Tuple, cast
import ujson import ujson
from django.conf import settings from django.conf import settings
@ -19,7 +19,8 @@ from zerver.lib.timestamp import datetime_to_timestamp, timestamp_to_datetime
from zerver.lib.utils import generate_random_token from zerver.lib.utils import generate_random_token
from zerver.lib.actions import do_change_plan_type from zerver.lib.actions import do_change_plan_type
from zerver.models import Realm, UserProfile, RealmAuditLog from zerver.models import Realm, UserProfile, RealmAuditLog
from corporate.models import Customer, CustomerPlan, Plan, Coupon from corporate.models import Customer, CustomerPlan, Plan, Coupon, \
get_active_plan
from zproject.settings import get_secret from zproject.settings import get_secret
STRIPE_PUBLISHABLE_KEY = get_secret('stripe_publishable_key') STRIPE_PUBLISHABLE_KEY = get_secret('stripe_publishable_key')
@ -50,6 +51,61 @@ def unsign_string(signed_string: str, salt: str) -> str:
signer = Signer(salt=salt) signer = Signer(salt=salt)
return signer.unsign(signed_string) return signer.unsign(signed_string)
# Be extremely careful changing this function. Historical billing periods
# are not stored anywhere, and are just computed on the fly using this
# function. Any change you make here should return the same value (or be
# within a few seconds) for basically any value from when the billing system
# went online to within a year from now.
def add_months(dt: datetime, months: int) -> datetime:
assert(months >= 0)
# It's fine that the max day in Feb is 28 for leap years.
MAX_DAY_FOR_MONTH = {1: 31, 2: 28, 3: 31, 4: 30, 5: 31, 6: 30,
7: 31, 8: 31, 9: 30, 10: 31, 11: 30, 12: 31}
year = dt.year
month = dt.month + months
while month > 12:
year += 1
month -= 12
day = min(dt.day, MAX_DAY_FOR_MONTH[month])
# datetimes don't support leap seconds, so don't need to worry about those
return dt.replace(year=year, month=month, day=day)
def next_month(billing_cycle_anchor: datetime, dt: datetime) -> datetime:
estimated_months = round((dt - billing_cycle_anchor).days * 12. / 365)
for months in range(max(estimated_months - 1, 0), estimated_months + 2):
proposed_next_month = add_months(billing_cycle_anchor, months)
if 20 < (proposed_next_month - dt).days < 40:
return proposed_next_month
raise AssertionError('Something wrong in next_month calculation with '
'billing_cycle_anchor: %s, dt: %s' % (billing_cycle_anchor, dt))
# TODO take downgrade into account
def next_renewal_date(plan: CustomerPlan) -> datetime:
months_per_period = {
CustomerPlan.ANNUAL: 12,
CustomerPlan.MONTHLY: 1,
}[plan.billing_schedule]
periods = 1
dt = plan.billing_cycle_anchor
while dt <= plan.billed_through:
dt = add_months(plan.billing_cycle_anchor, months_per_period * periods)
periods += 1
return dt
def renewal_amount(plan: CustomerPlan) -> int: # nocoverage: TODO
if plan.fixed_price is not None:
basis = plan.fixed_price
elif plan.automanage_licenses:
assert(plan.price_per_license is not None)
basis = plan.price_per_license * get_seat_count(plan.customer.realm)
else:
assert(plan.price_per_license is not None)
basis = plan.price_per_license * plan.licenses
if plan.discount is None:
return basis
# TODO: figure out right thing to do with Decimal
return int(float(basis * (100 - plan.discount) / 100) + .00001)
class BillingError(Exception): class BillingError(Exception):
# error messages # error messages
CONTACT_SUPPORT = _("Something went wrong. Please contact %s." % (settings.ZULIP_ADMINISTRATOR,)) CONTACT_SUPPORT = _("Something went wrong. Please contact %s." % (settings.ZULIP_ADMINISTRATOR,))
@ -73,9 +129,6 @@ def catch_stripe_errors(func: CallableT) -> CallableT:
if STRIPE_PUBLISHABLE_KEY is None: if STRIPE_PUBLISHABLE_KEY is None:
raise BillingError('missing stripe config', "Missing Stripe config. " raise BillingError('missing stripe config', "Missing Stripe config. "
"See https://zulip.readthedocs.io/en/latest/subsystems/billing.html.") "See https://zulip.readthedocs.io/en/latest/subsystems/billing.html.")
if not Plan.objects.exists():
raise BillingError('missing plans',
"Plan objects not created. Please run ./manage.py setup_stripe")
try: try:
return func(*args, **kwargs) return func(*args, **kwargs)
# See https://stripe.com/docs/api/python#error_handling, though # See https://stripe.com/docs/api/python#error_handling, though
@ -101,38 +154,7 @@ def stripe_get_customer(stripe_customer_id: str) -> stripe.Customer:
return stripe.Customer.retrieve(stripe_customer_id, expand=["default_source"]) return stripe.Customer.retrieve(stripe_customer_id, expand=["default_source"])
@catch_stripe_errors @catch_stripe_errors
def stripe_get_upcoming_invoice(stripe_customer_id: str) -> stripe.Invoice: def do_create_customer(user: UserProfile, stripe_token: Optional[str]=None) -> Customer:
return stripe.Invoice.upcoming(customer=stripe_customer_id)
# This allows us to access /billing in tests without having to mock the
# whole invoice object
def upcoming_invoice_total(stripe_customer_id: str) -> int:
return stripe_get_upcoming_invoice(stripe_customer_id).total
# Return type should be Optional[stripe.Subscription], which throws a mypy error.
# Will fix once we add type stubs for the Stripe API.
def extract_current_subscription(stripe_customer: stripe.Customer) -> Any:
if not stripe_customer.subscriptions:
return None
for stripe_subscription in stripe_customer.subscriptions:
if stripe_subscription.status != "canceled":
return stripe_subscription
def estimate_customer_arr(stripe_customer: stripe.Customer) -> int: # nocoverage
stripe_subscription = extract_current_subscription(stripe_customer)
if stripe_subscription is None:
return 0
# This is an overestimate for those paying by invoice
estimated_arr = stripe_subscription.plan.amount * stripe_subscription.quantity / 100.
if stripe_subscription.plan.interval == 'month':
estimated_arr *= 12
discount = Customer.objects.get(stripe_customer_id=stripe_customer.id).default_discount
if discount is not None:
estimated_arr *= 1 - discount/100.
return int(estimated_arr)
@catch_stripe_errors
def do_create_customer(user: UserProfile, stripe_token: Optional[str]=None) -> stripe.Customer:
realm = user.realm realm = user.realm
# We could do a better job of handling race conditions here, but if two # We could do a better job of handling race conditions here, but if two
# people from a realm try to upgrade at exactly the same time, the main # people from a realm try to upgrade at exactly the same time, the main
@ -152,10 +174,10 @@ def do_create_customer(user: UserProfile, stripe_token: Optional[str]=None) -> s
RealmAuditLog.objects.create( RealmAuditLog.objects.create(
realm=user.realm, acting_user=user, event_type=RealmAuditLog.STRIPE_CARD_CHANGED, realm=user.realm, acting_user=user, event_type=RealmAuditLog.STRIPE_CARD_CHANGED,
event_time=event_time) event_time=event_time)
Customer.objects.create(realm=realm, stripe_customer_id=stripe_customer.id) customer = Customer.objects.create(realm=realm, stripe_customer_id=stripe_customer.id)
user.is_billing_admin = True user.is_billing_admin = True
user.save(update_fields=["is_billing_admin"]) user.save(update_fields=["is_billing_admin"])
return stripe_customer return customer
@catch_stripe_errors @catch_stripe_errors
def do_replace_payment_source(user: UserProfile, stripe_token: str) -> stripe.Customer: def do_replace_payment_source(user: UserProfile, stripe_token: str) -> stripe.Customer:
@ -170,96 +192,154 @@ def do_replace_payment_source(user: UserProfile, stripe_token: str) -> stripe.Cu
event_time=timezone_now()) event_time=timezone_now())
return updated_stripe_customer return updated_stripe_customer
# Returns Customer instead of stripe_customer so that we don't make a Stripe
# API call if there's nothing to update
def update_or_create_stripe_customer(user: UserProfile, stripe_token: Optional[str]=None) -> Customer:
realm = user.realm
customer = Customer.objects.filter(realm=realm).first()
if customer is None:
return do_create_customer(user, stripe_token=stripe_token)
if stripe_token is not None:
do_replace_payment_source(user, stripe_token)
return customer
def compute_plan_parameters(
automanage_licenses: bool, billing_schedule: int,
discount: Optional[Decimal]) -> Tuple[datetime, datetime, datetime, int]:
# Everything in Stripe is stored as timestamps with 1 second resolution,
# so standardize on 1 second resolution.
# TODO talk about leapseconds?
billing_cycle_anchor = timezone_now().replace(microsecond=0)
if billing_schedule == CustomerPlan.ANNUAL:
# TODO use variables to account for Zulip Plus
price_per_license = 8000
period_end = add_months(billing_cycle_anchor, 12)
elif billing_schedule == CustomerPlan.MONTHLY:
price_per_license = 800
period_end = add_months(billing_cycle_anchor, 1)
else:
raise AssertionError('Unknown billing_schedule: {}'.format(billing_schedule))
if discount is not None:
# There are no fractional cents in Stripe, so round down to nearest integer.
price_per_license = int(float(price_per_license * (1 - discount / 100)) + .00001)
next_billing_date = period_end
if automanage_licenses:
next_billing_date = add_months(billing_cycle_anchor, 1)
return billing_cycle_anchor, next_billing_date, period_end, price_per_license
# Only used for cloud signups
@catch_stripe_errors @catch_stripe_errors
def do_subscribe_customer_to_plan(user: UserProfile, stripe_customer: stripe.Customer, stripe_plan_id: str, def process_initial_upgrade(user: UserProfile, licenses: int, automanage_licenses: bool,
seat_count: int, tax_percent: float, charge_automatically: bool) -> None: billing_schedule: int, stripe_token: Optional[str]) -> None:
if extract_current_subscription(stripe_customer) is not None: # nocoverage realm = user.realm
customer = update_or_create_stripe_customer(user, stripe_token=stripe_token)
# TODO write a test for this
if CustomerPlan.objects.filter(customer=customer, status=CustomerPlan.ACTIVE).exists(): # nocoverage
# Unlikely race condition from two people upgrading (clicking "Make payment") # Unlikely race condition from two people upgrading (clicking "Make payment")
# at exactly the same time. Doesn't fully resolve the race condition, but having # at exactly the same time. Doesn't fully resolve the race condition, but having
# a check here reduces the likelihood. # a check here reduces the likelihood.
billing_logger.error("Stripe customer %s trying to subscribe to %s, " billing_logger.warning(
"but has an active subscription" % (stripe_customer.id, stripe_plan_id)) "Customer {} trying to upgrade, but has an active subscription".format(customer))
raise BillingError('subscribing with existing subscription', BillingError.TRY_RELOADING) raise BillingError('subscribing with existing subscription', BillingError.TRY_RELOADING)
customer = Customer.objects.get(stripe_customer_id=stripe_customer.id)
billing_cycle_anchor, next_billing_date, period_end, price_per_license = compute_plan_parameters(
automanage_licenses, billing_schedule, customer.default_discount)
# The main design constraint in this function is that if you upgrade with a credit card, and the
# charge fails, everything should be rolled back as if nothing had happened. This is because we
# expect frequent card failures on initial signup.
# Hence, if we're going to charge a card, do it at the beginning, even if we later may have to
# adjust the number of licenses.
charge_automatically = stripe_token is not None
if charge_automatically:
stripe_charge = stripe.Charge.create(
amount=price_per_license * licenses,
currency='usd',
customer=customer.stripe_customer_id,
description="Upgrade to Zulip Standard, ${} x {}".format(price_per_license/100, licenses),
receipt_email=user.email,
statement_descriptor='Zulip Standard')
# Not setting a period start and end, but maybe we should? Unclear what will make things
# most similar to the renewal case from an accounting perspective.
stripe.InvoiceItem.create(
amount=price_per_license * licenses * -1,
currency='usd',
customer=customer.stripe_customer_id,
description="Payment (Card ending in {})".format(cast(stripe.Card, stripe_charge.source).last4),
discountable=False)
# TODO: The correctness of this relies on user creation, deactivation, etc being
# in a transaction.atomic() with the relevant RealmAuditLog entries
with transaction.atomic():
# billed_licenses can greater than licenses if users are added between the start of
# this function (process_initial_upgrade) and now
billed_licenses = max(get_seat_count(realm), licenses)
plan_params = {
'licenses': billed_licenses,
'automanage_licenses': automanage_licenses,
'charge_automatically': charge_automatically,
'price_per_license': price_per_license,
'discount': customer.default_discount,
'billing_cycle_anchor': billing_cycle_anchor,
'billing_schedule': billing_schedule,
'tier': CustomerPlan.STANDARD}
CustomerPlan.objects.create(
customer=customer,
billed_through=billing_cycle_anchor,
next_billing_date=next_billing_date,
**plan_params)
RealmAuditLog.objects.create(
realm=realm, acting_user=user, event_time=billing_cycle_anchor,
event_type=RealmAuditLog.CUSTOMER_PLAN_CREATED,
# TODO: add tests for licenses
# Only 'licenses' is guaranteed to be useful to automated tools. The other extra_data
# fields can change in the future and are only meant to assist manual debugging.
extra_data=ujson.dumps(plan_params))
description = 'Zulip Standard'
if customer.default_discount is not None: # nocoverage: TODO
description += ' (%s%% off)' % (customer.default_discount,)
stripe.InvoiceItem.create(
currency='usd',
customer=customer.stripe_customer_id,
description=description,
discountable=False,
period = {'start': datetime_to_timestamp(billing_cycle_anchor),
'end': datetime_to_timestamp(period_end)},
quantity=billed_licenses,
unit_amount=price_per_license)
if charge_automatically: if charge_automatically:
billing_method = 'charge_automatically' billing_method = 'charge_automatically'
days_until_due = None days_until_due = None
else: else:
billing_method = 'send_invoice' billing_method = 'send_invoice'
days_until_due = DEFAULT_INVOICE_DAYS_UNTIL_DUE days_until_due = DEFAULT_INVOICE_DAYS_UNTIL_DUE
# Note that there is a race condition here, where if two users upgrade at exactly the stripe_invoice = stripe.Invoice.create(
# same time, they will have two subscriptions, and get charged twice. We could try to auto_advance=True,
# reduce the chance of it with a well-designed idempotency_key, but it's not easy since
# we also need to be careful not to block the customer from retrying if their
# subscription attempt fails (e.g. due to insufficient funds).
# Success here implies the stripe_customer was charged: https://stripe.com/docs/billing/lifecycle#active
# Otherwise we should expect it to throw a stripe.error.
stripe_subscription = stripe.Subscription.create(
customer=stripe_customer.id,
billing=billing_method, billing=billing_method,
customer=customer.stripe_customer_id,
days_until_due=days_until_due, days_until_due=days_until_due,
items=[{ statement_descriptor='Zulip Standard')
'plan': stripe_plan_id, stripe.Invoice.finalize_invoice(stripe_invoice)
'quantity': seat_count,
}],
prorate=True,
tax_percent=tax_percent)
with transaction.atomic():
customer.has_billing_relationship = True
customer.save(update_fields=['has_billing_relationship'])
customer.realm.has_seat_based_plan = True
customer.realm.save(update_fields=['has_seat_based_plan'])
RealmAuditLog.objects.create(
realm=customer.realm,
acting_user=user,
event_type=RealmAuditLog.STRIPE_PLAN_CHANGED,
event_time=timestamp_to_datetime(stripe_subscription.created),
extra_data=ujson.dumps({'plan': stripe_plan_id, 'quantity': seat_count,
'billing_method': billing_method}))
current_seat_count = get_seat_count(customer.realm) do_change_plan_type(realm, Realm.STANDARD)
if seat_count != current_seat_count:
RealmAuditLog.objects.create(
realm=customer.realm,
event_type=RealmAuditLog.STRIPE_PLAN_QUANTITY_RESET,
event_time=timestamp_to_datetime(stripe_subscription.created),
requires_billing_update=True,
extra_data=ujson.dumps({'quantity': current_seat_count}))
def process_initial_upgrade(user: UserProfile, seat_count: int, schedule: int,
stripe_token: Optional[str]) -> None:
if schedule == CustomerPlan.ANNUAL:
plan = Plan.objects.get(nickname=Plan.CLOUD_ANNUAL)
else: # schedule == CustomerPlan.MONTHLY:
plan = Plan.objects.get(nickname=Plan.CLOUD_MONTHLY)
customer = Customer.objects.filter(realm=user.realm).first()
if customer is None:
stripe_customer = do_create_customer(user, stripe_token=stripe_token)
# elif instead of if since we want to avoid doing two round trips to
# stripe if we can
elif stripe_token is not None:
stripe_customer = do_replace_payment_source(user, stripe_token)
else:
stripe_customer = stripe_get_customer(customer.stripe_customer_id)
do_subscribe_customer_to_plan(
user=user,
stripe_customer=stripe_customer,
stripe_plan_id=plan.stripe_plan_id,
seat_count=seat_count,
# TODO: billing address details are passed to us in the request;
# use that to calculate taxes.
tax_percent=0,
charge_automatically=(stripe_token is not None))
do_change_plan_type(user.realm, Realm.STANDARD)
def attach_discount_to_realm(user: UserProfile, discount: Decimal) -> None: def attach_discount_to_realm(user: UserProfile, discount: Decimal) -> None:
customer = Customer.objects.filter(realm=user.realm).first() customer = Customer.objects.filter(realm=user.realm).first()
if customer is None: if customer is None:
do_create_customer(user) customer = do_create_customer(user)
customer = Customer.objects.filter(realm=user.realm).first()
customer.default_discount = discount customer.default_discount = discount
customer.save() customer.save()
def process_downgrade(user: UserProfile) -> None: # nocoverage def process_downgrade(user: UserProfile) -> None: # nocoverage
pass pass
def estimate_annual_recurring_revenue_by_realm() -> Dict[str, int]: # nocoverage
annual_revenue = {}
for plan in CustomerPlan.objects.filter(
status=CustomerPlan.ACTIVE).select_related('customer__realm'):
renewal_cents = renewal_amount(plan)
if plan.billing_schedule == CustomerPlan.MONTHLY:
renewal_cents *= 12
# TODO: Decimal stuff
annual_revenue[plan.customer.realm.string_id] = int(renewal_cents / 100)
return annual_revenue

View File

@ -1,58 +0,0 @@
from corporate.models import Plan, Coupon, Customer
from django.conf import settings
from zerver.lib.management import ZulipBaseCommand
from zproject.settings import get_secret
from typing import Any
import stripe
stripe.api_key = get_secret('stripe_secret_key')
class Command(ZulipBaseCommand):
help = """Script to add the appropriate products and plans to Stripe."""
def handle(self, *args: Any, **options: Any) -> None:
assert (settings.DEVELOPMENT or settings.TEST_SUITE)
Customer.objects.all().delete()
Plan.objects.all().delete()
Coupon.objects.all().delete()
# Zulip Cloud offerings
product = stripe.Product.create(
name="Zulip Cloud Standard",
type='service',
statement_descriptor="Zulip Cloud Standard",
unit_label="user")
plan = stripe.Plan.create(
currency='usd',
interval='month',
product=product.id,
amount=800,
billing_scheme='per_unit',
nickname=Plan.CLOUD_MONTHLY,
usage_type='licensed')
Plan.objects.create(nickname=Plan.CLOUD_MONTHLY, stripe_plan_id=plan.id)
plan = stripe.Plan.create(
currency='usd',
interval='year',
product=product.id,
amount=8000,
billing_scheme='per_unit',
nickname=Plan.CLOUD_ANNUAL,
usage_type='licensed')
Plan.objects.create(nickname=Plan.CLOUD_ANNUAL, stripe_plan_id=plan.id)
coupon = stripe.Coupon.create(
duration='forever',
name='25% discount',
percent_off=25)
Coupon.objects.create(percent_off=25, stripe_coupon_id=coupon.id)
coupon = stripe.Coupon.create(
duration='forever',
name='85% discount',
percent_off=85)
Coupon.objects.create(percent_off=85, stripe_coupon_id=coupon.id)

View File

@ -0,0 +1,35 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.11.16 on 2018-12-22 21:05
from __future__ import unicode_literals
from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
('corporate', '0002_customer_default_discount'),
]
operations = [
migrations.CreateModel(
name='CustomerPlan',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('licenses', models.IntegerField()),
('automanage_licenses', models.BooleanField(default=False)),
('charge_automatically', models.BooleanField(default=False)),
('price_per_license', models.IntegerField(null=True)),
('fixed_price', models.IntegerField(null=True)),
('discount', models.DecimalField(decimal_places=4, max_digits=6, null=True)),
('billing_cycle_anchor', models.DateTimeField()),
('billing_schedule', models.SmallIntegerField()),
('billed_through', models.DateTimeField()),
('next_billing_date', models.DateTimeField(db_index=True)),
('tier', models.SmallIntegerField()),
('status', models.SmallIntegerField(default=1)),
('customer', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='corporate.Customer')),
],
),
]

View File

@ -9,17 +9,52 @@ from zerver.models import Realm, RealmAuditLog
class Customer(models.Model): class Customer(models.Model):
realm = models.OneToOneField(Realm, on_delete=models.CASCADE) # type: Realm realm = models.OneToOneField(Realm, on_delete=models.CASCADE) # type: Realm
stripe_customer_id = models.CharField(max_length=255, unique=True) # type: str stripe_customer_id = models.CharField(max_length=255, unique=True) # type: str
# Becomes True the first time a payment successfully goes through, and never # Deprecated .. delete once everyone is migrated to new billing system
# goes back to being False
has_billing_relationship = models.BooleanField(default=False) # type: bool has_billing_relationship = models.BooleanField(default=False) # type: bool
default_discount = models.DecimalField(decimal_places=4, max_digits=7, null=True) # type: Optional[Decimal] default_discount = models.DecimalField(decimal_places=4, max_digits=7, null=True) # type: Optional[Decimal]
def __str__(self) -> str: def __str__(self) -> str:
return "<Customer %s %s>" % (self.realm, self.stripe_customer_id) return "<Customer %s %s>" % (self.realm, self.stripe_customer_id)
class CustomerPlan(object): class CustomerPlan(models.Model):
customer = models.ForeignKey(Customer, on_delete=models.CASCADE) # type: Customer
licenses = models.IntegerField() # type: int
automanage_licenses = models.BooleanField(default=False) # type: bool
charge_automatically = models.BooleanField(default=False) # type: bool
# Both of these are in cents. Exactly one of price_per_license or
# fixed_price should be set. fixed_price is only for manual deals, and
# can't be set via the self-serve billing system.
price_per_license = models.IntegerField(null=True) # type: Optional[int]
fixed_price = models.IntegerField(null=True) # type: Optional[int]
# A percentage, like 85
discount = models.DecimalField(decimal_places=4, max_digits=6, null=True) # type: Optional[Decimal]
billing_cycle_anchor = models.DateTimeField() # type: datetime.datetime
ANNUAL = 1 ANNUAL = 1
MONTHLY = 2 MONTHLY = 2
billing_schedule = models.SmallIntegerField() # type: int
# This is like analytic's FillState, but for billing
billed_through = models.DateTimeField() # type: datetime.datetime
next_billing_date = models.DateTimeField(db_index=True) # type: datetime.datetime
STANDARD = 1
PLUS = 2 # not available through self-serve signup
ENTERPRISE = 10
tier = models.SmallIntegerField() # type: int
ACTIVE = 1
ENDED = 2
NEVER_STARTED = 3
# You can only have 1 active subscription at a time
status = models.SmallIntegerField(default=ACTIVE) # type: int
# TODO maybe override setattr to ensure billing_cycle_anchor, etc are immutable
def get_active_plan(customer: Customer) -> Optional[CustomerPlan]:
return CustomerPlan.objects.filter(customer=customer, status=CustomerPlan.ACTIVE).first()
# Everything below here is legacy # Everything below here is legacy

View File

@ -1,4 +1,4 @@
import datetime from datetime import datetime
from decimal import Decimal from decimal import Decimal
from functools import wraps from functools import wraps
from mock import Mock, patch from mock import Mock, patch
@ -24,11 +24,12 @@ from zerver.lib.actions import do_deactivate_user, do_create_user, \
from zerver.lib.test_classes import ZulipTestCase from zerver.lib.test_classes import ZulipTestCase
from zerver.lib.timestamp import timestamp_to_datetime, datetime_to_timestamp from zerver.lib.timestamp import timestamp_to_datetime, datetime_to_timestamp
from zerver.models import Realm, UserProfile, get_realm, RealmAuditLog from zerver.models import Realm, UserProfile, get_realm, RealmAuditLog
from corporate.lib.stripe import catch_stripe_errors, \ from corporate.lib.stripe import catch_stripe_errors, attach_discount_to_realm, \
do_subscribe_customer_to_plan, attach_discount_to_realm, \ get_seat_count, sign_string, unsign_string, \
get_seat_count, extract_current_subscription, sign_string, unsign_string, \
BillingError, StripeCardError, StripeConnectionError, stripe_get_customer, \ BillingError, StripeCardError, StripeConnectionError, stripe_get_customer, \
DEFAULT_INVOICE_DAYS_UNTIL_DUE, MIN_INVOICED_LICENSES, do_create_customer DEFAULT_INVOICE_DAYS_UNTIL_DUE, MIN_INVOICED_LICENSES, do_create_customer, \
add_months, next_month, next_renewal_date, renewal_amount, \
compute_plan_parameters, update_or_create_stripe_customer
from corporate.models import Customer, CustomerPlan, Plan, Coupon from corporate.models import Customer, CustomerPlan, Plan, Coupon
from corporate.views import payment_method_string from corporate.views import payment_method_string
import corporate.urls import corporate.urls
@ -165,11 +166,11 @@ def normalize_fixture_data(decorated_function: CallableT,
f.write(file_content) f.write(file_content)
MOCKED_STRIPE_FUNCTION_NAMES = ["stripe.{}".format(name) for name in [ MOCKED_STRIPE_FUNCTION_NAMES = ["stripe.{}".format(name) for name in [
"Charge.list", "Charge.create", "Charge.list",
"Coupon.create", "Coupon.create",
"Customer.create", "Customer.retrieve", "Customer.save", "Customer.create", "Customer.retrieve", "Customer.save",
"Invoice.list", "Invoice.upcoming", "Invoice.create", "Invoice.finalize_invoice", "Invoice.list", "Invoice.upcoming",
"InvoiceItem.create", "InvoiceItem.create", "InvoiceItem.list",
"Plan.create", "Plan.create",
"Product.create", "Product.create",
"Subscription.create", "Subscription.delete", "Subscription.retrieve", "Subscription.save", "Subscription.create", "Subscription.delete", "Subscription.retrieve", "Subscription.save",
@ -205,14 +206,13 @@ def mock_stripe(tested_timestamp_fields: List[str]=[],
# A Kandra is a fictional character that can become anything. Used as a # A Kandra is a fictional character that can become anything. Used as a
# wildcard when testing for equality. # wildcard when testing for equality.
class Kandra(object): class Kandra(object): # nocoverage: TODO
def __eq__(self, other: Any) -> bool: def __eq__(self, other: Any) -> bool:
return True return True
class StripeTest(ZulipTestCase): class StripeTest(ZulipTestCase):
@mock_stripe(generate=False)
def setUp(self, *mocks: Mock) -> None: def setUp(self, *mocks: Mock) -> None:
call_command("setup_stripe") # TODO
# Unfortunately this test suite is likely not robust to users being # Unfortunately this test suite is likely not robust to users being
# added in populate_db. A quick hack for now to ensure get_seat_count is 8 # added in populate_db. A quick hack for now to ensure get_seat_count is 8
# for these tests (8, since that's what it was when the tests were written). # for these tests (8, since that's what it was when the tests were written).
@ -229,6 +229,11 @@ class StripeTest(ZulipTestCase):
self.assertEqual(get_seat_count(get_realm('zulip')), 8) self.assertEqual(get_seat_count(get_realm('zulip')), 8)
self.seat_count = 8 self.seat_count = 8
self.signed_seat_count, self.salt = sign_string(str(self.seat_count)) self.signed_seat_count, self.salt = sign_string(str(self.seat_count))
# Choosing dates with corresponding timestamps below 1500000000 so that they are
# not caught by our timestamp normalization regex in normalize_fixture_data
self.now = datetime(2012, 1, 2, 3, 4, 5).replace(tzinfo=timezone_utc)
self.next_month = datetime(2012, 2, 2, 3, 4, 5).replace(tzinfo=timezone_utc)
self.next_year = datetime(2013, 1, 2, 3, 4, 5).replace(tzinfo=timezone_utc)
def get_signed_seat_count_from_response(self, response: HttpResponse) -> Optional[str]: def get_signed_seat_count_from_response(self, response: HttpResponse) -> Optional[str]:
match = re.search(r'name=\"signed_seat_count\" value=\"(.+)\"', response.content.decode("utf-8")) match = re.search(r'name=\"signed_seat_count\" value=\"(.+)\"', response.content.decode("utf-8"))
@ -242,7 +247,7 @@ class StripeTest(ZulipTestCase):
realm: Optional[Realm]=None, del_args: List[str]=[], realm: Optional[Realm]=None, del_args: List[str]=[],
**kwargs: Any) -> HttpResponse: **kwargs: Any) -> HttpResponse:
host_args = {} host_args = {}
if realm is not None: if realm is not None: # nocoverage: TODO
host_args['HTTP_HOST'] = realm.host host_args['HTTP_HOST'] = realm.host
response = self.client_get("/upgrade/", **host_args) response = self.client_get("/upgrade/", **host_args)
params = { params = {
@ -304,19 +309,19 @@ class StripeTest(ZulipTestCase):
self.assert_in_success_response(["Page not found (404)"], response) self.assert_in_success_response(["Page not found (404)"], response)
@mock_stripe(tested_timestamp_fields=["created"]) @mock_stripe(tested_timestamp_fields=["created"])
def test_initial_upgrade(self, *mocks: Mock) -> None: def test_upgrade_by_card(self, *mocks: Mock) -> None:
user = self.example_user("hamlet") user = self.example_user("hamlet")
self.login(user.email) self.login(user.email)
response = self.client_get("/upgrade/") response = self.client_get("/upgrade/")
self.assert_in_success_response(['Pay annually'], response) self.assert_in_success_response(['Pay annually'], response)
self.assertFalse(user.realm.has_seat_based_plan)
self.assertNotEqual(user.realm.plan_type, Realm.STANDARD) self.assertNotEqual(user.realm.plan_type, Realm.STANDARD)
self.assertFalse(Customer.objects.filter(realm=user.realm).exists()) self.assertFalse(Customer.objects.filter(realm=user.realm).exists())
# Click "Make payment" in Stripe Checkout # Click "Make payment" in Stripe Checkout
self.upgrade() with patch('corporate.lib.stripe.timezone_now', return_value=self.now):
self.upgrade()
# Check that we correctly created Customer and Subscription objects in Stripe # Check that we correctly created a Customer object in Stripe
stripe_customer = stripe_get_customer(Customer.objects.get(realm=user.realm).stripe_customer_id) stripe_customer = stripe_get_customer(Customer.objects.get(realm=user.realm).stripe_customer_id)
self.assertEqual(stripe_customer.default_source.id[:5], 'card_') self.assertEqual(stripe_customer.default_source.id[:5], 'card_')
self.assertEqual(stripe_customer.description, "zulip (Zulip Dev)") self.assertEqual(stripe_customer.description, "zulip (Zulip Dev)")
@ -324,32 +329,69 @@ class StripeTest(ZulipTestCase):
self.assertEqual(stripe_customer.email, user.email) self.assertEqual(stripe_customer.email, user.email)
self.assertEqual(dict(stripe_customer.metadata), self.assertEqual(dict(stripe_customer.metadata),
{'realm_id': str(user.realm.id), 'realm_str': 'zulip'}) {'realm_id': str(user.realm.id), 'realm_str': 'zulip'})
# Check Charges in Stripe
stripe_charges = [charge for charge in stripe.Charge.list(customer=stripe_customer.id)]
self.assertEqual(len(stripe_charges), 1)
self.assertEqual(stripe_charges[0].amount, 8000 * self.seat_count)
# TODO: fix Decimal
self.assertEqual(stripe_charges[0].description,
"Upgrade to Zulip Standard, $80.0 x {}".format(self.seat_count))
self.assertEqual(stripe_charges[0].receipt_email, user.email)
self.assertEqual(stripe_charges[0].statement_descriptor, "Zulip Standard")
# Check Invoices in Stripe
stripe_invoices = [invoice for invoice in stripe.Invoice.list(customer=stripe_customer.id)]
self.assertEqual(len(stripe_invoices), 1)
self.assertIsNotNone(stripe_invoices[0].finalized_at)
invoice_params = {
# auto_advance is False because the invoice has been paid
'amount_due': 0, 'amount_paid': 0, 'auto_advance': False, 'billing': 'charge_automatically',
'charge': None, 'status': 'paid', 'total': 0}
for key, value in invoice_params.items():
self.assertEqual(stripe_invoices[0].get(key), value)
# Check Line Items on Stripe Invoice
stripe_line_items = [item for item in stripe_invoices[0].lines]
self.assertEqual(len(stripe_line_items), 2)
line_item_params = {
'amount': 8000 * self.seat_count, 'description': 'Zulip Standard', 'discountable': False,
'period': {
'end': datetime_to_timestamp(self.next_year),
'start': datetime_to_timestamp(self.now)},
# There's no unit_amount on Line Items, probably because it doesn't show up on the
# user-facing invoice. We could pull the Invoice Item instead and test unit_amount there,
# but testing the amount and quantity seems sufficient.
'plan': None, 'proration': False, 'quantity': self.seat_count}
for key, value in line_item_params.items():
self.assertEqual(stripe_line_items[0].get(key), value)
line_item_params = {
'amount': -8000 * self.seat_count, 'description': 'Payment (Card ending in 4242)',
'discountable': False, 'plan': None, 'proration': False, 'quantity': 1}
for key, value in line_item_params.items():
self.assertEqual(stripe_line_items[1].get(key), value)
stripe_subscription = extract_current_subscription(stripe_customer) # Check that we correctly populated Customer and CustomerPlan in Zulip
self.assertEqual(stripe_subscription.billing, 'charge_automatically') customer = Customer.objects.filter(stripe_customer_id=stripe_customer.id,
self.assertEqual(stripe_subscription.days_until_due, None) realm=user.realm).first()
self.assertEqual(stripe_subscription.plan.id, self.assertTrue(CustomerPlan.objects.filter(
Plan.objects.get(nickname=Plan.CLOUD_ANNUAL).stripe_plan_id) customer=customer, licenses=self.seat_count, automanage_licenses=True,
self.assertEqual(stripe_subscription.quantity, self.seat_count) price_per_license=8000, fixed_price=None, discount=None, billing_cycle_anchor=self.now,
self.assertEqual(stripe_subscription.status, 'active') billing_schedule=CustomerPlan.ANNUAL, billed_through=self.now,
self.assertEqual(stripe_subscription.tax_percent, 0) next_billing_date=self.next_month, tier=CustomerPlan.STANDARD,
status=CustomerPlan.ACTIVE).exists())
# Check that we correctly populated Customer and RealmAuditLog in Zulip # Check RealmAuditLog
self.assertEqual(1, Customer.objects.filter(stripe_customer_id=stripe_customer.id,
realm=user.realm).count())
audit_log_entries = list(RealmAuditLog.objects.filter(acting_user=user) audit_log_entries = list(RealmAuditLog.objects.filter(acting_user=user)
.values_list('event_type', 'event_time').order_by('id')) .values_list('event_type', 'event_time').order_by('id'))
self.assertEqual(audit_log_entries, [ self.assertEqual(audit_log_entries, [
(RealmAuditLog.STRIPE_CUSTOMER_CREATED, timestamp_to_datetime(stripe_customer.created)), (RealmAuditLog.STRIPE_CUSTOMER_CREATED, timestamp_to_datetime(stripe_customer.created)),
(RealmAuditLog.STRIPE_CARD_CHANGED, timestamp_to_datetime(stripe_customer.created)), (RealmAuditLog.STRIPE_CARD_CHANGED, timestamp_to_datetime(stripe_customer.created)),
# TODO: Add a test where stripe_customer.created != stripe_subscription.created (RealmAuditLog.CUSTOMER_PLAN_CREATED, self.now),
(RealmAuditLog.STRIPE_PLAN_CHANGED, timestamp_to_datetime(stripe_subscription.created)),
# TODO: Check for REALM_PLAN_TYPE_CHANGED # TODO: Check for REALM_PLAN_TYPE_CHANGED
# (RealmAuditLog.REALM_PLAN_TYPE_CHANGED, Kandra()), # (RealmAuditLog.REALM_PLAN_TYPE_CHANGED, Kandra()),
]) ])
self.assertEqual(ujson.loads(RealmAuditLog.objects.filter(
event_type=RealmAuditLog.CUSTOMER_PLAN_CREATED).values_list(
'extra_data', flat=True).first())['licenses'], self.seat_count)
# Check that we correctly updated Realm # Check that we correctly updated Realm
realm = get_realm("zulip") realm = get_realm("zulip")
self.assertTrue(realm.has_seat_based_plan)
self.assertEqual(realm.plan_type, Realm.STANDARD) self.assertEqual(realm.plan_type, Realm.STANDARD)
self.assertEqual(realm.max_invites, Realm.INVITES_STANDARD_REALM_DAILY_MAX) self.assertEqual(realm.max_invites, Realm.INVITES_STANDARD_REALM_DAILY_MAX)
# Check that we can no longer access /upgrade # Check that we can no longer access /upgrade
@ -357,12 +399,90 @@ class StripeTest(ZulipTestCase):
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
self.assertEqual('/billing/', response.url) self.assertEqual('/billing/', response.url)
# Check /billing has the correct information # TODO: Check /billing has the correct information
response = self.client_get("/billing/") # response = self.client_get("/billing/")
self.assert_not_in_success_response(['Pay annually'], response) # self.assert_not_in_success_response(['Pay annually'], response)
for substring in ['Your plan will renew on', '$%s.00' % (80 * self.seat_count,), # for substring in ['Your plan will renew on', '$%s.00' % (80 * self.seat_count,),
'Card ending in 4242', 'Update card']: # 'Card ending in 4242', 'Update card']:
self.assert_in_response(substring, response) # self.assert_in_response(substring, response)
@mock_stripe(tested_timestamp_fields=["created"])
def test_upgrade_by_invoice(self, *mocks: Mock) -> None:
user = self.example_user("hamlet")
self.login(user.email)
# Click "Make payment" in Stripe Checkout
with patch('corporate.lib.stripe.timezone_now', return_value=self.now):
self.upgrade(invoice=True)
# Check that we correctly created a Customer in Stripe
stripe_customer = stripe_get_customer(Customer.objects.get(realm=user.realm).stripe_customer_id)
# It can take a second for Stripe to attach the source to the customer, and in
# particular it may not be attached at the time stripe_get_customer is called above,
# causing test flakes.
# So commenting the next line out, but leaving it here so future readers know what
# is supposed to happen here
# self.assertEqual(stripe_customer.default_source.type, 'ach_credit_transfer')
# Check Charges in Stripe
self.assertFalse(stripe.Charge.list(customer=stripe_customer.id))
# Check Invoices in Stripe
stripe_invoices = [invoice for invoice in stripe.Invoice.list(customer=stripe_customer.id)]
self.assertEqual(len(stripe_invoices), 1)
self.assertIsNotNone(stripe_invoices[0].due_date)
self.assertIsNotNone(stripe_invoices[0].finalized_at)
invoice_params = {
'amount_due': 8000 * 123, 'amount_paid': 0, 'attempt_count': 0,
'auto_advance': True, 'billing': 'send_invoice', 'statement_descriptor': 'Zulip Standard',
'status': 'open', 'total': 8000 * 123}
for key, value in invoice_params.items():
self.assertEqual(stripe_invoices[0].get(key), value)
# Check Line Items on Stripe Invoice
stripe_line_items = [item for item in stripe_invoices[0].lines]
self.assertEqual(len(stripe_line_items), 1)
line_item_params = {
'amount': 8000 * 123, 'description': 'Zulip Standard', 'discountable': False,
'period': {
'end': datetime_to_timestamp(self.next_year),
'start': datetime_to_timestamp(self.now)},
'plan': None, 'proration': False, 'quantity': 123}
for key, value in line_item_params.items():
self.assertEqual(stripe_line_items[0].get(key), value)
# Check that we correctly populated Customer and CustomerPlan in Zulip
customer = Customer.objects.filter(stripe_customer_id=stripe_customer.id,
realm=user.realm).first()
self.assertTrue(CustomerPlan.objects.filter(
customer=customer, licenses=123, automanage_licenses=False, charge_automatically=False,
price_per_license=8000, fixed_price=None, discount=None, billing_cycle_anchor=self.now,
billing_schedule=CustomerPlan.ANNUAL, billed_through=self.now,
next_billing_date=self.next_year, tier=CustomerPlan.STANDARD,
status=CustomerPlan.ACTIVE).exists())
# Check RealmAuditLog
audit_log_entries = list(RealmAuditLog.objects.filter(acting_user=user)
.values_list('event_type', 'event_time').order_by('id'))
self.assertEqual(audit_log_entries, [
(RealmAuditLog.STRIPE_CUSTOMER_CREATED, timestamp_to_datetime(stripe_customer.created)),
(RealmAuditLog.CUSTOMER_PLAN_CREATED, self.now),
# TODO: Check for REALM_PLAN_TYPE_CHANGED
# (RealmAuditLog.REALM_PLAN_TYPE_CHANGED, Kandra()),
])
self.assertEqual(ujson.loads(RealmAuditLog.objects.filter(
event_type=RealmAuditLog.CUSTOMER_PLAN_CREATED).values_list(
'extra_data', flat=True).first())['licenses'], 123)
# Check that we correctly updated Realm
realm = get_realm("zulip")
self.assertEqual(realm.plan_type, Realm.STANDARD)
self.assertEqual(realm.max_invites, Realm.INVITES_STANDARD_REALM_DAILY_MAX)
# Check that we can no longer access /upgrade
response = self.client_get("/upgrade/")
self.assertEqual(response.status_code, 302)
self.assertEqual('/billing/', response.url)
# TODO: Check /billing has the correct information
# response = self.client_get("/billing/")
# self.assert_not_in_success_response(['Pay annually'], response)
# for substring in ['Your plan will renew on', '$%s.00' % (80 * self.seat_count,),
# 'Card ending in 4242', 'Update card']:
# self.assert_in_response(substring, response)
@mock_stripe() @mock_stripe()
def test_billing_page_permissions(self, *mocks: Mock) -> None: def test_billing_page_permissions(self, *mocks: Mock) -> None:
@ -386,49 +506,46 @@ class StripeTest(ZulipTestCase):
self.assert_in_success_response(["You must be an organization administrator"], response) self.assert_in_success_response(["You must be an organization administrator"], response)
@mock_stripe(tested_timestamp_fields=["created"]) @mock_stripe(tested_timestamp_fields=["created"])
def test_upgrade_with_outdated_seat_count(self, *mocks: Mock) -> None: def test_upgrade_by_card_with_outdated_seat_count(self, *mocks: Mock) -> None:
self.login(self.example_email("hamlet")) self.login(self.example_email("hamlet"))
new_seat_count = 123 new_seat_count = 23
# Change the seat count while the user is going through the upgrade flow # Change the seat count while the user is going through the upgrade flow
with patch('corporate.lib.stripe.get_seat_count', return_value=new_seat_count): with patch('corporate.lib.stripe.get_seat_count', return_value=new_seat_count):
self.upgrade() self.upgrade()
# Check that the subscription call used the old quantity, not new_seat_count stripe_customer_id = Customer.objects.first().stripe_customer_id
stripe_customer = stripe_get_customer( # Check that the Charge used the old quantity, not new_seat_count
Customer.objects.get(realm=get_realm('zulip')).stripe_customer_id) self.assertEqual(8000 * self.seat_count,
stripe_subscription = extract_current_subscription(stripe_customer) [charge for charge in stripe.Charge.list(customer=stripe_customer_id)][0].amount)
self.assertEqual(stripe_subscription.quantity, self.seat_count) # Check that the invoice has a credit for the old amount and a charge for the new one
stripe_invoice = [invoice for invoice in stripe.Invoice.list(customer=stripe_customer_id)][0]
# Check that we have the STRIPE_PLAN_QUANTITY_RESET entry, and that we self.assertEqual([8000 * new_seat_count, -8000 * self.seat_count],
# correctly handled the requires_billing_update field [item.amount for item in stripe_invoice.lines])
audit_log_entries = list(RealmAuditLog.objects.order_by('-id') # Check CustomerPlan and RealmAuditLog have the new amount
.values_list('event_type', 'event_time', self.assertEqual(CustomerPlan.objects.first().licenses, new_seat_count)
'requires_billing_update')[:5])[::-1]
self.assertEqual(audit_log_entries, [
(RealmAuditLog.STRIPE_CUSTOMER_CREATED, timestamp_to_datetime(stripe_customer.created), False),
(RealmAuditLog.STRIPE_CARD_CHANGED, timestamp_to_datetime(stripe_customer.created), False),
# TODO: Ideally this test would force stripe_customer.created != stripe_subscription.created
(RealmAuditLog.STRIPE_PLAN_CHANGED, timestamp_to_datetime(stripe_subscription.created), False),
(RealmAuditLog.STRIPE_PLAN_QUANTITY_RESET, timestamp_to_datetime(stripe_subscription.created), True),
(RealmAuditLog.REALM_PLAN_TYPE_CHANGED, Kandra(), False),
])
self.assertEqual(ujson.loads(RealmAuditLog.objects.filter( self.assertEqual(ujson.loads(RealmAuditLog.objects.filter(
event_type=RealmAuditLog.STRIPE_PLAN_QUANTITY_RESET).values_list('extra_data', flat=True).first()), event_type=RealmAuditLog.CUSTOMER_PLAN_CREATED).values_list(
{'quantity': new_seat_count}) 'extra_data', flat=True).first())['licenses'], new_seat_count)
@mock_stripe() @mock_stripe()
def test_upgrade_where_subscription_save_fails_at_first(self, *mocks: Mock) -> None: def test_upgrade_where_first_card_fails(self, *mocks: Mock) -> None:
user = self.example_user("hamlet") user = self.example_user("hamlet")
self.login(user.email) self.login(user.email)
# From https://stripe.com/docs/testing#cards: Attaching this card to # From https://stripe.com/docs/testing#cards: Attaching this card to
# a Customer object succeeds, but attempts to charge the customer fail. # a Customer object succeeds, but attempts to charge the customer fail.
self.upgrade(stripe_token=stripe_create_token('4000000000000341').id) with patch("corporate.lib.stripe.billing_logger.error") as mock_billing_logger:
# Check that we created a Customer object with has_billing_relationship False self.upgrade(stripe_token=stripe_create_token('4000000000000341').id)
customer = Customer.objects.get(realm=get_realm('zulip')) mock_billing_logger.assert_called()
self.assertFalse(customer.has_billing_relationship) # Check that we created a Customer object but no CustomerPlan
original_stripe_customer_id = customer.stripe_customer_id stripe_customer_id = Customer.objects.get(realm=get_realm('zulip')).stripe_customer_id
# Check that we created a customer in stripe, with no subscription self.assertFalse(CustomerPlan.objects.exists())
stripe_customer = stripe_get_customer(customer.stripe_customer_id) # Check that we created a Customer in stripe, a failed Charge, and no Invoices or Invoice Items
self.assertFalse(extract_current_subscription(stripe_customer)) self.assertTrue(stripe_get_customer(stripe_customer_id))
stripe_charges = [charge for charge in stripe.Charge.list(customer=stripe_customer_id)]
self.assertEqual(len(stripe_charges), 1)
self.assertEqual(stripe_charges[0].failure_code, 'card_declined')
# TODO: figure out what these actually are
self.assertFalse(stripe.Invoice.list(customer=stripe_customer_id))
self.assertFalse(stripe.InvoiceItem.list(customer=stripe_customer_id))
# Check that we correctly populated RealmAuditLog # Check that we correctly populated RealmAuditLog
audit_log_entries = list(RealmAuditLog.objects.filter(acting_user=user) audit_log_entries = list(RealmAuditLog.objects.filter(acting_user=user)
.values_list('event_type', flat=True).order_by('id')) .values_list('event_type', flat=True).order_by('id'))
@ -436,22 +553,28 @@ class StripeTest(ZulipTestCase):
RealmAuditLog.STRIPE_CARD_CHANGED]) RealmAuditLog.STRIPE_CARD_CHANGED])
# Check that we did not update Realm # Check that we did not update Realm
realm = get_realm("zulip") realm = get_realm("zulip")
self.assertFalse(realm.has_seat_based_plan) self.assertNotEqual(realm.plan_type, Realm.STANDARD)
# Check that we still get redirected to /upgrade # Check that we still get redirected to /upgrade
response = self.client_get("/billing/") response = self.client_get("/billing/")
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
self.assertEqual('/upgrade/', response.url) self.assertEqual('/upgrade/', response.url)
# Try again, with a valid card # Try again, with a valid card, after they added a few users
self.upgrade() with patch('corporate.lib.stripe.get_seat_count', return_value=23):
with patch('corporate.views.get_seat_count', return_value=23):
self.upgrade()
customer = Customer.objects.get(realm=get_realm('zulip')) customer = Customer.objects.get(realm=get_realm('zulip'))
# Impossible to create two Customers, but check that we didn't # It's impossible to create two Customers, but check that we didn't
# change stripe_customer_id and that we updated has_billing_relationship # change stripe_customer_id
self.assertEqual(customer.stripe_customer_id, original_stripe_customer_id) self.assertEqual(customer.stripe_customer_id, stripe_customer_id)
self.assertTrue(customer.has_billing_relationship) # Check that we successfully added a CustomerPlan
# Check that we successfully added a subscription self.assertTrue(CustomerPlan.objects.filter(customer=customer, licenses=23).exists())
stripe_customer = stripe_get_customer(customer.stripe_customer_id) # Check the Charges and Invoices in Stripe
self.assertTrue(extract_current_subscription(stripe_customer)) self.assertEqual(8000 * 23, [charge for charge in
stripe.Charge.list(customer=stripe_customer_id)][0].amount)
stripe_invoice = [invoice for invoice in stripe.Invoice.list(customer=stripe_customer_id)][0]
self.assertEqual([8000 * 23, -8000 * 23],
[item.amount for item in stripe_invoice.lines])
# Check that we correctly populated RealmAuditLog # Check that we correctly populated RealmAuditLog
audit_log_entries = list(RealmAuditLog.objects.filter(acting_user=user) audit_log_entries = list(RealmAuditLog.objects.filter(acting_user=user)
.values_list('event_type', flat=True).order_by('id')) .values_list('event_type', flat=True).order_by('id'))
@ -459,10 +582,10 @@ class StripeTest(ZulipTestCase):
self.assertEqual(audit_log_entries, [RealmAuditLog.STRIPE_CUSTOMER_CREATED, self.assertEqual(audit_log_entries, [RealmAuditLog.STRIPE_CUSTOMER_CREATED,
RealmAuditLog.STRIPE_CARD_CHANGED, RealmAuditLog.STRIPE_CARD_CHANGED,
RealmAuditLog.STRIPE_CARD_CHANGED, RealmAuditLog.STRIPE_CARD_CHANGED,
RealmAuditLog.STRIPE_PLAN_CHANGED]) RealmAuditLog.CUSTOMER_PLAN_CREATED])
# Check that we correctly updated Realm # Check that we correctly updated Realm
realm = get_realm("zulip") realm = get_realm("zulip")
self.assertTrue(realm.has_seat_based_plan) self.assertEqual(realm.plan_type, Realm.STANDARD)
# Check that we can no longer access /upgrade # Check that we can no longer access /upgrade
response = self.client_get("/upgrade/") response = self.client_get("/upgrade/")
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
@ -543,69 +666,6 @@ class StripeTest(ZulipTestCase):
self.assert_json_error_contains(response, "Something went wrong. Please contact zulip-admin@example.com.") self.assert_json_error_contains(response, "Something went wrong. Please contact zulip-admin@example.com.")
self.assertEqual(ujson.loads(response.content)['error_description'], 'uncaught exception during upgrade') self.assertEqual(ujson.loads(response.content)['error_description'], 'uncaught exception during upgrade')
@mock_stripe(tested_timestamp_fields=["created"])
def test_upgrade_billing_by_invoice(self, *mocks: Mock) -> None:
user = self.example_user("hamlet")
self.login(user.email)
self.upgrade(invoice=True)
# Check that we correctly created a Customer in Stripe
stripe_customer = stripe_get_customer(Customer.objects.get(realm=user.realm).stripe_customer_id)
self.assertEqual(stripe_customer.email, user.email)
# It can take a second for Stripe to attach the source to the
# customer, and in particular it may not be attached at the time
# stripe_get_customer is called above, causing test flakes.
# So commenting the next line out, but leaving it here so future readers know what
# is supposed to happen here (e.g. the default_source is not None as it would be if
# we had not added a Subscription).
# self.assertEqual(stripe_customer.default_source.type, 'ach_credit_transfer')
# Check that we correctly created a Subscription in Stripe
stripe_subscription = extract_current_subscription(stripe_customer)
self.assertEqual(stripe_subscription.billing, 'send_invoice')
self.assertEqual(stripe_subscription.days_until_due, DEFAULT_INVOICE_DAYS_UNTIL_DUE)
self.assertEqual(stripe_subscription.plan.id,
Plan.objects.get(nickname=Plan.CLOUD_ANNUAL).stripe_plan_id)
# In the middle of migrating off of this billing algorithm
# self.assertEqual(stripe_subscription.quantity, get_seat_count(user.realm))
self.assertEqual(stripe_subscription.status, 'active')
# Check that we correctly created an initial Invoice in Stripe
for stripe_invoice in stripe.Invoice.list(customer=stripe_customer.id, limit=1):
self.assertTrue(stripe_invoice.auto_advance)
self.assertEqual(stripe_invoice.billing, 'send_invoice')
self.assertEqual(stripe_invoice.billing_reason, 'subscription_create')
# Transitions to 'open' after 1-2 hours
self.assertEqual(stripe_invoice.status, 'draft')
# Very important. Check that we're invoicing for 123, and not get_seat_count
self.assertEqual(stripe_invoice.amount_due, 8000*123)
# Check that we correctly updated Realm
realm = get_realm("zulip")
self.assertTrue(realm.has_seat_based_plan)
self.assertEqual(realm.plan_type, Realm.STANDARD)
# Check that we created a Customer in Zulip
self.assertEqual(1, Customer.objects.filter(stripe_customer_id=stripe_customer.id,
realm=realm).count())
# Check that RealmAuditLog has STRIPE_PLAN_QUANTITY_RESET, and doesn't have STRIPE_CARD_CHANGED
audit_log_entries = list(RealmAuditLog.objects.order_by('-id')
.values_list('event_type', 'event_time',
'requires_billing_update')[:4])[::-1]
self.assertEqual(audit_log_entries, [
(RealmAuditLog.STRIPE_CUSTOMER_CREATED, timestamp_to_datetime(stripe_customer.created), False),
(RealmAuditLog.STRIPE_PLAN_CHANGED, timestamp_to_datetime(stripe_subscription.created), False),
(RealmAuditLog.STRIPE_PLAN_QUANTITY_RESET, timestamp_to_datetime(stripe_subscription.created), True),
(RealmAuditLog.REALM_PLAN_TYPE_CHANGED, Kandra(), False),
])
self.assertEqual(ujson.loads(RealmAuditLog.objects.filter(
event_type=RealmAuditLog.STRIPE_PLAN_QUANTITY_RESET).values_list('extra_data', flat=True).first()),
{'quantity': self.seat_count})
# Check /billing has the correct information
response = self.client_get("/billing/")
self.assert_not_in_success_response(['Pay annually', 'Update card'], response)
for substring in ['Your plan will renew on', 'Billed by invoice']:
self.assert_in_response(substring, response)
def test_redirect_for_billing_home(self) -> None: def test_redirect_for_billing_home(self) -> None:
user = self.example_user("iago") user = self.example_user("iago")
self.login(user.email) self.login(user.email)
@ -650,17 +710,18 @@ class StripeTest(ZulipTestCase):
# histories don't throw errors # histories don't throw errors
@mock_stripe() @mock_stripe()
def test_payment_method_string(self, *mocks: Mock) -> None: def test_payment_method_string(self, *mocks: Mock) -> None:
pass
# If you signup with a card, we should show your card as the payment method # If you signup with a card, we should show your card as the payment method
# Already tested in test_initial_upgrade # Already tested in test_initial_upgrade
# If you pay by invoice, your payment method should be # If you pay by invoice, your payment method should be
# "Billed by invoice", even if you have a card on file # "Billed by invoice", even if you have a card on file
user = self.example_user("hamlet") # user = self.example_user("hamlet")
do_create_customer(user, stripe_create_token().id) # do_create_customer(user, stripe_create_token().id)
self.login(user.email) # self.login(user.email)
self.upgrade(invoice=True) # self.upgrade(invoice=True)
stripe_customer = stripe_get_customer(Customer.objects.get(realm=user.realm).stripe_customer_id) # stripe_customer = stripe_get_customer(Customer.objects.get(realm=user.realm).stripe_customer_id)
self.assertEqual('Billed by invoice', payment_method_string(stripe_customer)) # self.assertEqual('Billed by invoice', payment_method_string(stripe_customer))
# If you signup with a card and then downgrade, we still have your # If you signup with a card and then downgrade, we still have your
# card on file, and should show it # card on file, and should show it
@ -806,3 +867,82 @@ class RequiresBillingAccessTest(ZulipTestCase):
json_endpoints.remove("json/billing/upgrade") json_endpoints.remove("json/billing/upgrade")
self.assertEqual(len(json_endpoints), len(params)) self.assertEqual(len(json_endpoints), len(params))
class BillingHelpersTest(ZulipTestCase):
def test_next_month(self) -> None:
anchor = datetime(2019, 12, 31, 1, 2, 3).replace(tzinfo=timezone_utc)
period_boundaries = [
anchor,
datetime(2020, 1, 31, 1, 2, 3).replace(tzinfo=timezone_utc),
# Test that this is the 28th even during leap years
datetime(2020, 2, 28, 1, 2, 3).replace(tzinfo=timezone_utc),
datetime(2020, 3, 31, 1, 2, 3).replace(tzinfo=timezone_utc),
datetime(2020, 4, 30, 1, 2, 3).replace(tzinfo=timezone_utc),
datetime(2020, 5, 31, 1, 2, 3).replace(tzinfo=timezone_utc),
datetime(2020, 6, 30, 1, 2, 3).replace(tzinfo=timezone_utc),
datetime(2020, 7, 31, 1, 2, 3).replace(tzinfo=timezone_utc),
datetime(2020, 8, 31, 1, 2, 3).replace(tzinfo=timezone_utc),
datetime(2020, 9, 30, 1, 2, 3).replace(tzinfo=timezone_utc),
datetime(2020, 10, 31, 1, 2, 3).replace(tzinfo=timezone_utc),
datetime(2020, 11, 30, 1, 2, 3).replace(tzinfo=timezone_utc),
datetime(2020, 12, 31, 1, 2, 3).replace(tzinfo=timezone_utc),
datetime(2021, 1, 31, 1, 2, 3).replace(tzinfo=timezone_utc),
datetime(2021, 2, 28, 1, 2, 3).replace(tzinfo=timezone_utc)]
with self.assertRaises(AssertionError):
add_months(anchor, -1)
# Explictly test add_months for each value of MAX_DAY_FOR_MONTH and
# for crossing a year boundary
for i, boundary in enumerate(period_boundaries):
self.assertEqual(add_months(anchor, i), boundary)
# Test next_month for small values
for last, next_ in zip(period_boundaries[:-1], period_boundaries[1:]):
self.assertEqual(next_month(anchor, last), next_)
# Test next_month for large values
period_boundaries = [dt.replace(year=dt.year+100) for dt in period_boundaries]
for last, next_ in zip(period_boundaries[:-1], period_boundaries[1:]):
self.assertEqual(next_month(anchor, last), next_)
def test_compute_plan_parameters(self) -> None:
# TODO: test rounding down microseconds
anchor = datetime(2019, 12, 31, 1, 2, 3).replace(tzinfo=timezone_utc)
month_later = datetime(2020, 1, 31, 1, 2, 3).replace(tzinfo=timezone_utc)
year_later = datetime(2020, 12, 31, 1, 2, 3).replace(tzinfo=timezone_utc)
test_cases = [
# TODO test with Decimal(85), not 85
# TODO fix the mypy error by specifying the exact type
# test all possibilities, since there aren't that many
[(True, CustomerPlan.ANNUAL, None), (anchor, month_later, year_later, 8000)], # lint:ignore
[(True, CustomerPlan.ANNUAL, 85), (anchor, month_later, year_later, 1200)], # lint:ignore
[(True, CustomerPlan.MONTHLY, None), (anchor, month_later, month_later, 800)], # lint:ignore
[(True, CustomerPlan.MONTHLY, 85), (anchor, month_later, month_later, 120)], # lint:ignore
[(False, CustomerPlan.ANNUAL, None), (anchor, year_later, year_later, 8000)], # lint:ignore
[(False, CustomerPlan.ANNUAL, 85), (anchor, year_later, year_later, 1200)], # lint:ignore
[(False, CustomerPlan.MONTHLY, None), (anchor, month_later, month_later, 800)], # lint:ignore
[(False, CustomerPlan.MONTHLY, 85), (anchor, month_later, month_later, 120)], # lint:ignore
# test exact math of Decimals; 800 * (1 - 87.25) = 101.9999999..
[(False, CustomerPlan.MONTHLY, 87.25), (anchor, month_later, month_later, 102)],
# test dropping of fractional cents; without the int it's 102.8
[(False, CustomerPlan.MONTHLY, 87.15), (anchor, month_later, month_later, 102)]]
with patch('corporate.lib.stripe.timezone_now', return_value=anchor):
for input_, output in test_cases:
output_ = compute_plan_parameters(*input_) # type: ignore # TODO
self.assertEqual(output_, output)
def test_update_or_create_stripe_customer_logic(self) -> None:
user = self.example_user('hamlet')
# No existing Customer object
with patch('corporate.lib.stripe.do_create_customer', return_value='returned') as mocked1:
returned = update_or_create_stripe_customer(user, stripe_token='token')
mocked1.assert_called()
self.assertEqual(returned, 'returned')
# Customer exists, replace payment source
Customer.objects.create(realm=get_realm('zulip'), stripe_customer_id='cus_12345')
with patch('corporate.lib.stripe.do_replace_payment_source') as mocked2:
customer = update_or_create_stripe_customer(self.example_user('hamlet'), 'token')
mocked2.assert_called()
self.assertTrue(isinstance(customer, Customer))
# Customer exists, do nothing
with patch('corporate.lib.stripe.do_replace_payment_source') as mocked3:
customer = update_or_create_stripe_customer(self.example_user('hamlet'), None)
mocked3.assert_not_called()
self.assertTrue(isinstance(customer, Customer))

View File

@ -14,15 +14,16 @@ from zerver.decorator import zulip_login_required, require_billing_access
from zerver.lib.json_encoder_for_html import JSONEncoderForHTML from zerver.lib.json_encoder_for_html import JSONEncoderForHTML
from zerver.lib.request import REQ, has_request_variables from zerver.lib.request import REQ, has_request_variables
from zerver.lib.response import json_error, json_success from zerver.lib.response import json_error, json_success
from zerver.lib.validator import check_string, check_int from zerver.lib.validator import check_string, check_int, check_bool
from zerver.lib.timestamp import timestamp_to_datetime from zerver.lib.timestamp import timestamp_to_datetime
from zerver.models import UserProfile, Realm from zerver.models import UserProfile, Realm
from corporate.lib.stripe import STRIPE_PUBLISHABLE_KEY, \ from corporate.lib.stripe import STRIPE_PUBLISHABLE_KEY, \
stripe_get_customer, upcoming_invoice_total, get_seat_count, \ stripe_get_customer, get_seat_count, \
extract_current_subscription, process_initial_upgrade, sign_string, \ process_initial_upgrade, sign_string, \
unsign_string, BillingError, process_downgrade, do_replace_payment_source, \ unsign_string, BillingError, process_downgrade, do_replace_payment_source, \
MIN_INVOICED_LICENSES, DEFAULT_INVOICE_DAYS_UNTIL_DUE MIN_INVOICED_LICENSES, DEFAULT_INVOICE_DAYS_UNTIL_DUE, \
from corporate.models import Customer, CustomerPlan, Plan next_renewal_date, renewal_amount
from corporate.models import Customer, CustomerPlan, Plan, get_active_plan
billing_logger = logging.getLogger('corporate.stripe') billing_logger = logging.getLogger('corporate.stripe')
@ -53,8 +54,9 @@ def check_upgrade_parameters(
raise BillingError('not enough licenses', raise BillingError('not enough licenses',
_("You must invoice for at least {} users.".format(min_licenses))) _("You must invoice for at least {} users.".format(min_licenses)))
def payment_method_string(stripe_customer: stripe.Customer) -> str: # TODO
subscription = extract_current_subscription(stripe_customer) def payment_method_string(stripe_customer: stripe.Customer) -> str: # nocoverage: TODO
subscription = None # extract_current_subscription(stripe_customer)
if subscription is not None and subscription.billing == "send_invoice": if subscription is not None and subscription.billing == "send_invoice":
return _("Billed by invoice") return _("Billed by invoice")
stripe_source = stripe_customer.default_source stripe_source = stripe_customer.default_source
@ -91,10 +93,11 @@ def upgrade(request: HttpRequest, user: UserProfile,
check_upgrade_parameters( check_upgrade_parameters(
billing_modality, schedule, license_management, licenses, billing_modality, schedule, license_management, licenses,
stripe_token is not None, seat_count) stripe_token is not None, seat_count)
automanage_licenses = license_management in ['automatic', 'mix']
billing_schedule = {'annual': CustomerPlan.ANNUAL, billing_schedule = {'annual': CustomerPlan.ANNUAL,
'monthly': CustomerPlan.MONTHLY}[schedule] 'monthly': CustomerPlan.MONTHLY}[schedule]
process_initial_upgrade(user, licenses, billing_schedule, stripe_token) process_initial_upgrade(user, licenses, automanage_licenses, billing_schedule, stripe_token)
except BillingError as e: except BillingError as e:
# TODO add a billing_logger.warning with all the upgrade parameters # TODO add a billing_logger.warning with all the upgrade parameters
return json_error(e.message, data={'error_description': e.description}) return json_error(e.message, data={'error_description': e.description})
@ -113,7 +116,7 @@ def initial_upgrade(request: HttpRequest) -> HttpResponse:
user = request.user user = request.user
customer = Customer.objects.filter(realm=user.realm).first() customer = Customer.objects.filter(realm=user.realm).first()
if customer is not None and customer.has_billing_relationship: if customer is not None and CustomerPlan.objects.filter(customer=customer).exists():
return HttpResponseRedirect(reverse('corporate.views.billing_home')) return HttpResponseRedirect(reverse('corporate.views.billing_home'))
percent_off = 0 percent_off = 0
@ -152,7 +155,7 @@ def billing_home(request: HttpRequest) -> HttpResponse:
customer = Customer.objects.filter(realm=user.realm).first() customer = Customer.objects.filter(realm=user.realm).first()
if customer is None: if customer is None:
return HttpResponseRedirect(reverse('corporate.views.initial_upgrade')) return HttpResponseRedirect(reverse('corporate.views.initial_upgrade'))
if not customer.has_billing_relationship: if not CustomerPlan.objects.filter(customer=customer).exists():
return HttpResponseRedirect(reverse('corporate.views.initial_upgrade')) return HttpResponseRedirect(reverse('corporate.views.initial_upgrade'))
if not user.is_realm_admin and not user.is_billing_admin: if not user.is_realm_admin and not user.is_billing_admin:
@ -160,40 +163,44 @@ def billing_home(request: HttpRequest) -> HttpResponse:
return render(request, 'corporate/billing.html', context=context) return render(request, 'corporate/billing.html', context=context)
context = {'admin_access': True} context = {'admin_access': True}
stripe_customer = stripe_get_customer(customer.stripe_customer_id) charge_automatically = False
if stripe_customer.account_balance > 0: # nocoverage, waiting for mock_stripe to mature plan = get_active_plan(customer)
context.update({'account_charges': '{:,.2f}'.format(stripe_customer.account_balance / 100.)}) if plan is not None:
if stripe_customer.account_balance < 0: # nocoverage plan_name = {
context.update({'account_credits': '{:,.2f}'.format(-stripe_customer.account_balance / 100.)}) CustomerPlan.STANDARD: 'Zulip Standard',
CustomerPlan.PLUS: 'Zulip Plus',
billed_by_invoice = False }[plan.tier]
subscription = extract_current_subscription(stripe_customer) licenses = plan.licenses
if subscription:
plan_name = PLAN_NAMES[Plan.objects.get(stripe_plan_id=subscription.plan.id).nickname]
licenses = subscription.quantity
# Need user's timezone to do this properly # Need user's timezone to do this properly
renewal_date = '{dt:%B} {dt.day}, {dt.year}'.format( renewal_date = '{dt:%B} {dt.day}, {dt.year}'.format(dt=next_renewal_date(plan))
dt=timestamp_to_datetime(subscription.current_period_end)) renewal_cents = renewal_amount(plan)
renewal_amount = upcoming_invoice_total(customer.stripe_customer_id) charge_automatically = plan.charge_automatically
if subscription.billing == 'send_invoice': if charge_automatically: # nocoverage: TODO
billed_by_invoice = True # TODO get last4
payment_method = 'Card on file'
else: # nocoverage: TODO
payment_method = 'Billed by invoice'
billed_by_invoice = not plan.charge_automatically
# Can only get here by subscribing and then downgrading. We don't support downgrading # Can only get here by subscribing and then downgrading. We don't support downgrading
# yet, but keeping this code here since we will soon. # yet, but keeping this code here since we will soon.
else: # nocoverage else: # nocoverage
plan_name = "Zulip Free" plan_name = "Zulip Free"
licenses = 0 licenses = 0
renewal_date = '' renewal_date = ''
renewal_amount = 0 renewal_cents = 0
payment_method = ''
context.update({ context.update({
'plan_name': plan_name, 'plan_name': plan_name,
'licenses': licenses, 'licenses': licenses,
'renewal_date': renewal_date, 'renewal_date': renewal_date,
'renewal_amount': '{:,.2f}'.format(renewal_amount / 100.), 'renewal_amount': '{:,.2f}'.format(renewal_cents / 100.),
'payment_method': payment_method_string(stripe_customer), 'payment_method': payment_method,
# TODO: Rename to charge_automatically
'billed_by_invoice': billed_by_invoice, 'billed_by_invoice': billed_by_invoice,
'publishable_key': STRIPE_PUBLISHABLE_KEY, 'publishable_key': STRIPE_PUBLISHABLE_KEY,
'stripe_email': stripe_customer.email, # TODO: get actual stripe email?
'stripe_email': user.email,
}) })
return render(request, 'corporate/billing.html', context=context) return render(request, 'corporate/billing.html', context=context)

Some files were not shown because too many files have changed in this diff Show More