stripe: Move `update_license_ledger_if_needed` to BillingSession.

This commit moves the 'update_license_ledger_if_needed' and its
helper function 'update_license_ledger_for_automanaged_plan'
to the 'BillingSession' abstract class.

This refactoring will help in minimizing duplicate code while
supporting both realm and remote_server customers.
This commit is contained in:
Prakhar Pratyush 2023-12-06 01:39:28 +05:30 committed by Tim Abbott
parent 133291ec2d
commit e5d71fe5ac
4 changed files with 83 additions and 72 deletions

View File

@ -2311,6 +2311,35 @@ class BillingSession(ABC):
else:
raise AssertionError("Pass licenses or licenses_at_next_renewal")
def update_license_ledger_for_automanaged_plan(
self, plan: CustomerPlan, event_time: datetime
) -> None:
new_plan, last_ledger_entry = self.make_end_of_cycle_updates_if_needed(plan, event_time)
if last_ledger_entry is None:
return
if new_plan is not None:
plan = new_plan
licenses_at_next_renewal = self.current_count_for_billed_licenses()
licenses = max(licenses_at_next_renewal, last_ledger_entry.licenses)
LicenseLedger.objects.create(
plan=plan,
event_time=event_time,
licenses=licenses,
licenses_at_next_renewal=licenses_at_next_renewal,
)
def update_license_ledger_if_needed(self, event_time: datetime) -> None:
customer = self.get_customer()
if customer is None:
return
plan = get_current_plan_by_customer(customer)
if plan is None:
return
if not plan.automanage_licenses:
return
self.update_license_ledger_for_automanaged_plan(plan, event_time)
class RealmBillingSession(BillingSession):
def __init__(
@ -3393,37 +3422,6 @@ def do_deactivate_remote_server(remote_server: RemoteZulipServer) -> None:
)
def update_license_ledger_for_automanaged_plan(
realm: Realm, plan: CustomerPlan, event_time: datetime
) -> None:
billing_session = RealmBillingSession(user=None, realm=realm)
new_plan, last_ledger_entry = billing_session.make_end_of_cycle_updates_if_needed(
plan, event_time
)
if last_ledger_entry is None:
return
if new_plan is not None:
plan = new_plan
licenses_at_next_renewal = get_latest_seat_count(realm)
licenses = max(licenses_at_next_renewal, last_ledger_entry.licenses)
LicenseLedger.objects.create(
plan=plan,
event_time=event_time,
licenses=licenses,
licenses_at_next_renewal=licenses_at_next_renewal,
)
def update_license_ledger_if_needed(realm: Realm, event_time: datetime) -> None:
plan = get_current_plan_by_realm(realm)
if plan is None:
return
if not plan.automanage_licenses:
return
update_license_ledger_for_automanaged_plan(realm, plan, event_time)
def get_plan_renewal_or_end_date(plan: CustomerPlan, event_time: datetime) -> datetime:
billing_period_end = start_of_next_billing_cycle(plan, event_time)

View File

@ -72,8 +72,6 @@ from corporate.lib.stripe import (
stripe_customer_has_credit_card_as_default_payment_method,
stripe_get_customer,
unsign_string,
update_license_ledger_for_automanaged_plan,
update_license_ledger_if_needed,
)
from corporate.models import (
Customer,
@ -1136,8 +1134,9 @@ class StripeTest(StripeTestCase):
self.assert_in_response(substring, response)
self.assert_not_in_success_response(["Go to your Zulip organization"], response)
billing_session = RealmBillingSession(user=user, realm=realm)
with patch("corporate.lib.stripe.get_latest_seat_count", return_value=12):
update_license_ledger_if_needed(realm, self.now)
billing_session.update_license_ledger_if_needed(self.now)
self.assertEqual(
LicenseLedger.objects.order_by("-id")
.values_list("licenses", "licenses_at_next_renewal")
@ -1146,7 +1145,7 @@ class StripeTest(StripeTestCase):
)
with patch("corporate.lib.stripe.get_latest_seat_count", return_value=15):
update_license_ledger_if_needed(realm, self.next_month)
billing_session.update_license_ledger_if_needed(self.next_month)
self.assertEqual(
LicenseLedger.objects.order_by("-id")
.values_list("licenses", "licenses_at_next_renewal")
@ -1201,7 +1200,7 @@ class StripeTest(StripeTestCase):
[invoice] = iter(stripe.Invoice.list(customer=stripe_customer.id))
with patch("corporate.lib.stripe.get_latest_seat_count", return_value=19):
update_license_ledger_if_needed(realm, add_months(free_trial_end_date, 10))
billing_session.update_license_ledger_if_needed(add_months(free_trial_end_date, 10))
self.assertEqual(
LicenseLedger.objects.order_by("-id")
.values_list("licenses", "licenses_at_next_renewal")
@ -2250,8 +2249,9 @@ class StripeTest(StripeTestCase):
# Verify that we still write LicenseLedger rows during the remaining
# part of the cycle
billing_session = RealmBillingSession(user=user, realm=user.realm)
with patch("corporate.lib.stripe.get_latest_seat_count", return_value=20):
update_license_ledger_if_needed(user.realm, self.now)
billing_session.update_license_ledger_if_needed(self.now)
self.assertEqual(
LicenseLedger.objects.order_by("-id")
.values_list("licenses", "licenses_at_next_renewal")
@ -2267,7 +2267,7 @@ class StripeTest(StripeTestCase):
# Check that we downgrade properly if the cycle is over
with patch("corporate.lib.stripe.get_latest_seat_count", return_value=30):
update_license_ledger_if_needed(user.realm, self.next_year)
billing_session.update_license_ledger_if_needed(self.next_year)
plan = CustomerPlan.objects.first()
assert plan is not None
self.assertEqual(get_realm("zulip").plan_type, Realm.PLAN_TYPE_LIMITED)
@ -2281,7 +2281,7 @@ class StripeTest(StripeTestCase):
# Verify that we don't write LicenseLedger rows once we've downgraded
with patch("corporate.lib.stripe.get_latest_seat_count", return_value=40):
update_license_ledger_if_needed(user.realm, self.next_year)
billing_session.update_license_ledger_if_needed(self.next_year)
self.assertEqual(
LicenseLedger.objects.order_by("-id")
.values_list("licenses", "licenses_at_next_renewal")
@ -2309,7 +2309,7 @@ class StripeTest(StripeTestCase):
# Check that we don't call invoice_plan after that final call
with patch("corporate.lib.stripe.get_latest_seat_count", return_value=50):
update_license_ledger_if_needed(user.realm, self.next_year + timedelta(days=80))
billing_session.update_license_ledger_if_needed(self.next_year + timedelta(days=80))
mocked = self.setup_mocked_stripe(
invoice_plans_as_needed, self.next_year + timedelta(days=400)
@ -2352,8 +2352,9 @@ class StripeTest(StripeTestCase):
["Your plan will switch to annual billing on February 2, 2012"], response
)
billing_session = RealmBillingSession(user=user, realm=user.realm)
with patch("corporate.lib.stripe.get_latest_seat_count", return_value=20):
update_license_ledger_if_needed(user.realm, self.now)
billing_session.update_license_ledger_if_needed(self.now)
self.assertEqual(LicenseLedger.objects.filter(plan=monthly_plan).count(), 2)
self.assertEqual(
LicenseLedger.objects.order_by("-id")
@ -2364,7 +2365,7 @@ class StripeTest(StripeTestCase):
with time_machine.travel(self.next_month, tick=False):
with patch("corporate.lib.stripe.get_latest_seat_count", return_value=25):
update_license_ledger_if_needed(user.realm, self.next_month)
billing_session.update_license_ledger_if_needed(self.next_month)
self.assertEqual(LicenseLedger.objects.filter(plan=monthly_plan).count(), 2)
customer = get_customer_by_realm(user.realm)
assert customer is not None
@ -2464,7 +2465,7 @@ class StripeTest(StripeTestCase):
self.assertEqual(monthly_plan_invoice_item[key], value)
with patch("corporate.lib.stripe.get_latest_seat_count", return_value=30):
update_license_ledger_if_needed(user.realm, add_months(self.next_month, 1))
billing_session.update_license_ledger_if_needed(add_months(self.next_month, 1))
invoice_plans_as_needed(add_months(self.next_month, 1))
[invoice0, invoice1, invoice2, invoice3] = iter(
@ -2660,8 +2661,9 @@ class StripeTest(StripeTestCase):
["Your plan will switch to monthly billing on January 2, 2013"], response
)
billing_session = RealmBillingSession(user=user, realm=user.realm)
with patch("corporate.lib.stripe.get_latest_seat_count", return_value=20):
update_license_ledger_if_needed(user.realm, self.now)
billing_session.update_license_ledger_if_needed(self.now)
self.assertEqual(LicenseLedger.objects.filter(plan=annual_plan).count(), 2)
self.assertEqual(
LicenseLedger.objects.order_by("-id")
@ -2676,7 +2678,7 @@ class StripeTest(StripeTestCase):
assert annual_plan.next_invoice_date is not None
with time_machine.travel(annual_plan.next_invoice_date, tick=False):
with patch("corporate.lib.stripe.get_latest_seat_count", return_value=25):
update_license_ledger_if_needed(user.realm, annual_plan.next_invoice_date)
billing_session.update_license_ledger_if_needed(annual_plan.next_invoice_date)
annual_plan.refresh_from_db()
self.assertEqual(annual_plan.status, CustomerPlan.SWITCH_TO_MONTHLY_AT_END_OF_CYCLE)
@ -2731,7 +2733,7 @@ class StripeTest(StripeTestCase):
# Check that we switch to monthly plan at the end of current billing cycle.
with time_machine.travel(self.next_year, tick=False):
with patch("corporate.lib.stripe.get_latest_seat_count", return_value=25):
update_license_ledger_if_needed(user.realm, self.next_year)
billing_session.update_license_ledger_if_needed(self.next_year)
self.assertEqual(LicenseLedger.objects.filter(plan=annual_plan).count(), 3)
customer = get_customer_by_realm(user.realm)
assert customer is not None
@ -2997,8 +2999,9 @@ class StripeTest(StripeTestCase):
self.assertEqual(plan.status, CustomerPlan.FREE_TRIAL)
# Add some extra users before the realm is deactivated
billing_session = RealmBillingSession(user=user, realm=user.realm)
with patch("corporate.lib.stripe.get_latest_seat_count", return_value=21):
update_license_ledger_if_needed(user.realm, self.now)
billing_session.update_license_ledger_if_needed(self.now)
last_ledger_entry = LicenseLedger.objects.order_by("id").last()
assert last_ledger_entry is not None
@ -3084,10 +3087,11 @@ class StripeTest(StripeTestCase):
response,
)
billing_session = RealmBillingSession(user=user, realm=user.realm)
# Verify that we still write LicenseLedger rows during the remaining
# part of the cycle
with patch("corporate.lib.stripe.get_latest_seat_count", return_value=20):
update_license_ledger_if_needed(user.realm, self.now)
billing_session.update_license_ledger_if_needed(self.now)
self.assertEqual(
LicenseLedger.objects.order_by("-id")
.values_list("licenses", "licenses_at_next_renewal")
@ -3103,7 +3107,7 @@ class StripeTest(StripeTestCase):
# Check that we downgrade properly if the cycle is over
with patch("corporate.lib.stripe.get_latest_seat_count", return_value=30):
update_license_ledger_if_needed(user.realm, free_trial_end_date)
billing_session.update_license_ledger_if_needed(free_trial_end_date)
plan = CustomerPlan.objects.first()
assert plan is not None
self.assertIsNone(plan.next_invoice_date)
@ -3118,7 +3122,7 @@ class StripeTest(StripeTestCase):
# Verify that we don't write LicenseLedger rows once we've downgraded
with patch("corporate.lib.stripe.get_latest_seat_count", return_value=40):
update_license_ledger_if_needed(user.realm, self.next_year)
billing_session.update_license_ledger_if_needed(self.next_year)
self.assertEqual(
LicenseLedger.objects.order_by("-id")
.values_list("licenses", "licenses_at_next_renewal")
@ -3546,8 +3550,9 @@ class StripeTest(StripeTestCase):
self.assertEqual(plan.status, CustomerPlan.ACTIVE)
# Add some extra users before the realm is deactivated
billing_session = RealmBillingSession(user=user, realm=user.realm)
with patch("corporate.lib.stripe.get_latest_seat_count", return_value=20):
update_license_ledger_if_needed(user.realm, self.now)
billing_session.update_license_ledger_if_needed(self.now)
last_ledger_entry = LicenseLedger.objects.order_by("id").last()
assert last_ledger_entry is not None
@ -4681,8 +4686,9 @@ class LicenseLedgerTest(StripeTestCase):
def test_update_license_ledger_if_needed(self) -> None:
realm = get_realm("zulip")
billing_session = RealmBillingSession(user=None, realm=realm)
# Test no Customer
update_license_ledger_if_needed(realm, self.now)
billing_session.update_license_ledger_if_needed(self.now)
self.assertFalse(LicenseLedger.objects.exists())
# Test plan not automanaged
self.local_upgrade(
@ -4692,18 +4698,18 @@ class LicenseLedgerTest(StripeTestCase):
self.assertEqual(LicenseLedger.objects.count(), 1)
self.assertEqual(plan.licenses(), self.seat_count + 1)
self.assertEqual(plan.licenses_at_next_renewal(), self.seat_count + 1)
update_license_ledger_if_needed(realm, self.now)
billing_session.update_license_ledger_if_needed(self.now)
self.assertEqual(LicenseLedger.objects.count(), 1)
# Test no active plan
plan.automanage_licenses = True
plan.status = CustomerPlan.ENDED
plan.save(update_fields=["automanage_licenses", "status"])
update_license_ledger_if_needed(realm, self.now)
billing_session.update_license_ledger_if_needed(self.now)
self.assertEqual(LicenseLedger.objects.count(), 1)
# Test update needed
plan.status = CustomerPlan.ACTIVE
plan.save(update_fields=["status"])
update_license_ledger_if_needed(realm, self.now)
billing_session.update_license_ledger_if_needed(self.now)
self.assertEqual(LicenseLedger.objects.count(), 2)
def test_update_license_ledger_for_automanaged_plan(self) -> None:
@ -4716,25 +4722,27 @@ class LicenseLedgerTest(StripeTestCase):
assert plan is not None
self.assertEqual(plan.licenses(), self.seat_count)
self.assertEqual(plan.licenses_at_next_renewal(), self.seat_count)
billing_session = RealmBillingSession(user=None, realm=realm)
# Simple increase
with patch("corporate.lib.stripe.get_latest_seat_count", return_value=23):
update_license_ledger_for_automanaged_plan(realm, plan, self.now)
billing_session.update_license_ledger_for_automanaged_plan(plan, self.now)
self.assertEqual(plan.licenses(), 23)
self.assertEqual(plan.licenses_at_next_renewal(), 23)
# Decrease
with patch("corporate.lib.stripe.get_latest_seat_count", return_value=20):
update_license_ledger_for_automanaged_plan(realm, plan, self.now)
billing_session.update_license_ledger_for_automanaged_plan(plan, self.now)
self.assertEqual(plan.licenses(), 23)
self.assertEqual(plan.licenses_at_next_renewal(), 20)
# Increase, but not past high watermark
with patch("corporate.lib.stripe.get_latest_seat_count", return_value=21):
update_license_ledger_for_automanaged_plan(realm, plan, self.now)
billing_session.update_license_ledger_for_automanaged_plan(plan, self.now)
self.assertEqual(plan.licenses(), 23)
self.assertEqual(plan.licenses_at_next_renewal(), 21)
# Increase, but after renewal date, and below last year's high watermark
with patch("corporate.lib.stripe.get_latest_seat_count", return_value=22):
update_license_ledger_for_automanaged_plan(
realm, plan, self.next_year + timedelta(seconds=1)
billing_session.update_license_ledger_for_automanaged_plan(
plan, self.next_year + timedelta(seconds=1)
)
self.assertEqual(plan.licenses(), 22)
self.assertEqual(plan.licenses_at_next_renewal(), 22)
@ -4884,24 +4892,25 @@ class InvoiceTest(StripeTestCase):
self.login_user(user)
with time_machine.travel(self.now, tick=False):
self.add_card_and_upgrade(user)
realm = get_realm("zulip")
billing_session = RealmBillingSession(user=user, realm=realm)
# Increase
with patch("corporate.lib.stripe.get_latest_seat_count", return_value=self.seat_count + 3):
update_license_ledger_if_needed(get_realm("zulip"), self.now + timedelta(days=100))
billing_session.update_license_ledger_if_needed(self.now + timedelta(days=100))
# Decrease
with patch("corporate.lib.stripe.get_latest_seat_count", return_value=self.seat_count):
update_license_ledger_if_needed(get_realm("zulip"), self.now + timedelta(days=200))
billing_session.update_license_ledger_if_needed(self.now + timedelta(days=200))
# Increase, but not past high watermark
with patch("corporate.lib.stripe.get_latest_seat_count", return_value=self.seat_count + 1):
update_license_ledger_if_needed(get_realm("zulip"), self.now + timedelta(days=300))
billing_session.update_license_ledger_if_needed(self.now + timedelta(days=300))
# Increase, but after renewal date, and below last year's high watermark
with patch("corporate.lib.stripe.get_latest_seat_count", return_value=self.seat_count + 2):
update_license_ledger_if_needed(get_realm("zulip"), self.now + timedelta(days=400))
billing_session.update_license_ledger_if_needed(self.now + timedelta(days=400))
# Increase, but after event_time
with patch("corporate.lib.stripe.get_latest_seat_count", return_value=self.seat_count + 3):
update_license_ledger_if_needed(get_realm("zulip"), self.now + timedelta(days=500))
billing_session.update_license_ledger_if_needed(self.now + timedelta(days=500))
plan = CustomerPlan.objects.first()
assert plan is not None
billing_session = RealmBillingSession(realm=user.realm)
billing_session.invoice_plan(plan, self.now + timedelta(days=400))
stripe_customer_id = plan.customer.stripe_customer_id
assert stripe_customer_id is not None

View File

@ -55,7 +55,7 @@ from zerver.models import (
from zerver.tornado.django_api import send_event_on_commit
if settings.BILLING_ENABLED:
from corporate.lib.stripe import update_license_ledger_if_needed
from corporate.lib.stripe import RealmBillingSession
MAX_NUM_ONBOARDING_MESSAGES = 1000
@ -514,7 +514,8 @@ def do_create_user(
event_time,
)
if settings.BILLING_ENABLED:
update_license_ledger_if_needed(user_profile.realm, event_time)
billing_session = RealmBillingSession(user=user_profile, realm=user_profile.realm)
billing_session.update_license_ledger_if_needed(event_time)
system_user_group = get_system_user_group_for_user(user_profile)
UserGroupMembership.objects.create(user_profile=user_profile, user_group=system_user_group)
@ -624,7 +625,8 @@ def do_activate_mirror_dummy_user(
event_time,
)
if settings.BILLING_ENABLED:
update_license_ledger_if_needed(user_profile.realm, event_time)
billing_session = RealmBillingSession(user=user_profile, realm=user_profile.realm)
billing_session.update_license_ledger_if_needed(event_time)
notify_created_user(user_profile, [])
@ -676,7 +678,8 @@ def do_reactivate_user(user_profile: UserProfile, *, acting_user: Optional[UserP
event_time,
)
if settings.BILLING_ENABLED:
update_license_ledger_if_needed(user_profile.realm, event_time)
billing_session = RealmBillingSession(user=user_profile, realm=user_profile.realm)
billing_session.update_license_ledger_if_needed(event_time)
event = dict(
type="realm_user", op="update", person=dict(user_id=user_profile.id, is_active=True)

View File

@ -51,7 +51,7 @@ from zerver.models import (
from zerver.tornado.django_api import send_event, send_event_on_commit
if settings.BILLING_ENABLED:
from corporate.lib.stripe import update_license_ledger_if_needed
from corporate.lib.stripe import RealmBillingSession
def do_delete_user(user_profile: UserProfile, *, acting_user: Optional[UserProfile]) -> None:
@ -364,7 +364,8 @@ def do_deactivate_user(
increment=-1,
)
if settings.BILLING_ENABLED:
update_license_ledger_if_needed(user_profile.realm, event_time)
billing_session = RealmBillingSession(user=user_profile, realm=user_profile.realm)
billing_session.update_license_ledger_if_needed(event_time)
transaction.on_commit(lambda: delete_user_sessions(user_profile))