diff --git a/analytics/tests/test_support_views.py b/analytics/tests/test_support_views.py index 248cc48cfe..f3441a2e5c 100644 --- a/analytics/tests/test_support_views.py +++ b/analytics/tests/test_support_views.py @@ -9,7 +9,13 @@ from django.utils.timezone import now as timezone_now from typing_extensions import override from corporate.lib.stripe import RealmBillingSession, add_months -from corporate.models import Customer, CustomerPlan, LicenseLedger, get_customer_by_realm +from corporate.models import ( + Customer, + CustomerPlan, + LicenseLedger, + get_current_plan_by_realm, + get_customer_by_realm, +) from zerver.actions.invites import do_create_multiuse_invite_link from zerver.actions.realm_settings import do_change_realm_org_type, do_send_realm_reactivation_email from zerver.actions.user_settings import do_change_user_setting @@ -434,39 +440,50 @@ class TestSupportEndpoint(ZulipTestCase): result, ) - @mock.patch("analytics.views.support.update_realm_billing_modality") - def test_change_billing_modality(self, m: mock.Mock) -> None: + def test_change_billing_modality(self) -> None: + realm = get_realm("zulip") cordelia = self.example_user("cordelia") self.login_user(cordelia) - result = self.client_post( - "/activity/support", {"realm_id": f"{cordelia.realm_id}", "plan_type": "2"} + "/activity/support", + {"realm_id": f"{realm.id}", "billing_method": "charge_automatically"}, ) self.assertEqual(result.status_code, 302) self.assertEqual(result["Location"], "/login/") + customer = Customer.objects.create(realm=realm, stripe_customer_id="cus_12345") + CustomerPlan.objects.create( + customer=customer, + status=CustomerPlan.ACTIVE, + billing_cycle_anchor=timezone_now(), + billing_schedule=CustomerPlan.BILLING_SCHEDULE_ANNUAL, + tier=CustomerPlan.TIER_CLOUD_STANDARD, + ) + iago = self.example_user("iago") self.login_user(iago) result = self.client_post( "/activity/support", - {"realm_id": f"{iago.realm_id}", "billing_modality": "charge_automatically"}, + {"realm_id": f"{realm.id}", "billing_modality": "charge_automatically"}, ) - m.assert_called_once_with(get_realm("zulip"), charge_automatically=True, acting_user=iago) self.assert_in_success_response( ["Billing collection method of zulip updated to charge automatically"], result ) - - m.reset_mock() + plan = get_current_plan_by_realm(realm) + assert plan is not None + self.assertEqual(plan.charge_automatically, True) result = self.client_post( - "/activity/support", - {"realm_id": f"{iago.realm_id}", "billing_modality": "send_invoice"}, + "/activity/support", {"realm_id": f"{realm.id}", "billing_modality": "send_invoice"} ) - m.assert_called_once_with(get_realm("zulip"), charge_automatically=False, acting_user=iago) self.assert_in_success_response( ["Billing collection method of zulip updated to send invoice"], result ) + realm.refresh_from_db() + plan = get_current_plan_by_realm(realm) + assert plan is not None + self.assertEqual(plan.charge_automatically, False) def test_change_realm_plan_type(self) -> None: cordelia = self.example_user("cordelia") diff --git a/analytics/views/support.py b/analytics/views/support.py index 389f1d4eaf..2c7bac46fb 100644 --- a/analytics/views/support.py +++ b/analytics/views/support.py @@ -63,7 +63,6 @@ if settings.BILLING_ENABLED: from corporate.lib.support import ( get_discount_for_realm, switch_realm_from_standard_to_plus_plan, - update_realm_billing_modality, ) from corporate.models import ( Customer, @@ -210,6 +209,11 @@ def support( support_type=SupportType.attach_discount, discount=discount, ) + elif billing_modality is not None: + support_view_request = SupportViewRequest( + support_type=SupportType.update_billing_modality, + billing_modality=billing_modality, + ) elif plan_type is not None: current_plan_type = realm.plan_type do_change_realm_plan_type(realm, plan_type, acting_user=acting_user) @@ -243,21 +247,6 @@ def support( elif status == "deactivated": do_deactivate_realm(realm, acting_user=acting_user) context["success_message"] = f"{realm.string_id} deactivated." - elif billing_modality is not None: - if billing_modality == "send_invoice": - update_realm_billing_modality( - realm, charge_automatically=False, acting_user=acting_user - ) - context[ - "success_message" - ] = f"Billing collection method of {realm.string_id} updated to send invoice." - elif billing_modality == "charge_automatically": - update_realm_billing_modality( - realm, charge_automatically=True, acting_user=acting_user - ) - context[ - "success_message" - ] = f"Billing collection method of {realm.string_id} updated to charge automatically." elif modify_plan is not None: billing_session = RealmBillingSession( user=acting_user, realm=realm, support_session=True diff --git a/corporate/lib/stripe.py b/corporate/lib/stripe.py index baa29e2889..7703c3d498 100644 --- a/corporate/lib/stripe.py +++ b/corporate/lib/stripe.py @@ -526,12 +526,14 @@ class SupportType(Enum): approve_sponsorship = 1 update_sponsorship_status = 2 attach_discount = 3 + update_billing_modality = 4 class SupportViewRequest(TypedDict, total=False): support_type: SupportType sponsorship_status: Optional[bool] discount: Optional[Decimal] + billing_modality: Optional[str] class AuditLogEventType(Enum): @@ -930,7 +932,7 @@ class BillingSession(ABC): ) return success_message - def update_billing_modality_of_current_plan(self, charge_automatically: bool) -> None: + def update_billing_modality_of_current_plan(self, charge_automatically: bool) -> str: customer = self.get_customer() if customer is not None: plan = get_current_plan_by_customer(customer) @@ -942,6 +944,11 @@ class BillingSession(ABC): event_time=timezone_now(), extra_data={"charge_automatically": charge_automatically}, ) + if charge_automatically: + success_message = f"Billing collection method of {self.billing_entity_display_name} updated to charge automatically." + else: + success_message = f"Billing collection method of {self.billing_entity_display_name} updated to send invoice." + return success_message def setup_upgrade_payment_intent_and_charge( self, @@ -2023,6 +2030,11 @@ class BillingSession(ABC): assert support_request["discount"] is not None new_discount = support_request["discount"] success_message = self.attach_discount_to_customer(new_discount) + elif support_type == SupportType.update_billing_modality: + assert support_request["billing_modality"] is not None + assert support_request["billing_modality"] in VALID_BILLING_MODALITY_VALUES + charge_automatically = support_request["billing_modality"] == "charge_automatically" + success_message = self.update_billing_modality_of_current_plan(charge_automatically) return success_message diff --git a/corporate/lib/support.py b/corporate/lib/support.py index ec334beb4c..590b311174 100644 --- a/corporate/lib/support.py +++ b/corporate/lib/support.py @@ -7,7 +7,7 @@ from django.urls import reverse from corporate.lib.stripe import RealmBillingSession from corporate.models import CustomerPlan, get_customer_by_realm -from zerver.models import Realm, UserProfile, get_realm +from zerver.models import Realm, get_realm def get_support_url(realm: Realm) -> str: @@ -26,13 +26,6 @@ def get_discount_for_realm(realm: Realm) -> Optional[Decimal]: return None -def update_realm_billing_modality( - realm: Realm, charge_automatically: bool, *, acting_user: UserProfile -) -> None: - billing_session = RealmBillingSession(acting_user, realm, support_session=True) - billing_session.update_billing_modality_of_current_plan(charge_automatically) - - def switch_realm_from_standard_to_plus_plan(realm: Realm) -> None: billing_session = RealmBillingSession(realm=realm) billing_session.do_change_plan_to_new_tier(new_plan_tier=CustomerPlan.TIER_CLOUD_PLUS) diff --git a/corporate/tests/test_stripe.py b/corporate/tests/test_stripe.py index 1d4381ddfa..cc2f35fb16 100644 --- a/corporate/tests/test_stripe.py +++ b/corporate/tests/test_stripe.py @@ -74,11 +74,7 @@ from corporate.lib.stripe import ( update_license_ledger_if_needed, void_all_open_invoices, ) -from corporate.lib.support import ( - get_discount_for_realm, - switch_realm_from_standard_to_plus_plan, - update_realm_billing_modality, -) +from corporate.lib.support import get_discount_for_realm, switch_realm_from_standard_to_plus_plan from corporate.models import ( Customer, CustomerPlan, @@ -5252,8 +5248,9 @@ class TestSupportBillingHelpers(StripeTestCase): ) self.assertEqual(plan.charge_automatically, False) - iago = self.example_user("iago") - update_realm_billing_modality(realm, True, acting_user=iago) + support_admin = self.example_user("iago") + billing_session = RealmBillingSession(user=support_admin, realm=realm, support_session=True) + billing_session.update_billing_modality_of_current_plan(True) plan.refresh_from_db() self.assertEqual(plan.charge_automatically, True) realm_audit_log = RealmAuditLog.objects.filter( @@ -5261,10 +5258,10 @@ class TestSupportBillingHelpers(StripeTestCase): ).last() assert realm_audit_log is not None expected_extra_data = {"charge_automatically": plan.charge_automatically} - self.assertEqual(realm_audit_log.acting_user, iago) + self.assertEqual(realm_audit_log.acting_user, support_admin) self.assertEqual(realm_audit_log.extra_data, expected_extra_data) - update_realm_billing_modality(realm, False, acting_user=iago) + billing_session.update_billing_modality_of_current_plan(False) plan.refresh_from_db() self.assertEqual(plan.charge_automatically, False) realm_audit_log = RealmAuditLog.objects.filter( @@ -5272,7 +5269,7 @@ class TestSupportBillingHelpers(StripeTestCase): ).last() assert realm_audit_log is not None expected_extra_data = {"charge_automatically": plan.charge_automatically} - self.assertEqual(realm_audit_log.acting_user, iago) + self.assertEqual(realm_audit_log.acting_user, support_admin) self.assertEqual(realm_audit_log.extra_data, expected_extra_data) @mock_stripe()