diff --git a/corporate/views.py b/corporate/views.py index bb7e8b9185..3fd8a2eff0 100644 --- a/corporate/views.py +++ b/corporate/views.py @@ -34,7 +34,7 @@ def unsign_seat_count(signed_seat_count: str, salt: str) -> int: raise BillingError('tampered seat count') def check_upgrade_parameters( - billing_modality: str, schedule: str, license_management: str, licenses: Optional[int], + billing_modality: str, schedule: str, license_management: Optional[str], licenses: Optional[int], has_stripe_token: bool, seat_count: int) -> None: if billing_modality not in ['send_invoice', 'charge_automatically']: raise BillingError('unknown billing_modality') @@ -74,9 +74,9 @@ def payment_method_string(stripe_customer: stripe.Customer) -> str: def upgrade(request: HttpRequest, user: UserProfile, billing_modality: str=REQ(validator=check_string), schedule: str=REQ(validator=check_string), - license_management: str=REQ(validator=check_string, default=None), - licenses: int=REQ(validator=check_int, default=None), - stripe_token: str=REQ(validator=check_string, default=None), + license_management: Optional[str]=REQ(validator=check_string, default=None), + licenses: Optional[int]=REQ(validator=check_int, default=None), + stripe_token: Optional[str]=REQ(validator=check_string, default=None), signed_seat_count: str=REQ(validator=check_string), salt: str=REQ(validator=check_string)) -> HttpResponse: try: @@ -89,6 +89,7 @@ def upgrade(request: HttpRequest, user: UserProfile, check_upgrade_parameters( billing_modality, schedule, license_management, licenses, stripe_token is not None, seat_count) + assert licenses is not None automanage_licenses = license_management == 'automatic' billing_schedule = {'annual': CustomerPlan.ANNUAL, diff --git a/tools/coveragerc b/tools/coveragerc index ae533f8f9e..95dbebc751 100644 --- a/tools/coveragerc +++ b/tools/coveragerc @@ -15,6 +15,8 @@ exclude_lines = raise UnexpectedWebhookEventType # Don't require coverage for blocks only run when type-checking if TYPE_CHECKING: + # PEP 484 overloading syntax + ^\s*\.\.\. [run] omit = diff --git a/zerver/lib/request.py b/zerver/lib/request.py index 951d1225e2..c7300b174c 100644 --- a/zerver/lib/request.py +++ b/zerver/lib/request.py @@ -11,7 +11,8 @@ from zerver.lib.types import Validator, ViewFuncT from django.http import HttpRequest, HttpResponse -from typing import Any, Callable, Dict, Generic, List, Optional, Type, TypeVar, Union, cast +from typing import Any, Callable, Dict, Generic, List, Optional, Type, TypeVar, Union, cast, overload +from typing_extensions import Literal class RequestConfusingParmsError(JsonableError): code = ErrorCode.REQUEST_CONFUSING_VAR @@ -131,12 +132,89 @@ class _REQ(Generic[ResultT]): # instance of class _REQ to enable the decorator to scan the parameter # list for _REQ objects and patch the parameters as the true types. +# Overload 1: converter +@overload +def REQ( + whence: Optional[str] = ..., + *, + type: Type[ResultT] = ..., + converter: Callable[[str], ResultT], + default: ResultT = ..., + intentionally_undocumented: bool = ..., + documentation_pending: bool = ..., + aliases: Optional[List[str]] = ..., + path_only: bool = ... +) -> ResultT: + ... + +# Overload 2: validator +@overload +def REQ( + whence: Optional[str] = ..., + *, + type: Type[ResultT] = ..., + default: ResultT = ..., + validator: Validator, + intentionally_undocumented: bool = ..., + documentation_pending: bool = ..., + aliases: Optional[List[str]] = ..., + path_only: bool = ... +) -> ResultT: + ... + +# Overload 3: no converter/validator, default: str or unspecified, argument_type=None +@overload +def REQ( + whence: Optional[str] = ..., + *, + type: Type[str] = ..., + default: str = ..., + str_validator: Optional[Validator] = ..., + intentionally_undocumented: bool = ..., + documentation_pending: bool = ..., + aliases: Optional[List[str]] = ..., + path_only: bool = ... +) -> str: + ... + +# Overload 4: no converter/validator, default=None, argument_type=None +@overload +def REQ( + whence: Optional[str] = ..., + *, + type: Type[str] = ..., + default: None, + str_validator: Optional[Validator] = ..., + intentionally_undocumented: bool = ..., + documentation_pending: bool = ..., + aliases: Optional[List[str]] = ..., + path_only: bool = ... +) -> Optional[str]: + ... + +# Overload 5: argument_type="body" +@overload +def REQ( + whence: Optional[str] = ..., + *, + type: Type[ResultT] = ..., + default: ResultT = ..., + str_validator: Optional[Validator] = ..., + argument_type: Literal["body"], + intentionally_undocumented: bool = ..., + documentation_pending: bool = ..., + aliases: Optional[List[str]] = ..., + path_only: bool = ... +) -> ResultT: + ... + +# Implementation def REQ( whence: Optional[str] = None, *, type: Type[ResultT] = Type[None], converter: Optional[Callable[[str], ResultT]] = None, - default: Union[_REQ._NotSpecified, ResultT, None] = _REQ.NotSpecified, + default: Union[_REQ._NotSpecified, ResultT] = _REQ.NotSpecified, validator: Optional[Validator] = None, str_validator: Optional[Validator] = None, argument_type: Optional[str] = None, diff --git a/zerver/tests/test_decorators.py b/zerver/tests/test_decorators.py index f4641c3ce4..4b32f675c2 100644 --- a/zerver/tests/test_decorators.py +++ b/zerver/tests/test_decorators.py @@ -181,7 +181,7 @@ class DecoratorTestCase(TestCase): def test_REQ_converter_and_validator_invalid(self) -> None: with self.assertRaisesRegex(AssertionError, "converter and validator are mutually exclusive"): @has_request_variables - def get_total(request: HttpRequest, + def get_total(request: HttpRequest, # type: ignore # The condition being tested is in fact an error. numbers: Iterable[int]=REQ(validator=check_list(check_int), converter=lambda x: [])) -> int: return sum(numbers) # nocoverage -- isn't intended to be run @@ -263,7 +263,7 @@ class DecoratorTestCase(TestCase): # Test we properly handle an invalid argument_type. with self.assertRaises(Exception) as cm: @has_request_variables - def test(request: HttpRequest, + def test(request: HttpRequest, # type: ignore # The condition being tested is in fact an error. payload: Any=REQ(argument_type="invalid")) -> None: # Any is ok; exception should occur in decorator: pass # nocoverage # this function isn't meant to be called diff --git a/zerver/tests/test_integrations_dev_panel.py b/zerver/tests/test_integrations_dev_panel.py index e081d4af62..ca647caacd 100644 --- a/zerver/tests/test_integrations_dev_panel.py +++ b/zerver/tests/test_integrations_dev_panel.py @@ -17,7 +17,7 @@ class TestIntegrationsDevPanel(ZulipTestCase): "url": url, "body": body, "custom_headers": "{}", - "is_json": True + "is_json": "true" } response = self.client_post(target_url, data) @@ -37,7 +37,7 @@ class TestIntegrationsDevPanel(ZulipTestCase): "url": url, "body": body, "custom_headers": "{}", - "is_json": True + "is_json": "true" } response = self.client_post(target_url, data) @@ -64,7 +64,7 @@ class TestIntegrationsDevPanel(ZulipTestCase): "url": url, "body": body, "custom_headers": ujson.dumps({"X_GITHUB_EVENT": "ping"}), - "is_json": True + "is_json": "true" } response = self.client_post(target_url, data) @@ -87,7 +87,7 @@ class TestIntegrationsDevPanel(ZulipTestCase): "url": url, "body": body, "custom_headers": ujson.dumps({"Content-Type": "application/x-www-form-urlencoded"}), - "is_json": False, + "is_json": "false", } response = self.client_post(target_url, data) diff --git a/zerver/tests/test_messages.py b/zerver/tests/test_messages.py index 2ea79b3816..7fe60f3cb6 100644 --- a/zerver/tests/test_messages.py +++ b/zerver/tests/test_messages.py @@ -1919,7 +1919,7 @@ class MessagePOSTTest(ZulipTestCase): "client": "test suite", "content": "Test message", "topic": "Test topic", - "forged": True}) + "forged": "true"}) self.assert_json_error(result, "User not authorized for this query") def test_send_message_as_not_superuser_to_different_domain(self) -> None: @@ -2047,6 +2047,23 @@ class MessagePOSTTest(ZulipTestCase): msg = self.get_last_message() self.assertEqual(int(datetime_to_timestamp(msg.date_sent)), int(fake_timestamp)) + # Now test again using forged=yes + fake_date_sent = timezone_now() - datetime.timedelta(minutes=22) + fake_timestamp = datetime_to_timestamp(fake_date_sent) + + result = self.api_post(email, "/api/v1/messages", {"type": "stream", + "forged": "yes", + "time": fake_timestamp, + "sender": "irc-user@irc.zulip.com", + "content": "Test message", + "client": "irc_mirror", + "topic": "from irc", + "to": "IRCLand"}) + self.assert_json_success(result) + + msg = self.get_last_message() + self.assertEqual(int(datetime_to_timestamp(msg.date_sent)), int(fake_timestamp)) + def test_unsubscribed_api_super_user(self) -> None: cordelia = self.example_user('cordelia') stream_name = 'private_stream' diff --git a/zerver/tornado/views.py b/zerver/tornado/views.py index 66c12f1b60..9ceca3f95d 100644 --- a/zerver/tornado/views.py +++ b/zerver/tornado/views.py @@ -10,7 +10,7 @@ from zerver.decorator import REQ, RespondAsynchronously, \ _RespondAsynchronously, asynchronous, to_non_negative_int, \ has_request_variables, internal_notify_view, process_client from zerver.lib.response import json_error, json_success -from zerver.lib.validator import check_bool, check_list, check_string +from zerver.lib.validator import check_bool, check_int, check_list, check_string from zerver.models import Client, UserProfile, get_client, get_user_profile_by_id from zerver.tornado.event_queue import fetch_events, \ get_client_descriptor, process_notification @@ -36,8 +36,11 @@ def cleanup_event_queue(request: HttpRequest, user_profile: UserProfile, @asynchronous @internal_notify_view(True) @has_request_variables -def get_events_internal(request: HttpRequest, handler: BaseHandler, - user_profile_id: int=REQ()) -> Union[HttpResponse, _RespondAsynchronously]: +def get_events_internal( + request: HttpRequest, + handler: BaseHandler, + user_profile_id: int = REQ(validator=check_int), +) -> Union[HttpResponse, _RespondAsynchronously]: user_profile = get_user_profile_by_id(user_profile_id) request._email = user_profile.email process_client(request, user_profile, client_name="internal") diff --git a/zerver/views/development/integrations.py b/zerver/views/development/integrations.py index fba4308be9..e5ea2568b9 100644 --- a/zerver/views/development/integrations.py +++ b/zerver/views/development/integrations.py @@ -10,6 +10,7 @@ from zerver.lib.integrations import WEBHOOK_INTEGRATIONS from zerver.lib.request import has_request_variables, REQ from zerver.lib.response import json_success, json_error from zerver.models import UserProfile, get_realm +from zerver.lib.validator import check_bool from zerver.lib.webhooks.common import get_fixture_http_headers, \ standardize_headers @@ -27,10 +28,10 @@ def dev_panel(request: HttpRequest) -> HttpResponse: context = {"integrations": integrations, "bots": bots} return render(request, "zerver/integrations/development/dev_panel.html", context) -def send_webhook_fixture_message(url: str=REQ(), - body: str=REQ(), - is_json: bool=REQ(), - custom_headers: Dict[str, Any]=REQ()) -> HttpResponse: +def send_webhook_fixture_message(url: str, + body: str, + is_json: bool, + custom_headers: Dict[str, Any]) -> HttpResponse: client = Client() realm = get_realm("zulip") standardized_headers = standardize_headers(custom_headers) @@ -85,7 +86,7 @@ def get_fixtures(request: HttpResponse, def check_send_webhook_fixture_message(request: HttpRequest, url: str=REQ(), body: str=REQ(), - is_json: bool=REQ(), + is_json: bool=REQ(validator=check_bool), custom_headers: str=REQ()) -> HttpResponse: try: custom_headers_dict = ujson.loads(custom_headers) diff --git a/zerver/views/documentation.py b/zerver/views/documentation.py index bb5d3fc8f1..89b2cc59c5 100644 --- a/zerver/views/documentation.py +++ b/zerver/views/documentation.py @@ -174,7 +174,7 @@ class IntegrationView(ApiURLView): @has_request_variables -def integration_doc(request: HttpRequest, integration_name: str=REQ(default=None)) -> HttpResponse: +def integration_doc(request: HttpRequest, integration_name: str=REQ()) -> HttpResponse: if not request.is_ajax(): return HttpResponseNotFound() try: diff --git a/zerver/views/messages.py b/zerver/views/messages.py index b585e9a45d..0e7f9df721 100644 --- a/zerver/views/messages.py +++ b/zerver/views/messages.py @@ -1276,8 +1276,9 @@ def send_message_backend(request: HttpRequest, user_profile: UserProfile, message_to: Union[Sequence[int], Sequence[str]]=REQ( 'to', type=Union[List[int], List[str]], converter=extract_recipients, default=[]), - forged: bool=REQ(default=False, - documentation_pending=True), + forged_str: Optional[str]=REQ("forged", + default=None, + documentation_pending=True), topic_name: Optional[str]=REQ_topic(), message_content: str=REQ('content'), widget_content: Optional[str]=REQ(default=None, @@ -1295,6 +1296,10 @@ def send_message_backend(request: HttpRequest, user_profile: UserProfile, tz_guess: Optional[str]=REQ('tz_guess', default=None, documentation_pending=True) ) -> HttpResponse: + # Temporary hack: We're transitioning `forged` from accepting + # `yes` to accepting `true` like all of our normal booleans. + forged = forged_str is not None and forged_str in ["yes", "true"] + client = request.client is_super_user = request.user.is_api_super_user if forged and not is_super_user: diff --git a/zerver/views/streams.py b/zerver/views/streams.py index 6844401ce6..4e513a7cee 100644 --- a/zerver/views/streams.py +++ b/zerver/views/streams.py @@ -90,9 +90,9 @@ def create_default_stream_group(request: HttpRequest, user_profile: UserProfile, @require_realm_admin @has_request_variables def update_default_stream_group_info(request: HttpRequest, user_profile: UserProfile, group_id: int, - new_group_name: str=REQ(validator=check_string, default=None), - new_description: str=REQ(validator=check_string, - default=None)) -> None: + new_group_name: Optional[str]=REQ(validator=check_string, default=None), + new_description: Optional[str]=REQ(validator=check_string, + default=None)) -> None: if not new_group_name and not new_description: return json_error(_('You must pass "new_description" or "new_group_name".')) @@ -544,7 +544,7 @@ def json_get_stream_id(request: HttpRequest, @has_request_variables def update_subscriptions_property(request: HttpRequest, user_profile: UserProfile, - stream_id: int=REQ(), + stream_id: int=REQ(validator=check_int), property: str=REQ(), value: str=REQ()) -> HttpResponse: subscription_data = [{"property": property, diff --git a/zerver/views/users.py b/zerver/views/users.py index a6bdd108e4..4fd63d701c 100644 --- a/zerver/views/users.py +++ b/zerver/views/users.py @@ -83,7 +83,7 @@ def update_user_backend(request: HttpRequest, user_profile: UserProfile, user_id full_name: Optional[str]=REQ(default="", validator=check_string), is_admin: Optional[bool]=REQ(default=None, validator=check_bool), is_guest: Optional[bool]=REQ(default=None, validator=check_bool), - profile_data: List[Dict[str, Union[int, str, List[int]]]]= + profile_data: Optional[List[Dict[str, Union[int, str, List[int]]]]]= REQ(default=None, validator=check_list(check_dict([('id', check_int)])))) -> HttpResponse: target = access_user_by_id(user_profile, user_id, allow_deactivated=True, allow_bots=True) @@ -165,7 +165,7 @@ def get_stream_name(stream: Optional[Stream]) -> Optional[str]: def patch_bot_backend( request: HttpRequest, user_profile: UserProfile, bot_id: int, full_name: Optional[str]=REQ(default=None), - bot_owner_id: Optional[int]=REQ(default=None), + bot_owner_id: Optional[int]=REQ(validator=check_int, default=None), config_data: Optional[Dict[str, str]]=REQ(default=None, validator=check_dict(value_validator=check_string)), service_payload_url: Optional[str]=REQ(validator=check_url, default=None), diff --git a/zerver/webhooks/github/view.py b/zerver/webhooks/github/view.py index d27f2f12a8..142ed71fdd 100644 --- a/zerver/webhooks/github/view.py +++ b/zerver/webhooks/github/view.py @@ -473,7 +473,7 @@ IGNORED_EVENTS = [ def api_github_webhook( request: HttpRequest, user_profile: UserProfile, payload: Dict[str, Any]=REQ(argument_type='body'), - branches: str=REQ(default=None), + branches: Optional[str]=REQ(default=None), user_specified_topic: Optional[str]=REQ("topic", default=None)) -> HttpResponse: event = get_event(request, payload, branches) if event is not None: @@ -489,7 +489,7 @@ def api_github_webhook( check_send_webhook_message(request, user_profile, subject, body) return json_success() -def get_event(request: HttpRequest, payload: Dict[str, Any], branches: str) -> Optional[str]: +def get_event(request: HttpRequest, payload: Dict[str, Any], branches: Optional[str]) -> Optional[str]: event = validate_extract_webhook_http_header(request, 'X_GITHUB_EVENT', 'GitHub') if event == 'pull_request': action = payload['action'] diff --git a/zerver/webhooks/transifex/view.py b/zerver/webhooks/transifex/view.py index 0afd537eac..4862649937 100644 --- a/zerver/webhooks/transifex/view.py +++ b/zerver/webhooks/transifex/view.py @@ -6,16 +6,22 @@ from django.http import HttpRequest, HttpResponse from zerver.decorator import api_key_only_webhook_view from zerver.lib.request import REQ, has_request_variables from zerver.lib.response import json_success +from zerver.lib.validator import check_int from zerver.lib.webhooks.common import check_send_webhook_message, \ UnexpectedWebhookEventType from zerver.models import UserProfile @api_key_only_webhook_view('Transifex', notify_bot_owner_on_invalid_json=False) @has_request_variables -def api_transifex_webhook(request: HttpRequest, user_profile: UserProfile, - project: str=REQ(), resource: str=REQ(), - language: str=REQ(), translated: Optional[int]=REQ(default=None), - reviewed: Optional[int]=REQ(default=None)) -> HttpResponse: +def api_transifex_webhook( + request: HttpRequest, + user_profile: UserProfile, + project: str = REQ(), + resource: str = REQ(), + language: str = REQ(), + translated: Optional[int] = REQ(validator=check_int, default=None), + reviewed: Optional[int] = REQ(validator=check_int, default=None), +) -> HttpResponse: subject = "{} in {}".format(project, language) if translated: body = "Resource {} fully translated.".format(resource) diff --git a/zilencer/views.py b/zilencer/views.py index 7cf78f6c40..7a896ffd77 100644 --- a/zilencer/views.py +++ b/zilencer/views.py @@ -84,7 +84,7 @@ def register_remote_server( @has_request_variables def register_remote_push_device(request: HttpRequest, entity: Union[UserProfile, RemoteZulipServer], - user_id: int=REQ(), token: str=REQ(), + user_id: int=REQ(validator=check_int), token: str=REQ(), token_kind: int=REQ(validator=check_int), ios_app_id: Optional[str]=None) -> HttpResponse: server = validate_bouncer_token_request(entity, token, token_kind) @@ -108,7 +108,7 @@ def register_remote_push_device(request: HttpRequest, entity: Union[UserProfile, def unregister_remote_push_device(request: HttpRequest, entity: Union[UserProfile, RemoteZulipServer], token: str=REQ(), token_kind: int=REQ(validator=check_int), - user_id: int=REQ(), + user_id: int=REQ(validator=check_int), ios_app_id: Optional[str]=None) -> HttpResponse: server = validate_bouncer_token_request(entity, token, token_kind) deleted = RemotePushDeviceToken.objects.filter(token=token,