billing: Move discount to local Customer object.

A lot of the seemingly unrelated test fixture changes are because we're
removing a query to stripe in the upgrade path, in cases when the user's
realm has an existing Customer object.
This commit is contained in:
Rishi Gupta 2018-12-12 10:41:03 -08:00
parent 8ec91fc42d
commit 7b5d15d254
30 changed files with 45 additions and 40 deletions

View File

@ -1,4 +1,5 @@
import datetime import datetime
from decimal import Decimal
from functools import wraps from functools import wraps
import logging import logging
import os import os
@ -125,17 +126,14 @@ def estimate_customer_arr(stripe_customer: stripe.Customer) -> int: # nocoverag
estimated_arr = stripe_subscription.plan.amount * stripe_subscription.quantity / 100. estimated_arr = stripe_subscription.plan.amount * stripe_subscription.quantity / 100.
if stripe_subscription.plan.interval == 'month': if stripe_subscription.plan.interval == 'month':
estimated_arr *= 12 estimated_arr *= 12
if stripe_customer.discount is not None: discount = Customer.objects.get(stripe_customer_id=stripe_customer.id).default_discount
estimated_arr *= 1 - stripe_customer.discount.coupon.percent_off/100. if discount is not None:
estimated_arr *= 1 - discount/100.
return int(estimated_arr) return int(estimated_arr)
@catch_stripe_errors @catch_stripe_errors
def do_create_customer(user: UserProfile, stripe_token: Optional[str]=None, def do_create_customer(user: UserProfile, stripe_token: Optional[str]=None) -> stripe.Customer:
coupon: Optional[Coupon]=None) -> stripe.Customer:
realm = user.realm realm = user.realm
stripe_coupon_id = None
if coupon is not None:
stripe_coupon_id = coupon.stripe_coupon_id
# We could do a better job of handling race conditions here, but if two # We could do a better job of handling race conditions here, but if two
# people from a realm try to upgrade at exactly the same time, the main # people from a realm try to upgrade at exactly the same time, the main
# bad thing that will happen is that we will create an extra stripe # bad thing that will happen is that we will create an extra stripe
@ -144,8 +142,7 @@ def do_create_customer(user: UserProfile, stripe_token: Optional[str]=None,
description="%s (%s)" % (realm.string_id, realm.name), description="%s (%s)" % (realm.string_id, realm.name),
email=user.email, email=user.email,
metadata={'realm_id': realm.id, 'realm_str': realm.string_id}, metadata={'realm_id': realm.id, 'realm_str': realm.string_id},
source=stripe_token, source=stripe_token)
coupon=stripe_coupon_id)
event_time = timestamp_to_datetime(stripe_customer.created) event_time = timestamp_to_datetime(stripe_customer.created)
with transaction.atomic(): with transaction.atomic():
RealmAuditLog.objects.create( RealmAuditLog.objects.create(
@ -173,12 +170,6 @@ def do_replace_payment_source(user: UserProfile, stripe_token: str) -> stripe.Cu
event_time=timezone_now()) event_time=timezone_now())
return updated_stripe_customer return updated_stripe_customer
@catch_stripe_errors
def do_replace_coupon(user: UserProfile, coupon: Coupon) -> stripe.Customer:
stripe_customer = stripe_get_customer(Customer.objects.get(realm=user.realm).stripe_customer_id)
stripe_customer.coupon = coupon.stripe_coupon_id
return stripe.Customer.save(stripe_customer)
@catch_stripe_errors @catch_stripe_errors
def do_subscribe_customer_to_plan(user: UserProfile, stripe_customer: stripe.Customer, stripe_plan_id: str, def do_subscribe_customer_to_plan(user: UserProfile, stripe_customer: stripe.Customer, stripe_plan_id: str,
seat_count: int, tax_percent: float, charge_automatically: bool) -> None: seat_count: int, tax_percent: float, charge_automatically: bool) -> None:
@ -258,13 +249,13 @@ def process_initial_upgrade(user: UserProfile, plan: Plan, seat_count: int,
charge_automatically=(stripe_token is not None)) charge_automatically=(stripe_token is not None))
do_change_plan_type(user.realm, Realm.STANDARD) do_change_plan_type(user.realm, Realm.STANDARD)
def attach_discount_to_realm(user: UserProfile, percent_off: int) -> None: def attach_discount_to_realm(user: UserProfile, discount: Decimal) -> None:
coupon = Coupon.objects.get(percent_off=percent_off)
customer = Customer.objects.filter(realm=user.realm).first() customer = Customer.objects.filter(realm=user.realm).first()
if customer is None: if customer is None:
do_create_customer(user, coupon=coupon) do_create_customer(user)
else: customer = Customer.objects.filter(realm=user.realm).first()
do_replace_coupon(user, coupon) customer.default_discount = discount
customer.save()
def process_downgrade(user: UserProfile) -> None: # nocoverage def process_downgrade(user: UserProfile) -> None: # nocoverage
pass pass

View File

@ -0,0 +1,20 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.11.16 on 2018-12-12 20:19
from __future__ import unicode_literals
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('corporate', '0001_initial'),
]
operations = [
migrations.AddField(
model_name='customer',
name='default_discount',
field=models.DecimalField(decimal_places=4, max_digits=7, null=True),
),
]

View File

