billing: Make attach discount update the current price per license.

This commit is contained in:
Vishnu KS 2020-12-04 19:44:59 +05:30 committed by Tim Abbott
parent 480288643c
commit bd2642a7b8
23 changed files with 25 additions and 8 deletions

View File

@ -589,6 +589,11 @@ def invoice_plans_as_needed(event_time: datetime=timezone_now()) -> None:
def attach_discount_to_realm(realm: Realm, discount: Decimal) -> None:
Customer.objects.update_or_create(realm=realm, defaults={'default_discount': discount})
plan = get_current_plan_by_realm(realm)
if plan is not None:
plan.price_per_license = get_price_per_license(plan.tier, plan.billing_schedule, discount)
plan.discount = discount
plan.save(update_fields=["price_per_license", "discount"])
def update_sponsorship_status(realm: Realm, sponsorship_pending: bool) -> None:
customer, _ = Customer.objects.get_or_create(realm=realm)

View File

@ -1244,10 +1244,10 @@ class StripeTest(StripeTestCase):
self.assert_in_success_response(['85'], self.client_get("/upgrade/"))
# Check that the customer was charged the discounted amount
self.upgrade()
stripe_customer_id = Customer.objects.values_list('stripe_customer_id', flat=True).first()
[charge] = stripe.Charge.list(customer=stripe_customer_id)
customer = Customer.objects.first()
[charge] = stripe.Charge.list(customer=customer.stripe_customer_id)
self.assertEqual(1200 * self.seat_count, charge.amount)
[invoice] = stripe.Invoice.list(customer=stripe_customer_id)
[invoice] = stripe.Invoice.list(customer=customer.stripe_customer_id)
self.assertEqual([1200 * self.seat_count, -1200 * self.seat_count],
[item.amount for item in invoice.lines])
# Check CustomerPlan reflects the discount
@ -1257,14 +1257,26 @@ class StripeTest(StripeTestCase):
plan.status = CustomerPlan.ENDED
plan.save(update_fields=['status'])
attach_discount_to_realm(user.realm, Decimal(25))
with patch('corporate.lib.stripe.timezone_now', return_value=self.now):
process_initial_upgrade(user, self.seat_count, True, CustomerPlan.ANNUAL, stripe_create_token().id)
[charge0, charge1] = stripe.Charge.list(customer=stripe_customer_id)
self.assertEqual(6000 * self.seat_count, charge0.amount)
[invoice0, invoice1] = stripe.Invoice.list(customer=stripe_customer_id)
[charge, _] = stripe.Charge.list(customer=customer.stripe_customer_id)
self.assertEqual(6000 * self.seat_count, charge.amount)
[invoice, _] = stripe.Invoice.list(customer=customer.stripe_customer_id)
self.assertEqual([6000 * self.seat_count, -6000 * self.seat_count],
[item.amount for item in invoice0.lines])
[item.amount for item in invoice.lines])
plan = CustomerPlan.objects.get(price_per_license=6000, discount=Decimal(25))
attach_discount_to_realm(user.realm, Decimal(50))
plan.refresh_from_db()
self.assertEqual(plan.price_per_license, 4000)
self.assertEqual(plan.discount, 50)
customer.refresh_from_db()
self.assertEqual(customer.default_discount, 50)
invoice_plans_as_needed(self.next_year + timedelta(days=10))
[invoice, _, _] = stripe.Invoice.list(customer=customer.stripe_customer_id)
self.assertEqual([4000 * self.seat_count],
[item.amount for item in invoice.lines])
def test_get_discount_for_realm(self) -> None:
user = self.example_user('hamlet')
self.assertEqual(get_discount_for_realm(user.realm), None)