diff --git a/corporate/views/upgrade.py b/corporate/views/upgrade.py index 7b74a34dd9..6ebe61ad4f 100644 --- a/corporate/views/upgrade.py +++ b/corporate/views/upgrade.py @@ -4,7 +4,7 @@ from typing import Annotated from django.conf import settings from django.http import HttpRequest, HttpResponse, HttpResponseRedirect from django.shortcuts import render -from pydantic import AfterValidator, Json +from pydantic import Json from corporate.lib.decorator import ( authenticated_remote_realm_management_endpoint, @@ -25,7 +25,7 @@ from corporate.models import CustomerPlan from zerver.decorator import require_organization_member, zulip_login_required from zerver.lib.response import json_success from zerver.lib.typed_endpoint import typed_endpoint -from zerver.lib.typed_endpoint_validators import check_string_in +from zerver.lib.typed_endpoint_validators import check_string_in_validator from zerver.models import UserProfile from zilencer.lib.remote_counts import MissingDataError @@ -38,17 +38,11 @@ def upgrade( request: HttpRequest, user: UserProfile, *, - billing_modality: Annotated[ - str, AfterValidator(lambda val: check_string_in(val, VALID_BILLING_MODALITY_VALUES)) - ], - schedule: Annotated[ - str, AfterValidator(lambda val: check_string_in(val, VALID_BILLING_SCHEDULE_VALUES)) - ], + billing_modality: Annotated[str, check_string_in_validator(VALID_BILLING_MODALITY_VALUES)], + schedule: Annotated[str, check_string_in_validator(VALID_BILLING_SCHEDULE_VALUES)], signed_seat_count: str, salt: str, - license_management: Annotated[ - str, AfterValidator(lambda val: check_string_in(val, VALID_LICENSE_MANAGEMENT_VALUES)) - ] + license_management: Annotated[str, check_string_in_validator(VALID_LICENSE_MANAGEMENT_VALUES)] | None = None, licenses: Json[int] | None = None, tier: Json[int] = CustomerPlan.TIER_CLOUD_STANDARD, @@ -94,17 +88,11 @@ def remote_realm_upgrade( request: HttpRequest, billing_session: RemoteRealmBillingSession, *, - billing_modality: Annotated[ - str, AfterValidator(lambda val: check_string_in(val, VALID_BILLING_MODALITY_VALUES)) - ], - schedule: Annotated[ - str, AfterValidator(lambda val: check_string_in(val, VALID_BILLING_SCHEDULE_VALUES)) - ], + billing_modality: Annotated[str, check_string_in_validator(VALID_BILLING_MODALITY_VALUES)], + schedule: Annotated[str, check_string_in_validator(VALID_BILLING_SCHEDULE_VALUES)], signed_seat_count: str, salt: str, - license_management: Annotated[ - str, AfterValidator(lambda val: check_string_in(val, VALID_LICENSE_MANAGEMENT_VALUES)) - ] + license_management: Annotated[str, check_string_in_validator(VALID_LICENSE_MANAGEMENT_VALUES)] | None = None, licenses: Json[int] | None = None, remote_server_plan_start_date: str | None = None, @@ -149,17 +137,11 @@ def remote_server_upgrade( request: HttpRequest, billing_session: RemoteServerBillingSession, *, - billing_modality: Annotated[ - str, AfterValidator(lambda val: check_string_in(val, VALID_BILLING_MODALITY_VALUES)) - ], - schedule: Annotated[ - str, AfterValidator(lambda val: check_string_in(val, VALID_BILLING_SCHEDULE_VALUES)) - ], + billing_modality: Annotated[str, check_string_in_validator(VALID_BILLING_MODALITY_VALUES)], + schedule: Annotated[str, check_string_in_validator(VALID_BILLING_SCHEDULE_VALUES)], signed_seat_count: str, salt: str, - license_management: Annotated[ - str, AfterValidator(lambda val: check_string_in(val, VALID_LICENSE_MANAGEMENT_VALUES)) - ] + license_management: Annotated[str, check_string_in_validator(VALID_LICENSE_MANAGEMENT_VALUES)] | None = None, licenses: Json[int] | None = None, remote_server_plan_start_date: str | None = None, diff --git a/zerver/lib/typed_endpoint_validators.py b/zerver/lib/typed_endpoint_validators.py index f5e37b3b7d..1ca380f2fc 100644 --- a/zerver/lib/typed_endpoint_validators.py +++ b/zerver/lib/typed_endpoint_validators.py @@ -1,6 +1,9 @@ +from collections.abc import Collection + from django.core.exceptions import ValidationError from django.core.validators import URLValidator from django.utils.translation import gettext as _ +from pydantic import AfterValidator from pydantic_core import PydanticCustomError # The Pydantic.StringConstraints does not have validation for the string to be @@ -19,18 +22,26 @@ def check_string_fixed_length(string: str, length: int) -> str | None: return string -def check_string_in(val: str, possible_values: list[str]) -> str: +def check_string_in(val: str, possible_values: Collection[str]) -> str: if val not in possible_values: raise ValueError(_("Not in the list of possible values")) return val -def check_int_in(val: int, possible_values: list[int]) -> int: +def check_int_in(val: int, possible_values: Collection[int]) -> int: if val not in possible_values: raise ValueError(_("Not in the list of possible values")) return val +def check_int_in_validator(possible_values: Collection[int]) -> AfterValidator: + return AfterValidator(lambda val: check_int_in(val, possible_values)) + + +def check_string_in_validator(possible_values: Collection[str]) -> AfterValidator: + return AfterValidator(lambda val: check_string_in(val, possible_values)) + + def check_url(val: str) -> str: validate = URLValidator() try: diff --git a/zerver/tests/test_typed_endpoint_validators.py b/zerver/tests/test_typed_endpoint_validators.py index b8f44bf0ab..f78cf84bfd 100644 --- a/zerver/tests/test_typed_endpoint_validators.py +++ b/zerver/tests/test_typed_endpoint_validators.py @@ -1,5 +1,5 @@ from zerver.lib.test_classes import ZulipTestCase -from zerver.lib.typed_endpoint_validators import check_int_in, check_url +from zerver.lib.typed_endpoint_validators import check_int_in, check_string_in, check_url class ValidatorTestCase(ZulipTestCase): @@ -8,6 +8,11 @@ class ValidatorTestCase(ZulipTestCase): with self.assertRaisesRegex(ValueError, "Not in the list of possible values"): check_int_in(3, [1, 2]) + def test_check_string_in(self) -> None: + check_string_in("foo", ["foo", "bar"]) + with self.assertRaisesRegex(ValueError, "Not in the list of possible values"): + check_string_in("foo", ["bar"]) + def test_check_url(self) -> None: check_url("https://example.com") with self.assertRaisesRegex(ValueError, "Not a URL"): diff --git a/zerver/tests/test_user_topics.py b/zerver/tests/test_user_topics.py index 364662c414..9d316d6a87 100644 --- a/zerver/tests/test_user_topics.py +++ b/zerver/tests/test_user_topics.py @@ -688,6 +688,26 @@ class UnmutedTopicsTests(ZulipTestCase): self.assert_json_error(result, "Invalid channel ID") +class UserTopicsTests(ZulipTestCase): + def test_invalid_visibility_policy(self) -> None: + user = self.example_user("hamlet") + self.login_user(user) + + stream = get_stream("Verona", user.realm) + + url = "/api/v1/user_topics" + data = { + "stream_id": stream.id, + "topic": "Verona3", + "visibility_policy": 999, + } + + result = self.api_post(user, url, data) + self.assert_json_error( + result, "Invalid visibility_policy: Value error, Not in the list of possible values" + ) + + class AutomaticallyFollowTopicsTests(ZulipTestCase): def test_automatically_follow_topic_on_initiation(self) -> None: hamlet = self.example_user("hamlet") diff --git a/zerver/tests/test_users.py b/zerver/tests/test_users.py index c828a72f0e..d8fa3a0eeb 100644 --- a/zerver/tests/test_users.py +++ b/zerver/tests/test_users.py @@ -461,6 +461,14 @@ class PermissionTest(ZulipTestCase): result = self.client_patch("/json/users/{}".format(self.example_user("hamlet").id), req) self.assert_json_error(result, "Invalid format!") + def test_invalid_role(self) -> None: + self.login("iago") + req = dict(role=1000) + result = self.client_patch("/json/users/{}".format(self.example_user("hamlet").id), req) + self.assert_json_error( + result, "Invalid role: Value error, Not in the list of possible values" + ) + def test_admin_cannot_set_full_name_with_invalid_characters(self) -> None: new_name = "Opheli*" self.login("iago") diff --git a/zerver/views/user_topics.py b/zerver/views/user_topics.py index 2bf913d122..6045862a1b 100644 --- a/zerver/views/user_topics.py +++ b/zerver/views/user_topics.py @@ -4,7 +4,7 @@ from typing import Annotated, Literal from django.http import HttpRequest, HttpResponse from django.utils.timezone import now as timezone_now from django.utils.translation import gettext as _ -from pydantic import AfterValidator, Json, StringConstraints +from pydantic import Json, StringConstraints from zerver.actions.user_topics import do_set_user_topic_visibility_policy from zerver.lib.response import json_success @@ -16,7 +16,7 @@ from zerver.lib.streams import ( check_for_exactly_one_stream_arg, ) from zerver.lib.typed_endpoint import typed_endpoint -from zerver.lib.typed_endpoint_validators import check_int_in +from zerver.lib.typed_endpoint_validators import check_int_in_validator from zerver.models import UserProfile, UserTopic from zerver.models.constants import MAX_TOPIC_NAME_LENGTH @@ -100,10 +100,7 @@ def update_user_topic( stream_id: Json[int], topic: Annotated[str, StringConstraints(max_length=MAX_TOPIC_NAME_LENGTH)], visibility_policy: Json[ - Annotated[ - int, - AfterValidator(lambda x: check_int_in(x, UserTopic.VisibilityPolicy.values)), - ] + Annotated[int, check_int_in_validator(UserTopic.VisibilityPolicy.values)] ], ) -> HttpResponse: if visibility_policy == UserTopic.VisibilityPolicy.INHERIT: diff --git a/zerver/views/users.py b/zerver/views/users.py index 419035baff..4e4838bcab 100644 --- a/zerver/views/users.py +++ b/zerver/views/users.py @@ -57,7 +57,7 @@ from zerver.lib.typed_endpoint import ( typed_endpoint, typed_endpoint_without_parameters, ) -from zerver.lib.typed_endpoint_validators import check_int_in, check_url +from zerver.lib.typed_endpoint_validators import check_int_in_validator, check_url from zerver.lib.types import ProfileDataElementUpdateDict from zerver.lib.upload import upload_avatar_image from zerver.lib.url_encoding import append_url_query_string @@ -98,11 +98,8 @@ from zproject.backends import check_password_strength RoleParamType: TypeAlias = Annotated[ int, - AfterValidator( - lambda x: check_int_in( - x, - UserProfile.ROLE_TYPES, - ) + check_int_in_validator( + UserProfile.ROLE_TYPES, ), ]