mirror of https://github.com/zulip/zulip.git
validators: Use cleaner syntax for AfterValidator.
Created a function that returns an `AfterValidator` for `check_int_in` and `check_string_in` instead of having to use a `lambda` wraper everytime.
This commit is contained in:
parent
0ec4b0285e
commit
a7da24a36f
|
@ -4,7 +4,7 @@ from typing import Annotated
|
|||
from django.conf import settings
|
||||
from django.http import HttpRequest, HttpResponse, HttpResponseRedirect
|
||||
from django.shortcuts import render
|
||||
from pydantic import AfterValidator, Json
|
||||
from pydantic import Json
|
||||
|
||||
from corporate.lib.decorator import (
|
||||
authenticated_remote_realm_management_endpoint,
|
||||
|
@ -25,7 +25,7 @@ from corporate.models import CustomerPlan
|
|||
from zerver.decorator import require_organization_member, zulip_login_required
|
||||
from zerver.lib.response import json_success
|
||||
from zerver.lib.typed_endpoint import typed_endpoint
|
||||
from zerver.lib.typed_endpoint_validators import check_string_in
|
||||
from zerver.lib.typed_endpoint_validators import check_string_in_validator
|
||||
from zerver.models import UserProfile
|
||||
from zilencer.lib.remote_counts import MissingDataError
|
||||
|
||||
|
@ -38,17 +38,11 @@ def upgrade(
|
|||
request: HttpRequest,
|
||||
user: UserProfile,
|
||||
*,
|
||||
billing_modality: Annotated[
|
||||
str, AfterValidator(lambda val: check_string_in(val, VALID_BILLING_MODALITY_VALUES))
|
||||
],
|
||||
schedule: Annotated[
|
||||
str, AfterValidator(lambda val: check_string_in(val, VALID_BILLING_SCHEDULE_VALUES))
|
||||
],
|
||||
billing_modality: Annotated[str, check_string_in_validator(VALID_BILLING_MODALITY_VALUES)],
|
||||
schedule: Annotated[str, check_string_in_validator(VALID_BILLING_SCHEDULE_VALUES)],
|
||||
signed_seat_count: str,
|
||||
salt: str,
|
||||
license_management: Annotated[
|
||||
str, AfterValidator(lambda val: check_string_in(val, VALID_LICENSE_MANAGEMENT_VALUES))
|
||||
]
|
||||
license_management: Annotated[str, check_string_in_validator(VALID_LICENSE_MANAGEMENT_VALUES)]
|
||||
| None = None,
|
||||
licenses: Json[int] | None = None,
|
||||
tier: Json[int] = CustomerPlan.TIER_CLOUD_STANDARD,
|
||||
|
@ -94,17 +88,11 @@ def remote_realm_upgrade(
|
|||
request: HttpRequest,
|
||||
billing_session: RemoteRealmBillingSession,
|
||||
*,
|
||||
billing_modality: Annotated[
|
||||
str, AfterValidator(lambda val: check_string_in(val, VALID_BILLING_MODALITY_VALUES))
|
||||
],
|
||||
schedule: Annotated[
|
||||
str, AfterValidator(lambda val: check_string_in(val, VALID_BILLING_SCHEDULE_VALUES))
|
||||
],
|
||||
billing_modality: Annotated[str, check_string_in_validator(VALID_BILLING_MODALITY_VALUES)],
|
||||
schedule: Annotated[str, check_string_in_validator(VALID_BILLING_SCHEDULE_VALUES)],
|
||||
signed_seat_count: str,
|
||||
salt: str,
|
||||
license_management: Annotated[
|
||||
str, AfterValidator(lambda val: check_string_in(val, VALID_LICENSE_MANAGEMENT_VALUES))
|
||||
]
|
||||
license_management: Annotated[str, check_string_in_validator(VALID_LICENSE_MANAGEMENT_VALUES)]
|
||||
| None = None,
|
||||
licenses: Json[int] | None = None,
|
||||
remote_server_plan_start_date: str | None = None,
|
||||
|
@ -149,17 +137,11 @@ def remote_server_upgrade(
|
|||
request: HttpRequest,
|
||||
billing_session: RemoteServerBillingSession,
|
||||
*,
|
||||
billing_modality: Annotated[
|
||||
str, AfterValidator(lambda val: check_string_in(val, VALID_BILLING_MODALITY_VALUES))
|
||||
],
|
||||
schedule: Annotated[
|
||||
str, AfterValidator(lambda val: check_string_in(val, VALID_BILLING_SCHEDULE_VALUES))
|
||||
],
|
||||
billing_modality: Annotated[str, check_string_in_validator(VALID_BILLING_MODALITY_VALUES)],
|
||||
schedule: Annotated[str, check_string_in_validator(VALID_BILLING_SCHEDULE_VALUES)],
|
||||
signed_seat_count: str,
|
||||
salt: str,
|
||||
license_management: Annotated[
|
||||
str, AfterValidator(lambda val: check_string_in(val, VALID_LICENSE_MANAGEMENT_VALUES))
|
||||
]
|
||||
license_management: Annotated[str, check_string_in_validator(VALID_LICENSE_MANAGEMENT_VALUES)]
|
||||
| None = None,
|
||||
licenses: Json[int] | None = None,
|
||||
remote_server_plan_start_date: str | None = None,
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
from collections.abc import Collection
|
||||
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.core.validators import URLValidator
|
||||
from django.utils.translation import gettext as _
|
||||
from pydantic import AfterValidator
|
||||
from pydantic_core import PydanticCustomError
|
||||
|
||||
# The Pydantic.StringConstraints does not have validation for the string to be
|
||||
|
@ -19,18 +22,26 @@ def check_string_fixed_length(string: str, length: int) -> str | None:
|
|||
return string
|
||||
|
||||
|
||||
def check_string_in(val: str, possible_values: list[str]) -> str:
|
||||
def check_string_in(val: str, possible_values: Collection[str]) -> str:
|
||||
if val not in possible_values:
|
||||
raise ValueError(_("Not in the list of possible values"))
|
||||
return val
|
||||
|
||||
|
||||
def check_int_in(val: int, possible_values: list[int]) -> int:
|
||||
def check_int_in(val: int, possible_values: Collection[int]) -> int:
|
||||
if val not in possible_values:
|
||||
raise ValueError(_("Not in the list of possible values"))
|
||||
return val
|
||||
|
||||
|
||||
def check_int_in_validator(possible_values: Collection[int]) -> AfterValidator:
|
||||
return AfterValidator(lambda val: check_int_in(val, possible_values))
|
||||
|
||||
|
||||
def check_string_in_validator(possible_values: Collection[str]) -> AfterValidator:
|
||||
return AfterValidator(lambda val: check_string_in(val, possible_values))
|
||||
|
||||
|
||||
def check_url(val: str) -> str:
|
||||
validate = URLValidator()
|
||||
try:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from zerver.lib.test_classes import ZulipTestCase
|
||||
from zerver.lib.typed_endpoint_validators import check_int_in, check_url
|
||||
from zerver.lib.typed_endpoint_validators import check_int_in, check_string_in, check_url
|
||||
|
||||
|
||||
class ValidatorTestCase(ZulipTestCase):
|
||||
|
@ -8,6 +8,11 @@ class ValidatorTestCase(ZulipTestCase):
|
|||
with self.assertRaisesRegex(ValueError, "Not in the list of possible values"):
|
||||
check_int_in(3, [1, 2])
|
||||
|
||||
def test_check_string_in(self) -> None:
|
||||
check_string_in("foo", ["foo", "bar"])
|
||||
with self.assertRaisesRegex(ValueError, "Not in the list of possible values"):
|
||||
check_string_in("foo", ["bar"])
|
||||
|
||||
def test_check_url(self) -> None:
|
||||
check_url("https://example.com")
|
||||
with self.assertRaisesRegex(ValueError, "Not a URL"):
|
||||
|
|
|
@ -688,6 +688,26 @@ class UnmutedTopicsTests(ZulipTestCase):
|
|||
self.assert_json_error(result, "Invalid channel ID")
|
||||
|
||||
|
||||
class UserTopicsTests(ZulipTestCase):
|
||||
def test_invalid_visibility_policy(self) -> None:
|
||||
user = self.example_user("hamlet")
|
||||
self.login_user(user)
|
||||
|
||||
stream = get_stream("Verona", user.realm)
|
||||
|
||||
url = "/api/v1/user_topics"
|
||||
data = {
|
||||
"stream_id": stream.id,
|
||||
"topic": "Verona3",
|
||||
"visibility_policy": 999,
|
||||
}
|
||||
|
||||
result = self.api_post(user, url, data)
|
||||
self.assert_json_error(
|
||||
result, "Invalid visibility_policy: Value error, Not in the list of possible values"
|
||||
)
|
||||
|
||||
|
||||
class AutomaticallyFollowTopicsTests(ZulipTestCase):
|
||||
def test_automatically_follow_topic_on_initiation(self) -> None:
|
||||
hamlet = self.example_user("hamlet")
|
||||
|
|
|
@ -461,6 +461,14 @@ class PermissionTest(ZulipTestCase):
|
|||
result = self.client_patch("/json/users/{}".format(self.example_user("hamlet").id), req)
|
||||
self.assert_json_error(result, "Invalid format!")
|
||||
|
||||
def test_invalid_role(self) -> None:
|
||||
self.login("iago")
|
||||
req = dict(role=1000)
|
||||
result = self.client_patch("/json/users/{}".format(self.example_user("hamlet").id), req)
|
||||
self.assert_json_error(
|
||||
result, "Invalid role: Value error, Not in the list of possible values"
|
||||
)
|
||||
|
||||
def test_admin_cannot_set_full_name_with_invalid_characters(self) -> None:
|
||||
new_name = "Opheli*"
|
||||
self.login("iago")
|
||||
|
|
|
@ -4,7 +4,7 @@ from typing import Annotated, Literal
|
|||
from django.http import HttpRequest, HttpResponse
|
||||
from django.utils.timezone import now as timezone_now
|
||||
from django.utils.translation import gettext as _
|
||||
from pydantic import AfterValidator, Json, StringConstraints
|
||||
from pydantic import Json, StringConstraints
|
||||
|
||||
from zerver.actions.user_topics import do_set_user_topic_visibility_policy
|
||||
from zerver.lib.response import json_success
|
||||
|
@ -16,7 +16,7 @@ from zerver.lib.streams import (
|
|||
check_for_exactly_one_stream_arg,
|
||||
)
|
||||
from zerver.lib.typed_endpoint import typed_endpoint
|
||||
from zerver.lib.typed_endpoint_validators import check_int_in
|
||||
from zerver.lib.typed_endpoint_validators import check_int_in_validator
|
||||
from zerver.models import UserProfile, UserTopic
|
||||
from zerver.models.constants import MAX_TOPIC_NAME_LENGTH
|
||||
|
||||
|
@ -100,10 +100,7 @@ def update_user_topic(
|
|||
stream_id: Json[int],
|
||||
topic: Annotated[str, StringConstraints(max_length=MAX_TOPIC_NAME_LENGTH)],
|
||||
visibility_policy: Json[
|
||||
Annotated[
|
||||
int,
|
||||
AfterValidator(lambda x: check_int_in(x, UserTopic.VisibilityPolicy.values)),
|
||||
]
|
||||
Annotated[int, check_int_in_validator(UserTopic.VisibilityPolicy.values)]
|
||||
],
|
||||
) -> HttpResponse:
|
||||
if visibility_policy == UserTopic.VisibilityPolicy.INHERIT:
|
||||
|
|
|
@ -57,7 +57,7 @@ from zerver.lib.typed_endpoint import (
|
|||
typed_endpoint,
|
||||
typed_endpoint_without_parameters,
|
||||
)
|
||||
from zerver.lib.typed_endpoint_validators import check_int_in, check_url
|
||||
from zerver.lib.typed_endpoint_validators import check_int_in_validator, check_url
|
||||
from zerver.lib.types import ProfileDataElementUpdateDict
|
||||
from zerver.lib.upload import upload_avatar_image
|
||||
from zerver.lib.url_encoding import append_url_query_string
|
||||
|
@ -98,11 +98,8 @@ from zproject.backends import check_password_strength
|
|||
|
||||
RoleParamType: TypeAlias = Annotated[
|
||||
int,
|
||||
AfterValidator(
|
||||
lambda x: check_int_in(
|
||||
x,
|
||||
UserProfile.ROLE_TYPES,
|
||||
)
|
||||
check_int_in_validator(
|
||||
UserProfile.ROLE_TYPES,
|
||||
),
|
||||
]
|
||||
|
||||
|
|
Loading…
Reference in New Issue