diff --git a/corporate/models.py b/corporate/models.py index ef1c3cb274..d0c15388b9 100644 --- a/corporate/models.py +++ b/corporate/models.py @@ -91,6 +91,13 @@ class CustomerPlan(models.Model): def licenses(self) -> int: return LicenseLedger.objects.filter(plan=self).order_by("id").last().licenses + def licenses_at_next_renewal(self) -> Optional[int]: + if self.status == CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE: + return None + return ( + LicenseLedger.objects.filter(plan=self).order_by("id").last().licenses_at_next_renewal + ) + def get_current_plan_by_customer(customer: Customer) -> Optional[CustomerPlan]: return CustomerPlan.objects.filter( diff --git a/corporate/tests/test_stripe.py b/corporate/tests/test_stripe.py index 4ef1dc17f5..9f89f4fb06 100644 --- a/corporate/tests/test_stripe.py +++ b/corporate/tests/test_stripe.py @@ -1779,6 +1779,10 @@ class StripeTest(StripeTestCase): self.login_user(user) with patch("corporate.lib.stripe.timezone_now", return_value=self.now): self.local_upgrade(self.seat_count, True, CustomerPlan.ANNUAL, "token") + plan = get_current_plan_by_realm(user.realm) + assert plan is not None + self.assertEqual(plan.licenses(), self.seat_count) + self.assertEqual(plan.licenses_at_next_renewal(), self.seat_count) with self.assertLogs("corporate.stripe", "INFO") as m: with patch("corporate.views.timezone_now", return_value=self.now): response = self.client_patch( @@ -1790,6 +1794,9 @@ class StripeTest(StripeTestCase): expected_log = f"INFO:corporate.stripe:Change plan status: Customer.id: {stripe_customer_id}, CustomerPlan.id: {new_plan.id}, status: {CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE}" self.assertEqual(m.output[0], expected_log) self.assert_json_success(response) + plan.refresh_from_db() + self.assertEqual(plan.licenses(), self.seat_count) + self.assertEqual(plan.licenses_at_next_renewal(), None) # Verify that we still write LicenseLedger rows during the remaining # part of the cycle @@ -2813,6 +2820,7 @@ class LicenseLedgerTest(StripeTestCase): plan = CustomerPlan.objects.get() self.assertEqual(LicenseLedger.objects.count(), 1) self.assertEqual(plan.licenses(), self.seat_count + 1) + self.assertEqual(plan.licenses_at_next_renewal(), self.seat_count + 1) update_license_ledger_if_needed(realm, self.now) self.assertEqual(LicenseLedger.objects.count(), 1) # Test no active plan @@ -2833,24 +2841,29 @@ class LicenseLedgerTest(StripeTestCase): self.local_upgrade(self.seat_count, True, CustomerPlan.ANNUAL, "token") plan = CustomerPlan.objects.first() self.assertEqual(plan.licenses(), self.seat_count) + self.assertEqual(plan.licenses_at_next_renewal(), self.seat_count) # Simple increase with patch("corporate.lib.stripe.get_latest_seat_count", return_value=23): update_license_ledger_for_automanaged_plan(realm, plan, self.now) self.assertEqual(plan.licenses(), 23) + self.assertEqual(plan.licenses_at_next_renewal(), 23) # Decrease with patch("corporate.lib.stripe.get_latest_seat_count", return_value=20): update_license_ledger_for_automanaged_plan(realm, plan, self.now) self.assertEqual(plan.licenses(), 23) + self.assertEqual(plan.licenses_at_next_renewal(), 20) # Increase, but not past high watermark with patch("corporate.lib.stripe.get_latest_seat_count", return_value=21): update_license_ledger_for_automanaged_plan(realm, plan, self.now) self.assertEqual(plan.licenses(), 23) + self.assertEqual(plan.licenses_at_next_renewal(), 21) # Increase, but after renewal date, and below last year's high watermark with patch("corporate.lib.stripe.get_latest_seat_count", return_value=22): update_license_ledger_for_automanaged_plan( realm, plan, self.next_year + timedelta(seconds=1) ) self.assertEqual(plan.licenses(), 22) + self.assertEqual(plan.licenses_at_next_renewal(), 22) ledger_entries = list( LicenseLedger.objects.values_list(