request: Tighten type checking on REQ.

Then, find and fix a predictable number of previous misuses.

With a small change by tabbott to preserve backwards compatibility for
sending `yes` for the `forged` field.

Signed-off-by: Anders Kaseorg <anders@zulipchat.com>
This commit is contained in:
Anders Kaseorg 2019-11-12 23:17:49 -08:00 committed by Tim Abbott
parent b0a7b33f9b
commit cafac83676
15 changed files with 151 additions and 38 deletions

View File

@ -34,7 +34,7 @@ def unsign_seat_count(signed_seat_count: str, salt: str) -> int:
raise BillingError('tampered seat count')
def check_upgrade_parameters(
billing_modality: str, schedule: str, license_management: str, licenses: Optional[int],
billing_modality: str, schedule: str, license_management: Optional[str], licenses: Optional[int],
has_stripe_token: bool, seat_count: int) -> None:
if billing_modality not in ['send_invoice', 'charge_automatically']:
raise BillingError('unknown billing_modality')
@ -74,9 +74,9 @@ def payment_method_string(stripe_customer: stripe.Customer) -> str:
def upgrade(request: HttpRequest, user: UserProfile,
billing_modality: str=REQ(validator=check_string),
schedule: str=REQ(validator=check_string),
license_management: str=REQ(validator=check_string, default=None),
licenses: int=REQ(validator=check_int, default=None),
stripe_token: str=REQ(validator=check_string, default=None),
license_management: Optional[str]=REQ(validator=check_string, default=None),
licenses: Optional[int]=REQ(validator=check_int, default=None),
stripe_token: Optional[str]=REQ(validator=check_string, default=None),
signed_seat_count: str=REQ(validator=check_string),
salt: str=REQ(validator=check_string)) -> HttpResponse:
try:
@ -89,6 +89,7 @@ def upgrade(request: HttpRequest, user: UserProfile,
check_upgrade_parameters(
billing_modality, schedule, license_management, licenses,
stripe_token is not None, seat_count)
assert licenses is not None
automanage_licenses = license_management == 'automatic'
billing_schedule = {'annual': CustomerPlan.ANNUAL,

View File

@ -15,6 +15,8 @@ exclude_lines =
raise UnexpectedWebhookEventType
# Don't require coverage for blocks only run when type-checking
if TYPE_CHECKING:
# PEP 484 overloading syntax
^\s*\.\.\.
[run]
omit =

View File

@ -11,7 +11,8 @@ from zerver.lib.types import Validator, ViewFuncT
from django.http import HttpRequest, HttpResponse
from typing import Any, Callable, Dict, Generic, List, Optional, Type, TypeVar, Union, cast
from typing import Any, Callable, Dict, Generic, List, Optional, Type, TypeVar, Union, cast, overload
from typing_extensions import Literal
class RequestConfusingParmsError(JsonableError):
code = ErrorCode.REQUEST_CONFUSING_VAR
@ -131,12 +132,89 @@ class _REQ(Generic[ResultT]):
# instance of class _REQ to enable the decorator to scan the parameter
# list for _REQ objects and patch the parameters as the true types.
# Overload 1: converter
@overload
def REQ(
whence: Optional[str] = ...,
*,
type: Type[ResultT] = ...,
converter: Callable[[str], ResultT],
default: ResultT = ...,
intentionally_undocumented: bool = ...,
documentation_pending: bool = ...,
aliases: Optional[List[str]] = ...,
path_only: bool = ...
) -> ResultT:
...
# Overload 2: validator
@overload
def REQ(
whence: Optional[str] = ...,
*,
type: Type[ResultT] = ...,
default: ResultT = ...,
validator: Validator,
intentionally_undocumented: bool = ...,
documentation_pending: bool = ...,
aliases: Optional[List[str]] = ...,
path_only: bool = ...
) -> ResultT:
...
# Overload 3: no converter/validator, default: str or unspecified, argument_type=None
@overload
def REQ(
whence: Optional[str] = ...,
*,
type: Type[str] = ...,
default: str = ...,
str_validator: Optional[Validator] = ...,
intentionally_undocumented: bool = ...,
documentation_pending: bool = ...,
aliases: Optional[List[str]] = ...,
path_only: bool = ...
) -> str:
...
# Overload 4: no converter/validator, default=None, argument_type=None
@overload
def REQ(
whence: Optional[str] = ...,
*,
type: Type[str] = ...,
default: None,
str_validator: Optional[Validator] = ...,
intentionally_undocumented: bool = ...,
documentation_pending: bool = ...,
aliases: Optional[List[str]] = ...,
path_only: bool = ...
) -> Optional[str]:
...
# Overload 5: argument_type="body"
@overload
def REQ(
whence: Optional[str] = ...,
*,
type: Type[ResultT] = ...,
default: ResultT = ...,
str_validator: Optional[Validator] = ...,
argument_type: Literal["body"],
intentionally_undocumented: bool = ...,
documentation_pending: bool = ...,
aliases: Optional[List[str]] = ...,
path_only: bool = ...
) -> ResultT:
...
# Implementation
def REQ(
whence: Optional[str] = None,
*,
type: Type[ResultT] = Type[None],
converter: Optional[Callable[[str], ResultT]] = None,
default: Union[_REQ._NotSpecified, ResultT, None] = _REQ.NotSpecified,
default: Union[_REQ._NotSpecified, ResultT] = _REQ.NotSpecified,
validator: Optional[Validator] = None,
str_validator: Optional[Validator] = None,
argument_type: Optional[str] = None,

View File

@ -181,7 +181,7 @@ class DecoratorTestCase(TestCase):
def test_REQ_converter_and_validator_invalid(self) -> None:
with self.assertRaisesRegex(AssertionError, "converter and validator are mutually exclusive"):
@has_request_variables
def get_total(request: HttpRequest,
def get_total(request: HttpRequest, # type: ignore # The condition being tested is in fact an error.
numbers: Iterable[int]=REQ(validator=check_list(check_int),
converter=lambda x: [])) -> int:
return sum(numbers) # nocoverage -- isn't intended to be run
@ -263,7 +263,7 @@ class DecoratorTestCase(TestCase):
# Test we properly handle an invalid argument_type.
with self.assertRaises(Exception) as cm:
@has_request_variables
def test(request: HttpRequest,
def test(request: HttpRequest, # type: ignore # The condition being tested is in fact an error.
payload: Any=REQ(argument_type="invalid")) -> None:
# Any is ok; exception should occur in decorator:
pass # nocoverage # this function isn't meant to be called

View File

@ -17,7 +17,7 @@ class TestIntegrationsDevPanel(ZulipTestCase):
"url": url,
"body": body,
"custom_headers": "{}",
"is_json": True
"is_json": "true"
}
response = self.client_post(target_url, data)
@ -37,7 +37,7 @@ class TestIntegrationsDevPanel(ZulipTestCase):
"url": url,
"body": body,
"custom_headers": "{}",
"is_json": True
"is_json": "true"
}
response = self.client_post(target_url, data)
@ -64,7 +64,7 @@ class TestIntegrationsDevPanel(ZulipTestCase):
"url": url,
"body": body,
"custom_headers": ujson.dumps({"X_GITHUB_EVENT": "ping"}),
"is_json": True
"is_json": "true"
}
response = self.client_post(target_url, data)
@ -87,7 +87,7 @@ class TestIntegrationsDevPanel(ZulipTestCase):
"url": url,
"body": body,
"custom_headers": ujson.dumps({"Content-Type": "application/x-www-form-urlencoded"}),
"is_json": False,
"is_json": "false",
}
response = self.client_post(target_url, data)

View File

@ -1919,7 +1919,7 @@ class MessagePOSTTest(ZulipTestCase):
"client": "test suite",
"content": "Test message",
"topic": "Test topic",
"forged": True})
"forged": "true"})
self.assert_json_error(result, "User not authorized for this query")
def test_send_message_as_not_superuser_to_different_domain(self) -> None:
@ -2047,6 +2047,23 @@ class MessagePOSTTest(ZulipTestCase):
msg = self.get_last_message()
self.assertEqual(int(datetime_to_timestamp(msg.date_sent)), int(fake_timestamp))
# Now test again using forged=yes
fake_date_sent = timezone_now() - datetime.timedelta(minutes=22)
fake_timestamp = datetime_to_timestamp(fake_date_sent)
result = self.api_post(email, "/api/v1/messages", {"type": "stream",
"forged": "yes",
"time": fake_timestamp,
"sender": "irc-user@irc.zulip.com",
"content": "Test message",
"client": "irc_mirror",
"topic": "from irc",
"to": "IRCLand"})
self.assert_json_success(result)
msg = self.get_last_message()
self.assertEqual(int(datetime_to_timestamp(msg.date_sent)), int(fake_timestamp))
def test_unsubscribed_api_super_user(self) -> None:
cordelia = self.example_user('cordelia')
stream_name = 'private_stream'

View File

@ -10,7 +10,7 @@ from zerver.decorator import REQ, RespondAsynchronously, \
_RespondAsynchronously, asynchronous, to_non_negative_int, \
has_request_variables, internal_notify_view, process_client
from zerver.lib.response import json_error, json_success
from zerver.lib.validator import check_bool, check_list, check_string
from zerver.lib.validator import check_bool, check_int, check_list, check_string
from zerver.models import Client, UserProfile, get_client, get_user_profile_by_id
from zerver.tornado.event_queue import fetch_events, \
get_client_descriptor, process_notification
@ -36,8 +36,11 @@ def cleanup_event_queue(request: HttpRequest, user_profile: UserProfile,
@asynchronous
@internal_notify_view(True)
@has_request_variables
def get_events_internal(request: HttpRequest, handler: BaseHandler,
user_profile_id: int=REQ()) -> Union[HttpResponse, _RespondAsynchronously]:
def get_events_internal(
request: HttpRequest,
handler: BaseHandler,
user_profile_id: int = REQ(validator=check_int),
) -> Union[HttpResponse, _RespondAsynchronously]:
user_profile = get_user_profile_by_id(user_profile_id)
request._email = user_profile.email
process_client(request, user_profile, client_name="internal")

View File

@ -10,6 +10,7 @@ from zerver.lib.integrations import WEBHOOK_INTEGRATIONS
from zerver.lib.request import has_request_variables, REQ
from zerver.lib.response import json_success, json_error
from zerver.models import UserProfile, get_realm
from zerver.lib.validator import check_bool
from zerver.lib.webhooks.common import get_fixture_http_headers, \
standardize_headers
@ -27,10 +28,10 @@ def dev_panel(request: HttpRequest) -> HttpResponse:
context = {"integrations": integrations, "bots": bots}
return render(request, "zerver/integrations/development/dev_panel.html", context)
def send_webhook_fixture_message(url: str=REQ(),
body: str=REQ(),
is_json: bool=REQ(),
custom_headers: Dict[str, Any]=REQ()) -> HttpResponse:
def send_webhook_fixture_message(url: str,
body: str,
is_json: bool,
custom_headers: Dict[str, Any]) -> HttpResponse:
client = Client()
realm = get_realm("zulip")
standardized_headers = standardize_headers(custom_headers)
@ -85,7 +86,7 @@ def get_fixtures(request: HttpResponse,
def check_send_webhook_fixture_message(request: HttpRequest,
url: str=REQ(),
body: str=REQ(),
is_json: bool=REQ(),
is_json: bool=REQ(validator=check_bool),
custom_headers: str=REQ()) -> HttpResponse:
try:
custom_headers_dict = ujson.loads(custom_headers)

View File

@ -174,7 +174,7 @@ class IntegrationView(ApiURLView):
@has_request_variables
def integration_doc(request: HttpRequest, integration_name: str=REQ(default=None)) -> HttpResponse:
def integration_doc(request: HttpRequest, integration_name: str=REQ()) -> HttpResponse:
if not request.is_ajax():
return HttpResponseNotFound()
try:

View File

@ -1276,8 +1276,9 @@ def send_message_backend(request: HttpRequest, user_profile: UserProfile,
message_to: Union[Sequence[int], Sequence[str]]=REQ(
'to', type=Union[List[int], List[str]],
converter=extract_recipients, default=[]),
forged: bool=REQ(default=False,
documentation_pending=True),
forged_str: Optional[str]=REQ("forged",
default=None,
documentation_pending=True),
topic_name: Optional[str]=REQ_topic(),
message_content: str=REQ('content'),
widget_content: Optional[str]=REQ(default=None,
@ -1295,6 +1296,10 @@ def send_message_backend(request: HttpRequest, user_profile: UserProfile,
tz_guess: Optional[str]=REQ('tz_guess', default=None,
documentation_pending=True)
) -> HttpResponse:
# Temporary hack: We're transitioning `forged` from accepting
# `yes` to accepting `true` like all of our normal booleans.
forged = forged_str is not None and forged_str in ["yes", "true"]
client = request.client
is_super_user = request.user.is_api_super_user
if forged and not is_super_user:

View File

@ -90,9 +90,9 @@ def create_default_stream_group(request: HttpRequest, user_profile: UserProfile,
@require_realm_admin
@has_request_variables
def update_default_stream_group_info(request: HttpRequest, user_profile: UserProfile, group_id: int,
new_group_name: str=REQ(validator=check_string, default=None),
new_description: str=REQ(validator=check_string,
default=None)) -> None:
new_group_name: Optional[str]=REQ(validator=check_string, default=None),
new_description: Optional[str]=REQ(validator=check_string,
default=None)) -> None:
if not new_group_name and not new_description:
return json_error(_('You must pass "new_description" or "new_group_name".'))
@ -544,7 +544,7 @@ def json_get_stream_id(request: HttpRequest,
@has_request_variables
def update_subscriptions_property(request: HttpRequest,
user_profile: UserProfile,
stream_id: int=REQ(),
stream_id: int=REQ(validator=check_int),
property: str=REQ(),
value: str=REQ()) -> HttpResponse:
subscription_data = [{"property": property,

View File

@ -83,7 +83,7 @@ def update_user_backend(request: HttpRequest, user_profile: UserProfile, user_id
full_name: Optional[str]=REQ(default="", validator=check_string),
is_admin: Optional[bool]=REQ(default=None, validator=check_bool),
is_guest: Optional[bool]=REQ(default=None, validator=check_bool),
profile_data: List[Dict[str, Union[int, str, List[int]]]]=
profile_data: Optional[List[Dict[str, Union[int, str, List[int]]]]]=
REQ(default=None,
validator=check_list(check_dict([('id', check_int)])))) -> HttpResponse:
target = access_user_by_id(user_profile, user_id, allow_deactivated=True, allow_bots=True)
@ -165,7 +165,7 @@ def get_stream_name(stream: Optional[Stream]) -> Optional[str]:
def patch_bot_backend(
request: HttpRequest, user_profile: UserProfile, bot_id: int,
full_name: Optional[str]=REQ(default=None),
bot_owner_id: Optional[int]=REQ(default=None),
bot_owner_id: Optional[int]=REQ(validator=check_int, default=None),
config_data: Optional[Dict[str, str]]=REQ(default=None,
validator=check_dict(value_validator=check_string)),
service_payload_url: Optional[str]=REQ(validator=check_url, default=None),

View File

@ -473,7 +473,7 @@ IGNORED_EVENTS = [
def api_github_webhook(
request: HttpRequest, user_profile: UserProfile,
payload: Dict[str, Any]=REQ(argument_type='body'),
branches: str=REQ(default=None),
branches: Optional[str]=REQ(default=None),
user_specified_topic: Optional[str]=REQ("topic", default=None)) -> HttpResponse:
event = get_event(request, payload, branches)
if event is not None:
@ -489,7 +489,7 @@ def api_github_webhook(
check_send_webhook_message(request, user_profile, subject, body)
return json_success()
def get_event(request: HttpRequest, payload: Dict[str, Any], branches: str) -> Optional[str]:
def get_event(request: HttpRequest, payload: Dict[str, Any], branches: Optional[str]) -> Optional[str]:
event = validate_extract_webhook_http_header(request, 'X_GITHUB_EVENT', 'GitHub')
if event == 'pull_request':
action = payload['action']

View File

@ -6,16 +6,22 @@ from django.http import HttpRequest, HttpResponse
from zerver.decorator import api_key_only_webhook_view
from zerver.lib.request import REQ, has_request_variables
from zerver.lib.response import json_success
from zerver.lib.validator import check_int
from zerver.lib.webhooks.common import check_send_webhook_message, \
UnexpectedWebhookEventType
from zerver.models import UserProfile
@api_key_only_webhook_view('Transifex', notify_bot_owner_on_invalid_json=False)
@has_request_variables
def api_transifex_webhook(request: HttpRequest, user_profile: UserProfile,
project: str=REQ(), resource: str=REQ(),
language: str=REQ(), translated: Optional[int]=REQ(default=None),
reviewed: Optional[int]=REQ(default=None)) -> HttpResponse:
def api_transifex_webhook(
request: HttpRequest,
user_profile: UserProfile,
project: str = REQ(),
resource: str = REQ(),
language: str = REQ(),
translated: Optional[int] = REQ(validator=check_int, default=None),
reviewed: Optional[int] = REQ(validator=check_int, default=None),
) -> HttpResponse:
subject = "{} in {}".format(project, language)
if translated:
body = "Resource {} fully translated.".format(resource)

View File

@ -84,7 +84,7 @@ def register_remote_server(
@has_request_variables
def register_remote_push_device(request: HttpRequest, entity: Union[UserProfile, RemoteZulipServer],
user_id: int=REQ(), token: str=REQ(),
user_id: int=REQ(validator=check_int), token: str=REQ(),
token_kind: int=REQ(validator=check_int),
ios_app_id: Optional[str]=None) -> HttpResponse:
server = validate_bouncer_token_request(entity, token, token_kind)
@ -108,7 +108,7 @@ def register_remote_push_device(request: HttpRequest, entity: Union[UserProfile,
def unregister_remote_push_device(request: HttpRequest, entity: Union[UserProfile, RemoteZulipServer],
token: str=REQ(),
token_kind: int=REQ(validator=check_int),
user_id: int=REQ(),
user_id: int=REQ(validator=check_int),
ios_app_id: Optional[str]=None) -> HttpResponse:
server = validate_bouncer_token_request(entity, token, token_kind)
deleted = RemotePushDeviceToken.objects.filter(token=token,