From 385328524174e0e433ba1bd6fea721fdc3a22e4b Mon Sep 17 00:00:00 2001 From: Anders Kaseorg Date: Mon, 7 Jun 2021 17:45:49 -0700 Subject: [PATCH] push_notifications: Replace PyAPNs2 with aioapns. Signed-off-by: Anders Kaseorg --- requirements/common.in | 4 +- requirements/dev.in | 2 +- requirements/dev.txt | 33 ++++----- requirements/prod.txt | 34 ++++------ zerver/lib/push_notifications.py | 73 ++++++++++---------- zerver/tests/test_push_notifications.py | 90 +++++++++++++++---------- 6 files changed, 122 insertions(+), 114 deletions(-) diff --git a/requirements/common.in b/requirements/common.in index 6c5b1cf9d0..c46097e5c5 100644 --- a/requirements/common.in +++ b/requirements/common.in @@ -100,7 +100,7 @@ tornado==4.* # https://github.com/zulip/zulip/issues/8913 orjson # Needed for iOS push notifications -apns2 +aioapns==1.* # 2.0 needs PyJWT 2: https://github.com/twilio/twilio-python/issues/556 python-twitter @@ -130,7 +130,7 @@ py3dns # Install Python Social Auth social-auth-app-django -social-auth-core[azuread,openidconnect,saml]<4.0.3 # 4.0.3 needs PyJWT 2: https://github.com/Pr0Ger/PyAPNs2/pull/122, https://github.com/twilio/twilio-python/issues/556 +social-auth-core[azuread,openidconnect,saml]<4.0.3 # 4.0.3 needs PyJWT 2: https://github.com/twilio/twilio-python/issues/556 # For encrypting a login token to the desktop app cryptography diff --git a/requirements/dev.in b/requirements/dev.in index 5c1f26555f..66305a9f19 100644 --- a/requirements/dev.in +++ b/requirements/dev.in @@ -12,7 +12,7 @@ moto[s3] Twisted # Needed for documentation links test -Scrapy<2.5.0 # 2.5.0 needs h2 3.0: https://github.com/Pr0Ger/PyAPNs2/issues/126 +Scrapy # Needed to compute test coverage coverage diff --git a/requirements/dev.txt b/requirements/dev.txt index d4b048ee47..8b9465dc53 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -7,14 +7,13 @@ # # For details, see requirements/README.md . # +aioapns==1.12 \ + --hash=sha256:263e36188bb218105c35bcbfde9252d78780805168fa2071d3f40b08bee14b17 + # via -r requirements/common.in alabaster==0.7.12 \ --hash=sha256:446438bdcca0e05bd45ea2de1668c1d9b032e1a9154c2c259092d77031ddd359 \ --hash=sha256:a661d72d58e6ea8a57f7a86e37d86716863ee5e92788398526d58b26a4e4dc02 # via sphinx -apns2==0.7.2 \ - --hash=sha256:4f2dae8c608961d1768f734acb1d0809a60ac71a0cdcca60f46529b73f20fb34 \ - --hash=sha256:f64a50181d0206a02943c835814a34fc1b1e12914931b74269a0f0fb4f39fd45 - # via -r requirements/common.in appdirs==1.4.4 \ --hash=sha256:7d5d0167b2b1ba821647616af46a749d1c653740dd0d2415100fe26e27afdf41 \ --hash=sha256:a841dacd6b99318a741b166adb07e19ee71a274450e68237b4650ca1055ab128 @@ -304,7 +303,6 @@ cryptography==3.4.7 \ --hash=sha256:ee77aa129f481be46f8d92a1a7db57269a2f23052d5f2433b4621bb457081cc9 # via # -r requirements/common.in - # apns2 # moto # pyopenssl # requests @@ -437,10 +435,10 @@ gitlint==0.15.1 \ --hash=sha256:4b22916dcbdca381244aee6cb8d8743756cfd98f27e7d1f02e78733f07c3c21c \ --hash=sha256:7ebdb8e7d333e577e956225cbc3ad8e0e96d05e638e6d461c9b66b784f9d2ac4 # via -r requirements/dev.in -h2==2.6.2 \ - --hash=sha256:93cbd1013a2218539af05cdf9fc37b786655b93bbc94f5296b7dabd1c5cadf41 \ - --hash=sha256:af35878673c83a44afbc12b13ac91a489da2819b5dc1e11768f3c2406f740fe9 - # via hyper +h2==3.2.0 \ + --hash=sha256:61e0f6601fa709f35cdb730863b4e5ec7ad449792add80d1410d4174ed139af5 \ + --hash=sha256:875f41ebd6f2c44781259005b157faed1a5031df3ae5aa7bcb4628a6c0782f14 + # via aioapns hpack==3.0.0 \ --hash=sha256:0edd79eda27a53ba5be2dfabf3b15780928a0dff6eb0c60a3d6767720e970c89 \ --hash=sha256:8eec9c1f4bfae3408a3f30500261f7e6a65912dc138526ea054f9ad98892e9d2 @@ -455,16 +453,10 @@ html5lib==1.1 \ --hash=sha256:0d78f8fde1c230e99fe37986a60526d7049ed4bf8a9fadbad5f00e22e58e041d \ --hash=sha256:b2e5b40261e20f354d198eae92afc10d750afb487ed5e50f9c4eaf07c184146f # via talon-core -hyper==0.7.0 \ - --hash=sha256:069514f54231fb7b5df2fb910a114663a83306d5296f588fffcb0a9be19407fc \ - --hash=sha256:12c82eacd122a659673484c1ea0d34576430afbe5aa6b8f63fe37fcb06a2458c - # via apns2 -hyperframe==3.2.0 \ - --hash=sha256:05f0e063e117c16fcdd13c12c93a4424a2c40668abfac3bb419a10f57698204e \ - --hash=sha256:4dcab11967482d400853b396d042038e4c492a15a5d2f57259e2b5f89a32f755 - # via - # h2 - # hyper +hyperframe==5.2.0 \ + --hash=sha256:5187962cb16dcc078f23cb5a4b110098d546c3f41ff2d4038a9896893bbd0b40 \ + --hash=sha256:a9f5c17f2cc3c719b917c4f33ed1c61bd1f8dfac4b1bd23b7c80b3400971b41f + # via h2 hyperlink==21.0.0 \ --hash=sha256:427af957daa58bc909471c6c40f74c5450fa123dd093fc53efd2e91d2705a56b \ --hash=sha256:e6b14c37ecb73e89c77d78cdb4c2cc8f3fb59a885c5b3f819ff4ed80f25af1b4 @@ -1048,7 +1040,7 @@ pyjwt==1.7.1 \ --hash=sha256:8d59a976fb773f3e6a39c85636357c4f0e242707394cadadd9814f5cbaa20e96 # via # -r requirements/common.in - # apns2 + # aioapns # social-auth-core # twilio pyoembed==0.1.2 \ @@ -1058,6 +1050,7 @@ pyopenssl==20.0.1 \ --hash=sha256:4c231c759543ba02560fcd2480c48dcec4dae34c9da7d3747c508227e0624b51 \ --hash=sha256:818ae18e06922c066f777a33f1fca45786d85edfe71cd043de6379337a7f274b # via + # aioapns # requests # scrapy pyparsing==2.4.7 \ diff --git a/requirements/prod.txt b/requirements/prod.txt index ce273b614a..26d7573b43 100644 --- a/requirements/prod.txt +++ b/requirements/prod.txt @@ -7,9 +7,8 @@ # # For details, see requirements/README.md . # -apns2==0.7.2 \ - --hash=sha256:4f2dae8c608961d1768f734acb1d0809a60ac71a0cdcca60f46529b73f20fb34 \ - --hash=sha256:f64a50181d0206a02943c835814a34fc1b1e12914931b74269a0f0fb4f39fd45 +aioapns==1.12 \ + --hash=sha256:263e36188bb218105c35bcbfde9252d78780805168fa2071d3f40b08bee14b17 # via -r requirements/common.in argon2-cffi==20.1.0 \ --hash=sha256:05a8ac07c7026542377e38389638a8a1e9b78f1cd8439cd7493b39f08dd75fbf \ @@ -186,7 +185,6 @@ cryptography==3.4.7 \ --hash=sha256:ee77aa129f481be46f8d92a1a7db57269a2f23052d5f2433b4621bb457081cc9 # via # -r requirements/common.in - # apns2 # pyopenssl # requests # social-auth-core @@ -286,10 +284,10 @@ ecdsa==0.17.0 \ future==0.18.2 \ --hash=sha256:b1bead90b70cf6ec3f0710ae53a525360fa360d306a86583adc6bf83a4db537d # via python-twitter -h2==2.6.2 \ - --hash=sha256:93cbd1013a2218539af05cdf9fc37b786655b93bbc94f5296b7dabd1c5cadf41 \ - --hash=sha256:af35878673c83a44afbc12b13ac91a489da2819b5dc1e11768f3c2406f740fe9 - # via hyper +h2==3.2.0 \ + --hash=sha256:61e0f6601fa709f35cdb730863b4e5ec7ad449792add80d1410d4174ed139af5 \ + --hash=sha256:875f41ebd6f2c44781259005b157faed1a5031df3ae5aa7bcb4628a6c0782f14 + # via aioapns hpack==3.0.0 \ --hash=sha256:0edd79eda27a53ba5be2dfabf3b15780928a0dff6eb0c60a3d6767720e970c89 \ --hash=sha256:8eec9c1f4bfae3408a3f30500261f7e6a65912dc138526ea054f9ad98892e9d2 @@ -304,16 +302,10 @@ html5lib==1.1 \ --hash=sha256:0d78f8fde1c230e99fe37986a60526d7049ed4bf8a9fadbad5f00e22e58e041d \ --hash=sha256:b2e5b40261e20f354d198eae92afc10d750afb487ed5e50f9c4eaf07c184146f # via talon-core -hyper==0.7.0 \ - --hash=sha256:069514f54231fb7b5df2fb910a114663a83306d5296f588fffcb0a9be19407fc \ - --hash=sha256:12c82eacd122a659673484c1ea0d34576430afbe5aa6b8f63fe37fcb06a2458c - # via apns2 -hyperframe==3.2.0 \ - --hash=sha256:05f0e063e117c16fcdd13c12c93a4424a2c40668abfac3bb419a10f57698204e \ - --hash=sha256:4dcab11967482d400853b396d042038e4c492a15a5d2f57259e2b5f89a32f755 - # via - # h2 - # hyper +hyperframe==5.2.0 \ + --hash=sha256:5187962cb16dcc078f23cb5a4b110098d546c3f41ff2d4038a9896893bbd0b40 \ + --hash=sha256:a9f5c17f2cc3c719b917c4f33ed1c61bd1f8dfac4b1bd23b7c80b3400971b41f + # via h2 idna==2.10 \ --hash=sha256:b307872f855b18632ce0c21c5e45be78c0ea7ae4c15c828c20788b26921eb3f6 \ --hash=sha256:b97d804b1e9b523befed77c48dacec60e6dcb0b5391d57af6a65a312a90648c0 @@ -672,7 +664,7 @@ pyjwt==1.7.1 \ --hash=sha256:8d59a976fb773f3e6a39c85636357c4f0e242707394cadadd9814f5cbaa20e96 # via # -r requirements/common.in - # apns2 + # aioapns # social-auth-core # twilio pyoembed==0.1.2 \ @@ -681,7 +673,9 @@ pyoembed==0.1.2 \ pyopenssl==20.0.1 \ --hash=sha256:4c231c759543ba02560fcd2480c48dcec4dae34c9da7d3747c508227e0624b51 \ --hash=sha256:818ae18e06922c066f777a33f1fca45786d85edfe71cd043de6379337a7f274b - # via requests + # via + # aioapns + # requests pyrsistent==0.18.0 \ --hash=sha256:097b96f129dd36a8c9e33594e7ebb151b1515eb52cceb08474c10a5479e799f2 \ --hash=sha256:2aaf19dc8ce517a8653746d98e962ef480ff34b6bc563fc067be6401ffb457c7 \ diff --git a/zerver/lib/push_notifications.py b/zerver/lib/push_notifications.py index 3ce8e3e529..51bf60c377 100644 --- a/zerver/lib/push_notifications.py +++ b/zerver/lib/push_notifications.py @@ -1,9 +1,10 @@ # See https://zulip.readthedocs.io/en/latest/subsystems/notifications.html +import asyncio import base64 import logging import re -import time +from dataclasses import dataclass from functools import lru_cache from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union @@ -38,7 +39,7 @@ from zerver.models import ( ) if TYPE_CHECKING: - from apns2.client import APNsClient + import aioapns logger = logging.getLogger(__name__) @@ -61,11 +62,17 @@ def hex_to_b64(data: str) -> str: # +@dataclass +class APNsContext: + apns: "aioapns.APNs" + loop: asyncio.AbstractEventLoop + + @lru_cache(maxsize=None) -def get_apns_client() -> "Optional[APNsClient]": +def get_apns_context() -> Optional[APNsContext]: # We lazily do this import as part of optimizing Zulip's base # import time. - from apns2.client import APNsClient + import aioapns if settings.APNS_CERT_FILE is None: return None @@ -73,12 +80,18 @@ def get_apns_client() -> "Optional[APNsClient]": # NB if called concurrently, this will make excess connections. # That's a little sloppy, but harmless unless a server gets # hammered with a ton of these all at once after startup. - return APNsClient(credentials=settings.APNS_CERT_FILE, use_sandbox=settings.APNS_SANDBOX) + loop = asyncio.new_event_loop() + apns = aioapns.APNs( + client_cert=settings.APNS_CERT_FILE, + topic=settings.APNS_TOPIC, + loop=loop, + use_sandbox=settings.APNS_SANDBOX, + ) + return APNsContext(apns=apns, loop=loop) def apns_enabled() -> bool: - client = get_apns_client() - return client is not None + return get_apns_context() is not None def modernize_apns_payload(data: Dict[str, Any]) -> Dict[str, Any]: @@ -120,11 +133,11 @@ def send_apple_push_notification( # import time; since these are only needed in the push # notification queue worker, it's best to only import them in the # code that needs them. - from apns2.payload import Payload as APNsPayload - from hyper.http20.exceptions import HTTP20Error + import aioapns + import aioapns.exceptions - client = get_apns_client() - if client is None: + apns_context = get_apns_context() + if apns_context is None: logger.debug( "APNs: Dropping a notification because nothing configured. " "Set PUSH_NOTIFICATION_BOUNCER_URL (or APNS_CERT_FILE)." @@ -138,36 +151,28 @@ def send_apple_push_notification( DeviceTokenClass = PushDeviceToken logger.info("APNs: Sending notification for user %d to %d devices", user_id, len(devices)) - payload = APNsPayload(**modernize_apns_payload(payload_data)) - expiration = int(time.time() + 24 * 3600) + message = {"aps": modernize_apns_payload(payload_data)} retries_left = APNS_MAX_RETRIES for device in devices: # TODO obviously this should be made to actually use the async + request = aioapns.NotificationRequest( + device_token=device.token, message=message, time_to_live=24 * 3600 + ) - def attempt_send() -> Optional[str]: - assert client is not None + async def attempt_send() -> Optional[str]: + assert apns_context is not None try: - stream_id = client.send_notification_async( - device.token, payload, topic=settings.APNS_TOPIC, expiration=expiration - ) - return client.get_notification_result(stream_id) - except HTTP20Error as e: + result = await apns_context.apns.send_notification(request) + return "Success" if result.is_successful else result.description + except aioapns.exceptions.ConnectionClosed as e: # nocoverage logger.warning( - "APNs: HTTP error sending for user %d to device %s: %s", + "APNs: ConnectionClosed sending for user %d to device %s: %s", user_id, device.token, e.__class__.__name__, ) return None - except BrokenPipeError as e: - logger.warning( - "APNs: BrokenPipeError sending for user %d to device %s: %s", - user_id, - device.token, - e.__class__.__name__, - ) - return None - except ConnectionError as e: # nocoverage + except aioapns.exceptions.ConnectionError as e: # nocoverage logger.warning( "APNs: ConnectionError sending for user %d to device %s: %s", user_id, @@ -176,17 +181,13 @@ def send_apple_push_notification( ) return None - result = attempt_send() + result = apns_context.loop.run_until_complete(attempt_send()) while result is None and retries_left > 0: retries_left -= 1 - result = attempt_send() + result = apns_context.loop.run_until_complete(attempt_send()) if result is None: result = "HTTP error, retries exhausted" - if result[0] == "Unregistered": - # For some reason, "Unregistered" result values have a - # different format, as a tuple of the pair ("Unregistered", 12345132131). - result = result[0] if result == "Success": logger.info("APNs: Success sending for user %d to device %s", user_id, device.token) elif result in ["Unregistered", "BadDeviceToken", "DeviceTokenNotForTopic"]: diff --git a/zerver/tests/test_push_notifications.py b/zerver/tests/test_push_notifications.py index a60cf4a44f..0f47c6a68d 100644 --- a/zerver/tests/test_push_notifications.py +++ b/zerver/tests/test_push_notifications.py @@ -1,3 +1,4 @@ +import asyncio import base64 import datetime import itertools @@ -30,13 +31,14 @@ from zerver.lib.actions import ( do_update_message_flags, ) from zerver.lib.push_notifications import ( + APNsContext, DeviceToken, absolute_avatar_url, b64_to_hex, datetime_to_timestamp, get_apns_badge_count, get_apns_badge_count_future, - get_apns_client, + get_apns_context, get_display_recipient, get_message_payload_apns, get_message_payload_gcm, @@ -746,8 +748,8 @@ class PushNotificationTest(BouncerTestCase): @contextmanager def mock_apns(self) -> Iterator[mock.MagicMock]: mock_apns = mock.Mock() - with mock.patch("zerver.lib.push_notifications.get_apns_client") as mock_get: - mock_get.return_value = mock_apns + with mock.patch("zerver.lib.push_notifications.get_apns_context") as mock_get: + mock_get.return_value = APNsContext(apns=mock_apns, loop=asyncio.new_event_loop()) yield mock_apns def setup_apns_tokens(self) -> None: @@ -833,7 +835,10 @@ class HandlePushNotificationTest(PushNotificationTest): for device in RemotePushDeviceToken.objects.filter(kind=PushDeviceToken.GCM) ] mock_gcm.json_request.return_value = {"success": {gcm_devices[0][2]: message.id}} - mock_apns.get_notification_result.return_value = "Success" + result = mock.Mock() + result.is_successful = True + mock_apns.send_notification.return_value = asyncio.Future() + mock_apns.send_notification.return_value.set_result(result) handle_push_notification(self.user_profile.id, missed_message) for _, _, token in apns_devices: mock_info.assert_any_call( @@ -879,8 +884,11 @@ class HandlePushNotificationTest(PushNotificationTest): for device in RemotePushDeviceToken.objects.filter(kind=PushDeviceToken.GCM) ] mock_gcm.json_request.return_value = {"success": {gcm_devices[0][2]: message.id}} - - mock_apns.get_notification_result.return_value = ("Unregistered", 1234567) + result = mock.Mock() + result.is_successful = False + result.description = "Unregistered" + mock_apns.send_notification.return_value = asyncio.Future() + mock_apns.send_notification.return_value.set_result(result) handle_push_notification(self.user_profile.id, missed_message) for _, _, token in apns_devices: mock_info.assert_any_call( @@ -1306,28 +1314,27 @@ class TestAPNs(PushNotificationTest): devices = self.devices() send_apple_push_notification(self.user_profile.id, devices, payload_data) - def test_get_apns_client(self) -> None: + def test_get_apns_context(self) -> None: """This test is pretty hacky, and needs to carefully reset the state it modifies in order to avoid leaking state that can lead to nondeterministic results for other tests. """ import zerver.lib.push_notifications - zerver.lib.push_notifications.get_apns_client.cache_clear() + zerver.lib.push_notifications.get_apns_context.cache_clear() try: - with self.settings(APNS_CERT_FILE="/foo.pem"), mock.patch( - "apns2.client.APNsClient" - ) as mock_client: - client = get_apns_client() - self.assertEqual(mock_client.return_value, client) + with self.settings(APNS_CERT_FILE="/foo.pem"), mock.patch("aioapns.APNs") as mock_apns: + apns_context = get_apns_context() + assert apns_context is not None + self.assertEqual(mock_apns.return_value, apns_context.apns) finally: - # Reset the cache for `get_apns_client` so that we don't + # Reset the cache for `get_apns_context` so that we don't # leak changes to the rest of the world. - zerver.lib.push_notifications.get_apns_client.cache_clear() + zerver.lib.push_notifications.get_apns_context.cache_clear() def test_not_configured(self) -> None: self.setup_apns_tokens() - with mock.patch("zerver.lib.push_notifications.get_apns_client") as mock_get, mock.patch( + with mock.patch("zerver.lib.push_notifications.get_apns_context") as mock_get, mock.patch( "zerver.lib.push_notifications.logger" ) as mock_logging: mock_get.return_value = None @@ -1350,7 +1357,10 @@ class TestAPNs(PushNotificationTest): with self.mock_apns() as mock_apns, mock.patch( "zerver.lib.push_notifications.logger" ) as mock_logging: - mock_apns.get_notification_result.return_value = "Success" + result = mock.Mock() + result.is_successful = True + mock_apns.send_notification.return_value = asyncio.Future() + mock_apns.send_notification.return_value.set_result(result) self.send() mock_logging.warning.assert_not_called() for device in self.devices(): @@ -1361,21 +1371,27 @@ class TestAPNs(PushNotificationTest): ) def test_http_retry(self) -> None: - import hyper + import aioapns self.setup_apns_tokens() with self.mock_apns() as mock_apns, mock.patch( "zerver.lib.push_notifications.logger" ) as mock_logging: - mock_apns.get_notification_result.side_effect = itertools.chain( - [hyper.http20.exceptions.StreamResetError()], itertools.repeat("Success") + exception: asyncio.Future[object] = asyncio.Future() + exception.set_exception(aioapns.exceptions.ConnectionError()) + result = mock.Mock() + result.is_successful = True + future: asyncio.Future[object] = asyncio.Future() + future.set_result(result) + mock_apns.send_notification.side_effect = itertools.chain( + [exception], itertools.repeat(future) ) self.send() mock_logging.warning.assert_called_once_with( - "APNs: HTTP error sending for user %d to device %s: %s", + "APNs: ConnectionError sending for user %d to device %s: %s", self.user_profile.id, self.devices()[0].token, - "StreamResetError", + "ConnectionError", ) for device in self.devices(): mock_logging.info.assert_any_call( @@ -1384,20 +1400,28 @@ class TestAPNs(PushNotificationTest): device.token, ) - def test_http_retry_pipefail(self) -> None: + def test_http_retry_closed(self) -> None: + import aioapns + self.setup_apns_tokens() with self.mock_apns() as mock_apns, mock.patch( "zerver.lib.push_notifications.logger" ) as mock_logging: - mock_apns.get_notification_result.side_effect = itertools.chain( - [BrokenPipeError()], itertools.repeat("Success") + exception: asyncio.Future[object] = asyncio.Future() + exception.set_exception(aioapns.exceptions.ConnectionClosed()) + result = mock.Mock() + result.is_successful = True + future: asyncio.Future[object] = asyncio.Future() + future.set_result(result) + mock_apns.send_notification.side_effect = itertools.chain( + [exception], itertools.repeat(future) ) self.send() mock_logging.warning.assert_called_once_with( - "APNs: BrokenPipeError sending for user %d to device %s: %s", + "APNs: ConnectionClosed sending for user %d to device %s: %s", self.user_profile.id, self.devices()[0].token, - "BrokenPipeError", + "ConnectionClosed", ) for device in self.devices(): mock_logging.info.assert_any_call( @@ -1407,19 +1431,15 @@ class TestAPNs(PushNotificationTest): ) def test_http_retry_eventually_fails(self) -> None: - import hyper + import aioapns self.setup_apns_tokens() with self.mock_apns() as mock_apns, mock.patch( "zerver.lib.push_notifications.logger" ) as mock_logging: - mock_apns.get_notification_result.side_effect = itertools.chain( - [hyper.http20.exceptions.StreamResetError()], - [hyper.http20.exceptions.StreamResetError()], - [hyper.http20.exceptions.StreamResetError()], - [hyper.http20.exceptions.StreamResetError()], - [hyper.http20.exceptions.StreamResetError()], - ) + exception: asyncio.Future[object] = asyncio.Future() + exception.set_exception(aioapns.exceptions.ConnectionError()) + mock_apns.send_notification.side_effect = iter([exception] * 5) self.send(devices=self.devices()[0:1]) self.assertEqual(mock_logging.warning.call_count, 5)