diff --git a/corporate/lib/stripe.py b/corporate/lib/stripe.py index e59eef8a09..9704a1ba2e 100644 --- a/corporate/lib/stripe.py +++ b/corporate/lib/stripe.py @@ -243,6 +243,7 @@ def do_replace_payment_source(user: UserProfile, stripe_token: str, # event_time should roughly be timezone_now(). Not designed to handle # event_times in the past or future +@transaction.atomic def make_end_of_cycle_updates_if_needed(plan: CustomerPlan, event_time: datetime) -> Tuple[Optional[CustomerPlan], Optional[LicenseLedger]]: last_ledger_entry = LicenseLedger.objects.filter(plan=plan).order_by('-id').first() diff --git a/corporate/tests/test_stripe.py b/corporate/tests/test_stripe.py index 3bd16fb407..4b03168d9e 100644 --- a/corporate/tests/test_stripe.py +++ b/corporate/tests/test_stripe.py @@ -1374,6 +1374,10 @@ class StripeTest(StripeTestCase): self.assertEqual(annual_ledger_entries.values_list('licenses', 'licenses_at_next_renewal')[0], (20, 20)) self.assertEqual(annual_ledger_entries[1].is_renewal, False) self.assertEqual(annual_ledger_entries.values_list('licenses', 'licenses_at_next_renewal')[1], (25, 25)) + audit_log = RealmAuditLog.objects.get(event_type=RealmAuditLog.CUSTOMER_SWITCHED_FROM_MONTHLY_TO_ANNUAL_PLAN) + self.assertEqual(audit_log.realm, user.realm) + self.assertEqual(ujson.loads(audit_log.extra_data)["monthly_plan_id"], monthly_plan.id) + self.assertEqual(ujson.loads(audit_log.extra_data)["annual_plan_id"], annual_plan.id) invoice_plans_as_needed(self.next_month) @@ -1491,8 +1495,7 @@ class StripeTest(StripeTestCase): response = self.client_get("/billing/") self.assert_in_success_response(["be switched from monthly to annual billing on February 2, 2012"], response) - with patch('corporate.lib.stripe.timezone_now', return_value=self.next_month): - invoice_plans_as_needed(self.next_month) + invoice_plans_as_needed(self.next_month) self.assertEqual(LicenseLedger.objects.filter(plan=monthly_plan).count(), 1) customer = get_customer_by_realm(user.realm) @@ -1514,8 +1517,9 @@ class StripeTest(StripeTestCase): self.assertEqual(annual_ledger_entries.values_list('licenses', 'licenses_at_next_renewal')[0], (num_licenses, num_licenses)) self.assertEqual(annual_plan.invoiced_through, None) - with patch('corporate.lib.stripe.timezone_now', return_value=self.next_month): - invoice_plans_as_needed(self.next_month + timedelta(days=1)) + # First call of invoice_plans_as_needed creates the new plan. Second call + # calls invoice_plan on the newly created plan. + invoice_plans_as_needed(self.next_month + timedelta(days=1)) annual_plan.refresh_from_db() self.assertEqual(annual_plan.invoiced_through, annual_ledger_entries[0]) diff --git a/corporate/views.py b/corporate/views.py index b49f19083d..0e0b5b7b6d 100644 --- a/corporate/views.py +++ b/corporate/views.py @@ -204,6 +204,7 @@ def billing_home(request: HttpRequest) -> HttpResponse: if last_ledger_entry is not None: if new_plan is not None: # nocoverage plan = new_plan + assert(plan is not None) # for mypy plan_name = { CustomerPlan.STANDARD: 'Zulip Standard', CustomerPlan.PLUS: 'Zulip Plus',