validators: Use cleaner syntax for AfterValidator.

Created a function that returns an `AfterValidator` for `check_int_in`
and `check_string_in` instead of having to use a
`lambda` wraper everytime.
This commit is contained in:
Kenneth Rodrigues 2024-07-05 00:08:59 +05:30 committed by Tim Abbott
parent 0ec4b0285e
commit a7da24a36f
7 changed files with 64 additions and 44 deletions

View File

@ -4,7 +4,7 @@ from typing import Annotated
from django.conf import settings
from django.http import HttpRequest, HttpResponse, HttpResponseRedirect
from django.shortcuts import render
from pydantic import AfterValidator, Json
from pydantic import Json
from corporate.lib.decorator import (
authenticated_remote_realm_management_endpoint,
@ -25,7 +25,7 @@ from corporate.models import CustomerPlan
from zerver.decorator import require_organization_member, zulip_login_required
from zerver.lib.response import json_success
from zerver.lib.typed_endpoint import typed_endpoint
from zerver.lib.typed_endpoint_validators import check_string_in
from zerver.lib.typed_endpoint_validators import check_string_in_validator
from zerver.models import UserProfile
from zilencer.lib.remote_counts import MissingDataError
@ -38,17 +38,11 @@ def upgrade(
request: HttpRequest,
user: UserProfile,
*,
billing_modality: Annotated[
str, AfterValidator(lambda val: check_string_in(val, VALID_BILLING_MODALITY_VALUES))
],
schedule: Annotated[
str, AfterValidator(lambda val: check_string_in(val, VALID_BILLING_SCHEDULE_VALUES))
],
billing_modality: Annotated[str, check_string_in_validator(VALID_BILLING_MODALITY_VALUES)],
schedule: Annotated[str, check_string_in_validator(VALID_BILLING_SCHEDULE_VALUES)],
signed_seat_count: str,
salt: str,
license_management: Annotated[
str, AfterValidator(lambda val: check_string_in(val, VALID_LICENSE_MANAGEMENT_VALUES))
]
license_management: Annotated[str, check_string_in_validator(VALID_LICENSE_MANAGEMENT_VALUES)]
| None = None,
licenses: Json[int] | None = None,
tier: Json[int] = CustomerPlan.TIER_CLOUD_STANDARD,
@ -94,17 +88,11 @@ def remote_realm_upgrade(
request: HttpRequest,
billing_session: RemoteRealmBillingSession,
*,
billing_modality: Annotated[
str, AfterValidator(lambda val: check_string_in(val, VALID_BILLING_MODALITY_VALUES))
],
schedule: Annotated[
str, AfterValidator(lambda val: check_string_in(val, VALID_BILLING_SCHEDULE_VALUES))
],
billing_modality: Annotated[str, check_string_in_validator(VALID_BILLING_MODALITY_VALUES)],
schedule: Annotated[str, check_string_in_validator(VALID_BILLING_SCHEDULE_VALUES)],
signed_seat_count: str,
salt: str,
license_management: Annotated[
str, AfterValidator(lambda val: check_string_in(val, VALID_LICENSE_MANAGEMENT_VALUES))
]
license_management: Annotated[str, check_string_in_validator(VALID_LICENSE_MANAGEMENT_VALUES)]
| None = None,
licenses: Json[int] | None = None,
remote_server_plan_start_date: str | None = None,
@ -149,17 +137,11 @@ def remote_server_upgrade(
request: HttpRequest,
billing_session: RemoteServerBillingSession,
*,
billing_modality: Annotated[
str, AfterValidator(lambda val: check_string_in(val, VALID_BILLING_MODALITY_VALUES))
],
schedule: Annotated[
str, AfterValidator(lambda val: check_string_in(val, VALID_BILLING_SCHEDULE_VALUES))
],
billing_modality: Annotated[str, check_string_in_validator(VALID_BILLING_MODALITY_VALUES)],
schedule: Annotated[str, check_string_in_validator(VALID_BILLING_SCHEDULE_VALUES)],
signed_seat_count: str,
salt: str,
license_management: Annotated[
str, AfterValidator(lambda val: check_string_in(val, VALID_LICENSE_MANAGEMENT_VALUES))
]
license_management: Annotated[str, check_string_in_validator(VALID_LICENSE_MANAGEMENT_VALUES)]
| None = None,
licenses: Json[int] | None = None,
remote_server_plan_start_date: str | None = None,

View File

@ -1,6 +1,9 @@
from collections.abc import Collection
from django.core.exceptions import ValidationError
from django.core.validators import URLValidator
from django.utils.translation import gettext as _
from pydantic import AfterValidator
from pydantic_core import PydanticCustomError
# The Pydantic.StringConstraints does not have validation for the string to be
@ -19,18 +22,26 @@ def check_string_fixed_length(string: str, length: int) -> str | None:
return string
def check_string_in(val: str, possible_values: list[str]) -> str:
def check_string_in(val: str, possible_values: Collection[str]) -> str:
if val not in possible_values:
raise ValueError(_("Not in the list of possible values"))
return val
def check_int_in(val: int, possible_values: list[int]) -> int:
def check_int_in(val: int, possible_values: Collection[int]) -> int:
if val not in possible_values:
raise ValueError(_("Not in the list of possible values"))
return val
def check_int_in_validator(possible_values: Collection[int]) -> AfterValidator:
return AfterValidator(lambda val: check_int_in(val, possible_values))
def check_string_in_validator(possible_values: Collection[str]) -> AfterValidator:
return AfterValidator(lambda val: check_string_in(val, possible_values))
def check_url(val: str) -> str:
validate = URLValidator()
try:

View File

@ -1,5 +1,5 @@
from zerver.lib.test_classes import ZulipTestCase
from zerver.lib.typed_endpoint_validators import check_int_in, check_url
from zerver.lib.typed_endpoint_validators import check_int_in, check_string_in, check_url
class ValidatorTestCase(ZulipTestCase):
@ -8,6 +8,11 @@ class ValidatorTestCase(ZulipTestCase):
with self.assertRaisesRegex(ValueError, "Not in the list of possible values"):
check_int_in(3, [1, 2])
def test_check_string_in(self) -> None:
check_string_in("foo", ["foo", "bar"])
with self.assertRaisesRegex(ValueError, "Not in the list of possible values"):
check_string_in("foo", ["bar"])
def test_check_url(self) -> None:
check_url("https://example.com")
with self.assertRaisesRegex(ValueError, "Not a URL"):

View File

@ -688,6 +688,26 @@ class UnmutedTopicsTests(ZulipTestCase):
self.assert_json_error(result, "Invalid channel ID")
class UserTopicsTests(ZulipTestCase):
def test_invalid_visibility_policy(self) -> None:
user = self.example_user("hamlet")
self.login_user(user)
stream = get_stream("Verona", user.realm)
url = "/api/v1/user_topics"
data = {
"stream_id": stream.id,
"topic": "Verona3",
"visibility_policy": 999,
}
result = self.api_post(user, url, data)
self.assert_json_error(
result, "Invalid visibility_policy: Value error, Not in the list of possible values"
)
class AutomaticallyFollowTopicsTests(ZulipTestCase):
def test_automatically_follow_topic_on_initiation(self) -> None:
hamlet = self.example_user("hamlet")

View File

@ -461,6 +461,14 @@ class PermissionTest(ZulipTestCase):
result = self.client_patch("/json/users/{}".format(self.example_user("hamlet").id), req)
self.assert_json_error(result, "Invalid format!")
def test_invalid_role(self) -> None:
self.login("iago")
req = dict(role=1000)
result = self.client_patch("/json/users/{}".format(self.example_user("hamlet").id), req)
self.assert_json_error(
result, "Invalid role: Value error, Not in the list of possible values"
)
def test_admin_cannot_set_full_name_with_invalid_characters(self) -> None:
new_name = "Opheli*"
self.login("iago")

View File

@ -4,7 +4,7 @@ from typing import Annotated, Literal
from django.http import HttpRequest, HttpResponse
from django.utils.timezone import now as timezone_now
from django.utils.translation import gettext as _
from pydantic import AfterValidator, Json, StringConstraints
from pydantic import Json, StringConstraints
from zerver.actions.user_topics import do_set_user_topic_visibility_policy
from zerver.lib.response import json_success
@ -16,7 +16,7 @@ from zerver.lib.streams import (
check_for_exactly_one_stream_arg,
)
from zerver.lib.typed_endpoint import typed_endpoint
from zerver.lib.typed_endpoint_validators import check_int_in
from zerver.lib.typed_endpoint_validators import check_int_in_validator
from zerver.models import UserProfile, UserTopic
from zerver.models.constants import MAX_TOPIC_NAME_LENGTH
@ -100,10 +100,7 @@ def update_user_topic(
stream_id: Json[int],
topic: Annotated[str, StringConstraints(max_length=MAX_TOPIC_NAME_LENGTH)],
visibility_policy: Json[
Annotated[
int,
AfterValidator(lambda x: check_int_in(x, UserTopic.VisibilityPolicy.values)),
]
Annotated[int, check_int_in_validator(UserTopic.VisibilityPolicy.values)]
],
) -> HttpResponse:
if visibility_policy == UserTopic.VisibilityPolicy.INHERIT:

View File

@ -57,7 +57,7 @@ from zerver.lib.typed_endpoint import (
typed_endpoint,
typed_endpoint_without_parameters,
)
from zerver.lib.typed_endpoint_validators import check_int_in, check_url
from zerver.lib.typed_endpoint_validators import check_int_in_validator, check_url
from zerver.lib.types import ProfileDataElementUpdateDict
from zerver.lib.upload import upload_avatar_image
from zerver.lib.url_encoding import append_url_query_string
@ -98,11 +98,8 @@ from zproject.backends import check_password_strength
RoleParamType: TypeAlias = Annotated[
int,
AfterValidator(
lambda x: check_int_in(
x,
UserProfile.ROLE_TYPES,
)
check_int_in_validator(
UserProfile.ROLE_TYPES,
),
]