mirror of https://github.com/zulip/zulip.git
decorator: Strengthen types of signature-preserving decorators.
Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
parent
a7bac82f2e
commit
e582bbea4a
|
@ -4,7 +4,7 @@ import logging
|
|||
import urllib
|
||||
from functools import wraps
|
||||
from io import BytesIO
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, TypeVar, Union
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, TypeVar, Union, cast
|
||||
|
||||
import django_otp
|
||||
import ujson
|
||||
|
@ -61,8 +61,6 @@ else: # nocoverage # Hack here basically to make impossible code paths compile
|
|||
get_remote_server_by_uuid = Mock()
|
||||
RemoteZulipServer = Mock() # type: ignore[misc] # https://github.com/JukkaL/mypy/issues/1188
|
||||
|
||||
ReturnT = TypeVar('ReturnT')
|
||||
|
||||
webhook_logger = logging.getLogger("zulip.zerver.webhooks")
|
||||
log_to_file(webhook_logger, settings.API_KEY_ONLY_WEBHOOK_LOG_PATH)
|
||||
|
||||
|
@ -70,17 +68,19 @@ webhook_unexpected_events_logger = logging.getLogger("zulip.zerver.lib.webhooks.
|
|||
log_to_file(webhook_unexpected_events_logger,
|
||||
settings.WEBHOOK_UNEXPECTED_EVENTS_LOG_PATH)
|
||||
|
||||
def cachify(method: Callable[..., ReturnT]) -> Callable[..., ReturnT]:
|
||||
dct: Dict[Tuple[Any, ...], ReturnT] = {}
|
||||
FuncT = TypeVar('FuncT', bound=Callable[..., object])
|
||||
|
||||
def cache_wrapper(*args: Any) -> ReturnT:
|
||||
def cachify(method: FuncT) -> FuncT:
|
||||
dct: Dict[Tuple[object, ...], object] = {}
|
||||
|
||||
def cache_wrapper(*args: object) -> object:
|
||||
tup = tuple(args)
|
||||
if tup in dct:
|
||||
return dct[tup]
|
||||
result = method(*args)
|
||||
dct[tup] = result
|
||||
return result
|
||||
return cache_wrapper
|
||||
return cast(FuncT, cache_wrapper) # https://github.com/python/mypy/issues/1927
|
||||
|
||||
def update_user_activity(request: HttpRequest, user_profile: UserProfile,
|
||||
query: Optional[str]) -> None:
|
||||
|
@ -732,19 +732,18 @@ def internal_notify_view(is_tornado_view: bool) -> Callable[[ViewFuncT], ViewFun
|
|||
def to_utc_datetime(timestamp: str) -> datetime.datetime:
|
||||
return timestamp_to_datetime(float(timestamp))
|
||||
|
||||
def statsd_increment(counter: str, val: int=1,
|
||||
) -> Callable[[Callable[..., ReturnT]], Callable[..., ReturnT]]:
|
||||
def statsd_increment(counter: str, val: int=1) -> Callable[[FuncT], FuncT]:
|
||||
"""Increments a statsd counter on completion of the
|
||||
decorated function.
|
||||
|
||||
Pass the name of the counter to this decorator-returning function."""
|
||||
def wrapper(func: Callable[..., ReturnT]) -> Callable[..., ReturnT]:
|
||||
def wrapper(func: FuncT) -> FuncT:
|
||||
@wraps(func)
|
||||
def wrapped_func(*args: Any, **kwargs: Any) -> ReturnT:
|
||||
def wrapped_func(*args: object, **kwargs: object) -> object:
|
||||
ret = func(*args, **kwargs)
|
||||
statsd.incr(counter, val)
|
||||
return ret
|
||||
return wrapped_func
|
||||
return cast(FuncT, wrapped_func) # https://github.com/python/mypy/issues/1927
|
||||
return wrapper
|
||||
|
||||
def rate_limit_user(request: HttpRequest, user: UserProfile, domain: str) -> None:
|
||||
|
|
|
@ -20,6 +20,7 @@ from typing import (
|
|||
Sequence,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
cast,
|
||||
)
|
||||
|
||||
from django.conf import settings
|
||||
|
@ -39,7 +40,7 @@ if TYPE_CHECKING:
|
|||
|
||||
MEMCACHED_MAX_KEY_LENGTH = 250
|
||||
|
||||
ReturnT = TypeVar('ReturnT') # Useful for matching return types via Callable[..., ReturnT]
|
||||
FuncT = TypeVar('FuncT', bound=Callable[..., object])
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
@ -127,15 +128,15 @@ def get_cache_backend(cache_name: Optional[str]) -> BaseCache:
|
|||
def get_cache_with_key(
|
||||
keyfunc: Callable[..., str],
|
||||
cache_name: Optional[str]=None,
|
||||
) -> Callable[[Callable[..., ReturnT]], Callable[..., ReturnT]]:
|
||||
) -> Callable[[FuncT], FuncT]:
|
||||
"""
|
||||
The main goal of this function getting value from the cache like in the "cache_with_key".
|
||||
A cache value can contain any data including the "None", so
|
||||
here used exception for case if value isn't found in the cache.
|
||||
"""
|
||||
def decorator(func: Callable[..., ReturnT]) -> (Callable[..., ReturnT]):
|
||||
def decorator(func: FuncT) -> FuncT:
|
||||
@wraps(func)
|
||||
def func_with_caching(*args: Any, **kwargs: Any) -> Callable[..., ReturnT]:
|
||||
def func_with_caching(*args: object, **kwargs: object) -> object:
|
||||
key = keyfunc(*args, **kwargs)
|
||||
try:
|
||||
val = cache_get(key, cache_name=cache_name)
|
||||
|
@ -148,14 +149,14 @@ def get_cache_with_key(
|
|||
return val[0]
|
||||
raise NotFoundInCache()
|
||||
|
||||
return func_with_caching
|
||||
return cast(FuncT, func_with_caching) # https://github.com/python/mypy/issues/1927
|
||||
|
||||
return decorator
|
||||
|
||||
def cache_with_key(
|
||||
keyfunc: Callable[..., str], cache_name: Optional[str]=None,
|
||||
timeout: Optional[int]=None, with_statsd_key: Optional[str]=None,
|
||||
) -> Callable[[Callable[..., ReturnT]], Callable[..., ReturnT]]:
|
||||
) -> Callable[[FuncT], FuncT]:
|
||||
"""Decorator which applies Django caching to a function.
|
||||
|
||||
Decorator argument is a function which computes a cache key
|
||||
|
@ -163,9 +164,9 @@ def cache_with_key(
|
|||
for avoiding collisions with other uses of this decorator or
|
||||
other uses of caching."""
|
||||
|
||||
def decorator(func: Callable[..., ReturnT]) -> Callable[..., ReturnT]:
|
||||
def decorator(func: FuncT) -> FuncT:
|
||||
@wraps(func)
|
||||
def func_with_caching(*args: Any, **kwargs: Any) -> ReturnT:
|
||||
def func_with_caching(*args: object, **kwargs: object) -> object:
|
||||
key = keyfunc(*args, **kwargs)
|
||||
|
||||
try:
|
||||
|
@ -198,7 +199,7 @@ def cache_with_key(
|
|||
|
||||
return val
|
||||
|
||||
return func_with_caching
|
||||
return cast(FuncT, func_with_caching) # https://github.com/python/mypy/issues/1927
|
||||
|
||||
return decorator
|
||||
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
import cProfile
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, TypeVar
|
||||
from typing import Callable, TypeVar, cast
|
||||
|
||||
ReturnT = TypeVar('ReturnT')
|
||||
FuncT = TypeVar('FuncT', bound=Callable[..., object])
|
||||
|
||||
def profiled(func: Callable[..., ReturnT]) -> Callable[..., ReturnT]:
|
||||
def profiled(func: FuncT) -> FuncT:
|
||||
"""
|
||||
This decorator should obviously be used only in a dev environment.
|
||||
It works best when surrounding a function that you expect to be
|
||||
|
@ -21,11 +21,13 @@ def profiled(func: Callable[..., ReturnT]) -> Callable[..., ReturnT]:
|
|||
./tools/show-profile-results test_ratelimit_decrease.profile
|
||||
|
||||
"""
|
||||
func_: Callable[..., object] = func # work around https://github.com/python/mypy/issues/9075
|
||||
|
||||
@wraps(func)
|
||||
def wrapped_func(*args: Any, **kwargs: Any) -> ReturnT:
|
||||
def wrapped_func(*args: object, **kwargs: object) -> object:
|
||||
fn = func.__name__ + ".profile"
|
||||
prof = cProfile.Profile()
|
||||
retval: ReturnT = prof.runcall(func, *args, **kwargs)
|
||||
retval = prof.runcall(func_, *args, **kwargs)
|
||||
prof.dump_stats(fn)
|
||||
return retval
|
||||
return wrapped_func
|
||||
return cast(FuncT, wrapped_func) # https://github.com/python/mypy/issues/1927
|
||||
|
|
|
@ -2,7 +2,7 @@ import json
|
|||
import os
|
||||
import sys
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Set
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, TypeVar, cast
|
||||
|
||||
from zulip import Client
|
||||
|
||||
|
@ -12,27 +12,29 @@ from zerver.openapi.openapi import validate_against_openapi_schema
|
|||
|
||||
ZULIP_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
TEST_FUNCTIONS: Dict[str, Callable[..., None]] = dict()
|
||||
TEST_FUNCTIONS: Dict[str, Callable[..., object]] = dict()
|
||||
REGISTERED_TEST_FUNCTIONS: Set[str] = set()
|
||||
CALLED_TEST_FUNCTIONS: Set[str] = set()
|
||||
|
||||
def openapi_test_function(endpoint: str) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
FuncT = TypeVar("FuncT", bound=Callable[..., object])
|
||||
|
||||
def openapi_test_function(endpoint: str) -> Callable[[FuncT], FuncT]:
|
||||
"""This decorator is used to register an openapi test function with
|
||||
its endpoint. Example usage:
|
||||
|
||||
@openapi_test_function("/messages/render:post")
|
||||
def ...
|
||||
"""
|
||||
def wrapper(test_func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
def wrapper(test_func: FuncT) -> FuncT:
|
||||
@wraps(test_func)
|
||||
def _record_calls_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
def _record_calls_wrapper(*args: object, **kwargs: object) -> object:
|
||||
CALLED_TEST_FUNCTIONS.add(test_func.__name__)
|
||||
return test_func(*args, **kwargs)
|
||||
|
||||
REGISTERED_TEST_FUNCTIONS.add(test_func.__name__)
|
||||
TEST_FUNCTIONS[endpoint] = _record_calls_wrapper
|
||||
|
||||
return _record_calls_wrapper
|
||||
return cast(FuncT, _record_calls_wrapper) # https://github.com/python/mypy/issues/1927
|
||||
return wrapper
|
||||
|
||||
def ensure_users(ids_list: List[int], user_names: List[str]) -> None:
|
||||
|
|
Loading…
Reference in New Issue