diff --git a/zerver/lib/push_notifications.py b/zerver/lib/push_notifications.py index 4daed27e64..3ce8e3e529 100644 --- a/zerver/lib/push_notifications.py +++ b/zerver/lib/push_notifications.py @@ -4,6 +4,7 @@ import base64 import logging import re import time +from functools import lru_cache from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union import gcm @@ -59,26 +60,20 @@ def hex_to_b64(data: str) -> str: # Sending to APNs, for iOS # -_apns_client: Optional["APNsClient"] = None -_apns_client_initialized = False - +@lru_cache(maxsize=None) def get_apns_client() -> "Optional[APNsClient]": # We lazily do this import as part of optimizing Zulip's base # import time. from apns2.client import APNsClient - global _apns_client, _apns_client_initialized - if not _apns_client_initialized: - # NB if called concurrently, this will make excess connections. - # That's a little sloppy, but harmless unless a server gets - # hammered with a ton of these all at once after startup. - if settings.APNS_CERT_FILE is not None: - _apns_client = APNsClient( - credentials=settings.APNS_CERT_FILE, use_sandbox=settings.APNS_SANDBOX - ) - _apns_client_initialized = True - return _apns_client + if settings.APNS_CERT_FILE is None: + return None + + # NB if called concurrently, this will make excess connections. + # That's a little sloppy, but harmless unless a server gets + # hammered with a ton of these all at once after startup. + return APNsClient(credentials=settings.APNS_CERT_FILE, use_sandbox=settings.APNS_SANDBOX) def apns_enabled() -> bool: diff --git a/zerver/tests/test_push_notifications.py b/zerver/tests/test_push_notifications.py index ee1784a0e6..a60cf4a44f 100644 --- a/zerver/tests/test_push_notifications.py +++ b/zerver/tests/test_push_notifications.py @@ -1313,7 +1313,7 @@ class TestAPNs(PushNotificationTest): """ import zerver.lib.push_notifications - zerver.lib.push_notifications._apns_client_initialized = False + zerver.lib.push_notifications.get_apns_client.cache_clear() try: with self.settings(APNS_CERT_FILE="/foo.pem"), mock.patch( "apns2.client.APNsClient" @@ -1321,10 +1321,9 @@ class TestAPNs(PushNotificationTest): client = get_apns_client() self.assertEqual(mock_client.return_value, client) finally: - # Reset the values set by `get_apns_client` so that we don't + # Reset the cache for `get_apns_client` so that we don't # leak changes to the rest of the world. - zerver.lib.push_notifications._apns_client_initialized = False - zerver.lib.push_notifications._apns_client = None + zerver.lib.push_notifications.get_apns_client.cache_clear() def test_not_configured(self) -> None: self.setup_apns_tokens()