billing: Add prototype remote billing sessions.

These new models are incomplete and totally untested, but merging this
will provide valuable scaffolding for doing smaller PRs working on
individual gaps, and reveals a clear set of TODOs/refactoring/model
changes needed to support where want to end up.

Co-authored-by: Tim Abbott <tabbott@zulip.com>
This commit is contained in:
Lauryn Menard 2023-11-09 20:40:42 +01:00 committed by Tim Abbott
parent f916385cab
commit 11cb37c9a4
4 changed files with 418 additions and 4 deletions

View File

@ -31,6 +31,8 @@ from corporate.models import (
get_current_plan_by_customer, get_current_plan_by_customer,
get_current_plan_by_realm, get_current_plan_by_realm,
get_customer_by_realm, get_customer_by_realm,
get_customer_by_remote_realm,
get_customer_by_remote_server,
) )
from zerver.lib.exceptions import JsonableError from zerver.lib.exceptions import JsonableError
from zerver.lib.logging_util import log_to_file from zerver.lib.logging_util import log_to_file
@ -38,7 +40,12 @@ from zerver.lib.send_email import FromAddress, send_email_to_billing_admins_and_
from zerver.lib.timestamp import datetime_to_timestamp, timestamp_to_datetime from zerver.lib.timestamp import datetime_to_timestamp, timestamp_to_datetime
from zerver.lib.utils import assert_is_not_none from zerver.lib.utils import assert_is_not_none
from zerver.models import Realm, RealmAuditLog, UserProfile, get_system_bot from zerver.models import Realm, RealmAuditLog, UserProfile, get_system_bot
from zilencer.models import RemoteZulipServer, RemoteZulipServerAuditLog from zilencer.models import (
RemoteRealm,
RemoteRealmAuditLog,
RemoteZulipServer,
RemoteZulipServerAuditLog,
)
from zproject.config import get_secret from zproject.config import get_secret
stripe.api_key = get_secret("stripe_secret_key") stripe.api_key = get_secret("stripe_secret_key")
@ -1195,6 +1202,9 @@ class RealmBillingSession(BillingSession):
assert self.support_session is False assert self.support_session is False
assert self.user is not None assert self.user is not None
amount = price_per_license * licenses amount = price_per_license * licenses
# TODO: Don't hardcode plan name; it should be looked up for
# the tier.
description = f"Upgrade to Zulip Cloud Standard, ${price_per_license/100} x {licenses}" description = f"Upgrade to Zulip Cloud Standard, ${price_per_license/100} x {licenses}"
plan_name = "Zulip Cloud Standard" plan_name = "Zulip Cloud Standard"
return StripePaymentIntentData( return StripePaymentIntentData(
@ -1275,6 +1285,7 @@ class RealmBillingSession(BillingSession):
"If you could {begin_link}list Zulip as a sponsor on your website{end_link}, " "If you could {begin_link}list Zulip as a sponsor on your website{end_link}, "
"we would really appreciate it!" "we would really appreciate it!"
).format( ).format(
# TODO: Don't hardcode plan names.
plan_name="Zulip Cloud Standard", plan_name="Zulip Cloud Standard",
emoji=":tada:", emoji=":tada:",
begin_link="[", begin_link="[",
@ -1283,6 +1294,338 @@ class RealmBillingSession(BillingSession):
internal_send_private_message(notification_bot, user, message) internal_send_private_message(notification_bot, user, message)
class RemoteRealmBillingSession(BillingSession): # nocoverage
def __init__(
self, remote_realm: RemoteRealm, support_staff: Optional[UserProfile] = None
) -> None:
self.remote_realm = remote_realm
if support_staff is not None:
assert support_staff.is_staff
self.support_session = True
else:
self.support_session = False
@override
@property
def billing_session_url(self) -> str:
return "TBD"
@override
def get_customer(self) -> Optional[Customer]:
return get_customer_by_remote_realm(self.remote_realm)
@override
def current_count_for_billed_licenses(self) -> int:
# TODO: Do the proper calculation here.
return 10
@override
def get_audit_log_event(self, event_type: AuditLogEventType) -> int:
if event_type is AuditLogEventType.STRIPE_CUSTOMER_CREATED:
return RemoteRealmAuditLog.STRIPE_CUSTOMER_CREATED
elif event_type is AuditLogEventType.STRIPE_CARD_CHANGED:
return RemoteRealmAuditLog.STRIPE_CARD_CHANGED
elif event_type is AuditLogEventType.CUSTOMER_PLAN_CREATED:
return RemoteRealmAuditLog.CUSTOMER_PLAN_CREATED
elif event_type is AuditLogEventType.DISCOUNT_CHANGED:
return RemoteRealmAuditLog.REMOTE_SERVER_DISCOUNT_CHANGED
elif event_type is AuditLogEventType.SPONSORSHIP_APPROVED:
return RemoteRealmAuditLog.REMOTE_SERVER_SPONSORSHIP_APPROVED
elif event_type is AuditLogEventType.SPONSORSHIP_PENDING_STATUS_CHANGED:
return RemoteRealmAuditLog.REMOTE_SERVER_SPONSORSHIP_PENDING_STATUS_CHANGED
elif event_type is AuditLogEventType.BILLING_METHOD_CHANGED:
return RemoteRealmAuditLog.REMOTE_SERVER_BILLING_METHOD_CHANGED
else:
raise BillingSessionAuditLogEventError(event_type)
@override
def write_to_audit_log(
self,
event_type: AuditLogEventType,
event_time: datetime,
*,
extra_data: Optional[Dict[str, Any]] = None,
) -> None:
# BUG: This doesn't have a way to pass realm_id !
audit_log_event = self.get_audit_log_event(event_type)
if extra_data:
RemoteRealmAuditLog.objects.create(
server=self.remote_realm.server,
remote_realm=self.remote_realm,
event_type=audit_log_event,
event_time=event_time,
extra_data=extra_data,
)
else:
RemoteRealmAuditLog.objects.create(
server=self.remote_realm.server,
remote_realm=self.remote_realm,
event_type=audit_log_event,
event_time=event_time,
)
@override
def get_data_for_stripe_customer(self) -> StripeCustomerData:
# Support requests do not set any stripe billing information.
assert self.support_session is False
metadata: Dict[str, Any] = {}
metadata["remote_realm_uuid"] = self.remote_realm.uuid
metadata["remote_realm_host"] = str(self.remote_realm.host)
realm_stripe_customer_data = StripeCustomerData(
description=str(self.remote_realm),
# BUG: This is an email for the whole server. We probably
# need a separable field here.
email=self.remote_realm.server.contact_email,
metadata=metadata,
)
return realm_stripe_customer_data
@override
def update_data_for_checkout_session_and_payment_intent(
self, metadata: Dict[str, Any]
) -> Dict[str, Any]:
# TODO: Figure out what this should do.
updated_metadata = dict(
**metadata,
)
return updated_metadata
@override
def get_data_for_stripe_payment_intent(
self, price_per_license: int, licenses: int
) -> StripePaymentIntentData:
# Support requests do not set any stripe billing information.
assert self.support_session is False
amount = price_per_license * licenses
# TODO: Don't hardcode plan names.
description = f"Upgrade to Zulip X Standard, ${price_per_license/100} x {licenses}"
plan_name = "Zulip X Standard"
return StripePaymentIntentData(
amount=amount,
description=description,
plan_name=plan_name,
# BUG: This is an email for the whole server. We probably
# need a separable field here.
email=self.remote_realm.server.contact_email,
)
@override
def update_or_create_customer(
self, stripe_customer_id: Optional[str] = None, *, defaults: Optional[Dict[str, Any]] = None
) -> Customer:
if stripe_customer_id is not None:
# Support requests do not set any stripe billing information.
assert self.support_session is False
customer, created = Customer.objects.update_or_create(
remote_realm=self.remote_realm,
defaults={"stripe_customer_id": stripe_customer_id},
)
return customer
else:
customer, created = Customer.objects.update_or_create(
remote_realm=self.remote_realm, defaults=defaults
)
return customer
@override
def do_change_plan_type(self, *, tier: Optional[int], is_sponsored: bool = False) -> None:
# TODO: Create actual plan types.
# This function needs to translate between the different
# formats of CustomerPlan.tier and Realm.plan_type.
if is_sponsored:
plan_type = RemoteRealm.PLAN_TYPE_COMMUNITY
elif tier == CustomerPlan.STANDARD:
plan_type = RemoteRealm.PLAN_TYPE_BUSINESS
elif tier == CustomerPlan.PLUS: # nocoverage # Plus plan doesn't use this code path yet.
plan_type = RemoteRealm.PLAN_TYPE_ENTERPRISE
else:
raise AssertionError("Unexpected tier")
# TODO: Audit logging.
self.remote_realm.plan_type = plan_type
self.remote_realm.save(update_fields=["plan_type"])
@override
def approve_sponsorship(self) -> None:
# TBD
pass
@override
def process_downgrade(self, plan: CustomerPlan) -> None:
self.remote_realm.plan_type = RemoteRealm.PLAN_TYPE_SELF_HOSTED
self.remote_realm.save(update_fields=["plan_type"])
# TODO: Write audit log entry
plan.status = CustomerPlan.ENDED
plan.save(update_fields=["status"])
class RemoteServerBillingSession(BillingSession): # nocoverage
"""Billing session for pre-8.0 servers that do not yet support
creating RemoteRealm objects."""
def __init__(
self, remote_server: RemoteZulipServer, support_staff: Optional[UserProfile] = None
) -> None:
self.remote_server = remote_server
if support_staff is not None:
assert support_staff.is_staff
self.support_session = True
else:
self.support_session = False
@override
@property
def billing_session_url(self) -> str:
return "TBD"
@override
def get_customer(self) -> Optional[Customer]:
return get_customer_by_remote_server(self.remote_server)
@override
def current_count_for_billed_licenses(self) -> int:
# TODO: Do the proper calculation here.
return 10
@override
def get_audit_log_event(self, event_type: AuditLogEventType) -> int:
if event_type is AuditLogEventType.STRIPE_CUSTOMER_CREATED:
return RemoteZulipServerAuditLog.STRIPE_CUSTOMER_CREATED
elif event_type is AuditLogEventType.STRIPE_CARD_CHANGED:
return RemoteZulipServerAuditLog.STRIPE_CARD_CHANGED
elif event_type is AuditLogEventType.CUSTOMER_PLAN_CREATED:
return RemoteZulipServerAuditLog.CUSTOMER_PLAN_CREATED
elif event_type is AuditLogEventType.DISCOUNT_CHANGED:
return RemoteZulipServerAuditLog.REMOTE_SERVER_DISCOUNT_CHANGED
elif event_type is AuditLogEventType.SPONSORSHIP_APPROVED:
return RemoteZulipServerAuditLog.REMOTE_SERVER_SPONSORSHIP_APPROVED
elif event_type is AuditLogEventType.SPONSORSHIP_PENDING_STATUS_CHANGED:
return RemoteZulipServerAuditLog.REMOTE_SERVER_SPONSORSHIP_PENDING_STATUS_CHANGED
elif event_type is AuditLogEventType.BILLING_METHOD_CHANGED:
return RemoteZulipServerAuditLog.REMOTE_SERVER_BILLING_METHOD_CHANGED
else:
raise BillingSessionAuditLogEventError(event_type)
@override
def write_to_audit_log(
self,
event_type: AuditLogEventType,
event_time: datetime,
*,
extra_data: Optional[Dict[str, Any]] = None,
) -> None:
audit_log_event = self.get_audit_log_event(event_type)
if extra_data:
RemoteZulipServerAuditLog.objects.create(
server=self.remote_server,
event_type=audit_log_event,
event_time=event_time,
extra_data=extra_data,
)
else:
RemoteZulipServerAuditLog.objects.create(
server=self.remote_server,
event_type=audit_log_event,
event_time=event_time,
)
@override
def get_data_for_stripe_customer(self) -> StripeCustomerData:
# Support requests do not set any stripe billing information.
assert self.support_session is False
metadata: Dict[str, Any] = {}
metadata["remote_server_uuid"] = self.remote_server.uuid
metadata["remote_server_str"] = str(self.remote_server)
realm_stripe_customer_data = StripeCustomerData(
description=str(self.remote_server),
email=self.remote_server.contact_email,
metadata=metadata,
)
return realm_stripe_customer_data
@override
def update_data_for_checkout_session_and_payment_intent(
self, metadata: Dict[str, Any]
) -> Dict[str, Any]:
updated_metadata = dict(
server=self.remote_server,
email=self.remote_server.contact_email,
**metadata,
)
return updated_metadata
@override
def get_data_for_stripe_payment_intent(
self, price_per_license: int, licenses: int
) -> StripePaymentIntentData:
# Support requests do not set any stripe billing information.
assert self.support_session is False
amount = price_per_license * licenses
description = f"Upgrade to Zulip X Standard, ${price_per_license/100} x {licenses}"
plan_name = "Zulip X Standard"
return StripePaymentIntentData(
amount=amount,
description=description,
plan_name=plan_name,
email=self.remote_server.contact_email,
)
@override
def update_or_create_customer(
self, stripe_customer_id: Optional[str] = None, *, defaults: Optional[Dict[str, Any]] = None
) -> Customer:
if stripe_customer_id is not None:
# Support requests do not set any stripe billing information.
assert self.support_session is False
customer, created = Customer.objects.update_or_create(
remote_server=self.remote_server,
defaults={"stripe_customer_id": stripe_customer_id},
)
return customer
else:
customer, created = Customer.objects.update_or_create(
remote_server=self.remote_server, defaults=defaults
)
return customer
@override
def do_change_plan_type(self, *, tier: Optional[int], is_sponsored: bool = False) -> None:
# TODO: Create actual plan types.
# This function needs to translate between the different
# formats of CustomerPlan.tier and RealmZulipServer.plan_type.
if is_sponsored:
plan_type = RemoteZulipServer.PLAN_TYPE_COMMUNITY
elif tier == CustomerPlan.STANDARD:
plan_type = RemoteZulipServer.PLAN_TYPE_BUSINESS
elif tier == CustomerPlan.PLUS: # nocoverage # Plus plan doesn't use this code path yet.
plan_type = RemoteZulipServer.PLAN_TYPE_ENTERPRISE
else:
raise AssertionError("Unexpected tier")
# TODO: Audit logging.
self.remote_server.plan_type = plan_type
self.remote_server.save(update_fields=["plan_type"])
@override
def approve_sponsorship(self) -> None:
# TBD
pass
@override
def process_downgrade(self, plan: CustomerPlan) -> None:
self.remote_server.plan_type = RemoteZulipServer.PLAN_TYPE_SELF_HOSTED
self.remote_server.save(update_fields=["plan_type"])
# TODO: Write audit log entry
plan.status = CustomerPlan.ENDED
plan.save(update_fields=["status"])
def stripe_customer_has_credit_card_as_default_payment_method( def stripe_customer_has_credit_card_as_default_payment_method(
stripe_customer: stripe.Customer, stripe_customer: stripe.Customer,
) -> bool: ) -> bool:

View File

@ -7,7 +7,7 @@ from django.db.models import CASCADE, Q
from typing_extensions import override from typing_extensions import override
from zerver.models import Realm, UserProfile from zerver.models import Realm, UserProfile
from zilencer.models import RemoteZulipServer from zilencer.models import RemoteRealm, RemoteZulipServer
class Customer(models.Model): class Customer(models.Model):
@ -56,6 +56,14 @@ def get_customer_by_realm(realm: Realm) -> Optional[Customer]:
return Customer.objects.filter(realm=realm).first() return Customer.objects.filter(realm=realm).first()
def get_customer_by_remote_server(remote_server: RemoteZulipServer) -> Optional[Customer]:
return Customer.objects.filter(remote_server=remote_server).first()
def get_customer_by_remote_realm(remote_realm: RemoteRealm) -> Optional[Customer]: # nocoverage
return Customer.objects.filter(remote_realm=remote_realm).first()
class Event(models.Model): class Event(models.Model):
stripe_event_id = models.CharField(max_length=255) stripe_event_id = models.CharField(max_length=255)

View File

@ -46,6 +46,7 @@ from corporate.lib.stripe import (
InvalidBillingScheduleError, InvalidBillingScheduleError,
InvalidTierError, InvalidTierError,
RealmBillingSession, RealmBillingSession,
RemoteServerBillingSession,
StripeCardError, StripeCardError,
add_months, add_months,
catch_stripe_errors, catch_stripe_errors,
@ -4901,6 +4902,64 @@ class TestRealmBillingSession(StripeTestCase):
): ):
billing_session.get_audit_log_event(event_type=fake_audit_log) billing_session.get_audit_log_event(event_type=fake_audit_log)
def test_get_customer(self) -> None:
user = self.example_user("hamlet")
billing_session = RealmBillingSession(user)
customer = billing_session.get_customer()
self.assertEqual(customer, None)
customer = Customer.objects.create(realm=user.realm, stripe_customer_id="cus_12345")
self.assertEqual(billing_session.get_customer(), customer)
class TestRemoteServerBillingSession(StripeTestCase):
def test_get_audit_log_error(self) -> None:
server_uuid = str(uuid.uuid4())
remote_server = RemoteZulipServer.objects.create(
uuid=server_uuid,
api_key="magic_secret_api_key",
hostname="demo.example.com",
contact_email="email@example.com",
)
billing_session = RemoteServerBillingSession(remote_server)
fake_audit_log = typing.cast(AuditLogEventType, 0)
with self.assertRaisesRegex(
BillingSessionAuditLogEventError, "Unknown audit log event type: 0"
):
billing_session.get_audit_log_event(event_type=fake_audit_log)
def test_get_customer(self) -> None:
server_uuid = str(uuid.uuid4())
remote_server = RemoteZulipServer.objects.create(
uuid=server_uuid,
api_key="magic_secret_api_key",
hostname="demo.example.com",
contact_email="email@example.com",
)
billing_session = RemoteServerBillingSession(remote_server)
customer = billing_session.get_customer()
self.assertEqual(customer, None)
customer = Customer.objects.create(
remote_server=remote_server, stripe_customer_id="cus_12345"
)
self.assertEqual(billing_session.get_customer(), customer)
# @mock_stripe
# def test_update_or_create_stripe_customer(self) -> None:
# server_uuid = str(uuid.uuid4())
# remote_server = RemoteZulipServer.objects.create(
# uuid=server_uuid,
# api_key="magic_secret_api_key",
# hostname="demo.example.com",
# contact_email="email@example.com",
# )
# billing_session = RemoteServerBillingSession(remote_server)
# # We need to generate stripe fixture for this type of test.
# customer = billing_session.update_or_create_stripe_customer()
# assert customer.stripe_customer_id
# # Confirm audit log, etc.
class TestSupportBillingHelpers(StripeTestCase): class TestSupportBillingHelpers(StripeTestCase):
def test_get_discount_for_realm(self) -> None: def test_get_discount_for_realm(self) -> None:

View File

@ -4806,9 +4806,13 @@ class AbstractRealmAuditLog(models.Model):
# Values should be exactly 10000 greater than the corresponding # Values should be exactly 10000 greater than the corresponding
# value used for the same purpose in RealmAuditLog (e.g. # value used for the same purpose in RealmAuditLog (e.g.
# REALM_DEACTIVATED = 201, and REMOTE_SERVER_DEACTIVATED = 10201). # REALM_DEACTIVATED = 201, and REMOTE_SERVER_DEACTIVATED = 10201).
REMOTE_SERVER_CREATED = 10215
REMOTE_SERVER_PLAN_TYPE_CHANGED = 10204
REMOTE_SERVER_DEACTIVATED = 10201 REMOTE_SERVER_DEACTIVATED = 10201
REMOTE_SERVER_PLAN_TYPE_CHANGED = 10204
REMOTE_SERVER_DISCOUNT_CHANGED = 10209
REMOTE_SERVER_SPONSORSHIP_APPROVED = 10210
REMOTE_SERVER_BILLING_METHOD_CHANGED = 10211
REMOTE_SERVER_SPONSORSHIP_PENDING_STATUS_CHANGED = 10213
REMOTE_SERVER_CREATED = 10215
# This value is for RemoteRealmAuditLog entries tracking changes to the # This value is for RemoteRealmAuditLog entries tracking changes to the
# RemoteRealm model resulting from modified realm information sent to us # RemoteRealm model resulting from modified realm information sent to us