From f4caf9dd79f96adc722165cfe73215a8527dac2d Mon Sep 17 00:00:00 2001 From: Zixuan James Li Date: Fri, 28 Jul 2023 02:34:04 -0400 Subject: [PATCH] api: Add new typed_endpoint decorators. The goal of typed_endpoint is to replicate most features supported by has_request_variables, and to improve on top of it. There are some unresolved issues that we don't plan to work on currently. For example, typed_endpoint does not support ignored_parameters_supported for 400 responses, and it does not run validators on path-only arguments. Unlike has_request_variables, typed_endpoint supports error handling by processing validation errors from Pydantic. Most features supported by has_request_variables are supported by typed_endpoint in various ways. To define a function, use a syntax like this with Annotated if there is any metadata you want to associate with a parameter, do note that parameters that are not keyword-only are ignored from the request: ``` @typed_endpoint def view( request: HttpRequest, user_profile: UserProfile, *, foo: Annotated[int, ApiParamConfig(path_only=True)], bar: Json[int], other: Annotated[ Json[int], ApiParamConfig( whence="lorem", documentation_status=NTENTIONALLY_UNDOCUMENTED ) ] = 10, ) -> HttpResponse: .... ``` There are also some shorthands for the commonly used annotated types, which are encouraged when applicable for better readability and less typing: ``` WebhookPayload = Annotated[Json[T], ApiParamConfig(argument_type_is_body=True)] PathOnly = Annotated[T, ApiParamConfig(path_only=True)] ``` Then the view function above can be rewritten as: ``` @typed_endpoint def view( request: HttpRequest, user_profile: UserProfile, *, foo: PathOnly[int], bar: Json[int], other: Annotated[ Json[int], ApiParamConfig( whence="lorem", documentation_status=INTENTIONALLY_UNDOCUMENTED ) ] = 10, ) -> HttpResponse: .... ``` There are some intentional restrictions: - A single parameter cannot have more than one ApiParamConfig - Path-only parameters cannot have default values - argument_type_is_body is incompatible with whence - Arguments of name "request", "user_profile", "args", and "kwargs" and etc. are ignored by typed_endpoint. - positional-only arguments are not supported by typed_endpoint. Only keyword-only parameters are expected to be parsed from the request. - Pydantic's strict mode is always enabled, because we don't want to coerce input parsed from JSON into other types unnecessarily. - Using strict mode all the time also means that we should always use Json[int] instead of int, because it is only possible for the request to have data of type str, and a type annotation of int will always reject such data. typed_endpoint's handling of ignored_parameters_unsupported is mostly identical to that of has_request_variables. --- tools/linter_lib/custom_check.py | 1 + tools/semgrep.yml | 17 + zerver/lib/exceptions.py | 6 + zerver/lib/typed_endpoint.py | 511 ++++++++++++++++++++++++ zerver/lib/validator.py | 13 + zerver/tests/test_typed_endpoint.py | 582 ++++++++++++++++++++++++++++ 6 files changed, 1130 insertions(+) create mode 100644 zerver/lib/typed_endpoint.py create mode 100644 zerver/tests/test_typed_endpoint.py diff --git a/tools/linter_lib/custom_check.py b/tools/linter_lib/custom_check.py index b90a5363fa..47a62d8470 100644 --- a/tools/linter_lib/custom_check.py +++ b/tools/linter_lib/custom_check.py @@ -23,6 +23,7 @@ FILES_WITH_LEGACY_SUBJECT = { "zerver/lib/email_mirror.py", "zerver/lib/email_notifications.py", "zerver/lib/send_email.py", + "zerver/lib/typed_endpoint.py", "zerver/tests/test_new_users.py", "zerver/tests/test_email_mirror.py", "zerver/tests/test_message_notification_emails.py", diff --git a/tools/semgrep.yml b/tools/semgrep.yml index 8fee591531..8b4b32a545 100644 --- a/tools/semgrep.yml +++ b/tools/semgrep.yml @@ -169,3 +169,20 @@ rules: message: 'A batched migration should not be atomic. Add "atomic = False" to the Migration class' languages: [python] severity: ERROR + + - id: typed_endpoint_without_keyword_only_param + patterns: + - pattern: | + @typed_endpoint + def $F(...)-> ...: + ... + - pattern-not-inside: | + @typed_endpoint + def $F(..., *, ...)-> ...: + ... + message: | + @typed_endpoint should not be used without keyword-only parameters. + Make parameters to be parsed from the request as keyword-only, + or use @typed_endpoint_without_parameters instead. + languages: [python] + severity: ERROR diff --git a/zerver/lib/exceptions.py b/zerver/lib/exceptions.py index f736052fbc..47e23c168d 100644 --- a/zerver/lib/exceptions.py +++ b/zerver/lib/exceptions.py @@ -527,3 +527,9 @@ class ReactionDoesNotExistError(JsonableError): @staticmethod def msg_format() -> str: return _("Reaction doesn't exist.") + + +class ApiParamValidationError(JsonableError): + def __init__(self, msg: str, error_type: str) -> None: + super().__init__(msg) + self.error_type = error_type diff --git a/zerver/lib/typed_endpoint.py b/zerver/lib/typed_endpoint.py new file mode 100644 index 0000000000..0af5d4d24f --- /dev/null +++ b/zerver/lib/typed_endpoint.py @@ -0,0 +1,511 @@ +import inspect +import json +from dataclasses import dataclass +from enum import Enum, auto +from functools import wraps +from typing import Callable, Generic, List, Optional, Sequence, Tuple, Type, TypeVar, Union + +from django.http import HttpRequest +from django.utils.translation import gettext as _ +from pydantic import Json, StringConstraints, TypeAdapter, ValidationError +from typing_extensions import ( + Annotated, + Concatenate, + ParamSpec, + TypeAlias, + get_args, + get_origin, + get_type_hints, +) + +from zerver.lib.exceptions import ApiParamValidationError, JsonableError +from zerver.lib.request import ( + _REQ, + RequestConfusingParamsError, + RequestNotes, + RequestVariableMissingError, + arguments_map, +) +from zerver.lib.response import MutableJsonResponse + +T = TypeVar("T") +ParamT = ParamSpec("ParamT") +ReturnT = TypeVar("ReturnT") + + +class DocumentationStatus(Enum): + DOCUMENTED = auto() + INTENTIONALLY_UNDOCUMENTED = auto() + DOCUMENTATION_PENDING = auto() + + +DOCUMENTED = DocumentationStatus.DOCUMENTED +INTENTIONALLY_UNDOCUMENTED = DocumentationStatus.INTENTIONALLY_UNDOCUMENTED +DOCUMENTATION_PENDING = DocumentationStatus.DOCUMENTATION_PENDING + + +@dataclass(frozen=True) +class ApiParamConfig: + """The metadata associated with a view function parameter as an annotation + to configure how the typed_endpoint decorator should process it. + + It should be used with Annotated as the type annotation of a parameter + in a @typed_endpoint-decorated function: + ``` + @typed_endpoint + def view( + request: HttpRequest, + *, + flag_value: Annotated[Json[bool], ApiParamConfig( + whence="flag", + documentation_status=INTENTIONALLY_UNDOCUMENTED, + )] + ) -> HttpResponse: + ... + ``` + + For a parameter that is not annotated with ApiParamConfig, typed_endpoint + will construct a configuration using the defaults. + + whence: + The name of the request variable that should be used for this parameter. + If None, it is set to the name of the function parameter. + + path_only: + Used for parameters included in the URL. + + argument_type_is_body: + When set to true, the value of the parameter will be extracted from the + request body instead of a single query parameter. + + documentation_status: + The OpenAPI documentation status of this parameter. Unless it is set to + INTENTIONALLY_UNDOCUMENTED or DOCUMENTATION_PENDING, the test suite is + configured to raise an error when its documentation cannot be found. + + aliases: + The names allowed for the request variable other than that specified with + "whence". + """ + + whence: Optional[str] = None + path_only: bool = False + argument_type_is_body: bool = False + documentation_status: DocumentationStatus = DOCUMENTED + aliases: Tuple[str, ...] = () + + +# TypeAliases for common Annotated types + +# Commonly used for webhook views, where the payload has a content type of +# application/json. It reads the data from request body and parse it from JSON. +WebhookPayload: TypeAlias = Annotated[Json[T], ApiParamConfig(argument_type_is_body=True)] +# A shorthand to declare path only variables that should not be parsed from the +# request by the @typed_endpoint decorator. +PathOnly: TypeAlias = Annotated[T, ApiParamConfig(path_only=True)] + +# Reusable annotation metadata for Annotated types + +# This disallows strings of length 0 after stripping. +# Example usage: Annotated[T, StringRequiredConstraint()] +RequiredStringConstraint = lambda: StringConstraints(strip_whitespace=True, min_length=1) + +# Implementation + + +class _NotSpecified: + pass + + +NotSpecified = _NotSpecified() + + +# For performance reasons, attributes needed from ApiParamConfig are copied to +# FuncParam. We should use slotted dataclass once the entire codebase is +# switched to Python 3.10+ +@dataclass(frozen=True) +class FuncParam(Generic[T]): + # Default value of the parameter. + default: Union[T, _NotSpecified] + # Name of the function parameter as defined in the original function. + param_name: str + # Inspected the underlying type of the parameter by unwrapping the Annotated + # type if there is one. + param_type: Type[T] + # The Pydantic TypeAdapter used to parse arbitrary input to the desired type. + # We store it on the FuncParam object as soon as the view function is + # decorated because it is expensive to construct. + # See also: https://docs.pydantic.dev/latest/usage/type_adapter/ + type_adapter: TypeAdapter[T] + + # The following group of attributes are computed from the ApiParamConfig + # annotation associated with this param: + # Name of the corresponding variable in the request data to look + # for. When argument_type_is_body is True, this is set to "request". + aliases: Tuple[str, ...] + argument_type_is_body: bool + documentation_status: DocumentationStatus + path_only: bool + request_var_name: str + + +@dataclass(frozen=True) +class ViewFuncInfo: + view_func_full_name: str + parameters: Sequence[FuncParam[object]] + + +def is_annotated(type_annotation: Type[object]) -> bool: + origin = get_origin(type_annotation) + return origin is Annotated + + +def parse_single_parameter( + param_name: str, param_type: Type[T], parameter: inspect.Parameter +) -> FuncParam[T]: + param_default = parameter.default + # inspect._empty is the internal type used by inspect to indicate not + # specified defaults. + if param_default is inspect._empty: + param_default = NotSpecified + + # Defaulting a value to None automatically wraps the type annotation with + # Optional. We explicitly unwrap it for the case of Annotated, which + # otherwise causes undesired behaviors that the annotated metadata gets + # lost. This is fixed in Python 3.11: + # https://github.com/python/cpython/issues/90353 + if param_default is None: + origin = get_origin(param_type) + type_args = get_args(param_type) + if origin is Union and type(None) in type_args and len(type_args) == 2: + inner_type = type_args[0] if type_args[1] is type(None) else type_args[1] + if is_annotated(inner_type): + param_type = inner_type + + param_config: Optional[ApiParamConfig] = None + if is_annotated(param_type): + # The first type is the underlying type of the parameter, the rest are + # metadata attached to Annotated. Note that we do not transform + # param_type to its underlying type because the Annotated metadata might + # still be needed by other parties like Pydantic. + _, *annotations = get_args(param_type) + for annotation in annotations: + if not isinstance(annotation, ApiParamConfig): + continue + assert param_config is None, "ApiParamConfig can only be defined once per parameter" + param_config = annotation + # If param_config is still None at this point, we could not find an instance + # of it in the type annotation of the function parameter. In this case, we + # fallback to the defaults by constructing ApiParamConfig here. + # This is common for simple parameters of type str, Json[int] and etc. + if param_config is None: + param_config = ApiParamConfig() + + # Metadata defines a validator making sure that argument_type_is_body is + # incompatible with whence. + if param_config.argument_type_is_body: + request_var_name = "request" + else: + request_var_name = param_config.whence if param_config.whence is not None else param_name + + return FuncParam( + default=param_default, + param_name=param_name, + param_type=param_type, + type_adapter=TypeAdapter(param_type), + aliases=param_config.aliases, + argument_type_is_body=param_config.argument_type_is_body, + documentation_status=param_config.documentation_status, + path_only=param_config.path_only, + request_var_name=request_var_name, + ) + + +def parse_view_func_signature( + view_func: Callable[Concatenate[HttpRequest, ParamT], object] +) -> ViewFuncInfo: + """This is responsible for inspecting the function signature and getting the + metadata from the parameters. We want to keep this function as pure as + possible not leaking side effects to the global state. Side effects should + be executed separately after the ViewFuncInfo is returned. + """ + type_hints = get_type_hints(view_func, include_extras=True) + parameters = inspect.signature(view_func).parameters + view_func_full_name = f"{view_func.__module__}.{view_func.__name__}" + + process_parameters: List[FuncParam[object]] = [] + + for param_name, parameter in parameters.items(): + assert param_name in type_hints + if parameter.kind != inspect.Parameter.KEYWORD_ONLY: + continue + param_info = parse_single_parameter( + param_name=param_name, param_type=type_hints[param_name], parameter=parameter + ) + process_parameters.append(param_info) + + return ViewFuncInfo( + view_func_full_name=view_func_full_name, + parameters=process_parameters, + ) + + +# TODO: To get coverage data, we should switch to match-case syntax when we +# upgrade to Python 3.10. +# This should be sorted alphabetically. +ERROR_TEMPLATES = { + "bool_parsing": _("{var_name} is not a boolean"), + "bool_type": _("{var_name} is not a boolean"), + "datetime_parsing": _("{var_name} is not a date"), + "datetime_type": _("{var_name} is not a date"), + "dict_type": _("{var_name} is not a dict"), + "extra_forbidden": _('Argument "{argument}" at {var_name} is unexpected'), + "float_parsing": _("{var_name} is not a float"), + "float_type": _("{var_name} is not a float"), + "greater_than": _("{var_name} is too small"), + "int_parsing": _("{var_name} is not an integer"), + "int_type": _("{var_name} is not an integer"), + "json_invalid": _("{var_name} is not valid JSON"), + "json_type": _("{var_name} is not valid JSON"), + "less_than": _("{var_name} is too large"), + "list_type": _("{var_name} is not a list"), + "literal_error": _("Invalid {var_name}"), + "string_too_long": _("{var_name} is too long (limit: {max_length} characters)"), + "string_too_short": _("{var_name} is too short."), + "string_type": _("{var_name} is not a string"), + "unexpected_keyword_argument": _('Argument "{argument}" at {var_name} is unexpected'), +} + + +def parse_value_for_parameter(parameter: FuncParam[T], value: object) -> T: + try: + return parameter.type_adapter.validate_python(value, strict=True) + except ValidationError as exc: + # If the validation fails, it is possible to get multiple errors from + # Pydantic. We only send the first error back to the client. + # See also on ValidationError: + # https://docs.pydantic.dev/latest/errors/validation_errors/ + error = exc.errors()[0] + # We require all Pydantic raised error types that we expect to be + # explicitly handled here. The end result should either be a 400 + # error with an translated message or an internal server error. + error_template = ERROR_TEMPLATES.get(error["type"]) + var_name = parameter.request_var_name + "".join( + f"[{json.dumps(loc)}]" for loc in error["loc"] + ) + context = { + "var_name": var_name, + **error.get("ctx", {}), + } + + if error["type"] == "json_invalid" and parameter.argument_type_is_body: + # argument_type_is_body is usually used by webhooks that do not + # require a specific var_name for payload JSON decoding error. + # We override it here. + error_template = _("Malformed JSON") + elif error["type"] in ("unexpected_keyword_argument", "extra_forbidden"): + context["argument"] = error["loc"][-1] + # This condition matches our StringRequiredConstraint + elif error["type"] == "string_too_short" and error["ctx"].get("min_length") == 1: + error_template = _("{var_name} cannot be blank") + + assert error_template is not None, MISSING_ERROR_TEMPLATE.format( + error_type=error["type"], + url=error.get("url", "(documentation unavailable)"), + error=json.dumps(error, indent=4), + ) + raise ApiParamValidationError(error_template.format(**context), error["type"]) + + +MISSING_ERROR_TEMPLATE = f""" + Pydantic validation error of type "{{error_type}}" does not have the + corresponding error message template or is not handled explicitly. We expect + that every validation error is formatted into a client-facing error message. + Consider adding this type to {__package__}.ERROR_TEMPLATES with the appropriate + internationalized error message or handle it in {__package__}.{parse_value_for_parameter.__name__}. + + Documentation for "{{error_type}}" can be found at {{url}}. + + Error information: +{{error}} +""" + + +UNEXPECTEDLY_MISSING_KEYWORD_ONLY_PARAMETERS = """ +Parameters expected to be parsed from the request should be defined as +keyword-only parameters, but there is no keyword-only parameter found in +{view_func_name}. + +Example usage: + +``` +@typed_endpoint +def view( + request: HttpRequest, + *, + flag_value: Annotated[Json[bool], ApiParamConfig( + whence="flag", documentation_status=INTENTIONALLY_UNDOCUMENTED, + )] +) -> HttpResponse: + ... +``` + +This is likely a programming error. See https://peps.python.org/pep-3102/ for details on how +to correctly declare your parameters as keyword-only parameters. +Endpoints that do not accept parameters should use @typed_endpoint_without_parameters. +""" + +UNEXPECTED_KEYWORD_ONLY_PARAMETERS = """ +Unexpected keyword-only parameters found in {view_func_name}. +keyword-only parameters are treated as parameters to be parsed from the request, +but @typed_endpoint_without_parameters does not expect any. + +Use @typed_endpoint instead. +""" + + +def typed_endpoint_without_parameters( + view_func: Callable[Concatenate[HttpRequest, ParamT], ReturnT], +) -> Callable[Concatenate[HttpRequest, ParamT], ReturnT]: + return typed_endpoint(view_func, expect_no_parameters=True) + + +def typed_endpoint( + view_func: Callable[Concatenate[HttpRequest, ParamT], ReturnT], + *, + expect_no_parameters: bool = False, +) -> Callable[Concatenate[HttpRequest, ParamT], ReturnT]: + # Extract all the type information from the view function. + endpoint_info = parse_view_func_signature(view_func) + if expect_no_parameters: + assert len(endpoint_info.parameters) == 0, UNEXPECTED_KEYWORD_ONLY_PARAMETERS.format( + view_func_name=endpoint_info.view_func_full_name + ) + else: + assert ( + len(endpoint_info.parameters) != 0 + ), UNEXPECTEDLY_MISSING_KEYWORD_ONLY_PARAMETERS.format( + view_func_name=endpoint_info.view_func_full_name + ) + for func_param in endpoint_info.parameters: + assert not isinstance( + func_param.default, _REQ + ), f"Unexpected REQ for parameter {func_param.param_name}; REQ is incompatible with typed_endpoint" + if func_param.path_only: + assert ( + func_param.default is NotSpecified + ), f"Path-only parameter {func_param.param_name} should not have a default value" + # Record arguments that should be documented so that our + # automated OpenAPI docs tests can compare these against the code. + if ( + func_param.documentation_status is DocumentationStatus.DOCUMENTED + and not func_param.path_only + ): + # TODO: Move arguments_map to here once zerver.lib.request does not + # need it anymore. + arguments_map[endpoint_info.view_func_full_name].append(func_param.request_var_name) + + @wraps(view_func) + def _wrapped_view_func( + request: HttpRequest, /, *args: ParamT.args, **kwargs: ParamT.kwargs + ) -> ReturnT: + request_notes = RequestNotes.get_notes(request) + for parameter in endpoint_info.parameters: + if parameter.path_only: + # For path_only parameters, they should already have been passed via + # the URL, so there's no need for us to do anything. + # + # TODO: Run validators for path_only parameters for NewType. + assert ( + parameter.param_name in kwargs + ), f"Path-only variable {parameter.param_name} should be passed already" + if parameter.param_name in kwargs: + # Skip parameters that are already supplied by the caller. + continue + + # Extract the value to parse from the request body if specified. + if parameter.argument_type_is_body: + try: + request_notes.processed_parameters.add(parameter.request_var_name) + kwargs[parameter.param_name] = parse_value_for_parameter( + parameter, request.body.decode(request.encoding or "utf-8") + ) + except UnicodeDecodeError: + raise JsonableError(_("Malformed payload")) + # test_typed_endpoint.TestEndpoint.test_argument_type has + # coverage of this, but coverage.py fails to recognize it for + # some reason. + continue # nocoverage + + # Otherwise, try to find the matching request variable in one of the QueryDicts + # This is a view bug, not a user error, and thus should throw a 500. + possible_aliases = [parameter.request_var_name, *parameter.aliases] + alias_used = None + value_to_parse = None + + for current_alias in possible_aliases: + if current_alias in request.POST: + value_to_parse = request.POST[current_alias] + elif current_alias in request.GET: + value_to_parse = request.GET[current_alias] + else: + # This is covered by + # test_typed_endpoint.TestEndpoint.test_aliases, but + # coverage.py fails to recognize this for some reason. + continue # nocoverage + if alias_used is not None: + raise RequestConfusingParamsError(alias_used, current_alias) + alias_used = current_alias + + if alias_used is None: + alias_used = parameter.request_var_name + if parameter.default is NotSpecified: + raise RequestVariableMissingError(alias_used) + # By skipping here, we leave it to Python to use the default value + # of this parameter, because we cannot find the request variable in + # the request. + # This is tested test_typed_endpoint.TestEndpoint.test_json, but + # coverage.py fails to recognize this for some reason. + continue # nocoverage + + # Note that value_to_parse comes from a QueryDict, so it has no chance + # of having a user-provided None value. + assert value_to_parse is not None + request_notes.processed_parameters.add(alias_used) + kwargs[parameter.param_name] = parse_value_for_parameter(parameter, value_to_parse) + return_value = view_func(request, *args, **kwargs) + + if ( + isinstance(return_value, MutableJsonResponse) + # TODO: Move is_webhook_view to the decorator + and not request_notes.is_webhook_view + # Implemented only for 200 responses. + # TODO: Implement returning unsupported ignored parameters for 400 + # JSON error responses. This is complex because typed_endpoint can be + # called multiple times, so when an error response is raised, there + # may be supported parameters that have not yet been processed, + # which could lead to inaccurate output. + and 200 <= return_value.status_code < 300 + ): + ignored_parameters = set( + list(request.POST.keys()) + list(request.GET.keys()) + ).difference(request_notes.processed_parameters) + + # This will be called each time a function decorated with @typed_endpoint + # returns a MutableJsonResponse with a success status_code. Because + # a shared processed_parameters value is checked each time, the + # value for the ignored_parameters_unsupported key is either + # added/updated to the response data or it is removed in the case + # that all of the request parameters have been processed. + if ignored_parameters: + return_value.get_data()["ignored_parameters_unsupported"] = sorted( + ignored_parameters + ) + else: + return_value.get_data().pop("ignored_parameters_unsupported", None) + + return return_value + + # TODO: Remove this once we replace has_request_variables with typed_endpoint. + _wrapped_view_func.use_endpoint = True # type: ignore[attr-defined] # Distinguish functions decorated with @typed_endpoint from those decorated with has_request_variables + return _wrapped_view_func diff --git a/zerver/lib/validator.py b/zerver/lib/validator.py index 08d4977815..e02eecd35a 100644 --- a/zerver/lib/validator.py +++ b/zerver/lib/validator.py @@ -54,6 +54,8 @@ import orjson from django.core.exceptions import ValidationError from django.core.validators import URLValidator, validate_email from django.utils.translation import gettext as _ +from pydantic import ValidationInfo, model_validator +from pydantic.functional_validators import ModelWrapValidatorHandler from zerver.lib.exceptions import InvalidJSONError, JsonableError from zerver.lib.timezone import canonicalize_timezone @@ -632,6 +634,17 @@ class WildValue: var_name: str value: object + @model_validator(mode="wrap") # type: ignore[arg-type] # The upstream's type annotation uses a TypeVar that is incorrectly unbounded. + @classmethod + def to_wild_value( + cls, + value: object, + # We bypass the original WildValue handler to customize it + handler: ModelWrapValidatorHandler["WildValue"], + info: ValidationInfo, + ) -> "WildValue": + return wrap_wild_value("request", value) + def __bool__(self) -> bool: return bool(self.value) diff --git a/zerver/tests/test_typed_endpoint.py b/zerver/tests/test_typed_endpoint.py new file mode 100644 index 0000000000..4b5872dd76 --- /dev/null +++ b/zerver/tests/test_typed_endpoint.py @@ -0,0 +1,582 @@ +from typing import Any, Callable, Dict, List, Literal, Optional, TypeVar, Union + +import orjson +from django.core.exceptions import ValidationError as DjangoValidationError +from django.http import HttpRequest, HttpResponse +from pydantic import BaseModel, ConfigDict, Json, ValidationInfo, WrapValidator +from pydantic.dataclasses import dataclass +from pydantic.functional_validators import ModelWrapValidatorHandler +from typing_extensions import Annotated + +from zerver.lib.exceptions import ApiParamValidationError, JsonableError +from zerver.lib.request import RequestConfusingParamsError, RequestVariableMissingError +from zerver.lib.response import MutableJsonResponse, json_success +from zerver.lib.test_classes import ZulipTestCase +from zerver.lib.test_helpers import HostRequestMock +from zerver.lib.typed_endpoint import ( + ApiParamConfig, + DocumentationStatus, + PathOnly, + RequiredStringConstraint, + WebhookPayload, + typed_endpoint, + typed_endpoint_without_parameters, +) +from zerver.lib.validator import WildValue, check_bool +from zerver.models import UserProfile + +ParamTypes = Literal["none", "json_only", "both"] +T = TypeVar("T") + + +def call_endpoint( + view: Callable[..., T], request: HttpRequest, *args: object, **kwargs: object +) -> T: + """A helper to let us ignore the view function's signature""" + return view(request, *args, **kwargs) + + +class TestEndpoint(ZulipTestCase): + def test_coerce(self) -> None: + @typed_endpoint + def view(request: HttpRequest, *, strict_int: int) -> None: + ... + + with self.assertRaisesMessage(JsonableError, "strict_int is not an integer"): + call_endpoint(view, HostRequestMock({"strict_int": orjson.dumps("10").decode()})) + with self.assertRaisesMessage(JsonableError, "strict_int is not an integer"): + self.assertEqual(call_endpoint(view, HostRequestMock({"strict_int": 10})), 20) + + @typed_endpoint + def view2(request: HttpRequest, *, strict_int: Json[int]) -> int: + return strict_int * 2 + + with self.assertRaisesMessage(JsonableError, "strict_int is not an integer"): + call_endpoint(view2, HostRequestMock({"strict_int": orjson.dumps("10").decode()})) + # This is the same as orjson.dumps(10).decode() + self.assertEqual(call_endpoint(view2, HostRequestMock({"strict_int": "10"})), 20) + self.assertEqual(call_endpoint(view2, HostRequestMock({"strict_int": 10})), 20) + + def test_json(self) -> None: + @dataclass(frozen=True) + class Foo: + num1: int + num2: int + + __pydantic_config__ = ConfigDict(extra="forbid") + + @typed_endpoint + def view( + request: HttpRequest, + *, + json_int: Json[int], + json_str: Json[str], + json_data: Json[Foo], + json_optional: Optional[Json[Union[int, None]]] = None, + json_default: Json[Foo] = Foo(10, 10), + non_json: str = "ok", + non_json_optional: Optional[str] = None, + ) -> HttpResponse: + return MutableJsonResponse( + data={ + "result1": json_int * json_data.num1 * json_data.num2, + "result2": json_default.num1 * json_default.num2, + "optional": json_optional, + "str": json_str + non_json, + }, + content_type="application/json", + status=200, + ) + + response = call_endpoint( + view, + HostRequestMock( + post_data={ + "json_int": "2", + "json_str": orjson.dumps("asd").decode(), + "json_data": orjson.dumps({"num1": 5, "num2": 7}).decode(), + } + ), + ) + self.assertDictEqual( + orjson.loads(response.content), + {"result1": 70, "result2": 100, "str": "asdok", "optional": None}, + ) + + data = { + "json_int": "2", + "json_str": orjson.dumps("asd").decode(), + "json_data": orjson.dumps({"num1": 5, "num2": 7}).decode(), + "json_default": orjson.dumps({"num1": 3, "num2": 11}).decode(), + "json_optional": "5", + "non_json": "asd", + } + response = call_endpoint( + view, + HostRequestMock(post_data=data), + ) + self.assertDictEqual( + orjson.loads(response.content), + {"result1": 70, "result2": 33, "str": "asdasd", "optional": 5}, + ) + + request = HostRequestMock() + request.GET.update(data) + response = call_endpoint( + view, + request, + ) + self.assertDictEqual( + orjson.loads(response.content), + {"result1": 70, "result2": 33, "str": "asdasd", "optional": 5}, + ) + + with self.assertRaisesMessage(JsonableError, "json_int is not valid JSON"): + call_endpoint( + view, + HostRequestMock( + post_data={ + "json_int": "foo", + "json_str": "asd", + "json_data": orjson.dumps({"num1": 5, "num2": 7}).decode(), + } + ), + ) + with self.assertRaisesMessage(JsonableError, "json_str is not valid JSON"): + call_endpoint( + view, + HostRequestMock( + post_data={ + "json_int": 5, + "json_str": "asd", + "json_data": orjson.dumps({"num1": 5, "num2": 7}).decode(), + } + ), + ) + + with self.assertRaisesMessage(RequestVariableMissingError, "Missing 'json_int' argument"): + call_endpoint(view, HostRequestMock()) + + with self.assertRaisesMessage(JsonableError, "json_int is not an integer"): + call_endpoint( + view, + HostRequestMock( + { + "json_int": orjson.dumps(False).decode(), + "json_str": orjson.dumps("10").decode(), + "json_data": orjson.dumps({"num1": "a", "num2": "b"}).decode(), + } + ), + ) + + with self.assertRaisesMessage(JsonableError, 'json_data["num1"] is not an integer'): + call_endpoint( + view, + HostRequestMock( + { + "json_int": orjson.dumps(0).decode(), + "json_str": orjson.dumps("test").decode(), + "json_data": orjson.dumps({"num1": "10", "num2": 20}).decode(), + } + ), + ) + + response = call_endpoint( + view, + HostRequestMock( + post_data={ + "json_int": 5, + "json_str": orjson.dumps("asd").decode(), + "json_data": orjson.dumps({"num1": 5, "num2": 7}).decode(), + "json_optional": orjson.dumps(None).decode(), + "non_json_optional": None, + } + ), + json_optional="asd", + ) + # Note that json_optional is ignored because we have passed it as a kwarg already. + self.assertDictEqual( + orjson.loads(response.content), + { + "result1": 175, + "result2": 100, + "str": "asdok", + "optional": "asd", + "ignored_parameters_unsupported": ["json_optional"], + }, + ) + + with self.assertRaisesMessage( + JsonableError, 'Argument "unknown" at json_data["unknown"] is unexpected' + ): + call_endpoint( + view, + HostRequestMock( + { + "json_int": orjson.dumps(19).decode(), + "json_str": orjson.dumps("10").decode(), + "json_data": orjson.dumps({"num1": 1, "num2": 4, "unknown": "c"}).decode(), + } + ), + ) + + def test_whence(self) -> None: + @typed_endpoint + def whence_view( + request: HttpRequest, *, param: Annotated[str, ApiParamConfig(whence="foo")] + ) -> str: + return param + + with self.assertRaisesMessage(RequestVariableMissingError, "Missing 'foo' argument"): + call_endpoint(whence_view, HostRequestMock({"param": "hi"})) + + result = call_endpoint(whence_view, HostRequestMock({"foo": "hi"})) + self.assertEqual(result, "hi") + + def test_argument_type(self) -> None: + @typed_endpoint + def webhook( + request: HttpRequest, + *, + body: WebhookPayload[WildValue], + foo: Json[int], + bar: Json[int] = 0, + ) -> Dict[str, object]: + status = body["totame"]["status"].tame(check_bool) + return {"status": status, "foo": foo, "bar": bar} + + # Simulate a paylaod that uses JSON encoding. We use the body setter to + # overwrite the request body. The HostRequestMock initializer sets the + # POST QueryDict, which is normally done by Django by parsing the body. + data = {"foo": 15, "totame": {"status": True}} + request = HostRequestMock(data) + request.body = orjson.dumps(data) + result = call_endpoint(webhook, request) + self.assertDictEqual(result, {"status": True, "foo": 15, "bar": 0}) + + request.body = orjson.dumps([]) + with self.assertRaisesRegex(DjangoValidationError, "request is not a dict"): + result = call_endpoint(webhook, request) + + request.body = orjson.dumps(10) + with self.assertRaisesRegex(DjangoValidationError, "request is not a dict"): + result = call_endpoint(webhook, request) + + request = HostRequestMock() + request.GET.update({"foo": "15", "bar": "10"}) + request.body = orjson.dumps(data) + result = call_endpoint(webhook, request) + self.assertDictEqual(result, {"status": True, "foo": 15, "bar": 10}) + + with self.assertRaisesMessage(JsonableError, "Malformed JSON"): + call_endpoint(webhook, HostRequestMock()) + + with self.assertRaisesMessage(JsonableError, "Malformed payload"): + request = HostRequestMock() + # This body triggers UnicodeDecodeError + request.body = b"\x81" + call_endpoint(webhook, request) + + def test_path_only(self) -> None: + @typed_endpoint + def path_only( + request: HttpRequest, + *, + path_var: PathOnly[int], + other: Json[int], + ) -> MutableJsonResponse: + # Return a MutableJsonResponse to see parameters ignored + return json_success(request, data={"val": path_var + other}) + + response = call_endpoint(path_only, HostRequestMock(post_data={"other": 1}), path_var=20) + self.assert_json_success(response) + self.assertEqual(orjson.loads(response.content)["val"], 21) + + with self.assertRaisesMessage( + AssertionError, "Path-only variable path_var should be passed already" + ): + call_endpoint(path_only, HostRequestMock(post_data={"other": 1})) + + # Even if the path-only variable is present in the request data, it + # shouldn't be parsed either. + with self.assertRaisesMessage( + AssertionError, "Path-only variable path_var should be passed already" + ): + call_endpoint(path_only, HostRequestMock(post_data={"path_var": 15, "other": 1})) + + # path_var in the request body is ignored + response = call_endpoint( + path_only, HostRequestMock(post_data={"path_var": 15, "other": 1}), path_var=10 + ) + self.assert_json_success(response, ignored_parameters=["path_var"]) + self.assertEqual(orjson.loads(response.content)["val"], 11) + + def path_only_default( + request: HttpRequest, + *, + path_var_default: PathOnly[str] = "test", + ) -> None: + ... + + with self.assertRaisesMessage( + AssertionError, "Path-only parameter path_var_default should not have a default value" + ): + typed_endpoint(path_only_default) + + def test_documentation_status(self) -> None: + def documentation( + request: HttpRequest, + *, + foo: Annotated[ + str, + ApiParamConfig(documentation_status=DocumentationStatus.INTENTIONALLY_UNDOCUMENTED), + ], + bar: Annotated[ + str, ApiParamConfig(documentation_status=DocumentationStatus.DOCUMENTATION_PENDING) + ], + baz: Annotated[ + str, ApiParamConfig(documentation_status=DocumentationStatus.DOCUMENTED) + ], + paz: PathOnly[int], + other: str, + ) -> None: + ... + + from zerver.lib.request import arguments_map + + view_func_full_name = f"{documentation.__module__}.{documentation.__name__}" + typed_endpoint(documentation) + # Path-only and non DOCUMENTED parameters should not be added + self.assertEqual(arguments_map[view_func_full_name], ["baz", "other"]) + + def test_annotated(self) -> None: + @typed_endpoint + def valid_usage_of_api_param_config( + request: HttpRequest, + *, + foo: Annotated[ + Json[int], + ApiParamConfig(path_only=True), + ], + ) -> None: + ... + + def annotated_with_repeated_api_param_config( + request: HttpRequest, + user_profile: UserProfile, + *, + foo: Annotated[Json[int], ApiParamConfig(), ApiParamConfig()], + ) -> None: + ... + + with self.assertRaisesMessage( + AssertionError, "ApiParamConfig can only be defined once per parameter" + ): + typed_endpoint(annotated_with_repeated_api_param_config) + + @typed_endpoint + def annotated_with_extra_unrelated_metadata( + request: HttpRequest, + user_profile: UserProfile, + *, + foo: Annotated[Json[bool], str, "unrelated"], + ) -> bool: + return foo + + hamlet = self.example_user("hamlet") + result = call_endpoint( + annotated_with_extra_unrelated_metadata, + HostRequestMock({"foo": orjson.dumps(False).decode()}), + hamlet, + ) + self.assertFalse(result) + + def test_aliases(self) -> None: + @typed_endpoint + def view_with_aliased_parameter( + request: HttpRequest, *, topic: Annotated[str, ApiParamConfig(aliases=["legacy_topic"])] + ) -> HttpResponse: + return json_success(request, {"value": topic}) + + result = call_endpoint( + view_with_aliased_parameter, HostRequestMock({"topic": "topic is topic"}) + ) + value = self.assert_json_success(result)["value"] + self.assertEqual(value, "topic is topic") + + req = HostRequestMock({"topic": "topic is topic"}) + req.GET["legacy_topic"] = "topic is" + with self.assertRaisesMessage( + RequestConfusingParamsError, "Can't decide between 'topic' and 'legacy_topic' arguments" + ): + call_endpoint(view_with_aliased_parameter, req) + + with self.assertRaisesMessage( + RequestConfusingParamsError, "Can't decide between 'topic' and 'legacy_topic' arguments" + ): + call_endpoint( + view_with_aliased_parameter, + HostRequestMock({"topic": "test", "legacy_topic": "test2"}), + ) + + result = call_endpoint( + view_with_aliased_parameter, HostRequestMock({"legacy_topic": "legacy_topic is topic"}) + ) + value = self.assert_json_success(result)["value"] + self.assertEqual(value, "legacy_topic is topic") + + result = call_endpoint( + view_with_aliased_parameter, + HostRequestMock( + {"legacy_topic": "legacy_topic is topic", "ignored": "extra parameter"} + ), + ) + value = self.assert_json_success(result, ignored_parameters=["ignored"])["value"] + self.assertEqual(value, "legacy_topic is topic") + + # aliases should work in combination with whence + @typed_endpoint + def view_with_aliased_and_whenced_parameter( + request: HttpRequest, + *, + topic: Annotated[str, ApiParamConfig(whence="topic_name", aliases=["legacy_topic"])], + ) -> HttpResponse: + return json_success(request, {"value": topic}) + + result = call_endpoint( + view_with_aliased_and_whenced_parameter, + HostRequestMock({"legacy_topic": "legacy_topic is topic", "topic": "extra parameter"}), + ) + value = self.assert_json_success(result, ignored_parameters=["topic"])["value"] + self.assertEqual(value, "legacy_topic is topic") + + with self.assertRaisesMessage( + RequestConfusingParamsError, + "Can't decide between 'topic_name' and 'legacy_topic' arguments", + ): + call_endpoint( + view_with_aliased_and_whenced_parameter, + HostRequestMock({"topic_name": "test", "legacy_topic": "test2"}), + ) + + def test_expect_no_parameters(self) -> None: + def no_parameter(request: HttpRequest) -> None: + ... + + def has_parameters(request: HttpRequest, *, foo: int, bar: str) -> None: + ... + + with self.assertRaisesRegex(AssertionError, "there is no keyword-only parameter found"): + typed_endpoint(no_parameter) + # No assertion errors expected + typed_endpoint(has_parameters) + + with self.assertRaisesMessage(AssertionError, "Unexpected keyword-only parameters found"): + typed_endpoint_without_parameters(has_parameters) + # No assertion errors expected + typed_endpoint_without_parameters(no_parameter) + + def test_custom_validator(self) -> None: + @dataclass + class CustomType: + val: int + + def validate_custom_type( + value: object, + handler: ModelWrapValidatorHandler[CustomType], + info: ValidationInfo, + ) -> CustomType: + return CustomType(42) + + @typed_endpoint + def test_view( + request: HttpRequest, *, foo: Annotated[CustomType, WrapValidator(validate_custom_type)] + ) -> None: + self.assertEqual(foo.val, 42) + + call_endpoint(test_view, HostRequestMock({"foo": ""})) + + +class ValidationErrorHandlingTest(ZulipTestCase): + def test_special_handling_errors(self) -> None: + """Test for errors that require special handling beyond an ERROR_TEMPLATES lookup. + Not all error types need to be tested here.""" + + @dataclass + class DataFoo: + __pydantic_config__ = ConfigDict(extra="forbid") + message: str + + class DataModel(BaseModel): + model_config = ConfigDict(extra="forbid") + message: str + + @dataclass + class SubTest: + """This describes a parameterized test case + for our handling of Pydantic validation errors""" + + # The type of the error, can be found at + # https://docs.pydantic.dev/latest/errors/validation_errors/ + error_type: str + # The type of the parameter. We set on a view function decorated + # with @typed_endpoint for a parameter named "input". + param_type: object + # Because QueryDict always converts the data into a str, this + # conversion can be unexpected so we ask the caller to convert + # input_data to str explicitly beforehand. The input data is + # automatically set to POST["input"] in the mock request. + input_data: str + # The exact error message we expect from the ApiValidationError + # raised when the view function is called with the provided input + # data. + error_message: str + + def __repr__(self) -> str: + return f"Pydantic error type: {self.error_type}; Parameter type: {self.param_type}; Expected error message: {self.error_message}" + + parameterized_tests: List[SubTest] = [ + SubTest( + error_type="string_too_short", + param_type=Json[List[Annotated[str, RequiredStringConstraint()]]], + input_data=orjson.dumps([""]).decode(), + error_message="input[0] cannot be blank", + ), + SubTest( + error_type="string_too_short", + param_type=Json[List[Annotated[str, RequiredStringConstraint()]]], + input_data=orjson.dumps(["g", " "]).decode(), + error_message="input[1] cannot be blank", + ), + SubTest( + error_type="unexpected_keyword_argument", + param_type=Json[DataFoo], + input_data=orjson.dumps({"message": "asd", "test": ""}).decode(), + error_message='Argument "test" at input["test"] is unexpected', + ), + SubTest( + error_type="extra_forbidden", + param_type=Json[DataModel], + input_data=orjson.dumps({"message": "asd", "test": ""}).decode(), + error_message='Argument "test" at input["test"] is unexpected', + ), + ] + + for index, subtest in enumerate(parameterized_tests): + subtest_title = f"Subtest #{index + 1}: {subtest!r}" + with self.subTest(subtest_title): + # We use Any here so that we don't perform unnecessary type + # checking. + # Without this, mypy crashes with an internal error: + # INTERNAL ERROR: maximum semantic analysis iteration count reached + input_type: Any = subtest.param_type + + @typed_endpoint + def func(request: HttpRequest, *, input: input_type) -> None: + ... + + with self.assertRaises(ApiParamValidationError) as m: + call_endpoint(func, HostRequestMock({"input": subtest.input_data})) + + self.assertEqual(m.exception.msg, subtest.error_message) + self.assertEqual(m.exception.error_type, subtest.error_type)