mirror of https://github.com/zulip/zulip.git
drafts: Migrate drafts to use @typed_endpoint.
This demonstrates the use of BaseModel to replace a check_dict_only validator. We also add support to referring to $defs in the OpenAPI tests. In the future, we can descend down each object instead of mapping them to dict for more accurate checks.
This commit is contained in:
parent
4701f290f7
commit
910f69465c
|
@ -1,11 +1,12 @@
|
|||
import time
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Dict, List, Set
|
||||
from typing import Any, Callable, Dict, List, Literal, Union
|
||||
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.http import HttpRequest, HttpResponse
|
||||
from django.utils.translation import gettext as _
|
||||
from typing_extensions import Concatenate, ParamSpec
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from typing_extensions import Annotated, Concatenate, ParamSpec
|
||||
|
||||
from zerver.lib.addressee import get_user_profiles_by_ids
|
||||
from zerver.lib.exceptions import JsonableError, ResourceNotFoundError
|
||||
|
@ -13,48 +14,36 @@ from zerver.lib.message import normalize_body, truncate_topic
|
|||
from zerver.lib.recipient_users import recipient_for_user_profiles
|
||||
from zerver.lib.streams import access_stream_by_id
|
||||
from zerver.lib.timestamp import timestamp_to_datetime
|
||||
from zerver.lib.validator import (
|
||||
check_dict_only,
|
||||
check_float,
|
||||
check_int,
|
||||
check_list,
|
||||
check_required_string,
|
||||
check_string,
|
||||
check_string_in,
|
||||
check_union,
|
||||
)
|
||||
from zerver.lib.typed_endpoint import RequiredStringConstraint
|
||||
from zerver.models import Draft, UserProfile
|
||||
from zerver.tornado.django_api import send_event
|
||||
|
||||
ParamT = ParamSpec("ParamT")
|
||||
VALID_DRAFT_TYPES: Set[str] = {"", "private", "stream"}
|
||||
|
||||
# A validator to verify if the structure (syntax) of a dictionary
|
||||
# meets the requirements to be a draft dictionary:
|
||||
draft_dict_validator = check_dict_only(
|
||||
required_keys=[
|
||||
("type", check_string_in(VALID_DRAFT_TYPES)),
|
||||
("to", check_list(check_int)), # The ID of the stream to send to, or a list of user IDs.
|
||||
("topic", check_string), # This string can simply be empty for private type messages.
|
||||
("content", check_required_string),
|
||||
],
|
||||
optional_keys=[
|
||||
("timestamp", check_union([check_int, check_float])), # A Unix timestamp.
|
||||
],
|
||||
)
|
||||
|
||||
class DraftData(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
type: Literal["private", "stream", ""]
|
||||
to: List[int]
|
||||
topic: str
|
||||
content: Annotated[str, RequiredStringConstraint()]
|
||||
timestamp: Union[int, float, None] = None
|
||||
|
||||
|
||||
def further_validated_draft_dict(
|
||||
draft_dict: Dict[str, Any], user_profile: UserProfile
|
||||
draft_dict: DraftData, user_profile: UserProfile
|
||||
) -> Dict[str, Any]:
|
||||
"""Take a draft_dict that was already validated by draft_dict_validator then
|
||||
further sanitize, validate, and transform it. Ultimately return this "further
|
||||
validated" draft dict. It will have a slightly different set of keys the values
|
||||
for which can be used to directly create a Draft object."""
|
||||
|
||||
content = normalize_body(draft_dict["content"])
|
||||
content = normalize_body(draft_dict.content)
|
||||
|
||||
timestamp = draft_dict.get("timestamp", time.time())
|
||||
timestamp = draft_dict.timestamp
|
||||
if timestamp is None:
|
||||
timestamp = time.time()
|
||||
timestamp = round(timestamp, 6)
|
||||
if timestamp < 0:
|
||||
# While it's not exactly an invalid timestamp, it's not something
|
||||
|
@ -64,16 +53,16 @@ def further_validated_draft_dict(
|
|||
|
||||
topic = ""
|
||||
recipient_id = None
|
||||
to = draft_dict["to"]
|
||||
if draft_dict["type"] == "stream":
|
||||
topic = truncate_topic(draft_dict["topic"])
|
||||
to = draft_dict.to
|
||||
if draft_dict.type == "stream":
|
||||
topic = truncate_topic(draft_dict.topic)
|
||||
if "\0" in topic:
|
||||
raise JsonableError(_("Topic must not contain null bytes"))
|
||||
if len(to) != 1:
|
||||
raise JsonableError(_("Must specify exactly 1 stream ID for stream messages"))
|
||||
stream, sub = access_stream_by_id(user_profile, to[0])
|
||||
recipient_id = stream.recipient_id
|
||||
elif draft_dict["type"] == "private" and len(to) != 0:
|
||||
elif draft_dict.type == "private" and len(to) != 0:
|
||||
to_users = get_user_profiles_by_ids(set(to), user_profile.realm)
|
||||
try:
|
||||
recipient_id = recipient_for_user_profiles(to_users, False, None, user_profile).id
|
||||
|
@ -106,14 +95,14 @@ def draft_endpoint(
|
|||
return draft_view_func
|
||||
|
||||
|
||||
def do_create_drafts(draft_dicts: List[Dict[str, Any]], user_profile: UserProfile) -> List[Draft]:
|
||||
def do_create_drafts(drafts: List[DraftData], user_profile: UserProfile) -> List[Draft]:
|
||||
"""Create drafts in bulk for a given user based on the draft dicts. Since
|
||||
currently, the only place this method is being used (apart from tests) is from
|
||||
the create_draft view, we assume that the drafts_dicts are syntactically valid
|
||||
(i.e. they satisfy the draft_dict_validator)."""
|
||||
draft_objects = []
|
||||
for draft_dict in draft_dicts:
|
||||
valid_draft_dict = further_validated_draft_dict(draft_dict, user_profile)
|
||||
for draft in drafts:
|
||||
valid_draft_dict = further_validated_draft_dict(draft, user_profile)
|
||||
draft_objects.append(
|
||||
Draft(
|
||||
user_profile=user_profile,
|
||||
|
@ -136,7 +125,7 @@ def do_create_drafts(draft_dicts: List[Dict[str, Any]], user_profile: UserProfil
|
|||
return created_draft_objects
|
||||
|
||||
|
||||
def do_edit_draft(draft_id: int, draft_dict: Dict[str, Any], user_profile: UserProfile) -> None:
|
||||
def do_edit_draft(draft_id: int, draft: DraftData, user_profile: UserProfile) -> None:
|
||||
"""Edit/update a single draft for a given user. Since the only place this method is being
|
||||
used from (apart from tests) is the edit_draft view, we assume that the drafts_dict is
|
||||
syntactically valid (i.e. it satisfies the draft_dict_validator)."""
|
||||
|
@ -144,7 +133,7 @@ def do_edit_draft(draft_id: int, draft_dict: Dict[str, Any], user_profile: UserP
|
|||
draft_object = Draft.objects.get(id=draft_id, user_profile=user_profile)
|
||||
except Draft.DoesNotExist:
|
||||
raise ResourceNotFoundError(_("Draft does not exist"))
|
||||
valid_draft_dict = further_validated_draft_dict(draft_dict, user_profile)
|
||||
valid_draft_dict = further_validated_draft_dict(draft, user_profile)
|
||||
draft_object.content = valid_draft_dict["content"]
|
||||
draft_object.topic = valid_draft_dict["topic"]
|
||||
draft_object.recipient_id = valid_draft_dict["recipient_id"]
|
||||
|
|
|
@ -126,7 +126,7 @@ from zerver.actions.users import (
|
|||
do_update_outgoing_webhook_service,
|
||||
)
|
||||
from zerver.actions.video_calls import do_set_zoom_token
|
||||
from zerver.lib.drafts import do_create_drafts, do_delete_draft, do_edit_draft
|
||||
from zerver.lib.drafts import DraftData, do_create_drafts, do_delete_draft, do_edit_draft
|
||||
from zerver.lib.event_schema import (
|
||||
check_alert_words,
|
||||
check_attachment_add,
|
||||
|
@ -3497,39 +3497,39 @@ class DraftActionTest(BaseAction):
|
|||
|
||||
def test_draft_create_event(self) -> None:
|
||||
self.do_enable_drafts_synchronization(self.user_profile)
|
||||
dummy_draft = {
|
||||
"type": "draft",
|
||||
"to": "",
|
||||
"topic": "",
|
||||
"content": "Sample draft content",
|
||||
"timestamp": 1596820995,
|
||||
}
|
||||
dummy_draft = DraftData(
|
||||
type="",
|
||||
to=[],
|
||||
topic="",
|
||||
content="Sample draft content",
|
||||
timestamp=1596820995,
|
||||
)
|
||||
action = lambda: do_create_drafts([dummy_draft], self.user_profile)
|
||||
self.verify_action(action)
|
||||
|
||||
def test_draft_edit_event(self) -> None:
|
||||
self.do_enable_drafts_synchronization(self.user_profile)
|
||||
dummy_draft = {
|
||||
"type": "draft",
|
||||
"to": "",
|
||||
"topic": "",
|
||||
"content": "Sample draft content",
|
||||
"timestamp": 1596820995,
|
||||
}
|
||||
dummy_draft = DraftData(
|
||||
type="",
|
||||
to=[],
|
||||
topic="",
|
||||
content="Sample draft content",
|
||||
timestamp=1596820995,
|
||||
)
|
||||
draft_id = do_create_drafts([dummy_draft], self.user_profile)[0].id
|
||||
dummy_draft["content"] = "Some more sample draft content"
|
||||
dummy_draft.content = "Some more sample draft content"
|
||||
action = lambda: do_edit_draft(draft_id, dummy_draft, self.user_profile)
|
||||
self.verify_action(action)
|
||||
|
||||
def test_draft_delete_event(self) -> None:
|
||||
self.do_enable_drafts_synchronization(self.user_profile)
|
||||
dummy_draft = {
|
||||
"type": "draft",
|
||||
"to": "",
|
||||
"topic": "",
|
||||
"content": "Sample draft content",
|
||||
"timestamp": 1596820995,
|
||||
}
|
||||
dummy_draft = DraftData(
|
||||
type="",
|
||||
to=[],
|
||||
topic="",
|
||||
content="Sample draft content",
|
||||
timestamp=1596820995,
|
||||
)
|
||||
draft_id = do_create_drafts([dummy_draft], self.user_profile)[0].id
|
||||
action = lambda: do_delete_draft(draft_id, self.user_profile)
|
||||
self.verify_action(action)
|
||||
|
|
|
@ -6,6 +6,7 @@ from typing import (
|
|||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
|
@ -57,17 +58,21 @@ VARMAP = {
|
|||
}
|
||||
|
||||
|
||||
def schema_type(schema: Dict[str, Any]) -> Union[type, Tuple[type, object]]:
|
||||
def schema_type(
|
||||
schema: Dict[str, Any], defs: Mapping[str, Any] = {}
|
||||
) -> Union[type, Tuple[type, object]]:
|
||||
if "oneOf" in schema:
|
||||
# Hack: Just use the type of the first value
|
||||
# Ideally, we'd turn this into a Union type.
|
||||
return schema_type(schema["oneOf"][0])
|
||||
return schema_type(schema["oneOf"][0], defs)
|
||||
elif "anyOf" in schema:
|
||||
return schema_type(schema["anyOf"][0])
|
||||
return schema_type(schema["anyOf"][0], defs)
|
||||
elif schema.get("contentMediaType") == "application/json":
|
||||
return schema_type(schema["contentSchema"])
|
||||
return schema_type(schema["contentSchema"], defs)
|
||||
elif "$ref" in schema:
|
||||
return schema_type(defs[schema["$ref"]], defs)
|
||||
elif schema["type"] == "array":
|
||||
return (list, schema_type(schema["items"]))
|
||||
return (list, schema_type(schema["items"], defs))
|
||||
else:
|
||||
return VARMAP[schema["type"]]
|
||||
|
||||
|
@ -439,7 +444,10 @@ do not match the types declared in the implementation of {function.__name__}.\n"
|
|||
openapi_params.add((expected_request_var_name, schema_type(expected_param_schema)))
|
||||
|
||||
for actual_param in parse_view_func_signature(function).parameters:
|
||||
actual_param_schema = TypeAdapter(actual_param.param_type).json_schema()
|
||||
actual_param_schema = TypeAdapter(actual_param.param_type).json_schema(
|
||||
ref_template="{model}"
|
||||
)
|
||||
defs_mapping = actual_param_schema.get("$defs", {})
|
||||
# The content type of the JSON schema generated from the
|
||||
# function parameter type annotation should have content type
|
||||
# matching that of our OpenAPI spec. If not so, hint that the
|
||||
|
@ -467,7 +475,9 @@ do not match the types declared in the implementation of {function.__name__}.\n"
|
|||
(int, bool),
|
||||
f'\nUnexpected content type {actual_param_schema["contentMediaType"]} on function parameter {actual_param.param_name}, which does not match the OpenAPI definition.',
|
||||
)
|
||||
function_params.add((actual_param.request_var_name, schema_type(actual_param_schema)))
|
||||
function_params.add(
|
||||
(actual_param.request_var_name, schema_type(actual_param_schema, defs_mapping))
|
||||
)
|
||||
|
||||
diff = openapi_params - function_params
|
||||
if diff: # nocoverage
|
||||
|
|
|
@ -1,17 +1,17 @@
|
|||
from typing import Any, Dict, List
|
||||
from typing import List
|
||||
|
||||
from django.http import HttpRequest, HttpResponse
|
||||
from pydantic import Json
|
||||
|
||||
from zerver.lib.drafts import (
|
||||
DraftData,
|
||||
do_create_drafts,
|
||||
do_delete_draft,
|
||||
do_edit_draft,
|
||||
draft_dict_validator,
|
||||
draft_endpoint,
|
||||
)
|
||||
from zerver.lib.request import REQ, has_request_variables
|
||||
from zerver.lib.response import json_success
|
||||
from zerver.lib.validator import check_list
|
||||
from zerver.lib.typed_endpoint import PathOnly, typed_endpoint
|
||||
from zerver.models import Draft, UserProfile
|
||||
|
||||
|
||||
|
@ -23,32 +23,32 @@ def fetch_drafts(request: HttpRequest, user_profile: UserProfile) -> HttpRespons
|
|||
|
||||
|
||||
@draft_endpoint
|
||||
@has_request_variables
|
||||
@typed_endpoint
|
||||
def create_drafts(
|
||||
request: HttpRequest,
|
||||
user_profile: UserProfile,
|
||||
draft_dicts: List[Dict[str, Any]] = REQ(
|
||||
"drafts", json_validator=check_list(draft_dict_validator)
|
||||
),
|
||||
*,
|
||||
drafts: Json[List[DraftData]],
|
||||
) -> HttpResponse:
|
||||
created_draft_objects = do_create_drafts(draft_dicts, user_profile)
|
||||
created_draft_objects = do_create_drafts(drafts, user_profile)
|
||||
draft_ids = [draft_object.id for draft_object in created_draft_objects]
|
||||
return json_success(request, data={"ids": draft_ids})
|
||||
|
||||
|
||||
@draft_endpoint
|
||||
@has_request_variables
|
||||
@typed_endpoint
|
||||
def edit_draft(
|
||||
request: HttpRequest,
|
||||
user_profile: UserProfile,
|
||||
draft_id: int,
|
||||
draft_dict: Dict[str, Any] = REQ("draft", json_validator=draft_dict_validator),
|
||||
*,
|
||||
draft_id: PathOnly[int],
|
||||
draft: Json[DraftData],
|
||||
) -> HttpResponse:
|
||||
do_edit_draft(draft_id, draft_dict, user_profile)
|
||||
do_edit_draft(draft_id, draft, user_profile)
|
||||
return json_success(request)
|
||||
|
||||
|
||||
@draft_endpoint
|
||||
def delete_draft(request: HttpRequest, user_profile: UserProfile, draft_id: int) -> HttpResponse:
|
||||
def delete_draft(request: HttpRequest, user_profile: UserProfile, *, draft_id: int) -> HttpResponse:
|
||||
do_delete_draft(draft_id, user_profile)
|
||||
return json_success(request)
|
||||
|
|
Loading…
Reference in New Issue