billing: Create get_customer_by_realm helper function.

This commit is contained in:
Vishnu KS 2020-03-23 18:05:04 +05:30 committed by Tim Abbott
parent f8ddab58ba
commit 9a2c64f3f4
4 changed files with 31 additions and 13 deletions

View File

@ -19,7 +19,7 @@ from zerver.lib.timestamp import datetime_to_timestamp, timestamp_to_datetime
from zerver.lib.utils import generate_random_token
from zerver.models import Realm, UserProfile, RealmAuditLog
from corporate.models import Customer, CustomerPlan, LicenseLedger, \
get_current_plan
get_current_plan, get_customer_by_realm
from zproject.config import get_secret
STRIPE_PUBLISHABLE_KEY = get_secret('stripe_publishable_key')
@ -199,7 +199,10 @@ def do_create_stripe_customer(user: UserProfile, stripe_token: Optional[str]=Non
@catch_stripe_errors
def do_replace_payment_source(user: UserProfile, stripe_token: str,
pay_invoices: bool=False) -> stripe.Customer:
stripe_customer = stripe_get_customer(Customer.objects.get(realm=user.realm).stripe_customer_id)
customer = get_customer_by_realm(user.realm)
assert(customer is not None) # for mypy
stripe_customer = stripe_get_customer(customer.stripe_customer_id)
stripe_customer.source = stripe_token
# Deletes existing card: https://stripe.com/docs/api#update_customer-source
updated_stripe_customer = stripe.Customer.save(stripe_customer)
@ -239,7 +242,7 @@ def make_end_of_cycle_updates_if_needed(plan: CustomerPlan,
# API call if there's nothing to update
def update_or_create_stripe_customer(user: UserProfile, stripe_token: Optional[str]=None) -> Customer:
realm = user.realm
customer = Customer.objects.filter(realm=realm).first()
customer = get_customer_by_realm(realm)
if customer is None or customer.stripe_customer_id is None:
return do_create_stripe_customer(user, stripe_token=stripe_token)
if stripe_token is not None:
@ -378,7 +381,7 @@ def update_license_ledger_for_automanaged_plan(realm: Realm, plan: CustomerPlan,
licenses_at_next_renewal=licenses_at_next_renewal)
def update_license_ledger_if_needed(realm: Realm, event_time: datetime) -> None:
customer = Customer.objects.filter(realm=realm).first()
customer = get_customer_by_realm(realm)
if customer is None:
return
plan = get_current_plan(customer)
@ -467,7 +470,7 @@ def attach_discount_to_realm(realm: Realm, discount: Decimal) -> None:
Customer.objects.update_or_create(realm=realm, defaults={'default_discount': discount})
def get_discount_for_realm(realm: Realm) -> Optional[Decimal]:
customer = Customer.objects.filter(realm=realm).first()
customer = get_customer_by_realm(realm)
if customer is not None:
return customer.default_discount
return None
@ -500,7 +503,7 @@ def estimate_annual_recurring_revenue_by_realm() -> Dict[str, int]: # nocoverag
# During realm deactivation we instantly downgrade the plan to Limited.
# Extra users added in the final month are not charged.
def downgrade_for_realm_deactivation(realm: Realm) -> None:
customer = Customer.objects.filter(realm=realm).first()
customer = get_customer_by_realm(realm)
if customer is not None:
plan = get_current_plan(customer)
if plan:

View File

@ -16,6 +16,9 @@ class Customer(models.Model):
def __str__(self) -> str:
return "<Customer %s %s>" % (self.realm, self.stripe_customer_id)
def get_customer_by_realm(realm: Realm) -> Optional[Customer]:
return Customer.objects.filter(realm=realm).first()
class CustomerPlan(models.Model):
customer = models.ForeignKey(Customer, on_delete=CASCADE) # type: Customer
automanage_licenses = models.BooleanField(default=False) # type: bool

View File

@ -34,7 +34,8 @@ from corporate.lib.stripe import catch_stripe_errors, attach_discount_to_realm,
process_initial_upgrade, make_end_of_cycle_updates_if_needed, \
update_license_ledger_if_needed, update_license_ledger_for_automanaged_plan, \
invoice_plan, invoice_plans_as_needed, get_discount_for_realm
from corporate.models import Customer, CustomerPlan, LicenseLedger
from corporate.models import Customer, CustomerPlan, LicenseLedger, \
get_customer_by_realm
CallableT = TypeVar('CallableT', bound=Callable[..., Any])
@ -1147,6 +1148,14 @@ class BillingHelpersTest(ZulipTestCase):
mocked3.assert_not_called()
self.assertTrue(isinstance(customer, Customer))
def test_get_customer_by_realm(self) -> None:
realm = get_realm('zulip')
self.assertEqual(get_customer_by_realm(realm), None)
customer = Customer.objects.create(realm=realm, stripe_customer_id='cus_12345')
self.assertEqual(get_customer_by_realm(realm), customer)
class LicenseLedgerTest(StripeTestCase):
def test_add_plan_renewal_if_needed(self) -> None:
with patch('corporate.lib.stripe.timezone_now', return_value=self.now):

View File

@ -1,4 +1,5 @@
import logging
from decimal import Decimal
import stripe
from typing import Any, Dict, cast, Optional, Union
@ -22,8 +23,7 @@ from corporate.lib.stripe import STRIPE_PUBLISHABLE_KEY, \
MIN_INVOICED_LICENSES, DEFAULT_INVOICE_DAYS_UNTIL_DUE, \
start_of_next_billing_cycle, renewal_amount, \
make_end_of_cycle_updates_if_needed
from corporate.models import Customer, CustomerPlan, \
get_current_plan
from corporate.models import CustomerPlan, get_current_plan, get_customer_by_realm
billing_logger = logging.getLogger('corporate.stripe')
@ -117,11 +117,11 @@ def initial_upgrade(request: HttpRequest) -> HttpResponse:
return render(request, "404.html")
user = request.user
customer = Customer.objects.filter(realm=user.realm).first()
customer = get_customer_by_realm(user.realm)
if customer is not None and get_current_plan(customer) is not None:
return HttpResponseRedirect(reverse('corporate.views.billing_home'))
percent_off = 0
percent_off = Decimal(0)
if customer is not None and customer.default_discount is not None:
percent_off = customer.default_discount
@ -149,7 +149,7 @@ def initial_upgrade(request: HttpRequest) -> HttpResponse:
@zulip_login_required
def billing_home(request: HttpRequest) -> HttpResponse:
user = request.user
customer = Customer.objects.filter(realm=user.realm).first()
customer = get_customer_by_realm(user.realm)
if customer is None:
return HttpResponseRedirect(reverse('corporate.views.initial_upgrade'))
if not CustomerPlan.objects.filter(customer=customer).exists():
@ -206,7 +206,10 @@ def billing_home(request: HttpRequest) -> HttpResponse:
def change_plan_at_end_of_cycle(request: HttpRequest, user: UserProfile,
status: int=REQ("status", validator=check_int)) -> HttpResponse:
assert(status in [CustomerPlan.ACTIVE, CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE])
plan = get_current_plan(Customer.objects.get(realm=user.realm))
customer = get_customer_by_realm(user.realm)
assert(customer is not None) # for mypy
plan = get_current_plan(customer)
assert(plan is not None) # for mypy
do_change_plan_status(plan, status)
return json_success()