drafts: Migrate drafts to use @typed_endpoint.

This demonstrates the use of BaseModel to replace a check_dict_only
validator.

We also add support to referring to $defs in the OpenAPI tests. In the
future, we can descend down each object instead of mapping them to dict
for more accurate checks.
This commit is contained in:
Zixuan James Li 2023-08-16 20:34:42 -04:00 committed by Tim Abbott
parent 4701f290f7
commit 910f69465c
4 changed files with 81 additions and 82 deletions

View File

@ -1,11 +1,12 @@
import time
from functools import wraps
from typing import Any, Callable, Dict, List, Set
from typing import Any, Callable, Dict, List, Literal, Union
from django.core.exceptions import ValidationError
from django.http import HttpRequest, HttpResponse
from django.utils.translation import gettext as _
from typing_extensions import Concatenate, ParamSpec
from pydantic import BaseModel, ConfigDict
from typing_extensions import Annotated, Concatenate, ParamSpec
from zerver.lib.addressee import get_user_profiles_by_ids
from zerver.lib.exceptions import JsonableError, ResourceNotFoundError
@ -13,48 +14,36 @@ from zerver.lib.message import normalize_body, truncate_topic
from zerver.lib.recipient_users import recipient_for_user_profiles
from zerver.lib.streams import access_stream_by_id
from zerver.lib.timestamp import timestamp_to_datetime
from zerver.lib.validator import (
check_dict_only,
check_float,
check_int,
check_list,
check_required_string,
check_string,
check_string_in,
check_union,
)
from zerver.lib.typed_endpoint import RequiredStringConstraint
from zerver.models import Draft, UserProfile
from zerver.tornado.django_api import send_event
ParamT = ParamSpec("ParamT")
VALID_DRAFT_TYPES: Set[str] = {"", "private", "stream"}
# A validator to verify if the structure (syntax) of a dictionary
# meets the requirements to be a draft dictionary:
draft_dict_validator = check_dict_only(
required_keys=[
("type", check_string_in(VALID_DRAFT_TYPES)),
("to", check_list(check_int)), # The ID of the stream to send to, or a list of user IDs.
("topic", check_string), # This string can simply be empty for private type messages.
("content", check_required_string),
],
optional_keys=[
("timestamp", check_union([check_int, check_float])), # A Unix timestamp.
],
)
class DraftData(BaseModel):
model_config = ConfigDict(extra="forbid")
type: Literal["private", "stream", ""]
to: List[int]
topic: str
content: Annotated[str, RequiredStringConstraint()]
timestamp: Union[int, float, None] = None
def further_validated_draft_dict(
draft_dict: Dict[str, Any], user_profile: UserProfile
draft_dict: DraftData, user_profile: UserProfile
) -> Dict[str, Any]:
"""Take a draft_dict that was already validated by draft_dict_validator then
further sanitize, validate, and transform it. Ultimately return this "further
validated" draft dict. It will have a slightly different set of keys the values
for which can be used to directly create a Draft object."""
content = normalize_body(draft_dict["content"])
content = normalize_body(draft_dict.content)
timestamp = draft_dict.get("timestamp", time.time())
timestamp = draft_dict.timestamp
if timestamp is None:
timestamp = time.time()
timestamp = round(timestamp, 6)
if timestamp < 0:
# While it's not exactly an invalid timestamp, it's not something
@ -64,16 +53,16 @@ def further_validated_draft_dict(
topic = ""
recipient_id = None
to = draft_dict["to"]
if draft_dict["type"] == "stream":
topic = truncate_topic(draft_dict["topic"])
to = draft_dict.to
if draft_dict.type == "stream":
topic = truncate_topic(draft_dict.topic)
if "\0" in topic:
raise JsonableError(_("Topic must not contain null bytes"))
if len(to) != 1:
raise JsonableError(_("Must specify exactly 1 stream ID for stream messages"))
stream, sub = access_stream_by_id(user_profile, to[0])
recipient_id = stream.recipient_id
elif draft_dict["type"] == "private" and len(to) != 0:
elif draft_dict.type == "private" and len(to) != 0:
to_users = get_user_profiles_by_ids(set(to), user_profile.realm)
try:
recipient_id = recipient_for_user_profiles(to_users, False, None, user_profile).id
@ -106,14 +95,14 @@ def draft_endpoint(
return draft_view_func
def do_create_drafts(draft_dicts: List[Dict[str, Any]], user_profile: UserProfile) -> List[Draft]:
def do_create_drafts(drafts: List[DraftData], user_profile: UserProfile) -> List[Draft]:
"""Create drafts in bulk for a given user based on the draft dicts. Since
currently, the only place this method is being used (apart from tests) is from
the create_draft view, we assume that the drafts_dicts are syntactically valid
(i.e. they satisfy the draft_dict_validator)."""
draft_objects = []
for draft_dict in draft_dicts:
valid_draft_dict = further_validated_draft_dict(draft_dict, user_profile)
for draft in drafts:
valid_draft_dict = further_validated_draft_dict(draft, user_profile)
draft_objects.append(
Draft(
user_profile=user_profile,
@ -136,7 +125,7 @@ def do_create_drafts(draft_dicts: List[Dict[str, Any]], user_profile: UserProfil
return created_draft_objects
def do_edit_draft(draft_id: int, draft_dict: Dict[str, Any], user_profile: UserProfile) -> None:
def do_edit_draft(draft_id: int, draft: DraftData, user_profile: UserProfile) -> None:
"""Edit/update a single draft for a given user. Since the only place this method is being
used from (apart from tests) is the edit_draft view, we assume that the drafts_dict is
syntactically valid (i.e. it satisfies the draft_dict_validator)."""
@ -144,7 +133,7 @@ def do_edit_draft(draft_id: int, draft_dict: Dict[str, Any], user_profile: UserP
draft_object = Draft.objects.get(id=draft_id, user_profile=user_profile)
except Draft.DoesNotExist:
raise ResourceNotFoundError(_("Draft does not exist"))
valid_draft_dict = further_validated_draft_dict(draft_dict, user_profile)
valid_draft_dict = further_validated_draft_dict(draft, user_profile)
draft_object.content = valid_draft_dict["content"]
draft_object.topic = valid_draft_dict["topic"]
draft_object.recipient_id = valid_draft_dict["recipient_id"]

View File

@ -126,7 +126,7 @@ from zerver.actions.users import (
do_update_outgoing_webhook_service,
)
from zerver.actions.video_calls import do_set_zoom_token
from zerver.lib.drafts import do_create_drafts, do_delete_draft, do_edit_draft
from zerver.lib.drafts import DraftData, do_create_drafts, do_delete_draft, do_edit_draft
from zerver.lib.event_schema import (
check_alert_words,
check_attachment_add,
@ -3497,39 +3497,39 @@ class DraftActionTest(BaseAction):
def test_draft_create_event(self) -> None:
self.do_enable_drafts_synchronization(self.user_profile)
dummy_draft = {
"type": "draft",
"to": "",
"topic": "",
"content": "Sample draft content",
"timestamp": 1596820995,
}
dummy_draft = DraftData(
type="",
to=[],
topic="",
content="Sample draft content",
timestamp=1596820995,
)
action = lambda: do_create_drafts([dummy_draft], self.user_profile)
self.verify_action(action)
def test_draft_edit_event(self) -> None:
self.do_enable_drafts_synchronization(self.user_profile)
dummy_draft = {
"type": "draft",
"to": "",
"topic": "",
"content": "Sample draft content",
"timestamp": 1596820995,
}
dummy_draft = DraftData(
type="",
to=[],
topic="",
content="Sample draft content",
timestamp=1596820995,
)
draft_id = do_create_drafts([dummy_draft], self.user_profile)[0].id
dummy_draft["content"] = "Some more sample draft content"
dummy_draft.content = "Some more sample draft content"
action = lambda: do_edit_draft(draft_id, dummy_draft, self.user_profile)
self.verify_action(action)
def test_draft_delete_event(self) -> None:
self.do_enable_drafts_synchronization(self.user_profile)
dummy_draft = {
"type": "draft",
"to": "",
"topic": "",
"content": "Sample draft content",
"timestamp": 1596820995,
}
dummy_draft = DraftData(
type="",
to=[],
topic="",
content="Sample draft content",
timestamp=1596820995,
)
draft_id = do_create_drafts([dummy_draft], self.user_profile)[0].id
action = lambda: do_delete_draft(draft_id, self.user_profile)
self.verify_action(action)

View File

@ -6,6 +6,7 @@ from typing import (
Callable,
Dict,
List,
Mapping,
Optional,
Sequence,
Set,
@ -57,17 +58,21 @@ VARMAP = {
}
def schema_type(schema: Dict[str, Any]) -> Union[type, Tuple[type, object]]:
def schema_type(
schema: Dict[str, Any], defs: Mapping[str, Any] = {}
) -> Union[type, Tuple[type, object]]:
if "oneOf" in schema:
# Hack: Just use the type of the first value
# Ideally, we'd turn this into a Union type.
return schema_type(schema["oneOf"][0])
return schema_type(schema["oneOf"][0], defs)
elif "anyOf" in schema:
return schema_type(schema["anyOf"][0])
return schema_type(schema["anyOf"][0], defs)
elif schema.get("contentMediaType") == "application/json":
return schema_type(schema["contentSchema"])
return schema_type(schema["contentSchema"], defs)
elif "$ref" in schema:
return schema_type(defs[schema["$ref"]], defs)
elif schema["type"] == "array":
return (list, schema_type(schema["items"]))
return (list, schema_type(schema["items"], defs))
else:
return VARMAP[schema["type"]]
@ -439,7 +444,10 @@ do not match the types declared in the implementation of {function.__name__}.\n"
openapi_params.add((expected_request_var_name, schema_type(expected_param_schema)))
for actual_param in parse_view_func_signature(function).parameters:
actual_param_schema = TypeAdapter(actual_param.param_type).json_schema()
actual_param_schema = TypeAdapter(actual_param.param_type).json_schema(
ref_template="{model}"
)
defs_mapping = actual_param_schema.get("$defs", {})
# The content type of the JSON schema generated from the
# function parameter type annotation should have content type
# matching that of our OpenAPI spec. If not so, hint that the
@ -467,7 +475,9 @@ do not match the types declared in the implementation of {function.__name__}.\n"
(int, bool),
f'\nUnexpected content type {actual_param_schema["contentMediaType"]} on function parameter {actual_param.param_name}, which does not match the OpenAPI definition.',
)
function_params.add((actual_param.request_var_name, schema_type(actual_param_schema)))
function_params.add(
(actual_param.request_var_name, schema_type(actual_param_schema, defs_mapping))
)
diff = openapi_params - function_params
if diff: # nocoverage

View File

@ -1,17 +1,17 @@
from typing import Any, Dict, List
from typing import List
from django.http import HttpRequest, HttpResponse
from pydantic import Json
from zerver.lib.drafts import (
DraftData,
do_create_drafts,
do_delete_draft,
do_edit_draft,
draft_dict_validator,
draft_endpoint,
)
from zerver.lib.request import REQ, has_request_variables
from zerver.lib.response import json_success
from zerver.lib.validator import check_list
from zerver.lib.typed_endpoint import PathOnly, typed_endpoint
from zerver.models import Draft, UserProfile
@ -23,32 +23,32 @@ def fetch_drafts(request: HttpRequest, user_profile: UserProfile) -> HttpRespons
@draft_endpoint
@has_request_variables
@typed_endpoint
def create_drafts(
request: HttpRequest,
user_profile: UserProfile,
draft_dicts: List[Dict[str, Any]] = REQ(
"drafts", json_validator=check_list(draft_dict_validator)
),
*,
drafts: Json[List[DraftData]],
) -> HttpResponse:
created_draft_objects = do_create_drafts(draft_dicts, user_profile)
created_draft_objects = do_create_drafts(drafts, user_profile)
draft_ids = [draft_object.id for draft_object in created_draft_objects]
return json_success(request, data={"ids": draft_ids})
@draft_endpoint
@has_request_variables
@typed_endpoint
def edit_draft(
request: HttpRequest,
user_profile: UserProfile,
draft_id: int,
draft_dict: Dict[str, Any] = REQ("draft", json_validator=draft_dict_validator),
*,
draft_id: PathOnly[int],
draft: Json[DraftData],
) -> HttpResponse:
do_edit_draft(draft_id, draft_dict, user_profile)
do_edit_draft(draft_id, draft, user_profile)
return json_success(request)
@draft_endpoint
def delete_draft(request: HttpRequest, user_profile: UserProfile, draft_id: int) -> HttpResponse:
def delete_draft(request: HttpRequest, user_profile: UserProfile, *, draft_id: int) -> HttpResponse:
do_delete_draft(draft_id, user_profile)
return json_success(request)