diff --git a/zerver/lib/redis_utils.py b/zerver/lib/redis_utils.py index 585297b222..5bb569b008 100644 --- a/zerver/lib/redis_utils.py +++ b/zerver/lib/redis_utils.py @@ -5,6 +5,13 @@ from zerver.lib.utils import generate_random_token import redis import ujson +# Redis accepts keys up to 512MB in size, but there's no reason for us to use such size, +# so we want to stay limited to 1024 characters. +MAX_KEY_LENGTH = 1024 + +class ZulipRedisKeyTooLongError(Exception): + pass + def get_redis_client() -> redis.StrictRedis: return redis.StrictRedis(host=settings.REDIS_HOST, port=settings.REDIS_PORT, password=settings.REDIS_PASSWORD, db=0) @@ -13,9 +20,13 @@ def put_dict_in_redis(redis_client: redis.StrictRedis, key_format: str, data_to_store: Dict[str, Any], expiration_seconds: int, token_length: int=64) -> str: + key_length = len(key_format) - len('{token}') + token_length + if key_length > MAX_KEY_LENGTH: + error_msg = "Requested key too long in put_dict_in_redis. Key format: %s, token length: %s" + raise ZulipRedisKeyTooLongError(error_msg % (key_format, token_length)) + token = generate_random_token(token_length) + key = key_format.format(token=token) with redis_client.pipeline() as pipeline: - token = generate_random_token(token_length) - key = key_format.format(token=token) pipeline.set(key, ujson.dumps(data_to_store)) pipeline.expire(key, expiration_seconds) pipeline.execute() @@ -23,6 +34,9 @@ def put_dict_in_redis(redis_client: redis.StrictRedis, key_format: str, return key def get_dict_from_redis(redis_client: redis.StrictRedis, key: str) -> Optional[Dict[str, Any]]: + if len(key) > MAX_KEY_LENGTH: + error_msg = "Requested key too long in get_dict_from_redis: %s" + raise ZulipRedisKeyTooLongError(error_msg % (key,)) data = redis_client.get(key) if data is None: return None diff --git a/zerver/tests/test_redis_utils.py b/zerver/tests/test_redis_utils.py new file mode 100644 index 0000000000..7e2a186c9c --- /dev/null +++ b/zerver/tests/test_redis_utils.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- + +import mock + +from zerver.lib.test_classes import ZulipTestCase +from zerver.lib.redis_utils import get_redis_client, get_dict_from_redis, put_dict_in_redis, \ + ZulipRedisKeyTooLongError, MAX_KEY_LENGTH + +class RedisUtilsTest(ZulipTestCase): + key_format = "test_redis_utils_{token}" + expiration_seconds = 60 + + @classmethod + def setUpClass(cls) -> None: + cls.redis_client = get_redis_client() + return super().setUpClass() + + def test_put_and_get_data(self) -> None: + data = { + "a": 1, + "b": "some value" + } + key = put_dict_in_redis(self.redis_client, self.key_format, data, + expiration_seconds=self.expiration_seconds) + retrieved_data = get_dict_from_redis(self.redis_client, key) + self.assertEqual(data, retrieved_data) + + def test_put_data_key_length_check(self) -> None: + data = { + "a": 1, + "b": "some value" + } + + max_valid_token_length = MAX_KEY_LENGTH - (len(self.key_format) - len('{token}')) + key = put_dict_in_redis(self.redis_client, self.key_format, data, + expiration_seconds=self.expiration_seconds, + token_length=max_valid_token_length) + retrieved_data = get_dict_from_redis(self.redis_client, key) + self.assertEqual(data, retrieved_data) + + # Trying to put data under an overly long key should get stopped before even + # generating the random token. + with mock.patch("zerver.lib.redis_utils.generate_random_token") as mock_generate: + with self.assertRaises(ZulipRedisKeyTooLongError): + put_dict_in_redis(self.redis_client, self.key_format, data, + expiration_seconds=self.expiration_seconds, + token_length=max_valid_token_length + 1) + mock_generate.assert_not_called() + + def test_get_data_key_length_check(self) -> None: + with self.assertRaises(ZulipRedisKeyTooLongError): + get_dict_from_redis(self.redis_client, 'A' * (MAX_KEY_LENGTH + 1))