stripe: Raise 'MissingDataError' while fetching license count.

If the RemoteRealmAuditLog has stale data, it means the server
stopped or never uploaded data. We raise MissingDataError in such
cases when a user action led to calculating licenses count from
stale data.
This commit is contained in:
Prakhar Pratyush 2023-12-06 23:55:49 +05:30 committed by Tim Abbott
parent 40621478cb
commit ed9b0d330d
6 changed files with 144 additions and 11 deletions

View File

@ -56,6 +56,7 @@ from zerver.models import (
get_realm, get_realm,
get_system_bot, get_system_bot,
) )
from zilencer.lib.remote_counts import MissingDataError
from zilencer.models import ( from zilencer.models import (
RemoteRealm, RemoteRealm,
RemoteRealmAuditLog, RemoteRealmAuditLog,
@ -64,6 +65,7 @@ from zilencer.models import (
RemoteZulipServerAuditLog, RemoteZulipServerAuditLog,
get_remote_realm_guest_and_non_guest_count, get_remote_realm_guest_and_non_guest_count,
get_remote_server_guest_and_non_guest_count, get_remote_server_guest_and_non_guest_count,
has_stale_audit_log,
) )
from zproject.config import get_secret from zproject.config import get_secret
@ -2802,6 +2804,8 @@ class RemoteRealmBillingSession(BillingSession): # nocoverage
@override @override
def current_count_for_billed_licenses(self) -> int: def current_count_for_billed_licenses(self) -> int:
if has_stale_audit_log(self.remote_realm.server):
raise MissingDataError
remote_realm_counts = get_remote_realm_guest_and_non_guest_count(self.remote_realm) remote_realm_counts = get_remote_realm_guest_and_non_guest_count(self.remote_realm)
return remote_realm_counts.non_guest_user_count + remote_realm_counts.guest_user_count return remote_realm_counts.non_guest_user_count + remote_realm_counts.guest_user_count
@ -3115,6 +3119,8 @@ class RemoteServerBillingSession(BillingSession): # nocoverage
@override @override
def current_count_for_billed_licenses(self) -> int: def current_count_for_billed_licenses(self) -> int:
if has_stale_audit_log(self.remote_server):
raise MissingDataError
remote_server_counts = get_remote_server_guest_and_non_guest_count(self.remote_server.id) remote_server_counts = get_remote_server_guest_and_non_guest_count(self.remote_server.id)
return remote_server_counts.non_guest_user_count + remote_server_counts.guest_user_count return remote_server_counts.non_guest_user_count + remote_server_counts.guest_user_count

View File

@ -344,10 +344,13 @@ class RemoteBillingAuthenticationTest(BouncerTestCase):
# Go to the URL we're redirected to after authentication and assert # Go to the URL we're redirected to after authentication and assert
# some basic expected content. # some basic expected content.
result = self.client_get(result["Location"], subdomain="selfhosting") # TODO: Add test for the case when redirected to error page (not yet implemented)
self.assert_in_success_response( # due to MissingDataError ('has_stale_audit_log' is True).
["Upgrade", "Purchase Zulip", "Your subscription will renew automatically."], result with mock.patch("corporate.lib.stripe.has_stale_audit_log", return_value=False):
) result = self.client_get(result["Location"], subdomain="selfhosting")
self.assert_in_success_response(
["Upgrade", "Purchase Zulip", "Your subscription will renew automatically."], result
)
class LegacyServerLoginTest(BouncerTestCase): class LegacyServerLoginTest(BouncerTestCase):
@ -424,8 +427,11 @@ class LegacyServerLoginTest(BouncerTestCase):
self.assertEqual(result["Location"], f"/server/{self.uuid}/upgrade/") self.assertEqual(result["Location"], f"/server/{self.uuid}/upgrade/")
# Access on the upgrade page is granted, assert a basic string proving that. # Access on the upgrade page is granted, assert a basic string proving that.
result = self.client_get(result["Location"], subdomain="selfhosting") # TODO: Add test for the case when redirected to error page (not yet implemented)
self.assert_in_success_response([f"Upgrade {self.server.hostname}"], result) # due to MissingDataError ('has_stale_audit_log' is True).
with mock.patch("corporate.lib.stripe.has_stale_audit_log", return_value=False):
result = self.client_get(result["Location"], subdomain="selfhosting")
self.assert_in_success_response([f"Upgrade {self.server.hostname}"], result)
def test_server_login_success_with_next_page(self) -> None: def test_server_login_success_with_next_page(self) -> None:
# First test an invalid next_page value. # First test an invalid next_page value.
@ -501,8 +507,11 @@ class LegacyServerLoginTest(BouncerTestCase):
self.assertEqual(result["Location"], f"/server/{self.uuid}/upgrade/") self.assertEqual(result["Location"], f"/server/{self.uuid}/upgrade/")
# Sanity check: access on the upgrade page is granted. # Sanity check: access on the upgrade page is granted.
result = self.client_get(result["Location"], subdomain="selfhosting") # TODO: Add test for the case when redirected to error page (Not yet implemented)
self.assert_in_success_response([f"Upgrade {self.server.hostname}"], result) # due to MissingDataError i.e., when 'has_stale_audit_log' is True.
with mock.patch("corporate.lib.stripe.has_stale_audit_log", return_value=False):
result = self.client_get(result["Location"], subdomain="selfhosting")
self.assert_in_success_response([f"Upgrade {self.server.hostname}"], result)
# Now we can simulate an expired identity dict in the session. # Now we can simulate an expired identity dict in the session.
with time_machine.travel( with time_machine.travel(

View File

@ -49,6 +49,7 @@ from corporate.lib.stripe import (
InvalidBillingScheduleError, InvalidBillingScheduleError,
InvalidTierError, InvalidTierError,
RealmBillingSession, RealmBillingSession,
RemoteRealmBillingSession,
RemoteServerBillingSession, RemoteServerBillingSession,
StripeCardError, StripeCardError,
SupportType, SupportType,
@ -104,7 +105,13 @@ from zerver.models import (
get_realm, get_realm,
get_system_bot, get_system_bot,
) )
from zilencer.models import RemoteZulipServer, RemoteZulipServerAuditLog from zilencer.lib.remote_counts import MissingDataError
from zilencer.models import (
RemoteRealm,
RemoteRealmAuditLog,
RemoteZulipServer,
RemoteZulipServerAuditLog,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from django.test.client import _MonkeyPatchedWSGIResponse as TestHttpResponse from django.test.client import _MonkeyPatchedWSGIResponse as TestHttpResponse
@ -5089,6 +5096,80 @@ class TestRealmBillingSession(StripeTestCase):
self.assertEqual(billing_session.get_customer(), customer) self.assertEqual(billing_session.get_customer(), customer)
class TestRemoteRealmBillingSession(StripeTestCase):
def test_current_count_for_billed_licenses(self) -> None:
server_uuid = str(uuid.uuid4())
remote_server = RemoteZulipServer.objects.create(
uuid=server_uuid,
api_key="magic_secret_api_key",
hostname="demo.example.com",
contact_email="email@example.com",
)
realm_uuid = str(uuid.uuid4())
remote_realm = RemoteRealm.objects.create(
server=remote_server,
uuid=realm_uuid,
uuid_owner_secret="dummy-owner-secret",
host="dummy-hostname",
realm_date_created=timezone_now(),
)
billing_session = RemoteRealmBillingSession(remote_realm=remote_realm)
# remote server never uploaded statistics. 'last_audit_log_update' is None.
with self.assertRaises(MissingDataError):
billing_session.current_count_for_billed_licenses()
# Available statistics is stale.
remote_server.last_audit_log_update = timezone_now() - timedelta(days=5)
remote_server.save()
with self.assertRaises(MissingDataError):
billing_session.current_count_for_billed_licenses()
# Available statistics is not stale.
event_time = timezone_now() - timedelta(days=1)
data_list = [
{
"server": remote_server,
"remote_realm": remote_realm,
"event_type": RemoteRealmAuditLog.USER_CREATED,
"event_time": event_time,
"extra_data": {
RemoteRealmAuditLog.ROLE_COUNT: {
RemoteRealmAuditLog.ROLE_COUNT_HUMANS: {
UserProfile.ROLE_REALM_ADMINISTRATOR: 10,
UserProfile.ROLE_REALM_OWNER: 10,
UserProfile.ROLE_MODERATOR: 10,
UserProfile.ROLE_MEMBER: 10,
UserProfile.ROLE_GUEST: 10,
}
}
},
},
{
"server": remote_server,
"remote_realm": remote_realm,
"event_type": RemoteRealmAuditLog.USER_ROLE_CHANGED,
"event_time": event_time,
"extra_data": {
RemoteRealmAuditLog.ROLE_COUNT: {
RemoteRealmAuditLog.ROLE_COUNT_HUMANS: {
UserProfile.ROLE_REALM_ADMINISTRATOR: 20,
UserProfile.ROLE_REALM_OWNER: 10,
UserProfile.ROLE_MODERATOR: 0,
UserProfile.ROLE_MEMBER: 30,
UserProfile.ROLE_GUEST: 10,
}
}
},
},
]
RemoteRealmAuditLog.objects.bulk_create([RemoteRealmAuditLog(**data) for data in data_list])
remote_server.last_audit_log_update = timezone_now() - timedelta(days=1)
remote_server.save()
self.assertEqual(billing_session.current_count_for_billed_licenses(), 70)
class TestRemoteServerBillingSession(StripeTestCase): class TestRemoteServerBillingSession(StripeTestCase):
def test_get_audit_log_error(self) -> None: def test_get_audit_log_error(self) -> None:
server_uuid = str(uuid.uuid4()) server_uuid = str(uuid.uuid4())

View File

@ -0,0 +1,17 @@
# Generated by Django 4.2.7 on 2023-12-06 18:23
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("zilencer", "0045_remoterealmauditlog_zilencer_remoterealmauditlog_server_realm_and_more"),
]
operations = [
migrations.AddField(
model_name="remotezulipserver",
name="last_audit_log_update",
field=models.DateTimeField(null=True),
),
]

View File

@ -2,7 +2,7 @@
# mypy: disable-error-code="explicit-override" # mypy: disable-error-code="explicit-override"
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime, timedelta
from typing import List, Tuple from typing import List, Tuple
from django.conf import settings from django.conf import settings
@ -67,6 +67,9 @@ class RemoteZulipServer(models.Model):
choices=[(t["id"], t["name"]) for t in Realm.ORG_TYPES.values()], choices=[(t["id"], t["name"]) for t in Realm.ORG_TYPES.values()],
) )
# The last time 'RemoteRealmAuditlog' was updated for this server.
last_audit_log_update = models.DateTimeField(null=True)
@override @override
def __str__(self) -> str: def __str__(self) -> str:
return f"{self.hostname} {str(self.uuid)[0:12]}" return f"{self.hostname} {str(self.uuid)[0:12]}"
@ -399,7 +402,7 @@ def get_remote_server_guest_and_non_guest_count(
def get_remote_realm_guest_and_non_guest_count( def get_remote_realm_guest_and_non_guest_count(
remote_realm: RemoteRealm, event_time: datetime = timezone_now() remote_realm: RemoteRealm, event_time: datetime = timezone_now()
) -> RemoteCustomerUserCount: # nocoverage ) -> RemoteCustomerUserCount:
latest_audit_log = ( latest_audit_log = (
RemoteRealmAuditLog.objects.filter( RemoteRealmAuditLog.objects.filter(
remote_realm=remote_realm, remote_realm=remote_realm,
@ -430,3 +433,13 @@ def get_remote_realm_guest_and_non_guest_count(
return RemoteCustomerUserCount( return RemoteCustomerUserCount(
non_guest_user_count=non_guest_count, guest_user_count=guest_count non_guest_user_count=non_guest_count, guest_user_count=guest_count
) )
def has_stale_audit_log(server: RemoteZulipServer) -> bool:
if server.last_audit_log_update is None:
return True
if timezone_now() - server.last_audit_log_update > timedelta(days=2):
return True
return False

View File

@ -749,6 +749,10 @@ def remote_server_post_analytics(
batch_create_table_data(server, RemoteInstallationCount, remote_installation_counts) batch_create_table_data(server, RemoteInstallationCount, remote_installation_counts)
if realmauditlog_rows is not None: if realmauditlog_rows is not None:
# Important: Do not return early if we receive 0 rows; we must
# updated last_audit_log_update even if there are no new rows,
# to help identify server whose ability to connect to this
# endpoint is broken by a networking problem.
remote_realm_audit_logs = [] remote_realm_audit_logs = []
for row in realmauditlog_rows: for row in realmauditlog_rows:
extra_data = {} extra_data = {}
@ -773,6 +777,9 @@ def remote_server_post_analytics(
) )
) )
batch_create_table_data(server, RemoteRealmAuditLog, remote_realm_audit_logs) batch_create_table_data(server, RemoteRealmAuditLog, remote_realm_audit_logs)
RemoteZulipServer.objects.filter(uuid=server.uuid).update(
last_audit_log_update=timezone_now()
)
remote_realm_dict: Dict[str, RemoteRealmDictValue] = {} remote_realm_dict: Dict[str, RemoteRealmDictValue] = {}
remote_realms = RemoteRealm.objects.filter(server=server) remote_realms = RemoteRealm.objects.filter(server=server)