diff --git a/corporate/lib/stripe.py b/corporate/lib/stripe.py index 70cd46248a..aa4a1870d8 100644 --- a/corporate/lib/stripe.py +++ b/corporate/lib/stripe.py @@ -361,6 +361,10 @@ class BillingSession(ABC): def get_customer(self) -> Optional[Customer]: pass + @abstractmethod + def current_count_for_billed_licenses(self) -> int: + pass + @abstractmethod def get_audit_log_event(self, event_type: AuditLogEventType) -> int: pass @@ -617,6 +621,10 @@ class RealmBillingSession(BillingSession): def get_customer(self) -> Optional[Customer]: return get_customer_by_realm(self.realm) + @override + def current_count_for_billed_licenses(self) -> int: + return get_latest_seat_count(self.realm) + @override def get_audit_log_event(self, event_type: AuditLogEventType) -> int: if event_type is AuditLogEventType.STRIPE_CUSTOMER_CREATED: @@ -1064,7 +1072,8 @@ def process_initial_upgrade( else: # billed_licenses can be greater than licenses if users are added between the start of # this function (process_initial_upgrade) and now - billed_licenses = max(get_latest_seat_count(customer.realm), licenses) + current_licenses_count = billing_session.current_count_for_billed_licenses() + billed_licenses = max(current_licenses_count, licenses) plan_params = { "automanage_licenses": automanage_licenses, "charge_automatically": charge_automatically,