push_notifications: Replace PyAPNs2 with aioapns.

Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
Anders Kaseorg 2021-06-07 17:45:49 -07:00 committed by Tim Abbott
parent 0bc002270c
commit 3853285241
6 changed files with 122 additions and 114 deletions

View File

@ -100,7 +100,7 @@ tornado==4.* # https://github.com/zulip/zulip/issues/8913
orjson
# Needed for iOS push notifications
apns2
aioapns==1.* # 2.0 needs PyJWT 2: https://github.com/twilio/twilio-python/issues/556
python-twitter
@ -130,7 +130,7 @@ py3dns
# Install Python Social Auth
social-auth-app-django
social-auth-core[azuread,openidconnect,saml]<4.0.3 # 4.0.3 needs PyJWT 2: https://github.com/Pr0Ger/PyAPNs2/pull/122, https://github.com/twilio/twilio-python/issues/556
social-auth-core[azuread,openidconnect,saml]<4.0.3 # 4.0.3 needs PyJWT 2: https://github.com/twilio/twilio-python/issues/556
# For encrypting a login token to the desktop app
cryptography

View File

@ -12,7 +12,7 @@ moto[s3]
Twisted
# Needed for documentation links test
Scrapy<2.5.0 # 2.5.0 needs h2 3.0: https://github.com/Pr0Ger/PyAPNs2/issues/126
Scrapy
# Needed to compute test coverage
coverage

View File

@ -7,14 +7,13 @@
#
# For details, see requirements/README.md .
#
aioapns==1.12 \
--hash=sha256:263e36188bb218105c35bcbfde9252d78780805168fa2071d3f40b08bee14b17
# via -r requirements/common.in
alabaster==0.7.12 \
--hash=sha256:446438bdcca0e05bd45ea2de1668c1d9b032e1a9154c2c259092d77031ddd359 \
--hash=sha256:a661d72d58e6ea8a57f7a86e37d86716863ee5e92788398526d58b26a4e4dc02
# via sphinx
apns2==0.7.2 \
--hash=sha256:4f2dae8c608961d1768f734acb1d0809a60ac71a0cdcca60f46529b73f20fb34 \
--hash=sha256:f64a50181d0206a02943c835814a34fc1b1e12914931b74269a0f0fb4f39fd45
# via -r requirements/common.in
appdirs==1.4.4 \
--hash=sha256:7d5d0167b2b1ba821647616af46a749d1c653740dd0d2415100fe26e27afdf41 \
--hash=sha256:a841dacd6b99318a741b166adb07e19ee71a274450e68237b4650ca1055ab128
@ -304,7 +303,6 @@ cryptography==3.4.7 \
--hash=sha256:ee77aa129f481be46f8d92a1a7db57269a2f23052d5f2433b4621bb457081cc9
# via
# -r requirements/common.in
# apns2
# moto
# pyopenssl
# requests
@ -437,10 +435,10 @@ gitlint==0.15.1 \
--hash=sha256:4b22916dcbdca381244aee6cb8d8743756cfd98f27e7d1f02e78733f07c3c21c \
--hash=sha256:7ebdb8e7d333e577e956225cbc3ad8e0e96d05e638e6d461c9b66b784f9d2ac4
# via -r requirements/dev.in
h2==2.6.2 \
--hash=sha256:93cbd1013a2218539af05cdf9fc37b786655b93bbc94f5296b7dabd1c5cadf41 \
--hash=sha256:af35878673c83a44afbc12b13ac91a489da2819b5dc1e11768f3c2406f740fe9
# via hyper
h2==3.2.0 \
--hash=sha256:61e0f6601fa709f35cdb730863b4e5ec7ad449792add80d1410d4174ed139af5 \
--hash=sha256:875f41ebd6f2c44781259005b157faed1a5031df3ae5aa7bcb4628a6c0782f14
# via aioapns
hpack==3.0.0 \
--hash=sha256:0edd79eda27a53ba5be2dfabf3b15780928a0dff6eb0c60a3d6767720e970c89 \
--hash=sha256:8eec9c1f4bfae3408a3f30500261f7e6a65912dc138526ea054f9ad98892e9d2
@ -455,16 +453,10 @@ html5lib==1.1 \
--hash=sha256:0d78f8fde1c230e99fe37986a60526d7049ed4bf8a9fadbad5f00e22e58e041d \
--hash=sha256:b2e5b40261e20f354d198eae92afc10d750afb487ed5e50f9c4eaf07c184146f
# via talon-core
hyper==0.7.0 \
--hash=sha256:069514f54231fb7b5df2fb910a114663a83306d5296f588fffcb0a9be19407fc \
--hash=sha256:12c82eacd122a659673484c1ea0d34576430afbe5aa6b8f63fe37fcb06a2458c
# via apns2
hyperframe==3.2.0 \
--hash=sha256:05f0e063e117c16fcdd13c12c93a4424a2c40668abfac3bb419a10f57698204e \
--hash=sha256:4dcab11967482d400853b396d042038e4c492a15a5d2f57259e2b5f89a32f755
# via
# h2
# hyper
hyperframe==5.2.0 \
--hash=sha256:5187962cb16dcc078f23cb5a4b110098d546c3f41ff2d4038a9896893bbd0b40 \
--hash=sha256:a9f5c17f2cc3c719b917c4f33ed1c61bd1f8dfac4b1bd23b7c80b3400971b41f
# via h2
hyperlink==21.0.0 \
--hash=sha256:427af957daa58bc909471c6c40f74c5450fa123dd093fc53efd2e91d2705a56b \
--hash=sha256:e6b14c37ecb73e89c77d78cdb4c2cc8f3fb59a885c5b3f819ff4ed80f25af1b4
@ -1048,7 +1040,7 @@ pyjwt==1.7.1 \
--hash=sha256:8d59a976fb773f3e6a39c85636357c4f0e242707394cadadd9814f5cbaa20e96
# via
# -r requirements/common.in
# apns2
# aioapns
# social-auth-core
# twilio
pyoembed==0.1.2 \
@ -1058,6 +1050,7 @@ pyopenssl==20.0.1 \
--hash=sha256:4c231c759543ba02560fcd2480c48dcec4dae34c9da7d3747c508227e0624b51 \
--hash=sha256:818ae18e06922c066f777a33f1fca45786d85edfe71cd043de6379337a7f274b
# via
# aioapns
# requests
# scrapy
pyparsing==2.4.7 \

