diff --git a/corporate/lib/stripe.py b/corporate/lib/stripe.py index 18cdbe803e..caba02ee7f 100644 --- a/corporate/lib/stripe.py +++ b/corporate/lib/stripe.py @@ -683,7 +683,7 @@ class BillingSession(ABC): pass @abstractmethod - def current_count_for_billed_licenses(self) -> int: + def current_count_for_billed_licenses(self, event_time: datetime = timezone_now()) -> int: pass @abstractmethod @@ -2422,13 +2422,13 @@ class BillingSession(ABC): def update_license_ledger_for_automanaged_plan( self, plan: CustomerPlan, event_time: datetime - ) -> None: + ) -> Optional[CustomerPlan]: new_plan, last_ledger_entry = self.make_end_of_cycle_updates_if_needed(plan, event_time) if last_ledger_entry is None: - return + return None if new_plan is not None: plan = new_plan - licenses_at_next_renewal = self.current_count_for_billed_licenses() + licenses_at_next_renewal = self.current_count_for_billed_licenses(event_time) licenses = max(licenses_at_next_renewal, last_ledger_entry.licenses) LicenseLedger.objects.create( @@ -2438,16 +2438,10 @@ class BillingSession(ABC): licenses_at_next_renewal=licenses_at_next_renewal, ) - def update_license_ledger_if_needed(self, event_time: datetime) -> None: - customer = self.get_customer() - if customer is None: - return - plan = get_current_plan_by_customer(customer) - if plan is None: - return - if not plan.automanage_licenses: - return - self.update_license_ledger_for_automanaged_plan(plan, event_time) + # Returning plan is particularly helpful for 'sync_license_ledger_if_needed'. + # If a new plan is created during the end of cycle update, then that function + # needs the updated plan for a correct LicenseLedger update. + return plan class RealmBillingSession(BillingSession): @@ -2507,7 +2501,7 @@ class RealmBillingSession(BillingSession): return self.user.delivery_email @override - def current_count_for_billed_licenses(self) -> int: + def current_count_for_billed_licenses(self, event_time: datetime = timezone_now()) -> int: return get_latest_seat_count(self.realm) @override @@ -2759,6 +2753,17 @@ class RealmBillingSession(BillingSession): self.realm.org_type = org_type self.realm.save(update_fields=["org_type"]) + def update_license_ledger_if_needed(self, event_time: datetime) -> None: + customer = self.get_customer() + if customer is None: + return + plan = get_current_plan_by_customer(customer) + if plan is None: + return + if not plan.automanage_licenses: + return + self.update_license_ledger_for_automanaged_plan(plan, event_time) + class RemoteRealmBillingSession(BillingSession): # nocoverage def __init__( @@ -2803,10 +2808,12 @@ class RemoteRealmBillingSession(BillingSession): # nocoverage return self.remote_realm.server.contact_email @override - def current_count_for_billed_licenses(self) -> int: + def current_count_for_billed_licenses(self, event_time: datetime = timezone_now()) -> int: if has_stale_audit_log(self.remote_realm.server): raise MissingDataError - remote_realm_counts = get_remote_realm_guest_and_non_guest_count(self.remote_realm) + remote_realm_counts = get_remote_realm_guest_and_non_guest_count( + self.remote_realm, event_time + ) return remote_realm_counts.non_guest_user_count + remote_realm_counts.guest_user_count @override @@ -3073,6 +3080,50 @@ class RemoteRealmBillingSession(BillingSession): # nocoverage self.remote_realm.org_type = org_type self.remote_realm.save(update_fields=["org_type"]) + def sync_license_ledger_if_needed(self) -> None: + # Updates the license ledger based on RemoteRealmAuditLog + # entries. + # + # Supports backfilling entries from weeks if the past if + # needed when we receive audit logs, making any end-of-cycle + # updates that happen to be scheduled inside the interval that + # we are processing. + # + # But this support is fragile, in that it does not handle the + # possibility that some other code path changed or ended the + # customer's current plan at some point after + # last_ledger.event_time but before the event times for the + # audit logs we will be processing. + customer = self.get_customer() + if customer is None: + return + plan = get_current_plan_by_customer(customer) + if plan is None: + return + if not plan.automanage_licenses: + return + + # It's an invariant that any current plan have at least an + # initial ledger entry. + last_ledger = LicenseLedger.objects.filter(plan=plan).order_by("id").last() + assert last_ledger is not None + + # New audit logs since last_ledger for the plan was created. + new_audit_logs = ( + RemoteRealmAuditLog.objects.filter( + remote_realm=self.remote_realm, + event_time__gt=last_ledger.event_time, + event_type__in=RemoteRealmAuditLog.SYNCED_BILLING_EVENTS, + ) + .exclude(extra_data={}) + .order_by("event_time") + ) + + for audit_log in new_audit_logs: + plan = self.update_license_ledger_for_automanaged_plan(plan, audit_log.event_time) + if plan is None: + return + class RemoteServerBillingSession(BillingSession): # nocoverage """Billing session for pre-8.0 servers that do not yet support @@ -3118,10 +3169,12 @@ class RemoteServerBillingSession(BillingSession): # nocoverage return self.remote_server.contact_email @override - def current_count_for_billed_licenses(self) -> int: + def current_count_for_billed_licenses(self, event_time: datetime = timezone_now()) -> int: if has_stale_audit_log(self.remote_server): raise MissingDataError - remote_server_counts = get_remote_server_guest_and_non_guest_count(self.remote_server.id) + remote_server_counts = get_remote_server_guest_and_non_guest_count( + self.remote_server.id, event_time + ) return remote_server_counts.non_guest_user_count + remote_server_counts.guest_user_count @override diff --git a/zilencer/views.py b/zilencer/views.py index 500bd42d63..c81c7556db 100644 --- a/zilencer/views.py +++ b/zilencer/views.py @@ -753,6 +753,7 @@ def remote_server_post_analytics( # updated last_audit_log_update even if there are no new rows, # to help identify server whose ability to connect to this # endpoint is broken by a networking problem. + remote_realms_set = set() remote_realm_audit_logs = [] for row in realmauditlog_rows: extra_data = {} @@ -764,6 +765,7 @@ def remote_server_post_analytics( elif row.extra_data is not None: assert isinstance(row.extra_data, dict) extra_data = row.extra_data + remote_realms_set.add(realm_id_to_remote_realm.get(row.realm)) remote_realm_audit_logs.append( RemoteRealmAuditLog( remote_realm=realm_id_to_remote_realm.get(row.realm), @@ -777,6 +779,15 @@ def remote_server_post_analytics( ) ) batch_create_table_data(server, RemoteRealmAuditLog, remote_realm_audit_logs) + + # Update LicenseLedger using logs in RemoteRealmAuditlog. + for remote_realm in remote_realms_set: + if remote_realm: + billing_session = RemoteRealmBillingSession(remote_realm=remote_realm) + billing_session.sync_license_ledger_if_needed() + + # Do this last, so we can assume LicenseLedger is always + # up-to-date through last_audit_log_update. RemoteZulipServer.objects.filter(uuid=server.uuid).update( last_audit_log_update=timezone_now() )