zerver: Migrate some files to typed_endpoint.

Migrates `invite.py`, `registration.py` and
`email_mirror.py` to use `typed_endpoint`.
This commit is contained in:
Kenneth Rodrigues 2024-07-14 23:09:20 +05:30 committed by Tim Abbott
parent 16abd82fa5
commit 6815cded83
7 changed files with 131 additions and 99 deletions

View File

@ -1,11 +1,14 @@
import zoneinfo
from collections.abc import Collection
from django.core.exceptions import ValidationError
from django.core.validators import URLValidator
from django.utils.translation import gettext as _
from pydantic import AfterValidator
from pydantic import AfterValidator, BeforeValidator, NonNegativeInt
from pydantic_core import PydanticCustomError
from zerver.lib.timezone import canonicalize_timezone
# The Pydantic.StringConstraints does not have validation for the string to be
# of the specified length. So, we need to create a custom validator for that.
@ -49,3 +52,34 @@ def check_url(val: str) -> str:
return val
except ValidationError:
raise ValueError(_("Not a URL"))
def to_timezone_or_empty(s: str) -> str:
try:
s = canonicalize_timezone(s)
zoneinfo.ZoneInfo(s)
except (ValueError, zoneinfo.ZoneInfoNotFoundError):
return ""
else:
return s
def timezone_or_empty_validator() -> AfterValidator:
return AfterValidator(lambda s: to_timezone_or_empty(s))
def to_non_negative_int_or_none(s: str) -> NonNegativeInt | None:
try:
i = int(s)
if i < 0:
return None
return i
except ValueError:
return None
# We use BeforeValidator, not AfterValidator, here, because the int
# type conversion will raise a ValueError if the string is not a valid
# integer, and we want to return None in that case.
def non_negative_int_or_none_validator() -> BeforeValidator:
return BeforeValidator(lambda s: to_non_negative_int_or_none(s))

View File

@ -30,7 +30,7 @@ for any particular type of object.
import re
import zoneinfo
from collections.abc import Callable, Collection, Container, Iterator
from collections.abc import Collection, Container, Iterator
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Any, NoReturn, TypeVar, cast, overload
@ -589,28 +589,6 @@ def to_float(var_name: str, s: str) -> float:
return float(s)
def to_timezone_or_empty(var_name: str, s: str) -> str:
try:
s = canonicalize_timezone(s)
zoneinfo.ZoneInfo(s)
except (ValueError, zoneinfo.ZoneInfoNotFoundError):
return ""
else:
return s
def to_converted_or_fallback(
sub_converter: Callable[[str, str], ResultT], default: ResultT
) -> Callable[[str, str], ResultT]:
def converter(var_name: str, s: str) -> ResultT:
try:
return sub_converter(var_name, s)
except ValueError:
return default
return converter
def check_string_or_int_list(var_name: str, val: object) -> str | list[int]:
if isinstance(val, str):
return val

View File

@ -731,7 +731,9 @@ class InviteUserTest(InviteUserBase):
self.login("iago")
invitee = self.nonreg_email("alice")
response = self.invite(invitee, ["Denmark"], invite_as=10)
self.assert_json_error(response, "Invalid invite_as")
self.assert_json_error(
response, "Invalid invite_as: Value error, Not in the list of possible values"
)
def test_successful_invite_user_as_guest_from_normal_account(self) -> None:
self.login("hamlet")
@ -2953,4 +2955,6 @@ class MultiuseInviteTest(ZulipTestCase):
"invite_expires_in_minutes": 2 * 24 * 60,
},
)
self.assert_json_error(result, "Invalid invite_as")
self.assert_json_error(
result, "Invalid invite_as: Value error, Not in the list of possible values"
)

View File

