From 7e92ff9d0a5b6909010c1fcd6cf25a409cb5909b Mon Sep 17 00:00:00 2001 From: Anders Kaseorg Date: Wed, 7 Aug 2019 02:15:46 -0700 Subject: [PATCH] request: Delete request.pyi and make request.py type check. Signed-off-by: Anders Kaseorg --- zerver/lib/request.py | 102 ++++++++++++++++++++++++----------- zerver/lib/request.pyi | 35 ------------ zerver/tests/test_openapi.py | 6 +-- 3 files changed, 73 insertions(+), 70 deletions(-) delete mode 100644 zerver/lib/request.pyi diff --git a/zerver/lib/request.py b/zerver/lib/request.py index f9480ef065..c846220091 100644 --- a/zerver/lib/request.py +++ b/zerver/lib/request.py @@ -1,23 +1,17 @@ -# When adding new functions/classes to this file, you need to also add -# their types to request.pyi in this directory (a mypy stubs file that -# we use to ensure mypy does correct type inference with REQ, which it -# can't do by default due to the dynamic nature of REQ). -# -# Because request.pyi exists, the type annotations in this file are -# mostly not processed by mypy. from collections import defaultdict from functools import wraps +from types import FunctionType import ujson from django.utils.translation import ugettext as _ from zerver.lib.exceptions import JsonableError, ErrorCode, \ InvalidJSONError -from zerver.lib.types import ViewFuncT +from zerver.lib.types import Validator, ViewFuncT from django.http import HttpRequest, HttpResponse -from typing import Any, Callable, Dict, List, Optional, Type +from typing import Any, Callable, Dict, Generic, List, Optional, Type, TypeVar, Union, cast class RequestConfusingParmsError(JsonableError): code = ErrorCode.REQUEST_CONFUSING_VAR @@ -55,7 +49,9 @@ class RequestVariableConversionError(JsonableError): return _("Bad value for '{var_name}': {bad_value}") # Used in conjunction with @has_request_variables, below -class REQ: +ResultT = TypeVar('ResultT') + +class _REQ(Generic[ResultT]): # NotSpecified is a sentinel value for determining whether a # default value was specified for a request variable. We can't # use None because that could be a valid, user-specified default @@ -63,13 +59,20 @@ class REQ: pass NotSpecified = _NotSpecified() - def __init__(self, whence: str=None, *, converter: Callable[[Any], Any]=None, - default: Any=NotSpecified, validator: Callable[[Any], Any]=None, - str_validator: Callable[[Any], Any]=None, - argument_type: str=None, type: Type=None, - intentionally_undocumented=False, - documentation_pending=False, - aliases: Optional[List[str]]=None) -> None: + def __init__( + self, + whence: Optional[str] = None, + *, + type: Type[ResultT] = Type[None], + converter: Optional[Callable[[str], ResultT]] = None, + default: Union[_NotSpecified, ResultT, None] = NotSpecified, + validator: Optional[Validator] = None, + str_validator: Optional[Validator] = None, + argument_type: Optional[str] = None, + intentionally_undocumented: bool=False, + documentation_pending: bool=False, + aliases: Optional[List[str]] = None + ) -> None: """whence: the name of the request variable that should be used for this parameter. Defaults to a request variable of the same name as the parameter. @@ -98,7 +101,7 @@ class REQ: """ self.post_var_name = whence - self.func_var_name = None # type: str + self.func_var_name = None # type: Optional[str] self.converter = converter self.validator = validator self.str_validator = str_validator @@ -115,6 +118,40 @@ class REQ: # Not user-facing, so shouldn't be tagged for translation raise AssertionError('validator and str_validator are mutually exclusive') +# This factory function ensures that mypy can correctly analyze REQ. +# +# Note that REQ claims to return a type matching that of the parameter +# of which it is the default value, allowing type checking of view +# functions using has_request_variables. In reality, REQ returns an +# instance of class _REQ to enable the decorator to scan the parameter +# list for _REQ objects and patch the parameters as the true types. + +def REQ( + whence: Optional[str] = None, + *, + type: Type[ResultT] = Type[None], + converter: Optional[Callable[[str], ResultT]] = None, + default: Union[_REQ._NotSpecified, ResultT, None] = _REQ.NotSpecified, + validator: Optional[Validator] = None, + str_validator: Optional[Validator] = None, + argument_type: Optional[str] = None, + intentionally_undocumented: bool=False, + documentation_pending: bool=False, + aliases: Optional[List[str]] = None +) -> ResultT: + return cast(ResultT, _REQ( + whence, + type=type, + converter=converter, + default=default, + validator=validator, + str_validator=str_validator, + argument_type=argument_type, + intentionally_undocumented=intentionally_undocumented, + documentation_pending=documentation_pending, + aliases=aliases, + )) + arguments_map = defaultdict(list) # type: Dict[str, List[str]] # Extracts variables from the request object and passes them as @@ -122,7 +159,7 @@ arguments_map = defaultdict(list) # type: Dict[str, List[str]] # argument to the function. # # To use, assign a function parameter a default value that is an -# instance of the REQ class. That parameter will then be automatically +# instance of the _REQ class. That parameter will then be automatically # populated from the HTTP request. The request object must be the # first argument to the decorated function. # @@ -135,21 +172,18 @@ arguments_map = defaultdict(list) # type: Dict[str, List[str]] # internally when it encounters an error def has_request_variables(view_func: ViewFuncT) -> ViewFuncT: num_params = view_func.__code__.co_argcount - if view_func.__defaults__ is None: - num_default_params = 0 - else: - num_default_params = len(view_func.__defaults__) - default_param_names = view_func.__code__.co_varnames[num_params - num_default_params:] - default_param_values = view_func.__defaults__ + default_param_values = cast(FunctionType, view_func).__defaults__ if default_param_values is None: - default_param_values = [] + default_param_values = () + num_default_params = len(default_param_values) + default_param_names = view_func.__code__.co_varnames[num_params - num_default_params:] post_params = [] view_func_full_name = '.'.join([view_func.__module__, view_func.__name__]) for (name, value) in zip(default_param_names, default_param_values): - if isinstance(value, REQ): + if isinstance(value, _REQ): value.func_var_name = name if value.post_var_name is None: value.post_var_name = name @@ -163,15 +197,17 @@ def has_request_variables(view_func: ViewFuncT) -> ViewFuncT: @wraps(view_func) def _wrapped_view_func(request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse: for param in post_params: - if param.func_var_name in kwargs: + func_var_name = param.func_var_name + if func_var_name in kwargs: continue + assert func_var_name is not None if param.argument_type == 'body': try: val = ujson.loads(request.body) except ValueError: raise InvalidJSONError(_("Malformed JSON")) - kwargs[param.func_var_name] = val + kwargs[func_var_name] = val continue elif param.argument_type is not None: # This is a view bug, not a user error, and thus should throw a 500. @@ -194,12 +230,14 @@ def has_request_variables(view_func: ViewFuncT) -> ViewFuncT: except KeyError: continue if post_var_name is not None: + assert req_var is not None raise RequestConfusingParmsError(post_var_name, req_var) post_var_name = req_var if post_var_name is None: post_var_name = param.post_var_name - if param.default is REQ.NotSpecified: + assert post_var_name is not None + if param.default is _REQ.NotSpecified: raise RequestVariableMissingError(post_var_name) val = param.default default_assigned = True @@ -229,8 +267,8 @@ def has_request_variables(view_func: ViewFuncT) -> ViewFuncT: if error: raise JsonableError(error) - kwargs[param.func_var_name] = val + kwargs[func_var_name] = val return view_func(request, *args, **kwargs) - return _wrapped_view_func + return cast(ViewFuncT, _wrapped_view_func) diff --git a/zerver/lib/request.pyi b/zerver/lib/request.pyi deleted file mode 100644 index c0b7d210d2..0000000000 --- a/zerver/lib/request.pyi +++ /dev/null @@ -1,35 +0,0 @@ -# This mypy stubs file ensures that mypy can correctly analyze REQ. -# -# Note that here REQ is claimed to be a function, with a return type to match -# that of the parameter of which it is the default value, allowing type -# checking. However, in request.py, REQ is a class to enable the decorator to -# scan the parameter list for REQ objects and patch the parameters as the true -# types. - -from typing import Dict, Callable, List, TypeVar, Optional, Union, Type -from zerver.lib.types import ViewFuncT, Validator -from zerver.lib.exceptions import JsonableError as JsonableError - -ResultT = TypeVar('ResultT') - -class RequestConfusingParmsError(JsonableError): ... -class RequestVariableMissingError(JsonableError): ... -class RequestVariableConversionError(JsonableError): ... - -class _NotSpecified: ... -NotSpecified = _NotSpecified() - -def REQ(whence: Optional[str] = None, - *, - type: Type[ResultT] = Type[None], - converter: Optional[Callable[[str], ResultT]] = None, - default: Union[_NotSpecified, ResultT, None] = NotSpecified, - validator: Optional[Validator] = None, - str_validator: Optional[Validator] = None, - argument_type: Optional[str] = None, - intentionally_undocumented: bool=False, - documentation_pending: bool=False, - aliases: Optional[List[str]] = None) -> ResultT: ... - -def has_request_variables(view_func: ViewFuncT) -> ViewFuncT: ... -arguments_map = ... # type: Dict[str, List[str]] diff --git a/zerver/tests/test_openapi.py b/zerver/tests/test_openapi.py index e3320944c2..0754e1cf91 100644 --- a/zerver/tests/test_openapi.py +++ b/zerver/tests/test_openapi.py @@ -14,7 +14,7 @@ from django.http import HttpResponse import zerver.lib.openapi as openapi from zerver.lib.bugdown.api_code_examples import generate_curl_example, \ render_curl_example, parse_language_and_options -from zerver.lib.request import REQ +from zerver.lib.request import _REQ from zerver.lib.test_classes import ZulipTestCase from zerver.lib.openapi import ( get_openapi_fixture, get_openapi_parameters, @@ -414,7 +414,7 @@ do not match the types declared in the implementation of {}.\n""".format(functio # of its parameters. for vname, defval in inspect.signature(function).parameters.items(): defval = defval.default - if defval.__class__ == REQ: + if defval.__class__ is _REQ: # TODO: The below inference logic in cases where # there's a converter function declared is incorrect. # Theoretically, we could restructure the converter @@ -423,7 +423,7 @@ do not match the types declared in the implementation of {}.\n""".format(functio # possible. vtype = self.get_standardized_argument_type(function.__annotations__[vname]) - vname = defval.post_var_name # type: ignore # See zerver/lib/request.pyi + vname = defval.post_var_name # type: ignore # See zerver/lib/request.py function_params.add((vname, vtype)) diff = openapi_params - function_params