diff --git a/zerver/lib/drafts.py b/zerver/lib/drafts.py index 49b21f2f02..113e26b4d6 100644 --- a/zerver/lib/drafts.py +++ b/zerver/lib/drafts.py @@ -1,11 +1,12 @@ import time from functools import wraps -from typing import Any, Callable, Dict, List, Set +from typing import Any, Callable, Dict, List, Literal, Union from django.core.exceptions import ValidationError from django.http import HttpRequest, HttpResponse from django.utils.translation import gettext as _ -from typing_extensions import Concatenate, ParamSpec +from pydantic import BaseModel, ConfigDict +from typing_extensions import Annotated, Concatenate, ParamSpec from zerver.lib.addressee import get_user_profiles_by_ids from zerver.lib.exceptions import JsonableError, ResourceNotFoundError @@ -13,48 +14,36 @@ from zerver.lib.message import normalize_body, truncate_topic from zerver.lib.recipient_users import recipient_for_user_profiles from zerver.lib.streams import access_stream_by_id from zerver.lib.timestamp import timestamp_to_datetime -from zerver.lib.validator import ( - check_dict_only, - check_float, - check_int, - check_list, - check_required_string, - check_string, - check_string_in, - check_union, -) +from zerver.lib.typed_endpoint import RequiredStringConstraint from zerver.models import Draft, UserProfile from zerver.tornado.django_api import send_event ParamT = ParamSpec("ParamT") -VALID_DRAFT_TYPES: Set[str] = {"", "private", "stream"} -# A validator to verify if the structure (syntax) of a dictionary -# meets the requirements to be a draft dictionary: -draft_dict_validator = check_dict_only( - required_keys=[ - ("type", check_string_in(VALID_DRAFT_TYPES)), - ("to", check_list(check_int)), # The ID of the stream to send to, or a list of user IDs. - ("topic", check_string), # This string can simply be empty for private type messages. - ("content", check_required_string), - ], - optional_keys=[ - ("timestamp", check_union([check_int, check_float])), # A Unix timestamp. - ], -) + +class DraftData(BaseModel): + model_config = ConfigDict(extra="forbid") + + type: Literal["private", "stream", ""] + to: List[int] + topic: str + content: Annotated[str, RequiredStringConstraint()] + timestamp: Union[int, float, None] = None def further_validated_draft_dict( - draft_dict: Dict[str, Any], user_profile: UserProfile + draft_dict: DraftData, user_profile: UserProfile ) -> Dict[str, Any]: """Take a draft_dict that was already validated by draft_dict_validator then further sanitize, validate, and transform it. Ultimately return this "further validated" draft dict. It will have a slightly different set of keys the values for which can be used to directly create a Draft object.""" - content = normalize_body(draft_dict["content"]) + content = normalize_body(draft_dict.content) - timestamp = draft_dict.get("timestamp", time.time()) + timestamp = draft_dict.timestamp + if timestamp is None: + timestamp = time.time() timestamp = round(timestamp, 6) if timestamp < 0: # While it's not exactly an invalid timestamp, it's not something @@ -64,16 +53,16 @@ def further_validated_draft_dict( topic = "" recipient_id = None - to = draft_dict["to"] - if draft_dict["type"] == "stream": - topic = truncate_topic(draft_dict["topic"]) + to = draft_dict.to + if draft_dict.type == "stream": + topic = truncate_topic(draft_dict.topic) if "\0" in topic: raise JsonableError(_("Topic must not contain null bytes")) if len(to) != 1: raise JsonableError(_("Must specify exactly 1 stream ID for stream messages")) stream, sub = access_stream_by_id(user_profile, to[0]) recipient_id = stream.recipient_id - elif draft_dict["type"] == "private" and len(to) != 0: + elif draft_dict.type == "private" and len(to) != 0: to_users = get_user_profiles_by_ids(set(to), user_profile.realm) try: recipient_id = recipient_for_user_profiles(to_users, False, None, user_profile).id @@ -106,14 +95,14 @@ def draft_endpoint( return draft_view_func -def do_create_drafts(draft_dicts: List[Dict[str, Any]], user_profile: UserProfile) -> List[Draft]: +def do_create_drafts(drafts: List[DraftData], user_profile: UserProfile) -> List[Draft]: """Create drafts in bulk for a given user based on the draft dicts. Since currently, the only place this method is being used (apart from tests) is from the create_draft view, we assume that the drafts_dicts are syntactically valid (i.e. they satisfy the draft_dict_validator).""" draft_objects = [] - for draft_dict in draft_dicts: - valid_draft_dict = further_validated_draft_dict(draft_dict, user_profile) + for draft in drafts: + valid_draft_dict = further_validated_draft_dict(draft, user_profile) draft_objects.append( Draft( user_profile=user_profile, @@ -136,7 +125,7 @@ def do_create_drafts(draft_dicts: List[Dict[str, Any]], user_profile: UserProfil return created_draft_objects -def do_edit_draft(draft_id: int, draft_dict: Dict[str, Any], user_profile: UserProfile) -> None: +def do_edit_draft(draft_id: int, draft: DraftData, user_profile: UserProfile) -> None: """Edit/update a single draft for a given user. Since the only place this method is being used from (apart from tests) is the edit_draft view, we assume that the drafts_dict is syntactically valid (i.e. it satisfies the draft_dict_validator).""" @@ -144,7 +133,7 @@ def do_edit_draft(draft_id: int, draft_dict: Dict[str, Any], user_profile: UserP draft_object = Draft.objects.get(id=draft_id, user_profile=user_profile) except Draft.DoesNotExist: raise ResourceNotFoundError(_("Draft does not exist")) - valid_draft_dict = further_validated_draft_dict(draft_dict, user_profile) + valid_draft_dict = further_validated_draft_dict(draft, user_profile) draft_object.content = valid_draft_dict["content"] draft_object.topic = valid_draft_dict["topic"] draft_object.recipient_id = valid_draft_dict["recipient_id"] diff --git a/zerver/tests/test_events.py b/zerver/tests/test_events.py index 59c3c64c0b..c10d5ee430 100644 --- a/zerver/tests/test_events.py +++ b/zerver/tests/test_events.py @@ -126,7 +126,7 @@ from zerver.actions.users import ( do_update_outgoing_webhook_service, ) from zerver.actions.video_calls import do_set_zoom_token -from zerver.lib.drafts import do_create_drafts, do_delete_draft, do_edit_draft +from zerver.lib.drafts import DraftData, do_create_drafts, do_delete_draft, do_edit_draft from zerver.lib.event_schema import ( check_alert_words, check_attachment_add, @@ -3497,39 +3497,39 @@ class DraftActionTest(BaseAction): def test_draft_create_event(self) -> None: self.do_enable_drafts_synchronization(self.user_profile) - dummy_draft = { - "type": "draft", - "to": "", - "topic": "", - "content": "Sample draft content", - "timestamp": 1596820995, - } + dummy_draft = DraftData( + type="", + to=[], + topic="", + content="Sample draft content", + timestamp=1596820995, + ) action = lambda: do_create_drafts([dummy_draft], self.user_profile) self.verify_action(action) def test_draft_edit_event(self) -> None: self.do_enable_drafts_synchronization(self.user_profile) - dummy_draft = { - "type": "draft", - "to": "", - "topic": "", - "content": "Sample draft content", - "timestamp": 1596820995, - } + dummy_draft = DraftData( + type="", + to=[], + topic="", + content="Sample draft content", + timestamp=1596820995, + ) draft_id = do_create_drafts([dummy_draft], self.user_profile)[0].id - dummy_draft["content"] = "Some more sample draft content" + dummy_draft.content = "Some more sample draft content" action = lambda: do_edit_draft(draft_id, dummy_draft, self.user_profile) self.verify_action(action) def test_draft_delete_event(self) -> None: self.do_enable_drafts_synchronization(self.user_profile) - dummy_draft = { - "type": "draft", - "to": "", - "topic": "", - "content": "Sample draft content", - "timestamp": 1596820995, - } + dummy_draft = DraftData( + type="", + to=[], + topic="", + content="Sample draft content", + timestamp=1596820995, + ) draft_id = do_create_drafts([dummy_draft], self.user_profile)[0].id action = lambda: do_delete_draft(draft_id, self.user_profile) self.verify_action(action) diff --git a/zerver/tests/test_openapi.py b/zerver/tests/test_openapi.py index 11713690c8..03f2c71c3c 100644 --- a/zerver/tests/test_openapi.py +++ b/zerver/tests/test_openapi.py @@ -6,6 +6,7 @@ from typing import ( Callable, Dict, List, + Mapping, Optional, Sequence, Set, @@ -57,17 +58,21 @@ VARMAP = { } -def schema_type(schema: Dict[str, Any]) -> Union[type, Tuple[type, object]]: +def schema_type( + schema: Dict[str, Any], defs: Mapping[str, Any] = {} +) -> Union[type, Tuple[type, object]]: if "oneOf" in schema: # Hack: Just use the type of the first value # Ideally, we'd turn this into a Union type. - return schema_type(schema["oneOf"][0]) + return schema_type(schema["oneOf"][0], defs) elif "anyOf" in schema: - return schema_type(schema["anyOf"][0]) + return schema_type(schema["anyOf"][0], defs) elif schema.get("contentMediaType") == "application/json": - return schema_type(schema["contentSchema"]) + return schema_type(schema["contentSchema"], defs) + elif "$ref" in schema: + return schema_type(defs[schema["$ref"]], defs) elif schema["type"] == "array": - return (list, schema_type(schema["items"])) + return (list, schema_type(schema["items"], defs)) else: return VARMAP[schema["type"]] @@ -439,7 +444,10 @@ do not match the types declared in the implementation of {function.__name__}.\n" openapi_params.add((expected_request_var_name, schema_type(expected_param_schema))) for actual_param in parse_view_func_signature(function).parameters: - actual_param_schema = TypeAdapter(actual_param.param_type).json_schema() + actual_param_schema = TypeAdapter(actual_param.param_type).json_schema( + ref_template="{model}" + ) + defs_mapping = actual_param_schema.get("$defs", {}) # The content type of the JSON schema generated from the # function parameter type annotation should have content type # matching that of our OpenAPI spec. If not so, hint that the @@ -467,7 +475,9 @@ do not match the types declared in the implementation of {function.__name__}.\n" (int, bool), f'\nUnexpected content type {actual_param_schema["contentMediaType"]} on function parameter {actual_param.param_name}, which does not match the OpenAPI definition.', ) - function_params.add((actual_param.request_var_name, schema_type(actual_param_schema))) + function_params.add( + (actual_param.request_var_name, schema_type(actual_param_schema, defs_mapping)) + ) diff = openapi_params - function_params if diff: # nocoverage diff --git a/zerver/views/drafts.py b/zerver/views/drafts.py index 33ba4ffc5e..6d46fa45d3 100644 --- a/zerver/views/drafts.py +++ b/zerver/views/drafts.py @@ -1,17 +1,17 @@ -from typing import Any, Dict, List +from typing import List from django.http import HttpRequest, HttpResponse +from pydantic import Json from zerver.lib.drafts import ( + DraftData, do_create_drafts, do_delete_draft, do_edit_draft, - draft_dict_validator, draft_endpoint, ) -from zerver.lib.request import REQ, has_request_variables from zerver.lib.response import json_success -from zerver.lib.validator import check_list +from zerver.lib.typed_endpoint import PathOnly, typed_endpoint from zerver.models import Draft, UserProfile @@ -23,32 +23,32 @@ def fetch_drafts(request: HttpRequest, user_profile: UserProfile) -> HttpRespons @draft_endpoint -@has_request_variables +@typed_endpoint def create_drafts( request: HttpRequest, user_profile: UserProfile, - draft_dicts: List[Dict[str, Any]] = REQ( - "drafts", json_validator=check_list(draft_dict_validator) - ), + *, + drafts: Json[List[DraftData]], ) -> HttpResponse: - created_draft_objects = do_create_drafts(draft_dicts, user_profile) + created_draft_objects = do_create_drafts(drafts, user_profile) draft_ids = [draft_object.id for draft_object in created_draft_objects] return json_success(request, data={"ids": draft_ids}) @draft_endpoint -@has_request_variables +@typed_endpoint def edit_draft( request: HttpRequest, user_profile: UserProfile, - draft_id: int, - draft_dict: Dict[str, Any] = REQ("draft", json_validator=draft_dict_validator), + *, + draft_id: PathOnly[int], + draft: Json[DraftData], ) -> HttpResponse: - do_edit_draft(draft_id, draft_dict, user_profile) + do_edit_draft(draft_id, draft, user_profile) return json_success(request) @draft_endpoint -def delete_draft(request: HttpRequest, user_profile: UserProfile, draft_id: int) -> HttpResponse: +def delete_draft(request: HttpRequest, user_profile: UserProfile, *, draft_id: int) -> HttpResponse: do_delete_draft(draft_id, user_profile) return json_success(request)