From 357dceb05a64abb0a2620f158eae71fef7913710 Mon Sep 17 00:00:00 2001 From: Mateusz Mandera Date: Fri, 17 Nov 2023 14:07:41 +0100 Subject: [PATCH] typing: Rewrite remote_server_post_analytics to use @typed_endpoint. The main point is the RealmDataForAnalytics structure, which we can next re-use for other endpoints that will take it in in their params. --- zerver/lib/remote_server.py | 30 +++- zerver/tests/test_push_notifications.py | 4 +- zilencer/views.py | 175 +++++++++--------------- 3 files changed, 93 insertions(+), 116 deletions(-) diff --git a/zerver/lib/remote_server.py b/zerver/lib/remote_server.py index 89b5bb771a..8ca378df96 100644 --- a/zerver/lib/remote_server.py +++ b/zerver/lib/remote_server.py @@ -7,6 +7,7 @@ import requests from django.conf import settings from django.forms.models import model_to_dict from django.utils.translation import gettext as _ +from pydantic import BaseModel, ConfigDict from analytics.models import InstallationCount, RealmCount from version import ZULIP_VERSION @@ -29,6 +30,19 @@ class PushNotificationBouncerRetryLaterError(JsonableError): http_status_code = 502 +class RealmDataForAnalytics(BaseModel): + model_config = ConfigDict(extra="forbid") + + id: int + host: str + url: str + date_created: float + deactivated: bool + + uuid: str + uuid_owner_secret: str + + def send_to_push_bouncer( method: str, endpoint: str, @@ -172,10 +186,10 @@ def build_analytics_data( ) -def get_realms_info_for_push_bouncer() -> List[Dict[str, Any]]: +def get_realms_info_for_push_bouncer() -> List[RealmDataForAnalytics]: realms = Realm.objects.order_by("id") - realm_info_dicts = [ - dict( + realm_info_list = [ + RealmDataForAnalytics( id=realm.id, uuid=str(realm.uuid), uuid_owner_secret=realm.uuid_owner_secret, @@ -187,7 +201,7 @@ def get_realms_info_for_push_bouncer() -> List[Dict[str, Any]]: for realm in realms ] - return realm_info_dicts + return realm_info_list def send_analytics_to_push_bouncer() -> None: @@ -221,7 +235,9 @@ def send_analytics_to_push_bouncer() -> None: "realm_counts": orjson.dumps(realm_count_data).decode(), "installation_counts": orjson.dumps(installation_count_data).decode(), "realmauditlog_rows": orjson.dumps(realmauditlog_data).decode(), - "realms": orjson.dumps(get_realms_info_for_push_bouncer()).decode(), + "realms": orjson.dumps( + [dict(realm_data) for realm_data in get_realms_info_for_push_bouncer()] + ).decode(), "version": orjson.dumps(ZULIP_VERSION).decode(), } @@ -235,7 +251,9 @@ def send_realms_only_to_push_bouncer() -> None: request = { "realm_counts": "[]", "installation_counts": "[]", - "realms": orjson.dumps(get_realms_info_for_push_bouncer()).decode(), + "realms": orjson.dumps( + [dict(realm_data) for realm_data in get_realms_info_for_push_bouncer()] + ).decode(), "version": orjson.dumps(ZULIP_VERSION).decode(), } diff --git a/zerver/tests/test_push_notifications.py b/zerver/tests/test_push_notifications.py index 66fb08bac6..a580b31b55 100644 --- a/zerver/tests/test_push_notifications.py +++ b/zerver/tests/test_push_notifications.py @@ -1493,7 +1493,9 @@ class AnalyticsBouncerTest(BouncerTestCase): post_data = { "realm_counts": "[]", "installation_counts": "[]", - "realms": orjson.dumps(get_realms_info_for_push_bouncer()).decode(), + "realms": orjson.dumps( + [dict(realm_data) for realm_data in get_realms_info_for_push_bouncer()] + ).decode(), "version": orjson.dumps(ZULIP_VERSION).decode(), } mock_send_to_push_bouncer.assert_called_with( diff --git a/zilencer/views.py b/zilencer/views.py index c8fae7c8d3..f39609d8f1 100644 --- a/zilencer/views.py +++ b/zilencer/views.py @@ -15,7 +15,7 @@ from django.utils.timezone import now as timezone_now from django.utils.translation import gettext as _ from django.utils.translation import gettext as err_ from django.views.decorators.csrf import csrf_exempt -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Json from analytics.lib.counts import ( BOUNCER_ONLY_REMOTE_COUNT_STAT_PROPERTIES, @@ -33,23 +33,12 @@ from zerver.lib.push_notifications import ( send_apple_push_notification, send_test_push_notification_directly_to_devices, ) +from zerver.lib.remote_server import RealmDataForAnalytics from zerver.lib.request import REQ, has_request_variables from zerver.lib.response import json_success from zerver.lib.timestamp import timestamp_to_datetime from zerver.lib.typed_endpoint import JsonBodyPayload, typed_endpoint -from zerver.lib.validator import ( - check_bool, - check_capped_string, - check_dict, - check_dict_only, - check_float, - check_int, - check_list, - check_none_or, - check_string, - check_string_fixed_length, - check_union, -) +from zerver.lib.validator import check_capped_string, check_int, check_string_fixed_length from zerver.views.push_notifications import check_app_id, validate_token from zilencer.auth import InvalidZulipServerKeyError from zilencer.models import ( @@ -601,87 +590,65 @@ def update_remote_realm_data_for_server( RemoteRealmAuditLog.objects.bulk_create(remote_realm_audit_logs) -@has_request_variables +class RealmAuditLogDataForAnalytics(BaseModel): + id: int + realm: int + event_time: float + backfilled: bool + extra_data: Optional[Union[str, Dict[str, Any]]] + event_type: int + + +class RealmCountDataForAnalytics(BaseModel): + property: str + realm: int + id: int + end_time: float + subgroup: Optional[str] + value: int + + +class InstallationCountDataForAnalytics(BaseModel): + property: str + id: int + end_time: float + subgroup: Optional[str] + value: int + + +@typed_endpoint def remote_server_post_analytics( request: HttpRequest, server: RemoteZulipServer, - realm_counts: List[Dict[str, Any]] = REQ( - json_validator=check_list( - check_dict_only( - [ - ("property", check_string), - ("realm", check_int), - ("id", check_int), - ("end_time", check_float), - ("subgroup", check_none_or(check_string)), - ("value", check_int), - ] - ) - ) - ), - installation_counts: List[Dict[str, Any]] = REQ( - json_validator=check_list( - check_dict_only( - [ - ("property", check_string), - ("id", check_int), - ("end_time", check_float), - ("subgroup", check_none_or(check_string)), - ("value", check_int), - ] - ) - ) - ), - realmauditlog_rows: Optional[List[Dict[str, Any]]] = REQ( - json_validator=check_list( - check_dict_only( - [ - ("id", check_int), - ("realm", check_int), - ("event_time", check_float), - ("backfilled", check_bool), - ("extra_data", check_none_or(check_union([check_string, check_dict()]))), - ("event_type", check_int), - ] - ) - ), - default=None, - ), - realms: Optional[List[Dict[str, Any]]] = REQ( - # Pre-8.0 servers don't send this data. - default=None, - json_validator=check_list( - check_dict_only( - [ - ("id", check_int), - ("uuid", check_string), - ("uuid_owner_secret", check_string), - ("host", check_string), - ("url", check_string), - ("deactivated", check_bool), - ("date_created", check_float), - ] - ) - ), - ), + *, + realm_counts: Json[List[RealmCountDataForAnalytics]], + installation_counts: Json[List[InstallationCountDataForAnalytics]], + realmauditlog_rows: Optional[Json[List[RealmAuditLogDataForAnalytics]]] = None, + realms: Optional[Json[List[RealmDataForAnalytics]]] = None, ) -> HttpResponse: - validate_incoming_table_data(server, RemoteRealmCount, realm_counts, True) - validate_incoming_table_data(server, RemoteInstallationCount, installation_counts, True) + validate_incoming_table_data( + server, RemoteRealmCount, [dict(count) for count in realm_counts], True + ) + validate_incoming_table_data( + server, RemoteInstallationCount, [dict(count) for count in installation_counts], True + ) if realmauditlog_rows is not None: - validate_incoming_table_data(server, RemoteRealmAuditLog, realmauditlog_rows) + validate_incoming_table_data( + server, RemoteRealmAuditLog, [dict(row) for row in realmauditlog_rows] + ) if realms is not None: - update_remote_realm_data_for_server(server, realms) + update_remote_realm_data_for_server(server, [dict(realm) for realm in realms]) remote_realm_counts = [ RemoteRealmCount( - property=row["property"], - realm_id=row["realm"], - remote_id=row["id"], + property=row.property, + realm_id=row.realm, + remote_id=row.id, server=server, - end_time=datetime.datetime.fromtimestamp(row["end_time"], tz=datetime.timezone.utc), - subgroup=row["subgroup"], - value=row["value"], + end_time=datetime.datetime.fromtimestamp(row.end_time, tz=datetime.timezone.utc), + subgroup=row.subgroup, + value=row.value, ) for row in realm_counts ] @@ -689,12 +656,12 @@ def remote_server_post_analytics( remote_installation_counts = [ RemoteInstallationCount( - property=row["property"], - remote_id=row["id"], + property=row.property, + remote_id=row.id, server=server, - end_time=datetime.datetime.fromtimestamp(row["end_time"], tz=datetime.timezone.utc), - subgroup=row["subgroup"], - value=row["value"], + end_time=datetime.datetime.fromtimestamp(row.end_time, tz=datetime.timezone.utc), + subgroup=row.subgroup, + value=row.value, ) for row in installation_counts ] @@ -704,35 +671,25 @@ def remote_server_post_analytics( remote_realm_audit_logs = [] for row in realmauditlog_rows: extra_data = {} - # Remote servers that do support JSONField will pass extra_data - # as a dict. Otherwise, extra_data will be either a string or None. - if isinstance(row["extra_data"], str): - # A valid "extra_data" as a str, if present, should always be generated from - # orjson.dumps because the POSTed analytics data for RealmAuditLog is restricted - # to event types in SYNC_BILLING_EVENTS. - # For these event types, we don't create extra_data that requires special - # handling to fit into the JSONField. + if isinstance(row.extra_data, str): try: - extra_data = orjson.loads(row["extra_data"]) + extra_data = orjson.loads(row.extra_data) except orjson.JSONDecodeError: raise JsonableError(_("Malformed audit log data")) - elif row["extra_data"] is not None: - # This is guaranteed to succeed because row["extra_data"] would be parsed - # from JSON with our json validator and validated with check_dict if it - # is not a str or None. - assert isinstance(row["extra_data"], dict) - extra_data = row["extra_data"] + elif row.extra_data is not None: + assert isinstance(row.extra_data, dict) + extra_data = row.extra_data remote_realm_audit_logs.append( RemoteRealmAuditLog( - realm_id=row["realm"], - remote_id=row["id"], + realm_id=row.realm, + remote_id=row.id, server=server, event_time=datetime.datetime.fromtimestamp( - row["event_time"], tz=datetime.timezone.utc + row.event_time, tz=datetime.timezone.utc ), - backfilled=row["backfilled"], + backfilled=row.backfilled, extra_data=extra_data, - event_type=row["event_type"], + event_type=row.event_type, ) ) batch_create_table_data(server, RemoteRealmAuditLog, remote_realm_audit_logs)