diff --git a/zerver/lib/cache.py b/zerver/lib/cache.py index d3ef103c91..1362f6d7c5 100644 --- a/zerver/lib/cache.py +++ b/zerver/lib/cache.py @@ -10,7 +10,7 @@ from django.core.cache.backends.base import BaseCache from django.http import HttpRequest from typing import Any, Callable, Dict, Iterable, List, \ - Optional, TypeVar, Tuple, TYPE_CHECKING + Optional, Sequence, TypeVar, Tuple, TYPE_CHECKING from zerver.lib.utils import statsd, statsd_key, make_safe_digest import time @@ -262,12 +262,16 @@ def default_cache_transformer(obj: ItemT) -> CacheItemT: def generic_bulk_cached_fetch( cache_key_function: Callable[[ObjKT], str], query_function: Callable[[List[ObjKT]], Iterable[ItemT]], - object_ids: Iterable[ObjKT], + object_ids: Sequence[ObjKT], extractor: Callable[[CompressedItemT], CacheItemT] = default_extractor, setter: Callable[[CacheItemT], CompressedItemT] = default_setter, id_fetcher: Callable[[ItemT], ObjKT] = default_id_fetcher, cache_transformer: Callable[[ItemT], CacheItemT] = default_cache_transformer, ) -> Dict[ObjKT, CacheItemT]: + if len(object_ids) == 0: + # Nothing to fetch. + return {} + cache_keys = {} # type: Dict[ObjKT, str] for object_id in object_ids: cache_keys[object_id] = cache_key_function(object_id) @@ -278,7 +282,12 @@ def generic_bulk_cached_fetch( cached_objects[key] = extractor(cached_objects_compressed[key][0]) needed_ids = [object_id for object_id in object_ids if cache_keys[object_id] not in cached_objects] - db_objects = query_function(needed_ids) + + # Only call query_function if there are some ids to fetch from the database: + if len(needed_ids) > 0: + db_objects = query_function(needed_ids) + else: + db_objects = [] items_for_remote_cache = {} # type: Dict[str, Tuple[CompressedItemT]] for obj in db_objects: diff --git a/zerver/lib/users.py b/zerver/lib/users.py index 9f0a02a358..c261bbabee 100644 --- a/zerver/lib/users.py +++ b/zerver/lib/users.py @@ -115,9 +115,6 @@ def bulk_get_users(emails: List[str], realm: Optional[Realm], # # But chaining __in and __iexact doesn't work with Django's # ORM, so we have the following hack to construct the relevant where clause - if len(emails) == 0: - return [] - upper_list = ", ".join(["UPPER(%s)"] * len(emails)) where_clause = "UPPER(zerver_userprofile.email::text) IN (%s)" % (upper_list,) return query.select_related("realm").extra( @@ -141,9 +138,6 @@ def user_ids_to_users(user_ids: List[int], realm: Realm) -> List[UserProfile]: # users should be included. def fetch_users_by_id(user_ids: List[int]) -> List[UserProfile]: - if len(user_ids) == 0: - return [] - return list(UserProfile.objects.filter(id__in=user_ids).select_related()) user_profiles_by_id = generic_bulk_cached_fetch( diff --git a/zerver/models.py b/zerver/models.py index a32518a24e..7a6ae5bce5 100644 --- a/zerver/models.py +++ b/zerver/models.py @@ -1430,8 +1430,6 @@ def bulk_get_streams(realm: Realm, stream_names: STREAM_NAMES) -> Dict[str, Any] # # But chaining __in and __iexact doesn't work with Django's # ORM, so we have the following hack to construct the relevant where clause - if len(stream_names) == 0: - return [] upper_list = ", ".join(["UPPER(%s)"] * len(stream_names)) where_clause = "UPPER(zerver_stream.name::text) IN (%s)" % (upper_list,) return get_active_streams(realm.id).select_related("realm").extra( diff --git a/zerver/tests/test_cache.py b/zerver/tests/test_cache.py index 2d64f2dc03..53ab7de981 100644 --- a/zerver/tests/test_cache.py +++ b/zerver/tests/test_cache.py @@ -1,10 +1,12 @@ from django.conf import settings from mock import Mock, patch +from typing import List, Dict from zerver.apps import flush_cache +from zerver.lib.cache import generic_bulk_cached_fetch, user_profile_by_email_cache_key from zerver.lib.test_classes import ZulipTestCase -from zerver.models import get_system_bot, get_user_profile_by_email +from zerver.models import get_system_bot, get_user_profile_by_email, UserProfile class AppsTest(ZulipTestCase): def test_cache_gets_flushed(self) -> None: @@ -34,3 +36,51 @@ class BotCacheKeyTest(ZulipTestCase): user_profile2 = get_user_profile_by_email(settings.EMAIL_GATEWAY_BOT) self.assertEqual(user_profile2.is_api_super_user, flipped_setting) + +class GenericBulkCachedFetchTest(ZulipTestCase): + def test_query_function_called_only_if_needed(self) -> None: + # Get the user cached: + hamlet = get_user_profile_by_email(self.example_email("hamlet")) + + class CustomException(Exception): + pass + + def query_function(emails: List[str]) -> List[UserProfile]: + raise CustomException("The query function was called") + + # query_function shouldn't be called, because the only requested object + # is already cached: + result = generic_bulk_cached_fetch( + cache_key_function=user_profile_by_email_cache_key, + query_function=query_function, + object_ids=[self.example_email("hamlet")] + ) # type: Dict[str, UserProfile] + self.assertEqual(result, {hamlet.email: hamlet}) + + flush_cache(Mock()) + # With the cache flushed, the query_function should get called: + with self.assertRaises(CustomException): + generic_bulk_cached_fetch( + cache_key_function=user_profile_by_email_cache_key, + query_function=query_function, + object_ids=[self.example_email("hamlet")] + ) + + def test_empty_object_ids_list(self) -> None: + class CustomException(Exception): + pass + + def cache_key_function(email: str) -> str: # nocoverage -- this is just here to make sure it's not called + raise CustomException("The cache key function was called") + + def query_function(emails: List[str]) -> List[UserProfile]: # nocoverage -- this is just here to make sure it's not called + raise CustomException("The query function was called") + + # query_function and cache_key_function shouldn't be called, because + # objects_ids is empty, so there's nothing to do. + result = generic_bulk_cached_fetch( + cache_key_function=cache_key_function, + query_function=query_function, + object_ids=[] + ) # type: Dict[str, UserProfile] + self.assertEqual(result, {})