Use appropriate string types and correctly encode/decode them.

This commit is contained in:
Eklavya Sharma 2016-06-10 02:10:04 +05:30
parent f18493f922
commit d3b80d94a2
2 changed files with 18 additions and 15 deletions

View File

@ -19,6 +19,7 @@ import os
import os.path
import hashlib
import six
from six import text_type
remote_cache_time_start = 0.0
remote_cache_total_time = 0.0
@ -46,16 +47,17 @@ def remote_cache_stats_finish():
remote_cache_total_time += (time.time() - remote_cache_time_start)
def get_or_create_key_prefix():
# type: () -> str
# type: () -> text_type
if settings.TEST_SUITE:
# This sets the prefix mostly for the benefit of the JS tests.
# The Python tests overwrite KEY_PREFIX on each test.
return 'test_suite:' + str(os.getpid()) + ':'
return u'test_suite:%s:' % (text_type(os.getpid()),)
filename = os.path.join(settings.DEPLOY_ROOT, "remote_cache_prefix")
try:
fd = os.open(filename, os.O_CREAT | os.O_EXCL | os.O_RDWR, 0o444)
prefix = base64.b16encode(hashlib.sha256(str(random.getrandbits(256))).digest())[:32].lower() + ':'
random_hash = hashlib.sha256(text_type(random.getrandbits(256)).encode('utf-8')).digest()
prefix = base64.b16encode(random_hash)[:32].decode('utf-8').lower() + ':'
# This does close the underlying file
with os.fdopen(fd, 'w') as f:
f.write(prefix + "\n")
@ -77,12 +79,12 @@ def get_or_create_key_prefix():
return prefix
KEY_PREFIX = get_or_create_key_prefix() # type: str
KEY_PREFIX = get_or_create_key_prefix() # type: text_type
def bounce_key_prefix_for_testing(test_name):
# type: (str) -> None
# type: (text_type) -> None
global KEY_PREFIX
KEY_PREFIX = test_name + ':' + str(os.getpid()) + ':'
KEY_PREFIX = test_name + u':' + text_type(os.getpid()) + u':'
def get_cache_backend(cache_name):
# type: (str) -> get_cache
@ -152,7 +154,7 @@ def cache_get(key, cache_name=None):
def cache_get_many(keys, cache_name=None):
# type: (List[str], Optional[str]) -> Dict[str, Any]
keys = [KEY_PREFIX + key for key in keys]
keys = [KEY_PREFIX + key for key in keys] # type: ignore # temporary
remote_cache_stats_start()
ret = get_cache_backend(cache_name).get_many(keys)
remote_cache_stats_finish()
@ -163,7 +165,7 @@ def cache_set_many(items, cache_name=None, timeout=None):
new_items = {}
for key in items:
new_items[KEY_PREFIX + key] = items[key]
items = new_items
items = new_items # type: ignore # temporary
remote_cache_stats_start()
get_cache_backend(cache_name).set_many(items, timeout=timeout)
remote_cache_stats_finish()

View File

@ -17,6 +17,7 @@ from zerver.lib.cache import cache_with_key, flush_user_profile, flush_realm, \
active_bot_dicts_in_realm_cache_key, active_user_dict_fields, \
active_bot_dict_fields
from zerver.lib.utils import make_safe_digest, generate_random_token
from zerver.lib.str_utils import force_bytes, dict_with_str_keys
from django.db import transaction
from zerver.lib.avatar import gravatar_hash, get_avatar_url
from zerver.lib.camo import get_camo_url
@ -35,7 +36,7 @@ import pylibmc
import re
import ujson
import logging
from six import text_type
from six import binary_type, text_type
import time
import datetime
@ -748,13 +749,13 @@ def linebreak(string):
# type: (str) -> str
return string.replace('\n\n', '<p/>').replace('\n', '<br/>')
def extract_message_dict(message_str):
# type: (str) -> Dict[str, Any]
return ujson.loads(zlib.decompress(message_str).decode("utf-8"))
def extract_message_dict(message_bytes):
# type: (binary_type) -> Dict[str, Any]
return dict_with_str_keys(ujson.loads(zlib.decompress(message_bytes).decode("utf-8")))
def stringify_message_dict(message_dict):
# type: (Dict[Any, Any]) -> str
return zlib.compress(ujson.dumps(message_dict).encode("utf-8"))
# type: (Dict[str, Any]) -> binary_type
return zlib.compress(force_bytes(ujson.dumps(message_dict)))
def to_dict_cache_key_id(message_id, apply_markdown):
# type: (int, bool) -> str
@ -869,7 +870,7 @@ class Message(models.Model):
@cache_with_key(to_dict_cache_key, timeout=3600*24)
def to_dict_json(self, apply_markdown):
# type: (bool) -> str
# type: (bool) -> binary_type
return stringify_message_dict(self.to_dict_uncached(apply_markdown))
def to_dict_uncached(self, apply_markdown):