request: Refactor has_request_variables with ParamSpec.

This makes `has_request_variables` more generic, in the sense of the return
value, and also makes it more accurate, in the sense of requiring the
first parameter of the decorated function to be `HttpRequest`, and
preserving the function signature without using `cast`.

This affects some callers of `has_request_variables` or the callers of its
decoratedfunctions in the following manners:

- Decorated non-view functions called directly in other functions cannot
use `request` as a keyword argument. Becasue `Concatenate` turns the
concatenated parameters (`request: HttpRequest` in this case) into
positional-only parameters. Callers of `get_chart_data` are thus
refactored.

- Functions to be decorated that accept variadic keyword arguments must
define `request: HttpRequest` as positional-only. Mypy in strict mode
rejects such functions otherwise because it is possible for the caller to
pass a keyword argument that has the same name as `request` for `**kwargs`.
No defining `request: HttpRequest` as positional-only breaks type safety
because function with positional-or-keyword parameters cannot be considered
a subtype of a function with the same parameters in which some of them are
positional-only.

Consider `f(x: int, /, **kwargs: object) -> int` and `g(x: int,
**kwargs: object) -> int`. `f(12, x="asd")` is valid but `g(12, x="asd")`
is not.

Signed-off-by: Zixuan James Li <p359101898@gmail.com>
This commit is contained in:
Zixuan James Li 2022-07-27 19:11:55 -04:00 committed by Tim Abbott
parent 9f99e6c43c
commit adae8b6d42
4 changed files with 25 additions and 18 deletions

View File

@ -131,7 +131,7 @@ def get_chart_data_for_realm(
except Realm.DoesNotExist:
raise JsonableError(_("Invalid organization"))
return get_chart_data(request=request, user_profile=user_profile, realm=realm, **kwargs)
return get_chart_data(request, user_profile=user_profile, realm=realm, **kwargs)
@require_server_admin_api
@ -147,7 +147,7 @@ def get_chart_data_for_remote_realm(
assert settings.ZILENCER_ENABLED
server = RemoteZulipServer.objects.get(id=remote_server_id)
return get_chart_data(
request=request,
request,
user_profile=user_profile,
server=server,
remote=True,
@ -179,9 +179,7 @@ def stats_for_remote_installation(request: HttpRequest, remote_server_id: int) -
def get_chart_data_for_installation(
request: HttpRequest, /, user_profile: UserProfile, chart_name: str = REQ(), **kwargs: Any
) -> HttpResponse:
return get_chart_data(
request=request, user_profile=user_profile, for_installation=True, **kwargs
)
return get_chart_data(request, user_profile=user_profile, for_installation=True, **kwargs)
@require_server_admin_api
@ -197,7 +195,7 @@ def get_chart_data_for_remote_installation(
assert settings.ZILENCER_ENABLED
server = RemoteZulipServer.objects.get(id=remote_server_id)
return get_chart_data(
request=request,
request,
user_profile=user_profile,
for_installation=True,
remote=True,

View File

@ -364,7 +364,7 @@ def webhook_view(
@has_request_variables
@wraps(view_func)
def _wrapped_func_arguments(
request: HttpRequest, api_key: str = REQ(), *args: object, **kwargs: object
request: HttpRequest, /, api_key: str = REQ(), *args: object, **kwargs: object
) -> HttpResponse:
user_profile = validate_api_key(
request,
@ -677,7 +677,7 @@ def authenticated_uploads_api_view(
@has_request_variables
@wraps(view_func)
def _wrapped_func_arguments(
request: HttpRequest, api_key: str = REQ(), *args: object, **kwargs: object
request: HttpRequest, /, api_key: str = REQ(), *args: object, **kwargs: object
) -> HttpResponse:
user_profile = validate_api_key(request, None, api_key, False)
if not skip_rate_limiting:

View File

@ -25,11 +25,12 @@ from django.conf import settings
from django.core.exceptions import ValidationError
from django.http import HttpRequest, HttpResponse
from django.utils.translation import gettext as _
from typing_extensions import Concatenate, ParamSpec
import zerver.lib.rate_limiter as rate_limiter
from zerver.lib.exceptions import ErrorCode, InvalidJSONError, JsonableError
from zerver.lib.notes import BaseNotes
from zerver.lib.types import Validator, ViewFuncT
from zerver.lib.types import Validator
from zerver.lib.validator import check_anything
from zerver.models import Client, Realm
@ -314,6 +315,9 @@ def REQ(
arguments_map: Dict[str, List[str]] = defaultdict(list)
ParamT = ParamSpec("ParamT")
ReturnT = TypeVar("ReturnT")
# Extracts variables from the request object and passes them as
# named function arguments. The request object must be the first
@ -331,17 +335,19 @@ arguments_map: Dict[str, List[str]] = defaultdict(list)
# Note that this can't be used in helper functions which are not
# expected to call json_success or raise JsonableError, as it uses JsonableError
# internally when it encounters an error
def has_request_variables(view_func: ViewFuncT) -> ViewFuncT:
num_params = view_func.__code__.co_argcount
default_param_values = cast(FunctionType, view_func).__defaults__
def has_request_variables(
req_func: Callable[Concatenate[HttpRequest, ParamT], ReturnT]
) -> Callable[Concatenate[HttpRequest, ParamT], ReturnT]:
num_params = req_func.__code__.co_argcount
default_param_values = cast(FunctionType, req_func).__defaults__
if default_param_values is None:
default_param_values = ()
num_default_params = len(default_param_values)
default_param_names = view_func.__code__.co_varnames[num_params - num_default_params :]
default_param_names = req_func.__code__.co_varnames[num_params - num_default_params :]
post_params = []
view_func_full_name = ".".join([view_func.__module__, view_func.__name__])
view_func_full_name = ".".join([req_func.__module__, req_func.__name__])
for (name, value) in zip(default_param_names, default_param_values):
if isinstance(value, _REQ):
@ -359,8 +365,10 @@ def has_request_variables(view_func: ViewFuncT) -> ViewFuncT:
):
arguments_map[view_func_full_name].append(value.post_var_name)
@wraps(view_func)
def _wrapped_view_func(request: HttpRequest, *args: object, **kwargs: object) -> HttpResponse:
@wraps(req_func)
def _wrapped_req_func(
request: HttpRequest, /, *args: ParamT.args, **kwargs: ParamT.kwargs
) -> ReturnT:
request_notes = RequestNotes.get_notes(request)
for param in post_params:
func_var_name = param.func_var_name
@ -447,9 +455,9 @@ def has_request_variables(view_func: ViewFuncT) -> ViewFuncT:
kwargs[func_var_name] = val
return view_func(request, *args, **kwargs)
return req_func(request, *args, **kwargs)
return cast(ViewFuncT, _wrapped_view_func) # https://github.com/python/mypy/issues/1927
return _wrapped_req_func
local = threading.local()

View File

@ -754,6 +754,7 @@ class TwoFactorLoginView(BaseTwoFactorLoginView):
@has_request_variables
def login_page(
request: HttpRequest,
/,
next: str = REQ(default="/"),
**kwargs: Any,
) -> HttpResponse: