remote-activity: Prefetch LicenseLedger entry for current plan.

To estimate the annual recurring revenue for remote server and
remote realm CustomerPlans, we prefetch the current license
ledger as part of the CustomerPlan query.

Also adds a select related to the remote realm audit log query
so that we don't go to the database for the remote realm ID.

With the test added in the previous commit, the query count for
the remote activity view goes from 27 to 11, as we are no longer
hitting the database multiple times for every current plan or
for every remote realm with audit log data.

Refactors get_customer_plan_renewal_amount so that a license
ledger is always passed and make_end_of_cycle_updates_if_needed
does not need to be called.
This commit is contained in:
Lauryn Menard 2024-01-22 17:12:11 +01:00 committed by Tim Abbott
parent 842dcb6546
commit dcae35196c
4 changed files with 107 additions and 55 deletions

View File

@ -187,7 +187,7 @@ class ActivityTest(ZulipTestCase):
event_time=timezone_now() - timedelta(days=1), event_time=timezone_now() - timedelta(days=1),
extra_data=extra_data, extra_data=extra_data,
) )
with self.assert_database_query_count(11): with self.assert_database_query_count(9):
result = self.client_get("/activity/remote") result = self.client_get("/activity/remote")
self.assertEqual(result.status_code, 200) self.assertEqual(result.status_code, 200)
@ -362,6 +362,6 @@ class ActivityTest(ZulipTestCase):
add_audit_log_data(realm.server, remote_realm=realm, realm_id=None) add_audit_log_data(realm.server, remote_realm=realm, realm_id=None)
self.login("iago") self.login("iago")
with self.assert_database_query_count(27): with self.assert_database_query_count(11):
result = self.client_get("/activity/remote") result = self.client_get("/activity/remote")
self.assertEqual(result.status_code, 200) self.assertEqual(result.status_code, 200)

View File

