diff --git a/corporate/lib/stripe.py b/corporate/lib/stripe.py index 92c614fbc5..6d19982d83 100644 --- a/corporate/lib/stripe.py +++ b/corporate/lib/stripe.py @@ -19,7 +19,7 @@ from zerver.lib.timestamp import datetime_to_timestamp, timestamp_to_datetime from zerver.lib.utils import generate_random_token from zerver.models import Realm, UserProfile, RealmAuditLog from corporate.models import Customer, CustomerPlan, LicenseLedger, \ - get_current_plan + get_current_plan, get_customer_by_realm from zproject.config import get_secret STRIPE_PUBLISHABLE_KEY = get_secret('stripe_publishable_key') @@ -199,7 +199,10 @@ def do_create_stripe_customer(user: UserProfile, stripe_token: Optional[str]=Non @catch_stripe_errors def do_replace_payment_source(user: UserProfile, stripe_token: str, pay_invoices: bool=False) -> stripe.Customer: - stripe_customer = stripe_get_customer(Customer.objects.get(realm=user.realm).stripe_customer_id) + customer = get_customer_by_realm(user.realm) + assert(customer is not None) # for mypy + + stripe_customer = stripe_get_customer(customer.stripe_customer_id) stripe_customer.source = stripe_token # Deletes existing card: https://stripe.com/docs/api#update_customer-source updated_stripe_customer = stripe.Customer.save(stripe_customer) @@ -239,7 +242,7 @@ def make_end_of_cycle_updates_if_needed(plan: CustomerPlan, # API call if there's nothing to update def update_or_create_stripe_customer(user: UserProfile, stripe_token: Optional[str]=None) -> Customer: realm = user.realm - customer = Customer.objects.filter(realm=realm).first() + customer = get_customer_by_realm(realm) if customer is None or customer.stripe_customer_id is None: return do_create_stripe_customer(user, stripe_token=stripe_token) if stripe_token is not None: @@ -378,7 +381,7 @@ def update_license_ledger_for_automanaged_plan(realm: Realm, plan: CustomerPlan, licenses_at_next_renewal=licenses_at_next_renewal) def update_license_ledger_if_needed(realm: Realm, event_time: datetime) -> None: - customer = Customer.objects.filter(realm=realm).first() + customer = get_customer_by_realm(realm) if customer is None: return plan = get_current_plan(customer) @@ -467,7 +470,7 @@ def attach_discount_to_realm(realm: Realm, discount: Decimal) -> None: Customer.objects.update_or_create(realm=realm, defaults={'default_discount': discount}) def get_discount_for_realm(realm: Realm) -> Optional[Decimal]: - customer = Customer.objects.filter(realm=realm).first() + customer = get_customer_by_realm(realm) if customer is not None: return customer.default_discount return None @@ -500,7 +503,7 @@ def estimate_annual_recurring_revenue_by_realm() -> Dict[str, int]: # nocoverag # During realm deactivation we instantly downgrade the plan to Limited. # Extra users added in the final month are not charged. def downgrade_for_realm_deactivation(realm: Realm) -> None: - customer = Customer.objects.filter(realm=realm).first() + customer = get_customer_by_realm(realm) if customer is not None: plan = get_current_plan(customer) if plan: diff --git a/corporate/models.py b/corporate/models.py index 235f9b4676..663f0b5646 100644 --- a/corporate/models.py +++ b/corporate/models.py @@ -16,6 +16,9 @@ class Customer(models.Model): def __str__(self) -> str: return "" % (self.realm, self.stripe_customer_id) +def get_customer_by_realm(realm: Realm) -> Optional[Customer]: + return Customer.objects.filter(realm=realm).first() + class CustomerPlan(models.Model): customer = models.ForeignKey(Customer, on_delete=CASCADE) # type: Customer automanage_licenses = models.BooleanField(default=False) # type: bool diff --git a/corporate/tests/test_stripe.py b/corporate/tests/test_stripe.py index 4ae30b93db..a754e58d48 100644 --- a/corporate/tests/test_stripe.py +++ b/corporate/tests/test_stripe.py @@ -34,7 +34,8 @@ from corporate.lib.stripe import catch_stripe_errors, attach_discount_to_realm, process_initial_upgrade, make_end_of_cycle_updates_if_needed, \ update_license_ledger_if_needed, update_license_ledger_for_automanaged_plan, \ invoice_plan, invoice_plans_as_needed, get_discount_for_realm -from corporate.models import Customer, CustomerPlan, LicenseLedger +from corporate.models import Customer, CustomerPlan, LicenseLedger, \ + get_customer_by_realm CallableT = TypeVar('CallableT', bound=Callable[..., Any]) @@ -1147,6 +1148,14 @@ class BillingHelpersTest(ZulipTestCase): mocked3.assert_not_called() self.assertTrue(isinstance(customer, Customer)) + def test_get_customer_by_realm(self) -> None: + realm = get_realm('zulip') + + self.assertEqual(get_customer_by_realm(realm), None) + + customer = Customer.objects.create(realm=realm, stripe_customer_id='cus_12345') + self.assertEqual(get_customer_by_realm(realm), customer) + class LicenseLedgerTest(StripeTestCase): def test_add_plan_renewal_if_needed(self) -> None: with patch('corporate.lib.stripe.timezone_now', return_value=self.now): diff --git a/corporate/views.py b/corporate/views.py index ba2be25fdb..56050a1897 100644 --- a/corporate/views.py +++ b/corporate/views.py @@ -1,4 +1,5 @@ import logging +from decimal import Decimal import stripe from typing import Any, Dict, cast, Optional, Union @@ -22,8 +23,7 @@ from corporate.lib.stripe import STRIPE_PUBLISHABLE_KEY, \ MIN_INVOICED_LICENSES, DEFAULT_INVOICE_DAYS_UNTIL_DUE, \ start_of_next_billing_cycle, renewal_amount, \ make_end_of_cycle_updates_if_needed -from corporate.models import Customer, CustomerPlan, \ - get_current_plan +from corporate.models import CustomerPlan, get_current_plan, get_customer_by_realm billing_logger = logging.getLogger('corporate.stripe') @@ -117,11 +117,11 @@ def initial_upgrade(request: HttpRequest) -> HttpResponse: return render(request, "404.html") user = request.user - customer = Customer.objects.filter(realm=user.realm).first() + customer = get_customer_by_realm(user.realm) if customer is not None and get_current_plan(customer) is not None: return HttpResponseRedirect(reverse('corporate.views.billing_home')) - percent_off = 0 + percent_off = Decimal(0) if customer is not None and customer.default_discount is not None: percent_off = customer.default_discount @@ -149,7 +149,7 @@ def initial_upgrade(request: HttpRequest) -> HttpResponse: @zulip_login_required def billing_home(request: HttpRequest) -> HttpResponse: user = request.user - customer = Customer.objects.filter(realm=user.realm).first() + customer = get_customer_by_realm(user.realm) if customer is None: return HttpResponseRedirect(reverse('corporate.views.initial_upgrade')) if not CustomerPlan.objects.filter(customer=customer).exists(): @@ -206,7 +206,10 @@ def billing_home(request: HttpRequest) -> HttpResponse: def change_plan_at_end_of_cycle(request: HttpRequest, user: UserProfile, status: int=REQ("status", validator=check_int)) -> HttpResponse: assert(status in [CustomerPlan.ACTIVE, CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE]) - plan = get_current_plan(Customer.objects.get(realm=user.realm)) + customer = get_customer_by_realm(user.realm) + assert(customer is not None) # for mypy + + plan = get_current_plan(customer) assert(plan is not None) # for mypy do_change_plan_status(plan, status) return json_success()