diff --git a/analytics/tests/test_activity_views.py b/analytics/tests/test_activity_views.py index a1b8140bb9..b18bf397f4 100644 --- a/analytics/tests/test_activity_views.py +++ b/analytics/tests/test_activity_views.py @@ -187,7 +187,7 @@ class ActivityTest(ZulipTestCase): event_time=timezone_now() - timedelta(days=1), extra_data=extra_data, ) - with self.assert_database_query_count(11): + with self.assert_database_query_count(9): result = self.client_get("/activity/remote") self.assertEqual(result.status_code, 200) @@ -362,6 +362,6 @@ class ActivityTest(ZulipTestCase): add_audit_log_data(realm.server, remote_realm=realm, realm_id=None) self.login("iago") - with self.assert_database_query_count(27): + with self.assert_database_query_count(11): result = self.client_get("/activity/remote") self.assertEqual(result.status_code, 200) diff --git a/corporate/lib/analytics.py b/corporate/lib/analytics.py index f69b8745bc..7a88532451 100644 --- a/corporate/lib/analytics.py +++ b/corporate/lib/analytics.py @@ -4,6 +4,7 @@ from datetime import datetime from decimal import Decimal from typing import Any, Dict, List +from django.db.models import Prefetch from django.utils.timezone import now as timezone_now from corporate.lib.stripe import ( @@ -11,7 +12,7 @@ from corporate.lib.stripe import ( RemoteRealmBillingSession, RemoteServerBillingSession, ) -from corporate.models import Customer, CustomerPlan +from corporate.models import Customer, CustomerPlan, LicenseLedger from zerver.lib.utils import assert_is_not_none from zilencer.models import ( RemoteCustomerUserCount, @@ -44,41 +45,75 @@ def get_realms_with_default_discount_dict() -> Dict[str, Decimal]: def estimate_annual_recurring_revenue_by_realm() -> Dict[str, int]: # nocoverage annual_revenue = {} - for plan in CustomerPlan.objects.filter(status=CustomerPlan.ACTIVE).select_related( - "customer__realm" - ): - if plan.customer.realm is not None: - # TODO: figure out what to do for plans that don't automatically - # renew, but which probably will renew - renewal_cents = RealmBillingSession( - realm=plan.customer.realm - ).get_customer_plan_renewal_amount(plan, timezone_now()) - if plan.billing_schedule == CustomerPlan.BILLING_SCHEDULE_MONTHLY: - renewal_cents *= 12 - # TODO: Decimal stuff - annual_revenue[plan.customer.realm.string_id] = int(renewal_cents / 100) + plans = ( + CustomerPlan.objects.filter( + status=CustomerPlan.ACTIVE, + customer__remote_realm__isnull=True, + customer__remote_server__isnull=True, + ) + .prefetch_related( + Prefetch( + "licenseledger_set", + queryset=LicenseLedger.objects.order_by("plan", "-id").distinct("plan"), + to_attr="latest_ledger_entry", + ) + ) + .select_related("customer__realm") + ) + + for plan in plans: + assert plan.customer.realm is not None + latest_ledger_entry = plan.latest_ledger_entry[0] # type: ignore[attr-defined] # attribute from prefetch_related query + assert latest_ledger_entry is not None + renewal_cents = RealmBillingSession( + realm=plan.customer.realm + ).get_customer_plan_renewal_amount(plan, latest_ledger_entry) + if plan.billing_schedule == CustomerPlan.BILLING_SCHEDULE_MONTHLY: + renewal_cents *= 12 + annual_revenue[plan.customer.realm.string_id] = int(renewal_cents / 100) return annual_revenue def get_plan_data_by_remote_server() -> Dict[int, RemoteActivityPlanData]: # nocoverage remote_server_plan_data: Dict[int, RemoteActivityPlanData] = {} - for plan in CustomerPlan.objects.filter( - status__lt=CustomerPlan.LIVE_STATUS_THRESHOLD, - customer__realm__isnull=True, - customer__remote_realm__isnull=True, - customer__remote_server__deactivated=False, - ).select_related("customer__remote_server"): + plans = ( + CustomerPlan.objects.filter( + status__lt=CustomerPlan.LIVE_STATUS_THRESHOLD, + customer__realm__isnull=True, + customer__remote_realm__isnull=True, + customer__remote_server__deactivated=False, + ) + .prefetch_related( + Prefetch( + "licenseledger_set", + queryset=LicenseLedger.objects.order_by("plan", "-id").distinct("plan"), + to_attr="latest_ledger_entry", + ) + ) + .select_related("customer__remote_server") + ) + + for plan in plans: renewal_cents = 0 server_id = None assert plan.customer.remote_server is not None server_id = plan.customer.remote_server.id - renewal_cents = RemoteServerBillingSession( - remote_server=plan.customer.remote_server - ).get_customer_plan_renewal_amount(plan, timezone_now()) - assert server_id is not None - + latest_ledger_entry = plan.latest_ledger_entry[0] # type: ignore[attr-defined] # attribute from prefetch_related query + assert latest_ledger_entry is not None + if plan.tier in ( + CustomerPlan.TIER_SELF_HOSTED_LEGACY, + CustomerPlan.TIER_SELF_HOSTED_COMMUNITY, + ) or plan.status in ( + CustomerPlan.DOWNGRADE_AT_END_OF_FREE_TRIAL, + CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE, + ): + renewal_cents = 0 + else: + renewal_cents = RemoteServerBillingSession( + remote_server=plan.customer.remote_server + ).get_customer_plan_renewal_amount(plan, latest_ledger_entry) if plan.billing_schedule == CustomerPlan.BILLING_SCHEDULE_MONTHLY: renewal_cents *= 12 @@ -104,24 +139,45 @@ def get_plan_data_by_remote_server() -> Dict[int, RemoteActivityPlanData]: # no def get_plan_data_by_remote_realm() -> Dict[int, Dict[int, RemoteActivityPlanData]]: # nocoverage remote_server_plan_data_by_realm: Dict[int, Dict[int, RemoteActivityPlanData]] = {} - for plan in CustomerPlan.objects.filter( - status__lt=CustomerPlan.LIVE_STATUS_THRESHOLD, - customer__realm__isnull=True, - customer__remote_server__isnull=True, - customer__remote_realm__is_system_bot_realm=False, - customer__remote_realm__realm_deactivated=False, - ).select_related("customer__remote_realm"): + plans = ( + CustomerPlan.objects.filter( + status__lt=CustomerPlan.LIVE_STATUS_THRESHOLD, + customer__realm__isnull=True, + customer__remote_server__isnull=True, + customer__remote_realm__is_system_bot_realm=False, + customer__remote_realm__realm_deactivated=False, + ) + .prefetch_related( + Prefetch( + "licenseledger_set", + queryset=LicenseLedger.objects.order_by("plan", "-id").distinct("plan"), + to_attr="latest_ledger_entry", + ) + ) + .select_related("customer__remote_realm") + ) + + for plan in plans: renewal_cents = 0 server_id = None assert plan.customer.remote_realm is not None - server_id = plan.customer.remote_realm.server.id - renewal_cents = RemoteRealmBillingSession( - remote_realm=plan.customer.remote_realm - ).get_customer_plan_renewal_amount(plan, timezone_now()) - + server_id = plan.customer.remote_realm.server_id assert server_id is not None - + latest_ledger_entry = plan.latest_ledger_entry[0] # type: ignore[attr-defined] # attribute from prefetch_related query + assert latest_ledger_entry is not None + if plan.tier in ( + CustomerPlan.TIER_SELF_HOSTED_LEGACY, + CustomerPlan.TIER_SELF_HOSTED_COMMUNITY, + ) or plan.status in ( + CustomerPlan.DOWNGRADE_AT_END_OF_FREE_TRIAL, + CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE, + ): + renewal_cents = 0 + else: + renewal_cents = RemoteRealmBillingSession( + remote_realm=plan.customer.remote_realm + ).get_customer_plan_renewal_amount(plan, latest_ledger_entry) if plan.billing_schedule == CustomerPlan.BILLING_SCHEDULE_MONTHLY: renewal_cents *= 12 @@ -175,6 +231,7 @@ def get_remote_realm_user_counts( .exclude(extra_data={}) .order_by("remote_realm", "-event_time") .distinct("remote_realm") + .select_related("remote_realm") ): assert log.remote_realm is not None user_counts_by_realm[log.remote_realm.id] = get_remote_customer_user_count([log]) diff --git a/corporate/lib/stripe.py b/corporate/lib/stripe.py index f5fd997e28..6ecdad6b06 100644 --- a/corporate/lib/stripe.py +++ b/corporate/lib/stripe.py @@ -1847,20 +1847,12 @@ class BillingSession(ABC): def get_customer_plan_renewal_amount( self, plan: CustomerPlan, - event_time: datetime, - last_ledger_entry: Optional[LicenseLedger] = None, + last_ledger_entry: LicenseLedger, ) -> int: if plan.fixed_price is not None: return plan.fixed_price - new_plan = None - if last_ledger_entry is None: - new_plan, last_ledger_entry = self.make_end_of_cycle_updates_if_needed(plan, event_time) - if last_ledger_entry is None: - return 0 # nocoverage if last_ledger_entry.licenses_at_next_renewal is None: return 0 # nocoverage - if new_plan is not None: - plan = new_plan # nocoverage assert plan.price_per_license is not None # for mypy return plan.price_per_license * last_ledger_entry.licenses_at_next_renewal @@ -1919,7 +1911,7 @@ class BillingSession(ABC): num_months_next_cycle = ( 12 if plan.billing_schedule == CustomerPlan.BILLING_SCHEDULE_ANNUAL else 1 ) - renewal_cents = self.get_customer_plan_renewal_amount(plan, now, last_ledger_entry) + renewal_cents = self.get_customer_plan_renewal_amount(plan, last_ledger_entry) if plan.price_per_license is None: price_per_license = "" diff --git a/corporate/lib/support.py b/corporate/lib/support.py index a9b5eb47e4..8ffe4ba55b 100644 --- a/corporate/lib/support.py +++ b/corporate/lib/support.py @@ -203,12 +203,15 @@ def get_current_plan_data_for_support_view(billing_session: BillingSession) -> P ) plan_data.has_fixed_price = plan_data.current_plan.fixed_price is not None annual_invoice_count = get_annual_invoice_count(plan_data.current_plan.billing_schedule) - plan_data.annual_recurring_revenue = ( - billing_session.get_customer_plan_renewal_amount( - plan_data.current_plan, timezone_now(), last_ledger_entry + if last_ledger_entry is not None: + plan_data.annual_recurring_revenue = ( + billing_session.get_customer_plan_renewal_amount( + plan_data.current_plan, last_ledger_entry + ) + * annual_invoice_count ) - * annual_invoice_count - ) + else: + plan_data.annual_recurring_revenue = 0 # nocoverage return plan_data