cache: Fix type: ignore issues.

This was hiding an actual type error in test_cache: a mismatch between
the object ID type, which is str, and the default id_fetcher, which
returns int.

Mypy’s insufficient support for default generic arguments basically
means we can’t use them without a lot of overloading, and there are
not enough callers here to justify that.

https://github.com/python/mypy/issues/3737

We avoid this being super messy where the code calls this by adding
some less generic wrappers for generic_bulk_cached_fetch.

Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
Anders Kaseorg 2020-06-30 18:29:31 -07:00 committed by Tim Abbott
parent bb8dcb9b1e
commit 1b96af2987
5 changed files with 67 additions and 33 deletions

View File

@ -348,23 +348,10 @@ CacheItemT = TypeVar('CacheItemT')
# serializable objects, will be the object; if encoded, bytes. # serializable objects, will be the object; if encoded, bytes.
CompressedItemT = TypeVar('CompressedItemT') CompressedItemT = TypeVar('CompressedItemT')
def default_extractor(obj: CompressedItemT) -> ItemT:
return obj # type: ignore[return-value] # Need a type assert that ItemT=CompressedItemT
def default_setter(obj: ItemT) -> CompressedItemT:
return obj # type: ignore[return-value] # Need a type assert that ItemT=CompressedItemT
def default_id_fetcher(obj: ItemT) -> ObjKT:
return obj.id # type: ignore[attr-defined] # Need ItemT/CompressedItemT typevars to be a Django protocol
def default_cache_transformer(obj: ItemT) -> CacheItemT:
return obj # type: ignore[return-value] # Need a type assert that ItemT=CacheItemT
# Required Arguments are as follows: # Required Arguments are as follows:
# * object_ids: The list of object ids to look up # * object_ids: The list of object ids to look up
# * cache_key_function: object_id => cache key # * cache_key_function: object_id => cache key
# * query_function: [object_ids] => [objects from database] # * query_function: [object_ids] => [objects from database]
# Optional keyword arguments:
# * setter: Function to call before storing items to cache (e.g. compression) # * setter: Function to call before storing items to cache (e.g. compression)
# * extractor: Function to call on items returned from cache # * extractor: Function to call on items returned from cache
# (e.g. decompression). Should be the inverse of the setter # (e.g. decompression). Should be the inverse of the setter
@ -378,10 +365,11 @@ def generic_bulk_cached_fetch(
cache_key_function: Callable[[ObjKT], str], cache_key_function: Callable[[ObjKT], str],
query_function: Callable[[List[ObjKT]], Iterable[ItemT]], query_function: Callable[[List[ObjKT]], Iterable[ItemT]],
object_ids: Sequence[ObjKT], object_ids: Sequence[ObjKT],
extractor: Callable[[CompressedItemT], CacheItemT] = default_extractor, *,
setter: Callable[[CacheItemT], CompressedItemT] = default_setter, extractor: Callable[[CompressedItemT], CacheItemT],
id_fetcher: Callable[[ItemT], ObjKT] = default_id_fetcher, setter: Callable[[CacheItemT], CompressedItemT],
cache_transformer: Callable[[ItemT], CacheItemT] = default_cache_transformer, id_fetcher: Callable[[ItemT], ObjKT],
cache_transformer: Callable[[ItemT], CacheItemT],
) -> Dict[ObjKT, CacheItemT]: ) -> Dict[ObjKT, CacheItemT]:
if len(object_ids) == 0: if len(object_ids) == 0:
# Nothing to fetch. # Nothing to fetch.
@ -418,6 +406,39 @@ def generic_bulk_cached_fetch(
return {object_id: cached_objects[cache_keys[object_id]] for object_id in object_ids return {object_id: cached_objects[cache_keys[object_id]] for object_id in object_ids
if cache_keys[object_id] in cached_objects} if cache_keys[object_id] in cached_objects}
def transformed_bulk_cached_fetch(
cache_key_function: Callable[[ObjKT], str],
query_function: Callable[[List[ObjKT]], Iterable[ItemT]],
object_ids: Sequence[ObjKT],
*,
id_fetcher: Callable[[ItemT], ObjKT],
cache_transformer: Callable[[ItemT], CacheItemT],
) -> Dict[ObjKT, CacheItemT]:
return generic_bulk_cached_fetch(
cache_key_function,
query_function,
object_ids,
extractor=lambda obj: obj,
setter=lambda obj: obj,
id_fetcher=id_fetcher,
cache_transformer=cache_transformer,
)
def bulk_cached_fetch(
cache_key_function: Callable[[ObjKT], str],
query_function: Callable[[List[ObjKT]], Iterable[ItemT]],
object_ids: Sequence[ObjKT],
*,
id_fetcher: Callable[[ItemT], ObjKT],
) -> Dict[ObjKT, ItemT]:
return transformed_bulk_cached_fetch(
cache_key_function,
query_function,
object_ids,
id_fetcher=id_fetcher,
cache_transformer=lambda obj: obj,
)
def preview_url_cache_key(url: str) -> str: def preview_url_cache_key(url: str) -> str:
return f"preview_url:{make_safe_digest(url)}" return f"preview_url:{make_safe_digest(url)}"

View File

@ -3,10 +3,11 @@ from typing import Dict, List, Optional, Set, Tuple
from typing_extensions import TypedDict from typing_extensions import TypedDict
from zerver.lib.cache import ( from zerver.lib.cache import (
bulk_cached_fetch,
cache_with_key, cache_with_key,
display_recipient_bulk_get_users_by_id_cache_key, display_recipient_bulk_get_users_by_id_cache_key,
display_recipient_cache_key, display_recipient_cache_key,
generic_bulk_cached_fetch, transformed_bulk_cached_fetch,
) )
from zerver.lib.types import DisplayRecipientT, UserDisplayRecipient from zerver.lib.types import DisplayRecipientT, UserDisplayRecipient
from zerver.models import Recipient, Stream, UserProfile, bulk_get_huddle_user_ids from zerver.models import Recipient, Stream, UserProfile, bulk_get_huddle_user_ids
@ -49,7 +50,7 @@ def user_dict_id_fetcher(user_dict: UserDisplayRecipient) -> int:
return user_dict['id'] return user_dict['id']
def bulk_get_user_profile_by_id(uids: List[int]) -> Dict[int, UserDisplayRecipient]: def bulk_get_user_profile_by_id(uids: List[int]) -> Dict[int, UserDisplayRecipient]:
return generic_bulk_cached_fetch( return bulk_cached_fetch(
# Use a separate cache key to protect us from conflicts with # Use a separate cache key to protect us from conflicts with
# the get_user_profile_by_id cache. # the get_user_profile_by_id cache.
# (Since we fetch only several fields here) # (Since we fetch only several fields here)
@ -96,7 +97,7 @@ def bulk_fetch_display_recipients(recipient_tuples: Set[Tuple[int, int, int]],
return stream['name'] return stream['name']
# ItemT = Stream, CacheItemT = str (name), ObjKT = int (recipient_id) # ItemT = Stream, CacheItemT = str (name), ObjKT = int (recipient_id)
stream_display_recipients: Dict[int, str] = generic_bulk_cached_fetch( stream_display_recipients: Dict[int, str] = transformed_bulk_cached_fetch(
cache_key_function=display_recipient_cache_key, cache_key_function=display_recipient_cache_key,
query_function=stream_query_function, query_function=stream_query_function,
object_ids=[recipient[0] for recipient in stream_recipients], object_ids=[recipient[0] for recipient in stream_recipients],
@ -167,7 +168,7 @@ def bulk_fetch_display_recipients(recipient_tuples: Set[Tuple[int, int, int]],
# ItemT = Tuple[int, List[UserDisplayRecipient]] (recipient_id, list of corresponding users) # ItemT = Tuple[int, List[UserDisplayRecipient]] (recipient_id, list of corresponding users)
# CacheItemT = List[UserDisplayRecipient] (display_recipient list) # CacheItemT = List[UserDisplayRecipient] (display_recipient list)
# ObjKT = int (recipient_id) # ObjKT = int (recipient_id)
personal_and_huddle_display_recipients = generic_bulk_cached_fetch( personal_and_huddle_display_recipients: Dict[int, List[UserDisplayRecipient]] = transformed_bulk_cached_fetch(
cache_key_function=display_recipient_cache_key, cache_key_function=display_recipient_cache_key,
query_function=personal_and_huddle_query_function, query_function=personal_and_huddle_query_function,
object_ids=[recipient[0] for recipient in personal_and_huddle_recipients], object_ids=[recipient[0] for recipient in personal_and_huddle_recipients],

View File

@ -12,7 +12,7 @@ from zulip_bots.custom_exceptions import ConfigValidationError
from zerver.lib.avatar import avatar_url, get_avatar_field from zerver.lib.avatar import avatar_url, get_avatar_field
from zerver.lib.cache import ( from zerver.lib.cache import (
generic_bulk_cached_fetch, bulk_cached_fetch,
realm_user_dict_fields, realm_user_dict_fields,
user_profile_by_id_cache_key, user_profile_by_id_cache_key,
user_profile_cache_key_id, user_profile_cache_key_id,
@ -173,7 +173,7 @@ def bulk_get_users(emails: List[str], realm: Optional[Realm],
def user_to_email(user_profile: UserProfile) -> str: def user_to_email(user_profile: UserProfile) -> str:
return user_profile.email.lower() return user_profile.email.lower()
return generic_bulk_cached_fetch( return bulk_cached_fetch(
# Use a separate cache key to protect us from conflicts with # Use a separate cache key to protect us from conflicts with
# the get_user cache. # the get_user cache.
lambda email: 'bulk_get_users:' + user_profile_cache_key_id(email, realm_id), lambda email: 'bulk_get_users:' + user_profile_cache_key_id(email, realm_id),
@ -182,6 +182,9 @@ def bulk_get_users(emails: List[str], realm: Optional[Realm],
id_fetcher=user_to_email, id_fetcher=user_to_email,
) )
def get_user_id(user: UserProfile) -> int:
return user.id
def user_ids_to_users(user_ids: Sequence[int], realm: Realm) -> List[UserProfile]: def user_ids_to_users(user_ids: Sequence[int], realm: Realm) -> List[UserProfile]:
# TODO: Consider adding a flag to control whether deactivated # TODO: Consider adding a flag to control whether deactivated
# users should be included. # users should be included.
@ -189,10 +192,11 @@ def user_ids_to_users(user_ids: Sequence[int], realm: Realm) -> List[UserProfile
def fetch_users_by_id(user_ids: List[int]) -> List[UserProfile]: def fetch_users_by_id(user_ids: List[int]) -> List[UserProfile]:
return list(UserProfile.objects.filter(id__in=user_ids).select_related()) return list(UserProfile.objects.filter(id__in=user_ids).select_related())
user_profiles_by_id: Dict[int, UserProfile] = generic_bulk_cached_fetch( user_profiles_by_id: Dict[int, UserProfile] = bulk_cached_fetch(
cache_key_function=user_profile_by_id_cache_key, cache_key_function=user_profile_by_id_cache_key,
query_function=fetch_users_by_id, query_function=fetch_users_by_id,
object_ids=user_ids, object_ids=user_ids,
id_fetcher=get_user_id,
) )
found_user_ids = user_profiles_by_id.keys() found_user_ids = user_profiles_by_id.keys()

View File

@ -43,6 +43,7 @@ from zerver.lib.cache import (
bot_dict_fields, bot_dict_fields,
bot_dicts_in_realm_cache_key, bot_dicts_in_realm_cache_key,
bot_profile_cache_key, bot_profile_cache_key,
bulk_cached_fetch,
cache_delete, cache_delete,
cache_set, cache_set,
cache_with_key, cache_with_key,
@ -52,7 +53,6 @@ from zerver.lib.cache import (
flush_submessage, flush_submessage,
flush_used_upload_space_cache, flush_used_upload_space_cache,
flush_user_profile, flush_user_profile,
generic_bulk_cached_fetch,
get_realm_used_upload_space_cache_key, get_realm_used_upload_space_cache_key,
get_stream_cache_key, get_stream_cache_key,
realm_alert_words_automaton_cache_key, realm_alert_words_automaton_cache_key,
@ -1653,10 +1653,12 @@ def bulk_get_streams(realm: Realm, stream_names: STREAM_NAMES) -> Dict[str, Any]
def stream_to_lower_name(stream: Stream) -> str: def stream_to_lower_name(stream: Stream) -> str:
return stream.name.lower() return stream.name.lower()
return generic_bulk_cached_fetch(stream_name_to_cache_key, return bulk_cached_fetch(
fetch_streams_by_name, stream_name_to_cache_key,
[stream_name.lower() for stream_name in stream_names], fetch_streams_by_name,
id_fetcher=stream_to_lower_name) [stream_name.lower() for stream_name in stream_names],
id_fetcher=stream_to_lower_name,
)
def get_huddle_recipient(user_profile_ids: Set[int]) -> Recipient: def get_huddle_recipient(user_profile_ids: Set[int]) -> Recipient:

View File

@ -8,6 +8,7 @@ from zerver.lib.cache import (
MEMCACHED_MAX_KEY_LENGTH, MEMCACHED_MAX_KEY_LENGTH,
InvalidCacheKeyException, InvalidCacheKeyException,
NotFoundInCache, NotFoundInCache,
bulk_cached_fetch,
cache_delete, cache_delete,
cache_delete_many, cache_delete_many,
cache_get, cache_get,
@ -15,7 +16,6 @@ from zerver.lib.cache import (
cache_set, cache_set,
cache_set_many, cache_set_many,
cache_with_key, cache_with_key,
generic_bulk_cached_fetch,
get_cache_with_key, get_cache_with_key,
safe_cache_get_many, safe_cache_get_many,
safe_cache_set_many, safe_cache_set_many,
@ -272,6 +272,9 @@ class BotCacheKeyTest(ZulipTestCase):
user_profile2 = get_user_profile_by_email(settings.EMAIL_GATEWAY_BOT) user_profile2 = get_user_profile_by_email(settings.EMAIL_GATEWAY_BOT)
self.assertEqual(user_profile2.is_api_super_user, flipped_setting) self.assertEqual(user_profile2.is_api_super_user, flipped_setting)
def get_user_email(user: UserProfile) -> str:
return user.email # nocoverage
class GenericBulkCachedFetchTest(ZulipTestCase): class GenericBulkCachedFetchTest(ZulipTestCase):
def test_query_function_called_only_if_needed(self) -> None: def test_query_function_called_only_if_needed(self) -> None:
# Get the user cached: # Get the user cached:
@ -285,20 +288,22 @@ class GenericBulkCachedFetchTest(ZulipTestCase):
# query_function shouldn't be called, because the only requested object # query_function shouldn't be called, because the only requested object
# is already cached: # is already cached:
result: Dict[str, UserProfile] = generic_bulk_cached_fetch( result: Dict[str, UserProfile] = bulk_cached_fetch(
cache_key_function=user_profile_by_email_cache_key, cache_key_function=user_profile_by_email_cache_key,
query_function=query_function, query_function=query_function,
object_ids=[self.example_email("hamlet")], object_ids=[self.example_email("hamlet")],
id_fetcher=get_user_email,
) )
self.assertEqual(result, {hamlet.delivery_email: hamlet}) self.assertEqual(result, {hamlet.delivery_email: hamlet})
flush_cache(Mock()) flush_cache(Mock())
# With the cache flushed, the query_function should get called: # With the cache flushed, the query_function should get called:
with self.assertRaises(CustomException): with self.assertRaises(CustomException):
generic_bulk_cached_fetch( result = bulk_cached_fetch(
cache_key_function=user_profile_by_email_cache_key, cache_key_function=user_profile_by_email_cache_key,
query_function=query_function, query_function=query_function,
object_ids=[self.example_email("hamlet")], object_ids=[self.example_email("hamlet")],
id_fetcher=get_user_email,
) )
def test_empty_object_ids_list(self) -> None: def test_empty_object_ids_list(self) -> None:
@ -313,9 +318,10 @@ class GenericBulkCachedFetchTest(ZulipTestCase):
# query_function and cache_key_function shouldn't be called, because # query_function and cache_key_function shouldn't be called, because
# objects_ids is empty, so there's nothing to do. # objects_ids is empty, so there's nothing to do.
result: Dict[str, UserProfile] = generic_bulk_cached_fetch( result: Dict[str, UserProfile] = bulk_cached_fetch(
cache_key_function=cache_key_function, cache_key_function=cache_key_function,
query_function=query_function, query_function=query_function,
object_ids=[], object_ids=[],
id_fetcher=get_user_email,
) )
self.assertEqual(result, {}) self.assertEqual(result, {})