alert_words: Migrate alert_words to use @typed_endpoint.

This demonstrates some basic use cases of the Json[...] wrapper with
@typed_endpoint.

Along with this change we extend test_openapi so that schema checking
based on function signatures will still work with this new decorator.
Pydantic's TypeAdapter supports dumping the JSON schema of any given type,
which is leveraged here to validate against our own OpenAPI definitions.
Parts of the implementation will be covered in later commits as we
migrate more functions to use @typed_endpoint.

See also:
https://docs.pydantic.dev/latest/api/type_adapter/#pydantic.type_adapter.TypeAdapter.json_schema

For the OpenAPI schema, we preprocess it mostly the same way. For the
parameter types though, we no longer need to use
get_standardized_argument_type to normalize type annotation, because
Pydantic dumps a JSON schema that is compliant with OpenAPI schema
already, which makes it a lot convenient for us to compare the types
with our OpenAPI definitions.

Do note that there are some exceptions where our definitions do not match
the generated one. For example, we use JSON to parse int and bool parameters,
but we don't mark them to use "application/json" in our definitions.
This commit is contained in:
Zixuan James Li 2023-08-11 18:03:37 -04:00 committed by Tim Abbott
parent c336bf0398
commit 9c53995830
3 changed files with 119 additions and 22 deletions

View File

@ -139,6 +139,12 @@ class AlertWordTests(ZulipTestCase):
response_dict = self.assert_json_success(result) response_dict = self.assert_json_success(result)
self.assertEqual(set(response_dict["alert_words"]), {"one", "two", "three"}) self.assertEqual(set(response_dict["alert_words"]), {"one", "two", "three"})
result = self.client_post(
"/json/users/me/alert_words",
{"alert_words": orjson.dumps(["long" * 26]).decode()},
)
self.assert_json_error(result, "alert_words[0] is too long (limit: 100 characters)")
def test_json_list_remove(self) -> None: def test_json_list_remove(self) -> None:
user = self.get_user() user = self.get_user()
self.login_user(user) self.login_user(user)

View File

