billing: Rename get_current_plan to get_current_plan_by_customer.

Also add tests.
This commit is contained in:
Vishnu KS 2020-03-24 18:44:03 +05:30 committed by Tim Abbott
parent 9a2c64f3f4
commit 83da23c0d4
4 changed files with 36 additions and 10 deletions

View File

@ -19,7 +19,7 @@ from zerver.lib.timestamp import datetime_to_timestamp, timestamp_to_datetime
from zerver.lib.utils import generate_random_token from zerver.lib.utils import generate_random_token
from zerver.models import Realm, UserProfile, RealmAuditLog from zerver.models import Realm, UserProfile, RealmAuditLog
from corporate.models import Customer, CustomerPlan, LicenseLedger, \ from corporate.models import Customer, CustomerPlan, LicenseLedger, \
get_current_plan, get_customer_by_realm get_current_plan_by_customer, get_customer_by_realm
from zproject.config import get_secret from zproject.config import get_secret
STRIPE_PUBLISHABLE_KEY = get_secret('stripe_publishable_key') STRIPE_PUBLISHABLE_KEY = get_secret('stripe_publishable_key')
@ -279,7 +279,7 @@ def process_initial_upgrade(user: UserProfile, licenses: int, automanage_license
billing_schedule: int, stripe_token: Optional[str]) -> None: billing_schedule: int, stripe_token: Optional[str]) -> None:
realm = user.realm realm = user.realm
customer = update_or_create_stripe_customer(user, stripe_token=stripe_token) customer = update_or_create_stripe_customer(user, stripe_token=stripe_token)
if get_current_plan(customer) is not None: if get_current_plan_by_customer(customer) is not None:
# Unlikely race condition from two people upgrading (clicking "Make payment") # Unlikely race condition from two people upgrading (clicking "Make payment")
# at exactly the same time. Doesn't fully resolve the race condition, but having # at exactly the same time. Doesn't fully resolve the race condition, but having
# a check here reduces the likelihood. # a check here reduces the likelihood.
@ -384,7 +384,7 @@ def update_license_ledger_if_needed(realm: Realm, event_time: datetime) -> None:
customer = get_customer_by_realm(realm) customer = get_customer_by_realm(realm)
if customer is None: if customer is None:
return return
plan = get_current_plan(customer) plan = get_current_plan_by_customer(customer)
if plan is None: if plan is None:
return return
if not plan.automanage_licenses: if not plan.automanage_licenses:
@ -505,7 +505,7 @@ def estimate_annual_recurring_revenue_by_realm() -> Dict[str, int]: # nocoverag
def downgrade_for_realm_deactivation(realm: Realm) -> None: def downgrade_for_realm_deactivation(realm: Realm) -> None:
customer = get_customer_by_realm(realm) customer = get_customer_by_realm(realm)
if customer is not None: if customer is not None:
plan = get_current_plan(customer) plan = get_current_plan_by_customer(customer)
if plan: if plan:
process_downgrade(plan) process_downgrade(plan)
plan.invoiced_through = LicenseLedger.objects.filter(plan=plan).order_by('id').last() plan.invoiced_through = LicenseLedger.objects.filter(plan=plan).order_by('id').last()

View File

@ -61,7 +61,7 @@ class CustomerPlan(models.Model):
# TODO maybe override setattr to ensure billing_cycle_anchor, etc are immutable # TODO maybe override setattr to ensure billing_cycle_anchor, etc are immutable
def get_current_plan(customer: Customer) -> Optional[CustomerPlan]: def get_current_plan_by_customer(customer: Customer) -> Optional[CustomerPlan]:
return CustomerPlan.objects.filter( return CustomerPlan.objects.filter(
customer=customer, status__lt=CustomerPlan.LIVE_STATUS_THRESHOLD).first() customer=customer, status__lt=CustomerPlan.LIVE_STATUS_THRESHOLD).first()

View File

