diff --git a/corporate/lib/stripe.py b/corporate/lib/stripe.py index d8a1d00bb7..7d511bb534 100644 --- a/corporate/lib/stripe.py +++ b/corporate/lib/stripe.py @@ -31,6 +31,8 @@ from corporate.models import ( get_current_plan_by_customer, get_current_plan_by_realm, get_customer_by_realm, + get_customer_by_remote_realm, + get_customer_by_remote_server, ) from zerver.lib.exceptions import JsonableError from zerver.lib.logging_util import log_to_file @@ -38,7 +40,12 @@ from zerver.lib.send_email import FromAddress, send_email_to_billing_admins_and_ from zerver.lib.timestamp import datetime_to_timestamp, timestamp_to_datetime from zerver.lib.utils import assert_is_not_none from zerver.models import Realm, RealmAuditLog, UserProfile, get_system_bot -from zilencer.models import RemoteZulipServer, RemoteZulipServerAuditLog +from zilencer.models import ( + RemoteRealm, + RemoteRealmAuditLog, + RemoteZulipServer, + RemoteZulipServerAuditLog, +) from zproject.config import get_secret stripe.api_key = get_secret("stripe_secret_key") @@ -1195,6 +1202,9 @@ class RealmBillingSession(BillingSession): assert self.support_session is False assert self.user is not None amount = price_per_license * licenses + + # TODO: Don't hardcode plan name; it should be looked up for + # the tier. description = f"Upgrade to Zulip Cloud Standard, ${price_per_license/100} x {licenses}" plan_name = "Zulip Cloud Standard" return StripePaymentIntentData( @@ -1275,6 +1285,7 @@ class RealmBillingSession(BillingSession): "If you could {begin_link}list Zulip as a sponsor on your website{end_link}, " "we would really appreciate it!" ).format( + # TODO: Don't hardcode plan names. plan_name="Zulip Cloud Standard", emoji=":tada:", begin_link="[", @@ -1283,6 +1294,338 @@ class RealmBillingSession(BillingSession): internal_send_private_message(notification_bot, user, message) +class RemoteRealmBillingSession(BillingSession): # nocoverage + def __init__( + self, remote_realm: RemoteRealm, support_staff: Optional[UserProfile] = None + ) -> None: + self.remote_realm = remote_realm + if support_staff is not None: + assert support_staff.is_staff + self.support_session = True + else: + self.support_session = False + + @override + @property + def billing_session_url(self) -> str: + return "TBD" + + @override + def get_customer(self) -> Optional[Customer]: + return get_customer_by_remote_realm(self.remote_realm) + + @override + def current_count_for_billed_licenses(self) -> int: + # TODO: Do the proper calculation here. + return 10 + + @override + def get_audit_log_event(self, event_type: AuditLogEventType) -> int: + if event_type is AuditLogEventType.STRIPE_CUSTOMER_CREATED: + return RemoteRealmAuditLog.STRIPE_CUSTOMER_CREATED + elif event_type is AuditLogEventType.STRIPE_CARD_CHANGED: + return RemoteRealmAuditLog.STRIPE_CARD_CHANGED + elif event_type is AuditLogEventType.CUSTOMER_PLAN_CREATED: + return RemoteRealmAuditLog.CUSTOMER_PLAN_CREATED + elif event_type is AuditLogEventType.DISCOUNT_CHANGED: + return RemoteRealmAuditLog.REMOTE_SERVER_DISCOUNT_CHANGED + elif event_type is AuditLogEventType.SPONSORSHIP_APPROVED: + return RemoteRealmAuditLog.REMOTE_SERVER_SPONSORSHIP_APPROVED + elif event_type is AuditLogEventType.SPONSORSHIP_PENDING_STATUS_CHANGED: + return RemoteRealmAuditLog.REMOTE_SERVER_SPONSORSHIP_PENDING_STATUS_CHANGED + elif event_type is AuditLogEventType.BILLING_METHOD_CHANGED: + return RemoteRealmAuditLog.REMOTE_SERVER_BILLING_METHOD_CHANGED + else: + raise BillingSessionAuditLogEventError(event_type) + + @override + def write_to_audit_log( + self, + event_type: AuditLogEventType, + event_time: datetime, + *, + extra_data: Optional[Dict[str, Any]] = None, + ) -> None: + # BUG: This doesn't have a way to pass realm_id ! + audit_log_event = self.get_audit_log_event(event_type) + if extra_data: + RemoteRealmAuditLog.objects.create( + server=self.remote_realm.server, + remote_realm=self.remote_realm, + event_type=audit_log_event, + event_time=event_time, + extra_data=extra_data, + ) + else: + RemoteRealmAuditLog.objects.create( + server=self.remote_realm.server, + remote_realm=self.remote_realm, + event_type=audit_log_event, + event_time=event_time, + ) + + @override + def get_data_for_stripe_customer(self) -> StripeCustomerData: + # Support requests do not set any stripe billing information. + assert self.support_session is False + metadata: Dict[str, Any] = {} + metadata["remote_realm_uuid"] = self.remote_realm.uuid + metadata["remote_realm_host"] = str(self.remote_realm.host) + realm_stripe_customer_data = StripeCustomerData( + description=str(self.remote_realm), + # BUG: This is an email for the whole server. We probably + # need a separable field here. + email=self.remote_realm.server.contact_email, + metadata=metadata, + ) + return realm_stripe_customer_data + + @override + def update_data_for_checkout_session_and_payment_intent( + self, metadata: Dict[str, Any] + ) -> Dict[str, Any]: + # TODO: Figure out what this should do. + updated_metadata = dict( + **metadata, + ) + return updated_metadata + + @override + def get_data_for_stripe_payment_intent( + self, price_per_license: int, licenses: int + ) -> StripePaymentIntentData: + # Support requests do not set any stripe billing information. + assert self.support_session is False + amount = price_per_license * licenses + # TODO: Don't hardcode plan names. + description = f"Upgrade to Zulip X Standard, ${price_per_license/100} x {licenses}" + plan_name = "Zulip X Standard" + return StripePaymentIntentData( + amount=amount, + description=description, + plan_name=plan_name, + # BUG: This is an email for the whole server. We probably + # need a separable field here. + email=self.remote_realm.server.contact_email, + ) + + @override + def update_or_create_customer( + self, stripe_customer_id: Optional[str] = None, *, defaults: Optional[Dict[str, Any]] = None + ) -> Customer: + if stripe_customer_id is not None: + # Support requests do not set any stripe billing information. + assert self.support_session is False + customer, created = Customer.objects.update_or_create( + remote_realm=self.remote_realm, + defaults={"stripe_customer_id": stripe_customer_id}, + ) + return customer + else: + customer, created = Customer.objects.update_or_create( + remote_realm=self.remote_realm, defaults=defaults + ) + return customer + + @override + def do_change_plan_type(self, *, tier: Optional[int], is_sponsored: bool = False) -> None: + # TODO: Create actual plan types. + + # This function needs to translate between the different + # formats of CustomerPlan.tier and Realm.plan_type. + if is_sponsored: + plan_type = RemoteRealm.PLAN_TYPE_COMMUNITY + elif tier == CustomerPlan.STANDARD: + plan_type = RemoteRealm.PLAN_TYPE_BUSINESS + elif tier == CustomerPlan.PLUS: # nocoverage # Plus plan doesn't use this code path yet. + plan_type = RemoteRealm.PLAN_TYPE_ENTERPRISE + else: + raise AssertionError("Unexpected tier") + + # TODO: Audit logging. + + self.remote_realm.plan_type = plan_type + self.remote_realm.save(update_fields=["plan_type"]) + + @override + def approve_sponsorship(self) -> None: + # TBD + pass + + @override + def process_downgrade(self, plan: CustomerPlan) -> None: + self.remote_realm.plan_type = RemoteRealm.PLAN_TYPE_SELF_HOSTED + self.remote_realm.save(update_fields=["plan_type"]) + + # TODO: Write audit log entry + plan.status = CustomerPlan.ENDED + plan.save(update_fields=["status"]) + + +class RemoteServerBillingSession(BillingSession): # nocoverage + """Billing session for pre-8.0 servers that do not yet support + creating RemoteRealm objects.""" + + def __init__( + self, remote_server: RemoteZulipServer, support_staff: Optional[UserProfile] = None + ) -> None: + self.remote_server = remote_server + if support_staff is not None: + assert support_staff.is_staff + self.support_session = True + else: + self.support_session = False + + @override + @property + def billing_session_url(self) -> str: + return "TBD" + + @override + def get_customer(self) -> Optional[Customer]: + return get_customer_by_remote_server(self.remote_server) + + @override + def current_count_for_billed_licenses(self) -> int: + # TODO: Do the proper calculation here. + return 10 + + @override + def get_audit_log_event(self, event_type: AuditLogEventType) -> int: + if event_type is AuditLogEventType.STRIPE_CUSTOMER_CREATED: + return RemoteZulipServerAuditLog.STRIPE_CUSTOMER_CREATED + elif event_type is AuditLogEventType.STRIPE_CARD_CHANGED: + return RemoteZulipServerAuditLog.STRIPE_CARD_CHANGED + elif event_type is AuditLogEventType.CUSTOMER_PLAN_CREATED: + return RemoteZulipServerAuditLog.CUSTOMER_PLAN_CREATED + elif event_type is AuditLogEventType.DISCOUNT_CHANGED: + return RemoteZulipServerAuditLog.REMOTE_SERVER_DISCOUNT_CHANGED + elif event_type is AuditLogEventType.SPONSORSHIP_APPROVED: + return RemoteZulipServerAuditLog.REMOTE_SERVER_SPONSORSHIP_APPROVED + elif event_type is AuditLogEventType.SPONSORSHIP_PENDING_STATUS_CHANGED: + return RemoteZulipServerAuditLog.REMOTE_SERVER_SPONSORSHIP_PENDING_STATUS_CHANGED + elif event_type is AuditLogEventType.BILLING_METHOD_CHANGED: + return RemoteZulipServerAuditLog.REMOTE_SERVER_BILLING_METHOD_CHANGED + else: + raise BillingSessionAuditLogEventError(event_type) + + @override + def write_to_audit_log( + self, + event_type: AuditLogEventType, + event_time: datetime, + *, + extra_data: Optional[Dict[str, Any]] = None, + ) -> None: + audit_log_event = self.get_audit_log_event(event_type) + if extra_data: + RemoteZulipServerAuditLog.objects.create( + server=self.remote_server, + event_type=audit_log_event, + event_time=event_time, + extra_data=extra_data, + ) + else: + RemoteZulipServerAuditLog.objects.create( + server=self.remote_server, + event_type=audit_log_event, + event_time=event_time, + ) + + @override + def get_data_for_stripe_customer(self) -> StripeCustomerData: + # Support requests do not set any stripe billing information. + assert self.support_session is False + metadata: Dict[str, Any] = {} + metadata["remote_server_uuid"] = self.remote_server.uuid + metadata["remote_server_str"] = str(self.remote_server) + realm_stripe_customer_data = StripeCustomerData( + description=str(self.remote_server), + email=self.remote_server.contact_email, + metadata=metadata, + ) + return realm_stripe_customer_data + + @override + def update_data_for_checkout_session_and_payment_intent( + self, metadata: Dict[str, Any] + ) -> Dict[str, Any]: + updated_metadata = dict( + server=self.remote_server, + email=self.remote_server.contact_email, + **metadata, + ) + return updated_metadata + + @override + def get_data_for_stripe_payment_intent( + self, price_per_license: int, licenses: int + ) -> StripePaymentIntentData: + # Support requests do not set any stripe billing information. + assert self.support_session is False + amount = price_per_license * licenses + description = f"Upgrade to Zulip X Standard, ${price_per_license/100} x {licenses}" + plan_name = "Zulip X Standard" + return StripePaymentIntentData( + amount=amount, + description=description, + plan_name=plan_name, + email=self.remote_server.contact_email, + ) + + @override + def update_or_create_customer( + self, stripe_customer_id: Optional[str] = None, *, defaults: Optional[Dict[str, Any]] = None + ) -> Customer: + if stripe_customer_id is not None: + # Support requests do not set any stripe billing information. + assert self.support_session is False + customer, created = Customer.objects.update_or_create( + remote_server=self.remote_server, + defaults={"stripe_customer_id": stripe_customer_id}, + ) + return customer + else: + customer, created = Customer.objects.update_or_create( + remote_server=self.remote_server, defaults=defaults + ) + return customer + + @override + def do_change_plan_type(self, *, tier: Optional[int], is_sponsored: bool = False) -> None: + # TODO: Create actual plan types. + + # This function needs to translate between the different + # formats of CustomerPlan.tier and RealmZulipServer.plan_type. + if is_sponsored: + plan_type = RemoteZulipServer.PLAN_TYPE_COMMUNITY + elif tier == CustomerPlan.STANDARD: + plan_type = RemoteZulipServer.PLAN_TYPE_BUSINESS + elif tier == CustomerPlan.PLUS: # nocoverage # Plus plan doesn't use this code path yet. + plan_type = RemoteZulipServer.PLAN_TYPE_ENTERPRISE + else: + raise AssertionError("Unexpected tier") + + # TODO: Audit logging. + + self.remote_server.plan_type = plan_type + self.remote_server.save(update_fields=["plan_type"]) + + @override + def approve_sponsorship(self) -> None: + # TBD + pass + + @override + def process_downgrade(self, plan: CustomerPlan) -> None: + self.remote_server.plan_type = RemoteZulipServer.PLAN_TYPE_SELF_HOSTED + self.remote_server.save(update_fields=["plan_type"]) + + # TODO: Write audit log entry + plan.status = CustomerPlan.ENDED + plan.save(update_fields=["status"]) + + def stripe_customer_has_credit_card_as_default_payment_method( stripe_customer: stripe.Customer, ) -> bool: diff --git a/corporate/models.py b/corporate/models.py index 7bc499ba06..7b292a396d 100644 --- a/corporate/models.py +++ b/corporate/models.py @@ -7,7 +7,7 @@ from django.db.models import CASCADE, Q from typing_extensions import override from zerver.models import Realm, UserProfile -from zilencer.models import RemoteZulipServer +from zilencer.models import RemoteRealm, RemoteZulipServer class Customer(models.Model): @@ -56,6 +56,14 @@ def get_customer_by_realm(realm: Realm) -> Optional[Customer]: return Customer.objects.filter(realm=realm).first() +def get_customer_by_remote_server(remote_server: RemoteZulipServer) -> Optional[Customer]: + return Customer.objects.filter(remote_server=remote_server).first() + + +def get_customer_by_remote_realm(remote_realm: RemoteRealm) -> Optional[Customer]: # nocoverage + return Customer.objects.filter(remote_realm=remote_realm).first() + + class Event(models.Model): stripe_event_id = models.CharField(max_length=255) diff --git a/corporate/tests/test_stripe.py b/corporate/tests/test_stripe.py index 0546930ff4..8cd9e4e264 100644 --- a/corporate/tests/test_stripe.py +++ b/corporate/tests/test_stripe.py @@ -46,6 +46,7 @@ from corporate.lib.stripe import ( InvalidBillingScheduleError, InvalidTierError, RealmBillingSession, + RemoteServerBillingSession, StripeCardError, add_months, catch_stripe_errors, @@ -4901,6 +4902,64 @@ class TestRealmBillingSession(StripeTestCase): ): billing_session.get_audit_log_event(event_type=fake_audit_log) + def test_get_customer(self) -> None: + user = self.example_user("hamlet") + billing_session = RealmBillingSession(user) + customer = billing_session.get_customer() + self.assertEqual(customer, None) + + customer = Customer.objects.create(realm=user.realm, stripe_customer_id="cus_12345") + self.assertEqual(billing_session.get_customer(), customer) + + +class TestRemoteServerBillingSession(StripeTestCase): + def test_get_audit_log_error(self) -> None: + server_uuid = str(uuid.uuid4()) + remote_server = RemoteZulipServer.objects.create( + uuid=server_uuid, + api_key="magic_secret_api_key", + hostname="demo.example.com", + contact_email="email@example.com", + ) + billing_session = RemoteServerBillingSession(remote_server) + fake_audit_log = typing.cast(AuditLogEventType, 0) + with self.assertRaisesRegex( + BillingSessionAuditLogEventError, "Unknown audit log event type: 0" + ): + billing_session.get_audit_log_event(event_type=fake_audit_log) + + def test_get_customer(self) -> None: + server_uuid = str(uuid.uuid4()) + remote_server = RemoteZulipServer.objects.create( + uuid=server_uuid, + api_key="magic_secret_api_key", + hostname="demo.example.com", + contact_email="email@example.com", + ) + billing_session = RemoteServerBillingSession(remote_server) + customer = billing_session.get_customer() + self.assertEqual(customer, None) + + customer = Customer.objects.create( + remote_server=remote_server, stripe_customer_id="cus_12345" + ) + self.assertEqual(billing_session.get_customer(), customer) + + # @mock_stripe + # def test_update_or_create_stripe_customer(self) -> None: + # server_uuid = str(uuid.uuid4()) + # remote_server = RemoteZulipServer.objects.create( + # uuid=server_uuid, + # api_key="magic_secret_api_key", + # hostname="demo.example.com", + # contact_email="email@example.com", + # ) + # billing_session = RemoteServerBillingSession(remote_server) + # # We need to generate stripe fixture for this type of test. + # customer = billing_session.update_or_create_stripe_customer() + # assert customer.stripe_customer_id + # # Confirm audit log, etc. + class TestSupportBillingHelpers(StripeTestCase): def test_get_discount_for_realm(self) -> None: diff --git a/zerver/models.py b/zerver/models.py index 6a24ef7ce3..dfcff72005 100644 --- a/zerver/models.py +++ b/zerver/models.py @@ -4806,9 +4806,13 @@ class AbstractRealmAuditLog(models.Model): # Values should be exactly 10000 greater than the corresponding # value used for the same purpose in RealmAuditLog (e.g. # REALM_DEACTIVATED = 201, and REMOTE_SERVER_DEACTIVATED = 10201). - REMOTE_SERVER_CREATED = 10215 - REMOTE_SERVER_PLAN_TYPE_CHANGED = 10204 REMOTE_SERVER_DEACTIVATED = 10201 + REMOTE_SERVER_PLAN_TYPE_CHANGED = 10204 + REMOTE_SERVER_DISCOUNT_CHANGED = 10209 + REMOTE_SERVER_SPONSORSHIP_APPROVED = 10210 + REMOTE_SERVER_BILLING_METHOD_CHANGED = 10211 + REMOTE_SERVER_SPONSORSHIP_PENDING_STATUS_CHANGED = 10213 + REMOTE_SERVER_CREATED = 10215 # This value is for RemoteRealmAuditLog entries tracking changes to the # RemoteRealm model resulting from modified realm information sent to us