billing: Create get_current_plan_by_realm helper function.

This commit is contained in:
Vishnu KS 2020-03-24 18:52:27 +05:30 committed by Tim Abbott
parent 83da23c0d4
commit 8b24d40585
4 changed files with 35 additions and 19 deletions

View File

@ -19,7 +19,8 @@ from zerver.lib.timestamp import datetime_to_timestamp, timestamp_to_datetime
from zerver.lib.utils import generate_random_token from zerver.lib.utils import generate_random_token
from zerver.models import Realm, UserProfile, RealmAuditLog from zerver.models import Realm, UserProfile, RealmAuditLog
from corporate.models import Customer, CustomerPlan, LicenseLedger, \ from corporate.models import Customer, CustomerPlan, LicenseLedger, \
get_current_plan_by_customer, get_customer_by_realm get_current_plan_by_customer, get_customer_by_realm, \
get_current_plan_by_realm
from zproject.config import get_secret from zproject.config import get_secret
STRIPE_PUBLISHABLE_KEY = get_secret('stripe_publishable_key') STRIPE_PUBLISHABLE_KEY = get_secret('stripe_publishable_key')
@ -381,10 +382,7 @@ def update_license_ledger_for_automanaged_plan(realm: Realm, plan: CustomerPlan,
licenses_at_next_renewal=licenses_at_next_renewal) licenses_at_next_renewal=licenses_at_next_renewal)
def update_license_ledger_if_needed(realm: Realm, event_time: datetime) -> None: def update_license_ledger_if_needed(realm: Realm, event_time: datetime) -> None:
customer = get_customer_by_realm(realm) plan = get_current_plan_by_realm(realm)
if customer is None:
return
plan = get_current_plan_by_customer(customer)
if plan is None: if plan is None:
return return
if not plan.automanage_licenses: if not plan.automanage_licenses:
@ -503,11 +501,11 @@ def estimate_annual_recurring_revenue_by_realm() -> Dict[str, int]: # nocoverag
# During realm deactivation we instantly downgrade the plan to Limited. # During realm deactivation we instantly downgrade the plan to Limited.
# Extra users added in the final month are not charged. # Extra users added in the final month are not charged.
def downgrade_for_realm_deactivation(realm: Realm) -> None: def downgrade_for_realm_deactivation(realm: Realm) -> None:
customer = get_customer_by_realm(realm) plan = get_current_plan_by_realm(realm)
if customer is not None: if plan is None:
plan = get_current_plan_by_customer(customer) return
if plan:
process_downgrade(plan) process_downgrade(plan)
plan.invoiced_through = LicenseLedger.objects.filter(plan=plan).order_by('id').last() plan.invoiced_through = LicenseLedger.objects.filter(plan=plan).order_by('id').last()
plan.next_invoice_date = next_invoice_date(plan) plan.next_invoice_date = next_invoice_date(plan)
plan.save(update_fields=["invoiced_through", "next_invoice_date"]) plan.save(update_fields=["invoiced_through", "next_invoice_date"])

View File

@ -65,6 +65,12 @@ def get_current_plan_by_customer(customer: Customer) -> Optional[CustomerPlan]:
return CustomerPlan.objects.filter( return CustomerPlan.objects.filter(
customer=customer, status__lt=CustomerPlan.LIVE_STATUS_THRESHOLD).first() customer=customer, status__lt=CustomerPlan.LIVE_STATUS_THRESHOLD).first()
def get_current_plan_by_realm(realm: Realm) -> Optional[CustomerPlan]:
customer = get_customer_by_realm(realm)
if customer is None:
return None
return get_current_plan_by_customer(customer)
class LicenseLedger(models.Model): class LicenseLedger(models.Model):
plan = models.ForeignKey(CustomerPlan, on_delete=CASCADE) # type: CustomerPlan plan = models.ForeignKey(CustomerPlan, on_delete=CASCADE) # type: CustomerPlan
# Also True for the initial upgrade. # Also True for the initial upgrade.

View File

@ -36,7 +36,8 @@ from corporate.lib.stripe import catch_stripe_errors, attach_discount_to_realm,
update_license_ledger_if_needed, update_license_ledger_for_automanaged_plan, \ update_license_ledger_if_needed, update_license_ledger_for_automanaged_plan, \
invoice_plan, invoice_plans_as_needed, get_discount_for_realm invoice_plan, invoice_plans_as_needed, get_discount_for_realm
from corporate.models import Customer, CustomerPlan, LicenseLedger, \ from corporate.models import Customer, CustomerPlan, LicenseLedger, \
get_customer_by_realm, get_current_plan_by_customer get_customer_by_realm, get_current_plan_by_customer, \
get_current_plan_by_realm
CallableT = TypeVar('CallableT', bound=Callable[..., Any]) CallableT = TypeVar('CallableT', bound=Callable[..., Any])
@ -1181,6 +1182,20 @@ class BillingHelpersTest(ZulipTestCase):
plan.save(update_fields=["status"]) plan.save(update_fields=["status"])
self.assertEqual(get_current_plan_by_customer(customer), None) self.assertEqual(get_current_plan_by_customer(customer), None)
def test_get_current_plan_by_realm(self) -> None:
realm = get_realm("zulip")
self.assertEqual(get_current_plan_by_realm(realm), None)
customer = Customer.objects.create(realm=realm, stripe_customer_id='cus_12345')
self.assertEqual(get_current_plan_by_realm(realm), None)
plan = CustomerPlan.objects.create(customer=customer, status=CustomerPlan.ACTIVE,
billing_cycle_anchor=timezone_now(),
billing_schedule=CustomerPlan.ANNUAL,
tier=CustomerPlan.STANDARD)
self.assertEqual(get_current_plan_by_realm(realm), plan)
class LicenseLedgerTest(StripeTestCase): class LicenseLedgerTest(StripeTestCase):
def test_add_plan_renewal_if_needed(self) -> None: def test_add_plan_renewal_if_needed(self) -> None:
with patch('corporate.lib.stripe.timezone_now', return_value=self.now): with patch('corporate.lib.stripe.timezone_now', return_value=self.now):

View File

@ -24,7 +24,7 @@ from corporate.lib.stripe import STRIPE_PUBLISHABLE_KEY, \
start_of_next_billing_cycle, renewal_amount, \ start_of_next_billing_cycle, renewal_amount, \
make_end_of_cycle_updates_if_needed make_end_of_cycle_updates_if_needed
from corporate.models import CustomerPlan, get_current_plan_by_customer, \ from corporate.models import CustomerPlan, get_current_plan_by_customer, \
get_customer_by_realm get_customer_by_realm, get_current_plan_by_realm
billing_logger = logging.getLogger('corporate.stripe') billing_logger = logging.getLogger('corporate.stripe')
@ -207,10 +207,7 @@ def billing_home(request: HttpRequest) -> HttpResponse:
def change_plan_at_end_of_cycle(request: HttpRequest, user: UserProfile, def change_plan_at_end_of_cycle(request: HttpRequest, user: UserProfile,
status: int=REQ("status", validator=check_int)) -> HttpResponse: status: int=REQ("status", validator=check_int)) -> HttpResponse:
assert(status in [CustomerPlan.ACTIVE, CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE]) assert(status in [CustomerPlan.ACTIVE, CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE])
customer = get_customer_by_realm(user.realm) plan = get_current_plan_by_realm(user.realm)
assert(customer is not None) # for mypy
plan = get_current_plan_by_customer(customer)
assert(plan is not None) # for mypy assert(plan is not None) # for mypy
do_change_plan_status(plan, status) do_change_plan_status(plan, status)
return json_success() return json_success()