@ -1,4 +1,6 @@
import datetime import datetime
from decimal import Decimal
from typing import Optional
from django.db import models from django.db import models
@ -10,6 +12,7 @@ class Customer(models.Model):
# Becomes True the first time a payment successfully goes through, and never # Becomes True the first time a payment successfully goes through, and never
# goes back to being False # goes back to being False
has_billing_relationship = models.BooleanField(default=False) # type: bool has_billing_relationship = models.BooleanField(default=False) # type: bool
default_discount = models.DecimalField(decimal_places=4, max_digits=7, null=True) # type: Optional[Decimal]
def __str__(self) -> str: def __str__(self) -> str:
return "<Customer %s %s>" % (self.realm, self.stripe_customer_id) return "<Customer %s %s>" % (self.realm, self.stripe_customer_id)
@ -22,6 +25,8 @@ class Plan(models.Model):
stripe_plan_id = models.CharField(max_length=255, unique=True) # type: str stripe_plan_id = models.CharField(max_length=255, unique=True) # type: str
# Everything below here is legacy
class Coupon(models.Model): class Coupon(models.Model):
percent_off = models.SmallIntegerField(unique=True) # type: int percent_off = models.SmallIntegerField(unique=True) # type: int
stripe_coupon_id = models.CharField(max_length=255, unique=True) # type: str stripe_coupon_id = models.CharField(max_length=255, unique=True) # type: str
@ -29,7 +34,6 @@ class Coupon(models.Model):
def __str__(self) -> str: def __str__(self) -> str:
return '<Coupon: %s %s %s>' % (self.percent_off, self.stripe_coupon_id, self.id) return '<Coupon: %s %s %s>' % (self.percent_off, self.stripe_coupon_id, self.id)
# legacy
class BillingProcessor(models.Model): class BillingProcessor(models.Model):
log_row = models.ForeignKey(RealmAuditLog, on_delete=models.CASCADE) # RealmAuditLog log_row = models.ForeignKey(RealmAuditLog, on_delete=models.CASCADE) # RealmAuditLog
# Exactly one processor, the global processor, has realm=None. # Exactly one processor, the global processor, has realm=None.

View File

@ -1,4 +1,5 @@
import datetime import datetime
from decimal import Decimal
from functools import wraps from functools import wraps
from mock import Mock, patch from mock import Mock, patch
import operator import operator
@ -634,27 +635,18 @@ class StripeTest(ZulipTestCase):
def test_attach_discount_to_realm(self, *mocks: Mock) -> None: def test_attach_discount_to_realm(self, *mocks: Mock) -> None:
# Attach discount before Stripe customer exists # Attach discount before Stripe customer exists
user = self.example_user('hamlet') user = self.example_user('hamlet')
attach_discount_to_realm(user, 85) attach_discount_to_realm(user, Decimal(85))
self.login(user.email) self.login(user.email)
# Check that the discount appears in page_params # Check that the discount appears in page_params
self.assert_in_success_response(['85'], self.client_get("/upgrade/")) self.assert_in_success_response(['85'], self.client_get("/upgrade/"))
self.upgrade()
stripe_customer = stripe_get_customer(Customer.objects.get(realm=user.realm).stripe_customer_id)
assert(stripe_customer.discount is not None) # for mypy
self.assertEqual(stripe_customer.discount.coupon.percent_off, 85.0)
# Check that the customer was charged the discounted amount # Check that the customer was charged the discounted amount
charges = stripe.Charge.list(customer=stripe_customer.id) # TODO
for charge in charges:
self.assertEqual(charge.amount, get_seat_count(user.realm) * 80 * 15)
# Check upcoming invoice reflects the discount # Check upcoming invoice reflects the discount
upcoming_invoice = stripe.Invoice.upcoming(customer=stripe_customer.id) # TODO
self.assertEqual(upcoming_invoice.amount_due, get_seat_count(user.realm) * 80 * 15)
# Attach discount to existing Stripe customer # Attach discount to existing Stripe customer
attach_discount_to_realm(user, 25) attach_discount_to_realm(user, Decimal(25))
# Check upcoming invoice reflects the new discount # Check upcoming invoice reflects the new discount
upcoming_invoice = stripe.Invoice.upcoming(customer=stripe_customer.id) # TODO
self.assertEqual(upcoming_invoice.amount_due, get_seat_count(user.realm) * 80 * 75)
@mock_stripe() @mock_stripe()
def test_replace_payment_source(self, *mocks: Mock) -> None: def test_replace_payment_source(self, *mocks: Mock) -> None:

View File

@ -106,10 +106,8 @@ def initial_upgrade(request: HttpRequest) -> HttpResponse:
return HttpResponseRedirect(reverse('corporate.views.billing_home')) return HttpResponseRedirect(reverse('corporate.views.billing_home'))
percent_off = 0 percent_off = 0
if customer is not None: if customer is not None and customer.default_discount is not None:
stripe_customer = stripe_get_customer(customer.stripe_customer_id) percent_off = customer.default_discount
if stripe_customer.discount is not None:
percent_off = stripe_customer.discount.coupon.percent_off
seat_count = get_seat_count(user.realm) seat_count = get_seat_count(user.realm)
signed_seat_count, salt = sign_string(str(seat_count)) signed_seat_count, salt = sign_string(str(seat_count))
@ -130,7 +128,7 @@ def initial_upgrade(request: HttpRequest) -> HttpResponse:
'nickname_monthly': Plan.CLOUD_MONTHLY, 'nickname_monthly': Plan.CLOUD_MONTHLY,
'annual_price': 8000, 'annual_price': 8000,
'monthly_price': 800, 'monthly_price': 800,
'percent_off': percent_off, 'percent_off': float(percent_off),
}), }),
} # type: Dict[str, Any] } # type: Dict[str, Any]
response = render(request, 'corporate/upgrade.html', context=context) response = render(request, 'corporate/upgrade.html', context=context)