From 6815cded8354e3d3d2e4533bb5aa7eccf63dfcdc Mon Sep 17 00:00:00 2001 From: Kenneth Rodrigues Date: Sun, 14 Jul 2024 23:09:20 +0530 Subject: [PATCH] zerver: Migrate some files to typed_endpoint. Migrates `invite.py`, `registration.py` and `email_mirror.py` to use `typed_endpoint`. --- zerver/lib/typed_endpoint_validators.py | 36 +++++++++- zerver/lib/validator.py | 24 +------ zerver/tests/test_invite.py | 8 ++- .../tests/test_typed_endpoint_validators.py | 18 ++++- zerver/views/email_mirror.py | 9 +-- zerver/views/invite.py | 69 +++++++++---------- zerver/views/registration.py | 66 +++++++++--------- 7 files changed, 131 insertions(+), 99 deletions(-) diff --git a/zerver/lib/typed_endpoint_validators.py b/zerver/lib/typed_endpoint_validators.py index 1ca380f2fc..a73730bb71 100644 --- a/zerver/lib/typed_endpoint_validators.py +++ b/zerver/lib/typed_endpoint_validators.py @@ -1,11 +1,14 @@ +import zoneinfo from collections.abc import Collection from django.core.exceptions import ValidationError from django.core.validators import URLValidator from django.utils.translation import gettext as _ -from pydantic import AfterValidator +from pydantic import AfterValidator, BeforeValidator, NonNegativeInt from pydantic_core import PydanticCustomError +from zerver.lib.timezone import canonicalize_timezone + # The Pydantic.StringConstraints does not have validation for the string to be # of the specified length. So, we need to create a custom validator for that. @@ -49,3 +52,34 @@ def check_url(val: str) -> str: return val except ValidationError: raise ValueError(_("Not a URL")) + + +def to_timezone_or_empty(s: str) -> str: + try: + s = canonicalize_timezone(s) + zoneinfo.ZoneInfo(s) + except (ValueError, zoneinfo.ZoneInfoNotFoundError): + return "" + else: + return s + + +def timezone_or_empty_validator() -> AfterValidator: + return AfterValidator(lambda s: to_timezone_or_empty(s)) + + +def to_non_negative_int_or_none(s: str) -> NonNegativeInt | None: + try: + i = int(s) + if i < 0: + return None + return i + except ValueError: + return None + + +# We use BeforeValidator, not AfterValidator, here, because the int +# type conversion will raise a ValueError if the string is not a valid +# integer, and we want to return None in that case. +def non_negative_int_or_none_validator() -> BeforeValidator: + return BeforeValidator(lambda s: to_non_negative_int_or_none(s)) diff --git a/zerver/lib/validator.py b/zerver/lib/validator.py index 3139b63d8c..c1c2f47620 100644 --- a/zerver/lib/validator.py +++ b/zerver/lib/validator.py @@ -30,7 +30,7 @@ for any particular type of object. import re import zoneinfo -from collections.abc import Callable, Collection, Container, Iterator +from collections.abc import Collection, Container, Iterator from dataclasses import dataclass from datetime import datetime, timezone from typing import Any, NoReturn, TypeVar, cast, overload @@ -589,28 +589,6 @@ def to_float(var_name: str, s: str) -> float: return float(s) -def to_timezone_or_empty(var_name: str, s: str) -> str: - try: - s = canonicalize_timezone(s) - zoneinfo.ZoneInfo(s) - except (ValueError, zoneinfo.ZoneInfoNotFoundError): - return "" - else: - return s - - -def to_converted_or_fallback( - sub_converter: Callable[[str, str], ResultT], default: ResultT -) -> Callable[[str, str], ResultT]: - def converter(var_name: str, s: str) -> ResultT: - try: - return sub_converter(var_name, s) - except ValueError: - return default - - return converter - - def check_string_or_int_list(var_name: str, val: object) -> str | list[int]: if isinstance(val, str): return val diff --git a/zerver/tests/test_invite.py b/zerver/tests/test_invite.py index 5b2417ed97..24610f0863 100644 --- a/zerver/tests/test_invite.py +++ b/zerver/tests/test_invite.py @@ -731,7 +731,9 @@ class InviteUserTest(InviteUserBase): self.login("iago") invitee = self.nonreg_email("alice") response = self.invite(invitee, ["Denmark"], invite_as=10) - self.assert_json_error(response, "Invalid invite_as") + self.assert_json_error( + response, "Invalid invite_as: Value error, Not in the list of possible values" + ) def test_successful_invite_user_as_guest_from_normal_account(self) -> None: self.login("hamlet") @@ -2953,4 +2955,6 @@ class MultiuseInviteTest(ZulipTestCase): "invite_expires_in_minutes": 2 * 24 * 60, }, ) - self.assert_json_error(result, "Invalid invite_as") + self.assert_json_error( + result, "Invalid invite_as: Value error, Not in the list of possible values" + ) diff --git a/zerver/tests/test_typed_endpoint_validators.py b/zerver/tests/test_typed_endpoint_validators.py index f78cf84bfd..eb6ec37288 100644 --- a/zerver/tests/test_typed_endpoint_validators.py +++ b/zerver/tests/test_typed_endpoint_validators.py @@ -1,5 +1,10 @@ from zerver.lib.test_classes import ZulipTestCase -from zerver.lib.typed_endpoint_validators import check_int_in, check_string_in, check_url +from zerver.lib.typed_endpoint_validators import ( + check_int_in, + check_string_in, + check_url, + to_non_negative_int_or_none, +) class ValidatorTestCase(ZulipTestCase): @@ -17,3 +22,14 @@ class ValidatorTestCase(ZulipTestCase): check_url("https://example.com") with self.assertRaisesRegex(ValueError, "Not a URL"): check_url("https://127.0.0..:5000") + + def test_to_non_negative_int_or_none(self) -> None: + self.assertEqual(to_non_negative_int_or_none("3"), 3) + self.assertEqual(to_non_negative_int_or_none("-3"), None) + self.assertEqual(to_non_negative_int_or_none("a"), None) + self.assertEqual(to_non_negative_int_or_none("3.5"), None) + self.assertEqual(to_non_negative_int_or_none("3.0"), None) + self.assertEqual(to_non_negative_int_or_none("3.1"), None) + self.assertEqual(to_non_negative_int_or_none("3.9"), None) + self.assertEqual(to_non_negative_int_or_none("3.5"), None) + self.assertEqual(to_non_negative_int_or_none("foo"), None) diff --git a/zerver/views/email_mirror.py b/zerver/views/email_mirror.py index 732fcf4fcb..de84d2fe8b 100644 --- a/zerver/views/email_mirror.py +++ b/zerver/views/email_mirror.py @@ -3,16 +3,17 @@ from django.http import HttpRequest, HttpResponse from zerver.decorator import internal_api_view from zerver.lib.email_mirror import mirror_email_message from zerver.lib.exceptions import JsonableError -from zerver.lib.request import REQ, has_request_variables from zerver.lib.response import json_success +from zerver.lib.typed_endpoint import typed_endpoint @internal_api_view(False) -@has_request_variables +@typed_endpoint def email_mirror_message( request: HttpRequest, - rcpt_to: str = REQ(), - msg_base64: str = REQ(), + *, + rcpt_to: str, + msg_base64: str, ) -> HttpResponse: result = mirror_email_message(rcpt_to, msg_base64) if result["status"] == "error": diff --git a/zerver/views/invite.py b/zerver/views/invite.py index 8b65ee8faf..50628fe890 100644 --- a/zerver/views/invite.py +++ b/zerver/views/invite.py @@ -1,10 +1,11 @@ import re -from collections.abc import Sequence +from typing import Annotated from django.conf import settings from django.http import HttpRequest, HttpResponse from django.utils.timezone import now as timezone_now from django.utils.translation import gettext as _ +from pydantic import Json from confirmation import settings as confirmation_settings from zerver.actions.invites import ( @@ -17,10 +18,10 @@ from zerver.actions.invites import ( ) from zerver.decorator import require_member_or_admin from zerver.lib.exceptions import InvitationError, JsonableError, OrganizationOwnerRequiredError -from zerver.lib.request import REQ, has_request_variables from zerver.lib.response import json_success from zerver.lib.streams import access_stream_by_id -from zerver.lib.validator import check_bool, check_int, check_int_in, check_list, check_none_or +from zerver.lib.typed_endpoint import ApiParamConfig, PathOnly, typed_endpoint +from zerver.lib.typed_endpoint_validators import check_int_in_validator from zerver.models import MultiuseInvite, PreregistrationUser, Stream, UserProfile # Convert INVITATION_LINK_VALIDITY_DAYS into minutes. @@ -44,25 +45,20 @@ def check_role_based_permissions( @require_member_or_admin -@has_request_variables +@typed_endpoint def invite_users_backend( request: HttpRequest, user_profile: UserProfile, - invitee_emails_raw: str = REQ("invitee_emails"), - invite_expires_in_minutes: int | None = REQ( - json_validator=check_none_or(check_int), default=INVITATION_LINK_VALIDITY_MINUTES - ), - invite_as: int = REQ( - json_validator=check_int_in( - list(PreregistrationUser.INVITE_AS.values()), - ), - default=PreregistrationUser.INVITE_AS["MEMBER"], - ), - notify_referrer_on_join: bool = REQ( - "notify_referrer_on_join", json_validator=check_bool, default=True - ), - stream_ids: list[int] = REQ(json_validator=check_list(check_int)), - include_realm_default_subscriptions: bool = REQ(json_validator=check_bool, default=False), + *, + invitee_emails_raw: Annotated[str, ApiParamConfig("invitee_emails")], + invite_expires_in_minutes: Json[int | None] = INVITATION_LINK_VALIDITY_MINUTES, + invite_as: Annotated[ + Json[int], + check_int_in_validator(list(PreregistrationUser.INVITE_AS.values())), + ] = PreregistrationUser.INVITE_AS["MEMBER"], + notify_referrer_on_join: Json[bool] = True, + stream_ids: Json[list[int]], + include_realm_default_subscriptions: Json[bool] = False, ) -> HttpResponse: if not user_profile.can_invite_users_by_email(): # Guest users case will not be handled here as it will @@ -140,9 +136,9 @@ def get_user_invites(request: HttpRequest, user_profile: UserProfile) -> HttpRes @require_member_or_admin -@has_request_variables +@typed_endpoint def revoke_user_invite( - request: HttpRequest, user_profile: UserProfile, invite_id: int + request: HttpRequest, user_profile: UserProfile, *, invite_id: PathOnly[int] ) -> HttpResponse: try: prereg_user = PreregistrationUser.objects.get(id=invite_id) @@ -160,9 +156,9 @@ def revoke_user_invite( @require_member_or_admin -@has_request_variables +@typed_endpoint def revoke_multiuse_invite( - request: HttpRequest, user_profile: UserProfile, invite_id: int + request: HttpRequest, user_profile: UserProfile, *, invite_id: PathOnly[int] ) -> HttpResponse: try: invite = MultiuseInvite.objects.get(id=invite_id) @@ -183,9 +179,9 @@ def revoke_multiuse_invite( @require_member_or_admin -@has_request_variables +@typed_endpoint def resend_user_invite_email( - request: HttpRequest, user_profile: UserProfile, invite_id: int + request: HttpRequest, user_profile: UserProfile, *, invite_id: PathOnly[int] ) -> HttpResponse: try: prereg_user = PreregistrationUser.objects.get(id=invite_id) @@ -205,22 +201,21 @@ def resend_user_invite_email( @require_member_or_admin -@has_request_variables +@typed_endpoint def generate_multiuse_invite_backend( request: HttpRequest, user_profile: UserProfile, - invite_expires_in_minutes: int | None = REQ( - json_validator=check_none_or(check_int), default=INVITATION_LINK_VALIDITY_MINUTES - ), - invite_as: int = REQ( - json_validator=check_int_in( - list(PreregistrationUser.INVITE_AS.values()), - ), - default=PreregistrationUser.INVITE_AS["MEMBER"], - ), - stream_ids: Sequence[int] = REQ(json_validator=check_list(check_int), default=[]), - include_realm_default_subscriptions: bool = REQ(json_validator=check_bool, default=False), + *, + invite_expires_in_minutes: Json[int | None] = INVITATION_LINK_VALIDITY_MINUTES, + invite_as: Annotated[ + Json[int], + check_int_in_validator(list(PreregistrationUser.INVITE_AS.values())), + ] = PreregistrationUser.INVITE_AS["MEMBER"], + stream_ids: Json[list[int]] | None = None, + include_realm_default_subscriptions: Json[bool] = False, ) -> HttpResponse: + if stream_ids is None: + stream_ids = [] if not user_profile.can_create_multiuse_invite_to_realm(): # Guest users case will not be handled here as it will # be handled by the decorator above. diff --git a/zerver/views/registration.py b/zerver/views/registration.py index f478242566..25111b0999 100644 --- a/zerver/views/registration.py +++ b/zerver/views/registration.py @@ -1,7 +1,7 @@ import logging from collections.abc import Iterable from contextlib import suppress -from typing import Any +from typing import Annotated, Any from urllib.parse import urlencode, urljoin import orjson @@ -19,6 +19,7 @@ from django.urls import reverse from django.utils.translation import get_language from django.views.defaults import server_error from django_auth_ldap.backend import LDAPBackend, _LDAPUser +from pydantic import Json, NonNegativeInt, StringConstraints from confirmation.models import ( Confirmation, @@ -59,19 +60,22 @@ from zerver.lib.i18n import ( ) from zerver.lib.pysa import mark_sanitized from zerver.lib.rate_limiter import rate_limit_request_by_ip -from zerver.lib.request import REQ, has_request_variables from zerver.lib.send_email import EmailNotDeliveredError, FromAddress, send_email from zerver.lib.sessions import get_expirable_session_var from zerver.lib.subdomains import get_subdomain +from zerver.lib.typed_endpoint import ( + ApiParamConfig, + PathOnly, + typed_endpoint, + typed_endpoint_without_parameters, +) +from zerver.lib.typed_endpoint_validators import ( + check_int_in_validator, + non_negative_int_or_none_validator, + timezone_or_empty_validator, +) from zerver.lib.url_encoding import append_url_query_string from zerver.lib.users import get_accounts_for_email -from zerver.lib.validator import ( - check_capped_string, - check_int_in, - to_converted_or_fallback, - to_non_negative_int, - to_timezone_or_empty, -) from zerver.lib.zephyr import compute_mit_user_fullname from zerver.models import ( MultiuseInvite, @@ -119,9 +123,9 @@ if settings.BILLING_ENABLED: from corporate.lib.stripe import LicenseLimitError -@has_request_variables +@typed_endpoint def get_prereg_key_and_redirect( - request: HttpRequest, confirmation_key: str, full_name: str | None = REQ(default=None) + request: HttpRequest, *, confirmation_key: PathOnly[str], full_name: str | None = None ) -> HttpResponse: """ The purpose of this little endpoint is primarily to take a GET @@ -223,17 +227,16 @@ def accounts_register(*args: Any, **kwargs: Any) -> HttpResponse: return registration_helper(*args, **kwargs) -@has_request_variables +@typed_endpoint def registration_helper( request: HttpRequest, - key: str = REQ(default=""), - timezone: str = REQ(default="", converter=to_timezone_or_empty), - from_confirmation: str | None = REQ(default=None), - form_full_name: str | None = REQ("full_name", default=None), - source_realm_id: int | None = REQ( - default=None, converter=to_converted_or_fallback(to_non_negative_int, None) - ), - form_is_demo_organization: str | None = REQ("is_demo_organization", default=None), + *, + key: str = "", + timezone: Annotated[str, timezone_or_empty_validator()] = "", + from_confirmation: str | None = None, + form_full_name: Annotated[str | None, ApiParamConfig("full_name")] = None, + source_realm_id: Annotated[NonNegativeInt | None, non_negative_int_or_none_validator()] = None, + form_is_demo_organization: Annotated[str | None, ApiParamConfig("is_demo_organization")] = None, ) -> HttpResponse: try: prereg_object, realm_creation = check_prereg_key(request, key) @@ -958,8 +961,8 @@ def create_realm(request: HttpRequest, creation_key: str | None = None) -> HttpR ) -@has_request_variables -def signup_send_confirm(request: HttpRequest, email: str = REQ("email")) -> HttpResponse: +@typed_endpoint +def signup_send_confirm(request: HttpRequest, *, email: str) -> HttpResponse: try: # Because we interpolate the email directly into the template # from the query parameter, do a simple validation that it @@ -980,14 +983,15 @@ def signup_send_confirm(request: HttpRequest, email: str = REQ("email")) -> Http @add_google_analytics -@has_request_variables +@typed_endpoint def new_realm_send_confirm( request: HttpRequest, - email: str = REQ("email"), - realm_name: str = REQ(str_validator=check_capped_string(Realm.MAX_REALM_NAME_LENGTH)), - realm_type: int = REQ(json_validator=check_int_in(Realm.ORG_TYPE_IDS)), - realm_default_language: str = REQ(str_validator=check_capped_string(MAX_LANGUAGE_ID_LENGTH)), - realm_subdomain: str = REQ(str_validator=check_capped_string(Realm.MAX_REALM_SUBDOMAIN_LENGTH)), + *, + email: str, + realm_name: Annotated[str, StringConstraints(max_length=Realm.MAX_REALM_NAME_LENGTH)], + realm_type: Annotated[Json[int], check_int_in_validator(Realm.ORG_TYPE_IDS)], + realm_default_language: Annotated[str, StringConstraints(max_length=MAX_LANGUAGE_ID_LENGTH)], + realm_subdomain: Annotated[str, StringConstraints(max_length=Realm.MAX_REALM_SUBDOMAIN_LENGTH)], ) -> HttpResponse: return TemplateResponse( request, @@ -1113,7 +1117,7 @@ def accounts_home_from_multiuse_invite(request: HttpRequest, confirmation_key: s ) -@has_request_variables +@typed_endpoint_without_parameters def find_account(request: HttpRequest) -> HttpResponse: url = reverse("find_account") form = FindMyTeamForm() @@ -1217,8 +1221,8 @@ def find_account(request: HttpRequest) -> HttpResponse: ) -@has_request_variables -def realm_redirect(request: HttpRequest, next: str = REQ(default="")) -> HttpResponse: +@typed_endpoint +def realm_redirect(request: HttpRequest, *, next: str = "") -> HttpResponse: if request.method == "POST": form = RealmRedirectForm(request.POST) if form.is_valid():