From 9c539958302ddfed7f66f722bf8f9eb960eed109 Mon Sep 17 00:00:00 2001 From: Zixuan James Li Date: Fri, 11 Aug 2023 18:03:37 -0400 Subject: [PATCH] alert_words: Migrate alert_words to use @typed_endpoint. This demonstrates some basic use cases of the Json[...] wrapper with @typed_endpoint. Along with this change we extend test_openapi so that schema checking based on function signatures will still work with this new decorator. Pydantic's TypeAdapter supports dumping the JSON schema of any given type, which is leveraged here to validate against our own OpenAPI definitions. Parts of the implementation will be covered in later commits as we migrate more functions to use @typed_endpoint. See also: https://docs.pydantic.dev/latest/api/type_adapter/#pydantic.type_adapter.TypeAdapter.json_schema For the OpenAPI schema, we preprocess it mostly the same way. For the parameter types though, we no longer need to use get_standardized_argument_type to normalize type annotation, because Pydantic dumps a JSON schema that is compliant with OpenAPI schema already, which makes it a lot convenient for us to compare the types with our OpenAPI definitions. Do note that there are some exceptions where our definitions do not match the generated one. For example, we use JSON to parse int and bool parameters, but we don't mark them to use "application/json" in our definitions. --- zerver/tests/test_alert_words.py | 6 ++ zerver/tests/test_openapi.py | 120 ++++++++++++++++++++++++++----- zerver/views/alert_words.py | 15 ++-- 3 files changed, 119 insertions(+), 22 deletions(-) diff --git a/zerver/tests/test_alert_words.py b/zerver/tests/test_alert_words.py index 3094cb9024..69151c6f56 100644 --- a/zerver/tests/test_alert_words.py +++ b/zerver/tests/test_alert_words.py @@ -139,6 +139,12 @@ class AlertWordTests(ZulipTestCase): response_dict = self.assert_json_success(result) self.assertEqual(set(response_dict["alert_words"]), {"one", "two", "three"}) + result = self.client_post( + "/json/users/me/alert_words", + {"alert_words": orjson.dumps(["long" * 26]).decode()}, + ) + self.assert_json_error(result, "alert_words[0] is too long (limit: 100 characters)") + def test_json_list_remove(self) -> None: user = self.get_user() self.login_user(user) diff --git a/zerver/tests/test_openapi.py b/zerver/tests/test_openapi.py index 549925179b..4c66bc51a5 100644 --- a/zerver/tests/test_openapi.py +++ b/zerver/tests/test_openapi.py @@ -20,10 +20,12 @@ import yaml from django.http import HttpResponse from django.urls import URLPattern from django.utils import regex_helper +from pydantic import TypeAdapter from zerver.lib.request import _REQ, arguments_map from zerver.lib.rest import rest_dispatch from zerver.lib.test_classes import ZulipTestCase +from zerver.lib.typed_endpoint import parse_view_func_signature from zerver.lib.utils import assert_is_not_none from zerver.openapi.markdown_extension import generate_curl_example, render_curl_example from zerver.openapi.openapi import ( @@ -60,6 +62,12 @@ def schema_type(schema: Dict[str, Any]) -> Union[type, Tuple[type, object]]: # Hack: Just use the type of the first value # Ideally, we'd turn this into a Union type. return schema_type(schema["oneOf"][0]) + elif "anyOf" in schema: + return schema_type(schema["anyOf"][0]) # nocoverage + elif schema.get("contentMediaType") == "application/json": + return schema_type( + schema["contentSchema"] + ) # nocoverage # Will be covered as more endpoints are migrated elif schema["type"] == "array": return (list, schema_type(schema["items"])) else: @@ -389,6 +397,88 @@ do not match the types declared in the implementation of {function.__name__}.\n" msg += f"{vname:<10}{opvtype!s:^30}{fdvtype!s:>10}\n" raise AssertionError(msg) + def validate_json_schema( + self, function: Callable[..., HttpResponse], openapi_parameters: List[Dict[str, Any]] + ) -> None: + """Validate against the Pydantic generated JSON schema against our OpenAPI definitions""" + USE_JSON_CONTENT_TYPE_HINT = f""" + The view function {{param_name}} should accept JSON input. + Consider wrapping the type annotation of the parameter in Json. + For example: + + from pydantic import Json + ... + @typed_endpoint + def {function.__name__}( + request: HttpRequest, + *, + {{param_name}}: Json[{{param_type}}] = ..., + ) -> ...: +""" + # The set of tuples containing the var name and type pairs extracted + # from the function signature. + function_params = set() + # The set of tuples containing the var name and type pairs extracted + # from OpenAPI. + openapi_params = set() + # The names of request variables that should have a content type of + # application/json according to our OpenAPI definitions. + json_request_var_names = set() + for expected_param_schema in openapi_parameters: + # We differentiate JSON and non-JSON parameters here. Because + # application/json is the only content type to be verify in the API, + # we assume that as long as "content" is present in the OpenAPI + # spec, the content type should be JSON. + expected_request_var_name = expected_param_schema["name"] + if "content" in expected_param_schema: + expected_param_schema = expected_param_schema["content"]["application/json"][ + "schema" + ] + json_request_var_names.add(expected_request_var_name) + else: + expected_param_schema = expected_param_schema[ + "schema" + ] # nocoverage # Will be covered as more endpoints are migrated + + openapi_params.add((expected_request_var_name, schema_type(expected_param_schema))) + + for actual_param in parse_view_func_signature(function).parameters: + actual_param_schema = TypeAdapter(actual_param.param_type).json_schema() + # The content type of the JSON schema generated from the + # function parameter type annotation should have content type + # matching that of our OpenAPI spec. If not so, hint that the + # Json[T] wrapper might be missing from the type annotation. + if actual_param.request_var_name in json_request_var_names: + self.assertEqual( + actual_param_schema.get("contentMediaType"), + "application/json", + USE_JSON_CONTENT_TYPE_HINT.format( + param_name=actual_param.param_name, + param_type=actual_param.param_type, + ), + ) + # actual_param_schema is a json_schema. Reference: + # https://docs.pydantic.dev/latest/api/json_schema/#pydantic.json_schema.GenerateJsonSchema.json_schema + actual_param_schema = actual_param_schema["contentSchema"] + elif ( + "contentMediaType" in actual_param_schema + ): # nocoverage # Will be covered as more endpoints are migrated + function_schema_type = schema_type(actual_param_schema) + # We do not specify that the content type of int or bool + # parameters should be JSON encoded, while our code does expect + # that. In this case, we exempt this parameter from the content + # type check. + self.assertIn( + function_schema_type, + (int, bool), + f'\nUnexpected content type {actual_param_schema["contentMediaType"]} on function parameter {actual_param.param_name}, which does not match the OpenAPI definition.', + ) + function_params.add((actual_param.request_var_name, schema_type(actual_param_schema))) + + diff = openapi_params - function_params + if diff: # nocoverage + self.render_openapi_type_exception(function, openapi_params, function_params, diff) + def check_argument_types( self, function: Callable[..., HttpResponse], openapi_parameters: List[Dict[str, Any]] ) -> None: @@ -397,6 +487,20 @@ do not match the types declared in the implementation of {function.__name__}.\n" OpenAPI data defines a different type than that actually accepted by the function. Otherwise, we print out the exact differences for convenient debugging and raise an AssertionError.""" + # Iterate through the decorators to find the original function, wrapped + # by has_request_variables/typed_endpoint, so we can parse its + # arguments. + use_endpoint_decorator = False + while (wrapped := getattr(function, "__wrapped__", None)) is not None: + # TODO: Remove this check once we replace has_request_variables with + # typed_endpoint. + if getattr(function, "use_endpoint", False): + use_endpoint_decorator = True + function = wrapped + + if use_endpoint_decorator: + return self.validate_json_schema(function, openapi_parameters) + openapi_params: Set[Tuple[str, Union[type, Tuple[type, object]]]] = set() json_params: Dict[str, Union[type, Tuple[type, object]]] = {} for element in openapi_parameters: @@ -424,22 +528,6 @@ do not match the types declared in the implementation of {function.__name__}.\n" function_params: Set[Tuple[str, Union[type, Tuple[type, object]]]] = set() - # Iterate through the decorators to find the original - # function, wrapped by has_request_variables, so we can parse - # its arguments. - while (wrapped := getattr(function, "__wrapped__", None)) is not None: - function = wrapped - - # Now, we do inference mapping each REQ parameter's - # declaration details to the Python/mypy types for the - # arguments passed to it. - # - # Because the mypy types are the types used inside the inner - # function (after the original data is processed by any - # validators, converters, etc.), they will not always match - # the API-level argument types. The main case where this - # happens is when a `converter` is used that changes the types - # of its parameters. for pname, defval in inspect.signature(function).parameters.items(): defval = defval.default if isinstance(defval, _REQ): diff --git a/zerver/views/alert_words.py b/zerver/views/alert_words.py index 9dd3abdce5..63cb97119c 100644 --- a/zerver/views/alert_words.py +++ b/zerver/views/alert_words.py @@ -1,12 +1,13 @@ from typing import List from django.http import HttpRequest, HttpResponse +from pydantic import Json, StringConstraints +from typing_extensions import Annotated from zerver.actions.alert_words import do_add_alert_words, do_remove_alert_words from zerver.lib.alert_words import user_alert_words -from zerver.lib.request import REQ, has_request_variables from zerver.lib.response import json_success -from zerver.lib.validator import check_capped_string, check_list, check_string +from zerver.lib.typed_endpoint import typed_endpoint from zerver.models import UserProfile @@ -19,21 +20,23 @@ def clean_alert_words(alert_words: List[str]) -> List[str]: return [w for w in alert_words if w != ""] -@has_request_variables +@typed_endpoint def add_alert_words( request: HttpRequest, user_profile: UserProfile, - alert_words: List[str] = REQ(json_validator=check_list(check_capped_string(100))), + *, + alert_words: Json[List[Annotated[str, StringConstraints(max_length=100)]]], ) -> HttpResponse: do_add_alert_words(user_profile, clean_alert_words(alert_words)) return json_success(request, data={"alert_words": user_alert_words(user_profile)}) -@has_request_variables +@typed_endpoint def remove_alert_words( request: HttpRequest, user_profile: UserProfile, - alert_words: List[str] = REQ(json_validator=check_list(check_string)), + *, + alert_words: Json[List[str]], ) -> HttpResponse: do_remove_alert_words(user_profile, alert_words) return json_success(request, data={"alert_words": user_alert_words(user_profile)})