redis_utils: Require key_format argument in get_dict_from_redis.

This commit is contained in:
Mateusz Mandera 2020-01-26 19:01:56 +01:00 committed by Tim Abbott
parent ad460e6ccb
commit 92c16996fc
4 changed files with 33 additions and 8 deletions

View File

@ -2,6 +2,7 @@ from django.conf import settings
from typing import Any, Dict, Optional
from zerver.lib.utils import generate_random_token
import re
import redis
import ujson
@ -9,7 +10,13 @@ import ujson
# so we want to stay limited to 1024 characters.
MAX_KEY_LENGTH = 1024
class ZulipRedisKeyTooLongError(Exception):
class ZulipRedisError(Exception):
pass
class ZulipRedisKeyTooLongError(ZulipRedisError):
pass
class ZulipRedisKeyOfWrongFormatError(ZulipRedisError):
pass
def get_redis_client() -> redis.StrictRedis:
@ -33,11 +40,25 @@ def put_dict_in_redis(redis_client: redis.StrictRedis, key_format: str,
return key
def get_dict_from_redis(redis_client: redis.StrictRedis, key: str) -> Optional[Dict[str, Any]]:
def get_dict_from_redis(redis_client: redis.StrictRedis, key_format: str, key: str
) -> Optional[Dict[str, Any]]:
# This function requires inputting the intended key_format to validate
# that the key fits it, as an additionally security measure. This protects
# against bugs where a caller requests a key based on user input and doesn't
# validate it - which could potentially allow users to poke around arbitrary redis keys.
if len(key) > MAX_KEY_LENGTH:
error_msg = "Requested key too long in get_dict_from_redis: %s"
raise ZulipRedisKeyTooLongError(error_msg % (key,))
validate_key_fits_format(key, key_format)
data = redis_client.get(key)
if data is None:
return None
return ujson.loads(data)
def validate_key_fits_format(key: str, key_format: str) -> None:
assert "{token}" in key_format
regex = key_format.format(token=r"[a-z0-9]+")
if not re.fullmatch(regex, key):
raise ZulipRedisKeyOfWrongFormatError("%s does not match format %s" % (key, key_format))

View File

@ -4,7 +4,7 @@ import mock
from zerver.lib.test_classes import ZulipTestCase
from zerver.lib.redis_utils import get_redis_client, get_dict_from_redis, put_dict_in_redis, \
ZulipRedisKeyTooLongError, MAX_KEY_LENGTH
ZulipRedisKeyTooLongError, ZulipRedisKeyOfWrongFormatError, MAX_KEY_LENGTH
class RedisUtilsTest(ZulipTestCase):
key_format = "test_redis_utils_{token}"
@ -22,7 +22,7 @@ class RedisUtilsTest(ZulipTestCase):
}
key = put_dict_in_redis(self.redis_client, self.key_format, data,
expiration_seconds=self.expiration_seconds)
retrieved_data = get_dict_from_redis(self.redis_client, key)
retrieved_data = get_dict_from_redis(self.redis_client, self.key_format, key)
self.assertEqual(data, retrieved_data)
def test_put_data_key_length_check(self) -> None:
@ -35,7 +35,7 @@ class RedisUtilsTest(ZulipTestCase):
key = put_dict_in_redis(self.redis_client, self.key_format, data,
expiration_seconds=self.expiration_seconds,
token_length=max_valid_token_length)
retrieved_data = get_dict_from_redis(self.redis_client, key)
retrieved_data = get_dict_from_redis(self.redis_client, self.key_format, key)
self.assertEqual(data, retrieved_data)
# Trying to put data under an overly long key should get stopped before even
@ -49,4 +49,8 @@ class RedisUtilsTest(ZulipTestCase):
def test_get_data_key_length_check(self) -> None:
with self.assertRaises(ZulipRedisKeyTooLongError):
get_dict_from_redis(self.redis_client, 'A' * (MAX_KEY_LENGTH + 1))
get_dict_from_redis(self.redis_client, key_format='{token}', key='A' * (MAX_KEY_LENGTH + 1))
def test_get_data_key_format_validation(self) -> None:
with self.assertRaises(ZulipRedisKeyOfWrongFormatError):
get_dict_from_redis(self.redis_client, self.key_format, 'nonmatching_format_1111')

View File

@ -555,7 +555,7 @@ def store_login_data(data: Dict[str, Any]) -> str:
def get_login_data(token: str, should_delete: bool=True) -> Optional[Dict[str, Any]]:
key = LOGIN_KEY_FORMAT.format(token=token)
data = get_dict_from_redis(redis_client, key)
data = get_dict_from_redis(redis_client, LOGIN_KEY_FORMAT, key)
if data is not None and should_delete:
redis_client.delete(key)
return data

View File

@ -1337,7 +1337,7 @@ class SAMLAuthBackend(SocialAuthMixin, SAMLAuth):
data = None
if key.startswith('saml_token_'):
# Safety if statement, to not allow someone to poke around arbitrary redis keys here.
data = get_dict_from_redis(redis_client, key)
data = get_dict_from_redis(redis_client, "saml_token_{token}", key)
if data is None:
# TODO: We will need some sort of user-facing message
# about the authentication session having expired here.