stripe: Add 'do_update_plan' method to the 'BillingSession' class.

This commit moves a major portion of the 'update_plan`
view to a new shared 'BillingSession.do_update_plan' method.

This refactoring will help in minimizing duplicate code
while supporting both realm and remote_server customers.
This commit is contained in:
Prakhar Pratyush 2023-11-22 17:14:02 +05:30 committed by Tim Abbott
parent efa423395f
commit 51b39cb682
6 changed files with 187 additions and 164 deletions

View File

@ -701,7 +701,9 @@ class TestSupportEndpoint(ZulipTestCase):
iago = self.example_user("iago") iago = self.example_user("iago")
self.login_user(iago) self.login_user(iago)
with mock.patch("analytics.views.support.downgrade_at_the_end_of_billing_cycle") as m: with mock.patch(
"analytics.views.support.RealmBillingSession.downgrade_at_the_end_of_billing_cycle"
) as m:
result = self.client_post( result = self.client_post(
"/activity/support", "/activity/support",
{ {
@ -709,13 +711,13 @@ class TestSupportEndpoint(ZulipTestCase):
"modify_plan": "downgrade_at_billing_cycle_end", "modify_plan": "downgrade_at_billing_cycle_end",
}, },
) )
m.assert_called_once_with(get_realm("zulip")) m.assert_called_once()
self.assert_in_success_response( self.assert_in_success_response(
["zulip marked for downgrade at the end of billing cycle"], result ["zulip marked for downgrade at the end of billing cycle"], result
) )
with mock.patch( with mock.patch(
"analytics.views.support.downgrade_now_without_creating_additional_invoices" "analytics.views.support.RealmBillingSession.downgrade_now_without_creating_additional_invoices"
) as m: ) as m:
result = self.client_post( result = self.client_post(
"/activity/support", "/activity/support",
@ -724,13 +726,13 @@ class TestSupportEndpoint(ZulipTestCase):
"modify_plan": "downgrade_now_without_additional_licenses", "modify_plan": "downgrade_now_without_additional_licenses",
}, },
) )
m.assert_called_once_with(get_realm("zulip")) m.assert_called_once()
self.assert_in_success_response( self.assert_in_success_response(
["zulip downgraded without creating additional invoices"], result ["zulip downgraded without creating additional invoices"], result
) )
with mock.patch( with mock.patch(
"analytics.views.support.downgrade_now_without_creating_additional_invoices" "analytics.views.support.RealmBillingSession.downgrade_now_without_creating_additional_invoices"
) as m1: ) as m1:
with mock.patch("analytics.views.support.void_all_open_invoices", return_value=1) as m2: with mock.patch("analytics.views.support.void_all_open_invoices", return_value=1) as m2:
result = self.client_post( result = self.client_post(
@ -740,7 +742,7 @@ class TestSupportEndpoint(ZulipTestCase):
"modify_plan": "downgrade_now_void_open_invoices", "modify_plan": "downgrade_now_void_open_invoices",
}, },
) )
m1.assert_called_once_with(get_realm("zulip")) m1.assert_called_once()
m2.assert_called_once_with(get_realm("zulip")) m2.assert_called_once_with(get_realm("zulip"))
self.assert_in_success_response( self.assert_in_success_response(
["zulip downgraded and voided 1 open invoices"], result ["zulip downgraded and voided 1 open invoices"], result

View File

@ -55,8 +55,6 @@ if settings.ZILENCER_ENABLED:
if settings.BILLING_ENABLED: if settings.BILLING_ENABLED:
from corporate.lib.stripe import ( from corporate.lib.stripe import (
RealmBillingSession, RealmBillingSession,
downgrade_at_the_end_of_billing_cycle,
downgrade_now_without_creating_additional_invoices,
get_latest_seat_count, get_latest_seat_count,
switch_realm_from_standard_to_plus_plan, switch_realm_from_standard_to_plus_plan,
void_all_open_invoices, void_all_open_invoices,
@ -264,18 +262,21 @@ def support(
approve_realm_sponsorship(realm, acting_user=acting_user) approve_realm_sponsorship(realm, acting_user=acting_user)
context["success_message"] = f"Sponsorship approved for {realm.string_id}" context["success_message"] = f"Sponsorship approved for {realm.string_id}"
elif modify_plan is not None: elif modify_plan is not None:
billing_session = RealmBillingSession(
user=acting_user, realm=realm, support_session=True
)
if modify_plan == "downgrade_at_billing_cycle_end": if modify_plan == "downgrade_at_billing_cycle_end":
downgrade_at_the_end_of_billing_cycle(realm) billing_session.downgrade_at_the_end_of_billing_cycle()
context[ context[
"success_message" "success_message"
] = f"{realm.string_id} marked for downgrade at the end of billing cycle" ] = f"{realm.string_id} marked for downgrade at the end of billing cycle"
elif modify_plan == "downgrade_now_without_additional_licenses": elif modify_plan == "downgrade_now_without_additional_licenses":
downgrade_now_without_creating_additional_invoices(realm) billing_session.downgrade_now_without_creating_additional_invoices()
context[ context[
"success_message" "success_message"
] = f"{realm.string_id} downgraded without creating additional invoices" ] = f"{realm.string_id} downgraded without creating additional invoices"
elif modify_plan == "downgrade_now_void_open_invoices": elif modify_plan == "downgrade_now_void_open_invoices":
downgrade_now_without_creating_additional_invoices(realm) billing_session.downgrade_now_without_creating_additional_invoices()
voided_invoices_count = void_all_open_invoices(realm) voided_invoices_count = void_all_open_invoices(realm)
context[ context[
"success_message" "success_message"

View File

@ -451,6 +451,13 @@ class InitialUpgradeRequest:
tier: int tier: int
@dataclass
class UpdatePlanRequest:
status: Optional[int]
licenses: Optional[int]
licenses_at_next_renewal: Optional[int]
class AuditLogEventType(Enum): class AuditLogEventType(Enum):
STRIPE_CUSTOMER_CREATED = 1 STRIPE_CUSTOMER_CREATED = 1
STRIPE_CARD_CHANGED = 2 STRIPE_CARD_CHANGED = 2
@ -1264,6 +1271,134 @@ class BillingSession(ABC):
return None, context return None, context
def downgrade_at_the_end_of_billing_cycle(self, plan: Optional[CustomerPlan] = None) -> None:
if plan is None: # nocoverage
# TODO: Add test coverage. Right now, this logic is used
# in production but mocked in tests.
customer = self.get_customer()
assert customer is not None
plan = get_current_plan_by_customer(customer)
assert plan is not None
do_change_plan_status(plan, CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE)
# During realm deactivation we instantly downgrade the plan to Limited.
# Extra users added in the final month are not charged. Also used
# for the cancellation of Free Trial.
def downgrade_now_without_creating_additional_invoices(
self,
plan: Optional[CustomerPlan] = None,
) -> None:
if plan is None:
customer = self.get_customer()
if customer is None:
return
plan = get_current_plan_by_customer(customer)
if plan is None:
return # nocoverage
self.process_downgrade(plan)
plan.invoiced_through = LicenseLedger.objects.filter(plan=plan).order_by("id").last()
plan.next_invoice_date = next_invoice_date(plan)
plan.save(update_fields=["invoiced_through", "next_invoice_date"])
def do_update_plan(self, update_plan_request: UpdatePlanRequest) -> None:
customer = self.get_customer()
assert customer is not None
plan = get_current_plan_by_customer(customer)
assert plan is not None # for mypy
new_plan, last_ledger_entry = self.make_end_of_cycle_updates_if_needed(plan, timezone_now())
if new_plan is not None:
raise JsonableError(
_(
"Unable to update the plan. The plan has been expired and replaced with a new plan."
)
)
if last_ledger_entry is None:
raise JsonableError(_("Unable to update the plan. The plan has ended."))
status = update_plan_request.status
if status is not None:
if status == CustomerPlan.ACTIVE:
assert plan.status < CustomerPlan.LIVE_STATUS_THRESHOLD
do_change_plan_status(plan, status)
elif status == CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE:
assert plan.status < CustomerPlan.LIVE_STATUS_THRESHOLD
self.downgrade_at_the_end_of_billing_cycle(plan=plan)
elif status == CustomerPlan.SWITCH_TO_ANNUAL_AT_END_OF_CYCLE:
assert plan.billing_schedule == CustomerPlan.MONTHLY
assert plan.status < CustomerPlan.LIVE_STATUS_THRESHOLD
# Customer needs to switch to an active plan first to avoid unexpected behavior.
assert plan.status != CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE
assert plan.fixed_price is None
do_change_plan_status(plan, status)
elif status == CustomerPlan.SWITCH_TO_MONTHLY_AT_END_OF_CYCLE:
assert plan.billing_schedule == CustomerPlan.ANNUAL
assert plan.status < CustomerPlan.LIVE_STATUS_THRESHOLD
# Customer needs to switch to an active plan first to avoid unexpected behavior.
assert plan.status != CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE
assert plan.fixed_price is None
do_change_plan_status(plan, status)
elif status == CustomerPlan.ENDED:
assert plan.is_free_trial()
self.downgrade_now_without_creating_additional_invoices(plan=plan)
return
licenses = update_plan_request.licenses
if licenses is not None:
if plan.automanage_licenses:
raise JsonableError(
_(
"Unable to update licenses manually. Your plan is on automatic license management."
)
)
if last_ledger_entry.licenses == licenses:
raise JsonableError(
_(
"Your plan is already on {licenses} licenses in the current billing period."
).format(licenses=licenses)
)
if last_ledger_entry.licenses > licenses:
raise JsonableError(
_("You cannot decrease the licenses in the current billing period.")
)
validate_licenses(
plan.charge_automatically,
licenses,
self.current_count_for_billed_licenses(),
plan.customer.exempt_from_license_number_check,
)
update_license_ledger_for_manual_plan(plan, timezone_now(), licenses=licenses)
return
licenses_at_next_renewal = update_plan_request.licenses_at_next_renewal
if licenses_at_next_renewal is not None:
if plan.automanage_licenses:
raise JsonableError(
_(
"Unable to update licenses manually. Your plan is on automatic license management."
)
)
if last_ledger_entry.licenses_at_next_renewal == licenses_at_next_renewal:
raise JsonableError(
_(
"Your plan is already scheduled to renew with {licenses_at_next_renewal} licenses."
).format(licenses_at_next_renewal=licenses_at_next_renewal)
)
validate_licenses(
plan.charge_automatically,
licenses_at_next_renewal,
self.current_count_for_billed_licenses(),
plan.customer.exempt_from_license_number_check,
)
update_license_ledger_for_manual_plan(
plan, timezone_now(), licenses_at_next_renewal=licenses_at_next_renewal
)
return
raise JsonableError(_("Nothing to change."))
class RealmBillingSession(BillingSession): class RealmBillingSession(BillingSession):
def __init__( def __init__(
@ -2174,27 +2309,6 @@ def do_change_plan_status(plan: CustomerPlan, status: int) -> None:
) )
# During realm deactivation we instantly downgrade the plan to Limited.
# Extra users added in the final month are not charged. Also used
# for the cancellation of Free Trial.
def downgrade_now_without_creating_additional_invoices(realm: Realm) -> None:
plan = get_current_plan_by_realm(realm)
if plan is None:
return
billing_session = RealmBillingSession(user=None, realm=realm)
billing_session.process_downgrade(plan)
plan.invoiced_through = LicenseLedger.objects.filter(plan=plan).order_by("id").last()
plan.next_invoice_date = next_invoice_date(plan)
plan.save(update_fields=["invoiced_through", "next_invoice_date"])
def downgrade_at_the_end_of_billing_cycle(realm: Realm) -> None:
plan = get_current_plan_by_realm(realm)
assert plan is not None
do_change_plan_status(plan, CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE)
def get_all_invoices_for_customer(customer: Customer) -> Generator[stripe.Invoice, None, None]: def get_all_invoices_for_customer(customer: Customer) -> Generator[stripe.Invoice, None, None]:
if customer.stripe_customer_id is None: if customer.stripe_customer_id is None:
return return
@ -2251,8 +2365,8 @@ def downgrade_small_realms_behind_on_payments_as_needed() -> None:
continue continue
# We've now decided to downgrade this customer and void all invoices, and the below will execute this. # We've now decided to downgrade this customer and void all invoices, and the below will execute this.
billing_session = RealmBillingSession(user=None, realm=realm)
downgrade_now_without_creating_additional_invoices(realm) billing_session.downgrade_now_without_creating_additional_invoices()
void_all_open_invoices(realm) void_all_open_invoices(realm)
context: Dict[str, Union[str, Realm]] = { context: Dict[str, Union[str, Realm]] = {
"upgrade_url": f"{realm.uri}{reverse('initial_upgrade')}", "upgrade_url": f"{realm.uri}{reverse('initial_upgrade')}",

View File

@ -2185,7 +2185,7 @@ class StripeTest(StripeTestCase):
self.assertEqual(plan.licenses(), self.seat_count) self.assertEqual(plan.licenses(), self.seat_count)
self.assertEqual(plan.licenses_at_next_renewal(), self.seat_count) self.assertEqual(plan.licenses_at_next_renewal(), self.seat_count)
with self.assertLogs("corporate.stripe", "INFO") as m: with self.assertLogs("corporate.stripe", "INFO") as m:
with patch("corporate.views.billing_page.timezone_now", return_value=self.now): with patch("corporate.lib.stripe.timezone_now", return_value=self.now):
response = self.client_patch( response = self.client_patch(
"/json/billing/plan", {"status": CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE} "/json/billing/plan", {"status": CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE}
) )
@ -2299,7 +2299,7 @@ class StripeTest(StripeTestCase):
assert new_plan is not None assert new_plan is not None
with self.assertLogs("corporate.stripe", "INFO") as m: with self.assertLogs("corporate.stripe", "INFO") as m:
with patch("corporate.views.billing_page.timezone_now", return_value=self.now): with patch("corporate.lib.stripe.timezone_now", return_value=self.now):
response = self.client_patch( response = self.client_patch(
"/json/billing/plan", "/json/billing/plan",
{"status": CustomerPlan.SWITCH_TO_ANNUAL_AT_END_OF_CYCLE}, {"status": CustomerPlan.SWITCH_TO_ANNUAL_AT_END_OF_CYCLE},
@ -2488,7 +2488,7 @@ class StripeTest(StripeTestCase):
new_plan = get_current_plan_by_realm(user.realm) new_plan = get_current_plan_by_realm(user.realm)
assert new_plan is not None assert new_plan is not None
with self.assertLogs("corporate.stripe", "INFO") as m: with self.assertLogs("corporate.stripe", "INFO") as m:
with patch("corporate.views.billing_page.timezone_now", return_value=self.now): with patch("corporate.lib.stripe.timezone_now", return_value=self.now):
response = self.client_patch( response = self.client_patch(
"/json/billing/plan", "/json/billing/plan",
{"status": CustomerPlan.SWITCH_TO_ANNUAL_AT_END_OF_CYCLE}, {"status": CustomerPlan.SWITCH_TO_ANNUAL_AT_END_OF_CYCLE},
@ -2602,7 +2602,7 @@ class StripeTest(StripeTestCase):
assert new_plan is not None assert new_plan is not None
with self.assertLogs("corporate.stripe", "INFO") as m: with self.assertLogs("corporate.stripe", "INFO") as m:
with patch("corporate.views.billing_page.timezone_now", return_value=self.now): with patch("corporate.lib.stripe.timezone_now", return_value=self.now):
response = self.client_patch( response = self.client_patch(
"/json/billing/plan", "/json/billing/plan",
{"status": CustomerPlan.SWITCH_TO_MONTHLY_AT_END_OF_CYCLE}, {"status": CustomerPlan.SWITCH_TO_MONTHLY_AT_END_OF_CYCLE},
@ -2767,7 +2767,7 @@ class StripeTest(StripeTestCase):
with patch("corporate.lib.stripe.timezone_now", return_value=self.now): with patch("corporate.lib.stripe.timezone_now", return_value=self.now):
self.local_upgrade(self.seat_count, True, CustomerPlan.ANNUAL, True, False) self.local_upgrade(self.seat_count, True, CustomerPlan.ANNUAL, True, False)
with self.assertLogs("corporate.stripe", "INFO") as m: with self.assertLogs("corporate.stripe", "INFO") as m:
with patch("corporate.views.billing_page.timezone_now", return_value=self.now): with patch("corporate.lib.stripe.timezone_now", return_value=self.now):
response = self.client_patch( response = self.client_patch(
"/json/billing/plan", {"status": CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE} "/json/billing/plan", {"status": CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE}
) )
@ -2781,7 +2781,7 @@ class StripeTest(StripeTestCase):
assert plan is not None assert plan is not None
self.assertEqual(plan.status, CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE) self.assertEqual(plan.status, CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE)
with self.assertLogs("corporate.stripe", "INFO") as m: with self.assertLogs("corporate.stripe", "INFO") as m:
with patch("corporate.views.billing_page.timezone_now", return_value=self.now): with patch("corporate.lib.stripe.timezone_now", return_value=self.now):
response = self.client_patch("/json/billing/plan", {"status": CustomerPlan.ACTIVE}) response = self.client_patch("/json/billing/plan", {"status": CustomerPlan.ACTIVE})
expected_log = f"INFO:corporate.stripe:Change plan status: Customer.id: {stripe_customer_id}, CustomerPlan.id: {new_plan.id}, status: {CustomerPlan.ACTIVE}" expected_log = f"INFO:corporate.stripe:Change plan status: Customer.id: {stripe_customer_id}, CustomerPlan.id: {new_plan.id}, status: {CustomerPlan.ACTIVE}"
self.assertEqual(m.output[0], expected_log) self.assertEqual(m.output[0], expected_log)
@ -2807,7 +2807,7 @@ class StripeTest(StripeTestCase):
stripe_customer_id = Customer.objects.get(realm=user.realm).id stripe_customer_id = Customer.objects.get(realm=user.realm).id
new_plan = get_current_plan_by_realm(user.realm) new_plan = get_current_plan_by_realm(user.realm)
assert new_plan is not None assert new_plan is not None
with patch("corporate.views.billing_page.timezone_now", return_value=self.now): with patch("corporate.lib.stripe.timezone_now", return_value=self.now):
self.client_patch( self.client_patch(
"/json/billing/plan", {"status": CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE} "/json/billing/plan", {"status": CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE}
) )
@ -2848,7 +2848,7 @@ class StripeTest(StripeTestCase):
self.login_user(user) self.login_user(user)
with patch("corporate.views.billing_page.timezone_now", return_value=self.now): with patch("corporate.lib.stripe.timezone_now", return_value=self.now):
self.client_patch("/json/billing/plan", {"status": CustomerPlan.ENDED}) self.client_patch("/json/billing/plan", {"status": CustomerPlan.ENDED})
plan.refresh_from_db() plan.refresh_from_db()
@ -2880,7 +2880,7 @@ class StripeTest(StripeTestCase):
self.login_user(user) self.login_user(user)
with self.assertLogs("corporate.stripe", "INFO") as m: with self.assertLogs("corporate.stripe", "INFO") as m:
with patch("corporate.views.billing_page.timezone_now", return_value=self.now): with patch("corporate.lib.stripe.timezone_now", return_value=self.now):
self.client_patch( self.client_patch(
"/json/billing/plan", {"status": CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE} "/json/billing/plan", {"status": CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE}
) )
@ -2935,38 +2935,38 @@ class StripeTest(StripeTestCase):
with patch("corporate.lib.stripe.timezone_now", return_value=self.now): with patch("corporate.lib.stripe.timezone_now", return_value=self.now):
self.upgrade(invoice=True, licenses=100) self.upgrade(invoice=True, licenses=100)
with patch("corporate.views.billing_page.timezone_now", return_value=self.now): with patch("corporate.lib.stripe.timezone_now", return_value=self.now):
result = self.client_patch("/json/billing/plan", {"licenses": 100}) result = self.client_patch("/json/billing/plan", {"licenses": 100})
self.assert_json_error_contains( self.assert_json_error_contains(
result, "Your plan is already on 100 licenses in the current billing period." result, "Your plan is already on 100 licenses in the current billing period."
) )
with patch("corporate.views.billing_page.timezone_now", return_value=self.now): with patch("corporate.lib.stripe.timezone_now", return_value=self.now):
result = self.client_patch("/json/billing/plan", {"licenses_at_next_renewal": 100}) result = self.client_patch("/json/billing/plan", {"licenses_at_next_renewal": 100})
self.assert_json_error_contains( self.assert_json_error_contains(
result, "Your plan is already scheduled to renew with 100 licenses." result, "Your plan is already scheduled to renew with 100 licenses."
) )
with patch("corporate.views.billing_page.timezone_now", return_value=self.now): with patch("corporate.lib.stripe.timezone_now", return_value=self.now):
result = self.client_patch("/json/billing/plan", {"licenses": 50}) result = self.client_patch("/json/billing/plan", {"licenses": 50})
self.assert_json_error_contains( self.assert_json_error_contains(
result, "You cannot decrease the licenses in the current billing period." result, "You cannot decrease the licenses in the current billing period."
) )
with patch("corporate.views.billing_page.timezone_now", return_value=self.now): with patch("corporate.lib.stripe.timezone_now", return_value=self.now):
result = self.client_patch("/json/billing/plan", {"licenses_at_next_renewal": 25}) result = self.client_patch("/json/billing/plan", {"licenses_at_next_renewal": 25})
self.assert_json_error_contains( self.assert_json_error_contains(
result, result,
"You must purchase licenses for all active users in your organization (minimum 30).", "You must purchase licenses for all active users in your organization (minimum 30).",
) )
with patch("corporate.views.billing_page.timezone_now", return_value=self.now): with patch("corporate.lib.stripe.timezone_now", return_value=self.now):
result = self.client_patch("/json/billing/plan", {"licenses": 2000}) result = self.client_patch("/json/billing/plan", {"licenses": 2000})
self.assert_json_error_contains( self.assert_json_error_contains(
result, "Invoices with more than 1000 licenses can't be processed from this page." result, "Invoices with more than 1000 licenses can't be processed from this page."
) )
with patch("corporate.views.billing_page.timezone_now", return_value=self.now): with patch("corporate.lib.stripe.timezone_now", return_value=self.now):
result = self.client_patch("/json/billing/plan", {"licenses": 150}) result = self.client_patch("/json/billing/plan", {"licenses": 150})
self.assert_json_success(result) self.assert_json_success(result)
invoice_plans_as_needed(self.next_year) invoice_plans_as_needed(self.next_year)
@ -3016,7 +3016,7 @@ class StripeTest(StripeTestCase):
for key, value in line_item_params.items(): for key, value in line_item_params.items():
self.assertEqual(extra_license_item.get(key), value) self.assertEqual(extra_license_item.get(key), value)
with patch("corporate.views.billing_page.timezone_now", return_value=self.next_year): with patch("corporate.lib.stripe.timezone_now", return_value=self.next_year):
result = self.client_patch("/json/billing/plan", {"licenses_at_next_renewal": 120}) result = self.client_patch("/json/billing/plan", {"licenses_at_next_renewal": 120})
self.assert_json_success(result) self.assert_json_success(result)
invoice_plans_as_needed(self.next_year + timedelta(days=365)) invoice_plans_as_needed(self.next_year + timedelta(days=365))
@ -3069,7 +3069,7 @@ class StripeTest(StripeTestCase):
with patch("corporate.lib.stripe.timezone_now", return_value=self.now): with patch("corporate.lib.stripe.timezone_now", return_value=self.now):
self.local_upgrade(100, False, CustomerPlan.ANNUAL, True, False) self.local_upgrade(100, False, CustomerPlan.ANNUAL, True, False)
with patch("corporate.views.billing_page.timezone_now", return_value=self.now): with patch("corporate.lib.stripe.timezone_now", return_value=self.now):
result = self.client_patch( result = self.client_patch(
"/json/billing/plan", "/json/billing/plan",
{"licenses_at_next_renewal": get_latest_seat_count(user.realm) - 2}, {"licenses_at_next_renewal": get_latest_seat_count(user.realm) - 2},
@ -3115,11 +3115,11 @@ class StripeTest(StripeTestCase):
with patch("corporate.lib.stripe.timezone_now", return_value=self.now): with patch("corporate.lib.stripe.timezone_now", return_value=self.now):
self.local_upgrade(self.seat_count, True, CustomerPlan.ANNUAL, True, False) self.local_upgrade(self.seat_count, True, CustomerPlan.ANNUAL, True, False)
with patch("corporate.views.billing_page.timezone_now", return_value=self.now): with patch("corporate.lib.stripe.timezone_now", return_value=self.now):
result = self.client_patch("/json/billing/plan", {"licenses": 100}) result = self.client_patch("/json/billing/plan", {"licenses": 100})
self.assert_json_error_contains(result, "Your plan is on automatic license management.") self.assert_json_error_contains(result, "Your plan is on automatic license management.")
with patch("corporate.views.billing_page.timezone_now", return_value=self.now): with patch("corporate.lib.stripe.timezone_now", return_value=self.now):
result = self.client_patch("/json/billing/plan", {"licenses_at_next_renewal": 100}) result = self.client_patch("/json/billing/plan", {"licenses_at_next_renewal": 100})
self.assert_json_error_contains(result, "Your plan is on automatic license management.") self.assert_json_error_contains(result, "Your plan is on automatic license management.")
@ -3139,7 +3139,7 @@ class StripeTest(StripeTestCase):
self.local_upgrade(self.seat_count, True, CustomerPlan.ANNUAL, True, False) self.local_upgrade(self.seat_count, True, CustomerPlan.ANNUAL, True, False)
self.login_user(self.example_user("hamlet")) self.login_user(self.example_user("hamlet"))
with patch("corporate.views.billing_page.timezone_now", return_value=self.now): with patch("corporate.lib.stripe.timezone_now", return_value=self.now):
response = self.client_patch("/json/billing/plan", {}) response = self.client_patch("/json/billing/plan", {})
self.assert_json_error_contains(response, "Nothing to change") self.assert_json_error_contains(response, "Nothing to change")
@ -3149,7 +3149,7 @@ class StripeTest(StripeTestCase):
self.login_user(self.example_user("hamlet")) self.login_user(self.example_user("hamlet"))
with self.assertLogs("corporate.stripe", "INFO") as m: with self.assertLogs("corporate.stripe", "INFO") as m:
with patch("corporate.views.billing_page.timezone_now", return_value=self.now): with patch("corporate.lib.stripe.timezone_now", return_value=self.now):
result = self.client_patch( result = self.client_patch(
"/json/billing/plan", {"status": CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE} "/json/billing/plan", {"status": CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE}
) )
@ -3171,7 +3171,7 @@ class StripeTest(StripeTestCase):
self.login_user(self.example_user("hamlet")) self.login_user(self.example_user("hamlet"))
with self.assertLogs("corporate.stripe", "INFO") as m: with self.assertLogs("corporate.stripe", "INFO") as m:
with patch("corporate.views.billing_page.timezone_now", return_value=self.now): with patch("corporate.lib.stripe.timezone_now", return_value=self.now):
result = self.client_patch( result = self.client_patch(
"/json/billing/plan", {"status": CustomerPlan.SWITCH_TO_ANNUAL_AT_END_OF_CYCLE} "/json/billing/plan", {"status": CustomerPlan.SWITCH_TO_ANNUAL_AT_END_OF_CYCLE}
) )

View File

@ -4,21 +4,10 @@ from typing import Any, Dict, Optional
from django.http import HttpRequest, HttpResponse, HttpResponseRedirect from django.http import HttpRequest, HttpResponse, HttpResponseRedirect
from django.shortcuts import render from django.shortcuts import render
from django.urls import reverse from django.urls import reverse
from django.utils.timezone import now as timezone_now
from django.utils.translation import gettext as _
from corporate.lib.stripe import ( from corporate.lib.stripe import RealmBillingSession, UpdatePlanRequest
RealmBillingSession, from corporate.models import CustomerPlan, get_customer_by_realm
do_change_plan_status,
downgrade_at_the_end_of_billing_cycle,
downgrade_now_without_creating_additional_invoices,
get_latest_seat_count,
update_license_ledger_for_manual_plan,
validate_licenses,
)
from corporate.models import CustomerPlan, get_current_plan_by_realm, get_customer_by_realm
from zerver.decorator import require_billing_access, zulip_login_required from zerver.decorator import require_billing_access, zulip_login_required
from zerver.lib.exceptions import JsonableError
from zerver.lib.request import REQ, has_request_variables from zerver.lib.request import REQ, has_request_variables
from zerver.lib.response import json_success from zerver.lib.response import json_success
from zerver.lib.validator import check_int, check_int_in, check_string from zerver.lib.validator import check_int, check_int_in, check_string
@ -142,96 +131,11 @@ def update_plan(
"licenses_at_next_renewal", json_validator=check_int, default=None "licenses_at_next_renewal", json_validator=check_int, default=None
), ),
) -> HttpResponse: ) -> HttpResponse:
plan = get_current_plan_by_realm(user.realm) update_plan_request = UpdatePlanRequest(
assert plan is not None # for mypy status=status,
licenses=licenses,
realm = plan.customer.realm licenses_at_next_renewal=licenses_at_next_renewal,
billing_session = RealmBillingSession(user=None, realm=realm)
new_plan, last_ledger_entry = billing_session.make_end_of_cycle_updates_if_needed(
plan, timezone_now()
) )
if new_plan is not None: billing_session = RealmBillingSession(user=user)
raise JsonableError( billing_session.do_update_plan(update_plan_request)
_("Unable to update the plan. The plan has been expired and replaced with a new plan.") return json_success(request)
)
if last_ledger_entry is None:
raise JsonableError(_("Unable to update the plan. The plan has ended."))
if status is not None:
if status == CustomerPlan.ACTIVE:
assert plan.status < CustomerPlan.LIVE_STATUS_THRESHOLD
do_change_plan_status(plan, status)
elif status == CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE:
assert plan.status < CustomerPlan.LIVE_STATUS_THRESHOLD
downgrade_at_the_end_of_billing_cycle(user.realm)
elif status == CustomerPlan.SWITCH_TO_ANNUAL_AT_END_OF_CYCLE:
assert plan.billing_schedule == CustomerPlan.MONTHLY
assert plan.status < CustomerPlan.LIVE_STATUS_THRESHOLD
# Customer needs to switch to an active plan first to avoid unexpected behavior.
assert plan.status != CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE
assert plan.fixed_price is None
do_change_plan_status(plan, status)
elif status == CustomerPlan.SWITCH_TO_MONTHLY_AT_END_OF_CYCLE:
assert plan.billing_schedule == CustomerPlan.ANNUAL
assert plan.status < CustomerPlan.LIVE_STATUS_THRESHOLD
# Customer needs to switch to an active plan first to avoid unexpected behavior.
assert plan.status != CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE
assert plan.fixed_price is None
do_change_plan_status(plan, status)
elif status == CustomerPlan.ENDED:
assert plan.is_free_trial()
downgrade_now_without_creating_additional_invoices(user.realm)
return json_success(request)
if licenses is not None:
if plan.automanage_licenses:
raise JsonableError(
_(
"Unable to update licenses manually. Your plan is on automatic license management."
)
)
if last_ledger_entry.licenses == licenses:
raise JsonableError(
_(
"Your plan is already on {licenses} licenses in the current billing period."
).format(licenses=licenses)
)
if last_ledger_entry.licenses > licenses:
raise JsonableError(
_("You cannot decrease the licenses in the current billing period.")
)
validate_licenses(
plan.charge_automatically,
licenses,
get_latest_seat_count(user.realm),
plan.customer.exempt_from_license_number_check,
)
update_license_ledger_for_manual_plan(plan, timezone_now(), licenses=licenses)
return json_success(request)
if licenses_at_next_renewal is not None:
if plan.automanage_licenses:
raise JsonableError(
_(
"Unable to update licenses manually. Your plan is on automatic license management."
)
)
if last_ledger_entry.licenses_at_next_renewal == licenses_at_next_renewal:
raise JsonableError(
_(
"Your plan is already scheduled to renew with {licenses_at_next_renewal} licenses."
).format(licenses_at_next_renewal=licenses_at_next_renewal)
)
validate_licenses(
plan.charge_automatically,
licenses_at_next_renewal,
get_latest_seat_count(user.realm),
plan.customer.exempt_from_license_number_check,
)
update_license_ledger_for_manual_plan(
plan, timezone_now(), licenses_at_next_renewal=licenses_at_next_renewal
)
return json_success(request)
raise JsonableError(_("Nothing to change."))

View File

@ -39,7 +39,7 @@ from zerver.models import (
from zerver.tornado.django_api import send_event, send_event_on_commit from zerver.tornado.django_api import send_event, send_event_on_commit
if settings.BILLING_ENABLED: if settings.BILLING_ENABLED:
from corporate.lib.stripe import downgrade_now_without_creating_additional_invoices from corporate.lib.stripe import RealmBillingSession
def active_humans_in_realm(realm: Realm) -> QuerySet[UserProfile]: def active_humans_in_realm(realm: Realm) -> QuerySet[UserProfile]:
@ -309,7 +309,8 @@ def do_deactivate_realm(realm: Realm, *, acting_user: Optional[UserProfile]) ->
realm.save(update_fields=["deactivated"]) realm.save(update_fields=["deactivated"])
if settings.BILLING_ENABLED: if settings.BILLING_ENABLED:
downgrade_now_without_creating_additional_invoices(realm) billing_session = RealmBillingSession(user=acting_user, realm=realm)
billing_session.downgrade_now_without_creating_additional_invoices()
event_time = timezone_now() event_time = timezone_now()
RealmAuditLog.objects.create( RealmAuditLog.objects.create(
@ -389,7 +390,8 @@ def do_delete_all_realm_attachments(realm: Realm, *, batch_size: int = 1000) ->
def do_scrub_realm(realm: Realm, *, acting_user: Optional[UserProfile]) -> None: def do_scrub_realm(realm: Realm, *, acting_user: Optional[UserProfile]) -> None:
if settings.BILLING_ENABLED: if settings.BILLING_ENABLED:
downgrade_now_without_creating_additional_invoices(realm) billing_session = RealmBillingSession(user=acting_user, realm=realm)
billing_session.downgrade_now_without_creating_additional_invoices()
users = UserProfile.objects.filter(realm=realm) users = UserProfile.objects.filter(realm=realm)
for user in users: for user in users: