diff --git a/corporate/lib/stripe.py b/corporate/lib/stripe.py index 160b1e6735..92280a8697 100644 --- a/corporate/lib/stripe.py +++ b/corporate/lib/stripe.py @@ -48,7 +48,6 @@ from zerver.lib.send_email import ( send_email_to_billing_admins_and_realm_owners, ) from zerver.lib.timestamp import datetime_to_timestamp, timestamp_to_datetime -from zerver.lib.types import RemoteRealmDictValue from zerver.lib.url_encoding import append_url_query_string from zerver.lib.utils import assert_is_not_none from zerver.models import ( @@ -3409,28 +3408,6 @@ class RemoteRealmBillingSession(BillingSession): return # nocoverage current_plan = end_of_cycle_plan - def get_push_service_validity_dict(self) -> RemoteRealmDictValue: - customer = self.get_customer() - if customer is None: - return {"can_push": True, "expected_end_timestamp": None} - - current_plan = get_current_plan_by_customer(customer) - if current_plan is None: - return {"can_push": True, "expected_end_timestamp": None} - - expected_end_timestamp = None - if current_plan.status in [ - CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE, - CustomerPlan.DOWNGRADE_AT_END_OF_FREE_TRIAL, - ]: - expected_end_timestamp = datetime_to_timestamp( - self.get_next_billing_cycle(current_plan) - ) - return { - "can_push": True, - "expected_end_timestamp": expected_end_timestamp, - } - class RemoteServerBillingSession(BillingSession): """Billing session for pre-8.0 servers that do not yet support @@ -4083,3 +4060,62 @@ def downgrade_small_realms_behind_on_payments_as_needed() -> None: # the last invoice open, void the open invoices. billing_session = RealmBillingSession(user=None, realm=realm) billing_session.void_all_open_invoices() + + +@dataclass +class PushNotificationsEnabledStatus: + can_push: bool + expected_end_timestamp: Optional[int] + + # Not sent to clients, just for debugging + message: str + + +def get_push_status_for_remote_request( + remote_server: RemoteZulipServer, remote_realm: Optional[RemoteRealm] +) -> PushNotificationsEnabledStatus: + # First, get the operative Customer object for this + # installation. If there's a `RemoteRealm` customer, that + # takes precedence. + customer = None + + if remote_realm is not None: + billing_session: BillingSession = RemoteRealmBillingSession(remote_realm) + customer = billing_session.get_customer() + + if customer is None: + billing_session = RemoteServerBillingSession(remote_server) + customer = billing_session.get_customer() + + if customer is not None: + current_plan = get_current_plan_by_customer(customer) + else: + current_plan = None + + if current_plan is not None: + if current_plan.status in [ + CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE, + CustomerPlan.DOWNGRADE_AT_END_OF_FREE_TRIAL, + ]: + # Plans scheduled to end + expected_end_timestamp = datetime_to_timestamp( + billing_session.get_next_billing_cycle(current_plan) + ) + return PushNotificationsEnabledStatus( + can_push=True, + expected_end_timestamp=expected_end_timestamp, + message="Scheduled end", + ) + + # Current plan, no expected end. + return PushNotificationsEnabledStatus( + can_push=True, + expected_end_timestamp=None, + message="Active plan", + ) + + return PushNotificationsEnabledStatus( + can_push=True, + expected_end_timestamp=None, + message="No plan", + ) diff --git a/zerver/tests/test_push_notifications.py b/zerver/tests/test_push_notifications.py index 5f3ba70adf..426f52ee5a 100644 --- a/zerver/tests/test_push_notifications.py +++ b/zerver/tests/test_push_notifications.py @@ -1890,7 +1890,7 @@ class AnalyticsBouncerTest(BouncerTestCase): self.add_mock_response() with mock.patch( - "zilencer.views.RemoteRealmBillingSession.get_customer", return_value=None + "corporate.lib.stripe.RemoteRealmBillingSession.get_customer", return_value=None ) as m: send_server_data_to_push_bouncer(consider_usage_statistics=False) m.assert_called() @@ -1901,7 +1901,8 @@ class AnalyticsBouncerTest(BouncerTestCase): dummy_customer = mock.MagicMock() with mock.patch( - "zilencer.views.RemoteRealmBillingSession.get_customer", return_value=dummy_customer + "corporate.lib.stripe.RemoteRealmBillingSession.get_customer", + return_value=dummy_customer, ): with mock.patch( "corporate.lib.stripe.get_current_plan_by_customer", return_value=None @@ -1917,14 +1918,15 @@ class AnalyticsBouncerTest(BouncerTestCase): dummy_customer_plan.status = CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE dummy_date = datetime(year=2023, month=12, day=3, tzinfo=timezone.utc) with mock.patch( - "zilencer.views.RemoteRealmBillingSession.get_customer", return_value=dummy_customer + "corporate.lib.stripe.RemoteRealmBillingSession.get_customer", + return_value=dummy_customer, ): with mock.patch( "corporate.lib.stripe.get_current_plan_by_customer", return_value=dummy_customer_plan, ): with mock.patch( - "zilencer.views.RemoteRealmBillingSession.get_next_billing_cycle", + "corporate.lib.stripe.RemoteRealmBillingSession.get_next_billing_cycle", return_value=dummy_date, ) as m, self.assertLogs("zulip.analytics", level="INFO") as info_log: send_server_data_to_push_bouncer(consider_usage_statistics=False) @@ -1941,6 +1943,31 @@ class AnalyticsBouncerTest(BouncerTestCase): info_log.output[0], ) + dummy_customer_plan = mock.MagicMock() + dummy_customer_plan.status = CustomerPlan.ACTIVE + with mock.patch( + "corporate.lib.stripe.RemoteRealmBillingSession.get_customer", + return_value=dummy_customer, + ): + with mock.patch( + "corporate.lib.stripe.get_current_plan_by_customer", + return_value=dummy_customer_plan, + ): + with self.assertLogs("zulip.analytics", level="INFO") as info_log: + send_server_data_to_push_bouncer(consider_usage_statistics=False) + m.assert_called() + realms = Realm.objects.all() + for realm in realms: + self.assertEqual(realm.push_notifications_enabled, True) + self.assertEqual( + realm.push_notifications_enabled_end_timestamp, + None, + ) + self.assertIn( + "INFO:zulip.analytics:Reported 0 records", + info_log.output[0], + ) + with mock.patch("zerver.lib.remote_server.send_to_push_bouncer") as m, self.assertLogs( "zulip.analytics", level="WARNING" ) as exception_log: diff --git a/zilencer/views.py b/zilencer/views.py index fca570df8e..6241e3d697 100644 --- a/zilencer/views.py +++ b/zilencer/views.py @@ -28,6 +28,7 @@ from corporate.lib.stripe import ( RemoteRealmBillingSession, RemoteServerBillingSession, do_deactivate_remote_server, + get_push_status_for_remote_request, ) from corporate.models import CustomerPlan, get_current_plan_by_customer from zerver.decorator import require_post @@ -532,8 +533,11 @@ def remote_server_notify_push( timezone_now(), increment=android_successfully_delivered + apple_successfully_delivered, ) - billing_session = RemoteRealmBillingSession(remote_realm) - remote_realm_dict = billing_session.get_push_service_validity_dict() + push_status = get_push_status_for_remote_request(server, remote_realm) + remote_realm_dict = { + "can_push": push_status.can_push, + "expected_end_timestamp": push_status.expected_end_timestamp, + } deleted_devices = get_deleted_devices( user_identity, @@ -1045,8 +1049,11 @@ def remote_server_post_analytics( remote_realms = RemoteRealm.objects.filter(server=server, realm_locally_deleted=False) for remote_realm in remote_realms: uuid = str(remote_realm.uuid) - billing_session = RemoteRealmBillingSession(remote_realm) - remote_realm_dict[uuid] = billing_session.get_push_service_validity_dict() + status = get_push_status_for_remote_request(server, remote_realm) + remote_realm_dict[uuid] = { + "can_push": status.can_push, + "expected_end_timestamp": status.expected_end_timestamp, + } return json_success(request, data={"realms": remote_realm_dict})