mirror of https://github.com/zulip/zulip.git
request: Refactor has_request_variables with ParamSpec.
This makes `has_request_variables` more generic, in the sense of the return value, and also makes it more accurate, in the sense of requiring the first parameter of the decorated function to be `HttpRequest`, and preserving the function signature without using `cast`. This affects some callers of `has_request_variables` or the callers of its decoratedfunctions in the following manners: - Decorated non-view functions called directly in other functions cannot use `request` as a keyword argument. Becasue `Concatenate` turns the concatenated parameters (`request: HttpRequest` in this case) into positional-only parameters. Callers of `get_chart_data` are thus refactored. - Functions to be decorated that accept variadic keyword arguments must define `request: HttpRequest` as positional-only. Mypy in strict mode rejects such functions otherwise because it is possible for the caller to pass a keyword argument that has the same name as `request` for `**kwargs`. No defining `request: HttpRequest` as positional-only breaks type safety because function with positional-or-keyword parameters cannot be considered a subtype of a function with the same parameters in which some of them are positional-only. Consider `f(x: int, /, **kwargs: object) -> int` and `g(x: int, **kwargs: object) -> int`. `f(12, x="asd")` is valid but `g(12, x="asd")` is not. Signed-off-by: Zixuan James Li <p359101898@gmail.com>
This commit is contained in:
parent
9f99e6c43c
commit
adae8b6d42
|
@ -131,7 +131,7 @@ def get_chart_data_for_realm(
|
|||
except Realm.DoesNotExist:
|
||||
raise JsonableError(_("Invalid organization"))
|
||||
|
||||
return get_chart_data(request=request, user_profile=user_profile, realm=realm, **kwargs)
|
||||
return get_chart_data(request, user_profile=user_profile, realm=realm, **kwargs)
|
||||
|
||||
|
||||
@require_server_admin_api
|
||||
|
@ -147,7 +147,7 @@ def get_chart_data_for_remote_realm(
|
|||
assert settings.ZILENCER_ENABLED
|
||||
server = RemoteZulipServer.objects.get(id=remote_server_id)
|
||||
return get_chart_data(
|
||||
request=request,
|
||||
request,
|
||||
user_profile=user_profile,
|
||||
server=server,
|
||||
remote=True,
|
||||
|
@ -179,9 +179,7 @@ def stats_for_remote_installation(request: HttpRequest, remote_server_id: int) -
|
|||
def get_chart_data_for_installation(
|
||||
request: HttpRequest, /, user_profile: UserProfile, chart_name: str = REQ(), **kwargs: Any
|
||||
) -> HttpResponse:
|
||||
return get_chart_data(
|
||||
request=request, user_profile=user_profile, for_installation=True, **kwargs
|
||||
)
|
||||
return get_chart_data(request, user_profile=user_profile, for_installation=True, **kwargs)
|
||||
|
||||
|
||||
@require_server_admin_api
|
||||
|
@ -197,7 +195,7 @@ def get_chart_data_for_remote_installation(
|
|||
assert settings.ZILENCER_ENABLED
|
||||
server = RemoteZulipServer.objects.get(id=remote_server_id)
|
||||
return get_chart_data(
|
||||
request=request,
|
||||
request,
|
||||
user_profile=user_profile,
|
||||
for_installation=True,
|
||||
remote=True,
|
||||
|
|
|
@ -364,7 +364,7 @@ def webhook_view(
|
|||
@has_request_variables
|
||||
@wraps(view_func)
|
||||
def _wrapped_func_arguments(
|
||||
request: HttpRequest, api_key: str = REQ(), *args: object, **kwargs: object
|
||||
request: HttpRequest, /, api_key: str = REQ(), *args: object, **kwargs: object
|
||||
) -> HttpResponse:
|
||||
user_profile = validate_api_key(
|
||||
request,
|
||||
|
@ -677,7 +677,7 @@ def authenticated_uploads_api_view(
|
|||
@has_request_variables
|
||||
@wraps(view_func)
|
||||
def _wrapped_func_arguments(
|
||||
request: HttpRequest, api_key: str = REQ(), *args: object, **kwargs: object
|
||||
request: HttpRequest, /, api_key: str = REQ(), *args: object, **kwargs: object
|
||||
) -> HttpResponse:
|
||||
user_profile = validate_api_key(request, None, api_key, False)
|
||||
if not skip_rate_limiting:
|
||||
|
|
|
@ -25,11 +25,12 @@ from django.conf import settings
|
|||
from django.core.exceptions import ValidationError
|
||||
from django.http import HttpRequest, HttpResponse
|
||||
from django.utils.translation import gettext as _
|
||||
from typing_extensions import Concatenate, ParamSpec
|
||||
|
||||
import zerver.lib.rate_limiter as rate_limiter
|
||||
from zerver.lib.exceptions import ErrorCode, InvalidJSONError, JsonableError
|
||||
from zerver.lib.notes import BaseNotes
|
||||
from zerver.lib.types import Validator, ViewFuncT
|
||||
from zerver.lib.types import Validator
|
||||
from zerver.lib.validator import check_anything
|
||||
from zerver.models import Client, Realm
|
||||
|
||||
|
@ -314,6 +315,9 @@ def REQ(
|
|||
|
||||
|
||||
arguments_map: Dict[str, List[str]] = defaultdict(list)
|
||||
ParamT = ParamSpec("ParamT")
|
||||
ReturnT = TypeVar("ReturnT")
|
||||
|
||||
|
||||
# Extracts variables from the request object and passes them as
|
||||
# named function arguments. The request object must be the first
|
||||
|
@ -331,17 +335,19 @@ arguments_map: Dict[str, List[str]] = defaultdict(list)
|
|||
# Note that this can't be used in helper functions which are not
|
||||
# expected to call json_success or raise JsonableError, as it uses JsonableError
|
||||
# internally when it encounters an error
|
||||
def has_request_variables(view_func: ViewFuncT) -> ViewFuncT:
|
||||
num_params = view_func.__code__.co_argcount
|
||||
default_param_values = cast(FunctionType, view_func).__defaults__
|
||||
def has_request_variables(
|
||||
req_func: Callable[Concatenate[HttpRequest, ParamT], ReturnT]
|
||||
) -> Callable[Concatenate[HttpRequest, ParamT], ReturnT]:
|
||||
num_params = req_func.__code__.co_argcount
|
||||
default_param_values = cast(FunctionType, req_func).__defaults__
|
||||
if default_param_values is None:
|
||||
default_param_values = ()
|
||||
num_default_params = len(default_param_values)
|
||||
default_param_names = view_func.__code__.co_varnames[num_params - num_default_params :]
|
||||
default_param_names = req_func.__code__.co_varnames[num_params - num_default_params :]
|
||||
|
||||
post_params = []
|
||||
|
||||
view_func_full_name = ".".join([view_func.__module__, view_func.__name__])
|
||||
view_func_full_name = ".".join([req_func.__module__, req_func.__name__])
|
||||
|
||||
for (name, value) in zip(default_param_names, default_param_values):
|
||||
if isinstance(value, _REQ):
|
||||
|
@ -359,8 +365,10 @@ def has_request_variables(view_func: ViewFuncT) -> ViewFuncT:
|
|||
):
|
||||
arguments_map[view_func_full_name].append(value.post_var_name)
|
||||
|
||||
@wraps(view_func)
|
||||
def _wrapped_view_func(request: HttpRequest, *args: object, **kwargs: object) -> HttpResponse:
|
||||
@wraps(req_func)
|
||||
def _wrapped_req_func(
|
||||
request: HttpRequest, /, *args: ParamT.args, **kwargs: ParamT.kwargs
|
||||
) -> ReturnT:
|
||||
request_notes = RequestNotes.get_notes(request)
|
||||
for param in post_params:
|
||||
func_var_name = param.func_var_name
|
||||
|
@ -447,9 +455,9 @@ def has_request_variables(view_func: ViewFuncT) -> ViewFuncT:
|
|||
|
||||
kwargs[func_var_name] = val
|
||||
|
||||
return view_func(request, *args, **kwargs)
|
||||
return req_func(request, *args, **kwargs)
|
||||
|
||||
return cast(ViewFuncT, _wrapped_view_func) # https://github.com/python/mypy/issues/1927
|
||||
return _wrapped_req_func
|
||||
|
||||
|
||||
local = threading.local()
|
||||
|
|
|
@ -754,6 +754,7 @@ class TwoFactorLoginView(BaseTwoFactorLoginView):
|
|||
@has_request_variables
|
||||
def login_page(
|
||||
request: HttpRequest,
|
||||
/,
|
||||
next: str = REQ(default="/"),
|
||||
**kwargs: Any,
|
||||
) -> HttpResponse:
|
||||
|
|
Loading…
Reference in New Issue