stripe: Update LicenseLedger for remote realms.

When a remote server uploads statistics, we update the
LicenseLedger using the audit logs uploaded.

We iterate over the RemoteRealmAuditlog data for the concerned
realm starting from the event_time of the last LicenseLedger
created for that customer and update the ledger based on each event.
This commit is contained in:
Prakhar Pratyush 2023-12-08 17:49:24 +05:30 committed by Tim Abbott
parent ed9b0d330d
commit bf4fdbff12
2 changed files with 83 additions and 19 deletions

View File

@ -683,7 +683,7 @@ class BillingSession(ABC):
pass pass
@abstractmethod @abstractmethod
def current_count_for_billed_licenses(self) -> int: def current_count_for_billed_licenses(self, event_time: datetime = timezone_now()) -> int:
pass pass
@abstractmethod @abstractmethod
@ -2422,13 +2422,13 @@ class BillingSession(ABC):
def update_license_ledger_for_automanaged_plan( def update_license_ledger_for_automanaged_plan(
self, plan: CustomerPlan, event_time: datetime self, plan: CustomerPlan, event_time: datetime
) -> None: ) -> Optional[CustomerPlan]:
new_plan, last_ledger_entry = self.make_end_of_cycle_updates_if_needed(plan, event_time) new_plan, last_ledger_entry = self.make_end_of_cycle_updates_if_needed(plan, event_time)
if last_ledger_entry is None: if last_ledger_entry is None:
return return None
if new_plan is not None: if new_plan is not None:
plan = new_plan plan = new_plan
licenses_at_next_renewal = self.current_count_for_billed_licenses() licenses_at_next_renewal = self.current_count_for_billed_licenses(event_time)
licenses = max(licenses_at_next_renewal, last_ledger_entry.licenses) licenses = max(licenses_at_next_renewal, last_ledger_entry.licenses)
LicenseLedger.objects.create( LicenseLedger.objects.create(
@ -2438,16 +2438,10 @@ class BillingSession(ABC):
licenses_at_next_renewal=licenses_at_next_renewal, licenses_at_next_renewal=licenses_at_next_renewal,
) )
def update_license_ledger_if_needed(self, event_time: datetime) -> None: # Returning plan is particularly helpful for 'sync_license_ledger_if_needed'.
customer = self.get_customer() # If a new plan is created during the end of cycle update, then that function
if customer is None: # needs the updated plan for a correct LicenseLedger update.
return return plan
plan = get_current_plan_by_customer(customer)
if plan is None:
return
if not plan.automanage_licenses:
return
self.update_license_ledger_for_automanaged_plan(plan, event_time)
class RealmBillingSession(BillingSession): class RealmBillingSession(BillingSession):
@ -2507,7 +2501,7 @@ class RealmBillingSession(BillingSession):
return self.user.delivery_email return self.user.delivery_email
@override @override
def current_count_for_billed_licenses(self) -> int: def current_count_for_billed_licenses(self, event_time: datetime = timezone_now()) -> int:
return get_latest_seat_count(self.realm) return get_latest_seat_count(self.realm)
@override @override
@ -2759,6 +2753,17 @@ class RealmBillingSession(BillingSession):
self.realm.org_type = org_type self.realm.org_type = org_type
self.realm.save(update_fields=["org_type"]) self.realm.save(update_fields=["org_type"])
def update_license_ledger_if_needed(self, event_time: datetime) -> None:
customer = self.get_customer()
if customer is None:
return
plan = get_current_plan_by_customer(customer)
if plan is None:
return
if not plan.automanage_licenses:
return
self.update_license_ledger_for_automanaged_plan(plan, event_time)
class RemoteRealmBillingSession(BillingSession): # nocoverage class RemoteRealmBillingSession(BillingSession): # nocoverage
def __init__( def __init__(
@ -2803,10 +2808,12 @@ class RemoteRealmBillingSession(BillingSession): # nocoverage
return self.remote_realm.server.contact_email return self.remote_realm.server.contact_email
@override @override
def current_count_for_billed_licenses(self) -> int: def current_count_for_billed_licenses(self, event_time: datetime = timezone_now()) -> int:
if has_stale_audit_log(self.remote_realm.server): if has_stale_audit_log(self.remote_realm.server):
raise MissingDataError raise MissingDataError
remote_realm_counts = get_remote_realm_guest_and_non_guest_count(self.remote_realm) remote_realm_counts = get_remote_realm_guest_and_non_guest_count(
self.remote_realm, event_time
)
return remote_realm_counts.non_guest_user_count + remote_realm_counts.guest_user_count return remote_realm_counts.non_guest_user_count + remote_realm_counts.guest_user_count
@override @override
@ -3073,6 +3080,50 @@ class RemoteRealmBillingSession(BillingSession): # nocoverage
self.remote_realm.org_type = org_type self.remote_realm.org_type = org_type
self.remote_realm.save(update_fields=["org_type"]) self.remote_realm.save(update_fields=["org_type"])
def sync_license_ledger_if_needed(self) -> None:
# Updates the license ledger based on RemoteRealmAuditLog
# entries.
#
# Supports backfilling entries from weeks if the past if
# needed when we receive audit logs, making any end-of-cycle
# updates that happen to be scheduled inside the interval that
# we are processing.
#
# But this support is fragile, in that it does not handle the
# possibility that some other code path changed or ended the
# customer's current plan at some point after
# last_ledger.event_time but before the event times for the
# audit logs we will be processing.
customer = self.get_customer()
if customer is None:
return
plan = get_current_plan_by_customer(customer)
if plan is None:
return
if not plan.automanage_licenses:
return
# It's an invariant that any current plan have at least an
# initial ledger entry.
last_ledger = LicenseLedger.objects.filter(plan=plan).order_by("id").last()
assert last_ledger is not None
# New audit logs since last_ledger for the plan was created.
new_audit_logs = (
RemoteRealmAuditLog.objects.filter(
remote_realm=self.remote_realm,
event_time__gt=last_ledger.event_time,
event_type__in=RemoteRealmAuditLog.SYNCED_BILLING_EVENTS,
)
.exclude(extra_data={})
.order_by("event_time")
)
for audit_log in new_audit_logs:
plan = self.update_license_ledger_for_automanaged_plan(plan, audit_log.event_time)
if plan is None:
return
class RemoteServerBillingSession(BillingSession): # nocoverage class RemoteServerBillingSession(BillingSession): # nocoverage
"""Billing session for pre-8.0 servers that do not yet support """Billing session for pre-8.0 servers that do not yet support
@ -3118,10 +3169,12 @@ class RemoteServerBillingSession(BillingSession): # nocoverage
return self.remote_server.contact_email return self.remote_server.contact_email
@override @override
def current_count_for_billed_licenses(self) -> int: def current_count_for_billed_licenses(self, event_time: datetime = timezone_now()) -> int:
if has_stale_audit_log(self.remote_server): if has_stale_audit_log(self.remote_server):
raise MissingDataError raise MissingDataError
remote_server_counts = get_remote_server_guest_and_non_guest_count(self.remote_server.id) remote_server_counts = get_remote_server_guest_and_non_guest_count(
self.remote_server.id, event_time
)
return remote_server_counts.non_guest_user_count + remote_server_counts.guest_user_count return remote_server_counts.non_guest_user_count + remote_server_counts.guest_user_count
@override @override

View File

@ -753,6 +753,7 @@ def remote_server_post_analytics(
# updated last_audit_log_update even if there are no new rows, # updated last_audit_log_update even if there are no new rows,
# to help identify server whose ability to connect to this # to help identify server whose ability to connect to this
# endpoint is broken by a networking problem. # endpoint is broken by a networking problem.
remote_realms_set = set()
remote_realm_audit_logs = [] remote_realm_audit_logs = []
for row in realmauditlog_rows: for row in realmauditlog_rows:
extra_data = {} extra_data = {}
@ -764,6 +765,7 @@ def remote_server_post_analytics(
elif row.extra_data is not None: elif row.extra_data is not None:
assert isinstance(row.extra_data, dict) assert isinstance(row.extra_data, dict)
extra_data = row.extra_data extra_data = row.extra_data
remote_realms_set.add(realm_id_to_remote_realm.get(row.realm))
remote_realm_audit_logs.append( remote_realm_audit_logs.append(
RemoteRealmAuditLog( RemoteRealmAuditLog(
remote_realm=realm_id_to_remote_realm.get(row.realm), remote_realm=realm_id_to_remote_realm.get(row.realm),
@ -777,6 +779,15 @@ def remote_server_post_analytics(
) )
) )
batch_create_table_data(server, RemoteRealmAuditLog, remote_realm_audit_logs) batch_create_table_data(server, RemoteRealmAuditLog, remote_realm_audit_logs)
# Update LicenseLedger using logs in RemoteRealmAuditlog.
for remote_realm in remote_realms_set:
if remote_realm:
billing_session = RemoteRealmBillingSession(remote_realm=remote_realm)
billing_session.sync_license_ledger_if_needed()
# Do this last, so we can assume LicenseLedger is always
# up-to-date through last_audit_log_update.
RemoteZulipServer.objects.filter(uuid=server.uuid).update( RemoteZulipServer.objects.filter(uuid=server.uuid).update(
last_audit_log_update=timezone_now() last_audit_log_update=timezone_now()
) )