diff --git a/corporate/lib/stripe.py b/corporate/lib/stripe.py index e63cf38c8b..5d6b277534 100644 --- a/corporate/lib/stripe.py +++ b/corporate/lib/stripe.py @@ -19,7 +19,8 @@ from zerver.lib.timestamp import datetime_to_timestamp, timestamp_to_datetime from zerver.lib.utils import generate_random_token from zerver.lib.actions import do_change_plan_type from zerver.models import Realm, UserProfile, RealmAuditLog -from corporate.models import Customer, CustomerPlan, get_active_plan +from corporate.models import Customer, CustomerPlan, LicenseLedger, \ + get_active_plan from zproject.settings import get_secret STRIPE_PUBLISHABLE_KEY = get_secret('stripe_publishable_key') @@ -91,15 +92,15 @@ def next_renewal_date(plan: CustomerPlan) -> datetime: periods += 1 return dt -def renewal_amount(plan: CustomerPlan) -> int: # nocoverage: TODO +def renewal_amount(plan: CustomerPlan) -> Optional[int]: # nocoverage: TODO if plan.fixed_price is not None: basis = plan.fixed_price - elif plan.automanage_licenses: - assert(plan.price_per_license is not None) - basis = plan.price_per_license * get_seat_count(plan.customer.realm) else: - assert(plan.price_per_license is not None) - basis = plan.price_per_license * plan.licenses + last_ledger_entry = add_plan_renewal_to_license_ledger_if_needed(plan, timezone_now()) + if last_ledger_entry.licenses_at_next_renewal is None: + return None + assert(plan.price_per_license is not None) # for mypy + basis = plan.price_per_license * last_ledger_entry.licenses_at_next_renewal if plan.discount is None: return basis # TODO: figure out right thing to do with Decimal @@ -191,6 +192,21 @@ def do_replace_payment_source(user: UserProfile, stripe_token: str) -> stripe.Cu event_time=timezone_now()) return updated_stripe_customer +# event_time should roughly be timezone_now(). Not designed to handle +# event_times in the past or future +# TODO handle downgrade +def add_plan_renewal_to_license_ledger_if_needed(plan: CustomerPlan, event_time: datetime) -> LicenseLedger: + last_ledger_entry = LicenseLedger.objects.filter(plan=plan).order_by('-event_time').first() + plan_renewal_date = next_renewal_date(plan) + if plan_renewal_date < event_time: + if not LicenseLedger.objects.filter( + plan=plan, event_time=plan_renewal_date, is_renewal=True).exists(): + return LicenseLedger.objects.create( + plan=plan, is_renewal=True, event_time=plan_renewal_date, + licenses=last_ledger_entry.licenses_at_next_renewal, + licenses_at_next_renewal=last_ledger_entry.licenses_at_next_renewal) + return last_ledger_entry + # Returns Customer instead of stripe_customer so that we don't make a Stripe # API call if there's nothing to update def update_or_create_stripe_customer(user: UserProfile, stripe_token: Optional[str]=None) -> Customer: @@ -273,7 +289,6 @@ def process_initial_upgrade(user: UserProfile, licenses: int, automanage_license # this function (process_initial_upgrade) and now billed_licenses = max(get_seat_count(realm), licenses) plan_params = { - 'licenses': billed_licenses, 'automanage_licenses': automanage_licenses, 'charge_automatically': charge_automatically, 'price_per_license': price_per_license, @@ -281,17 +296,22 @@ def process_initial_upgrade(user: UserProfile, licenses: int, automanage_license 'billing_cycle_anchor': billing_cycle_anchor, 'billing_schedule': billing_schedule, 'tier': CustomerPlan.STANDARD} - CustomerPlan.objects.create( + plan = CustomerPlan.objects.create( customer=customer, + # Deprecated, remove + licenses=-1, billed_through=billing_cycle_anchor, next_billing_date=next_billing_date, **plan_params) + LicenseLedger.objects.create( + plan=plan, + is_renewal=True, + event_time=billing_cycle_anchor, + licenses=billed_licenses, + licenses_at_next_renewal=billed_licenses) RealmAuditLog.objects.create( realm=realm, acting_user=user, event_time=billing_cycle_anchor, event_type=RealmAuditLog.CUSTOMER_PLAN_CREATED, - # TODO: add tests for licenses - # Only 'licenses' is guaranteed to be useful to automated tools. The other extra_data - # fields can change in the future and are only meant to assist manual debugging. extra_data=ujson.dumps(plan_params)) description = 'Zulip Standard' if customer.default_discount is not None: # nocoverage: TODO @@ -336,7 +356,9 @@ def estimate_annual_recurring_revenue_by_realm() -> Dict[str, int]: # nocoverag annual_revenue = {} for plan in CustomerPlan.objects.filter( status=CustomerPlan.ACTIVE).select_related('customer__realm'): - renewal_cents = renewal_amount(plan) + # TODO: figure out what to do for plans that don't automatically + # renew, but which probably will renew + renewal_cents = renewal_amount(plan) or 0 if plan.billing_schedule == CustomerPlan.MONTHLY: renewal_cents *= 12 # TODO: Decimal stuff diff --git a/corporate/migrations/0004_licenseledger.py b/corporate/migrations/0004_licenseledger.py new file mode 100644 index 0000000000..4df372682d --- /dev/null +++ b/corporate/migrations/0004_licenseledger.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- +# Generated by Django 1.11.18 on 2019-01-19 05:01 +from __future__ import unicode_literals + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('corporate', '0003_customerplan'), + ] + + operations = [ + migrations.CreateModel( + name='LicenseLedger', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('is_renewal', models.BooleanField(default=False)), + ('event_time', models.DateTimeField()), + ('licenses', models.IntegerField()), + ('licenses_at_next_renewal', models.IntegerField(null=True)), + ('plan', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='corporate.CustomerPlan')), + ], + ), + ] diff --git a/corporate/models.py b/corporate/models.py index 888e19f3ce..169006659a 100644 --- a/corporate/models.py +++ b/corporate/models.py @@ -19,6 +19,7 @@ class Customer(models.Model): class CustomerPlan(models.Model): customer = models.ForeignKey(Customer, on_delete=CASCADE) # type: Customer + # Deprecated .. delete once everyone is migrated to new billing system licenses = models.IntegerField() # type: int automanage_licenses = models.BooleanField(default=False) # type: bool charge_automatically = models.BooleanField(default=False) # type: bool @@ -57,6 +58,17 @@ class CustomerPlan(models.Model): def get_active_plan(customer: Customer) -> Optional[CustomerPlan]: return CustomerPlan.objects.filter(customer=customer, status=CustomerPlan.ACTIVE).first() +class LicenseLedger(models.Model): + plan = models.ForeignKey(CustomerPlan, on_delete=CASCADE) # type: CustomerPlan + # Also True for the initial upgrade. + is_renewal = models.BooleanField(default=False) # type: bool + event_time = models.DateTimeField() # type: datetime.datetime + licenses = models.IntegerField() # type: int + # None means the plan does not automatically renew. + # 0 means the plan has been explicitly downgraded. + # This cannot be None if plan.automanage_licenses. + licenses_at_next_renewal = models.IntegerField(null=True) # type: Optional[int] + # Everything below here is legacy class Plan(models.Model): diff --git a/corporate/tests/test_stripe.py b/corporate/tests/test_stripe.py index fc198939a7..48c80e01a7 100644 --- a/corporate/tests/test_stripe.py +++ b/corporate/tests/test_stripe.py @@ -1,6 +1,6 @@ -from datetime import datetime +from datetime import datetime, timedelta from decimal import Decimal -from functools import wraps +from functools import partial, wraps from mock import Mock, patch import operator import os @@ -28,8 +28,9 @@ from corporate.lib.stripe import catch_stripe_errors, attach_discount_to_realm, BillingError, StripeCardError, StripeConnectionError, stripe_get_customer, \ DEFAULT_INVOICE_DAYS_UNTIL_DUE, MIN_INVOICED_LICENSES, do_create_customer, \ add_months, next_month, next_renewal_date, renewal_amount, \ - compute_plan_parameters, update_or_create_stripe_customer -from corporate.models import Customer, CustomerPlan + compute_plan_parameters, update_or_create_stripe_customer, \ + process_initial_upgrade, add_plan_renewal_to_license_ledger_if_needed +from corporate.models import Customer, CustomerPlan, LicenseLedger from corporate.views import payment_method_string import corporate.urls @@ -367,15 +368,17 @@ class StripeTest(ZulipTestCase): for key, value in line_item_params.items(): self.assertEqual(stripe_line_items[1].get(key), value) - # Check that we correctly populated Customer and CustomerPlan in Zulip - customer = Customer.objects.filter(stripe_customer_id=stripe_customer.id, - realm=user.realm).first() - self.assertTrue(CustomerPlan.objects.filter( - customer=customer, licenses=self.seat_count, automanage_licenses=True, + # Check that we correctly populated Customer, CustomerPlan, and LicenseLedger in Zulip + customer = Customer.objects.get(stripe_customer_id=stripe_customer.id, realm=user.realm) + plan = CustomerPlan.objects.get( + customer=customer, automanage_licenses=True, price_per_license=8000, fixed_price=None, discount=None, billing_cycle_anchor=self.now, billing_schedule=CustomerPlan.ANNUAL, billed_through=self.now, next_billing_date=self.next_month, tier=CustomerPlan.STANDARD, - status=CustomerPlan.ACTIVE).exists()) + status=CustomerPlan.ACTIVE) + LicenseLedger.objects.get( + plan=plan, is_renewal=True, event_time=self.now, licenses=self.seat_count, + licenses_at_next_renewal=self.seat_count) # Check RealmAuditLog audit_log_entries = list(RealmAuditLog.objects.filter(acting_user=user) .values_list('event_type', 'event_time').order_by('id')) @@ -388,7 +391,7 @@ class StripeTest(ZulipTestCase): ]) self.assertEqual(ujson.loads(RealmAuditLog.objects.filter( event_type=RealmAuditLog.CUSTOMER_PLAN_CREATED).values_list( - 'extra_data', flat=True).first())['licenses'], self.seat_count) + 'extra_data', flat=True).first())['automanage_licenses'], True) # Check that we correctly updated Realm realm = get_realm("zulip") self.assertEqual(realm.plan_type, Realm.STANDARD) @@ -449,15 +452,16 @@ class StripeTest(ZulipTestCase): for key, value in line_item_params.items(): self.assertEqual(stripe_line_items[0].get(key), value) - # Check that we correctly populated Customer and CustomerPlan in Zulip - customer = Customer.objects.filter(stripe_customer_id=stripe_customer.id, - realm=user.realm).first() - self.assertTrue(CustomerPlan.objects.filter( - customer=customer, licenses=123, automanage_licenses=False, charge_automatically=False, + # Check that we correctly populated Customer, CustomerPlan and LicenseLedger in Zulip + customer = Customer.objects.get(stripe_customer_id=stripe_customer.id, realm=user.realm) + plan = CustomerPlan.objects.get( + customer=customer, automanage_licenses=False, charge_automatically=False, price_per_license=8000, fixed_price=None, discount=None, billing_cycle_anchor=self.now, billing_schedule=CustomerPlan.ANNUAL, billed_through=self.now, next_billing_date=self.next_year, tier=CustomerPlan.STANDARD, - status=CustomerPlan.ACTIVE).exists()) + status=CustomerPlan.ACTIVE) + LicenseLedger.objects.get( + plan=plan, is_renewal=True, event_time=self.now, licenses=123, licenses_at_next_renewal=123) # Check RealmAuditLog audit_log_entries = list(RealmAuditLog.objects.filter(acting_user=user) .values_list('event_type', 'event_time').order_by('id')) @@ -469,7 +473,7 @@ class StripeTest(ZulipTestCase): ]) self.assertEqual(ujson.loads(RealmAuditLog.objects.filter( event_type=RealmAuditLog.CUSTOMER_PLAN_CREATED).values_list( - 'extra_data', flat=True).first())['licenses'], 123) + 'extra_data', flat=True).first())['automanage_licenses'], False) # Check that we correctly updated Realm realm = get_realm("zulip") self.assertEqual(realm.plan_type, Realm.STANDARD) @@ -524,11 +528,9 @@ class StripeTest(ZulipTestCase): stripe_invoice = [invoice for invoice in stripe.Invoice.list(customer=stripe_customer_id)][0] self.assertEqual([8000 * new_seat_count, -8000 * self.seat_count], [item.amount for item in stripe_invoice.lines]) - # Check CustomerPlan and RealmAuditLog have the new amount - self.assertEqual(CustomerPlan.objects.first().licenses, new_seat_count) - self.assertEqual(ujson.loads(RealmAuditLog.objects.filter( - event_type=RealmAuditLog.CUSTOMER_PLAN_CREATED).values_list( - 'extra_data', flat=True).first())['licenses'], new_seat_count) + # Check LicenseLedger has the new amount + self.assertEqual(LicenseLedger.objects.first().licenses, new_seat_count) + self.assertEqual(LicenseLedger.objects.first().licenses_at_next_renewal, new_seat_count) @mock_stripe() def test_upgrade_where_first_card_fails(self, *mocks: Mock) -> None: @@ -571,8 +573,11 @@ class StripeTest(ZulipTestCase): # It's impossible to create two Customers, but check that we didn't # change stripe_customer_id self.assertEqual(customer.stripe_customer_id, stripe_customer_id) - # Check that we successfully added a CustomerPlan - self.assertTrue(CustomerPlan.objects.filter(customer=customer, licenses=23).exists()) + # Check that we successfully added a CustomerPlan, and have the right number of licenses + plan = CustomerPlan.objects.get(customer=customer) + ledger_entry = LicenseLedger.objects.get(plan=plan) + self.assertEqual(ledger_entry.licenses, 23) + self.assertEqual(ledger_entry.licenses_at_next_renewal, 23) # Check the Charges and Invoices in Stripe self.assertEqual(8000 * 23, [charge for charge in stripe.Charge.list(customer=stripe_customer_id)][0].amount) @@ -912,3 +917,50 @@ class BillingHelpersTest(ZulipTestCase): customer = update_or_create_stripe_customer(self.example_user('hamlet'), None) mocked3.assert_not_called() self.assertTrue(isinstance(customer, Customer)) + +# todo: Create a StripeTestCase, similar to AnalyticsTestCase +class LicenseLedgerTest(ZulipTestCase): + def setUp(self) -> None: + self.seat_count = get_seat_count(get_realm('zulip')) + self.now = datetime(2012, 1, 2, 3, 4, 5).replace(tzinfo=timezone_utc) + self.next_month = datetime(2012, 2, 2, 3, 4, 5).replace(tzinfo=timezone_utc) + self.next_year = datetime(2013, 1, 2, 3, 4, 5).replace(tzinfo=timezone_utc) + + # Upgrade without talking to Stripe + def local_upgrade(self, *args: Any) -> None: + class StripeMock(object): + def __init__(self, depth: int=1): + self.id = 'id' + self.created = '1000' + self.last4 = '4242' + if depth == 1: + self.source = StripeMock(depth=2) + + def upgrade_func(*args: Any) -> Any: + return process_initial_upgrade(self.example_user('hamlet'), *args[:4]) + + for mocked_function_name in MOCKED_STRIPE_FUNCTION_NAMES: + upgrade_func = patch(mocked_function_name, return_value=StripeMock())(upgrade_func) + upgrade_func(*args) + + def test_add_plan_renewal_if_needed(self) -> None: + with patch('corporate.lib.stripe.timezone_now', return_value=self.now): + self.local_upgrade(self.seat_count, True, CustomerPlan.ANNUAL, 'token') + self.assertEqual(LicenseLedger.objects.count(), 1) + plan = CustomerPlan.objects.get() + # Plan hasn't renewed yet + add_plan_renewal_to_license_ledger_if_needed(plan, self.next_year) + self.assertEqual(LicenseLedger.objects.count(), 1) + # Plan needs to renew + # TODO: do_deactivate_user for a user, so that licenses_at_next_renewal != licenses + ledger_entry = add_plan_renewal_to_license_ledger_if_needed( + plan, self.next_year + timedelta(seconds=1)) + self.assertEqual(LicenseLedger.objects.count(), 2) + ledger_params = { + 'plan': plan, 'is_renewal': True, 'event_time': self.next_year, + 'licenses': self.seat_count, 'licenses_at_next_renewal': self.seat_count} + for key, value in ledger_params.items(): + self.assertEqual(getattr(ledger_entry, key), value) + # Plan needs to renew, but we already added the plan_renewal ledger entry + add_plan_renewal_to_license_ledger_if_needed(plan, self.next_year + timedelta(seconds=1)) + self.assertEqual(LicenseLedger.objects.count(), 2) diff --git a/corporate/views.py b/corporate/views.py index 1433147115..1728d56c77 100644 --- a/corporate/views.py +++ b/corporate/views.py @@ -4,7 +4,7 @@ from typing import Any, Dict, Optional, Tuple, cast from django.core import signing from django.http import HttpRequest, HttpResponse, HttpResponseRedirect -from django.utils import timezone +from django.utils.timezone import now as timezone_now from django.utils.translation import ugettext as _, ugettext as err_ from django.shortcuts import redirect, render from django.urls import reverse @@ -22,8 +22,10 @@ from corporate.lib.stripe import STRIPE_PUBLISHABLE_KEY, \ process_initial_upgrade, sign_string, \ unsign_string, BillingError, process_downgrade, do_replace_payment_source, \ MIN_INVOICED_LICENSES, DEFAULT_INVOICE_DAYS_UNTIL_DUE, \ - next_renewal_date, renewal_amount -from corporate.models import Customer, CustomerPlan, get_active_plan + next_renewal_date, renewal_amount, \ + add_plan_renewal_to_license_ledger_if_needed +from corporate.models import Customer, CustomerPlan, LicenseLedger, \ + get_active_plan billing_logger = logging.getLogger('corporate.stripe') @@ -166,10 +168,15 @@ def billing_home(request: HttpRequest) -> HttpResponse: CustomerPlan.STANDARD: 'Zulip Standard', CustomerPlan.PLUS: 'Zulip Plus', }[plan.tier] - licenses = plan.licenses + last_ledger_entry = add_plan_renewal_to_license_ledger_if_needed(plan, timezone_now()) + # TODO: this is not really correct; need to give the situation as of the "fillstate" + licenses = last_ledger_entry.licenses # Should do this in javascript, using the user's timezone renewal_date = '{dt:%B} {dt.day}, {dt.year}'.format(dt=next_renewal_date(plan)) renewal_cents = renewal_amount(plan) + # TODO: this is the case where the plan doesn't automatically renew + if renewal_cents is None: # nocoverage + renewal_cents = 0 charge_automatically = plan.charge_automatically if charge_automatically: payment_method = payment_method_string(stripe_customer)