generic_bulk_cached_fetch: Only call query_function if necessary.

This commit is contained in:
Mateusz Mandera 2019-08-10 23:31:14 +02:00 committed by Tim Abbott
parent 898bc52538
commit cb2c9b04b3
4 changed files with 63 additions and 12 deletions

View File

@ -10,7 +10,7 @@ from django.core.cache.backends.base import BaseCache
from django.http import HttpRequest
from typing import Any, Callable, Dict, Iterable, List, \
Optional, TypeVar, Tuple, TYPE_CHECKING
Optional, Sequence, TypeVar, Tuple, TYPE_CHECKING
from zerver.lib.utils import statsd, statsd_key, make_safe_digest
import time
@ -262,12 +262,16 @@ def default_cache_transformer(obj: ItemT) -> CacheItemT:
def generic_bulk_cached_fetch(
cache_key_function: Callable[[ObjKT], str],
query_function: Callable[[List[ObjKT]], Iterable[ItemT]],
object_ids: Iterable[ObjKT],
object_ids: Sequence[ObjKT],
extractor: Callable[[CompressedItemT], CacheItemT] = default_extractor,
setter: Callable[[CacheItemT], CompressedItemT] = default_setter,
id_fetcher: Callable[[ItemT], ObjKT] = default_id_fetcher,
cache_transformer: Callable[[ItemT], CacheItemT] = default_cache_transformer,
) -> Dict[ObjKT, CacheItemT]:
if len(object_ids) == 0:
# Nothing to fetch.
return {}
cache_keys = {} # type: Dict[ObjKT, str]
for object_id in object_ids:
cache_keys[object_id] = cache_key_function(object_id)
@ -278,7 +282,12 @@ def generic_bulk_cached_fetch(
cached_objects[key] = extractor(cached_objects_compressed[key][0])
needed_ids = [object_id for object_id in object_ids if
cache_keys[object_id] not in cached_objects]
db_objects = query_function(needed_ids)
# Only call query_function if there are some ids to fetch from the database:
if len(needed_ids) > 0:
db_objects = query_function(needed_ids)
else:
db_objects = []
items_for_remote_cache = {} # type: Dict[str, Tuple[CompressedItemT]]
for obj in db_objects:

View File

@ -115,9 +115,6 @@ def bulk_get_users(emails: List[str], realm: Optional[Realm],
#
# But chaining __in and __iexact doesn't work with Django's
# ORM, so we have the following hack to construct the relevant where clause
if len(emails) == 0:
return []
upper_list = ", ".join(["UPPER(%s)"] * len(emails))
where_clause = "UPPER(zerver_userprofile.email::text) IN (%s)" % (upper_list,)
return query.select_related("realm").extra(
@ -141,9 +138,6 @@ def user_ids_to_users(user_ids: List[int], realm: Realm) -> List[UserProfile]:
# users should be included.
def fetch_users_by_id(user_ids: List[int]) -> List[UserProfile]:
if len(user_ids) == 0:
return []
return list(UserProfile.objects.filter(id__in=user_ids).select_related())
user_profiles_by_id = generic_bulk_cached_fetch(

View File

@ -1430,8 +1430,6 @@ def bulk_get_streams(realm: Realm, stream_names: STREAM_NAMES) -> Dict[str, Any]
#
# But chaining __in and __iexact doesn't work with Django's
# ORM, so we have the following hack to construct the relevant where clause
if len(stream_names) == 0:
return []
upper_list = ", ".join(["UPPER(%s)"] * len(stream_names))
where_clause = "UPPER(zerver_stream.name::text) IN (%s)" % (upper_list,)
return get_active_streams(realm.id).select_related("realm").extra(

View File

@ -1,10 +1,12 @@
from django.conf import settings
from mock import Mock, patch
from typing import List, Dict
from zerver.apps import flush_cache
from zerver.lib.cache import generic_bulk_cached_fetch, user_profile_by_email_cache_key
from zerver.lib.test_classes import ZulipTestCase
from zerver.models import get_system_bot, get_user_profile_by_email
from zerver.models import get_system_bot, get_user_profile_by_email, UserProfile
class AppsTest(ZulipTestCase):
def test_cache_gets_flushed(self) -> None:
@ -34,3 +36,51 @@ class BotCacheKeyTest(ZulipTestCase):
user_profile2 = get_user_profile_by_email(settings.EMAIL_GATEWAY_BOT)
self.assertEqual(user_profile2.is_api_super_user, flipped_setting)
class GenericBulkCachedFetchTest(ZulipTestCase):
def test_query_function_called_only_if_needed(self) -> None:
# Get the user cached:
hamlet = get_user_profile_by_email(self.example_email("hamlet"))
class CustomException(Exception):
pass
def query_function(emails: List[str]) -> List[UserProfile]:
raise CustomException("The query function was called")
# query_function shouldn't be called, because the only requested object
# is already cached:
result = generic_bulk_cached_fetch(
cache_key_function=user_profile_by_email_cache_key,
query_function=query_function,
object_ids=[self.example_email("hamlet")]
) # type: Dict[str, UserProfile]
self.assertEqual(result, {hamlet.email: hamlet})
flush_cache(Mock())
# With the cache flushed, the query_function should get called:
with self.assertRaises(CustomException):
generic_bulk_cached_fetch(
cache_key_function=user_profile_by_email_cache_key,
query_function=query_function,
object_ids=[self.example_email("hamlet")]
)
def test_empty_object_ids_list(self) -> None:
class CustomException(Exception):
pass
def cache_key_function(email: str) -> str: # nocoverage -- this is just here to make sure it's not called
raise CustomException("The cache key function was called")
def query_function(emails: List[str]) -> List[UserProfile]: # nocoverage -- this is just here to make sure it's not called
raise CustomException("The query function was called")
# query_function and cache_key_function shouldn't be called, because
# objects_ids is empty, so there's nothing to do.
result = generic_bulk_cached_fetch(
cache_key_function=cache_key_function,
query_function=query_function,
object_ids=[]
) # type: Dict[str, UserProfile]
self.assertEqual(result, {})