decorator: Strengthen types of signature-preserving decorators.

Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
Anders Kaseorg 2020-06-23 16:52:37 -07:00
parent a7bac82f2e
commit e582bbea4a
4 changed files with 37 additions and 33 deletions

View File

@ -4,7 +4,7 @@ import logging
import urllib
from functools import wraps
from io import BytesIO
from typing import Any, Callable, Dict, Optional, Tuple, TypeVar, Union
from typing import Any, Callable, Dict, Optional, Tuple, TypeVar, Union, cast
import django_otp
import ujson
@ -61,8 +61,6 @@ else: # nocoverage # Hack here basically to make impossible code paths compile
get_remote_server_by_uuid = Mock()
RemoteZulipServer = Mock() # type: ignore[misc] # https://github.com/JukkaL/mypy/issues/1188
ReturnT = TypeVar('ReturnT')
webhook_logger = logging.getLogger("zulip.zerver.webhooks")
log_to_file(webhook_logger, settings.API_KEY_ONLY_WEBHOOK_LOG_PATH)
@ -70,17 +68,19 @@ webhook_unexpected_events_logger = logging.getLogger("zulip.zerver.lib.webhooks.
log_to_file(webhook_unexpected_events_logger,
settings.WEBHOOK_UNEXPECTED_EVENTS_LOG_PATH)
def cachify(method: Callable[..., ReturnT]) -> Callable[..., ReturnT]:
dct: Dict[Tuple[Any, ...], ReturnT] = {}
FuncT = TypeVar('FuncT', bound=Callable[..., object])
def cache_wrapper(*args: Any) -> ReturnT:
def cachify(method: FuncT) -> FuncT:
dct: Dict[Tuple[object, ...], object] = {}
def cache_wrapper(*args: object) -> object:
tup = tuple(args)
if tup in dct:
return dct[tup]
result = method(*args)
dct[tup] = result
return result
return cache_wrapper
return cast(FuncT, cache_wrapper) # https://github.com/python/mypy/issues/1927
def update_user_activity(request: HttpRequest, user_profile: UserProfile,
query: Optional[str]) -> None:
@ -732,19 +732,18 @@ def internal_notify_view(is_tornado_view: bool) -> Callable[[ViewFuncT], ViewFun
def to_utc_datetime(timestamp: str) -> datetime.datetime:
return timestamp_to_datetime(float(timestamp))
def statsd_increment(counter: str, val: int=1,
) -> Callable[[Callable[..., ReturnT]], Callable[..., ReturnT]]:
def statsd_increment(counter: str, val: int=1) -> Callable[[FuncT], FuncT]:
"""Increments a statsd counter on completion of the
decorated function.
Pass the name of the counter to this decorator-returning function."""
def wrapper(func: Callable[..., ReturnT]) -> Callable[..., ReturnT]:
def wrapper(func: FuncT) -> FuncT:
@wraps(func)
def wrapped_func(*args: Any, **kwargs: Any) -> ReturnT:
def wrapped_func(*args: object, **kwargs: object) -> object:
ret = func(*args, **kwargs)
statsd.incr(counter, val)
return ret
return wrapped_func
return cast(FuncT, wrapped_func) # https://github.com/python/mypy/issues/1927
return wrapper
def rate_limit_user(request: HttpRequest, user: UserProfile, domain: str) -> None:

View File

@ -20,6 +20,7 @@ from typing import (
Sequence,
Tuple,
TypeVar,
cast,
)
from django.conf import settings
@ -39,7 +40,7 @@ if TYPE_CHECKING:
MEMCACHED_MAX_KEY_LENGTH = 250
ReturnT = TypeVar('ReturnT') # Useful for matching return types via Callable[..., ReturnT]
FuncT = TypeVar('FuncT', bound=Callable[..., object])
logger = logging.getLogger()
@ -127,15 +128,15 @@ def get_cache_backend(cache_name: Optional[str]) -> BaseCache:
def get_cache_with_key(
keyfunc: Callable[..., str],
cache_name: Optional[str]=None,
) -> Callable[[Callable[..., ReturnT]], Callable[..., ReturnT]]:
) -> Callable[[FuncT], FuncT]:
"""
The main goal of this function getting value from the cache like in the "cache_with_key".
A cache value can contain any data including the "None", so
here used exception for case if value isn't found in the cache.
"""
def decorator(func: Callable[..., ReturnT]) -> (Callable[..., ReturnT]):
def decorator(func: FuncT) -> FuncT:
@wraps(func)
def func_with_caching(*args: Any, **kwargs: Any) -> Callable[..., ReturnT]:
def func_with_caching(*args: object, **kwargs: object) -> object:
key = keyfunc(*args, **kwargs)
try:
val = cache_get(key, cache_name=cache_name)
@ -148,14 +149,14 @@ def get_cache_with_key(
return val[0]
raise NotFoundInCache()
return func_with_caching
return cast(FuncT, func_with_caching) # https://github.com/python/mypy/issues/1927
return decorator
def cache_with_key(
keyfunc: Callable[..., str], cache_name: Optional[str]=None,
timeout: Optional[int]=None, with_statsd_key: Optional[str]=None,
) -> Callable[[Callable[..., ReturnT]], Callable[..., ReturnT]]:
) -> Callable[[FuncT], FuncT]:
"""Decorator which applies Django caching to a function.
Decorator argument is a function which computes a cache key
@ -163,9 +164,9 @@ def cache_with_key(
for avoiding collisions with other uses of this decorator or
other uses of caching."""
def decorator(func: Callable[..., ReturnT]) -> Callable[..., ReturnT]:
def decorator(func: FuncT) -> FuncT:
@wraps(func)
def func_with_caching(*args: Any, **kwargs: Any) -> ReturnT:
def func_with_caching(*args: object, **kwargs: object) -> object:
key = keyfunc(*args, **kwargs)
try:
@ -198,7 +199,7 @@ def cache_with_key(
return val
return func_with_caching
return cast(FuncT, func_with_caching) # https://github.com/python/mypy/issues/1927
return decorator

View File

@ -1,10 +1,10 @@
import cProfile
from functools import wraps
from typing import Any, Callable, TypeVar
from typing import Callable, TypeVar, cast
ReturnT = TypeVar('ReturnT')
FuncT = TypeVar('FuncT', bound=Callable[..., object])
def profiled(func: Callable[..., ReturnT]) -> Callable[..., ReturnT]:
def profiled(func: FuncT) -> FuncT:
"""
This decorator should obviously be used only in a dev environment.
It works best when surrounding a function that you expect to be
@ -21,11 +21,13 @@ def profiled(func: Callable[..., ReturnT]) -> Callable[..., ReturnT]:
./tools/show-profile-results test_ratelimit_decrease.profile
"""
func_: Callable[..., object] = func # work around https://github.com/python/mypy/issues/9075
@wraps(func)
def wrapped_func(*args: Any, **kwargs: Any) -> ReturnT:
def wrapped_func(*args: object, **kwargs: object) -> object:
fn = func.__name__ + ".profile"
prof = cProfile.Profile()
retval: ReturnT = prof.runcall(func, *args, **kwargs)
retval = prof.runcall(func_, *args, **kwargs)
prof.dump_stats(fn)
return retval
return wrapped_func
return cast(FuncT, wrapped_func) # https://github.com/python/mypy/issues/1927

View File

@ -2,7 +2,7 @@ import json
import os
import sys
from functools import wraps
from typing import Any, Callable, Dict, Iterable, List, Optional, Set
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, TypeVar, cast
from zulip import Client
@ -12,27 +12,29 @@ from zerver.openapi.openapi import validate_against_openapi_schema
ZULIP_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
TEST_FUNCTIONS: Dict[str, Callable[..., None]] = dict()
TEST_FUNCTIONS: Dict[str, Callable[..., object]] = dict()
REGISTERED_TEST_FUNCTIONS: Set[str] = set()
CALLED_TEST_FUNCTIONS: Set[str] = set()
def openapi_test_function(endpoint: str) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
FuncT = TypeVar("FuncT", bound=Callable[..., object])
def openapi_test_function(endpoint: str) -> Callable[[FuncT], FuncT]:
"""This decorator is used to register an openapi test function with
its endpoint. Example usage:
@openapi_test_function("/messages/render:post")
def ...
"""
def wrapper(test_func: Callable[..., Any]) -> Callable[..., Any]:
def wrapper(test_func: FuncT) -> FuncT:
@wraps(test_func)
def _record_calls_wrapper(*args: Any, **kwargs: Any) -> Any:
def _record_calls_wrapper(*args: object, **kwargs: object) -> object:
CALLED_TEST_FUNCTIONS.add(test_func.__name__)
return test_func(*args, **kwargs)
REGISTERED_TEST_FUNCTIONS.add(test_func.__name__)
TEST_FUNCTIONS[endpoint] = _record_calls_wrapper
return _record_calls_wrapper
return cast(FuncT, _record_calls_wrapper) # https://github.com/python/mypy/issues/1927
return wrapper
def ensure_users(ids_list: List[int], user_names: List[str]) -> None: