push_notifications: Replace PyAPNs2 with aioapns.

Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
Anders Kaseorg 2021-06-07 17:45:49 -07:00 committed by Tim Abbott
parent 0bc002270c
commit 3853285241
6 changed files with 122 additions and 114 deletions

View File

@ -100,7 +100,7 @@ tornado==4.* # https://github.com/zulip/zulip/issues/8913
orjson orjson
# Needed for iOS push notifications # Needed for iOS push notifications
apns2 aioapns==1.* # 2.0 needs PyJWT 2: https://github.com/twilio/twilio-python/issues/556
python-twitter python-twitter
@ -130,7 +130,7 @@ py3dns
# Install Python Social Auth # Install Python Social Auth
social-auth-app-django social-auth-app-django
social-auth-core[azuread,openidconnect,saml]<4.0.3 # 4.0.3 needs PyJWT 2: https://github.com/Pr0Ger/PyAPNs2/pull/122, https://github.com/twilio/twilio-python/issues/556 social-auth-core[azuread,openidconnect,saml]<4.0.3 # 4.0.3 needs PyJWT 2: https://github.com/twilio/twilio-python/issues/556
# For encrypting a login token to the desktop app # For encrypting a login token to the desktop app
cryptography cryptography

View File

@ -12,7 +12,7 @@ moto[s3]
Twisted Twisted
# Needed for documentation links test # Needed for documentation links test
Scrapy<2.5.0 # 2.5.0 needs h2 3.0: https://github.com/Pr0Ger/PyAPNs2/issues/126 Scrapy
# Needed to compute test coverage # Needed to compute test coverage
coverage coverage

View File

