mirror of https://github.com/zulip/zulip.git
redis: Extract put_dict_in_redis and get_dict_from_redis helpers.
This commit is contained in:
parent
eafdefc5c9
commit
af2c4a9735
|
@ -1,7 +1,28 @@
|
|||
from django.conf import settings
|
||||
from typing import Any, Dict, Optional
|
||||
from zerver.lib.utils import generate_random_token
|
||||
|
||||
import redis
|
||||
import ujson
|
||||
|
||||
def get_redis_client() -> redis.StrictRedis:
|
||||
return redis.StrictRedis(host=settings.REDIS_HOST, port=settings.REDIS_PORT,
|
||||
password=settings.REDIS_PASSWORD, db=0)
|
||||
|
||||
def put_dict_in_redis(redis_client: redis.StrictRedis, key_format: str,
|
||||
data_to_store: Dict[str, Any],
|
||||
expiration_seconds: int) -> str:
|
||||
with redis_client.pipeline() as pipeline:
|
||||
token = generate_random_token(64)
|
||||
key = key_format.format(token=token)
|
||||
pipeline.set(key, ujson.dumps(data_to_store))
|
||||
pipeline.expire(key, expiration_seconds)
|
||||
pipeline.execute()
|
||||
|
||||
return key
|
||||
|
||||
def get_dict_from_redis(redis_client: redis.StrictRedis, key: str) -> Optional[Dict[str, Any]]:
|
||||
data = redis_client.get(key)
|
||||
if data is None:
|
||||
return None
|
||||
return ujson.loads(data)
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
import copy
|
||||
import logging
|
||||
import magic
|
||||
import ujson
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
|
||||
from typing_extensions import TypedDict
|
||||
|
@ -50,8 +49,7 @@ from zerver.lib.avatar_hash import user_avatar_content_hash
|
|||
from zerver.lib.dev_ldap_directory import init_fakeldap
|
||||
from zerver.lib.request import JsonableError
|
||||
from zerver.lib.users import check_full_name, validate_user_custom_profile_field
|
||||
from zerver.lib.utils import generate_random_token
|
||||
from zerver.lib.redis_utils import get_redis_client
|
||||
from zerver.lib.redis_utils import get_redis_client, get_dict_from_redis, put_dict_in_redis
|
||||
from zerver.models import CustomProfileField, DisposableEmailError, DomainNotAllowedForRealmError, \
|
||||
EmailContainsPlusError, PreregistrationUser, UserProfile, Realm, custom_profile_fields_for_realm, \
|
||||
email_allowed_for_realm, get_user_profile_by_id, remote_user_to_email, \
|
||||
|
@ -1318,28 +1316,22 @@ class SAMLAuthBackend(SocialAuthMixin, SAMLAuth):
|
|||
|
||||
@classmethod
|
||||
def put_data_in_redis(cls, data_to_relay: Dict[str, Any]) -> str:
|
||||
with redis_client.pipeline() as pipeline:
|
||||
token = generate_random_token(64)
|
||||
key = "saml_token_{}".format(token)
|
||||
pipeline.set(key, ujson.dumps(data_to_relay))
|
||||
pipeline.expire(key, cls.REDIS_EXPIRATION_SECONDS)
|
||||
pipeline.execute()
|
||||
|
||||
return key
|
||||
return put_dict_in_redis(redis_client, "saml_token_{token}",
|
||||
data_to_store=data_to_relay,
|
||||
expiration_seconds=cls.REDIS_EXPIRATION_SECONDS)
|
||||
|
||||
@classmethod
|
||||
def get_data_from_redis(cls, key: str) -> Optional[Dict[str, Any]]:
|
||||
redis_data = None
|
||||
data = None
|
||||
if key.startswith('saml_token_'):
|
||||
# Safety if statement, to not allow someone to poke around arbitrary redis keys here.
|
||||
redis_data = redis_client.get(key)
|
||||
if redis_data is None:
|
||||
data = get_dict_from_redis(redis_client, key)
|
||||
if data is None:
|
||||
# TODO: We will need some sort of user-facing message
|
||||
# about the authentication session having expired here.
|
||||
logging.info("SAML authentication failed: bad RelayState token.")
|
||||
return None
|
||||
|
||||
return ujson.loads(redis_data)
|
||||
return data
|
||||
|
||||
def auth_complete(self, *args: Any, **kwargs: Any) -> Optional[HttpResponse]:
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue