diff --git a/corporate/lib/stripe.py b/corporate/lib/stripe.py index e2b0da4e80..18cdbe803e 100644 --- a/corporate/lib/stripe.py +++ b/corporate/lib/stripe.py @@ -56,6 +56,7 @@ from zerver.models import ( get_realm, get_system_bot, ) +from zilencer.lib.remote_counts import MissingDataError from zilencer.models import ( RemoteRealm, RemoteRealmAuditLog, @@ -64,6 +65,7 @@ from zilencer.models import ( RemoteZulipServerAuditLog, get_remote_realm_guest_and_non_guest_count, get_remote_server_guest_and_non_guest_count, + has_stale_audit_log, ) from zproject.config import get_secret @@ -2802,6 +2804,8 @@ class RemoteRealmBillingSession(BillingSession): # nocoverage @override def current_count_for_billed_licenses(self) -> int: + if has_stale_audit_log(self.remote_realm.server): + raise MissingDataError remote_realm_counts = get_remote_realm_guest_and_non_guest_count(self.remote_realm) return remote_realm_counts.non_guest_user_count + remote_realm_counts.guest_user_count @@ -3115,6 +3119,8 @@ class RemoteServerBillingSession(BillingSession): # nocoverage @override def current_count_for_billed_licenses(self) -> int: + if has_stale_audit_log(self.remote_server): + raise MissingDataError remote_server_counts = get_remote_server_guest_and_non_guest_count(self.remote_server.id) return remote_server_counts.non_guest_user_count + remote_server_counts.guest_user_count diff --git a/corporate/tests/test_remote_billing.py b/corporate/tests/test_remote_billing.py index 7c840732c8..c6be9ae8cf 100644 --- a/corporate/tests/test_remote_billing.py +++ b/corporate/tests/test_remote_billing.py @@ -344,10 +344,13 @@ class RemoteBillingAuthenticationTest(BouncerTestCase): # Go to the URL we're redirected to after authentication and assert # some basic expected content. - result = self.client_get(result["Location"], subdomain="selfhosting") - self.assert_in_success_response( - ["Upgrade", "Purchase Zulip", "Your subscription will renew automatically."], result - ) + # TODO: Add test for the case when redirected to error page (not yet implemented) + # due to MissingDataError ('has_stale_audit_log' is True). + with mock.patch("corporate.lib.stripe.has_stale_audit_log", return_value=False): + result = self.client_get(result["Location"], subdomain="selfhosting") + self.assert_in_success_response( + ["Upgrade", "Purchase Zulip", "Your subscription will renew automatically."], result + ) class LegacyServerLoginTest(BouncerTestCase): @@ -424,8 +427,11 @@ class LegacyServerLoginTest(BouncerTestCase): self.assertEqual(result["Location"], f"/server/{self.uuid}/upgrade/") # Access on the upgrade page is granted, assert a basic string proving that. - result = self.client_get(result["Location"], subdomain="selfhosting") - self.assert_in_success_response([f"Upgrade {self.server.hostname}"], result) + # TODO: Add test for the case when redirected to error page (not yet implemented) + # due to MissingDataError ('has_stale_audit_log' is True). + with mock.patch("corporate.lib.stripe.has_stale_audit_log", return_value=False): + result = self.client_get(result["Location"], subdomain="selfhosting") + self.assert_in_success_response([f"Upgrade {self.server.hostname}"], result) def test_server_login_success_with_next_page(self) -> None: # First test an invalid next_page value. @@ -501,8 +507,11 @@ class LegacyServerLoginTest(BouncerTestCase): self.assertEqual(result["Location"], f"/server/{self.uuid}/upgrade/") # Sanity check: access on the upgrade page is granted. - result = self.client_get(result["Location"], subdomain="selfhosting") - self.assert_in_success_response([f"Upgrade {self.server.hostname}"], result) + # TODO: Add test for the case when redirected to error page (Not yet implemented) + # due to MissingDataError i.e., when 'has_stale_audit_log' is True. + with mock.patch("corporate.lib.stripe.has_stale_audit_log", return_value=False): + result = self.client_get(result["Location"], subdomain="selfhosting") + self.assert_in_success_response([f"Upgrade {self.server.hostname}"], result) # Now we can simulate an expired identity dict in the session. with time_machine.travel( diff --git a/corporate/tests/test_stripe.py b/corporate/tests/test_stripe.py index aa9832a431..4d0369c2fc 100644 --- a/corporate/tests/test_stripe.py +++ b/corporate/tests/test_stripe.py @@ -49,6 +49,7 @@ from corporate.lib.stripe import ( InvalidBillingScheduleError, InvalidTierError, RealmBillingSession, + RemoteRealmBillingSession, RemoteServerBillingSession, StripeCardError, SupportType, @@ -104,7 +105,13 @@ from zerver.models import ( get_realm, get_system_bot, ) -from zilencer.models import RemoteZulipServer, RemoteZulipServerAuditLog +from zilencer.lib.remote_counts import MissingDataError +from zilencer.models import ( + RemoteRealm, + RemoteRealmAuditLog, + RemoteZulipServer, + RemoteZulipServerAuditLog, +) if TYPE_CHECKING: from django.test.client import _MonkeyPatchedWSGIResponse as TestHttpResponse @@ -5089,6 +5096,80 @@ class TestRealmBillingSession(StripeTestCase): self.assertEqual(billing_session.get_customer(), customer) +class TestRemoteRealmBillingSession(StripeTestCase): + def test_current_count_for_billed_licenses(self) -> None: + server_uuid = str(uuid.uuid4()) + remote_server = RemoteZulipServer.objects.create( + uuid=server_uuid, + api_key="magic_secret_api_key", + hostname="demo.example.com", + contact_email="email@example.com", + ) + realm_uuid = str(uuid.uuid4()) + remote_realm = RemoteRealm.objects.create( + server=remote_server, + uuid=realm_uuid, + uuid_owner_secret="dummy-owner-secret", + host="dummy-hostname", + realm_date_created=timezone_now(), + ) + billing_session = RemoteRealmBillingSession(remote_realm=remote_realm) + + # remote server never uploaded statistics. 'last_audit_log_update' is None. + with self.assertRaises(MissingDataError): + billing_session.current_count_for_billed_licenses() + + # Available statistics is stale. + remote_server.last_audit_log_update = timezone_now() - timedelta(days=5) + remote_server.save() + with self.assertRaises(MissingDataError): + billing_session.current_count_for_billed_licenses() + + # Available statistics is not stale. + event_time = timezone_now() - timedelta(days=1) + data_list = [ + { + "server": remote_server, + "remote_realm": remote_realm, + "event_type": RemoteRealmAuditLog.USER_CREATED, + "event_time": event_time, + "extra_data": { + RemoteRealmAuditLog.ROLE_COUNT: { + RemoteRealmAuditLog.ROLE_COUNT_HUMANS: { + UserProfile.ROLE_REALM_ADMINISTRATOR: 10, + UserProfile.ROLE_REALM_OWNER: 10, + UserProfile.ROLE_MODERATOR: 10, + UserProfile.ROLE_MEMBER: 10, + UserProfile.ROLE_GUEST: 10, + } + } + }, + }, + { + "server": remote_server, + "remote_realm": remote_realm, + "event_type": RemoteRealmAuditLog.USER_ROLE_CHANGED, + "event_time": event_time, + "extra_data": { + RemoteRealmAuditLog.ROLE_COUNT: { + RemoteRealmAuditLog.ROLE_COUNT_HUMANS: { + UserProfile.ROLE_REALM_ADMINISTRATOR: 20, + UserProfile.ROLE_REALM_OWNER: 10, + UserProfile.ROLE_MODERATOR: 0, + UserProfile.ROLE_MEMBER: 30, + UserProfile.ROLE_GUEST: 10, + } + } + }, + }, + ] + RemoteRealmAuditLog.objects.bulk_create([RemoteRealmAuditLog(**data) for data in data_list]) + remote_server.last_audit_log_update = timezone_now() - timedelta(days=1) + remote_server.save() + + self.assertEqual(billing_session.current_count_for_billed_licenses(), 70) + + class TestRemoteServerBillingSession(StripeTestCase): def test_get_audit_log_error(self) -> None: server_uuid = str(uuid.uuid4()) diff --git a/zilencer/migrations/0046_remotezulipserver_last_audit_log_update.py b/zilencer/migrations/0046_remotezulipserver_last_audit_log_update.py new file mode 100644 index 0000000000..2d2e638873 --- /dev/null +++ b/zilencer/migrations/0046_remotezulipserver_last_audit_log_update.py @@ -0,0 +1,17 @@ +# Generated by Django 4.2.7 on 2023-12-06 18:23 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("zilencer", "0045_remoterealmauditlog_zilencer_remoterealmauditlog_server_realm_and_more"), + ] + + operations = [ + migrations.AddField( + model_name="remotezulipserver", + name="last_audit_log_update", + field=models.DateTimeField(null=True), + ), + ] diff --git a/zilencer/models.py b/zilencer/models.py index c245d77661..ff10d03aa1 100644 --- a/zilencer/models.py +++ b/zilencer/models.py @@ -2,7 +2,7 @@ # mypy: disable-error-code="explicit-override" from dataclasses import dataclass -from datetime import datetime +from datetime import datetime, timedelta from typing import List, Tuple from django.conf import settings @@ -67,6 +67,9 @@ class RemoteZulipServer(models.Model): choices=[(t["id"], t["name"]) for t in Realm.ORG_TYPES.values()], ) + # The last time 'RemoteRealmAuditlog' was updated for this server. + last_audit_log_update = models.DateTimeField(null=True) + @override def __str__(self) -> str: return f"{self.hostname} {str(self.uuid)[0:12]}" @@ -399,7 +402,7 @@ def get_remote_server_guest_and_non_guest_count( def get_remote_realm_guest_and_non_guest_count( remote_realm: RemoteRealm, event_time: datetime = timezone_now() -) -> RemoteCustomerUserCount: # nocoverage +) -> RemoteCustomerUserCount: latest_audit_log = ( RemoteRealmAuditLog.objects.filter( remote_realm=remote_realm, @@ -430,3 +433,13 @@ def get_remote_realm_guest_and_non_guest_count( return RemoteCustomerUserCount( non_guest_user_count=non_guest_count, guest_user_count=guest_count ) + + +def has_stale_audit_log(server: RemoteZulipServer) -> bool: + if server.last_audit_log_update is None: + return True + + if timezone_now() - server.last_audit_log_update > timedelta(days=2): + return True + + return False diff --git a/zilencer/views.py b/zilencer/views.py index 80f62eb76e..500bd42d63 100644 --- a/zilencer/views.py +++ b/zilencer/views.py @@ -749,6 +749,10 @@ def remote_server_post_analytics( batch_create_table_data(server, RemoteInstallationCount, remote_installation_counts) if realmauditlog_rows is not None: + # Important: Do not return early if we receive 0 rows; we must + # updated last_audit_log_update even if there are no new rows, + # to help identify server whose ability to connect to this + # endpoint is broken by a networking problem. remote_realm_audit_logs = [] for row in realmauditlog_rows: extra_data = {} @@ -773,6 +777,9 @@ def remote_server_post_analytics( ) ) batch_create_table_data(server, RemoteRealmAuditLog, remote_realm_audit_logs) + RemoteZulipServer.objects.filter(uuid=server.uuid).update( + last_audit_log_update=timezone_now() + ) remote_realm_dict: Dict[str, RemoteRealmDictValue] = {} remote_realms = RemoteRealm.objects.filter(server=server)