@ -7,14 +7,13 @@
# #
# For details, see requirements/README.md . # For details, see requirements/README.md .
# #
aioapns==1.12 \
--hash=sha256:263e36188bb218105c35bcbfde9252d78780805168fa2071d3f40b08bee14b17
# via -r requirements/common.in
alabaster==0.7.12 \ alabaster==0.7.12 \
--hash=sha256:446438bdcca0e05bd45ea2de1668c1d9b032e1a9154c2c259092d77031ddd359 \ --hash=sha256:446438bdcca0e05bd45ea2de1668c1d9b032e1a9154c2c259092d77031ddd359 \
--hash=sha256:a661d72d58e6ea8a57f7a86e37d86716863ee5e92788398526d58b26a4e4dc02 --hash=sha256:a661d72d58e6ea8a57f7a86e37d86716863ee5e92788398526d58b26a4e4dc02
# via sphinx # via sphinx
apns2==0.7.2 \
--hash=sha256:4f2dae8c608961d1768f734acb1d0809a60ac71a0cdcca60f46529b73f20fb34 \
--hash=sha256:f64a50181d0206a02943c835814a34fc1b1e12914931b74269a0f0fb4f39fd45
# via -r requirements/common.in
appdirs==1.4.4 \ appdirs==1.4.4 \
--hash=sha256:7d5d0167b2b1ba821647616af46a749d1c653740dd0d2415100fe26e27afdf41 \ --hash=sha256:7d5d0167b2b1ba821647616af46a749d1c653740dd0d2415100fe26e27afdf41 \
--hash=sha256:a841dacd6b99318a741b166adb07e19ee71a274450e68237b4650ca1055ab128 --hash=sha256:a841dacd6b99318a741b166adb07e19ee71a274450e68237b4650ca1055ab128
@ -304,7 +303,6 @@ cryptography==3.4.7 \
--hash=sha256:ee77aa129f481be46f8d92a1a7db57269a2f23052d5f2433b4621bb457081cc9 --hash=sha256:ee77aa129f481be46f8d92a1a7db57269a2f23052d5f2433b4621bb457081cc9
# via # via
# -r requirements/common.in # -r requirements/common.in
# apns2
# moto # moto
# pyopenssl # pyopenssl
# requests # requests
@ -437,10 +435,10 @@ gitlint==0.15.1 \
--hash=sha256:4b22916dcbdca381244aee6cb8d8743756cfd98f27e7d1f02e78733f07c3c21c \ --hash=sha256:4b22916dcbdca381244aee6cb8d8743756cfd98f27e7d1f02e78733f07c3c21c \
--hash=sha256:7ebdb8e7d333e577e956225cbc3ad8e0e96d05e638e6d461c9b66b784f9d2ac4 --hash=sha256:7ebdb8e7d333e577e956225cbc3ad8e0e96d05e638e6d461c9b66b784f9d2ac4
# via -r requirements/dev.in # via -r requirements/dev.in
h2==2.6.2 \ h2==3.2.0 \
--hash=sha256:93cbd1013a2218539af05cdf9fc37b786655b93bbc94f5296b7dabd1c5cadf41 \ --hash=sha256:61e0f6601fa709f35cdb730863b4e5ec7ad449792add80d1410d4174ed139af5 \
--hash=sha256:af35878673c83a44afbc12b13ac91a489da2819b5dc1e11768f3c2406f740fe9 --hash=sha256:875f41ebd6f2c44781259005b157faed1a5031df3ae5aa7bcb4628a6c0782f14
# via hyper # via aioapns
hpack==3.0.0 \ hpack==3.0.0 \
--hash=sha256:0edd79eda27a53ba5be2dfabf3b15780928a0dff6eb0c60a3d6767720e970c89 \ --hash=sha256:0edd79eda27a53ba5be2dfabf3b15780928a0dff6eb0c60a3d6767720e970c89 \
--hash=sha256:8eec9c1f4bfae3408a3f30500261f7e6a65912dc138526ea054f9ad98892e9d2 --hash=sha256:8eec9c1f4bfae3408a3f30500261f7e6a65912dc138526ea054f9ad98892e9d2
@ -455,16 +453,10 @@ html5lib==1.1 \
--hash=sha256:0d78f8fde1c230e99fe37986a60526d7049ed4bf8a9fadbad5f00e22e58e041d \ --hash=sha256:0d78f8fde1c230e99fe37986a60526d7049ed4bf8a9fadbad5f00e22e58e041d \
--hash=sha256:b2e5b40261e20f354d198eae92afc10d750afb487ed5e50f9c4eaf07c184146f --hash=sha256:b2e5b40261e20f354d198eae92afc10d750afb487ed5e50f9c4eaf07c184146f
# via talon-core # via talon-core
hyper==0.7.0 \ hyperframe==5.2.0 \
--hash=sha256:069514f54231fb7b5df2fb910a114663a83306d5296f588fffcb0a9be19407fc \ --hash=sha256:5187962cb16dcc078f23cb5a4b110098d546c3f41ff2d4038a9896893bbd0b40 \
--hash=sha256:12c82eacd122a659673484c1ea0d34576430afbe5aa6b8f63fe37fcb06a2458c --hash=sha256:a9f5c17f2cc3c719b917c4f33ed1c61bd1f8dfac4b1bd23b7c80b3400971b41f
# via apns2 # via h2
hyperframe==3.2.0 \
--hash=sha256:05f0e063e117c16fcdd13c12c93a4424a2c40668abfac3bb419a10f57698204e \
--hash=sha256:4dcab11967482d400853b396d042038e4c492a15a5d2f57259e2b5f89a32f755
# via
# h2
# hyper
hyperlink==21.0.0 \ hyperlink==21.0.0 \
--hash=sha256:427af957daa58bc909471c6c40f74c5450fa123dd093fc53efd2e91d2705a56b \ --hash=sha256:427af957daa58bc909471c6c40f74c5450fa123dd093fc53efd2e91d2705a56b \
--hash=sha256:e6b14c37ecb73e89c77d78cdb4c2cc8f3fb59a885c5b3f819ff4ed80f25af1b4 --hash=sha256:e6b14c37ecb73e89c77d78cdb4c2cc8f3fb59a885c5b3f819ff4ed80f25af1b4
@ -1048,7 +1040,7 @@ pyjwt==1.7.1 \
--hash=sha256:8d59a976fb773f3e6a39c85636357c4f0e242707394cadadd9814f5cbaa20e96 --hash=sha256:8d59a976fb773f3e6a39c85636357c4f0e242707394cadadd9814f5cbaa20e96
# via # via
# -r requirements/common.in # -r requirements/common.in
# apns2 # aioapns
# social-auth-core # social-auth-core
# twilio # twilio
pyoembed==0.1.2 \ pyoembed==0.1.2 \
@ -1058,6 +1050,7 @@ pyopenssl==20.0.1 \
--hash=sha256:4c231c759543ba02560fcd2480c48dcec4dae34c9da7d3747c508227e0624b51 \ --hash=sha256:4c231c759543ba02560fcd2480c48dcec4dae34c9da7d3747c508227e0624b51 \
--hash=sha256:818ae18e06922c066f777a33f1fca45786d85edfe71cd043de6379337a7f274b --hash=sha256:818ae18e06922c066f777a33f1fca45786d85edfe71cd043de6379337a7f274b
# via # via
# aioapns
# requests # requests
# scrapy # scrapy
pyparsing==2.4.7 \ pyparsing==2.4.7 \

View File

@ -7,9 +7,8 @@
# #
# For details, see requirements/README.md . # For details, see requirements/README.md .
# #
apns2==0.7.2 \ aioapns==1.12 \
--hash=sha256:4f2dae8c608961d1768f734acb1d0809a60ac71a0cdcca60f46529b73f20fb34 \ --hash=sha256:263e36188bb218105c35bcbfde9252d78780805168fa2071d3f40b08bee14b17
--hash=sha256:f64a50181d0206a02943c835814a34fc1b1e12914931b74269a0f0fb4f39fd45
# via -r requirements/common.in # via -r requirements/common.in
argon2-cffi==20.1.0 \ argon2-cffi==20.1.0 \
--hash=sha256:05a8ac07c7026542377e38389638a8a1e9b78f1cd8439cd7493b39f08dd75fbf \ --hash=sha256:05a8ac07c7026542377e38389638a8a1e9b78f1cd8439cd7493b39f08dd75fbf \
@ -186,7 +185,6 @@ cryptography==3.4.7 \
--hash=sha256:ee77aa129f481be46f8d92a1a7db57269a2f23052d5f2433b4621bb457081cc9 --hash=sha256:ee77aa129f481be46f8d92a1a7db57269a2f23052d5f2433b4621bb457081cc9
# via # via
# -r requirements/common.in # -r requirements/common.in
# apns2
# pyopenssl # pyopenssl
# requests # requests
# social-auth-core # social-auth-core
@ -286,10 +284,10 @@ ecdsa==0.17.0 \
future==0.18.2 \ future==0.18.2 \
--hash=sha256:b1bead90b70cf6ec3f0710ae53a525360fa360d306a86583adc6bf83a4db537d --hash=sha256:b1bead90b70cf6ec3f0710ae53a525360fa360d306a86583adc6bf83a4db537d
# via python-twitter # via python-twitter
h2==2.6.2 \ h2==3.2.0 \
--hash=sha256:93cbd1013a2218539af05cdf9fc37b786655b93bbc94f5296b7dabd1c5cadf41 \ --hash=sha256:61e0f6601fa709f35cdb730863b4e5ec7ad449792add80d1410d4174ed139af5 \
--hash=sha256:af35878673c83a44afbc12b13ac91a489da2819b5dc1e11768f3c2406f740fe9 --hash=sha256:875f41ebd6f2c44781259005b157faed1a5031df3ae5aa7bcb4628a6c0782f14
# via hyper # via aioapns
hpack==3.0.0 \ hpack==3.0.0 \
--hash=sha256:0edd79eda27a53ba5be2dfabf3b15780928a0dff6eb0c60a3d6767720e970c89 \ --hash=sha256:0edd79eda27a53ba5be2dfabf3b15780928a0dff6eb0c60a3d6767720e970c89 \
--hash=sha256:8eec9c1f4bfae3408a3f30500261f7e6a65912dc138526ea054f9ad98892e9d2 --hash=sha256:8eec9c1f4bfae3408a3f30500261f7e6a65912dc138526ea054f9ad98892e9d2
@ -304,16 +302,10 @@ html5lib==1.1 \
--hash=sha256:0d78f8fde1c230e99fe37986a60526d7049ed4bf8a9fadbad5f00e22e58e041d \ --hash=sha256:0d78f8fde1c230e99fe37986a60526d7049ed4bf8a9fadbad5f00e22e58e041d \
--hash=sha256:b2e5b40261e20f354d198eae92afc10d750afb487ed5e50f9c4eaf07c184146f --hash=sha256:b2e5b40261e20f354d198eae92afc10d750afb487ed5e50f9c4eaf07c184146f
# via talon-core # via talon-core
hyper==0.7.0 \ hyperframe==5.2.0 \
--hash=sha256:069514f54231fb7b5df2fb910a114663a83306d5296f588fffcb0a9be19407fc \ --hash=sha256:5187962cb16dcc078f23cb5a4b110098d546c3f41ff2d4038a9896893bbd0b40 \
--hash=sha256:12c82eacd122a659673484c1ea0d34576430afbe5aa6b8f63fe37fcb06a2458c --hash=sha256:a9f5c17f2cc3c719b917c4f33ed1c61bd1f8dfac4b1bd23b7c80b3400971b41f
# via apns2 # via h2
hyperframe==3.2.0 \
--hash=sha256:05f0e063e117c16fcdd13c12c93a4424a2c40668abfac3bb419a10f57698204e \
--hash=sha256:4dcab11967482d400853b396d042038e4c492a15a5d2f57259e2b5f89a32f755
# via
# h2
# hyper
idna==2.10 \ idna==2.10 \
--hash=sha256:b307872f855b18632ce0c21c5e45be78c0ea7ae4c15c828c20788b26921eb3f6 \ --hash=sha256:b307872f855b18632ce0c21c5e45be78c0ea7ae4c15c828c20788b26921eb3f6 \
--hash=sha256:b97d804b1e9b523befed77c48dacec60e6dcb0b5391d57af6a65a312a90648c0 --hash=sha256:b97d804b1e9b523befed77c48dacec60e6dcb0b5391d57af6a65a312a90648c0
@ -672,7 +664,7 @@ pyjwt==1.7.1 \
--hash=sha256:8d59a976fb773f3e6a39c85636357c4f0e242707394cadadd9814f5cbaa20e96 --hash=sha256:8d59a976fb773f3e6a39c85636357c4f0e242707394cadadd9814f5cbaa20e96
# via # via
# -r requirements/common.in # -r requirements/common.in
# apns2 # aioapns
# social-auth-core # social-auth-core
# twilio # twilio
pyoembed==0.1.2 \ pyoembed==0.1.2 \
@ -681,7 +673,9 @@ pyoembed==0.1.2 \
pyopenssl==20.0.1 \ pyopenssl==20.0.1 \
--hash=sha256:4c231c759543ba02560fcd2480c48dcec4dae34c9da7d3747c508227e0624b51 \ --hash=sha256:4c231c759543ba02560fcd2480c48dcec4dae34c9da7d3747c508227e0624b51 \
--hash=sha256:818ae18e06922c066f777a33f1fca45786d85edfe71cd043de6379337a7f274b --hash=sha256:818ae18e06922c066f777a33f1fca45786d85edfe71cd043de6379337a7f274b
# via requests # via
# aioapns
# requests
pyrsistent==0.18.0 \ pyrsistent==0.18.0 \
--hash=sha256:097b96f129dd36a8c9e33594e7ebb151b1515eb52cceb08474c10a5479e799f2 \ --hash=sha256:097b96f129dd36a8c9e33594e7ebb151b1515eb52cceb08474c10a5479e799f2 \
--hash=sha256:2aaf19dc8ce517a8653746d98e962ef480ff34b6bc563fc067be6401ffb457c7 \ --hash=sha256:2aaf19dc8ce517a8653746d98e962ef480ff34b6bc563fc067be6401ffb457c7 \

