validators: Use cleaner syntax for AfterValidator.

Created a function that returns an `AfterValidator` for `check_int_in`
and `check_string_in` instead of having to use a
`lambda` wraper everytime.
This commit is contained in:
Kenneth Rodrigues 2024-07-05 00:08:59 +05:30 committed by Tim Abbott
parent 0ec4b0285e
commit a7da24a36f
7 changed files with 64 additions and 44 deletions

View File

@ -4,7 +4,7 @@ from typing import Annotated
from django.conf import settings from django.conf import settings
from django.http import HttpRequest, HttpResponse, HttpResponseRedirect from django.http import HttpRequest, HttpResponse, HttpResponseRedirect
from django.shortcuts import render from django.shortcuts import render
from pydantic import AfterValidator, Json from pydantic import Json
from corporate.lib.decorator import ( from corporate.lib.decorator import (
authenticated_remote_realm_management_endpoint, authenticated_remote_realm_management_endpoint,
@ -25,7 +25,7 @@ from corporate.models import CustomerPlan
from zerver.decorator import require_organization_member, zulip_login_required from zerver.decorator import require_organization_member, zulip_login_required
from zerver.lib.response import json_success from zerver.lib.response import json_success
from zerver.lib.typed_endpoint import typed_endpoint from zerver.lib.typed_endpoint import typed_endpoint
from zerver.lib.typed_endpoint_validators import check_string_in from zerver.lib.typed_endpoint_validators import check_string_in_validator
from zerver.models import UserProfile from zerver.models import UserProfile
from zilencer.lib.remote_counts import MissingDataError from zilencer.lib.remote_counts import MissingDataError
@ -38,17 +38,11 @@ def upgrade(
request: HttpRequest, request: HttpRequest,
user: UserProfile, user: UserProfile,
*, *,
billing_modality: Annotated[ billing_modality: Annotated[str, check_string_in_validator(VALID_BILLING_MODALITY_VALUES)],
str, AfterValidator(lambda val: check_string_in(val, VALID_BILLING_MODALITY_VALUES)) schedule: Annotated[str, check_string_in_validator(VALID_BILLING_SCHEDULE_VALUES)],
],
schedule: Annotated[
str, AfterValidator(lambda val: check_string_in(val, VALID_BILLING_SCHEDULE_VALUES))
],
signed_seat_count: str, signed_seat_count: str,
salt: str, salt: str,
license_management: Annotated[ license_management: Annotated[str, check_string_in_validator(VALID_LICENSE_MANAGEMENT_VALUES)]
str, AfterValidator(lambda val: check_string_in(val, VALID_LICENSE_MANAGEMENT_VALUES))
]
| None = None, | None = None,
licenses: Json[int] | None = None, licenses: Json[int] | None = None,
tier: Json[int] = CustomerPlan.TIER_CLOUD_STANDARD, tier: Json[int] = CustomerPlan.TIER_CLOUD_STANDARD,
@ -94,17 +88,11 @@ def remote_realm_upgrade(
request: HttpRequest, request: HttpRequest,
billing_session: RemoteRealmBillingSession, billing_session: RemoteRealmBillingSession,
*, *,
billing_modality: Annotated[ billing_modality: Annotated[str, check_string_in_validator(VALID_BILLING_MODALITY_VALUES)],
str, AfterValidator(lambda val: check_string_in(val, VALID_BILLING_MODALITY_VALUES)) schedule: Annotated[str, check_string_in_validator(VALID_BILLING_SCHEDULE_VALUES)],
],
schedule: Annotated[
str, AfterValidator(lambda val: check_string_in(val, VALID_BILLING_SCHEDULE_VALUES))
],
signed_seat_count: str, signed_seat_count: str,
salt: str, salt: str,
license_management: Annotated[ license_management: Annotated[str, check_string_in_validator(VALID_LICENSE_MANAGEMENT_VALUES)]
str, AfterValidator(lambda val: check_string_in(val, VALID_LICENSE_MANAGEMENT_VALUES))
]
| None = None, | None = None,
licenses: Json[int] | None = None, licenses: Json[int] | None = None,
remote_server_plan_start_date: str | None = None, remote_server_plan_start_date: str | None = None,
@ -149,17 +137,11 @@ def remote_server_upgrade(
request: HttpRequest, request: HttpRequest,
billing_session: RemoteServerBillingSession, billing_session: RemoteServerBillingSession,
*, *,
billing_modality: Annotated[ billing_modality: Annotated[str, check_string_in_validator(VALID_BILLING_MODALITY_VALUES)],
str, AfterValidator(lambda val: check_string_in(val, VALID_BILLING_MODALITY_VALUES)) schedule: Annotated[str, check_string_in_validator(VALID_BILLING_SCHEDULE_VALUES)],
],
schedule: Annotated[
str, AfterValidator(lambda val: check_string_in(val, VALID_BILLING_SCHEDULE_VALUES))
],
signed_seat_count: str, signed_seat_count: str,
salt: str, salt: str,
license_management: Annotated[ license_management: Annotated[str, check_string_in_validator(VALID_LICENSE_MANAGEMENT_VALUES)]
str, AfterValidator(lambda val: check_string_in(val, VALID_LICENSE_MANAGEMENT_VALUES))
]
| None = None, | None = None,
licenses: Json[int] | None = None, licenses: Json[int] | None = None,
remote_server_plan_start_date: str | None = None, remote_server_plan_start_date: str | None = None,

View File

@ -1,6 +1,9 @@
from collections.abc import Collection
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.core.validators import URLValidator from django.core.validators import URLValidator
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from pydantic import AfterValidator
from pydantic_core import PydanticCustomError from pydantic_core import PydanticCustomError
# The Pydantic.StringConstraints does not have validation for the string to be # The Pydantic.StringConstraints does not have validation for the string to be
@ -19,18 +22,26 @@ def check_string_fixed_length(string: str, length: int) -> str | None:
return string return string
def check_string_in(val: str, possible_values: list[str]) -> str: def check_string_in(val: str, possible_values: Collection[str]) -> str:
if val not in possible_values: if val not in possible_values:
raise ValueError(_("Not in the list of possible values")) raise ValueError(_("Not in the list of possible values"))
return val return val
def check_int_in(val: int, possible_values: list[int]) -> int: def check_int_in(val: int, possible_values: Collection[int]) -> int:
if val not in possible_values: if val not in possible_values:
raise ValueError(_("Not in the list of possible values")) raise ValueError(_("Not in the list of possible values"))
return val return val
def check_int_in_validator(possible_values: Collection[int]) -> AfterValidator:
return AfterValidator(lambda val: check_int_in(val, possible_values))
def check_string_in_validator(possible_values: Collection[str]) -> AfterValidator:
return AfterValidator(lambda val: check_string_in(val, possible_values))
def check_url(val: str) -> str: def check_url(val: str) -> str:
validate = URLValidator() validate = URLValidator()
try: try:

View File

@ -1,5 +1,5 @@
from zerver.lib.test_classes import ZulipTestCase from zerver.lib.test_classes import ZulipTestCase
from zerver.lib.typed_endpoint_validators import check_int_in, check_url from zerver.lib.typed_endpoint_validators import check_int_in, check_string_in, check_url
class ValidatorTestCase(ZulipTestCase): class ValidatorTestCase(ZulipTestCase):
@ -8,6 +8,11 @@ class ValidatorTestCase(ZulipTestCase):
with self.assertRaisesRegex(ValueError, "Not in the list of possible values"): with self.assertRaisesRegex(ValueError, "Not in the list of possible values"):
check_int_in(3, [1, 2]) check_int_in(3, [1, 2])
def test_check_string_in(self) -> None:
check_string_in("foo", ["foo", "bar"])
with self.assertRaisesRegex(ValueError, "Not in the list of possible values"):
check_string_in("foo", ["bar"])
def test_check_url(self) -> None: def test_check_url(self) -> None:
check_url("https://example.com") check_url("https://example.com")
with self.assertRaisesRegex(ValueError, "Not a URL"): with self.assertRaisesRegex(ValueError, "Not a URL"):

View File

@ -688,6 +688,26 @@ class UnmutedTopicsTests(ZulipTestCase):
self.assert_json_error(result, "Invalid channel ID") self.assert_json_error(result, "Invalid channel ID")
class UserTopicsTests(ZulipTestCase):
def test_invalid_visibility_policy(self) -> None:
user = self.example_user("hamlet")
self.login_user(user)
stream = get_stream("Verona", user.realm)
url = "/api/v1/user_topics"
data = {
"stream_id": stream.id,
"topic": "Verona3",
"visibility_policy": 999,
}
result = self.api_post(user, url, data)
self.assert_json_error(
result, "Invalid visibility_policy: Value error, Not in the list of possible values"
)
class AutomaticallyFollowTopicsTests(ZulipTestCase): class AutomaticallyFollowTopicsTests(ZulipTestCase):
def test_automatically_follow_topic_on_initiation(self) -> None: def test_automatically_follow_topic_on_initiation(self) -> None:
hamlet = self.example_user("hamlet") hamlet = self.example_user("hamlet")

View File

@ -461,6 +461,14 @@ class PermissionTest(ZulipTestCase):
result = self.client_patch("/json/users/{}".format(self.example_user("hamlet").id), req) result = self.client_patch("/json/users/{}".format(self.example_user("hamlet").id), req)
self.assert_json_error(result, "Invalid format!") self.assert_json_error(result, "Invalid format!")
def test_invalid_role(self) -> None:
self.login("iago")
req = dict(role=1000)
result = self.client_patch("/json/users/{}".format(self.example_user("hamlet").id), req)
self.assert_json_error(
result, "Invalid role: Value error, Not in the list of possible values"
)
def test_admin_cannot_set_full_name_with_invalid_characters(self) -> None: def test_admin_cannot_set_full_name_with_invalid_characters(self) -> None:
new_name = "Opheli*" new_name = "Opheli*"
self.login("iago") self.login("iago")

View File

@ -4,7 +4,7 @@ from typing import Annotated, Literal
from django.http import HttpRequest, HttpResponse from django.http import HttpRequest, HttpResponse
from django.utils.timezone import now as timezone_now from django.utils.timezone import now as timezone_now
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from pydantic import AfterValidator, Json, StringConstraints from pydantic import Json, StringConstraints
from zerver.actions.user_topics import do_set_user_topic_visibility_policy from zerver.actions.user_topics import do_set_user_topic_visibility_policy
from zerver.lib.response import json_success from zerver.lib.response import json_success
@ -16,7 +16,7 @@ from zerver.lib.streams import (
check_for_exactly_one_stream_arg, check_for_exactly_one_stream_arg,
) )
from zerver.lib.typed_endpoint import typed_endpoint from zerver.lib.typed_endpoint import typed_endpoint
from zerver.lib.typed_endpoint_validators import check_int_in from zerver.lib.typed_endpoint_validators import check_int_in_validator
from zerver.models import UserProfile, UserTopic from zerver.models import UserProfile, UserTopic
from zerver.models.constants import MAX_TOPIC_NAME_LENGTH from zerver.models.constants import MAX_TOPIC_NAME_LENGTH
@ -100,10 +100,7 @@ def update_user_topic(
stream_id: Json[int], stream_id: Json[int],
topic: Annotated[str, StringConstraints(max_length=MAX_TOPIC_NAME_LENGTH)], topic: Annotated[str, StringConstraints(max_length=MAX_TOPIC_NAME_LENGTH)],
visibility_policy: Json[ visibility_policy: Json[
Annotated[ Annotated[int, check_int_in_validator(UserTopic.VisibilityPolicy.values)]
int,
AfterValidator(lambda x: check_int_in(x, UserTopic.VisibilityPolicy.values)),
]
], ],
) -> HttpResponse: ) -> HttpResponse:
if visibility_policy == UserTopic.VisibilityPolicy.INHERIT: if visibility_policy == UserTopic.VisibilityPolicy.INHERIT:

View File

@ -57,7 +57,7 @@ from zerver.lib.typed_endpoint import (
typed_endpoint, typed_endpoint,
typed_endpoint_without_parameters, typed_endpoint_without_parameters,
) )
from zerver.lib.typed_endpoint_validators import check_int_in, check_url from zerver.lib.typed_endpoint_validators import check_int_in_validator, check_url
from zerver.lib.types import ProfileDataElementUpdateDict from zerver.lib.types import ProfileDataElementUpdateDict
from zerver.lib.upload import upload_avatar_image from zerver.lib.upload import upload_avatar_image
from zerver.lib.url_encoding import append_url_query_string from zerver.lib.url_encoding import append_url_query_string
@ -98,11 +98,8 @@ from zproject.backends import check_password_strength
RoleParamType: TypeAlias = Annotated[ RoleParamType: TypeAlias = Annotated[
int, int,
AfterValidator( check_int_in_validator(
lambda x: check_int_in(
x,
UserProfile.ROLE_TYPES, UserProfile.ROLE_TYPES,
)
), ),
] ]