diff --git a/corporate/lib/stripe.py b/corporate/lib/stripe.py index a78e3334fd..529942c3ab 100644 --- a/corporate/lib/stripe.py +++ b/corporate/lib/stripe.py @@ -700,9 +700,12 @@ def process_initial_upgrade( # TODO: The correctness of this relies on user creation, deactivation, etc being # in a transaction.atomic() with the relevant RealmAuditLog entries with transaction.atomic(): - # billed_licenses can greater than licenses if users are added between the start of - # this function (process_initial_upgrade) and now - billed_licenses = max(get_latest_seat_count(realm), licenses) + if customer.exempt_from_license_number_check: + billed_licenses = licenses + else: + # billed_licenses can be greater than licenses if users are added between the start of + # this function (process_initial_upgrade) and now + billed_licenses = max(get_latest_seat_count(realm), licenses) plan_params = { "automanage_licenses": automanage_licenses, "charge_automatically": charge_automatically, @@ -777,14 +780,16 @@ def update_license_ledger_for_manual_plan( ) -> None: if licenses is not None: assert plan.customer.realm is not None - assert get_latest_seat_count(plan.customer.realm) <= licenses + if not plan.customer.exempt_from_license_number_check: + assert get_latest_seat_count(plan.customer.realm) <= licenses assert licenses > plan.licenses() LicenseLedger.objects.create( plan=plan, event_time=event_time, licenses=licenses, licenses_at_next_renewal=licenses ) elif licenses_at_next_renewal is not None: assert plan.customer.realm is not None - assert get_latest_seat_count(plan.customer.realm) <= licenses_at_next_renewal + if not plan.customer.exempt_from_license_number_check: + assert get_latest_seat_count(plan.customer.realm) <= licenses_at_next_renewal LicenseLedger.objects.create( plan=plan, event_time=event_time, diff --git a/corporate/tests/test_stripe.py b/corporate/tests/test_stripe.py index 32a6dc250c..1834d05a3b 100644 --- a/corporate/tests/test_stripe.py +++ b/corporate/tests/test_stripe.py @@ -3397,6 +3397,62 @@ class StripeTest(StripeTestCase): for key, value in line_item_params.items(): self.assertEqual(renewal_item.get(key), value) + def test_update_licenses_of_manual_plan_from_billing_page_exempt_from_license_number_check( + self, + ) -> None: + """ + Verifies that an organization exempt from the license number check can reduce their number + of licenses. + """ + user = self.example_user("hamlet") + self.login_user(user) + + customer = Customer.objects.get_or_create(realm=user.realm)[0] + customer.exempt_from_license_number_check = True + customer.save() + + with patch("corporate.lib.stripe.timezone_now", return_value=self.now): + self.local_upgrade(100, False, CustomerPlan.ANNUAL, True, False) + + with patch("corporate.views.billing_page.timezone_now", return_value=self.now): + result = self.client_patch( + "/json/billing/plan", + {"licenses_at_next_renewal": get_latest_seat_count(user.realm) - 2}, + ) + + self.assert_json_success(result) + latest_license_ledger = LicenseLedger.objects.last() + assert latest_license_ledger is not None + self.assertEqual( + latest_license_ledger.licenses_at_next_renewal, get_latest_seat_count(user.realm) - 2 + ) + + def test_upgrade_exempt_from_license_number_check_realm_less_licenses_than_seat_count( + self, + ) -> None: + """ + Verifies that an organization exempt from the license number check can upgrade their plan, + specifying a number of licenses less than their current number of licenses and be charged + for the number of licenses specified. Tests against a former bug, where the organization + was charged for the current seat count, despite specifying a lower number of licenses. + """ + user = self.example_user("hamlet") + self.login_user(user) + + customer = Customer.objects.get_or_create(realm=user.realm)[0] + customer.exempt_from_license_number_check = True + customer.save() + + reduced_seat_count = get_latest_seat_count(user.realm) - 2 + + with patch("corporate.lib.stripe.timezone_now", return_value=self.now): + self.local_upgrade(reduced_seat_count, False, CustomerPlan.ANNUAL, True, False) + + latest_license_ledger = LicenseLedger.objects.last() + assert latest_license_ledger is not None + self.assertEqual(latest_license_ledger.licenses_at_next_renewal, reduced_seat_count) + self.assertEqual(latest_license_ledger.licenses, reduced_seat_count) + def test_update_licenses_of_automatic_plan_from_billing_page(self) -> None: user = self.example_user("hamlet") self.login_user(user)