redis: Extract put_dict_in_redis and get_dict_from_redis helpers.

This commit is contained in:
Mateusz Mandera 2020-01-20 14:17:53 +01:00 committed by Tim Abbott
parent eafdefc5c9
commit af2c4a9735
2 changed files with 29 additions and 16 deletions

View File

@ -1,7 +1,28 @@
from django.conf import settings
from typing import Any, Dict, Optional
from zerver.lib.utils import generate_random_token
import redis
import ujson
def get_redis_client() -> redis.StrictRedis:
return redis.StrictRedis(host=settings.REDIS_HOST, port=settings.REDIS_PORT,
password=settings.REDIS_PASSWORD, db=0)
def put_dict_in_redis(redis_client: redis.StrictRedis, key_format: str,
data_to_store: Dict[str, Any],
expiration_seconds: int) -> str:
with redis_client.pipeline() as pipeline:
token = generate_random_token(64)
key = key_format.format(token=token)
pipeline.set(key, ujson.dumps(data_to_store))
pipeline.expire(key, expiration_seconds)
pipeline.execute()
return key
def get_dict_from_redis(redis_client: redis.StrictRedis, key: str) -> Optional[Dict[str, Any]]:
data = redis_client.get(key)
if data is None:
return None
return ujson.loads(data)

View File

@ -15,7 +15,6 @@
import copy
import logging
import magic
import ujson
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
from typing_extensions import TypedDict
@ -50,8 +49,7 @@ from zerver.lib.avatar_hash import user_avatar_content_hash
from zerver.lib.dev_ldap_directory import init_fakeldap
from zerver.lib.request import JsonableError
from zerver.lib.users import check_full_name, validate_user_custom_profile_field
from zerver.lib.utils import generate_random_token
from zerver.lib.redis_utils import get_redis_client
from zerver.lib.redis_utils import get_redis_client, get_dict_from_redis, put_dict_in_redis
from zerver.models import CustomProfileField, DisposableEmailError, DomainNotAllowedForRealmError, \
EmailContainsPlusError, PreregistrationUser, UserProfile, Realm, custom_profile_fields_for_realm, \
email_allowed_for_realm, get_user_profile_by_id, remote_user_to_email, \
@ -1318,28 +1316,22 @@ class SAMLAuthBackend(SocialAuthMixin, SAMLAuth):
@classmethod
def put_data_in_redis(cls, data_to_relay: Dict[str, Any]) -> str:
with redis_client.pipeline() as pipeline:
token = generate_random_token(64)
key = "saml_token_{}".format(token)
pipeline.set(key, ujson.dumps(data_to_relay))
pipeline.expire(key, cls.REDIS_EXPIRATION_SECONDS)
pipeline.execute()
return key
return put_dict_in_redis(redis_client, "saml_token_{token}",
data_to_store=data_to_relay,
expiration_seconds=cls.REDIS_EXPIRATION_SECONDS)
@classmethod
def get_data_from_redis(cls, key: str) -> Optional[Dict[str, Any]]:
redis_data = None
data = None
if key.startswith('saml_token_'):
# Safety if statement, to not allow someone to poke around arbitrary redis keys here.
redis_data = redis_client.get(key)
if redis_data is None:
data = get_dict_from_redis(redis_client, key)
if data is None:
# TODO: We will need some sort of user-facing message
# about the authentication session having expired here.
logging.info("SAML authentication failed: bad RelayState token.")
return None
return ujson.loads(redis_data)
return data
def auth_complete(self, *args: Any, **kwargs: Any) -> Optional[HttpResponse]:
"""