mirror of https://github.com/zulip/zulip.git
cache: Validate keys before passing them to memcached.
Fixes #13504. This commit is purely an improvement in error handling. We used to not do any validation on keys before passing them to memcached, which meant for invalid keys, memcached's own key validation would throw an exception. Unfortunately, the resulting error messages are super hard to read; the traceback structure doesn't even show where the call into memcached happened. In this commit we add validation to all the basic cache_* functions, and appropriate handling in their callers. We also add a lot of tests for the new behavior, which has the nice effect of giving us decent coverage of all these core caching functions which previously had been primarily tested manually.
This commit is contained in:
parent
5bb84a2255
commit
4f2897fafc
|
@ -15,8 +15,11 @@ from typing import Any, Callable, Dict, Iterable, List, \
|
|||
from zerver.lib.utils import statsd, statsd_key, make_safe_digest
|
||||
import time
|
||||
import base64
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
import traceback
|
||||
import os
|
||||
import hashlib
|
||||
|
||||
|
@ -25,8 +28,12 @@ if TYPE_CHECKING:
|
|||
# they cannot be imported at runtime due to cyclic dependency.
|
||||
from zerver.models import UserProfile, Realm, Message
|
||||
|
||||
MEMCACHED_MAX_KEY_LENGTH = 250
|
||||
|
||||
ReturnT = TypeVar('ReturnT') # Useful for matching return types via Callable[..., ReturnT]
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
class NotFoundInCache(Exception):
|
||||
pass
|
||||
|
||||
|
@ -100,7 +107,7 @@ def bounce_key_prefix_for_testing(test_name: str) -> None:
|
|||
global KEY_PREFIX
|
||||
KEY_PREFIX = test_name + ':' + str(os.getpid()) + ':'
|
||||
# We are taking the hash of the KEY_PREFIX to decrease the size of the key.
|
||||
# Memcached keys should have a length of less than 256.
|
||||
# Memcached keys should have a length of less than 250.
|
||||
KEY_PREFIX = hashlib.sha1(KEY_PREFIX.encode('utf-8')).hexdigest() + ":"
|
||||
|
||||
def get_cache_backend(cache_name: Optional[str]) -> BaseCache:
|
||||
|
@ -121,7 +128,13 @@ def get_cache_with_key(
|
|||
@wraps(func)
|
||||
def func_with_caching(*args: Any, **kwargs: Any) -> Callable[..., ReturnT]:
|
||||
key = keyfunc(*args, **kwargs)
|
||||
val = cache_get(key, cache_name=cache_name)
|
||||
try:
|
||||
val = cache_get(key, cache_name=cache_name)
|
||||
except InvalidCacheKeyException:
|
||||
stack_trace = traceback.format_exc()
|
||||
log_invalid_cache_keys(stack_trace, [key])
|
||||
val = None
|
||||
|
||||
if val is not None:
|
||||
return val[0]
|
||||
raise NotFoundInCache()
|
||||
|
@ -146,7 +159,12 @@ def cache_with_key(
|
|||
def func_with_caching(*args: Any, **kwargs: Any) -> ReturnT:
|
||||
key = keyfunc(*args, **kwargs)
|
||||
|
||||
val = cache_get(key, cache_name=cache_name)
|
||||
try:
|
||||
val = cache_get(key, cache_name=cache_name)
|
||||
except InvalidCacheKeyException:
|
||||
stack_trace = traceback.format_exc()
|
||||
log_invalid_cache_keys(stack_trace, [key])
|
||||
return func(*args, **kwargs)
|
||||
|
||||
extra = ""
|
||||
if cache_name == 'database':
|
||||
|
@ -175,47 +193,134 @@ def cache_with_key(
|
|||
|
||||
return decorator
|
||||
|
||||
class InvalidCacheKeyException(Exception):
|
||||
pass
|
||||
|
||||
def log_invalid_cache_keys(stack_trace: str, key: List[str]) -> None:
|
||||
logger.warning(
|
||||
"Invalid cache key used: {}\nStack trace: {}\n".format(key, stack_trace)
|
||||
)
|
||||
|
||||
def validate_cache_key(key: str) -> None:
|
||||
if not key.startswith(KEY_PREFIX):
|
||||
key = KEY_PREFIX + key
|
||||
|
||||
# Theoretically memcached can handle non-ascii characters
|
||||
# and only "control" characters are strictly disallowed, see:
|
||||
# https://github.com/memcached/memcached/blob/master/doc/protocol.txt
|
||||
# However, limiting the characters we allow in keys simiplifies things,
|
||||
# and anyway we use make_safe_digest when forming some keys to ensure
|
||||
# the resulting keys fit the regex below.
|
||||
# The regex checks "all characters between ! and ~ in the ascii table",
|
||||
# which happens to be the set of all "nice" ascii characters.
|
||||
if not bool(re.fullmatch(r"([!-~])+", key)):
|
||||
raise InvalidCacheKeyException("Invalid characters in the cache key: " + key)
|
||||
if len(key) > MEMCACHED_MAX_KEY_LENGTH:
|
||||
raise InvalidCacheKeyException("Cache key too long: {} Length: {}".format(key, len(key)))
|
||||
|
||||
def cache_set(key: str, val: Any, cache_name: Optional[str]=None, timeout: Optional[int]=None) -> None:
|
||||
final_key = KEY_PREFIX + key
|
||||
validate_cache_key(final_key)
|
||||
|
||||
remote_cache_stats_start()
|
||||
cache_backend = get_cache_backend(cache_name)
|
||||
cache_backend.set(KEY_PREFIX + key, (val,), timeout=timeout)
|
||||
cache_backend.set(final_key, (val,), timeout=timeout)
|
||||
remote_cache_stats_finish()
|
||||
|
||||
def cache_get(key: str, cache_name: Optional[str]=None) -> Any:
|
||||
final_key = KEY_PREFIX + key
|
||||
validate_cache_key(final_key)
|
||||
|
||||
remote_cache_stats_start()
|
||||
cache_backend = get_cache_backend(cache_name)
|
||||
ret = cache_backend.get(KEY_PREFIX + key)
|
||||
ret = cache_backend.get(final_key)
|
||||
remote_cache_stats_finish()
|
||||
return ret
|
||||
|
||||
def cache_get_many(keys: List[str], cache_name: Optional[str]=None) -> Dict[str, Any]:
|
||||
keys = [KEY_PREFIX + key for key in keys]
|
||||
for key in keys:
|
||||
validate_cache_key(key)
|
||||
remote_cache_stats_start()
|
||||
ret = get_cache_backend(cache_name).get_many(keys)
|
||||
remote_cache_stats_finish()
|
||||
return dict([(key[len(KEY_PREFIX):], value) for key, value in ret.items()])
|
||||
|
||||
def safe_cache_get_many(keys: List[str], cache_name: Optional[str]=None) -> Dict[str, Any]:
|
||||
"""Variant of cache_get_many that drops any keys that fail
|
||||
validation, rather than throwing an exception visible to the
|
||||
caller."""
|
||||
try:
|
||||
# Almost always the keys will all be correct, so we just try
|
||||
# to do normal cache_get_many to avoid the overhead of
|
||||
# validating all the keys here.
|
||||
return cache_get_many(keys, cache_name)
|
||||
except InvalidCacheKeyException:
|
||||
stack_trace = traceback.format_exc()
|
||||
good_keys, bad_keys = filter_good_and_bad_keys(keys)
|
||||
|
||||
log_invalid_cache_keys(stack_trace, bad_keys)
|
||||
return cache_get_many(good_keys, cache_name)
|
||||
|
||||
def cache_set_many(items: Dict[str, Any], cache_name: Optional[str]=None,
|
||||
timeout: Optional[int]=None) -> None:
|
||||
new_items = {}
|
||||
for key in items:
|
||||
new_items[KEY_PREFIX + key] = items[key]
|
||||
new_key = KEY_PREFIX + key
|
||||
validate_cache_key(new_key)
|
||||
new_items[new_key] = items[key]
|
||||
items = new_items
|
||||
remote_cache_stats_start()
|
||||
get_cache_backend(cache_name).set_many(items, timeout=timeout)
|
||||
remote_cache_stats_finish()
|
||||
|
||||
def safe_cache_set_many(items: Dict[str, Any], cache_name: Optional[str]=None,
|
||||
timeout: Optional[int]=None) -> None:
|
||||
"""Variant of cache_set_many that drops saving any keys that fail
|
||||
validation, rather than throwing an exception visible to the
|
||||
caller."""
|
||||
try:
|
||||
# Almost always the keys will all be correct, so we just try
|
||||
# to do normal cache_set_many to avoid the overhead of
|
||||
# validating all the keys here.
|
||||
return cache_set_many(items, cache_name, timeout)
|
||||
except InvalidCacheKeyException:
|
||||
stack_trace = traceback.format_exc()
|
||||
|
||||
good_keys, bad_keys = filter_good_and_bad_keys(list(items.keys()))
|
||||
log_invalid_cache_keys(stack_trace, bad_keys)
|
||||
|
||||
good_items = dict((key, items[key]) for key in good_keys)
|
||||
return cache_set_many(good_items, cache_name, timeout)
|
||||
|
||||
def cache_delete(key: str, cache_name: Optional[str]=None) -> None:
|
||||
final_key = KEY_PREFIX + key
|
||||
validate_cache_key(final_key)
|
||||
|
||||
remote_cache_stats_start()
|
||||
get_cache_backend(cache_name).delete(KEY_PREFIX + key)
|
||||
get_cache_backend(cache_name).delete(final_key)
|
||||
remote_cache_stats_finish()
|
||||
|
||||
def cache_delete_many(items: Iterable[str], cache_name: Optional[str]=None) -> None:
|
||||
keys = [KEY_PREFIX + item for item in items]
|
||||
for key in keys:
|
||||
validate_cache_key(key)
|
||||
remote_cache_stats_start()
|
||||
get_cache_backend(cache_name).delete_many(
|
||||
KEY_PREFIX + item for item in items)
|
||||
get_cache_backend(cache_name).delete_many(keys)
|
||||
remote_cache_stats_finish()
|
||||
|
||||
def filter_good_and_bad_keys(keys: List[str]) -> Tuple[List[str], List[str]]:
|
||||
good_keys = []
|
||||
bad_keys = []
|
||||
for key in keys:
|
||||
try:
|
||||
validate_cache_key(key)
|
||||
good_keys.append(key)
|
||||
except InvalidCacheKeyException:
|
||||
bad_keys.append(key)
|
||||
|
||||
return good_keys, bad_keys
|
||||
|
||||
# Generic_bulk_cached fetch and its helpers. We start with declaring
|
||||
# a few type variables that help define its interface.
|
||||
|
||||
|
@ -275,8 +380,10 @@ def generic_bulk_cached_fetch(
|
|||
cache_keys = {} # type: Dict[ObjKT, str]
|
||||
for object_id in object_ids:
|
||||
cache_keys[object_id] = cache_key_function(object_id)
|
||||
cached_objects_compressed = cache_get_many([cache_keys[object_id]
|
||||
for object_id in object_ids]) # type: Dict[str, Tuple[CompressedItemT]]
|
||||
|
||||
cached_objects_compressed = safe_cache_get_many([cache_keys[object_id]
|
||||
for object_id in object_ids]) # type: Dict[str, Tuple[CompressedItemT]]
|
||||
|
||||
cached_objects = {} # type: Dict[str, CacheItemT]
|
||||
for (key, val) in cached_objects_compressed.items():
|
||||
cached_objects[key] = extractor(cached_objects_compressed[key][0])
|
||||
|
@ -296,7 +403,7 @@ def generic_bulk_cached_fetch(
|
|||
items_for_remote_cache[key] = (setter(item),)
|
||||
cached_objects[key] = item
|
||||
if len(items_for_remote_cache) > 0:
|
||||
cache_set_many(items_for_remote_cache)
|
||||
safe_cache_set_many(items_for_remote_cache)
|
||||
return dict((object_id, cached_objects[cache_keys[object_id]]) for object_id in object_ids
|
||||
if cache_keys[object_id] in cached_objects)
|
||||
|
||||
|
|
|
@ -1,11 +1,15 @@
|
|||
from django.conf import settings
|
||||
|
||||
from mock import Mock, patch
|
||||
from typing import List, Dict
|
||||
from typing import Any, List, Dict
|
||||
|
||||
from zerver.apps import flush_cache
|
||||
from zerver.lib.cache import generic_bulk_cached_fetch, user_profile_by_email_cache_key
|
||||
from zerver.lib.cache import generic_bulk_cached_fetch, user_profile_by_email_cache_key, cache_with_key, \
|
||||
validate_cache_key, InvalidCacheKeyException, MEMCACHED_MAX_KEY_LENGTH, get_cache_with_key, \
|
||||
NotFoundInCache, cache_set, cache_get, cache_delete, cache_delete_many, cache_get_many, cache_set_many, \
|
||||
safe_cache_get_many, safe_cache_set_many
|
||||
from zerver.lib.test_classes import ZulipTestCase
|
||||
from zerver.lib.test_helpers import queries_captured
|
||||
from zerver.models import get_system_bot, get_user_profile_by_email, UserProfile
|
||||
|
||||
class AppsTest(ZulipTestCase):
|
||||
|
@ -17,6 +21,194 @@ class AppsTest(ZulipTestCase):
|
|||
mock.assert_called_once()
|
||||
mock_logging.assert_called_once()
|
||||
|
||||
class CacheKeyValidationTest(ZulipTestCase):
|
||||
def test_validate_cache_key(self) -> None:
|
||||
validate_cache_key('nice_Ascii:string!~')
|
||||
with self.assertRaises(InvalidCacheKeyException):
|
||||
validate_cache_key('utf8_character:ą')
|
||||
with self.assertRaises(InvalidCacheKeyException):
|
||||
validate_cache_key('new_line_character:\n')
|
||||
with self.assertRaises(InvalidCacheKeyException):
|
||||
validate_cache_key('control_character:\r')
|
||||
with self.assertRaises(InvalidCacheKeyException):
|
||||
validate_cache_key('whitespace_character: ')
|
||||
with self.assertRaises(InvalidCacheKeyException):
|
||||
validate_cache_key('too_long:' + 'X'*MEMCACHED_MAX_KEY_LENGTH)
|
||||
|
||||
with self.assertRaises(InvalidCacheKeyException):
|
||||
# validate_cache_key does validation on a key with the
|
||||
# KEY_PREFIX appended to the start, so even though we're
|
||||
# passing something "short enough" here, it becomes too
|
||||
# long after appending KEY_PREFIX.
|
||||
validate_cache_key('X' * (MEMCACHED_MAX_KEY_LENGTH - 2))
|
||||
|
||||
def test_cache_functions_raise_exception(self) -> None:
|
||||
invalid_key = 'invalid_character:\n'
|
||||
good_key = "good_key"
|
||||
with self.assertRaises(InvalidCacheKeyException):
|
||||
cache_get(invalid_key)
|
||||
with self.assertRaises(InvalidCacheKeyException):
|
||||
cache_set(invalid_key, 0)
|
||||
with self.assertRaises(InvalidCacheKeyException):
|
||||
cache_delete(invalid_key)
|
||||
|
||||
with self.assertRaises(InvalidCacheKeyException):
|
||||
cache_get_many([good_key, invalid_key])
|
||||
with self.assertRaises(InvalidCacheKeyException):
|
||||
cache_set_many({good_key: 0, invalid_key: 1})
|
||||
with self.assertRaises(InvalidCacheKeyException):
|
||||
cache_delete_many([good_key, invalid_key])
|
||||
|
||||
class CacheWithKeyDecoratorTest(ZulipTestCase):
|
||||
def test_cache_with_key_invalid_character(self) -> None:
|
||||
def invalid_characters_cache_key_function(user_id: int) -> str:
|
||||
return 'CacheWithKeyDecoratorTest:invalid_character:ą:{}'.format(user_id)
|
||||
|
||||
@cache_with_key(invalid_characters_cache_key_function, timeout=1000)
|
||||
def get_user_function_with_bad_cache_keys(user_id: int) -> UserProfile:
|
||||
return UserProfile.objects.get(id=user_id)
|
||||
|
||||
hamlet = self.example_user('hamlet')
|
||||
with patch('zerver.lib.cache.cache_set') as mock_set, \
|
||||
patch('zerver.lib.cache.logger.warning') as mock_warn:
|
||||
with queries_captured() as queries:
|
||||
result = get_user_function_with_bad_cache_keys(hamlet.id)
|
||||
|
||||
self.assertEqual(result, hamlet)
|
||||
self.assert_length(queries, 1)
|
||||
mock_set.assert_not_called()
|
||||
mock_warn.assert_called_once()
|
||||
|
||||
def test_cache_with_key_key_too_long(self) -> None:
|
||||
def too_long_cache_key_function(user_id: int) -> str:
|
||||
return 'CacheWithKeyDecoratorTest:very_long_key:{}:{}'.format('a'*250, user_id)
|
||||
|
||||
@cache_with_key(too_long_cache_key_function, timeout=1000)
|
||||
def get_user_function_with_bad_cache_keys(user_id: int) -> UserProfile:
|
||||
return UserProfile.objects.get(id=user_id)
|
||||
|
||||
hamlet = self.example_user('hamlet')
|
||||
|
||||
with patch('zerver.lib.cache.cache_set') as mock_set, \
|
||||
patch('zerver.lib.cache.logger.warning') as mock_warn:
|
||||
with queries_captured() as queries:
|
||||
result = get_user_function_with_bad_cache_keys(hamlet.id)
|
||||
|
||||
self.assertEqual(result, hamlet)
|
||||
self.assert_length(queries, 1)
|
||||
mock_set.assert_not_called()
|
||||
mock_warn.assert_called_once()
|
||||
|
||||
def test_cache_with_key_good_key(self) -> None:
|
||||
def good_cache_key_function(user_id: int) -> str:
|
||||
return 'CacheWithKeyDecoratorTest:good_cache_key:{}'.format(user_id)
|
||||
|
||||
@cache_with_key(good_cache_key_function, timeout=1000)
|
||||
def get_user_function_with_good_cache_keys(user_id: int) -> UserProfile:
|
||||
return UserProfile.objects.get(id=user_id)
|
||||
|
||||
hamlet = self.example_user('hamlet')
|
||||
|
||||
with queries_captured() as queries:
|
||||
result = get_user_function_with_good_cache_keys(hamlet.id)
|
||||
|
||||
self.assertEqual(result, hamlet)
|
||||
self.assert_length(queries, 1)
|
||||
|
||||
# The previous function call should have cached the result correctly, so now
|
||||
# no database queries should happen:
|
||||
with queries_captured() as queries_two:
|
||||
result_two = get_user_function_with_good_cache_keys(hamlet.id)
|
||||
|
||||
self.assertEqual(result_two, hamlet)
|
||||
self.assert_length(queries_two, 0)
|
||||
|
||||
class GetCacheWithKeyDecoratorTest(ZulipTestCase):
|
||||
def test_get_cache_with_good_key(self) -> None:
|
||||
# Test with a good cache key function, but a get_user function
|
||||
# that always returns None just to make it convenient to tell
|
||||
# whether the cache was used (whatever we put in the cache) or
|
||||
# we got the result from calling the function (None)
|
||||
|
||||
def good_cache_key_function(user_id: int) -> str:
|
||||
return 'CacheWithKeyDecoratorTest:good_cache_key:{}'.format(user_id)
|
||||
|
||||
@get_cache_with_key(good_cache_key_function)
|
||||
def get_user_function_with_good_cache_keys(user_id: int) -> Any: # nocoverage
|
||||
return
|
||||
|
||||
hamlet = self.example_user('hamlet')
|
||||
with patch('zerver.lib.cache.logger.warning') as mock_warn:
|
||||
with self.assertRaises(NotFoundInCache):
|
||||
get_user_function_with_good_cache_keys(hamlet.id)
|
||||
mock_warn.assert_not_called()
|
||||
|
||||
cache_set(good_cache_key_function(hamlet.id), hamlet)
|
||||
result = get_user_function_with_good_cache_keys(hamlet.id)
|
||||
self.assertEqual(result, hamlet)
|
||||
|
||||
def test_get_cache_with_bad_key(self) -> None:
|
||||
def bad_cache_key_function(user_id: int) -> str:
|
||||
return 'CacheWithKeyDecoratorTest:invalid_character:ą:{}'.format(user_id)
|
||||
|
||||
@get_cache_with_key(bad_cache_key_function)
|
||||
def get_user_function_with_bad_cache_keys(user_id: int) -> Any: # nocoverage
|
||||
return
|
||||
|
||||
hamlet = self.example_user('hamlet')
|
||||
with patch('zerver.lib.cache.logger.warning') as mock_warn:
|
||||
with self.assertRaises(NotFoundInCache):
|
||||
get_user_function_with_bad_cache_keys(hamlet.id)
|
||||
mock_warn.assert_called_once()
|
||||
|
||||
class SafeCacheFunctionsTest(ZulipTestCase):
|
||||
def test_safe_cache_functions_with_all_good_keys(self) -> None:
|
||||
items = {"SafeFunctionsTest:key1": 1, "SafeFunctionsTest:key2": 2, "SafeFunctionsTest:key3": 3}
|
||||
safe_cache_set_many(items)
|
||||
|
||||
result = safe_cache_get_many(list(items.keys()))
|
||||
for key, value in result.items():
|
||||
self.assertEqual(value, items[key])
|
||||
|
||||
def test_safe_cache_functions_with_all_bad_keys(self) -> None:
|
||||
items = {"SafeFunctionsTest:\nbadkey1": 1, "SafeFunctionsTest:\nbadkey2": 2}
|
||||
with patch('zerver.lib.cache.logger.warning') as mock_warn:
|
||||
safe_cache_set_many(items)
|
||||
mock_warn.assert_called_once()
|
||||
warning_string = mock_warn.call_args[0][0]
|
||||
self.assertIn("badkey1", warning_string)
|
||||
self.assertIn("badkey2", warning_string)
|
||||
|
||||
with patch('zerver.lib.cache.logger.warning') as mock_warn:
|
||||
result = safe_cache_get_many(list(items.keys()))
|
||||
mock_warn.assert_called_once()
|
||||
warning_string = mock_warn.call_args[0][0]
|
||||
self.assertIn("badkey1", warning_string)
|
||||
self.assertIn("badkey2", warning_string)
|
||||
|
||||
self.assertEqual(result, {})
|
||||
|
||||
def test_safe_cache_functions_with_good_and_bad_keys(self) -> None:
|
||||
bad_items = {"SafeFunctionsTest:\nbadkey1": 1, "SafeFunctionsTest:\nbadkey2": 2}
|
||||
good_items = {"SafeFunctionsTest:goodkey1": 3, "SafeFunctionsTest:goodkey2": 4}
|
||||
items = {**good_items, **bad_items}
|
||||
|
||||
with patch('zerver.lib.cache.logger.warning') as mock_warn:
|
||||
safe_cache_set_many(items)
|
||||
mock_warn.assert_called_once()
|
||||
warning_string = mock_warn.call_args[0][0]
|
||||
self.assertIn("badkey1", warning_string)
|
||||
self.assertIn("badkey2", warning_string)
|
||||
|
||||
with patch('zerver.lib.cache.logger.warning') as mock_warn:
|
||||
result = safe_cache_get_many(list(items.keys()))
|
||||
mock_warn.assert_called_once()
|
||||
warning_string = mock_warn.call_args[0][0]
|
||||
self.assertIn("badkey1", warning_string)
|
||||
self.assertIn("badkey2", warning_string)
|
||||
|
||||
self.assertEqual(result, good_items)
|
||||
|
||||
class BotCacheKeyTest(ZulipTestCase):
|
||||
def test_bot_profile_key_deleted_on_save(self) -> None:
|
||||
# Get the profile cached on both cache keys:
|
||||
|
|
Loading…
Reference in New Issue