test_stripe: Add end-to-end test for RemoteRealm billing flow.

This commit is contained in:
Prakhar Pratyush 2023-12-13 16:55:23 +05:30 committed by Tim Abbott
parent 33e04362e1
commit 1588f49b4f
29 changed files with 133 additions and 23 deletions

View File

@ -3384,7 +3384,7 @@ class RemoteRealmBillingSession(BillingSession):
self.remote_realm.save(update_fields=["org_type"]) self.remote_realm.save(update_fields=["org_type"])
@override @override
def sync_license_ledger_if_needed(self) -> None: # nocoverage def sync_license_ledger_if_needed(self) -> None:
last_ledger = self.get_last_ledger_for_automanaged_plan_if_exists() last_ledger = self.get_last_ledger_for_automanaged_plan_if_exists()
if last_ledger is None: if last_ledger is None:
return return
@ -3406,7 +3406,7 @@ class RemoteRealmBillingSession(BillingSession):
current_plan, audit_log.event_time current_plan, audit_log.event_time
) )
if end_of_cycle_plan is None: if end_of_cycle_plan is None:
return return # nocoverage
current_plan = end_of_cycle_plan current_plan = end_of_cycle_plan
def get_push_service_validity_dict(self) -> RemoteRealmDictValue: def get_push_service_validity_dict(self) -> RemoteRealmDictValue:

View File

@ -34,6 +34,7 @@ import stripe.util
import time_machine import time_machine
from django.conf import settings from django.conf import settings
from django.core import signing from django.core import signing
from django.test import override_settings
from django.urls.resolvers import get_resolver from django.urls.resolvers import get_resolver
from django.utils.crypto import get_random_string from django.utils.crypto import get_random_string
from django.utils.timezone import now as timezone_now from django.utils.timezone import now as timezone_now
@ -96,7 +97,8 @@ from zerver.actions.create_user import (
) )
from zerver.actions.realm_settings import do_deactivate_realm, do_reactivate_realm from zerver.actions.realm_settings import do_deactivate_realm, do_reactivate_realm
from zerver.actions.users import do_deactivate_user from zerver.actions.users import do_deactivate_user
from zerver.lib.test_classes import ZulipTestCase from zerver.lib.remote_server import send_server_data_to_push_bouncer
from zerver.lib.test_classes import BouncerTestCase, ZulipTestCase
from zerver.lib.timestamp import datetime_to_timestamp, timestamp_to_datetime from zerver.lib.timestamp import datetime_to_timestamp, timestamp_to_datetime
from zerver.lib.utils import assert_is_not_none from zerver.lib.utils import assert_is_not_none
from zerver.models import ( from zerver.models import (
@ -432,7 +434,9 @@ class StripeTestCase(ZulipTestCase):
hamlet.is_billing_admin = True hamlet.is_billing_admin = True
hamlet.save(update_fields=["is_billing_admin"]) hamlet.save(update_fields=["is_billing_admin"])
self.billing_session = RealmBillingSession(user=hamlet, realm=realm) self.billing_session: Union[
RealmBillingSession, RemoteRealmBillingSession
] = RealmBillingSession(user=hamlet, realm=realm)
def get_signed_seat_count_from_response(self, response: "TestHttpResponse") -> Optional[str]: def get_signed_seat_count_from_response(self, response: "TestHttpResponse") -> Optional[str]:
match = re.search(r"name=\"signed_seat_count\" value=\"(.+)\"", response.content.decode()) match = re.search(r"name=\"signed_seat_count\" value=\"(.+)\"", response.content.decode())
@ -575,7 +579,14 @@ class StripeTestCase(ZulipTestCase):
**kwargs: Any, **kwargs: Any,
) -> "TestHttpResponse": ) -> "TestHttpResponse":
if upgrade_page_response is None: if upgrade_page_response is None:
upgrade_page_response = self.client_get("/upgrade/", {}) if self.billing_session.billing_base_url:
upgrade_page_response = self.client_get(
f"{self.billing_session.billing_base_url}/upgrade/", {}, subdomain="selfhosting"
)
else:
upgrade_page_response = self.client_get(
f"{self.billing_session.billing_base_url}/upgrade/", {}
)
params: Dict[str, Any] = { params: Dict[str, Any] = {
"schedule": "annual", "schedule": "annual",
"signed_seat_count": self.get_signed_seat_count_from_response(upgrade_page_response), "signed_seat_count": self.get_signed_seat_count_from_response(upgrade_page_response),
@ -629,15 +640,22 @@ class StripeTestCase(ZulipTestCase):
self.send_stripe_webhook_events(last_event) self.send_stripe_webhook_events(last_event)
return upgrade_json_response return upgrade_json_response
def add_card_and_upgrade(self, user: UserProfile, **kwargs: Any) -> stripe.Customer: def add_card_and_upgrade(
self, user: Optional[UserProfile] = None, **kwargs: Any
) -> stripe.Customer:
# Add card # Add card
with time_machine.travel(self.now, tick=False): with time_machine.travel(self.now, tick=False):
self.add_card_to_customer_for_upgrade() self.add_card_to_customer_for_upgrade()
# Check that we correctly created a Customer object in Stripe # Check that we correctly created a Customer object in Stripe
stripe_customer = stripe_get_customer( if user is not None:
assert_is_not_none(Customer.objects.get(realm=user.realm).stripe_customer_id) stripe_customer = stripe_get_customer(
) assert_is_not_none(Customer.objects.get(realm=user.realm).stripe_customer_id)
)
else:
customer = self.billing_session.get_customer()
assert customer is not None
stripe_customer = stripe_get_customer(assert_is_not_none(customer.stripe_customer_id))
self.assertTrue(stripe_customer_has_credit_card_as_default_payment_method(stripe_customer)) self.assertTrue(stripe_customer_has_credit_card_as_default_payment_method(stripe_customer))
with time_machine.travel(self.now, tick=False): with time_machine.travel(self.now, tick=False):
@ -714,7 +732,7 @@ class StripeTestCase(ZulipTestCase):
def client_billing_patch(self, url_suffix: str, info: Mapping[str, Any] = {}) -> Any: def client_billing_patch(self, url_suffix: str, info: Mapping[str, Any] = {}) -> Any:
url = f"/json{self.billing_session.billing_base_url}" + url_suffix url = f"/json{self.billing_session.billing_base_url}" + url_suffix
if self.billing_session.billing_base_url: if self.billing_session.billing_base_url:
response = self.client_patch(url, info, subdomain="selfhosting") response = self.client_patch(url, info, subdomain="selfhosting") # nocoverage
else: else:
response = self.client_patch(url, info) response = self.client_patch(url, info)
return response return response
@ -5587,3 +5605,95 @@ class TestRemoteBillingWriteAuditLog(StripeTestCase):
assert_audit_log( assert_audit_log(
audit_log, None, support_admin, audit_log_class.CUSTOMER_PLAN_CREATED, event_time audit_log, None, support_admin, audit_log_class.CUSTOMER_PLAN_CREATED, event_time
) )
@override_settings(PUSH_NOTIFICATION_BOUNCER_URL="https://push.zulip.org.example.com")
class TestRemoteRealmBillingFlow(StripeTestCase, BouncerTestCase):
@override
def setUp(self) -> None:
# We need to time travel to 2012-1-2 because super().setUp()
# creates users and changes roles with event_time=timezone_now().
# That affects the LicenseLedger queries as their event_time would
# be more recent than other operations we perform in this test.
with time_machine.travel(datetime(2012, 1, 2, 3, 4, 5, tzinfo=timezone.utc), tick=False):
super().setUp()
hamlet = self.example_user("hamlet")
remote_realm = RemoteRealm.objects.get(uuid=hamlet.realm.uuid)
self.billing_session = RemoteRealmBillingSession(remote_realm=remote_realm)
@responses.activate
@mock_stripe()
def test_non_sponsorship_billing(self, *mocks: Mock) -> None:
self.add_mock_response()
with time_machine.travel(self.now, tick=False):
send_server_data_to_push_bouncer(consider_usage_statistics=False)
self.login("hamlet")
hamlet = self.example_user("hamlet")
result = self.execute_remote_billing_authentication_flow(hamlet)
self.assertEqual(result.status_code, 302)
self.assertEqual(result["Location"], f"{self.billing_session.billing_base_url}/plans/")
# upgrade to business plan
with time_machine.travel(self.now, tick=False):
result = self.client_get(
f"{self.billing_session.billing_base_url}/upgrade/", subdomain="selfhosting"
)
self.assertEqual(result.status_code, 200)
self.assert_in_success_response(["Add card", "Purchase Zulip Business"], result)
self.assertFalse(Customer.objects.exists())
self.assertFalse(CustomerPlan.objects.exists())
self.assertFalse(LicenseLedger.objects.exists())
with time_machine.travel(self.now, tick=False):
stripe_customer = self.add_card_and_upgrade()
customer = Customer.objects.get(stripe_customer_id=stripe_customer.id)
plan = CustomerPlan.objects.get(customer=customer)
LicenseLedger.objects.get(plan=plan)
with time_machine.travel(self.now + timedelta(days=1), tick=False):
response = self.client_get(
f"{self.billing_session.billing_base_url}/billing/", subdomain="selfhosting"
)
for substring in [
"Zulip Business",
"Number of licenses",
"10 (managed automatically)",
"Your plan will automatically renew on",
"Visa ending in 4242",
"Update card",
]:
self.assert_in_response(substring, response)
# Verify that change in user count updates LicenseLedger.
audit_log_count = RemoteRealmAuditLog.objects.count()
self.assertEqual(LicenseLedger.objects.count(), 1)
with time_machine.travel(self.now + timedelta(days=2), tick=False):
user_count = self.billing_session.current_count_for_billed_licenses(
self.now + timedelta(days=2)
)
for count in range(10):
do_create_user(
f"email {count}",
f"password {count}",
hamlet.realm,
"name",
role=UserProfile.ROLE_MEMBER,
acting_user=None,
)
with time_machine.travel(self.now + timedelta(days=3), tick=False):
send_server_data_to_push_bouncer(consider_usage_statistics=False)
self.assertEqual(
RemoteRealmAuditLog.objects.count(),
audit_log_count + 10,
)
latest_ledger = LicenseLedger.objects.last()
assert latest_ledger is not None
self.assertEqual(latest_ledger.licenses, user_count + 10)

View File

@ -99,7 +99,7 @@ def remote_realm_billing_page(
billing_session: RemoteRealmBillingSession, billing_session: RemoteRealmBillingSession,
*, *,
success_message: str = "", success_message: str = "",
) -> HttpResponse: # nocoverage ) -> HttpResponse:
realm_uuid = billing_session.remote_realm.uuid realm_uuid = billing_session.remote_realm.uuid
context: Dict[str, Any] = { context: Dict[str, Any] = {
# We wouldn't be here if user didn't have access. # We wouldn't be here if user didn't have access.
@ -109,11 +109,11 @@ def remote_realm_billing_page(
"billing_base_url": billing_session.billing_base_url, "billing_base_url": billing_session.billing_base_url,
} }
if billing_session.remote_realm.plan_type == RemoteRealm.PLAN_TYPE_COMMUNITY: if billing_session.remote_realm.plan_type == RemoteRealm.PLAN_TYPE_COMMUNITY: # nocoverage
return HttpResponseRedirect(reverse("remote_realm_sponsorship_page", args=(realm_uuid,))) return HttpResponseRedirect(reverse("remote_realm_sponsorship_page", args=(realm_uuid,)))
customer = billing_session.get_customer() customer = billing_session.get_customer()
if customer is not None and customer.sponsorship_pending: if customer is not None and customer.sponsorship_pending: # nocoverage
# Don't redirect to sponsorship page if the remote realm is on a paid plan or scheduled for an upgrade. # Don't redirect to sponsorship page if the remote realm is on a paid plan or scheduled for an upgrade.
if ( if (
not billing_session.on_paid_plan() not billing_session.on_paid_plan()
@ -136,12 +136,12 @@ def remote_realm_billing_page(
RemoteRealm.PLAN_TYPE_SELF_MANAGED_LEGACY, RemoteRealm.PLAN_TYPE_SELF_MANAGED_LEGACY,
] ]
) )
): ): # nocoverage
return HttpResponseRedirect(reverse("remote_realm_plans_page", args=(realm_uuid,))) return HttpResponseRedirect(reverse("remote_realm_plans_page", args=(realm_uuid,)))
try: try:
main_context = billing_session.get_billing_page_context() main_context = billing_session.get_billing_page_context()
except MissingDataError: except MissingDataError: # nocoverage
return billing_session.missing_data_error_page(request) return billing_session.missing_data_error_page(request)
if main_context: if main_context:

View File

@ -48,7 +48,7 @@ def remote_realm_event_status(
*, *,
stripe_session_id: Optional[str] = None, stripe_session_id: Optional[str] = None,
stripe_payment_intent_id: Optional[str] = None, stripe_payment_intent_id: Optional[str] = None,
) -> HttpResponse: # nocoverage ) -> HttpResponse:
event_status_request = EventStatusRequest( event_status_request = EventStatusRequest(
stripe_session_id=stripe_session_id, stripe_payment_intent_id=stripe_payment_intent_id stripe_session_id=stripe_session_id, stripe_payment_intent_id=stripe_payment_intent_id
) )

View File

@ -75,7 +75,7 @@ def start_card_update_stripe_session_for_remote_realm_upgrade(
billing_session: RemoteRealmBillingSession, billing_session: RemoteRealmBillingSession,
*, *,
manual_license_management: Json[bool] = False, manual_license_management: Json[bool] = False,
) -> HttpResponse: # nocoverage ) -> HttpResponse:
session_data = billing_session.create_card_update_session_for_upgrade(manual_license_management) session_data = billing_session.create_card_update_session_for_upgrade(manual_license_management)
return json_success( return json_success(
request, request,

View File

@ -97,7 +97,7 @@ def remote_realm_upgrade(
), ),
licenses: Optional[int] = REQ(json_validator=check_int, default=None), licenses: Optional[int] = REQ(json_validator=check_int, default=None),
remote_server_plan_start_date: Optional[str] = REQ(default=None), remote_server_plan_start_date: Optional[str] = REQ(default=None),
) -> HttpResponse: # nocoverage ) -> HttpResponse:
try: try:
upgrade_request = UpgradeRequest( upgrade_request = UpgradeRequest(
billing_modality=billing_modality, billing_modality=billing_modality,
@ -112,7 +112,7 @@ def remote_realm_upgrade(
) )
data = billing_session.do_upgrade(upgrade_request) data = billing_session.do_upgrade(upgrade_request)
return json_success(request, data) return json_success(request, data)
except BillingError as e: except BillingError as e: # nocoverage
billing_logger.warning( billing_logger.warning(
"BillingError during upgrade: %s. remote_realm=%s (%s), billing_modality=%s, " "BillingError during upgrade: %s. remote_realm=%s (%s), billing_modality=%s, "
"schedule=%s, license_management=%s, licenses=%s", "schedule=%s, license_management=%s, licenses=%s",
@ -125,7 +125,7 @@ def remote_realm_upgrade(
licenses, licenses,
) )
raise e raise e
except Exception: except Exception: # nocoverage
billing_logger.exception("Uncaught exception in billing:", stack_info=True) billing_logger.exception("Uncaught exception in billing:", stack_info=True)
error_message = BillingError.CONTACT_SUPPORT.format(email=settings.ZULIP_ADMINISTRATOR) error_message = BillingError.CONTACT_SUPPORT.format(email=settings.ZULIP_ADMINISTRATOR)
error_description = "uncaught exception during upgrade" error_description = "uncaught exception during upgrade"
@ -215,7 +215,7 @@ def remote_realm_upgrade_page(
*, *,
manual_license_management: Json[bool] = False, manual_license_management: Json[bool] = False,
success_message: str = "", success_message: str = "",
) -> HttpResponse: # nocoverage ) -> HttpResponse:
initial_upgrade_request = InitialUpgradeRequest( initial_upgrade_request = InitialUpgradeRequest(
manual_license_management=manual_license_management, manual_license_management=manual_license_management,
tier=CustomerPlan.TIER_SELF_HOSTED_BUSINESS, tier=CustomerPlan.TIER_SELF_HOSTED_BUSINESS,
@ -223,10 +223,10 @@ def remote_realm_upgrade_page(
) )
try: try:
redirect_url, context = billing_session.get_initial_upgrade_context(initial_upgrade_request) redirect_url, context = billing_session.get_initial_upgrade_context(initial_upgrade_request)
except MissingDataError: except MissingDataError: # nocoverage
return billing_session.missing_data_error_page(request) return billing_session.missing_data_error_page(request)
if redirect_url: if redirect_url: # nocoverage
return HttpResponseRedirect(redirect_url) return HttpResponseRedirect(redirect_url)
response = render(request, "corporate/upgrade.html", context=context) response = render(request, "corporate/upgrade.html", context=context)