diff --git a/corporate/lib/stripe.py b/corporate/lib/stripe.py index a6bd9b416f..dfaa0e22eb 100644 --- a/corporate/lib/stripe.py +++ b/corporate/lib/stripe.py @@ -1889,7 +1889,7 @@ class BillingSession(ABC): self.current_count_for_billed_licenses(), plan.customer.exempt_from_license_number_check, ) - update_license_ledger_for_manual_plan(plan, timezone_now(), licenses=licenses) + self.update_license_ledger_for_manual_plan(plan, timezone_now(), licenses=licenses) return licenses_at_next_renewal = update_plan_request.licenses_at_next_renewal @@ -1912,7 +1912,7 @@ class BillingSession(ABC): self.current_count_for_billed_licenses(), plan.customer.exempt_from_license_number_check, ) - update_license_ledger_for_manual_plan( + self.update_license_ledger_for_manual_plan( plan, timezone_now(), licenses_at_next_renewal=licenses_at_next_renewal ) return @@ -2282,6 +2282,35 @@ class BillingSession(ABC): return success_message + def update_license_ledger_for_manual_plan( + self, + plan: CustomerPlan, + event_time: datetime, + licenses: Optional[int] = None, + licenses_at_next_renewal: Optional[int] = None, + ) -> None: + if licenses is not None: + if not plan.customer.exempt_from_license_number_check: + assert self.current_count_for_billed_licenses() <= licenses + assert licenses > plan.licenses() + LicenseLedger.objects.create( + plan=plan, + event_time=event_time, + licenses=licenses, + licenses_at_next_renewal=licenses, + ) + elif licenses_at_next_renewal is not None: + if not plan.customer.exempt_from_license_number_check: + assert self.current_count_for_billed_licenses() <= licenses_at_next_renewal + LicenseLedger.objects.create( + plan=plan, + event_time=event_time, + licenses=plan.licenses(), + licenses_at_next_renewal=licenses_at_next_renewal, + ) + else: + raise AssertionError("Pass licenses or licenses_at_next_renewal") + class RealmBillingSession(BillingSession): def __init__( @@ -3364,34 +3393,6 @@ def do_deactivate_remote_server(remote_server: RemoteZulipServer) -> None: ) -def update_license_ledger_for_manual_plan( - plan: CustomerPlan, - event_time: datetime, - licenses: Optional[int] = None, - licenses_at_next_renewal: Optional[int] = None, -) -> None: - if licenses is not None: - assert plan.customer.realm is not None - if not plan.customer.exempt_from_license_number_check: - assert get_latest_seat_count(plan.customer.realm) <= licenses - assert licenses > plan.licenses() - LicenseLedger.objects.create( - plan=plan, event_time=event_time, licenses=licenses, licenses_at_next_renewal=licenses - ) - elif licenses_at_next_renewal is not None: - assert plan.customer.realm is not None - if not plan.customer.exempt_from_license_number_check: - assert get_latest_seat_count(plan.customer.realm) <= licenses_at_next_renewal - LicenseLedger.objects.create( - plan=plan, - event_time=event_time, - licenses=plan.licenses(), - licenses_at_next_renewal=licenses_at_next_renewal, - ) - else: - raise AssertionError("Pass licenses or licenses_at_next_renewal") - - def update_license_ledger_for_automanaged_plan( realm: Realm, plan: CustomerPlan, event_time: datetime ) -> None: diff --git a/corporate/tests/test_stripe.py b/corporate/tests/test_stripe.py index 30b4e996db..1c07516a00 100644 --- a/corporate/tests/test_stripe.py +++ b/corporate/tests/test_stripe.py @@ -73,7 +73,6 @@ from corporate.lib.stripe import ( stripe_get_customer, unsign_string, update_license_ledger_for_automanaged_plan, - update_license_ledger_for_manual_plan, update_license_ledger_if_needed, ) from corporate.models import ( @@ -4765,20 +4764,25 @@ class LicenseLedgerTest(StripeTestCase): self.seat_count + 1, False, CustomerPlan.BILLING_SCHEDULE_ANNUAL, True, False ) + billing_session = RealmBillingSession(user=None, realm=realm) plan = get_current_plan_by_realm(realm) assert plan is not None with patch("corporate.lib.stripe.get_latest_seat_count", return_value=self.seat_count): - update_license_ledger_for_manual_plan(plan, self.now, licenses=self.seat_count + 3) + billing_session.update_license_ledger_for_manual_plan( + plan, self.now, licenses=self.seat_count + 3 + ) self.assertEqual(plan.licenses(), self.seat_count + 3) self.assertEqual(plan.licenses_at_next_renewal(), self.seat_count + 3) with patch("corporate.lib.stripe.get_latest_seat_count", return_value=self.seat_count): with self.assertRaises(AssertionError): - update_license_ledger_for_manual_plan(plan, self.now, licenses=self.seat_count) + billing_session.update_license_ledger_for_manual_plan( + plan, self.now, licenses=self.seat_count + ) with patch("corporate.lib.stripe.get_latest_seat_count", return_value=self.seat_count): - update_license_ledger_for_manual_plan( + billing_session.update_license_ledger_for_manual_plan( plan, self.now, licenses_at_next_renewal=self.seat_count ) self.assertEqual(plan.licenses(), self.seat_count + 3) @@ -4786,16 +4790,17 @@ class LicenseLedgerTest(StripeTestCase): with patch("corporate.lib.stripe.get_latest_seat_count", return_value=self.seat_count): with self.assertRaises(AssertionError): - update_license_ledger_for_manual_plan( + billing_session.update_license_ledger_for_manual_plan( plan, self.now, licenses_at_next_renewal=self.seat_count - 1 ) with patch("corporate.lib.stripe.get_latest_seat_count", return_value=self.seat_count): - update_license_ledger_for_manual_plan(plan, self.now, licenses=self.seat_count + 10) + billing_session.update_license_ledger_for_manual_plan( + plan, self.now, licenses=self.seat_count + 10 + ) self.assertEqual(plan.licenses(), self.seat_count + 10) self.assertEqual(plan.licenses_at_next_renewal(), self.seat_count + 10) - billing_session = RealmBillingSession(user=None, realm=realm) billing_session.make_end_of_cycle_updates_if_needed(plan, self.next_year) self.assertEqual(plan.licenses(), self.seat_count + 10) @@ -4817,7 +4822,7 @@ class LicenseLedgerTest(StripeTestCase): ) with self.assertRaises(AssertionError): - update_license_ledger_for_manual_plan(plan, self.now) + billing_session.update_license_ledger_for_manual_plan(plan, self.now) def test_user_changes(self) -> None: self.local_upgrade(self.seat_count, True, CustomerPlan.BILLING_SCHEDULE_ANNUAL, True, False)