@ -16,6 +16,7 @@ from django.urls.resolvers import get_resolver
from django.http import HttpResponse from django.http import HttpResponse
from django.utils.timezone import utc as timezone_utc from django.utils.timezone import utc as timezone_utc
from django.conf import settings from django.conf import settings
from django.utils.timezone import now as timezone_now
import stripe import stripe
@ -35,7 +36,7 @@ from corporate.lib.stripe import catch_stripe_errors, attach_discount_to_realm,
update_license_ledger_if_needed, update_license_ledger_for_automanaged_plan, \ update_license_ledger_if_needed, update_license_ledger_for_automanaged_plan, \
invoice_plan, invoice_plans_as_needed, get_discount_for_realm invoice_plan, invoice_plans_as_needed, get_discount_for_realm
from corporate.models import Customer, CustomerPlan, LicenseLedger, \ from corporate.models import Customer, CustomerPlan, LicenseLedger, \
get_customer_by_realm get_customer_by_realm, get_current_plan_by_customer
CallableT = TypeVar('CallableT', bound=Callable[..., Any]) CallableT = TypeVar('CallableT', bound=Callable[..., Any])
@ -1156,6 +1157,30 @@ class BillingHelpersTest(ZulipTestCase):
customer = Customer.objects.create(realm=realm, stripe_customer_id='cus_12345') customer = Customer.objects.create(realm=realm, stripe_customer_id='cus_12345')
self.assertEqual(get_customer_by_realm(realm), customer) self.assertEqual(get_customer_by_realm(realm), customer)
def test_get_current_plan_by_customer(self) -> None:
realm = get_realm("zulip")
customer = Customer.objects.create(realm=realm, stripe_customer_id='cus_12345')
self.assertEqual(get_current_plan_by_customer(customer), None)
plan = CustomerPlan.objects.create(customer=customer, status=CustomerPlan.ACTIVE,
billing_cycle_anchor=timezone_now(),
billing_schedule=CustomerPlan.ANNUAL,
tier=CustomerPlan.STANDARD)
self.assertEqual(get_current_plan_by_customer(customer), plan)
plan.status = CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE
plan.save(update_fields=["status"])
self.assertEqual(get_current_plan_by_customer(customer), plan)
plan.status = CustomerPlan.ENDED
plan.save(update_fields=["status"])
self.assertEqual(get_current_plan_by_customer(customer), None)
plan.status = CustomerPlan.NEVER_STARTED
plan.save(update_fields=["status"])
self.assertEqual(get_current_plan_by_customer(customer), None)
class LicenseLedgerTest(StripeTestCase): class LicenseLedgerTest(StripeTestCase):
def test_add_plan_renewal_if_needed(self) -> None: def test_add_plan_renewal_if_needed(self) -> None:
with patch('corporate.lib.stripe.timezone_now', return_value=self.now): with patch('corporate.lib.stripe.timezone_now', return_value=self.now):

View File

@ -23,7 +23,8 @@ from corporate.lib.stripe import STRIPE_PUBLISHABLE_KEY, \
MIN_INVOICED_LICENSES, DEFAULT_INVOICE_DAYS_UNTIL_DUE, \ MIN_INVOICED_LICENSES, DEFAULT_INVOICE_DAYS_UNTIL_DUE, \
start_of_next_billing_cycle, renewal_amount, \ start_of_next_billing_cycle, renewal_amount, \
make_end_of_cycle_updates_if_needed make_end_of_cycle_updates_if_needed
from corporate.models import CustomerPlan, get_current_plan, get_customer_by_realm from corporate.models import CustomerPlan, get_current_plan_by_customer, \
get_customer_by_realm
billing_logger = logging.getLogger('corporate.stripe') billing_logger = logging.getLogger('corporate.stripe')
@ -118,7 +119,7 @@ def initial_upgrade(request: HttpRequest) -> HttpResponse:
user = request.user user = request.user
customer = get_customer_by_realm(user.realm) customer = get_customer_by_realm(user.realm)
if customer is not None and get_current_plan(customer) is not None: if customer is not None and get_current_plan_by_customer(customer) is not None:
return HttpResponseRedirect(reverse('corporate.views.billing_home')) return HttpResponseRedirect(reverse('corporate.views.billing_home'))
percent_off = Decimal(0) percent_off = Decimal(0)
@ -168,7 +169,7 @@ def billing_home(request: HttpRequest) -> HttpResponse:
charge_automatically = False charge_automatically = False
stripe_customer = stripe_get_customer(customer.stripe_customer_id) stripe_customer = stripe_get_customer(customer.stripe_customer_id)
plan = get_current_plan(customer) plan = get_current_plan_by_customer(customer)
if plan is not None: if plan is not None:
plan_name = { plan_name = {
CustomerPlan.STANDARD: 'Zulip Standard', CustomerPlan.STANDARD: 'Zulip Standard',
@ -209,7 +210,7 @@ def change_plan_at_end_of_cycle(request: HttpRequest, user: UserProfile,
customer = get_customer_by_realm(user.realm) customer = get_customer_by_realm(user.realm)
assert(customer is not None) # for mypy assert(customer is not None) # for mypy
plan = get_current_plan(customer) plan = get_current_plan_by_customer(customer)
assert(plan is not None) # for mypy assert(plan is not None) # for mypy
do_change_plan_status(plan, status) do_change_plan_status(plan, status)
return json_success() return json_success()