@ -1,5 +1,10 @@
from zerver.lib.test_classes import ZulipTestCase
from zerver.lib.typed_endpoint_validators import check_int_in, check_string_in, check_url
from zerver.lib.typed_endpoint_validators import (
check_int_in,
check_string_in,
check_url,
to_non_negative_int_or_none,
)
class ValidatorTestCase(ZulipTestCase):
@ -17,3 +22,14 @@ class ValidatorTestCase(ZulipTestCase):
check_url("https://example.com")
with self.assertRaisesRegex(ValueError, "Not a URL"):
check_url("https://127.0.0..:5000")
def test_to_non_negative_int_or_none(self) -> None:
self.assertEqual(to_non_negative_int_or_none("3"), 3)
self.assertEqual(to_non_negative_int_or_none("-3"), None)
self.assertEqual(to_non_negative_int_or_none("a"), None)
self.assertEqual(to_non_negative_int_or_none("3.5"), None)
self.assertEqual(to_non_negative_int_or_none("3.0"), None)
self.assertEqual(to_non_negative_int_or_none("3.1"), None)
self.assertEqual(to_non_negative_int_or_none("3.9"), None)
self.assertEqual(to_non_negative_int_or_none("3.5"), None)
self.assertEqual(to_non_negative_int_or_none("foo"), None)

View File

@ -3,16 +3,17 @@ from django.http import HttpRequest, HttpResponse
from zerver.decorator import internal_api_view
from zerver.lib.email_mirror import mirror_email_message
from zerver.lib.exceptions import JsonableError
from zerver.lib.request import REQ, has_request_variables
from zerver.lib.response import json_success
from zerver.lib.typed_endpoint import typed_endpoint
@internal_api_view(False)
@has_request_variables
@typed_endpoint
def email_mirror_message(
request: HttpRequest,
rcpt_to: str = REQ(),
msg_base64: str = REQ(),
*,
rcpt_to: str,
msg_base64: str,
) -> HttpResponse:
result = mirror_email_message(rcpt_to, msg_base64)
if result["status"] == "error":

View File

