mirror of https://github.com/zulip/zulip.git
request: Tighten type checking on REQ.
Then, find and fix a predictable number of previous misuses. With a small change by tabbott to preserve backwards compatibility for sending `yes` for the `forged` field. Signed-off-by: Anders Kaseorg <anders@zulipchat.com>
This commit is contained in:
parent
b0a7b33f9b
commit
cafac83676
|
@ -34,7 +34,7 @@ def unsign_seat_count(signed_seat_count: str, salt: str) -> int:
|
|||
raise BillingError('tampered seat count')
|
||||
|
||||
def check_upgrade_parameters(
|
||||
billing_modality: str, schedule: str, license_management: str, licenses: Optional[int],
|
||||
billing_modality: str, schedule: str, license_management: Optional[str], licenses: Optional[int],
|
||||
has_stripe_token: bool, seat_count: int) -> None:
|
||||
if billing_modality not in ['send_invoice', 'charge_automatically']:
|
||||
raise BillingError('unknown billing_modality')
|
||||
|
@ -74,9 +74,9 @@ def payment_method_string(stripe_customer: stripe.Customer) -> str:
|
|||
def upgrade(request: HttpRequest, user: UserProfile,
|
||||
billing_modality: str=REQ(validator=check_string),
|
||||
schedule: str=REQ(validator=check_string),
|
||||
license_management: str=REQ(validator=check_string, default=None),
|
||||
licenses: int=REQ(validator=check_int, default=None),
|
||||
stripe_token: str=REQ(validator=check_string, default=None),
|
||||
license_management: Optional[str]=REQ(validator=check_string, default=None),
|
||||
licenses: Optional[int]=REQ(validator=check_int, default=None),
|
||||
stripe_token: Optional[str]=REQ(validator=check_string, default=None),
|
||||
signed_seat_count: str=REQ(validator=check_string),
|
||||
salt: str=REQ(validator=check_string)) -> HttpResponse:
|
||||
try:
|
||||
|
@ -89,6 +89,7 @@ def upgrade(request: HttpRequest, user: UserProfile,
|
|||
check_upgrade_parameters(
|
||||
billing_modality, schedule, license_management, licenses,
|
||||
stripe_token is not None, seat_count)
|
||||
assert licenses is not None
|
||||
automanage_licenses = license_management == 'automatic'
|
||||
|
||||
billing_schedule = {'annual': CustomerPlan.ANNUAL,
|
||||
|
|
|
@ -15,6 +15,8 @@ exclude_lines =
|
|||
raise UnexpectedWebhookEventType
|
||||
# Don't require coverage for blocks only run when type-checking
|
||||
if TYPE_CHECKING:
|
||||
# PEP 484 overloading syntax
|
||||
^\s*\.\.\.
|
||||
|
||||
[run]
|
||||
omit =
|
||||
|
|
|
@ -11,7 +11,8 @@ from zerver.lib.types import Validator, ViewFuncT
|
|||
|
||||
from django.http import HttpRequest, HttpResponse
|
||||
|
||||
from typing import Any, Callable, Dict, Generic, List, Optional, Type, TypeVar, Union, cast
|
||||
from typing import Any, Callable, Dict, Generic, List, Optional, Type, TypeVar, Union, cast, overload
|
||||
from typing_extensions import Literal
|
||||
|
||||
class RequestConfusingParmsError(JsonableError):
|
||||
code = ErrorCode.REQUEST_CONFUSING_VAR
|
||||
|
@ -131,12 +132,89 @@ class _REQ(Generic[ResultT]):
|
|||
# instance of class _REQ to enable the decorator to scan the parameter
|
||||
# list for _REQ objects and patch the parameters as the true types.
|
||||
|
||||
# Overload 1: converter
|
||||
@overload
|
||||
def REQ(
|
||||
whence: Optional[str] = ...,
|
||||
*,
|
||||
type: Type[ResultT] = ...,
|
||||
converter: Callable[[str], ResultT],
|
||||
default: ResultT = ...,
|
||||
intentionally_undocumented: bool = ...,
|
||||
documentation_pending: bool = ...,
|
||||
aliases: Optional[List[str]] = ...,
|
||||
path_only: bool = ...
|
||||
) -> ResultT:
|
||||
...
|
||||
|
||||
# Overload 2: validator
|
||||
@overload
|
||||
def REQ(
|
||||
whence: Optional[str] = ...,
|
||||
*,
|
||||
type: Type[ResultT] = ...,
|
||||
default: ResultT = ...,
|
||||
validator: Validator,
|
||||
intentionally_undocumented: bool = ...,
|
||||
documentation_pending: bool = ...,
|
||||
aliases: Optional[List[str]] = ...,
|
||||
path_only: bool = ...
|
||||
) -> ResultT:
|
||||
...
|
||||
|
||||
# Overload 3: no converter/validator, default: str or unspecified, argument_type=None
|
||||
@overload
|
||||
def REQ(
|
||||
whence: Optional[str] = ...,
|
||||
*,
|
||||
type: Type[str] = ...,
|
||||
default: str = ...,
|
||||
str_validator: Optional[Validator] = ...,
|
||||
intentionally_undocumented: bool = ...,
|
||||
documentation_pending: bool = ...,
|
||||
aliases: Optional[List[str]] = ...,
|
||||
path_only: bool = ...
|
||||
) -> str:
|
||||
...
|
||||
|
||||
# Overload 4: no converter/validator, default=None, argument_type=None
|
||||
@overload
|
||||
def REQ(
|
||||
whence: Optional[str] = ...,
|
||||
*,
|
||||
type: Type[str] = ...,
|
||||
default: None,
|
||||
str_validator: Optional[Validator] = ...,
|
||||
intentionally_undocumented: bool = ...,
|
||||
documentation_pending: bool = ...,
|
||||
aliases: Optional[List[str]] = ...,
|
||||
path_only: bool = ...
|
||||
) -> Optional[str]:
|
||||
...
|
||||
|
||||
# Overload 5: argument_type="body"
|
||||
@overload
|
||||
def REQ(
|
||||
whence: Optional[str] = ...,
|
||||
*,
|
||||
type: Type[ResultT] = ...,
|
||||
default: ResultT = ...,
|
||||
str_validator: Optional[Validator] = ...,
|
||||
argument_type: Literal["body"],
|
||||
intentionally_undocumented: bool = ...,
|
||||
documentation_pending: bool = ...,
|
||||
aliases: Optional[List[str]] = ...,
|
||||
path_only: bool = ...
|
||||
) -> ResultT:
|
||||
...
|
||||
|
||||
# Implementation
|
||||
def REQ(
|
||||
whence: Optional[str] = None,
|
||||
*,
|
||||
type: Type[ResultT] = Type[None],
|
||||
converter: Optional[Callable[[str], ResultT]] = None,
|
||||
default: Union[_REQ._NotSpecified, ResultT, None] = _REQ.NotSpecified,
|
||||
default: Union[_REQ._NotSpecified, ResultT] = _REQ.NotSpecified,
|
||||
validator: Optional[Validator] = None,
|
||||
str_validator: Optional[Validator] = None,
|
||||
argument_type: Optional[str] = None,
|
||||
|
|
|
@ -181,7 +181,7 @@ class DecoratorTestCase(TestCase):
|
|||
def test_REQ_converter_and_validator_invalid(self) -> None:
|
||||
with self.assertRaisesRegex(AssertionError, "converter and validator are mutually exclusive"):
|
||||
@has_request_variables
|
||||
def get_total(request: HttpRequest,
|
||||
def get_total(request: HttpRequest, # type: ignore # The condition being tested is in fact an error.
|
||||
numbers: Iterable[int]=REQ(validator=check_list(check_int),
|
||||
converter=lambda x: [])) -> int:
|
||||
return sum(numbers) # nocoverage -- isn't intended to be run
|
||||
|
@ -263,7 +263,7 @@ class DecoratorTestCase(TestCase):
|
|||
# Test we properly handle an invalid argument_type.
|
||||
with self.assertRaises(Exception) as cm:
|
||||
@has_request_variables
|
||||
def test(request: HttpRequest,
|
||||
def test(request: HttpRequest, # type: ignore # The condition being tested is in fact an error.
|
||||
payload: Any=REQ(argument_type="invalid")) -> None:
|
||||
# Any is ok; exception should occur in decorator:
|
||||
pass # nocoverage # this function isn't meant to be called
|
||||
|
|
|
@ -17,7 +17,7 @@ class TestIntegrationsDevPanel(ZulipTestCase):
|
|||
"url": url,
|
||||
"body": body,
|
||||
"custom_headers": "{}",
|
||||
"is_json": True
|
||||
"is_json": "true"
|
||||
}
|
||||
|
||||
response = self.client_post(target_url, data)
|
||||
|
@ -37,7 +37,7 @@ class TestIntegrationsDevPanel(ZulipTestCase):
|
|||
"url": url,
|
||||
"body": body,
|
||||
"custom_headers": "{}",
|
||||
"is_json": True
|
||||
"is_json": "true"
|
||||
}
|
||||
|
||||
response = self.client_post(target_url, data)
|
||||
|
@ -64,7 +64,7 @@ class TestIntegrationsDevPanel(ZulipTestCase):
|
|||
"url": url,
|
||||
"body": body,
|
||||
"custom_headers": ujson.dumps({"X_GITHUB_EVENT": "ping"}),
|
||||
"is_json": True
|
||||
"is_json": "true"
|
||||
}
|
||||
|
||||
response = self.client_post(target_url, data)
|
||||
|
@ -87,7 +87,7 @@ class TestIntegrationsDevPanel(ZulipTestCase):
|
|||
"url": url,
|
||||
"body": body,
|
||||
"custom_headers": ujson.dumps({"Content-Type": "application/x-www-form-urlencoded"}),
|
||||
"is_json": False,
|
||||
"is_json": "false",
|
||||
}
|
||||
|
||||
response = self.client_post(target_url, data)
|
||||
|
|
|
@ -1919,7 +1919,7 @@ class MessagePOSTTest(ZulipTestCase):
|
|||
"client": "test suite",
|
||||
"content": "Test message",
|
||||
"topic": "Test topic",
|
||||
"forged": True})
|
||||
"forged": "true"})
|
||||
self.assert_json_error(result, "User not authorized for this query")
|
||||
|
||||
def test_send_message_as_not_superuser_to_different_domain(self) -> None:
|
||||
|
@ -2047,6 +2047,23 @@ class MessagePOSTTest(ZulipTestCase):
|
|||
msg = self.get_last_message()
|
||||
self.assertEqual(int(datetime_to_timestamp(msg.date_sent)), int(fake_timestamp))
|
||||
|
||||
# Now test again using forged=yes
|
||||
fake_date_sent = timezone_now() - datetime.timedelta(minutes=22)
|
||||
fake_timestamp = datetime_to_timestamp(fake_date_sent)
|
||||
|
||||
result = self.api_post(email, "/api/v1/messages", {"type": "stream",
|
||||
"forged": "yes",
|
||||
"time": fake_timestamp,
|
||||
"sender": "irc-user@irc.zulip.com",
|
||||
"content": "Test message",
|
||||
"client": "irc_mirror",
|
||||
"topic": "from irc",
|
||||
"to": "IRCLand"})
|
||||
self.assert_json_success(result)
|
||||
|
||||
msg = self.get_last_message()
|
||||
self.assertEqual(int(datetime_to_timestamp(msg.date_sent)), int(fake_timestamp))
|
||||
|
||||
def test_unsubscribed_api_super_user(self) -> None:
|
||||
cordelia = self.example_user('cordelia')
|
||||
stream_name = 'private_stream'
|
||||
|
|
|
@ -10,7 +10,7 @@ from zerver.decorator import REQ, RespondAsynchronously, \
|
|||
_RespondAsynchronously, asynchronous, to_non_negative_int, \
|
||||
has_request_variables, internal_notify_view, process_client
|
||||
from zerver.lib.response import json_error, json_success
|
||||
from zerver.lib.validator import check_bool, check_list, check_string
|
||||
from zerver.lib.validator import check_bool, check_int, check_list, check_string
|
||||
from zerver.models import Client, UserProfile, get_client, get_user_profile_by_id
|
||||
from zerver.tornado.event_queue import fetch_events, \
|
||||
get_client_descriptor, process_notification
|
||||
|
@ -36,8 +36,11 @@ def cleanup_event_queue(request: HttpRequest, user_profile: UserProfile,
|
|||
@asynchronous
|
||||
@internal_notify_view(True)
|
||||
@has_request_variables
|
||||
def get_events_internal(request: HttpRequest, handler: BaseHandler,
|
||||
user_profile_id: int=REQ()) -> Union[HttpResponse, _RespondAsynchronously]:
|
||||
def get_events_internal(
|
||||
request: HttpRequest,
|
||||
handler: BaseHandler,
|
||||
user_profile_id: int = REQ(validator=check_int),
|
||||
) -> Union[HttpResponse, _RespondAsynchronously]:
|
||||
user_profile = get_user_profile_by_id(user_profile_id)
|
||||
request._email = user_profile.email
|
||||
process_client(request, user_profile, client_name="internal")
|
||||
|
|
|
@ -10,6 +10,7 @@ from zerver.lib.integrations import WEBHOOK_INTEGRATIONS
|
|||
from zerver.lib.request import has_request_variables, REQ
|
||||
from zerver.lib.response import json_success, json_error
|
||||
from zerver.models import UserProfile, get_realm
|
||||
from zerver.lib.validator import check_bool
|
||||
from zerver.lib.webhooks.common import get_fixture_http_headers, \
|
||||
standardize_headers
|
||||
|
||||
|
@ -27,10 +28,10 @@ def dev_panel(request: HttpRequest) -> HttpResponse:
|
|||
context = {"integrations": integrations, "bots": bots}
|
||||
return render(request, "zerver/integrations/development/dev_panel.html", context)
|
||||
|
||||
def send_webhook_fixture_message(url: str=REQ(),
|
||||
body: str=REQ(),
|
||||
is_json: bool=REQ(),
|
||||
custom_headers: Dict[str, Any]=REQ()) -> HttpResponse:
|
||||
def send_webhook_fixture_message(url: str,
|
||||
body: str,
|
||||
is_json: bool,
|
||||
custom_headers: Dict[str, Any]) -> HttpResponse:
|
||||
client = Client()
|
||||
realm = get_realm("zulip")
|
||||
standardized_headers = standardize_headers(custom_headers)
|
||||
|
@ -85,7 +86,7 @@ def get_fixtures(request: HttpResponse,
|
|||
def check_send_webhook_fixture_message(request: HttpRequest,
|
||||
url: str=REQ(),
|
||||
body: str=REQ(),
|
||||
is_json: bool=REQ(),
|
||||
is_json: bool=REQ(validator=check_bool),
|
||||
custom_headers: str=REQ()) -> HttpResponse:
|
||||
try:
|
||||
custom_headers_dict = ujson.loads(custom_headers)
|
||||
|
|
|
@ -174,7 +174,7 @@ class IntegrationView(ApiURLView):
|
|||
|
||||
|
||||
@has_request_variables
|
||||
def integration_doc(request: HttpRequest, integration_name: str=REQ(default=None)) -> HttpResponse:
|
||||
def integration_doc(request: HttpRequest, integration_name: str=REQ()) -> HttpResponse:
|
||||
if not request.is_ajax():
|
||||
return HttpResponseNotFound()
|
||||
try:
|
||||
|
|
|
@ -1276,7 +1276,8 @@ def send_message_backend(request: HttpRequest, user_profile: UserProfile,
|
|||
message_to: Union[Sequence[int], Sequence[str]]=REQ(
|
||||
'to', type=Union[List[int], List[str]],
|
||||
converter=extract_recipients, default=[]),
|
||||
forged: bool=REQ(default=False,
|
||||
forged_str: Optional[str]=REQ("forged",
|
||||
default=None,
|
||||
documentation_pending=True),
|
||||
topic_name: Optional[str]=REQ_topic(),
|
||||
message_content: str=REQ('content'),
|
||||
|
@ -1295,6 +1296,10 @@ def send_message_backend(request: HttpRequest, user_profile: UserProfile,
|
|||
tz_guess: Optional[str]=REQ('tz_guess', default=None,
|
||||
documentation_pending=True)
|
||||
) -> HttpResponse:
|
||||
# Temporary hack: We're transitioning `forged` from accepting
|
||||
# `yes` to accepting `true` like all of our normal booleans.
|
||||
forged = forged_str is not None and forged_str in ["yes", "true"]
|
||||
|
||||
client = request.client
|
||||
is_super_user = request.user.is_api_super_user
|
||||
if forged and not is_super_user:
|
||||
|
|
|
@ -90,8 +90,8 @@ def create_default_stream_group(request: HttpRequest, user_profile: UserProfile,
|
|||
@require_realm_admin
|
||||
@has_request_variables
|
||||
def update_default_stream_group_info(request: HttpRequest, user_profile: UserProfile, group_id: int,
|
||||
new_group_name: str=REQ(validator=check_string, default=None),
|
||||
new_description: str=REQ(validator=check_string,
|
||||
new_group_name: Optional[str]=REQ(validator=check_string, default=None),
|
||||
new_description: Optional[str]=REQ(validator=check_string,
|
||||
default=None)) -> None:
|
||||
if not new_group_name and not new_description:
|
||||
return json_error(_('You must pass "new_description" or "new_group_name".'))
|
||||
|
@ -544,7 +544,7 @@ def json_get_stream_id(request: HttpRequest,
|
|||
@has_request_variables
|
||||
def update_subscriptions_property(request: HttpRequest,
|
||||
user_profile: UserProfile,
|
||||
stream_id: int=REQ(),
|
||||
stream_id: int=REQ(validator=check_int),
|
||||
property: str=REQ(),
|
||||
value: str=REQ()) -> HttpResponse:
|
||||
subscription_data = [{"property": property,
|
||||
|
|
|
@ -83,7 +83,7 @@ def update_user_backend(request: HttpRequest, user_profile: UserProfile, user_id
|
|||
full_name: Optional[str]=REQ(default="", validator=check_string),
|
||||
is_admin: Optional[bool]=REQ(default=None, validator=check_bool),
|
||||
is_guest: Optional[bool]=REQ(default=None, validator=check_bool),
|
||||
profile_data: List[Dict[str, Union[int, str, List[int]]]]=
|
||||
profile_data: Optional[List[Dict[str, Union[int, str, List[int]]]]]=
|
||||
REQ(default=None,
|
||||
validator=check_list(check_dict([('id', check_int)])))) -> HttpResponse:
|
||||
target = access_user_by_id(user_profile, user_id, allow_deactivated=True, allow_bots=True)
|
||||
|
@ -165,7 +165,7 @@ def get_stream_name(stream: Optional[Stream]) -> Optional[str]:
|
|||
def patch_bot_backend(
|
||||
request: HttpRequest, user_profile: UserProfile, bot_id: int,
|
||||
full_name: Optional[str]=REQ(default=None),
|
||||
bot_owner_id: Optional[int]=REQ(default=None),
|
||||
bot_owner_id: Optional[int]=REQ(validator=check_int, default=None),
|
||||
config_data: Optional[Dict[str, str]]=REQ(default=None,
|
||||
validator=check_dict(value_validator=check_string)),
|
||||
service_payload_url: Optional[str]=REQ(validator=check_url, default=None),
|
||||
|
|
|
@ -473,7 +473,7 @@ IGNORED_EVENTS = [
|
|||
def api_github_webhook(
|
||||
request: HttpRequest, user_profile: UserProfile,
|
||||
payload: Dict[str, Any]=REQ(argument_type='body'),
|
||||
branches: str=REQ(default=None),
|
||||
branches: Optional[str]=REQ(default=None),
|
||||
user_specified_topic: Optional[str]=REQ("topic", default=None)) -> HttpResponse:
|
||||
event = get_event(request, payload, branches)
|
||||
if event is not None:
|
||||
|
@ -489,7 +489,7 @@ def api_github_webhook(
|
|||
check_send_webhook_message(request, user_profile, subject, body)
|
||||
return json_success()
|
||||
|
||||
def get_event(request: HttpRequest, payload: Dict[str, Any], branches: str) -> Optional[str]:
|
||||
def get_event(request: HttpRequest, payload: Dict[str, Any], branches: Optional[str]) -> Optional[str]:
|
||||
event = validate_extract_webhook_http_header(request, 'X_GITHUB_EVENT', 'GitHub')
|
||||
if event == 'pull_request':
|
||||
action = payload['action']
|
||||
|
|
|
@ -6,16 +6,22 @@ from django.http import HttpRequest, HttpResponse
|
|||
from zerver.decorator import api_key_only_webhook_view
|
||||
from zerver.lib.request import REQ, has_request_variables
|
||||
from zerver.lib.response import json_success
|
||||
from zerver.lib.validator import check_int
|
||||
from zerver.lib.webhooks.common import check_send_webhook_message, \
|
||||
UnexpectedWebhookEventType
|
||||
from zerver.models import UserProfile
|
||||
|
||||
@api_key_only_webhook_view('Transifex', notify_bot_owner_on_invalid_json=False)
|
||||
@has_request_variables
|
||||
def api_transifex_webhook(request: HttpRequest, user_profile: UserProfile,
|
||||
project: str=REQ(), resource: str=REQ(),
|
||||
language: str=REQ(), translated: Optional[int]=REQ(default=None),
|
||||
reviewed: Optional[int]=REQ(default=None)) -> HttpResponse:
|
||||
def api_transifex_webhook(
|
||||
request: HttpRequest,
|
||||
user_profile: UserProfile,
|
||||
project: str = REQ(),
|
||||
resource: str = REQ(),
|
||||
language: str = REQ(),
|
||||
translated: Optional[int] = REQ(validator=check_int, default=None),
|
||||
reviewed: Optional[int] = REQ(validator=check_int, default=None),
|
||||
) -> HttpResponse:
|
||||
subject = "{} in {}".format(project, language)
|
||||
if translated:
|
||||
body = "Resource {} fully translated.".format(resource)
|
||||
|
|
|
@ -84,7 +84,7 @@ def register_remote_server(
|
|||
|
||||
@has_request_variables
|
||||
def register_remote_push_device(request: HttpRequest, entity: Union[UserProfile, RemoteZulipServer],
|
||||
user_id: int=REQ(), token: str=REQ(),
|
||||
user_id: int=REQ(validator=check_int), token: str=REQ(),
|
||||
token_kind: int=REQ(validator=check_int),
|
||||
ios_app_id: Optional[str]=None) -> HttpResponse:
|
||||
server = validate_bouncer_token_request(entity, token, token_kind)
|
||||
|
@ -108,7 +108,7 @@ def register_remote_push_device(request: HttpRequest, entity: Union[UserProfile,
|
|||
def unregister_remote_push_device(request: HttpRequest, entity: Union[UserProfile, RemoteZulipServer],
|
||||
token: str=REQ(),
|
||||
token_kind: int=REQ(validator=check_int),
|
||||
user_id: int=REQ(),
|
||||
user_id: int=REQ(validator=check_int),
|
||||
ios_app_id: Optional[str]=None) -> HttpResponse:
|
||||
server = validate_bouncer_token_request(entity, token, token_kind)
|
||||
deleted = RemotePushDeviceToken.objects.filter(token=token,
|
||||
|
|
Loading…
Reference in New Issue