mirror of https://github.com/zulip/zulip.git
redis_utils: Require key_format argument in get_dict_from_redis.
This commit is contained in:
parent
ad460e6ccb
commit
92c16996fc
|
@ -2,6 +2,7 @@ from django.conf import settings
|
|||
from typing import Any, Dict, Optional
|
||||
from zerver.lib.utils import generate_random_token
|
||||
|
||||
import re
|
||||
import redis
|
||||
import ujson
|
||||
|
||||
|
@ -9,7 +10,13 @@ import ujson
|
|||
# so we want to stay limited to 1024 characters.
|
||||
MAX_KEY_LENGTH = 1024
|
||||
|
||||
class ZulipRedisKeyTooLongError(Exception):
|
||||
class ZulipRedisError(Exception):
|
||||
pass
|
||||
|
||||
class ZulipRedisKeyTooLongError(ZulipRedisError):
|
||||
pass
|
||||
|
||||
class ZulipRedisKeyOfWrongFormatError(ZulipRedisError):
|
||||
pass
|
||||
|
||||
def get_redis_client() -> redis.StrictRedis:
|
||||
|
@ -33,11 +40,25 @@ def put_dict_in_redis(redis_client: redis.StrictRedis, key_format: str,
|
|||
|
||||
return key
|
||||
|
||||
def get_dict_from_redis(redis_client: redis.StrictRedis, key: str) -> Optional[Dict[str, Any]]:
|
||||
def get_dict_from_redis(redis_client: redis.StrictRedis, key_format: str, key: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
# This function requires inputting the intended key_format to validate
|
||||
# that the key fits it, as an additionally security measure. This protects
|
||||
# against bugs where a caller requests a key based on user input and doesn't
|
||||
# validate it - which could potentially allow users to poke around arbitrary redis keys.
|
||||
if len(key) > MAX_KEY_LENGTH:
|
||||
error_msg = "Requested key too long in get_dict_from_redis: %s"
|
||||
raise ZulipRedisKeyTooLongError(error_msg % (key,))
|
||||
validate_key_fits_format(key, key_format)
|
||||
|
||||
data = redis_client.get(key)
|
||||
if data is None:
|
||||
return None
|
||||
return ujson.loads(data)
|
||||
|
||||
def validate_key_fits_format(key: str, key_format: str) -> None:
|
||||
assert "{token}" in key_format
|
||||
regex = key_format.format(token=r"[a-z0-9]+")
|
||||
|
||||
if not re.fullmatch(regex, key):
|
||||
raise ZulipRedisKeyOfWrongFormatError("%s does not match format %s" % (key, key_format))
|
||||
|
|
|
@ -4,7 +4,7 @@ import mock
|
|||
|
||||
from zerver.lib.test_classes import ZulipTestCase
|
||||
from zerver.lib.redis_utils import get_redis_client, get_dict_from_redis, put_dict_in_redis, \
|
||||
ZulipRedisKeyTooLongError, MAX_KEY_LENGTH
|
||||
ZulipRedisKeyTooLongError, ZulipRedisKeyOfWrongFormatError, MAX_KEY_LENGTH
|
||||
|
||||
class RedisUtilsTest(ZulipTestCase):
|
||||
key_format = "test_redis_utils_{token}"
|
||||
|
@ -22,7 +22,7 @@ class RedisUtilsTest(ZulipTestCase):
|
|||
}
|
||||
key = put_dict_in_redis(self.redis_client, self.key_format, data,
|
||||
expiration_seconds=self.expiration_seconds)
|
||||
retrieved_data = get_dict_from_redis(self.redis_client, key)
|
||||
retrieved_data = get_dict_from_redis(self.redis_client, self.key_format, key)
|
||||
self.assertEqual(data, retrieved_data)
|
||||
|
||||
def test_put_data_key_length_check(self) -> None:
|
||||
|
@ -35,7 +35,7 @@ class RedisUtilsTest(ZulipTestCase):
|
|||
key = put_dict_in_redis(self.redis_client, self.key_format, data,
|
||||
expiration_seconds=self.expiration_seconds,
|
||||
token_length=max_valid_token_length)
|
||||
retrieved_data = get_dict_from_redis(self.redis_client, key)
|
||||
retrieved_data = get_dict_from_redis(self.redis_client, self.key_format, key)
|
||||
self.assertEqual(data, retrieved_data)
|
||||
|
||||
# Trying to put data under an overly long key should get stopped before even
|
||||
|
@ -49,4 +49,8 @@ class RedisUtilsTest(ZulipTestCase):
|
|||
|
||||
def test_get_data_key_length_check(self) -> None:
|
||||
with self.assertRaises(ZulipRedisKeyTooLongError):
|
||||
get_dict_from_redis(self.redis_client, 'A' * (MAX_KEY_LENGTH + 1))
|
||||
get_dict_from_redis(self.redis_client, key_format='{token}', key='A' * (MAX_KEY_LENGTH + 1))
|
||||
|
||||
def test_get_data_key_format_validation(self) -> None:
|
||||
with self.assertRaises(ZulipRedisKeyOfWrongFormatError):
|
||||
get_dict_from_redis(self.redis_client, self.key_format, 'nonmatching_format_1111')
|
||||
|
|
|
@ -555,7 +555,7 @@ def store_login_data(data: Dict[str, Any]) -> str:
|
|||
|
||||
def get_login_data(token: str, should_delete: bool=True) -> Optional[Dict[str, Any]]:
|
||||
key = LOGIN_KEY_FORMAT.format(token=token)
|
||||
data = get_dict_from_redis(redis_client, key)
|
||||
data = get_dict_from_redis(redis_client, LOGIN_KEY_FORMAT, key)
|
||||
if data is not None and should_delete:
|
||||
redis_client.delete(key)
|
||||
return data
|
||||
|
|
|
@ -1337,7 +1337,7 @@ class SAMLAuthBackend(SocialAuthMixin, SAMLAuth):
|
|||
data = None
|
||||
if key.startswith('saml_token_'):
|
||||
# Safety if statement, to not allow someone to poke around arbitrary redis keys here.
|
||||
data = get_dict_from_redis(redis_client, key)
|
||||
data = get_dict_from_redis(redis_client, "saml_token_{token}", key)
|
||||
if data is None:
|
||||
# TODO: We will need some sort of user-facing message
|
||||
# about the authentication session having expired here.
|
||||
|
|
Loading…
Reference in New Issue