cache: Strength types of cache decorators with ParamSpec.

This demonstrates a way to resolve the long-standing issue
of typing higher-order identity functions without using
`cast` and in a type-safe manner for decorators in `cache.py`.

Signed-off-by: Zixuan James Li <359101898@qq.com>
This commit is contained in:
Zixuan James Li 2022-04-12 23:42:12 -04:00 committed by Tim Abbott
parent c3317ebff8
commit f21746ba0b
4 changed files with 45 additions and 28 deletions

View File

@ -1634,8 +1634,14 @@ def compute_jabber_user_fullname(email: str) -> str:
return email.split("@")[0] + " (XMPP)"
def get_user_profile_delivery_email_cache_key(
realm: Realm, email: str, email_to_fullname: Callable[[str], str]
) -> str:
return user_profile_delivery_email_cache_key(email, realm)
@cache_with_key(
lambda realm, email, f: user_profile_delivery_email_cache_key(email, realm),
get_user_profile_delivery_email_cache_key,
timeout=3600 * 24 * 7,
)
def create_mirror_user_if_needed(

View File

@ -19,7 +19,6 @@ from typing import (
Sequence,
Tuple,
TypeVar,
cast,
)
from django.conf import settings
@ -28,6 +27,7 @@ from django.core.cache import caches
from django.core.cache.backends.base import BaseCache
from django.db.models import Q
from django.http import HttpRequest
from typing_extensions import ParamSpec
from zerver.lib.utils import make_safe_digest, statsd, statsd_key
@ -38,7 +38,8 @@ if TYPE_CHECKING:
MEMCACHED_MAX_KEY_LENGTH = 250
FuncT = TypeVar("FuncT", bound=Callable[..., object])
ParamT = ParamSpec("ParamT")
ReturnT = TypeVar("ReturnT")
logger = logging.getLogger()
@ -130,18 +131,18 @@ def get_cache_backend(cache_name: Optional[str]) -> BaseCache:
def get_cache_with_key(
keyfunc: Callable[..., str],
keyfunc: Callable[ParamT, str],
cache_name: Optional[str] = None,
) -> Callable[[FuncT], FuncT]:
) -> Callable[[Callable[ParamT, ReturnT]], Callable[ParamT, ReturnT]]:
"""
The main goal of this function getting value from the cache like in the "cache_with_key".
A cache value can contain any data including the "None", so
here used exception for case if value isn't found in the cache.
"""
def decorator(func: FuncT) -> FuncT:
def decorator(func: Callable[ParamT, ReturnT]) -> Callable[ParamT, ReturnT]:
@wraps(func)
def func_with_caching(*args: object, **kwargs: object) -> object:
def func_with_caching(*args: ParamT.args, **kwargs: ParamT.kwargs) -> ReturnT:
key = keyfunc(*args, **kwargs)
try:
val = cache_get(key, cache_name=cache_name)
@ -154,17 +155,17 @@ def get_cache_with_key(
return val[0]
raise NotFoundInCache()
return cast(FuncT, func_with_caching) # https://github.com/python/mypy/issues/1927
return func_with_caching
return decorator
def cache_with_key(
keyfunc: Callable[..., str],
keyfunc: Callable[ParamT, str],
cache_name: Optional[str] = None,
timeout: Optional[int] = None,
with_statsd_key: Optional[str] = None,
) -> Callable[[FuncT], FuncT]:
) -> Callable[[Callable[ParamT, ReturnT]], Callable[ParamT, ReturnT]]:
"""Decorator which applies Django caching to a function.
Decorator argument is a function which computes a cache key
@ -172,9 +173,9 @@ def cache_with_key(
for avoiding collisions with other uses of this decorator or
other uses of caching."""
def decorator(func: FuncT) -> FuncT:
def decorator(func: Callable[ParamT, ReturnT]) -> Callable[ParamT, ReturnT]:
@wraps(func)
def func_with_caching(*args: object, **kwargs: object) -> object:
def func_with_caching(*args: ParamT.args, **kwargs: ParamT.kwargs) -> ReturnT:
key = keyfunc(*args, **kwargs)
try:
@ -207,7 +208,7 @@ def cache_with_key(
return val
return cast(FuncT, func_with_caching) # https://github.com/python/mypy/issues/1927
return func_with_caching
return decorator
@ -499,7 +500,7 @@ def user_profile_delivery_email_cache_key(delivery_email: str, realm: "Realm") -
return f"user_profile_by_delivery_email:{make_safe_digest(delivery_email.strip())}:{realm.id}"
def bot_profile_cache_key(email: str, realm_id: Optional[int] = None) -> str:
def bot_profile_cache_key(email: str, realm_id: int) -> str:
return f"bot_profile:{make_safe_digest(email.strip())}"
@ -592,7 +593,7 @@ def delete_user_profile_caches(user_profiles: Iterable["UserProfile"]) -> None:
)
if user_profile.is_bot and is_cross_realm_bot_email(user_profile.email):
# Handle clearing system bots from their special cache.
keys.append(bot_profile_cache_key(user_profile.email))
keys.append(bot_profile_cache_key(user_profile.email, user_profile.realm_id))
cache_delete_many(keys)
@ -769,7 +770,7 @@ def flush_submessage(*, instance: "SubMessage", **kwargs: object) -> None:
def ignore_unhashable_lru_cache(
maxsize: int = 128, typed: bool = False
) -> Callable[[FuncT], FuncT]:
) -> Callable[[Callable[ParamT, ReturnT]], Callable[ParamT, ReturnT]]:
"""
This is a wrapper over lru_cache function. It adds following features on
top of lru_cache:
@ -779,7 +780,7 @@ def ignore_unhashable_lru_cache(
"""
internal_decorator = lru_cache(maxsize=maxsize, typed=typed)
def decorator(user_function: FuncT) -> FuncT:
def decorator(user_function: Callable[ParamT, ReturnT]) -> Callable[ParamT, ReturnT]:
if settings.DEVELOPMENT and not settings.TEST_SUITE: # nocoverage
# In the development environment, we want every file
# change to refresh the source files from disk.
@ -788,7 +789,7 @@ def ignore_unhashable_lru_cache(
# Casting to Any since we're about to monkey-patch this.
cache_enabled_user_function: Any = internal_decorator(user_function)
def wrapper(*args: object, **kwargs: object) -> object:
def wrapper(*args: ParamT.args, **kwargs: ParamT.kwargs) -> ReturnT:
if not hasattr(cache_enabled_user_function, "key_prefix"):
cache_enabled_user_function.key_prefix = KEY_PREFIX
@ -813,7 +814,7 @@ def ignore_unhashable_lru_cache(
setattr(wrapper, "cache_info", cache_enabled_user_function.cache_info)
setattr(wrapper, "cache_clear", cache_enabled_user_function.cache_clear)
return cast(FuncT, wrapper) # https://github.com/python/mypy/issues/1927
return wrapper
return decorator

View File

@ -25,7 +25,13 @@ class TinyStreamResult(TypedDict):
name: str
@cache_with_key(lambda *args: display_recipient_cache_key(args[0]), timeout=3600 * 24 * 7)
def get_display_recipient_cache_key(
recipient_id: int, recipient_type: int, recipient_type_id: Optional[int]
) -> str:
return display_recipient_cache_key(recipient_id)
@cache_with_key(get_display_recipient_cache_key, timeout=3600 * 24 * 7)
def get_display_recipient_remote_cache(
recipient_id: int, recipient_type: int, recipient_type_id: Optional[int]
) -> DisplayRecipientT:

View File

@ -774,13 +774,15 @@ class Realm(models.Model):
def __str__(self) -> str:
return f"<Realm: {self.string_id} {self.id}>"
# `realm` instead of `self` here to make sure the parameters of the cache key
# function matches the original method.
@cache_with_key(get_realm_emoji_cache_key, timeout=3600 * 24 * 7)
def get_emoji(self) -> Dict[str, EmojiInfo]:
return get_realm_emoji_uncached(self)
def get_emoji(realm) -> Dict[str, EmojiInfo]:
return get_realm_emoji_uncached(realm)
@cache_with_key(get_active_realm_emoji_cache_key, timeout=3600 * 24 * 7)
def get_active_emoji(self) -> Dict[str, EmojiInfo]:
return get_active_realm_emoji_uncached(self)
def get_active_emoji(realm) -> Dict[str, EmojiInfo]:
return get_active_realm_emoji_uncached(realm)
def get_admin_users_and_bots(
self, include_realm_owners: bool = True
@ -880,9 +882,11 @@ class Realm(models.Model):
# it as gibibytes (GiB) to be a bit more generous in case of confusion.
return self.upload_quota_gb << 30
# `realm` instead of `self` here to make sure the parameters of the cache key
# function matches the original method.
@cache_with_key(get_realm_used_upload_space_cache_key, timeout=3600 * 24 * 7)
def currently_used_upload_space_bytes(self) -> int:
used_space = Attachment.objects.filter(realm=self).aggregate(Sum("size"))["size__sum"]
def currently_used_upload_space_bytes(realm) -> int:
used_space = Attachment.objects.filter(realm=realm).aggregate(Sum("size"))["size__sum"]
if used_space is None:
return 0
return used_space
@ -3565,8 +3569,8 @@ class Subscription(models.Model):
@cache_with_key(user_profile_by_id_cache_key, timeout=3600 * 24 * 7)
def get_user_profile_by_id(uid: int) -> UserProfile:
return UserProfile.objects.select_related().get(id=uid)
def get_user_profile_by_id(user_profile_id: int) -> UserProfile:
return UserProfile.objects.select_related().get(id=user_profile_id)
def get_user_profile_by_email(email: str) -> UserProfile: