diff --git a/zerver/tests/test_push_notifications.py b/zerver/tests/test_push_notifications.py index 526ff9ca56..09c401d18c 100644 --- a/zerver/tests/test_push_notifications.py +++ b/zerver/tests/test_push_notifications.py @@ -892,6 +892,71 @@ class AnalyticsBouncerTest(BouncerTestCase): ) self.assertEqual(remote_log_entry.event_type, RealmAuditLog.USER_REACTIVATED) + # This verify that the bouncer is forward-compatible with remote servers using + # TextField to store extra_data. + @override_settings(PUSH_NOTIFICATION_BOUNCER_URL="https://push.zulip.org.example.com") + @responses.activate + def test_realmauditlog_string_extra_data(self) -> None: + self.add_mock_response() + + def verify_request_with_overridden_extra_data( + request_extra_data: object, expected_extra_data: object + ) -> None: + user = self.example_user("hamlet") + log_entry = RealmAuditLog.objects.create( + realm=user.realm, + modified_user=user, + event_type=RealmAuditLog.USER_REACTIVATED, + event_time=self.TIME_ZERO, + extra_data=orjson.dumps({RealmAuditLog.ROLE_COUNT: 0}).decode(), + ) + + # We use this to patch send_to_push_bouncer so that extra_data in the + # legacy format gets sent to the bouncer. + def transform_realmauditlog_extra_data( + method: str, + endpoint: str, + post_data: Union[bytes, Mapping[str, Union[str, int, None, bytes]]], + extra_headers: Mapping[str, str] = {}, + ) -> Dict[str, Any]: + if endpoint == "server/analytics": + assert isinstance(post_data, dict) + assert isinstance(post_data["realmauditlog_rows"], str) + original_data = orjson.loads(post_data["realmauditlog_rows"]) + # We replace the extra_data with another fake example to verify that + # the bouncer actually gets requested with extra_data being string + new_data = [{**row, "extra_data": request_extra_data} for row in original_data] + post_data["realmauditlog_rows"] = orjson.dumps(new_data).decode() + return send_to_push_bouncer(method, endpoint, post_data, extra_headers) + + with mock.patch( + "zerver.lib.remote_server.send_to_push_bouncer", + side_effect=transform_realmauditlog_extra_data, + ): + send_analytics_to_remote_server() + + remote_log_entry = RemoteRealmAuditLog.objects.order_by("id").last() + assert remote_log_entry is not None + self.assertEqual(str(remote_log_entry.server.uuid), self.server_uuid) + self.assertEqual(remote_log_entry.remote_id, log_entry.id) + self.assertEqual(remote_log_entry.event_time, self.TIME_ZERO) + self.assertEqual(remote_log_entry.extra_data, expected_extra_data) + + # Pre-migration extra_data + verify_request_with_overridden_extra_data( + request_extra_data=orjson.dumps({"fake_data": 42}).decode(), + expected_extra_data=orjson.dumps({"fake_data": 42}).decode(), + ) + verify_request_with_overridden_extra_data(request_extra_data=None, expected_extra_data=None) + # Post-migration extra_data + verify_request_with_overridden_extra_data( + request_extra_data={"fake_data": 42}, + expected_extra_data=orjson.dumps({"fake_data": 42}).decode(), + ) + verify_request_with_overridden_extra_data( + request_extra_data={}, expected_extra_data=orjson.dumps({}).decode() + ) + class PushNotificationTest(BouncerTestCase): def setUp(self) -> None: diff --git a/zilencer/views.py b/zilencer/views.py index 12df8e3eb0..43fb529a60 100644 --- a/zilencer/views.py +++ b/zilencer/views.py @@ -4,6 +4,7 @@ from collections import Counter from typing import Any, Dict, List, Optional, Type, TypeVar from uuid import UUID +import orjson from django.core.exceptions import ValidationError from django.core.validators import URLValidator, validate_email from django.db import IntegrityError, transaction @@ -29,6 +30,7 @@ from zerver.lib.response import json_success from zerver.lib.validator import ( check_bool, check_capped_string, + check_dict, check_dict_only, check_float, check_int, @@ -36,6 +38,7 @@ from zerver.lib.validator import ( check_none_or, check_string, check_string_fixed_length, + check_union, ) from zerver.views.push_notifications import validate_token from zilencer.auth import InvalidZulipServerKeyError @@ -435,7 +438,7 @@ def remote_server_post_analytics( ("realm", check_int), ("event_time", check_float), ("backfilled", check_bool), - ("extra_data", check_none_or(check_string)), + ("extra_data", check_none_or(check_union([check_string, check_dict()]))), ("event_type", check_int), ] ) @@ -476,20 +479,29 @@ def remote_server_post_analytics( batch_create_table_data(server, RemoteInstallationCount, remote_installation_counts) if realmauditlog_rows is not None: - remote_realm_audit_logs = [ - RemoteRealmAuditLog( - realm_id=row["realm"], - remote_id=row["id"], - server=server, - event_time=datetime.datetime.fromtimestamp( - row["event_time"], tz=datetime.timezone.utc - ), - backfilled=row["backfilled"], - extra_data=row["extra_data"], - event_type=row["event_type"], + remote_realm_audit_logs = [] + for row in realmauditlog_rows: + # Remote servers that do support JSONField will pass extra_data + # as a dict. Otherwise, extra_data will be either a string or None. + if isinstance(row["extra_data"], dict): + # This is guaranteed to succeed because row["extra_data"] would be parsed + # from JSON with our json validator if it is a dict. + extra_data = orjson.dumps(row["extra_data"]).decode() + else: + extra_data = row["extra_data"] + remote_realm_audit_logs.append( + RemoteRealmAuditLog( + realm_id=row["realm"], + remote_id=row["id"], + server=server, + event_time=datetime.datetime.fromtimestamp( + row["event_time"], tz=datetime.timezone.utc + ), + backfilled=row["backfilled"], + extra_data=extra_data, + event_type=row["event_type"], + ) ) - for row in realmauditlog_rows - ] batch_create_table_data(server, RemoteRealmAuditLog, remote_realm_audit_logs) return json_success(request)