diff --git a/zerver/lib/request.py b/zerver/lib/request.py index 6d9ea9d48b..c37a377884 100644 --- a/zerver/lib/request.py +++ b/zerver/lib/request.py @@ -324,7 +324,7 @@ def has_request_variables( ) -> Callable[Concatenate[HttpRequest, ParamT], ReturnT]: num_params = req_func.__code__.co_argcount default_param_values = cast(FunctionType, req_func).__defaults__ - if default_param_values is None: + if default_param_values is None: # nocoverage # No users of this path. default_param_values = () num_default_params = len(default_param_values) default_param_names = req_func.__code__.co_varnames[num_params - num_default_params :] @@ -392,7 +392,7 @@ def has_request_variables( if req_var in request.POST: val = request.POST[req_var] request_notes.processed_parameters.add(req_var) - elif req_var in request.GET: + elif req_var in request.GET: # nocoverage # No users of this path val = request.GET[req_var] request_notes.processed_parameters.add(req_var) else: diff --git a/zerver/tests/test_openapi.py b/zerver/tests/test_openapi.py index bc5ddc4acb..b25da69513 100644 --- a/zerver/tests/test_openapi.py +++ b/zerver/tests/test_openapi.py @@ -1,7 +1,6 @@ -import inspect import os from collections.abc import Callable, Mapping -from typing import Any, get_origin +from typing import Any from unittest.mock import MagicMock, patch import yaml @@ -10,7 +9,7 @@ from django.urls import URLPattern from django.utils import regex_helper from pydantic import TypeAdapter -from zerver.lib.request import _REQ, arguments_map +from zerver.lib.request import arguments_map from zerver.lib.rest import rest_dispatch from zerver.lib.test_classes import ZulipTestCase from zerver.lib.typed_endpoint import parse_view_func_signature @@ -307,21 +306,6 @@ so maybe we shouldn't mark it as intentionally undocumented in the URLs. msg += f"\n + {undocumented_path}" raise AssertionError(msg) - def get_standardized_argument_type(self, t: Any) -> type | tuple[type, object]: - """Given a type from the typing module such as List[str] or Union[str, int], - convert it into a corresponding Python type. Unions are mapped to a canonical - choice among the options. - E.g. typing.Union[typing.List[typing.Dict[str, typing.Any]], NoneType] - needs to be mapped to list.""" - - origin = get_origin(t) - - if origin is None: - # Then it's most likely one of the fundamental data types - # I.E. Not one of the data types from the "typing" module. - return t - raise AssertionError(f"Unknown origin {origin}") - def render_openapi_type_exception( self, function: Callable[..., HttpResponse], @@ -451,8 +435,8 @@ do not match the types declared in the implementation of {function.__name__}.\n" OpenAPI data defines a different type than that actually accepted by the function. Otherwise, we print out the exact differences for convenient debugging and raise an AssertionError.""" - # Iterate through the decorators to find the original function, wrapped - # by has_request_variables/typed_endpoint, so we can parse its + # Iterate through the decorators to find the original + # function, wrapped by typed_endpoint, so we can parse its # arguments. use_endpoint_decorator = False while (wrapped := getattr(function, "__wrapped__", None)) is not None: @@ -462,42 +446,10 @@ do not match the types declared in the implementation of {function.__name__}.\n" use_endpoint_decorator = True function = wrapped - if use_endpoint_decorator: + if len(openapi_parameters) > 0: + assert use_endpoint_decorator return self.validate_json_schema(function, openapi_parameters) - openapi_params: set[tuple[str, type | tuple[type, object]]] = set() - json_params: dict[str, type | tuple[type, object]] = {} - for openapi_parameter in openapi_parameters: - name = openapi_parameter.name - # This no longer happens in remaining has_request_variables endpoint. - assert not openapi_parameter.json_encoded - openapi_params.add((name, schema_type(openapi_parameter.value_schema))) - - function_params: set[tuple[str, type | tuple[type, object]]] = set() - - for pname, defval in inspect.signature(function).parameters.items(): - defval = defval.default - if isinstance(defval, _REQ): - # TODO: The below inference logic in cases where - # there's a converter function declared is incorrect. - # Theoretically, we could restructure the converter - # function model so that we can check what type it - # excepts to be passed to make validation here - # possible. - - vtype = self.get_standardized_argument_type(function.__annotations__[pname]) - vname = defval.post_var_name - assert vname is not None - # This no longer happens following typed_endpoint migrations. - assert vname not in json_params - function_params.add((vname, vtype)) - - # After the above operations `json_params` should be empty. - assert len(json_params) == 0 - diff = openapi_params - function_params - if diff: # nocoverage - self.render_openapi_type_exception(function, openapi_params, function_params, diff) - def check_openapi_arguments_for_view( self, pattern: URLPattern, diff --git a/zerver/views/auth.py b/zerver/views/auth.py index 24169a00a8..f46f617fc6 100644 --- a/zerver/views/auth.py +++ b/zerver/views/auth.py @@ -23,6 +23,7 @@ from django.utils.http import url_has_allowed_host_and_scheme from django.utils.translation import gettext as _ from django.views.decorators.csrf import csrf_exempt from django.views.decorators.http import require_safe +from pydantic import Json from social_django.utils import load_backend, load_strategy from two_factor.forms import BackupTokenForm from two_factor.views import LoginView as BaseTwoFactorLoginView @@ -59,15 +60,16 @@ from zerver.lib.mobile_auth_otp import otp_encrypt_api_key from zerver.lib.push_notifications import push_notifications_configured from zerver.lib.pysa import mark_sanitized from zerver.lib.realm_icon import realm_icon_url -from zerver.lib.request import REQ, RequestNotes, has_request_variables +from zerver.lib.request import RequestNotes from zerver.lib.response import json_success from zerver.lib.sessions import set_expirable_session_var from zerver.lib.subdomains import get_subdomain, is_subdomain_root_or_alias +from zerver.lib.typed_endpoint import typed_endpoint from zerver.lib.url_encoding import append_url_query_string from zerver.lib.user_agent import parse_user_agent from zerver.lib.users import get_api_key, get_users_for_api, is_2fa_verified from zerver.lib.utils import has_api_key_format -from zerver.lib.validator import check_bool, validate_login_email +from zerver.lib.validator import validate_login_email from zerver.models import ( MultiuseInvite, PreregistrationRealm, @@ -508,12 +510,13 @@ def create_response_for_otp_flow( @log_view_func -@has_request_variables +@typed_endpoint def remote_user_sso( request: HttpRequest, - mobile_flow_otp: str | None = REQ(default=None), - desktop_flow_otp: str | None = REQ(default=None), - next: str = REQ(default="/"), + *, + mobile_flow_otp: str | None = None, + desktop_flow_otp: str | None = None, + next: str = "/", ) -> HttpResponse: subdomain = get_subdomain(request) try: @@ -562,9 +565,8 @@ def remote_user_sso( return login_or_register_remote_user(request, result) -@has_request_variables def get_email_and_realm_from_jwt_authentication_request( - request: HttpRequest, json_web_token: str + request: HttpRequest, *, json_web_token: str ) -> tuple[str, Realm]: realm = get_realm_from_request(request) if realm is None: @@ -595,9 +597,11 @@ def get_email_and_realm_from_jwt_authentication_request( @csrf_exempt @require_post @log_view_func -@has_request_variables -def remote_user_jwt(request: HttpRequest, token: str = REQ(default="")) -> HttpResponse: - email, realm = get_email_and_realm_from_jwt_authentication_request(request, token) +@typed_endpoint +def remote_user_jwt(request: HttpRequest, *, token: str = "") -> HttpResponse: + email, realm = get_email_and_realm_from_jwt_authentication_request( + request, json_web_token=token + ) user_profile = authenticate(username=email, realm=realm, use_dummy_backend=True) if user_profile is None: @@ -611,17 +615,22 @@ def remote_user_jwt(request: HttpRequest, token: str = REQ(default="")) -> HttpR return login_or_register_remote_user(request, result) -@has_request_variables +@typed_endpoint def oauth_redirect_to_root( request: HttpRequest, url: str, sso_type: str, - is_signup: bool = False, - extra_url_params: Mapping[str, str] = {}, - next: str | None = REQ(default=None), - multiuse_object_key: str = REQ(default=""), - mobile_flow_otp: str | None = REQ(default=None), - desktop_flow_otp: str | None = REQ(default=None), + is_signup: bool, + extra_url_params: Mapping[str, str], + # Protect the above parameters from being processed as kwargs + # provided by @typed_endpoint by marking them as mandatory + # positional parameters. + /, + *, + next: str | None = None, + multiuse_object_key: str = "", + mobile_flow_otp: str | None = None, + desktop_flow_otp: str | None = None, ) -> HttpResponse: main_site_url = settings.ROOT_DOMAIN_URI + url if settings.SOCIAL_AUTH_SUBDOMAIN is not None and sso_type == "social": @@ -716,7 +725,13 @@ def start_social_login( if not (getattr(settings, key_setting) and getattr(settings, secret_setting)): return config_error(request, backend) - return oauth_redirect_to_root(request, backend_url, "social", extra_url_params=extra_url_params) + return oauth_redirect_to_root( + request, + backend_url, + "social", + False, + extra_url_params, + ) @handle_desktop_flow @@ -738,7 +753,11 @@ def start_social_signup( return config_error(request, "saml") extra_url_params = {"idp": extra_arg} return oauth_redirect_to_root( - request, backend_url, "social", is_signup=True, extra_url_params=extra_url_params + request, + backend_url, + "social", + True, + extra_url_params, ) @@ -862,11 +881,12 @@ class TwoFactorLoginView(BaseTwoFactorLoginView): return super().done(form_list, **kwargs) -@has_request_variables +@typed_endpoint def login_page( request: HttpRequest, /, - next: str = REQ(default="/"), + *, + next: str = "/", **kwargs: Any, ) -> HttpResponse: if get_subdomain(request) == settings.SOCIAL_AUTH_SUBDOMAIN: @@ -1021,13 +1041,16 @@ def get_api_key_fetch_authenticate_failure(return_data: dict[str, bool]) -> Json @csrf_exempt @require_post -@has_request_variables +@typed_endpoint def jwt_fetch_api_key( request: HttpRequest, - include_profile: bool = REQ(default=False, json_validator=check_bool), - token: str = REQ(default=""), + *, + include_profile: Json[bool] = False, + token: str = "", ) -> HttpResponse: - remote_email, realm = get_email_and_realm_from_jwt_authentication_request(request, token) + remote_email, realm = get_email_and_realm_from_jwt_authentication_request( + request, json_web_token=token + ) return_data: dict[str, bool] = {} @@ -1062,10 +1085,8 @@ def jwt_fetch_api_key( @csrf_exempt @require_post -@has_request_variables -def api_fetch_api_key( - request: HttpRequest, username: str = REQ(), password: str = REQ() -) -> HttpResponse: +@typed_endpoint +def api_fetch_api_key(request: HttpRequest, *, username: str, password: str) -> HttpResponse: return_data: dict[str, bool] = {} realm = get_realm_from_request(request) @@ -1160,9 +1181,9 @@ def api_get_server_settings(request: HttpRequest) -> HttpResponse: return json_success(request, data=result) -@has_request_variables +@typed_endpoint def json_fetch_api_key( - request: HttpRequest, user_profile: UserProfile, password: str = REQ(default="") + request: HttpRequest, user_profile: UserProfile, *, password: str = "" ) -> HttpResponse: realm = get_realm_from_request(request) if realm is None: