diff --git a/zerver/lib/push_notifications.py b/zerver/lib/push_notifications.py index 4b8dbb21fa..0ed00d59dc 100644 --- a/zerver/lib/push_notifications.py +++ b/zerver/lib/push_notifications.py @@ -547,6 +547,9 @@ def add_push_device_token( post_data = { "server_uuid": settings.ZULIP_ORG_ID, "user_uuid": str(user_profile.uuid), + # user_id is sent so that the bouncer can delete any pre-existing registrations + # for this user+device to avoid duplication upon adding the uuid registration. + "user_id": str(user_profile.id), "token": token_str, "token_kind": kind, } diff --git a/zerver/tests/test_push_notifications.py b/zerver/tests/test_push_notifications.py index a7e85ea8c4..ccdb8b5e9c 100644 --- a/zerver/tests/test_push_notifications.py +++ b/zerver/tests/test_push_notifications.py @@ -192,12 +192,6 @@ class PushBouncerNotificationTest(BouncerTestCase): self.server_uuid, endpoint, {"token": token, "token_kind": token_kind} ) self.assert_json_error(result, "Missing user_id or user_uuid") - result = self.uuid_post( - self.server_uuid, - endpoint, - {"user_id": user_id, "user_uuid": "xxx", "token": token, "token_kind": token_kind}, - ) - self.assert_json_error(result, "Specify only one of user_id or user_uuid") result = self.uuid_post( self.server_uuid, endpoint, {"user_id": user_id, "token": token, "token_kind": 17} ) @@ -280,6 +274,40 @@ class PushBouncerNotificationTest(BouncerTestCase): 401, ) + def test_register_device_deduplication(self) -> None: + hamlet = self.example_user("hamlet") + token = "111222" + user_id = hamlet.id + user_uuid = str(hamlet.uuid) + token_kind = PushDeviceToken.GCM + + endpoint = "/api/v1/remotes/push/register" + + # First we create a legacy user_id registration. + result = self.uuid_post( + self.server_uuid, + endpoint, + {"user_id": user_id, "token_kind": token_kind, "token": token}, + ) + self.assert_json_success(result) + + registrations = list(RemotePushDeviceToken.objects.filter(token=token)) + self.assert_length(registrations, 1) + self.assertEqual(registrations[0].user_id, user_id) + self.assertEqual(registrations[0].user_uuid, None) + + # Register same user+device with uuid now. The old registration should be deleted + # to avoid duplication. + result = self.uuid_post( + self.server_uuid, + endpoint, + {"user_id": user_id, "user_uuid": user_uuid, "token_kind": token_kind, "token": token}, + ) + registrations = list(RemotePushDeviceToken.objects.filter(token=token)) + self.assert_length(registrations, 1) + self.assertEqual(registrations[0].user_id, None) + self.assertEqual(str(registrations[0].user_uuid), user_uuid) + def test_remote_push_user_endpoints(self) -> None: endpoints = [ ("/api/v1/remotes/push/register", "register"), diff --git a/zilencer/views.py b/zilencer/views.py index dbe9607ecf..50b1ea3c44 100644 --- a/zilencer/views.py +++ b/zilencer/views.py @@ -157,23 +157,26 @@ def register_remote_push_device( if user_id is None and user_uuid is None: raise JsonableError(_("Missing user_id or user_uuid")) if user_id is not None and user_uuid is not None: - # We don't want "hybrid" registrations with both. - # Our RemotePushDeviceToken should be either in the new uuid format - # or the legacy id one. - raise JsonableError(_("Specify only one of user_id or user_uuid")) - + kwargs: Dict[str, object] = {"user_uuid": user_uuid, "user_id": None} + # Delete pre-existing user_id registration for this user+device to avoid + # duplication. Further down, uuid registration will be created. + RemotePushDeviceToken.objects.filter( + server=server, token=token, kind=token_kind, user_id=user_id + ).delete() + else: + # One of these is None, so these will kwargs will leads to a proper registration + # of either user_id or user_uuid type + kwargs = {"user_id": user_id, "user_uuid": user_uuid} try: with transaction.atomic(): RemotePushDeviceToken.objects.create( - # Exactly one of these two user identity fields will be None. - user_id=user_id, - user_uuid=user_uuid, server=server, kind=token_kind, token=token, ios_app_id=ios_app_id, # last_updated is to be renamed to date_created. last_updated=timezone.now(), + **kwargs, ) except IntegrityError: pass