@ -4,6 +4,7 @@ from datetime import datetime
from decimal import Decimal from decimal import Decimal
from typing import Any, Dict, List from typing import Any, Dict, List
from django.db.models import Prefetch
from django.utils.timezone import now as timezone_now from django.utils.timezone import now as timezone_now
from corporate.lib.stripe import ( from corporate.lib.stripe import (
@ -11,7 +12,7 @@ from corporate.lib.stripe import (
RemoteRealmBillingSession, RemoteRealmBillingSession,
RemoteServerBillingSession, RemoteServerBillingSession,
) )
from corporate.models import Customer, CustomerPlan from corporate.models import Customer, CustomerPlan, LicenseLedger
from zerver.lib.utils import assert_is_not_none from zerver.lib.utils import assert_is_not_none
from zilencer.models import ( from zilencer.models import (
RemoteCustomerUserCount, RemoteCustomerUserCount,
@ -44,41 +45,75 @@ def get_realms_with_default_discount_dict() -> Dict[str, Decimal]:
def estimate_annual_recurring_revenue_by_realm() -> Dict[str, int]: # nocoverage def estimate_annual_recurring_revenue_by_realm() -> Dict[str, int]: # nocoverage
annual_revenue = {} annual_revenue = {}
for plan in CustomerPlan.objects.filter(status=CustomerPlan.ACTIVE).select_related( plans = (
"customer__realm" CustomerPlan.objects.filter(
): status=CustomerPlan.ACTIVE,
if plan.customer.realm is not None: customer__remote_realm__isnull=True,
# TODO: figure out what to do for plans that don't automatically customer__remote_server__isnull=True,
# renew, but which probably will renew )
renewal_cents = RealmBillingSession( .prefetch_related(
realm=plan.customer.realm Prefetch(
).get_customer_plan_renewal_amount(plan, timezone_now()) "licenseledger_set",
if plan.billing_schedule == CustomerPlan.BILLING_SCHEDULE_MONTHLY: queryset=LicenseLedger.objects.order_by("plan", "-id").distinct("plan"),
renewal_cents *= 12 to_attr="latest_ledger_entry",
# TODO: Decimal stuff )
annual_revenue[plan.customer.realm.string_id] = int(renewal_cents / 100) )
.select_related("customer__realm")
)
for plan in plans:
assert plan.customer.realm is not None
latest_ledger_entry = plan.latest_ledger_entry[0] # type: ignore[attr-defined] # attribute from prefetch_related query
assert latest_ledger_entry is not None
renewal_cents = RealmBillingSession(
realm=plan.customer.realm
).get_customer_plan_renewal_amount(plan, latest_ledger_entry)
if plan.billing_schedule == CustomerPlan.BILLING_SCHEDULE_MONTHLY:
renewal_cents *= 12
annual_revenue[plan.customer.realm.string_id] = int(renewal_cents / 100)
return annual_revenue return annual_revenue
def get_plan_data_by_remote_server() -> Dict[int, RemoteActivityPlanData]: # nocoverage def get_plan_data_by_remote_server() -> Dict[int, RemoteActivityPlanData]: # nocoverage
remote_server_plan_data: Dict[int, RemoteActivityPlanData] = {} remote_server_plan_data: Dict[int, RemoteActivityPlanData] = {}
for plan in CustomerPlan.objects.filter( plans = (
status__lt=CustomerPlan.LIVE_STATUS_THRESHOLD, CustomerPlan.objects.filter(
customer__realm__isnull=True, status__lt=CustomerPlan.LIVE_STATUS_THRESHOLD,
customer__remote_realm__isnull=True, customer__realm__isnull=True,
customer__remote_server__deactivated=False, customer__remote_realm__isnull=True,
).select_related("customer__remote_server"): customer__remote_server__deactivated=False,
)
.prefetch_related(
Prefetch(
"licenseledger_set",
queryset=LicenseLedger.objects.order_by("plan", "-id").distinct("plan"),
to_attr="latest_ledger_entry",
)
)
.select_related("customer__remote_server")
)
for plan in plans:
renewal_cents = 0 renewal_cents = 0
server_id = None server_id = None
assert plan.customer.remote_server is not None assert plan.customer.remote_server is not None
server_id = plan.customer.remote_server.id server_id = plan.customer.remote_server.id
renewal_cents = RemoteServerBillingSession(
remote_server=plan.customer.remote_server
).get_customer_plan_renewal_amount(plan, timezone_now())
assert server_id is not None assert server_id is not None
latest_ledger_entry = plan.latest_ledger_entry[0] # type: ignore[attr-defined] # attribute from prefetch_related query
assert latest_ledger_entry is not None
if plan.tier in (
CustomerPlan.TIER_SELF_HOSTED_LEGACY,
CustomerPlan.TIER_SELF_HOSTED_COMMUNITY,
) or plan.status in (
CustomerPlan.DOWNGRADE_AT_END_OF_FREE_TRIAL,
CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE,
):
renewal_cents = 0
else:
renewal_cents = RemoteServerBillingSession(
remote_server=plan.customer.remote_server
).get_customer_plan_renewal_amount(plan, latest_ledger_entry)
if plan.billing_schedule == CustomerPlan.BILLING_SCHEDULE_MONTHLY: if plan.billing_schedule == CustomerPlan.BILLING_SCHEDULE_MONTHLY:
renewal_cents *= 12 renewal_cents *= 12
@ -104,24 +139,45 @@ def get_plan_data_by_remote_server() -> Dict[int, RemoteActivityPlanData]: # no
def get_plan_data_by_remote_realm() -> Dict[int, Dict[int, RemoteActivityPlanData]]: # nocoverage def get_plan_data_by_remote_realm() -> Dict[int, Dict[int, RemoteActivityPlanData]]: # nocoverage
remote_server_plan_data_by_realm: Dict[int, Dict[int, RemoteActivityPlanData]] = {} remote_server_plan_data_by_realm: Dict[int, Dict[int, RemoteActivityPlanData]] = {}
for plan in CustomerPlan.objects.filter( plans = (
status__lt=CustomerPlan.LIVE_STATUS_THRESHOLD, CustomerPlan.objects.filter(
customer__realm__isnull=True, status__lt=CustomerPlan.LIVE_STATUS_THRESHOLD,
customer__remote_server__isnull=True, customer__realm__isnull=True,
customer__remote_realm__is_system_bot_realm=False, customer__remote_server__isnull=True,
customer__remote_realm__realm_deactivated=False, customer__remote_realm__is_system_bot_realm=False,
).select_related("customer__remote_realm"): customer__remote_realm__realm_deactivated=False,
)
.prefetch_related(
Prefetch(
"licenseledger_set",
queryset=LicenseLedger.objects.order_by("plan", "-id").distinct("plan"),
to_attr="latest_ledger_entry",
)
)
.select_related("customer__remote_realm")
)
for plan in plans:
renewal_cents = 0 renewal_cents = 0
server_id = None server_id = None
assert plan.customer.remote_realm is not None assert plan.customer.remote_realm is not None
server_id = plan.customer.remote_realm.server.id server_id = plan.customer.remote_realm.server_id
renewal_cents = RemoteRealmBillingSession(
remote_realm=plan.customer.remote_realm
).get_customer_plan_renewal_amount(plan, timezone_now())
assert server_id is not None assert server_id is not None
latest_ledger_entry = plan.latest_ledger_entry[0] # type: ignore[attr-defined] # attribute from prefetch_related query
assert latest_ledger_entry is not None
if plan.tier in (
CustomerPlan.TIER_SELF_HOSTED_LEGACY,
CustomerPlan.TIER_SELF_HOSTED_COMMUNITY,
) or plan.status in (
CustomerPlan.DOWNGRADE_AT_END_OF_FREE_TRIAL,
CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE,
):
renewal_cents = 0
else:
renewal_cents = RemoteRealmBillingSession(
remote_realm=plan.customer.remote_realm
).get_customer_plan_renewal_amount(plan, latest_ledger_entry)
if plan.billing_schedule == CustomerPlan.BILLING_SCHEDULE_MONTHLY: if plan.billing_schedule == CustomerPlan.BILLING_SCHEDULE_MONTHLY:
renewal_cents *= 12 renewal_cents *= 12
@ -175,6 +231,7 @@ def get_remote_realm_user_counts(
.exclude(extra_data={}) .exclude(extra_data={})
.order_by("remote_realm", "-event_time") .order_by("remote_realm", "-event_time")
.distinct("remote_realm") .distinct("remote_realm")
.select_related("remote_realm")
): ):
assert log.remote_realm is not None assert log.remote_realm is not None
user_counts_by_realm[log.remote_realm.id] = get_remote_customer_user_count([log]) user_counts_by_realm[log.remote_realm.id] = get_remote_customer_user_count([log])

View File

@ -1847,20 +1847,12 @@ class BillingSession(ABC):
def get_customer_plan_renewal_amount( def get_customer_plan_renewal_amount(
self, self,
plan: CustomerPlan, plan: CustomerPlan,
event_time: datetime, last_ledger_entry: LicenseLedger,
last_ledger_entry: Optional[LicenseLedger] = None,
) -> int: ) -> int:
if plan.fixed_price is not None: if plan.fixed_price is not None:
return plan.fixed_price return plan.fixed_price
new_plan = None
if last_ledger_entry is None:
new_plan, last_ledger_entry = self.make_end_of_cycle_updates_if_needed(plan, event_time)
if last_ledger_entry is None:
return 0 # nocoverage
if last_ledger_entry.licenses_at_next_renewal is None: if last_ledger_entry.licenses_at_next_renewal is None:
return 0 # nocoverage return 0 # nocoverage
if new_plan is not None:
plan = new_plan # nocoverage
assert plan.price_per_license is not None # for mypy assert plan.price_per_license is not None # for mypy
return plan.price_per_license * last_ledger_entry.licenses_at_next_renewal return plan.price_per_license * last_ledger_entry.licenses_at_next_renewal
@ -1919,7 +1911,7 @@ class BillingSession(ABC):
num_months_next_cycle = ( num_months_next_cycle = (
12 if plan.billing_schedule == CustomerPlan.BILLING_SCHEDULE_ANNUAL else 1 12 if plan.billing_schedule == CustomerPlan.BILLING_SCHEDULE_ANNUAL else 1
) )
renewal_cents = self.get_customer_plan_renewal_amount(plan, now, last_ledger_entry) renewal_cents = self.get_customer_plan_renewal_amount(plan, last_ledger_entry)
if plan.price_per_license is None: if plan.price_per_license is None:
price_per_license = "" price_per_license = ""

View File

@ -203,12 +203,15 @@ def get_current_plan_data_for_support_view(billing_session: BillingSession) -> P
) )
plan_data.has_fixed_price = plan_data.current_plan.fixed_price is not None plan_data.has_fixed_price = plan_data.current_plan.fixed_price is not None
annual_invoice_count = get_annual_invoice_count(plan_data.current_plan.billing_schedule) annual_invoice_count = get_annual_invoice_count(plan_data.current_plan.billing_schedule)
plan_data.annual_recurring_revenue = ( if last_ledger_entry is not None:
billing_session.get_customer_plan_renewal_amount( plan_data.annual_recurring_revenue = (
plan_data.current_plan, timezone_now(), last_ledger_entry billing_session.get_customer_plan_renewal_amount(
plan_data.current_plan, last_ledger_entry
)
* annual_invoice_count
) )
* annual_invoice_count else:
) plan_data.annual_recurring_revenue = 0 # nocoverage
return plan_data return plan_data