zulip/zerver/lib/cache.py

509 lines
20 KiB
Python

from functools import wraps
from django.utils.lru_cache import lru_cache
from django.core.cache import cache as djcache
from django.core.cache import caches
from django.conf import settings
from django.db.models import Q
from django.core.cache.backends.base import BaseCache
from typing import cast, Any, Callable, Dict, Iterable, List, Optional, Union, Set, TypeVar, Tuple
from zerver.lib.utils import statsd, statsd_key, make_safe_digest
import subprocess
import time
import base64
import random
import sys
import os
import hashlib
if False:
from zerver.models import UserProfile, Realm, Message
# These modules have to be imported for type annotations but
# they cannot be imported at runtime due to cyclic dependency.
ReturnT = TypeVar('ReturnT') # Useful for matching return types via Callable[..., ReturnT]
class NotFoundInCache(Exception):
pass
remote_cache_time_start = 0.0
remote_cache_total_time = 0.0
remote_cache_total_requests = 0
def get_remote_cache_time() -> float:
return remote_cache_total_time
def get_remote_cache_requests() -> int:
return remote_cache_total_requests
def remote_cache_stats_start() -> None:
global remote_cache_time_start
remote_cache_time_start = time.time()
def remote_cache_stats_finish() -> None:
global remote_cache_total_time
global remote_cache_total_requests
global remote_cache_time_start
remote_cache_total_requests += 1
remote_cache_total_time += (time.time() - remote_cache_time_start)
def get_or_create_key_prefix() -> str:
if settings.CASPER_TESTS:
# This sets the prefix for the benefit of the Casper tests.
#
# Having a fixed key is OK since we don't support running
# multiple copies of the casper tests at the same time anyway.
return 'casper_tests:'
elif settings.TEST_SUITE:
# The Python tests overwrite KEY_PREFIX on each test, but use
# this codepath as well, just to save running the more complex
# code below for reading the normal key prefix.
return 'django_tests_unused:'
# directory `var` should exist in production
subprocess.check_call(["mkdir", "-p", os.path.join(settings.DEPLOY_ROOT, "var")])
filename = os.path.join(settings.DEPLOY_ROOT, "var", "remote_cache_prefix")
try:
fd = os.open(filename, os.O_CREAT | os.O_EXCL | os.O_RDWR, 0o444)
random_hash = hashlib.sha256(str(random.getrandbits(256)).encode('utf-8')).digest()
prefix = base64.b16encode(random_hash)[:32].decode('utf-8').lower() + ':'
# This does close the underlying file
with os.fdopen(fd, 'w') as f:
f.write(prefix + "\n")
except OSError:
# The file already exists
tries = 1
while tries < 10:
with open(filename, 'r') as f:
prefix = f.readline()[:-1]
if len(prefix) == 33:
break
tries += 1
prefix = ''
time.sleep(0.5)
if not prefix:
print("Could not read remote cache key prefix file")
sys.exit(1)
return prefix
KEY_PREFIX = get_or_create_key_prefix() # type: str
def bounce_key_prefix_for_testing(test_name: str) -> None:
global KEY_PREFIX
KEY_PREFIX = test_name + ':' + str(os.getpid()) + ':'
# We are taking the hash of the KEY_PREFIX to decrease the size of the key.
# Memcached keys should have a length of less than 256.
KEY_PREFIX = hashlib.sha1(KEY_PREFIX.encode('utf-8')).hexdigest()
def get_cache_backend(cache_name: Optional[str]) -> BaseCache:
if cache_name is None:
return djcache
return caches[cache_name]
def get_cache_with_key(
keyfunc: Callable[..., str],
cache_name: Optional[str]=None
) -> Callable[[Callable[..., ReturnT]], Callable[..., ReturnT]]:
"""
The main goal of this function getting value from the cache like in the "cache_with_key".
A cache value can contain any data including the "None", so
here used exception for case if value isn't found in the cache.
"""
def decorator(func: Callable[..., ReturnT]) -> (Callable[..., ReturnT]):
@wraps(func)
def func_with_caching(*args: Any, **kwargs: Any) -> Callable[..., ReturnT]:
key = keyfunc(*args, **kwargs)
val = cache_get(key, cache_name=cache_name)
if val is not None:
return val[0]
raise NotFoundInCache()
return func_with_caching
return decorator
def cache_with_key(
keyfunc: Callable[..., str], cache_name: Optional[str]=None,
timeout: Optional[int]=None, with_statsd_key: Optional[str]=None
) -> Callable[[Callable[..., ReturnT]], Callable[..., ReturnT]]:
"""Decorator which applies Django caching to a function.
Decorator argument is a function which computes a cache key
from the original function's arguments. You are responsible
for avoiding collisions with other uses of this decorator or
other uses of caching."""
def decorator(func: Callable[..., ReturnT]) -> Callable[..., ReturnT]:
@wraps(func)
def func_with_caching(*args: Any, **kwargs: Any) -> ReturnT:
key = keyfunc(*args, **kwargs)
val = cache_get(key, cache_name=cache_name)
extra = ""
if cache_name == 'database':
extra = ".dbcache"
if with_statsd_key is not None:
metric_key = with_statsd_key
else:
metric_key = statsd_key(key)
status = "hit" if val is not None else "miss"
statsd.incr("cache%s.%s.%s" % (extra, metric_key, status))
# Values are singleton tuples so that we can distinguish
# a result of None from a missing key.
if val is not None:
return val[0]
val = func(*args, **kwargs)
cache_set(key, val, cache_name=cache_name, timeout=timeout)
return val
return func_with_caching
return decorator
def cache_set(key: str, val: Any, cache_name: Optional[str]=None, timeout: Optional[int]=None) -> None:
remote_cache_stats_start()
cache_backend = get_cache_backend(cache_name)
cache_backend.set(KEY_PREFIX + key, (val,), timeout=timeout)
remote_cache_stats_finish()
def cache_get(key: str, cache_name: Optional[str]=None) -> Any:
remote_cache_stats_start()
cache_backend = get_cache_backend(cache_name)
ret = cache_backend.get(KEY_PREFIX + key)
remote_cache_stats_finish()
return ret
def cache_get_many(keys: List[str], cache_name: Optional[str]=None) -> Dict[str, Any]:
keys = [KEY_PREFIX + key for key in keys]
remote_cache_stats_start()
ret = get_cache_backend(cache_name).get_many(keys)
remote_cache_stats_finish()
return dict([(key[len(KEY_PREFIX):], value) for key, value in ret.items()])
def cache_set_many(items: Dict[str, Any], cache_name: Optional[str]=None,
timeout: Optional[int]=None) -> None:
new_items = {}
for key in items:
new_items[KEY_PREFIX + key] = items[key]
items = new_items
remote_cache_stats_start()
get_cache_backend(cache_name).set_many(items, timeout=timeout)
remote_cache_stats_finish()
def cache_delete(key: str, cache_name: Optional[str]=None) -> None:
remote_cache_stats_start()
get_cache_backend(cache_name).delete(KEY_PREFIX + key)
remote_cache_stats_finish()
def cache_delete_many(items: Iterable[str], cache_name: Optional[str]=None) -> None:
remote_cache_stats_start()
get_cache_backend(cache_name).delete_many(
KEY_PREFIX + item for item in items)
remote_cache_stats_finish()
# Generic_bulk_cached fetch and its helpers
ObjKT = TypeVar('ObjKT')
ItemT = TypeVar('ItemT')
CompressedItemT = TypeVar('CompressedItemT')
def default_extractor(obj: CompressedItemT) -> ItemT:
return obj # type: ignore # Need a type assert that ItemT=CompressedItemT
def default_setter(obj: ItemT) -> CompressedItemT:
return obj # type: ignore # Need a type assert that ItemT=CompressedItemT
def default_id_fetcher(obj: ItemT) -> ObjKT:
return obj.id # type: ignore # Need ItemT/CompressedItemT typevars to be a Django protocol
def default_cache_transformer(obj: ItemT) -> ItemT:
return obj
# Required Arguments are as follows:
# * object_ids: The list of object ids to look up
# * cache_key_function: object_id => cache key
# * query_function: [object_ids] => [objects from database]
# Optional keyword arguments:
# * setter: Function to call before storing items to cache (e.g. compression)
# * extractor: Function to call on items returned from cache
# (e.g. decompression). Should be the inverse of the setter
# function.
# * id_fetcher: Function mapping an object from database => object_id
# (in case we're using a key more complex than obj.id)
# * cache_transformer: Function mapping an object from database =>
# value for cache (in case the values that we're caching are some
# function of the objects, not the objects themselves)
def generic_bulk_cached_fetch(
cache_key_function: Callable[[ObjKT], str],
query_function: Callable[[List[ObjKT]], Iterable[Any]],
object_ids: Iterable[ObjKT],
extractor: Callable[[CompressedItemT], ItemT] = default_extractor,
setter: Callable[[ItemT], CompressedItemT] = default_setter,
id_fetcher: Callable[[ItemT], ObjKT] = default_id_fetcher,
cache_transformer: Callable[[ItemT], ItemT] = default_cache_transformer
) -> Dict[ObjKT, ItemT]:
cache_keys = {} # type: Dict[ObjKT, str]
for object_id in object_ids:
cache_keys[object_id] = cache_key_function(object_id)
cached_objects_compressed = cache_get_many([cache_keys[object_id]
for object_id in object_ids]) # type: Dict[str, Tuple[CompressedItemT]]
cached_objects = {} # type: Dict[str, ItemT]
for (key, val) in cached_objects_compressed.items():
cached_objects[key] = extractor(cached_objects_compressed[key][0])
needed_ids = [object_id for object_id in object_ids if
cache_keys[object_id] not in cached_objects]
db_objects = query_function(needed_ids)
items_for_remote_cache = {} # type: Dict[str, Tuple[CompressedItemT]]
for obj in db_objects:
key = cache_keys[id_fetcher(obj)]
item = cache_transformer(obj)
items_for_remote_cache[key] = (setter(item),)
cached_objects[key] = item
if len(items_for_remote_cache) > 0:
cache_set_many(items_for_remote_cache)
return dict((object_id, cached_objects[cache_keys[object_id]]) for object_id in object_ids
if cache_keys[object_id] in cached_objects)
def cache(func: Callable[..., ReturnT]) -> Callable[..., ReturnT]:
"""Decorator which applies Django caching to a function.
Uses a key based on the function's name, filename, and
the repr() of its arguments."""
func_uniqifier = '%s-%s' % (func.__code__.co_filename, func.__name__)
@wraps(func)
def keyfunc(*args: Any, **kwargs: Any) -> str:
# Django complains about spaces because memcached rejects them
key = func_uniqifier + repr((args, kwargs))
return key.replace('-', '--').replace(' ', '-s')
return cache_with_key(keyfunc)(func)
def display_recipient_cache_key(recipient_id: int) -> str:
return "display_recipient_dict:%d" % (recipient_id,)
def user_profile_by_email_cache_key(email: str) -> str:
# See the comment in zerver/lib/avatar_hash.py:gravatar_hash for why we
# are proactively encoding email addresses even though they will
# with high likelihood be ASCII-only for the foreseeable future.
return 'user_profile_by_email:%s' % (make_safe_digest(email.strip()),)
def user_profile_cache_key_id(email: str, realm_id: int) -> str:
return u"user_profile:%s:%s" % (make_safe_digest(email.strip()), realm_id,)
def user_profile_cache_key(email: str, realm: 'Realm') -> str:
return user_profile_cache_key_id(email, realm.id)
def bot_profile_cache_key(email: str) -> str:
return "bot_profile:%s" % (make_safe_digest(email.strip()))
def user_profile_by_id_cache_key(user_profile_id: int) -> str:
return "user_profile_by_id:%s" % (user_profile_id,)
def user_profile_by_api_key_cache_key(api_key: str) -> str:
return "user_profile_by_api_key:%s" % (api_key,)
realm_user_dict_fields = [
'id', 'full_name', 'short_name', 'email',
'avatar_source', 'avatar_version', 'is_active',
'is_realm_admin', 'is_bot', 'realm_id', 'timezone',
'date_joined'
] # type: List[str]
def realm_user_dicts_cache_key(realm_id: int) -> str:
return "realm_user_dicts:%s" % (realm_id,)
def active_user_ids_cache_key(realm_id: int) -> str:
return "active_user_ids:%s" % (realm_id,)
bot_dict_fields = ['id', 'full_name', 'short_name', 'bot_type', 'email',
'is_active', 'default_sending_stream__name',
'realm_id',
'default_events_register_stream__name',
'default_all_public_streams', 'api_key',
'bot_owner__email', 'avatar_source',
'avatar_version'] # type: List[str]
def bot_dicts_in_realm_cache_key(realm: 'Realm') -> str:
return "bot_dicts_in_realm:%s" % (realm.id,)
def get_stream_cache_key(stream_name: str, realm_id: int) -> str:
return "stream_by_realm_and_name:%s:%s" % (
realm_id, make_safe_digest(stream_name.strip().lower()))
def delete_user_profile_caches(user_profiles: Iterable['UserProfile']) -> None:
keys = []
for user_profile in user_profiles:
keys.append(user_profile_by_email_cache_key(user_profile.email))
keys.append(user_profile_by_id_cache_key(user_profile.id))
keys.append(user_profile_by_api_key_cache_key(user_profile.api_key))
keys.append(user_profile_cache_key(user_profile.email, user_profile.realm))
cache_delete_many(keys)
def delete_display_recipient_cache(user_profile: 'UserProfile') -> None:
from zerver.models import Subscription # We need to import here to avoid cyclic dependency.
recipient_ids = Subscription.objects.filter(user_profile=user_profile)
recipient_ids = recipient_ids.values_list('recipient_id', flat=True)
keys = [display_recipient_cache_key(rid) for rid in recipient_ids]
cache_delete_many(keys)
# Called by models.py to flush the user_profile cache whenever we save
# a user_profile object
def flush_user_profile(sender: Any, **kwargs: Any) -> None:
user_profile = kwargs['instance']
delete_user_profile_caches([user_profile])
def changed(fields: List[str]) -> bool:
if kwargs.get('update_fields') is None:
# adds/deletes should invalidate the cache
return True
update_fields = set(kwargs['update_fields'])
for f in fields:
if f in update_fields:
return True
return False
# Invalidate our active_users_in_realm info dict if any user has changed
# the fields in the dict or become (in)active
if changed(realm_user_dict_fields):
cache_delete(realm_user_dicts_cache_key(user_profile.realm_id))
if changed(['is_active']):
cache_delete(active_user_ids_cache_key(user_profile.realm_id))
if changed(['email', 'full_name', 'short_name', 'id', 'is_mirror_dummy']):
delete_display_recipient_cache(user_profile)
# Invalidate our bots_in_realm info dict if any bot has
# changed the fields in the dict or become (in)active
if user_profile.is_bot and changed(bot_dict_fields):
cache_delete(bot_dicts_in_realm_cache_key(user_profile.realm))
# Invalidate realm-wide alert words cache if any user in the realm has changed
# alert words
if changed(['alert_words']):
cache_delete(realm_alert_words_cache_key(user_profile.realm))
# Called by models.py to flush various caches whenever we save
# a Realm object. The main tricky thing here is that Realm info is
# generally cached indirectly through user_profile objects.
def flush_realm(sender: Any, **kwargs: Any) -> None:
realm = kwargs['instance']
users = realm.get_active_users()
delete_user_profile_caches(users)
# Deleting realm or updating message_visibility_limit
# attribute should clear the first_visible_message_id cache.
if kwargs.get('update_fields') is None or "message_visibility_limit" in kwargs['update_fields']:
cache_delete(realm_first_visible_message_id_cache_key(realm))
if realm.deactivated:
cache_delete(realm_user_dicts_cache_key(realm.id))
cache_delete(active_user_ids_cache_key(realm.id))
cache_delete(bot_dicts_in_realm_cache_key(realm))
cache_delete(realm_alert_words_cache_key(realm))
def realm_alert_words_cache_key(realm: 'Realm') -> str:
return "realm_alert_words:%s" % (realm.string_id,)
def realm_first_visible_message_id_cache_key(realm: 'Realm') -> str:
return u"realm_first_visible_message_id:%s" % (realm.string_id,)
# Called by models.py to flush the stream cache whenever we save a stream
# object.
def flush_stream(sender: Any, **kwargs: Any) -> None:
from zerver.models import UserProfile
stream = kwargs['instance']
items_for_remote_cache = {}
items_for_remote_cache[get_stream_cache_key(stream.name, stream.realm_id)] = (stream,)
cache_set_many(items_for_remote_cache)
if kwargs.get('update_fields') is None or 'name' in kwargs['update_fields'] and \
UserProfile.objects.filter(
Q(default_sending_stream=stream) |
Q(default_events_register_stream=stream)).exists():
cache_delete(bot_dicts_in_realm_cache_key(stream.realm))
def to_dict_cache_key_id(message_id: int) -> str:
return 'message_dict:%d' % (message_id,)
def to_dict_cache_key(message: 'Message') -> str:
return to_dict_cache_key_id(message.id)
def flush_message(sender: Any, **kwargs: Any) -> None:
message = kwargs['instance']
cache_delete(to_dict_cache_key_id(message.id))
def flush_submessage(sender: Any, **kwargs: Any) -> None:
submessage = kwargs['instance']
# submessages are not cached directly, they are part of their
# parent messages
message_id = submessage.message_id
cache_delete(to_dict_cache_key_id(message_id))
DECORATOR = Callable[[Callable[..., Any]], Callable[..., Any]]
def ignore_unhashable_lru_cache(maxsize: int=128, typed: bool=False) -> DECORATOR:
"""
This is a wrapper over lru_cache function. It adds following features on
top of lru_cache:
* It will not cache result of functions with unhashable arguments.
* It will clear cache whenever zerver.lib.cache.KEY_PREFIX changes.
"""
internal_decorator = lru_cache(maxsize=maxsize, typed=typed)
def decorator(user_function: Callable[..., Any]) -> Callable[..., Any]:
if settings.DEVELOPMENT and not settings.TEST_SUITE: # nocoverage
# In the development environment, we want every file
# change to refresh the source files from disk.
return user_function
cache_enabled_user_function = internal_decorator(user_function)
def wrapper(*args: Any, **kwargs: Any) -> Any:
if not hasattr(cache_enabled_user_function, 'key_prefix'):
cache_enabled_user_function.key_prefix = KEY_PREFIX
if cache_enabled_user_function.key_prefix != KEY_PREFIX:
# Clear cache when cache.KEY_PREFIX changes. This is used in
# tests.
cache_enabled_user_function.cache_clear()
cache_enabled_user_function.key_prefix = KEY_PREFIX
try:
return cache_enabled_user_function(*args, **kwargs)
except TypeError:
# args or kwargs contains an element which is unhashable. In
# this case we don't cache the result.
pass
# Deliberately calling this function from outside of exception
# handler to get a more descriptive traceback. Otherise traceback
# can include the exception from cached_enabled_user_function as
# well.
return user_function(*args, **kwargs)
setattr(wrapper, 'cache_info', cache_enabled_user_function.cache_info)
setattr(wrapper, 'cache_clear', cache_enabled_user_function.cache_clear)
return wrapper
return decorator