View File

@ -7,9 +7,8 @@
#
# For details, see requirements/README.md .
#
apns2==0.7.2 \
--hash=sha256:4f2dae8c608961d1768f734acb1d0809a60ac71a0cdcca60f46529b73f20fb34 \
--hash=sha256:f64a50181d0206a02943c835814a34fc1b1e12914931b74269a0f0fb4f39fd45
aioapns==1.12 \
--hash=sha256:263e36188bb218105c35bcbfde9252d78780805168fa2071d3f40b08bee14b17
# via -r requirements/common.in
argon2-cffi==20.1.0 \
--hash=sha256:05a8ac07c7026542377e38389638a8a1e9b78f1cd8439cd7493b39f08dd75fbf \
@ -186,7 +185,6 @@ cryptography==3.4.7 \
--hash=sha256:ee77aa129f481be46f8d92a1a7db57269a2f23052d5f2433b4621bb457081cc9
# via
# -r requirements/common.in
# apns2
# pyopenssl
# requests
# social-auth-core
@ -286,10 +284,10 @@ ecdsa==0.17.0 \
future==0.18.2 \
--hash=sha256:b1bead90b70cf6ec3f0710ae53a525360fa360d306a86583adc6bf83a4db537d
# via python-twitter
h2==2.6.2 \
--hash=sha256:93cbd1013a2218539af05cdf9fc37b786655b93bbc94f5296b7dabd1c5cadf41 \
--hash=sha256:af35878673c83a44afbc12b13ac91a489da2819b5dc1e11768f3c2406f740fe9
# via hyper
h2==3.2.0 \
--hash=sha256:61e0f6601fa709f35cdb730863b4e5ec7ad449792add80d1410d4174ed139af5 \
--hash=sha256:875f41ebd6f2c44781259005b157faed1a5031df3ae5aa7bcb4628a6c0782f14
# via aioapns
hpack==3.0.0 \
--hash=sha256:0edd79eda27a53ba5be2dfabf3b15780928a0dff6eb0c60a3d6767720e970c89 \
--hash=sha256:8eec9c1f4bfae3408a3f30500261f7e6a65912dc138526ea054f9ad98892e9d2
@ -304,16 +302,10 @@ html5lib==1.1 \
--hash=sha256:0d78f8fde1c230e99fe37986a60526d7049ed4bf8a9fadbad5f00e22e58e041d \
--hash=sha256:b2e5b40261e20f354d198eae92afc10d750afb487ed5e50f9c4eaf07c184146f
# via talon-core
hyper==0.7.0 \
--hash=sha256:069514f54231fb7b5df2fb910a114663a83306d5296f588fffcb0a9be19407fc \
--hash=sha256:12c82eacd122a659673484c1ea0d34576430afbe5aa6b8f63fe37fcb06a2458c
# via apns2
hyperframe==3.2.0 \
--hash=sha256:05f0e063e117c16fcdd13c12c93a4424a2c40668abfac3bb419a10f57698204e \
--hash=sha256:4dcab11967482d400853b396d042038e4c492a15a5d2f57259e2b5f89a32f755
# via
# h2
# hyper
hyperframe==5.2.0 \
--hash=sha256:5187962cb16dcc078f23cb5a4b110098d546c3f41ff2d4038a9896893bbd0b40 \
--hash=sha256:a9f5c17f2cc3c719b917c4f33ed1c61bd1f8dfac4b1bd23b7c80b3400971b41f
# via h2
idna==2.10 \
--hash=sha256:b307872f855b18632ce0c21c5e45be78c0ea7ae4c15c828c20788b26921eb3f6 \
--hash=sha256:b97d804b1e9b523befed77c48dacec60e6dcb0b5391d57af6a65a312a90648c0
@ -672,7 +664,7 @@ pyjwt==1.7.1 \
--hash=sha256:8d59a976fb773f3e6a39c85636357c4f0e242707394cadadd9814f5cbaa20e96
# via
# -r requirements/common.in
# apns2
# aioapns
# social-auth-core
# twilio
pyoembed==0.1.2 \
@ -681,7 +673,9 @@ pyoembed==0.1.2 \
pyopenssl==20.0.1 \
--hash=sha256:4c231c759543ba02560fcd2480c48dcec4dae34c9da7d3747c508227e0624b51 \
--hash=sha256:818ae18e06922c066f777a33f1fca45786d85edfe71cd043de6379337a7f274b
# via requests
# via
# aioapns
# requests
pyrsistent==0.18.0 \
--hash=sha256:097b96f129dd36a8c9e33594e7ebb151b1515eb52cceb08474c10a5479e799f2 \
--hash=sha256:2aaf19dc8ce517a8653746d98e962ef480ff34b6bc563fc067be6401ffb457c7 \

View File

@ -1,9 +1,10 @@
# See https://zulip.readthedocs.io/en/latest/subsystems/notifications.html
import asyncio
import base64
import logging
import re
import time
from dataclasses import dataclass
from functools import lru_cache
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union
@ -38,7 +39,7 @@ from zerver.models import (
)
if TYPE_CHECKING:
from apns2.client import APNsClient
import aioapns
logger = logging.getLogger(__name__)
@ -61,11 +62,17 @@ def hex_to_b64(data: str) -> str:
#
@dataclass
class APNsContext:
apns: "aioapns.APNs"
loop: asyncio.AbstractEventLoop
@lru_cache(maxsize=None)
def get_apns_client() -> "Optional[APNsClient]":
def get_apns_context() -> Optional[APNsContext]:
# We lazily do this import as part of optimizing Zulip's base
# import time.
from apns2.client import APNsClient
import aioapns
if settings.APNS_CERT_FILE is None:
return None
@ -73,12 +80,18 @@ def get_apns_client() -> "Optional[APNsClient]":
# NB if called concurrently, this will make excess connections.
# That's a little sloppy, but harmless unless a server gets
# hammered with a ton of these all at once after startup.
return APNsClient(credentials=settings.APNS_CERT_FILE, use_sandbox=settings.APNS_SANDBOX)
loop = asyncio.new_event_loop()
apns = aioapns.APNs(
client_cert=settings.APNS_CERT_FILE,
topic=settings.APNS_TOPIC,
loop=loop,
use_sandbox=settings.APNS_SANDBOX,
)
return APNsContext(apns=apns, loop=loop)
def apns_enabled() -> bool:
client = get_apns_client()
return client is not None
return get_apns_context() is not None
def modernize_apns_payload(data: Dict[str, Any]) -> Dict[str, Any]:
@ -120,11 +133,11 @@ def send_apple_push_notification(
# import time; since these are only needed in the push
# notification queue worker, it's best to only import them in the
# code that needs them.
from apns2.payload import Payload as APNsPayload
from hyper.http20.exceptions import HTTP20Error
import aioapns
import aioapns.exceptions
client = get_apns_client()
if client is None:
apns_context = get_apns_context()
if apns_context is None:
logger.debug(
"APNs: Dropping a notification because nothing configured. "
"Set PUSH_NOTIFICATION_BOUNCER_URL (or APNS_CERT_FILE)."
@ -138,36 +151,28 @@ def send_apple_push_notification(
DeviceTokenClass = PushDeviceToken
logger.info("APNs: Sending notification for user %d to %d devices", user_id, len(devices))
payload = APNsPayload(**modernize_apns_payload(payload_data))
expiration = int(time.time() + 24 * 3600)
message = {"aps": modernize_apns_payload(payload_data)}
retries_left = APNS_MAX_RETRIES
for device in devices:
# TODO obviously this should be made to actually use the async
request = aioapns.NotificationRequest(
device_token=device.token, message=message, time_to_live=24 * 3600
)
def attempt_send() -> Optional[str]:
assert client is not None
async def attempt_send() -> Optional[str]:
assert apns_context is not None
try:
stream_id = client.send_notification_async(
device.token, payload, topic=settings.APNS_TOPIC, expiration=expiration
)
return client.get_notification_result(stream_id)
except HTTP20Error as e:
result = await apns_context.apns.send_notification(request)
return "Success" if result.is_successful else result.description
except aioapns.exceptions.ConnectionClosed as e: # nocoverage
logger.warning(
"APNs: HTTP error sending for user %d to device %s: %s",
"APNs: ConnectionClosed sending for user %d to device %s: %s",
user_id,
device.token,
e.__class__.__name__,
)
return None
except BrokenPipeError as e:
logger.warning(
"APNs: BrokenPipeError sending for user %d to device %s: %s",
user_id,
device.token,
e.__class__.__name__,
)
return None
except ConnectionError as e: # nocoverage
except aioapns.exceptions.ConnectionError as e: # nocoverage
logger.warning(
"APNs: ConnectionError sending for user %d to device %s: %s",
user_id,
@ -176,17 +181,13 @@ def send_apple_push_notification(
)
return None
result = attempt_send()
result = apns_context.loop.run_until_complete(attempt_send())
while result is None and retries_left > 0:
retries_left -= 1
result = attempt_send()
result = apns_context.loop.run_until_complete(attempt_send())
if result is None:
result = "HTTP error, retries exhausted"
if result[0] == "Unregistered":
# For some reason, "Unregistered" result values have a
# different format, as a tuple of the pair ("Unregistered", 12345132131).
result = result[0]
if result == "Success":
logger.info("APNs: Success sending for user %d to device %s", user_id, device.token)
elif result in ["Unregistered", "BadDeviceToken", "DeviceTokenNotForTopic"]:

View File

@ -1,3 +1,4 @@
import asyncio
import base64
import datetime
import itertools
@ -30,13 +31,14 @@ from zerver.lib.actions import (
do_update_message_flags,
)
from zerver.lib.push_notifications import (
APNsContext,
DeviceToken,
absolute_avatar_url,
b64_to_hex,
datetime_to_timestamp,
get_apns_badge_count,
get_apns_badge_count_future,
get_apns_client,
get_apns_context,
get_display_recipient,
get_message_payload_apns,
get_message_payload_gcm,
@ -746,8 +748,8 @@ class PushNotificationTest(BouncerTestCase):
@contextmanager
def mock_apns(self) -> Iterator[mock.MagicMock]:
mock_apns = mock.Mock()
with mock.patch("zerver.lib.push_notifications.get_apns_client") as mock_get:
mock_get.return_value = mock_apns
with mock.patch("zerver.lib.push_notifications.get_apns_context") as mock_get:
mock_get.return_value = APNsContext(apns=mock_apns, loop=asyncio.new_event_loop())
yield mock_apns
def setup_apns_tokens(self) -> None:
@ -833,7 +835,10 @@ class HandlePushNotificationTest(PushNotificationTest):
for device in RemotePushDeviceToken.objects.filter(kind=PushDeviceToken.GCM)
]
mock_gcm.json_request.return_value = {"success": {gcm_devices[0][2]: message.id}}
mock_apns.get_notification_result.return_value = "Success"
result = mock.Mock()
result.is_successful = True
mock_apns.send_notification.return_value = asyncio.Future()
mock_apns.send_notification.return_value.set_result(result)
handle_push_notification(self.user_profile.id, missed_message)
for _, _, token in apns_devices:
mock_info.assert_any_call(
@ -879,8 +884,11 @@ class HandlePushNotificationTest(PushNotificationTest):
for device in RemotePushDeviceToken.objects.filter(kind=PushDeviceToken.GCM)
]
mock_gcm.json_request.return_value = {"success": {gcm_devices[0][2]: message.id}}
mock_apns.get_notification_result.return_value = ("Unregistered", 1234567)
result = mock.Mock()
result.is_successful = False
result.description = "Unregistered"
mock_apns.send_notification.return_value = asyncio.Future()
mock_apns.send_notification.return_value.set_result(result)
handle_push_notification(self.user_profile.id, missed_message)
for _, _, token in apns_devices:
mock_info.assert_any_call(
@ -1306,28 +1314,27 @@ class TestAPNs(PushNotificationTest):
devices = self.devices()
send_apple_push_notification(self.user_profile.id, devices, payload_data)
def test_get_apns_client(self) -> None:
def test_get_apns_context(self) -> None:
"""This test is pretty hacky, and needs to carefully reset the state
it modifies in order to avoid leaking state that can lead to
nondeterministic results for other tests.
"""
import zerver.lib.push_notifications
zerver.lib.push_notifications.get_apns_client.cache_clear()
zerver.lib.push_notifications.get_apns_context.cache_clear()
try:
with self.settings(APNS_CERT_FILE="/foo.pem"), mock.patch(
"apns2.client.APNsClient"
) as mock_client:
client = get_apns_client()
self.assertEqual(mock_client.return_value, client)
with self.settings(APNS_CERT_FILE="/foo.pem"), mock.patch("aioapns.APNs") as mock_apns:
apns_context = get_apns_context()
assert apns_context is not None
self.assertEqual(mock_apns.return_value, apns_context.apns)
finally:
# Reset the cache for `get_apns_client` so that we don't
# Reset the cache for `get_apns_context` so that we don't
# leak changes to the rest of the world.
zerver.lib.push_notifications.get_apns_client.cache_clear()
zerver.lib.push_notifications.get_apns_context.cache_clear()
def test_not_configured(self) -> None:
self.setup_apns_tokens()
with mock.patch("zerver.lib.push_notifications.get_apns_client") as mock_get, mock.patch(
with mock.patch("zerver.lib.push_notifications.get_apns_context") as mock_get, mock.patch(
"zerver.lib.push_notifications.logger"
) as mock_logging:
mock_get.return_value = None
@ -1350,7 +1357,10 @@ class TestAPNs(PushNotificationTest):
with self.mock_apns() as mock_apns, mock.patch(
"zerver.lib.push_notifications.logger"
) as mock_logging:
mock_apns.get_notification_result.return_value = "Success"
result = mock.Mock()
result.is_successful = True
mock_apns.send_notification.return_value = asyncio.Future()
mock_apns.send_notification.return_value.set_result(result)
self.send()
mock_logging.warning.assert_not_called()
for device in self.devices():
@ -1361,21 +1371,27 @@ class TestAPNs(PushNotificationTest):
)
def test_http_retry(self) -> None:
import hyper
import aioapns
self.setup_apns_tokens()
with self.mock_apns() as mock_apns, mock.patch(
"zerver.lib.push_notifications.logger"
) as mock_logging:
mock_apns.get_notification_result.side_effect = itertools.chain(
[hyper.http20.exceptions.StreamResetError()], itertools.repeat("Success")
exception: asyncio.Future[object] = asyncio.Future()
exception.set_exception(aioapns.exceptions.ConnectionError())
result = mock.Mock()
result.is_successful = True
future: asyncio.Future[object] = asyncio.Future()
future.set_result(result)
mock_apns.send_notification.side_effect = itertools.chain(
[exception], itertools.repeat(future)
)
self.send()
mock_logging.warning.assert_called_once_with(
"APNs: HTTP error sending for user %d to device %s: %s",
"APNs: ConnectionError sending for user %d to device %s: %s",
self.user_profile.id,
self.devices()[0].token,
"StreamResetError",
"ConnectionError",
)
for device in self.devices():
mock_logging.info.assert_any_call(
@ -1384,20 +1400,28 @@ class TestAPNs(PushNotificationTest):
device.token,
)
def test_http_retry_pipefail(self) -> None:
def test_http_retry_closed(self) -> None:
import aioapns
self.setup_apns_tokens()
with self.mock_apns() as mock_apns, mock.patch(
"zerver.lib.push_notifications.logger"
) as mock_logging:
mock_apns.get_notification_result.side_effect = itertools.chain(
[BrokenPipeError()], itertools.repeat("Success")
exception: asyncio.Future[object] = asyncio.Future()
exception.set_exception(aioapns.exceptions.ConnectionClosed())
result = mock.Mock()
result.is_successful = True
future: asyncio.Future[object] = asyncio.Future()
future.set_result(result)
mock_apns.send_notification.side_effect = itertools.chain(
[exception], itertools.repeat(future)
)
self.send()
mock_logging.warning.assert_called_once_with(
"APNs: BrokenPipeError sending for user %d to device %s: %s",
"APNs: ConnectionClosed sending for user %d to device %s: %s",
self.user_profile.id,
self.devices()[0].token,
"BrokenPipeError",
"ConnectionClosed",
)
for device in self.devices():
mock_logging.info.assert_any_call(
@ -1407,19 +1431,15 @@ class TestAPNs(PushNotificationTest):
)
def test_http_retry_eventually_fails(self) -> None:
import hyper
import aioapns
self.setup_apns_tokens()
with self.mock_apns() as mock_apns, mock.patch(
"zerver.lib.push_notifications.logger"
) as mock_logging:
mock_apns.get_notification_result.side_effect = itertools.chain(
[hyper.http20.exceptions.StreamResetError()],
[hyper.http20.exceptions.StreamResetError()],
[hyper.http20.exceptions.StreamResetError()],
[hyper.http20.exceptions.StreamResetError()],
[hyper.http20.exceptions.StreamResetError()],
)
exception: asyncio.Future[object] = asyncio.Future()
exception.set_exception(aioapns.exceptions.ConnectionError())
mock_apns.send_notification.side_effect = iter([exception] * 5)
self.send(devices=self.devices()[0:1])
self.assertEqual(mock_logging.warning.call_count, 5)