From adae8b6d422c224613fa747adddc5028d2ea90cd Mon Sep 17 00:00:00 2001 From: Zixuan James Li Date: Wed, 27 Jul 2022 19:11:55 -0400 Subject: [PATCH] request: Refactor has_request_variables with ParamSpec. This makes `has_request_variables` more generic, in the sense of the return value, and also makes it more accurate, in the sense of requiring the first parameter of the decorated function to be `HttpRequest`, and preserving the function signature without using `cast`. This affects some callers of `has_request_variables` or the callers of its decoratedfunctions in the following manners: - Decorated non-view functions called directly in other functions cannot use `request` as a keyword argument. Becasue `Concatenate` turns the concatenated parameters (`request: HttpRequest` in this case) into positional-only parameters. Callers of `get_chart_data` are thus refactored. - Functions to be decorated that accept variadic keyword arguments must define `request: HttpRequest` as positional-only. Mypy in strict mode rejects such functions otherwise because it is possible for the caller to pass a keyword argument that has the same name as `request` for `**kwargs`. No defining `request: HttpRequest` as positional-only breaks type safety because function with positional-or-keyword parameters cannot be considered a subtype of a function with the same parameters in which some of them are positional-only. Consider `f(x: int, /, **kwargs: object) -> int` and `g(x: int, **kwargs: object) -> int`. `f(12, x="asd")` is valid but `g(12, x="asd")` is not. Signed-off-by: Zixuan James Li --- analytics/views/stats.py | 10 ++++------ zerver/decorator.py | 4 ++-- zerver/lib/request.py | 28 ++++++++++++++++++---------- zerver/views/auth.py | 1 + 4 files changed, 25 insertions(+), 18 deletions(-) diff --git a/analytics/views/stats.py b/analytics/views/stats.py index d5f878e181..e57e751c1f 100644 --- a/analytics/views/stats.py +++ b/analytics/views/stats.py @@ -131,7 +131,7 @@ def get_chart_data_for_realm( except Realm.DoesNotExist: raise JsonableError(_("Invalid organization")) - return get_chart_data(request=request, user_profile=user_profile, realm=realm, **kwargs) + return get_chart_data(request, user_profile=user_profile, realm=realm, **kwargs) @require_server_admin_api @@ -147,7 +147,7 @@ def get_chart_data_for_remote_realm( assert settings.ZILENCER_ENABLED server = RemoteZulipServer.objects.get(id=remote_server_id) return get_chart_data( - request=request, + request, user_profile=user_profile, server=server, remote=True, @@ -179,9 +179,7 @@ def stats_for_remote_installation(request: HttpRequest, remote_server_id: int) - def get_chart_data_for_installation( request: HttpRequest, /, user_profile: UserProfile, chart_name: str = REQ(), **kwargs: Any ) -> HttpResponse: - return get_chart_data( - request=request, user_profile=user_profile, for_installation=True, **kwargs - ) + return get_chart_data(request, user_profile=user_profile, for_installation=True, **kwargs) @require_server_admin_api @@ -197,7 +195,7 @@ def get_chart_data_for_remote_installation( assert settings.ZILENCER_ENABLED server = RemoteZulipServer.objects.get(id=remote_server_id) return get_chart_data( - request=request, + request, user_profile=user_profile, for_installation=True, remote=True, diff --git a/zerver/decorator.py b/zerver/decorator.py index c604a8013f..0baafddb6c 100644 --- a/zerver/decorator.py +++ b/zerver/decorator.py @@ -364,7 +364,7 @@ def webhook_view( @has_request_variables @wraps(view_func) def _wrapped_func_arguments( - request: HttpRequest, api_key: str = REQ(), *args: object, **kwargs: object + request: HttpRequest, /, api_key: str = REQ(), *args: object, **kwargs: object ) -> HttpResponse: user_profile = validate_api_key( request, @@ -677,7 +677,7 @@ def authenticated_uploads_api_view( @has_request_variables @wraps(view_func) def _wrapped_func_arguments( - request: HttpRequest, api_key: str = REQ(), *args: object, **kwargs: object + request: HttpRequest, /, api_key: str = REQ(), *args: object, **kwargs: object ) -> HttpResponse: user_profile = validate_api_key(request, None, api_key, False) if not skip_rate_limiting: diff --git a/zerver/lib/request.py b/zerver/lib/request.py index 1bd834782f..0da32bd694 100644 --- a/zerver/lib/request.py +++ b/zerver/lib/request.py @@ -25,11 +25,12 @@ from django.conf import settings from django.core.exceptions import ValidationError from django.http import HttpRequest, HttpResponse from django.utils.translation import gettext as _ +from typing_extensions import Concatenate, ParamSpec import zerver.lib.rate_limiter as rate_limiter from zerver.lib.exceptions import ErrorCode, InvalidJSONError, JsonableError from zerver.lib.notes import BaseNotes -from zerver.lib.types import Validator, ViewFuncT +from zerver.lib.types import Validator from zerver.lib.validator import check_anything from zerver.models import Client, Realm @@ -314,6 +315,9 @@ def REQ( arguments_map: Dict[str, List[str]] = defaultdict(list) +ParamT = ParamSpec("ParamT") +ReturnT = TypeVar("ReturnT") + # Extracts variables from the request object and passes them as # named function arguments. The request object must be the first @@ -331,17 +335,19 @@ arguments_map: Dict[str, List[str]] = defaultdict(list) # Note that this can't be used in helper functions which are not # expected to call json_success or raise JsonableError, as it uses JsonableError # internally when it encounters an error -def has_request_variables(view_func: ViewFuncT) -> ViewFuncT: - num_params = view_func.__code__.co_argcount - default_param_values = cast(FunctionType, view_func).__defaults__ +def has_request_variables( + req_func: Callable[Concatenate[HttpRequest, ParamT], ReturnT] +) -> Callable[Concatenate[HttpRequest, ParamT], ReturnT]: + num_params = req_func.__code__.co_argcount + default_param_values = cast(FunctionType, req_func).__defaults__ if default_param_values is None: default_param_values = () num_default_params = len(default_param_values) - default_param_names = view_func.__code__.co_varnames[num_params - num_default_params :] + default_param_names = req_func.__code__.co_varnames[num_params - num_default_params :] post_params = [] - view_func_full_name = ".".join([view_func.__module__, view_func.__name__]) + view_func_full_name = ".".join([req_func.__module__, req_func.__name__]) for (name, value) in zip(default_param_names, default_param_values): if isinstance(value, _REQ): @@ -359,8 +365,10 @@ def has_request_variables(view_func: ViewFuncT) -> ViewFuncT: ): arguments_map[view_func_full_name].append(value.post_var_name) - @wraps(view_func) - def _wrapped_view_func(request: HttpRequest, *args: object, **kwargs: object) -> HttpResponse: + @wraps(req_func) + def _wrapped_req_func( + request: HttpRequest, /, *args: ParamT.args, **kwargs: ParamT.kwargs + ) -> ReturnT: request_notes = RequestNotes.get_notes(request) for param in post_params: func_var_name = param.func_var_name @@ -447,9 +455,9 @@ def has_request_variables(view_func: ViewFuncT) -> ViewFuncT: kwargs[func_var_name] = val - return view_func(request, *args, **kwargs) + return req_func(request, *args, **kwargs) - return cast(ViewFuncT, _wrapped_view_func) # https://github.com/python/mypy/issues/1927 + return _wrapped_req_func local = threading.local() diff --git a/zerver/views/auth.py b/zerver/views/auth.py index f642546179..f0b3dc85bb 100644 --- a/zerver/views/auth.py +++ b/zerver/views/auth.py @@ -754,6 +754,7 @@ class TwoFactorLoginView(BaseTwoFactorLoginView): @has_request_variables def login_page( request: HttpRequest, + /, next: str = REQ(default="/"), **kwargs: Any, ) -> HttpResponse: