diff --git a/zerver/tests/test_alert_words.py b/zerver/tests/test_alert_words.py index 3094cb9024..69151c6f56 100644 --- a/zerver/tests/test_alert_words.py +++ b/zerver/tests/test_alert_words.py @@ -139,6 +139,12 @@ class AlertWordTests(ZulipTestCase): response_dict = self.assert_json_success(result) self.assertEqual(set(response_dict["alert_words"]), {"one", "two", "three"}) + result = self.client_post( + "/json/users/me/alert_words", + {"alert_words": orjson.dumps(["long" * 26]).decode()}, + ) + self.assert_json_error(result, "alert_words[0] is too long (limit: 100 characters)") + def test_json_list_remove(self) -> None: user = self.get_user() self.login_user(user) diff --git a/zerver/tests/test_openapi.py b/zerver/tests/test_openapi.py index 549925179b..4c66bc51a5 100644 --- a/zerver/tests/test_openapi.py +++ b/zerver/tests/test_openapi.py @@ -20,10 +20,12 @@ import yaml from django.http import HttpResponse from django.urls import URLPattern from django.utils import regex_helper +from pydantic import TypeAdapter from zerver.lib.request import _REQ, arguments_map from zerver.lib.rest import rest_dispatch from zerver.lib.test_classes import ZulipTestCase +from zerver.lib.typed_endpoint import parse_view_func_signature from zerver.lib.utils import assert_is_not_none from zerver.openapi.markdown_extension import generate_curl_example, render_curl_example from zerver.openapi.openapi import ( @@ -60,6 +62,12 @@ def schema_type(schema: Dict[str, Any]) -> Union[type, Tuple[type, object]]: # Hack: Just use the type of the first value # Ideally, we'd turn this into a Union type. return schema_type(schema["oneOf"][0]) + elif "anyOf" in schema: + return schema_type(schema["anyOf"][0]) # nocoverage + elif schema.get("contentMediaType") == "application/json": + return schema_type( + schema["contentSchema"] + ) # nocoverage # Will be covered as more endpoints are migrated elif schema["type"] == "array": return (list, schema_type(schema["items"])) else: @@ -389,6 +397,88 @@ do not match the types declared in the implementation of {function.__name__}.\n" msg += f"{vname:<10}{opvtype!s:^30}{fdvtype!s:>10}\n" raise AssertionError(msg) + def validate_json_schema( + self, function: Callable[..., HttpResponse], openapi_parameters: List[Dict[str, Any]] + ) -> None: + """Validate against the Pydantic generated JSON schema against our OpenAPI definitions""" + USE_JSON_CONTENT_TYPE_HINT = f""" + The view function {{param_name}} should accept JSON input. + Consider wrapping the type annotation of the parameter in Json. + For example: + + from pydantic import Json + ... + @typed_endpoint + def {function.__name__}( + request: HttpRequest, + *, + {{param_name}}: Json[{{param_type}}] = ..., + ) -> ...: +""" + # The set of tuples containing the var name and type pairs extracted + # from the function signature. + function_params = set() + # The set of tuples containing the var name and type pairs extracted + # from OpenAPI. + openapi_params = set() + # The names of request variables that should have a content type of + # application/json according to our OpenAPI definitions. + json_request_var_names = set() + for expected_param_schema in openapi_parameters: + # We differentiate JSON and non-JSON parameters here. Because + # application/json is the only content type to be verify in the API, + # we assume that as long as "content" is present in the OpenAPI + # spec, the content type should be JSON. + expected_request_var_name = expected_param_schema["name"] + if "content" in expected_param_schema: + expected_param_schema = expected_param_schema["content"]["application/json"][ + "schema" + ] + json_request_var_names.add(expected_request_var_name) + else: + expected_param_schema = expected_param_schema[ + "schema" + ] # nocoverage # Will be covered as more endpoints are migrated + + openapi_params.add((expected_request_var_name, schema_type(expected_param_schema))) + + for actual_param in parse_view_func_signature(function).parameters: + actual_param_schema = TypeAdapter(actual_param.param_type).json_schema() + # The content type of the JSON schema generated from the + # function parameter type annotation should have content type + # matching that of our OpenAPI spec. If not so, hint that the + # Json[T] wrapper might be missing from the type annotation. + if actual_param.request_var_name in json_request_var_names: + self.assertEqual( + actual_param_schema.get("contentMediaType"), + "application/json", + USE_JSON_CONTENT_TYPE_HINT.format( + param_name=actual_param.param_name, + param_type=actual_param.param_type, + ), + ) + # actual_param_schema is a json_schema. Reference: + # https://docs.pydantic.dev/latest/api/json_schema/#pydantic.json_schema.GenerateJsonSchema.json_schema + actual_param_schema = actual_param_schema["contentSchema"] + elif ( + "contentMediaType" in actual_param_schema + ): # nocoverage # Will be covered as more endpoints are migrated + function_schema_type = schema_type(actual_param_schema) + # We do not specify that the content type of int or bool + # parameters should be JSON encoded, while our code does expect + # that. In this case, we exempt this parameter from the content + # type check. + self.assertIn( + function_schema_type, + (int, bool), + f'\nUnexpected content type {actual_param_schema["contentMediaType"]} on function parameter {actual_param.param_name}, which does not match the OpenAPI definition.', + ) + function_params.add((actual_param.request_var_name, schema_type(actual_param_schema))) + + diff = openapi_params - function_params + if diff: # nocoverage + self.render_openapi_type_exception(function, openapi_params, function_params, diff) + def check_argument_types( self, function: Callable[..., HttpResponse], openapi_parameters: List[Dict[str, Any]] ) -> None: @@ -397,6 +487,20 @@ do not match the types declared in the implementation of {function.__name__}.\n" OpenAPI data defines a different type than that actually accepted by the function. Otherwise, we print out the exact differences for convenient debugging and raise an AssertionError.""" + # Iterate through the decorators to find the original function, wrapped + # by has_request_variables/typed_endpoint, so we can parse its + # arguments. + use_endpoint_decorator = False + while (wrapped := getattr(function, "__wrapped__", None)) is not None: + # TODO: Remove this check once we replace has_request_variables with + # typed_endpoint. + if getattr(function, "use_endpoint", False): + use_endpoint_decorator = True + function = wrapped + + if use_endpoint_decorator: + return self.validate_json_schema(function, openapi_parameters) + openapi_params: Set[Tuple[str, Union[type, Tuple[type, object]]]] = set() json_params: Dict[str, Union[type, Tuple[type, object]]] = {} for element in openapi_parameters: @@ -424,22 +528,6 @@ do not match the types declared in the implementation of {function.__name__}.\n" function_params: Set[Tuple[str, Union[type, Tuple[type, object]]]] = set() - # Iterate through the decorators to find the original - # function, wrapped by has_request_variables, so we can parse - # its arguments. - while (wrapped := getattr(function, "__wrapped__", None)) is not None: - function = wrapped - - # Now, we do inference mapping each REQ parameter's - # declaration details to the Python/mypy types for the - # arguments passed to it. - # - # Because the mypy types are the types used inside the inner - # function (after the original data is processed by any - # validators, converters, etc.), they will not always match - # the API-level argument types. The main case where this - # happens is when a `converter` is used that changes the types - # of its parameters. for pname, defval in inspect.signature(function).parameters.items(): defval = defval.default if isinstance(defval, _REQ): diff --git a/zerver/views/alert_words.py b/zerver/views/alert_words.py index 9dd3abdce5..63cb97119c 100644 --- a/zerver/views/alert_words.py +++ b/zerver/views/alert_words.py @@ -1,12 +1,13 @@ from typing import List from django.http import HttpRequest, HttpResponse +from pydantic import Json, StringConstraints +from typing_extensions import Annotated from zerver.actions.alert_words import do_add_alert_words, do_remove_alert_words from zerver.lib.alert_words import user_alert_words -from zerver.lib.request import REQ, has_request_variables from zerver.lib.response import json_success -from zerver.lib.validator import check_capped_string, check_list, check_string +from zerver.lib.typed_endpoint import typed_endpoint from zerver.models import UserProfile @@ -19,21 +20,23 @@ def clean_alert_words(alert_words: List[str]) -> List[str]: return [w for w in alert_words if w != ""] -@has_request_variables +@typed_endpoint def add_alert_words( request: HttpRequest, user_profile: UserProfile, - alert_words: List[str] = REQ(json_validator=check_list(check_capped_string(100))), + *, + alert_words: Json[List[Annotated[str, StringConstraints(max_length=100)]]], ) -> HttpResponse: do_add_alert_words(user_profile, clean_alert_words(alert_words)) return json_success(request, data={"alert_words": user_alert_words(user_profile)}) -@has_request_variables +@typed_endpoint def remove_alert_words( request: HttpRequest, user_profile: UserProfile, - alert_words: List[str] = REQ(json_validator=check_list(check_string)), + *, + alert_words: Json[List[str]], ) -> HttpResponse: do_remove_alert_words(user_profile, alert_words) return json_success(request, data={"alert_words": user_alert_words(user_profile)})