diff --git a/corporate/lib/stripe.py b/corporate/lib/stripe.py index 0e27530cc4..66dfbed6cb 100644 --- a/corporate/lib/stripe.py +++ b/corporate/lib/stripe.py @@ -230,6 +230,12 @@ class InvalidBillingSchedule(Exception): super().__init__(self.message) +class InvalidTier(Exception): + def __init__(self, tier: int) -> None: + self.message = f"Unknown tier: {tier}" + super().__init__(self.message) + + def catch_stripe_errors(func: CallableT) -> CallableT: @wraps(func) def wrapped(*args: object, **kwargs: object) -> object: @@ -480,16 +486,25 @@ def calculate_discounted_price_per_license( def get_price_per_license( tier: int, billing_schedule: int, discount: Optional[Decimal] = None ) -> int: - # TODO use variables to account for Zulip Plus - assert tier == CustomerPlan.STANDARD - price_per_license: Optional[int] = None - if billing_schedule == CustomerPlan.ANNUAL: - price_per_license = 8000 - elif billing_schedule == CustomerPlan.MONTHLY: - price_per_license = 800 - else: # nocoverage - raise InvalidBillingSchedule(billing_schedule) + + if tier == CustomerPlan.STANDARD: + if billing_schedule == CustomerPlan.ANNUAL: + price_per_license = 8000 + elif billing_schedule == CustomerPlan.MONTHLY: + price_per_license = 800 + else: # nocoverage + raise InvalidBillingSchedule(billing_schedule) + elif tier == CustomerPlan.PLUS: + if billing_schedule == CustomerPlan.ANNUAL: + price_per_license = 16000 + elif billing_schedule == CustomerPlan.MONTHLY: + price_per_license = 1600 + else: # nocoverage + raise InvalidBillingSchedule(billing_schedule) + else: + raise InvalidTier(tier) + if discount is not None: price_per_license = calculate_discounted_price_per_license(price_per_license, discount) return price_per_license diff --git a/corporate/tests/test_stripe.py b/corporate/tests/test_stripe.py index 6989e7599c..4fa4739e87 100644 --- a/corporate/tests/test_stripe.py +++ b/corporate/tests/test_stripe.py @@ -38,6 +38,7 @@ from corporate.lib.stripe import ( MIN_INVOICED_LICENSES, BillingError, InvalidBillingSchedule, + InvalidTier, StripeCardError, add_months, approve_sponsorship, @@ -3225,12 +3226,19 @@ class BillingHelpersTest(ZulipTestCase): 400, ) - with self.assertRaises(AssertionError): - get_price_per_license(CustomerPlan.PLUS, CustomerPlan.MONTHLY) + self.assertEqual(get_price_per_license(CustomerPlan.PLUS, CustomerPlan.ANNUAL), 16000) + self.assertEqual(get_price_per_license(CustomerPlan.PLUS, CustomerPlan.MONTHLY), 1600) + self.assertEqual( + get_price_per_license(CustomerPlan.PLUS, CustomerPlan.MONTHLY, discount=Decimal(50)), + 800, + ) with self.assertRaisesRegex(InvalidBillingSchedule, "Unknown billing_schedule: 1000"): get_price_per_license(CustomerPlan.STANDARD, 1000) + with self.assertRaisesRegex(InvalidTier, "Unknown tier: 10"): + get_price_per_license(CustomerPlan.ENTERPRISE, CustomerPlan.ANNUAL) + def test_update_or_create_stripe_customer_logic(self) -> None: user = self.example_user("hamlet") # No existing Customer object