View File

@ -1,9 +1,10 @@
# See https://zulip.readthedocs.io/en/latest/subsystems/notifications.html # See https://zulip.readthedocs.io/en/latest/subsystems/notifications.html
import asyncio
import base64 import base64
import logging import logging
import re import re
import time from dataclasses import dataclass
from functools import lru_cache from functools import lru_cache
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union
@ -38,7 +39,7 @@ from zerver.models import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from apns2.client import APNsClient import aioapns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -61,11 +62,17 @@ def hex_to_b64(data: str) -> str:
# #
@dataclass
class APNsContext:
apns: "aioapns.APNs"
loop: asyncio.AbstractEventLoop
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def get_apns_client() -> "Optional[APNsClient]": def get_apns_context() -> Optional[APNsContext]:
# We lazily do this import as part of optimizing Zulip's base # We lazily do this import as part of optimizing Zulip's base
# import time. # import time.
from apns2.client import APNsClient import aioapns
if settings.APNS_CERT_FILE is None: if settings.APNS_CERT_FILE is None:
return None return None
@ -73,12 +80,18 @@ def get_apns_client() -> "Optional[APNsClient]":
# NB if called concurrently, this will make excess connections. # NB if called concurrently, this will make excess connections.
# That's a little sloppy, but harmless unless a server gets # That's a little sloppy, but harmless unless a server gets
# hammered with a ton of these all at once after startup. # hammered with a ton of these all at once after startup.
return APNsClient(credentials=settings.APNS_CERT_FILE, use_sandbox=settings.APNS_SANDBOX) loop = asyncio.new_event_loop()
apns = aioapns.APNs(
client_cert=settings.APNS_CERT_FILE,
topic=settings.APNS_TOPIC,
loop=loop,
use_sandbox=settings.APNS_SANDBOX,
)
return APNsContext(apns=apns, loop=loop)
def apns_enabled() -> bool: def apns_enabled() -> bool:
client = get_apns_client() return get_apns_context() is not None
return client is not None
def modernize_apns_payload(data: Dict[str, Any]) -> Dict[str, Any]: def modernize_apns_payload(data: Dict[str, Any]) -> Dict[str, Any]:
@ -120,11 +133,11 @@ def send_apple_push_notification(
# import time; since these are only needed in the push # import time; since these are only needed in the push
# notification queue worker, it's best to only import them in the # notification queue worker, it's best to only import them in the
# code that needs them. # code that needs them.
from apns2.payload import Payload as APNsPayload import aioapns
from hyper.http20.exceptions import HTTP20Error import aioapns.exceptions
client = get_apns_client() apns_context = get_apns_context()
if client is None: if apns_context is None:
logger.debug( logger.debug(
"APNs: Dropping a notification because nothing configured. " "APNs: Dropping a notification because nothing configured. "
"Set PUSH_NOTIFICATION_BOUNCER_URL (or APNS_CERT_FILE)." "Set PUSH_NOTIFICATION_BOUNCER_URL (or APNS_CERT_FILE)."
@ -138,36 +151,28 @@ def send_apple_push_notification(
DeviceTokenClass = PushDeviceToken DeviceTokenClass = PushDeviceToken
logger.info("APNs: Sending notification for user %d to %d devices", user_id, len(devices)) logger.info("APNs: Sending notification for user %d to %d devices", user_id, len(devices))
payload = APNsPayload(**modernize_apns_payload(payload_data)) message = {"aps": modernize_apns_payload(payload_data)}
expiration = int(time.time() + 24 * 3600)
retries_left = APNS_MAX_RETRIES retries_left = APNS_MAX_RETRIES
for device in devices: for device in devices:
# TODO obviously this should be made to actually use the async # TODO obviously this should be made to actually use the async
request = aioapns.NotificationRequest(
device_token=device.token, message=message, time_to_live=24 * 3600
)
def attempt_send() -> Optional[str]: async def attempt_send() -> Optional[str]:
assert client is not None assert apns_context is not None
try: try:
stream_id = client.send_notification_async( result = await apns_context.apns.send_notification(request)
device.token, payload, topic=settings.APNS_TOPIC, expiration=expiration return "Success" if result.is_successful else result.description
) except aioapns.exceptions.ConnectionClosed as e: # nocoverage
return client.get_notification_result(stream_id)
except HTTP20Error as e:
logger.warning( logger.warning(
"APNs: HTTP error sending for user %d to device %s: %s", "APNs: ConnectionClosed sending for user %d to device %s: %s",
user_id, user_id,
device.token, device.token,
e.__class__.__name__, e.__class__.__name__,
) )
return None return None
except BrokenPipeError as e: except aioapns.exceptions.ConnectionError as e: # nocoverage
logger.warning(
"APNs: BrokenPipeError sending for user %d to device %s: %s",
user_id,
device.token,
e.__class__.__name__,
)
return None
except ConnectionError as e: # nocoverage
logger.warning( logger.warning(
"APNs: ConnectionError sending for user %d to device %s: %s", "APNs: ConnectionError sending for user %d to device %s: %s",
user_id, user_id,
@ -176,17 +181,13 @@ def send_apple_push_notification(
) )
return None return None
result = attempt_send() result = apns_context.loop.run_until_complete(attempt_send())
while result is None and retries_left > 0: while result is None and retries_left > 0:
retries_left -= 1 retries_left -= 1
result = attempt_send() result = apns_context.loop.run_until_complete(attempt_send())
if result is None: if result is None:
result = "HTTP error, retries exhausted" result = "HTTP error, retries exhausted"
if result[0] == "Unregistered":
# For some reason, "Unregistered" result values have a
# different format, as a tuple of the pair ("Unregistered", 12345132131).
result = result[0]
if result == "Success": if result == "Success":
logger.info("APNs: Success sending for user %d to device %s", user_id, device.token) logger.info("APNs: Success sending for user %d to device %s", user_id, device.token)
elif result in ["Unregistered", "BadDeviceToken", "DeviceTokenNotForTopic"]: elif result in ["Unregistered", "BadDeviceToken", "DeviceTokenNotForTopic"]:

View File

@ -1,3 +1,4 @@
import asyncio
import base64 import base64
import datetime import datetime
import itertools import itertools
@ -30,13 +31,14 @@ from zerver.lib.actions import (
do_update_message_flags, do_update_message_flags,
) )
from zerver.lib.push_notifications import ( from zerver.lib.push_notifications import (
APNsContext,
DeviceToken, DeviceToken,
absolute_avatar_url, absolute_avatar_url,
b64_to_hex, b64_to_hex,
datetime_to_timestamp, datetime_to_timestamp,
get_apns_badge_count, get_apns_badge_count,
get_apns_badge_count_future, get_apns_badge_count_future,
get_apns_client, get_apns_context,
get_display_recipient, get_display_recipient,
get_message_payload_apns, get_message_payload_apns,
get_message_payload_gcm, get_message_payload_gcm,
@ -746,8 +748,8 @@ class PushNotificationTest(BouncerTestCase):
@contextmanager @contextmanager
def mock_apns(self) -> Iterator[mock.MagicMock]: def mock_apns(self) -> Iterator[mock.MagicMock]:
mock_apns = mock.Mock() mock_apns = mock.Mock()
with mock.patch("zerver.lib.push_notifications.get_apns_client") as mock_get: with mock.patch("zerver.lib.push_notifications.get_apns_context") as mock_get:
mock_get.return_value = mock_apns mock_get.return_value = APNsContext(apns=mock_apns, loop=asyncio.new_event_loop())
yield mock_apns yield mock_apns
def setup_apns_tokens(self) -> None: def setup_apns_tokens(self) -> None:
@ -833,7 +835,10 @@ class HandlePushNotificationTest(PushNotificationTest):
for device in RemotePushDeviceToken.objects.filter(kind=PushDeviceToken.GCM) for device in RemotePushDeviceToken.objects.filter(kind=PushDeviceToken.GCM)
] ]
mock_gcm.json_request.return_value = {"success": {gcm_devices[0][2]: message.id}} mock_gcm.json_request.return_value = {"success": {gcm_devices[0][2]: message.id}}
mock_apns.get_notification_result.return_value = "Success" result = mock.Mock()
result.is_successful = True
mock_apns.send_notification.return_value = asyncio.Future()
mock_apns.send_notification.return_value.set_result(result)
handle_push_notification(self.user_profile.id, missed_message) handle_push_notification(self.user_profile.id, missed_message)
for _, _, token in apns_devices: for _, _, token in apns_devices:
mock_info.assert_any_call( mock_info.assert_any_call(
@ -879,8 +884,11 @@ class HandlePushNotificationTest(PushNotificationTest):
for device in RemotePushDeviceToken.objects.filter(kind=PushDeviceToken.GCM) for device in RemotePushDeviceToken.objects.filter(kind=PushDeviceToken.GCM)
] ]
mock_gcm.json_request.return_value = {"success": {gcm_devices[0][2]: message.id}} mock_gcm.json_request.return_value = {"success": {gcm_devices[0][2]: message.id}}
result = mock.Mock()
mock_apns.get_notification_result.return_value = ("Unregistered", 1234567) result.is_successful = False
result.description = "Unregistered"
mock_apns.send_notification.return_value = asyncio.Future()
mock_apns.send_notification.return_value.set_result(result)
handle_push_notification(self.user_profile.id, missed_message) handle_push_notification(self.user_profile.id, missed_message)
for _, _, token in apns_devices: for _, _, token in apns_devices:
mock_info.assert_any_call( mock_info.assert_any_call(
@ -1306,28 +1314,27 @@ class TestAPNs(PushNotificationTest):
devices = self.devices() devices = self.devices()
send_apple_push_notification(self.user_profile.id, devices, payload_data) send_apple_push_notification(self.user_profile.id, devices, payload_data)
def test_get_apns_client(self) -> None: def test_get_apns_context(self) -> None:
"""This test is pretty hacky, and needs to carefully reset the state """This test is pretty hacky, and needs to carefully reset the state
it modifies in order to avoid leaking state that can lead to it modifies in order to avoid leaking state that can lead to
nondeterministic results for other tests. nondeterministic results for other tests.
""" """
import zerver.lib.push_notifications import zerver.lib.push_notifications
zerver.lib.push_notifications.get_apns_client.cache_clear() zerver.lib.push_notifications.get_apns_context.cache_clear()
try: try:
with self.settings(APNS_CERT_FILE="/foo.pem"), mock.patch( with self.settings(APNS_CERT_FILE="/foo.pem"), mock.patch("aioapns.APNs") as mock_apns:
"apns2.client.APNsClient" apns_context = get_apns_context()
) as mock_client: assert apns_context is not None
client = get_apns_client() self.assertEqual(mock_apns.return_value, apns_context.apns)
self.assertEqual(mock_client.return_value, client)
finally: finally:
# Reset the cache for `get_apns_client` so that we don't # Reset the cache for `get_apns_context` so that we don't
# leak changes to the rest of the world. # leak changes to the rest of the world.
zerver.lib.push_notifications.get_apns_client.cache_clear() zerver.lib.push_notifications.get_apns_context.cache_clear()
def test_not_configured(self) -> None: def test_not_configured(self) -> None:
self.setup_apns_tokens() self.setup_apns_tokens()
with mock.patch("zerver.lib.push_notifications.get_apns_client") as mock_get, mock.patch( with mock.patch("zerver.lib.push_notifications.get_apns_context") as mock_get, mock.patch(
"zerver.lib.push_notifications.logger" "zerver.lib.push_notifications.logger"
) as mock_logging: ) as mock_logging:
mock_get.return_value = None mock_get.return_value = None
@ -1350,7 +1357,10 @@ class TestAPNs(PushNotificationTest):
with self.mock_apns() as mock_apns, mock.patch( with self.mock_apns() as mock_apns, mock.patch(
"zerver.lib.push_notifications.logger" "zerver.lib.push_notifications.logger"
) as mock_logging: ) as mock_logging:
mock_apns.get_notification_result.return_value = "Success" result = mock.Mock()
result.is_successful = True
mock_apns.send_notification.return_value = asyncio.Future()
mock_apns.send_notification.return_value.set_result(result)
self.send() self.send()
mock_logging.warning.assert_not_called() mock_logging.warning.assert_not_called()
for device in self.devices(): for device in self.devices():
@ -1361,21 +1371,27 @@ class TestAPNs(PushNotificationTest):
) )
def test_http_retry(self) -> None: def test_http_retry(self) -> None:
import hyper import aioapns
self.setup_apns_tokens() self.setup_apns_tokens()
with self.mock_apns() as mock_apns, mock.patch( with self.mock_apns() as mock_apns, mock.patch(
"zerver.lib.push_notifications.logger" "zerver.lib.push_notifications.logger"
) as mock_logging: ) as mock_logging:
mock_apns.get_notification_result.side_effect = itertools.chain( exception: asyncio.Future[object] = asyncio.Future()
[hyper.http20.exceptions.StreamResetError()], itertools.repeat("Success") exception.set_exception(aioapns.exceptions.ConnectionError())
result = mock.Mock()
result.is_successful = True
future: asyncio.Future[object] = asyncio.Future()
future.set_result(result)
mock_apns.send_notification.side_effect = itertools.chain(
[exception], itertools.repeat(future)
) )
self.send() self.send()
mock_logging.warning.assert_called_once_with( mock_logging.warning.assert_called_once_with(
"APNs: HTTP error sending for user %d to device %s: %s", "APNs: ConnectionError sending for user %d to device %s: %s",
self.user_profile.id, self.user_profile.id,
self.devices()[0].token, self.devices()[0].token,
"StreamResetError", "ConnectionError",
) )
for device in self.devices(): for device in self.devices():
mock_logging.info.assert_any_call( mock_logging.info.assert_any_call(
@ -1384,20 +1400,28 @@ class TestAPNs(PushNotificationTest):
device.token, device.token,
) )
def test_http_retry_pipefail(self) -> None: def test_http_retry_closed(self) -> None:
import aioapns
self.setup_apns_tokens() self.setup_apns_tokens()
with self.mock_apns() as mock_apns, mock.patch( with self.mock_apns() as mock_apns, mock.patch(
"zerver.lib.push_notifications.logger" "zerver.lib.push_notifications.logger"
) as mock_logging: ) as mock_logging:
mock_apns.get_notification_result.side_effect = itertools.chain( exception: asyncio.Future[object] = asyncio.Future()
[BrokenPipeError()], itertools.repeat("Success") exception.set_exception(aioapns.exceptions.ConnectionClosed())
result = mock.Mock()
result.is_successful = True
future: asyncio.Future[object] = asyncio.Future()
future.set_result(result)
mock_apns.send_notification.side_effect = itertools.chain(
[exception], itertools.repeat(future)
) )
self.send() self.send()
mock_logging.warning.assert_called_once_with( mock_logging.warning.assert_called_once_with(
"APNs: BrokenPipeError sending for user %d to device %s: %s", "APNs: ConnectionClosed sending for user %d to device %s: %s",
self.user_profile.id, self.user_profile.id,
self.devices()[0].token, self.devices()[0].token,
"BrokenPipeError", "ConnectionClosed",
) )
for device in self.devices(): for device in self.devices():
mock_logging.info.assert_any_call( mock_logging.info.assert_any_call(
@ -1407,19 +1431,15 @@ class TestAPNs(PushNotificationTest):
) )
def test_http_retry_eventually_fails(self) -> None: def test_http_retry_eventually_fails(self) -> None:
import hyper import aioapns
self.setup_apns_tokens() self.setup_apns_tokens()
with self.mock_apns() as mock_apns, mock.patch( with self.mock_apns() as mock_apns, mock.patch(
"zerver.lib.push_notifications.logger" "zerver.lib.push_notifications.logger"
) as mock_logging: ) as mock_logging:
mock_apns.get_notification_result.side_effect = itertools.chain( exception: asyncio.Future[object] = asyncio.Future()
[hyper.http20.exceptions.StreamResetError()], exception.set_exception(aioapns.exceptions.ConnectionError())
[hyper.http20.exceptions.StreamResetError()], mock_apns.send_notification.side_effect = iter([exception] * 5)
[hyper.http20.exceptions.StreamResetError()],
[hyper.http20.exceptions.StreamResetError()],
[hyper.http20.exceptions.StreamResetError()],
)
self.send(devices=self.devices()[0:1]) self.send(devices=self.devices()[0:1])
self.assertEqual(mock_logging.warning.call_count, 5) self.assertEqual(mock_logging.warning.call_count, 5)