mypy: Improve typing of request.pyi (REQ).

This expands request.pyi to type-check the arguments passed into REQ.

Tweaked by tabbott to fix the report.py annotations.
This commit is contained in:
neiljp (Neil Pilgrim) 2017-11-02 21:13:04 -07:00 committed by Tim Abbott
parent dd1920c811
commit 42f5eea61f
4 changed files with 41 additions and 17 deletions

View File

@ -1,14 +1,31 @@
# This mypy stubs file ensures that mypy can correctly analyze REQ.
from typing import Any, Callable, Text, TypeVar
#
# Note that here REQ is claimed to be a function, with a return type to match
# that of the parameter of which it is the default value, allowing type
# checking. However, in request.py, REQ is a class to enable the decorator to
# scan the parameter list for REQ objects and patch the parameters as the true
# types.
from typing import Any, Callable, Text, TypeVar, Optional, Union
from django.http import HttpResponse
from zerver.lib.exceptions import JsonableError as JsonableError
Validator = Callable[[str, Any], Optional[str]]
ResultT = TypeVar('ResultT')
ViewFuncT = TypeVar('ViewFuncT', bound=Callable[..., HttpResponse])
class RequestVariableMissingError(JsonableError): ...
class RequestVariableConversionError(JsonableError): ...
def REQ(*args: Any, **kwargs: Any) -> Any: ...
class _NotSpecified: ...
NotSpecified = _NotSpecified()
def REQ(whence: Optional[str] = None,
converter: Optional[Callable[[str], ResultT]] = None,
default: Union[_NotSpecified, ResultT] = NotSpecified,
validator: Optional[Validator] = None,
argument_type: Optional[str] = None) -> ResultT: ...
def has_request_variables(view_func: ViewFuncT) -> ViewFuncT: ...

View File

@ -102,13 +102,13 @@ class DecoratorTestCase(TestCase):
# type: () -> None
def my_converter(data):
# type: (str) -> List[str]
# type: (str) -> List[int]
lst = ujson.loads(data)
if not isinstance(lst, list):
raise ValueError('not a list')
if 13 in lst:
raise JsonableError('13 is an unlucky number!')
return lst
return [int(elem) for elem in lst]
@has_request_variables
def get_total(request, numbers=REQ(converter=my_converter)):
@ -148,7 +148,7 @@ class DecoratorTestCase(TestCase):
with self.assertRaisesRegex(AssertionError, "converter and validator are mutually exclusive"):
@has_request_variables
def get_total(request, numbers=REQ(validator=check_list(check_int),
converter=lambda: None)):
converter=lambda x: [])):
# type: (HttpRequest, Iterable[int]) -> int
return sum(numbers) # nocoverage -- isn't intended to be run

View File

@ -1,6 +1,6 @@
# System documented in https://zulip.readthedocs.io/en/latest/logging.html
from typing import Any, Dict, Optional, Text
from typing import Any, Dict, Optional, Text, Union
from django.conf import settings
from django.http import HttpRequest, HttpResponse
@ -33,20 +33,27 @@ def get_js_source_map() -> Optional[SourceMap]:
@human_users_only
@has_request_variables
def report_send_times(request, user_profile,
time=REQ(converter=to_non_negative_int),
received=REQ(converter=to_non_negative_int, default="(unknown)"),
displayed=REQ(converter=to_non_negative_int, default="(unknown)"),
locally_echoed=REQ(validator=check_bool, default=False),
rendered_content_disparity=REQ(validator=check_bool, default=False)):
# type: (HttpRequest, UserProfile, int, int, int, bool, bool) -> HttpResponse
def report_send_times(request: HttpRequest, user_profile: UserProfile,
time: int=REQ(converter=to_non_negative_int),
received: int=REQ(converter=to_non_negative_int, default=-1),
displayed: int=REQ(converter=to_non_negative_int, default=-1),
locally_echoed: bool=REQ(validator=check_bool, default=False),
rendered_content_disparity: bool=REQ(validator=check_bool, default=False)) -> HttpResponse:
received_str = "(unknown)"
if received > 0:
received_str = str(received)
displayed_str = "(unknown)"
if displayed > 0:
displayed_str = str(displayed)
request._log_data["extra"] = "[%sms/%sms/%sms/echo:%s/diff:%s]" \
% (time, received, displayed, locally_echoed, rendered_content_disparity)
% (time, received_str, displayed_str, locally_echoed, rendered_content_disparity)
base_key = statsd_key(user_profile.realm.string_id, clean_periods=True)
statsd.timing("endtoend.send_time.%s" % (base_key,), time)
if received != "(unknown)":
if received > 0:
statsd.timing("endtoend.receive_time.%s" % (base_key,), received)
if displayed != "(unknown)":
if displayed > 0:
statsd.timing("endtoend.displayed_time.%s" % (base_key,), displayed)
if locally_echoed:
statsd.incr('locally_echoed')

View File

@ -32,7 +32,7 @@ def api_travis_webhook(request, user_profile,
('status_message', check_string),
('compare_url', check_string),
]))):
# type: (HttpRequest, UserProfile, str, str, str, Dict[str, str]) -> HttpResponse
# type: (HttpRequest, UserProfile, str, str, bool, Dict[str, str]) -> HttpResponse
message_status = message['status_message']
if ignore_pull_requests and message['type'] == 'pull_request':