@ -20,10 +20,12 @@ import yaml
from django.http import HttpResponse from django.http import HttpResponse
from django.urls import URLPattern from django.urls import URLPattern
from django.utils import regex_helper from django.utils import regex_helper
from pydantic import TypeAdapter
from zerver.lib.request import _REQ, arguments_map from zerver.lib.request import _REQ, arguments_map
from zerver.lib.rest import rest_dispatch from zerver.lib.rest import rest_dispatch
from zerver.lib.test_classes import ZulipTestCase from zerver.lib.test_classes import ZulipTestCase
from zerver.lib.typed_endpoint import parse_view_func_signature
from zerver.lib.utils import assert_is_not_none from zerver.lib.utils import assert_is_not_none
from zerver.openapi.markdown_extension import generate_curl_example, render_curl_example from zerver.openapi.markdown_extension import generate_curl_example, render_curl_example
from zerver.openapi.openapi import ( from zerver.openapi.openapi import (
@ -60,6 +62,12 @@ def schema_type(schema: Dict[str, Any]) -> Union[type, Tuple[type, object]]:
# Hack: Just use the type of the first value # Hack: Just use the type of the first value
# Ideally, we'd turn this into a Union type. # Ideally, we'd turn this into a Union type.
return schema_type(schema["oneOf"][0]) return schema_type(schema["oneOf"][0])
elif "anyOf" in schema:
return schema_type(schema["anyOf"][0]) # nocoverage
elif schema.get("contentMediaType") == "application/json":
return schema_type(
schema["contentSchema"]
) # nocoverage # Will be covered as more endpoints are migrated
elif schema["type"] == "array": elif schema["type"] == "array":
return (list, schema_type(schema["items"])) return (list, schema_type(schema["items"]))
else: else:
@ -389,6 +397,88 @@ do not match the types declared in the implementation of {function.__name__}.\n"
msg += f"{vname:<10}{opvtype!s:^30}{fdvtype!s:>10}\n" msg += f"{vname:<10}{opvtype!s:^30}{fdvtype!s:>10}\n"
raise AssertionError(msg) raise AssertionError(msg)
def validate_json_schema(
self, function: Callable[..., HttpResponse], openapi_parameters: List[Dict[str, Any]]
) -> None:
"""Validate against the Pydantic generated JSON schema against our OpenAPI definitions"""
USE_JSON_CONTENT_TYPE_HINT = f"""
The view function {{param_name}} should accept JSON input.
Consider wrapping the type annotation of the parameter in Json.
For example:
from pydantic import Json
...
@typed_endpoint
def {function.__name__}(
request: HttpRequest,
*,
{{param_name}}: Json[{{param_type}}] = ...,
) -> ...:
"""
# The set of tuples containing the var name and type pairs extracted
# from the function signature.
function_params = set()
# The set of tuples containing the var name and type pairs extracted
# from OpenAPI.
openapi_params = set()
# The names of request variables that should have a content type of
# application/json according to our OpenAPI definitions.
json_request_var_names = set()
for expected_param_schema in openapi_parameters:
# We differentiate JSON and non-JSON parameters here. Because
# application/json is the only content type to be verify in the API,
# we assume that as long as "content" is present in the OpenAPI
# spec, the content type should be JSON.
expected_request_var_name = expected_param_schema["name"]
if "content" in expected_param_schema:
expected_param_schema = expected_param_schema["content"]["application/json"][
"schema"
]
json_request_var_names.add(expected_request_var_name)
else:
expected_param_schema = expected_param_schema[
"schema"
] # nocoverage # Will be covered as more endpoints are migrated
openapi_params.add((expected_request_var_name, schema_type(expected_param_schema)))
for actual_param in parse_view_func_signature(function).parameters:
actual_param_schema = TypeAdapter(actual_param.param_type).json_schema()
# The content type of the JSON schema generated from the
# function parameter type annotation should have content type
# matching that of our OpenAPI spec. If not so, hint that the
# Json[T] wrapper might be missing from the type annotation.
if actual_param.request_var_name in json_request_var_names:
self.assertEqual(
actual_param_schema.get("contentMediaType"),
"application/json",
USE_JSON_CONTENT_TYPE_HINT.format(
param_name=actual_param.param_name,
param_type=actual_param.param_type,
),
)
# actual_param_schema is a json_schema. Reference:
# https://docs.pydantic.dev/latest/api/json_schema/#pydantic.json_schema.GenerateJsonSchema.json_schema
actual_param_schema = actual_param_schema["contentSchema"]
elif (
"contentMediaType" in actual_param_schema
): # nocoverage # Will be covered as more endpoints are migrated
function_schema_type = schema_type(actual_param_schema)
# We do not specify that the content type of int or bool
# parameters should be JSON encoded, while our code does expect
# that. In this case, we exempt this parameter from the content
# type check.
self.assertIn(
function_schema_type,
(int, bool),
f'\nUnexpected content type {actual_param_schema["contentMediaType"]} on function parameter {actual_param.param_name}, which does not match the OpenAPI definition.',
)
function_params.add((actual_param.request_var_name, schema_type(actual_param_schema)))
diff = openapi_params - function_params
if diff: # nocoverage
self.render_openapi_type_exception(function, openapi_params, function_params, diff)
def check_argument_types( def check_argument_types(
self, function: Callable[..., HttpResponse], openapi_parameters: List[Dict[str, Any]] self, function: Callable[..., HttpResponse], openapi_parameters: List[Dict[str, Any]]
) -> None: ) -> None:
@ -397,6 +487,20 @@ do not match the types declared in the implementation of {function.__name__}.\n"
OpenAPI data defines a different type than that actually accepted by the function. OpenAPI data defines a different type than that actually accepted by the function.
Otherwise, we print out the exact differences for convenient debugging and raise an Otherwise, we print out the exact differences for convenient debugging and raise an
AssertionError.""" AssertionError."""
# Iterate through the decorators to find the original function, wrapped
# by has_request_variables/typed_endpoint, so we can parse its
# arguments.
use_endpoint_decorator = False
while (wrapped := getattr(function, "__wrapped__", None)) is not None:
# TODO: Remove this check once we replace has_request_variables with
# typed_endpoint.
if getattr(function, "use_endpoint", False):
use_endpoint_decorator = True
function = wrapped
if use_endpoint_decorator:
return self.validate_json_schema(function, openapi_parameters)
openapi_params: Set[Tuple[str, Union[type, Tuple[type, object]]]] = set() openapi_params: Set[Tuple[str, Union[type, Tuple[type, object]]]] = set()
json_params: Dict[str, Union[type, Tuple[type, object]]] = {} json_params: Dict[str, Union[type, Tuple[type, object]]] = {}
for element in openapi_parameters: for element in openapi_parameters:
@ -424,22 +528,6 @@ do not match the types declared in the implementation of {function.__name__}.\n"
function_params: Set[Tuple[str, Union[type, Tuple[type, object]]]] = set() function_params: Set[Tuple[str, Union[type, Tuple[type, object]]]] = set()
# Iterate through the decorators to find the original
# function, wrapped by has_request_variables, so we can parse
# its arguments.
while (wrapped := getattr(function, "__wrapped__", None)) is not None:
function = wrapped
# Now, we do inference mapping each REQ parameter's
# declaration details to the Python/mypy types for the
# arguments passed to it.
#
# Because the mypy types are the types used inside the inner
# function (after the original data is processed by any
# validators, converters, etc.), they will not always match
# the API-level argument types. The main case where this
# happens is when a `converter` is used that changes the types
# of its parameters.
for pname, defval in inspect.signature(function).parameters.items(): for pname, defval in inspect.signature(function).parameters.items():
defval = defval.default defval = defval.default
if isinstance(defval, _REQ): if isinstance(defval, _REQ):

View File

@ -1,12 +1,13 @@
from typing import List from typing import List
from django.http import HttpRequest, HttpResponse from django.http import HttpRequest, HttpResponse
from pydantic import Json, StringConstraints
from typing_extensions import Annotated
from zerver.actions.alert_words import do_add_alert_words, do_remove_alert_words from zerver.actions.alert_words import do_add_alert_words, do_remove_alert_words
from zerver.lib.alert_words import user_alert_words from zerver.lib.alert_words import user_alert_words
from zerver.lib.request import REQ, has_request_variables
from zerver.lib.response import json_success from zerver.lib.response import json_success
from zerver.lib.validator import check_capped_string, check_list, check_string from zerver.lib.typed_endpoint import typed_endpoint
from zerver.models import UserProfile from zerver.models import UserProfile
@ -19,21 +20,23 @@ def clean_alert_words(alert_words: List[str]) -> List[str]:
return [w for w in alert_words if w != ""] return [w for w in alert_words if w != ""]
@has_request_variables @typed_endpoint
def add_alert_words( def add_alert_words(
request: HttpRequest, request: HttpRequest,
user_profile: UserProfile, user_profile: UserProfile,
alert_words: List[str] = REQ(json_validator=check_list(check_capped_string(100))), *,
alert_words: Json[List[Annotated[str, StringConstraints(max_length=100)]]],
) -> HttpResponse: ) -> HttpResponse:
do_add_alert_words(user_profile, clean_alert_words(alert_words)) do_add_alert_words(user_profile, clean_alert_words(alert_words))
return json_success(request, data={"alert_words": user_alert_words(user_profile)}) return json_success(request, data={"alert_words": user_alert_words(user_profile)})
@has_request_variables @typed_endpoint
def remove_alert_words( def remove_alert_words(
request: HttpRequest, request: HttpRequest,
user_profile: UserProfile, user_profile: UserProfile,
alert_words: List[str] = REQ(json_validator=check_list(check_string)), *,
alert_words: Json[List[str]],
) -> HttpResponse: ) -> HttpResponse:
do_remove_alert_words(user_profile, alert_words) do_remove_alert_words(user_profile, alert_words)
return json_success(request, data={"alert_words": user_alert_words(user_profile)}) return json_success(request, data={"alert_words": user_alert_words(user_profile)})