@ -1,10 +1,11 @@
import re
from collections.abc import Sequence
from typing import Annotated
from django.conf import settings
from django.http import HttpRequest, HttpResponse
from django.utils.timezone import now as timezone_now
from django.utils.translation import gettext as _
from pydantic import Json
from confirmation import settings as confirmation_settings
from zerver.actions.invites import (
@ -17,10 +18,10 @@ from zerver.actions.invites import (
)
from zerver.decorator import require_member_or_admin
from zerver.lib.exceptions import InvitationError, JsonableError, OrganizationOwnerRequiredError
from zerver.lib.request import REQ, has_request_variables
from zerver.lib.response import json_success
from zerver.lib.streams import access_stream_by_id
from zerver.lib.validator import check_bool, check_int, check_int_in, check_list, check_none_or
from zerver.lib.typed_endpoint import ApiParamConfig, PathOnly, typed_endpoint
from zerver.lib.typed_endpoint_validators import check_int_in_validator
from zerver.models import MultiuseInvite, PreregistrationUser, Stream, UserProfile
# Convert INVITATION_LINK_VALIDITY_DAYS into minutes.
@ -44,25 +45,20 @@ def check_role_based_permissions(
@require_member_or_admin
@has_request_variables
@typed_endpoint
def invite_users_backend(
request: HttpRequest,
user_profile: UserProfile,
invitee_emails_raw: str = REQ("invitee_emails"),
invite_expires_in_minutes: int | None = REQ(
json_validator=check_none_or(check_int), default=INVITATION_LINK_VALIDITY_MINUTES
),
invite_as: int = REQ(
json_validator=check_int_in(
list(PreregistrationUser.INVITE_AS.values()),
),
default=PreregistrationUser.INVITE_AS["MEMBER"],
),
notify_referrer_on_join: bool = REQ(
"notify_referrer_on_join", json_validator=check_bool, default=True
),
stream_ids: list[int] = REQ(json_validator=check_list(check_int)),
include_realm_default_subscriptions: bool = REQ(json_validator=check_bool, default=False),
*,
invitee_emails_raw: Annotated[str, ApiParamConfig("invitee_emails")],
invite_expires_in_minutes: Json[int | None] = INVITATION_LINK_VALIDITY_MINUTES,
invite_as: Annotated[
Json[int],
check_int_in_validator(list(PreregistrationUser.INVITE_AS.values())),
] = PreregistrationUser.INVITE_AS["MEMBER"],
notify_referrer_on_join: Json[bool] = True,
stream_ids: Json[list[int]],
include_realm_default_subscriptions: Json[bool] = False,
) -> HttpResponse:
if not user_profile.can_invite_users_by_email():
# Guest users case will not be handled here as it will
@ -140,9 +136,9 @@ def get_user_invites(request: HttpRequest, user_profile: UserProfile) -> HttpRes
@require_member_or_admin
@has_request_variables
@typed_endpoint
def revoke_user_invite(
request: HttpRequest, user_profile: UserProfile, invite_id: int
request: HttpRequest, user_profile: UserProfile, *, invite_id: PathOnly[int]
) -> HttpResponse:
try:
prereg_user = PreregistrationUser.objects.get(id=invite_id)
@ -160,9 +156,9 @@ def revoke_user_invite(
@require_member_or_admin
@has_request_variables
@typed_endpoint
def revoke_multiuse_invite(
request: HttpRequest, user_profile: UserProfile, invite_id: int
request: HttpRequest, user_profile: UserProfile, *, invite_id: PathOnly[int]
) -> HttpResponse:
try:
invite = MultiuseInvite.objects.get(id=invite_id)
@ -183,9 +179,9 @@ def revoke_multiuse_invite(
@require_member_or_admin
@has_request_variables
@typed_endpoint
def resend_user_invite_email(
request: HttpRequest, user_profile: UserProfile, invite_id: int
request: HttpRequest, user_profile: UserProfile, *, invite_id: PathOnly[int]
) -> HttpResponse:
try:
prereg_user = PreregistrationUser.objects.get(id=invite_id)
@ -205,22 +201,21 @@ def resend_user_invite_email(
@require_member_or_admin
@has_request_variables
@typed_endpoint
def generate_multiuse_invite_backend(
request: HttpRequest,
user_profile: UserProfile,
invite_expires_in_minutes: int | None = REQ(
json_validator=check_none_or(check_int), default=INVITATION_LINK_VALIDITY_MINUTES
),
invite_as: int = REQ(
json_validator=check_int_in(
list(PreregistrationUser.INVITE_AS.values()),
),
default=PreregistrationUser.INVITE_AS["MEMBER"],
),
stream_ids: Sequence[int] = REQ(json_validator=check_list(check_int), default=[]),
include_realm_default_subscriptions: bool = REQ(json_validator=check_bool, default=False),
*,
invite_expires_in_minutes: Json[int | None] = INVITATION_LINK_VALIDITY_MINUTES,
invite_as: Annotated[
Json[int],
check_int_in_validator(list(PreregistrationUser.INVITE_AS.values())),
] = PreregistrationUser.INVITE_AS["MEMBER"],
stream_ids: Json[list[int]] | None = None,
include_realm_default_subscriptions: Json[bool] = False,
) -> HttpResponse:
if stream_ids is None:
stream_ids = []
if not user_profile.can_create_multiuse_invite_to_realm():
# Guest users case will not be handled here as it will
# be handled by the decorator above.

View File

@ -1,7 +1,7 @@
import logging
from collections.abc import Iterable
from contextlib import suppress
from typing import Any
from typing import Annotated, Any
from urllib.parse import urlencode, urljoin
import orjson
@ -19,6 +19,7 @@ from django.urls import reverse
from django.utils.translation import get_language
from django.views.defaults import server_error
from django_auth_ldap.backend import LDAPBackend, _LDAPUser
from pydantic import Json, NonNegativeInt, StringConstraints
from confirmation.models import (
Confirmation,
@ -59,19 +60,22 @@ from zerver.lib.i18n import (
)
from zerver.lib.pysa import mark_sanitized
from zerver.lib.rate_limiter import rate_limit_request_by_ip
from zerver.lib.request import REQ, has_request_variables
from zerver.lib.send_email import EmailNotDeliveredError, FromAddress, send_email
from zerver.lib.sessions import get_expirable_session_var
from zerver.lib.subdomains import get_subdomain
from zerver.lib.typed_endpoint import (
ApiParamConfig,
PathOnly,
typed_endpoint,
typed_endpoint_without_parameters,
)
from zerver.lib.typed_endpoint_validators import (
check_int_in_validator,
non_negative_int_or_none_validator,
timezone_or_empty_validator,
)
from zerver.lib.url_encoding import append_url_query_string
from zerver.lib.users import get_accounts_for_email
from zerver.lib.validator import (
check_capped_string,
check_int_in,
to_converted_or_fallback,
to_non_negative_int,
to_timezone_or_empty,
)
from zerver.lib.zephyr import compute_mit_user_fullname
from zerver.models import (
MultiuseInvite,
@ -119,9 +123,9 @@ if settings.BILLING_ENABLED:
from corporate.lib.stripe import LicenseLimitError
@has_request_variables
@typed_endpoint
def get_prereg_key_and_redirect(
request: HttpRequest, confirmation_key: str, full_name: str | None = REQ(default=None)
request: HttpRequest, *, confirmation_key: PathOnly[str], full_name: str | None = None
) -> HttpResponse:
"""
The purpose of this little endpoint is primarily to take a GET
@ -223,17 +227,16 @@ def accounts_register(*args: Any, **kwargs: Any) -> HttpResponse:
return registration_helper(*args, **kwargs)
@has_request_variables
@typed_endpoint
def registration_helper(
request: HttpRequest,
key: str = REQ(default=""),
timezone: str = REQ(default="", converter=to_timezone_or_empty),
from_confirmation: str | None = REQ(default=None),
form_full_name: str | None = REQ("full_name", default=None),
source_realm_id: int | None = REQ(
default=None, converter=to_converted_or_fallback(to_non_negative_int, None)
),
form_is_demo_organization: str | None = REQ("is_demo_organization", default=None),
*,
key: str = "",
timezone: Annotated[str, timezone_or_empty_validator()] = "",
from_confirmation: str | None = None,
form_full_name: Annotated[str | None, ApiParamConfig("full_name")] = None,
source_realm_id: Annotated[NonNegativeInt | None, non_negative_int_or_none_validator()] = None,
form_is_demo_organization: Annotated[str | None, ApiParamConfig("is_demo_organization")] = None,
) -> HttpResponse:
try:
prereg_object, realm_creation = check_prereg_key(request, key)
@ -958,8 +961,8 @@ def create_realm(request: HttpRequest, creation_key: str | None = None) -> HttpR
)
@has_request_variables
def signup_send_confirm(request: HttpRequest, email: str = REQ("email")) -> HttpResponse:
@typed_endpoint
def signup_send_confirm(request: HttpRequest, *, email: str) -> HttpResponse:
try:
# Because we interpolate the email directly into the template
# from the query parameter, do a simple validation that it
@ -980,14 +983,15 @@ def signup_send_confirm(request: HttpRequest, email: str = REQ("email")) -> Http
@add_google_analytics
@has_request_variables
@typed_endpoint
def new_realm_send_confirm(
request: HttpRequest,
email: str = REQ("email"),
realm_name: str = REQ(str_validator=check_capped_string(Realm.MAX_REALM_NAME_LENGTH)),
realm_type: int = REQ(json_validator=check_int_in(Realm.ORG_TYPE_IDS)),
realm_default_language: str = REQ(str_validator=check_capped_string(MAX_LANGUAGE_ID_LENGTH)),
realm_subdomain: str = REQ(str_validator=check_capped_string(Realm.MAX_REALM_SUBDOMAIN_LENGTH)),
*,
email: str,
realm_name: Annotated[str, StringConstraints(max_length=Realm.MAX_REALM_NAME_LENGTH)],
realm_type: Annotated[Json[int], check_int_in_validator(Realm.ORG_TYPE_IDS)],
realm_default_language: Annotated[str, StringConstraints(max_length=MAX_LANGUAGE_ID_LENGTH)],
realm_subdomain: Annotated[str, StringConstraints(max_length=Realm.MAX_REALM_SUBDOMAIN_LENGTH)],
) -> HttpResponse:
return TemplateResponse(
request,
@ -1113,7 +1117,7 @@ def accounts_home_from_multiuse_invite(request: HttpRequest, confirmation_key: s
)
@has_request_variables
@typed_endpoint_without_parameters
def find_account(request: HttpRequest) -> HttpResponse:
url = reverse("find_account")
form = FindMyTeamForm()
@ -1217,8 +1221,8 @@ def find_account(request: HttpRequest) -> HttpResponse:
)
@has_request_variables
def realm_redirect(request: HttpRequest, next: str = REQ(default="")) -> HttpResponse:
@typed_endpoint
def realm_redirect(request: HttpRequest, *, next: str = "") -> HttpResponse:
if request.method == "POST":
form = RealmRedirectForm(request.POST)
if form.is_valid():