auth: Migrate to @typed_endpoint.

Since this is the last has_request_variables endpoint outside tests,
more test_openapi code needs to be deleted in this transition.
This commit is contained in:
bedo 2024-07-17 01:56:17 +03:00 committed by Tim Abbott
parent a8ecba8ab8
commit 3da91e951c
3 changed files with 61 additions and 88 deletions

View File

@ -324,7 +324,7 @@ def has_request_variables(
) -> Callable[Concatenate[HttpRequest, ParamT], ReturnT]: ) -> Callable[Concatenate[HttpRequest, ParamT], ReturnT]:
num_params = req_func.__code__.co_argcount num_params = req_func.__code__.co_argcount
default_param_values = cast(FunctionType, req_func).__defaults__ default_param_values = cast(FunctionType, req_func).__defaults__
if default_param_values is None: if default_param_values is None: # nocoverage # No users of this path.
default_param_values = () default_param_values = ()
num_default_params = len(default_param_values) num_default_params = len(default_param_values)
default_param_names = req_func.__code__.co_varnames[num_params - num_default_params :] default_param_names = req_func.__code__.co_varnames[num_params - num_default_params :]
@ -392,7 +392,7 @@ def has_request_variables(
if req_var in request.POST: if req_var in request.POST:
val = request.POST[req_var] val = request.POST[req_var]
request_notes.processed_parameters.add(req_var) request_notes.processed_parameters.add(req_var)
elif req_var in request.GET: elif req_var in request.GET: # nocoverage # No users of this path
val = request.GET[req_var] val = request.GET[req_var]
request_notes.processed_parameters.add(req_var) request_notes.processed_parameters.add(req_var)
else: else:

View File

@ -1,7 +1,6 @@
import inspect
import os import os
from collections.abc import Callable, Mapping from collections.abc import Callable, Mapping
from typing import Any, get_origin from typing import Any
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import yaml import yaml
@ -10,7 +9,7 @@ from django.urls import URLPattern
from django.utils import regex_helper from django.utils import regex_helper
from pydantic import TypeAdapter from pydantic import TypeAdapter
from zerver.lib.request import _REQ, arguments_map from zerver.lib.request import arguments_map
from zerver.lib.rest import rest_dispatch from zerver.lib.rest import rest_dispatch
from zerver.lib.test_classes import ZulipTestCase from zerver.lib.test_classes import ZulipTestCase
from zerver.lib.typed_endpoint import parse_view_func_signature from zerver.lib.typed_endpoint import parse_view_func_signature
@ -307,21 +306,6 @@ so maybe we shouldn't mark it as intentionally undocumented in the URLs.
msg += f"\n + {undocumented_path}" msg += f"\n + {undocumented_path}"
raise AssertionError(msg) raise AssertionError(msg)
def get_standardized_argument_type(self, t: Any) -> type | tuple[type, object]:
"""Given a type from the typing module such as List[str] or Union[str, int],
convert it into a corresponding Python type. Unions are mapped to a canonical
choice among the options.
E.g. typing.Union[typing.List[typing.Dict[str, typing.Any]], NoneType]
needs to be mapped to list."""
origin = get_origin(t)
if origin is None:
# Then it's most likely one of the fundamental data types
# I.E. Not one of the data types from the "typing" module.
return t
raise AssertionError(f"Unknown origin {origin}")
def render_openapi_type_exception( def render_openapi_type_exception(
self, self,
function: Callable[..., HttpResponse], function: Callable[..., HttpResponse],
@ -451,8 +435,8 @@ do not match the types declared in the implementation of {function.__name__}.\n"
OpenAPI data defines a different type than that actually accepted by the function. OpenAPI data defines a different type than that actually accepted by the function.
Otherwise, we print out the exact differences for convenient debugging and raise an Otherwise, we print out the exact differences for convenient debugging and raise an
AssertionError.""" AssertionError."""
# Iterate through the decorators to find the original function, wrapped # Iterate through the decorators to find the original
# by has_request_variables/typed_endpoint, so we can parse its # function, wrapped by typed_endpoint, so we can parse its
# arguments. # arguments.
use_endpoint_decorator = False use_endpoint_decorator = False
while (wrapped := getattr(function, "__wrapped__", None)) is not None: while (wrapped := getattr(function, "__wrapped__", None)) is not None:
@ -462,42 +446,10 @@ do not match the types declared in the implementation of {function.__name__}.\n"
use_endpoint_decorator = True use_endpoint_decorator = True
function = wrapped function = wrapped
if use_endpoint_decorator: if len(openapi_parameters) > 0:
assert use_endpoint_decorator
return self.validate_json_schema(function, openapi_parameters) return self.validate_json_schema(function, openapi_parameters)
openapi_params: set[tuple[str, type | tuple[type, object]]] = set()
json_params: dict[str, type | tuple[type, object]] = {}
for openapi_parameter in openapi_parameters:
name = openapi_parameter.name
# This no longer happens in remaining has_request_variables endpoint.
assert not openapi_parameter.json_encoded
openapi_params.add((name, schema_type(openapi_parameter.value_schema)))
function_params: set[tuple[str, type | tuple[type, object]]] = set()
for pname, defval in inspect.signature(function).parameters.items():
defval = defval.default
if isinstance(defval, _REQ):
# TODO: The below inference logic in cases where
# there's a converter function declared is incorrect.
# Theoretically, we could restructure the converter
# function model so that we can check what type it
# excepts to be passed to make validation here
# possible.
vtype = self.get_standardized_argument_type(function.__annotations__[pname])
vname = defval.post_var_name
assert vname is not None
# This no longer happens following typed_endpoint migrations.
assert vname not in json_params
function_params.add((vname, vtype))
# After the above operations `json_params` should be empty.
assert len(json_params) == 0
diff = openapi_params - function_params
if diff: # nocoverage
self.render_openapi_type_exception(function, openapi_params, function_params, diff)
def check_openapi_arguments_for_view( def check_openapi_arguments_for_view(
self, self,
pattern: URLPattern, pattern: URLPattern,

View File

@ -23,6 +23,7 @@ from django.utils.http import url_has_allowed_host_and_scheme
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from django.views.decorators.csrf import csrf_exempt from django.views.decorators.csrf import csrf_exempt
from django.views.decorators.http import require_safe from django.views.decorators.http import require_safe
from pydantic import Json
from social_django.utils import load_backend, load_strategy from social_django.utils import load_backend, load_strategy
from two_factor.forms import BackupTokenForm from two_factor.forms import BackupTokenForm
from two_factor.views import LoginView as BaseTwoFactorLoginView from two_factor.views import LoginView as BaseTwoFactorLoginView
@ -59,15 +60,16 @@ from zerver.lib.mobile_auth_otp import otp_encrypt_api_key
from zerver.lib.push_notifications import push_notifications_configured from zerver.lib.push_notifications import push_notifications_configured
from zerver.lib.pysa import mark_sanitized from zerver.lib.pysa import mark_sanitized
from zerver.lib.realm_icon import realm_icon_url from zerver.lib.realm_icon import realm_icon_url
from zerver.lib.request import REQ, RequestNotes, has_request_variables from zerver.lib.request import RequestNotes
from zerver.lib.response import json_success from zerver.lib.response import json_success
from zerver.lib.sessions import set_expirable_session_var from zerver.lib.sessions import set_expirable_session_var
from zerver.lib.subdomains import get_subdomain, is_subdomain_root_or_alias from zerver.lib.subdomains import get_subdomain, is_subdomain_root_or_alias
from zerver.lib.typed_endpoint import typed_endpoint
from zerver.lib.url_encoding import append_url_query_string from zerver.lib.url_encoding import append_url_query_string
from zerver.lib.user_agent import parse_user_agent from zerver.lib.user_agent import parse_user_agent
from zerver.lib.users import get_api_key, get_users_for_api, is_2fa_verified from zerver.lib.users import get_api_key, get_users_for_api, is_2fa_verified
from zerver.lib.utils import has_api_key_format from zerver.lib.utils import has_api_key_format
from zerver.lib.validator import check_bool, validate_login_email from zerver.lib.validator import validate_login_email
from zerver.models import ( from zerver.models import (
MultiuseInvite, MultiuseInvite,
PreregistrationRealm, PreregistrationRealm,
@ -508,12 +510,13 @@ def create_response_for_otp_flow(
@log_view_func @log_view_func
@has_request_variables @typed_endpoint
def remote_user_sso( def remote_user_sso(
request: HttpRequest, request: HttpRequest,
mobile_flow_otp: str | None = REQ(default=None), *,
desktop_flow_otp: str | None = REQ(default=None), mobile_flow_otp: str | None = None,
next: str = REQ(default="/"), desktop_flow_otp: str | None = None,
next: str = "/",
) -> HttpResponse: ) -> HttpResponse:
subdomain = get_subdomain(request) subdomain = get_subdomain(request)
try: try:
@ -562,9 +565,8 @@ def remote_user_sso(
return login_or_register_remote_user(request, result) return login_or_register_remote_user(request, result)
@has_request_variables
def get_email_and_realm_from_jwt_authentication_request( def get_email_and_realm_from_jwt_authentication_request(
request: HttpRequest, json_web_token: str request: HttpRequest, *, json_web_token: str
) -> tuple[str, Realm]: ) -> tuple[str, Realm]:
realm = get_realm_from_request(request) realm = get_realm_from_request(request)
if realm is None: if realm is None:
@ -595,9 +597,11 @@ def get_email_and_realm_from_jwt_authentication_request(
@csrf_exempt @csrf_exempt
@require_post @require_post
@log_view_func @log_view_func
@has_request_variables @typed_endpoint
def remote_user_jwt(request: HttpRequest, token: str = REQ(default="")) -> HttpResponse: def remote_user_jwt(request: HttpRequest, *, token: str = "") -> HttpResponse:
email, realm = get_email_and_realm_from_jwt_authentication_request(request, token) email, realm = get_email_and_realm_from_jwt_authentication_request(
request, json_web_token=token
)
user_profile = authenticate(username=email, realm=realm, use_dummy_backend=True) user_profile = authenticate(username=email, realm=realm, use_dummy_backend=True)
if user_profile is None: if user_profile is None:
@ -611,17 +615,22 @@ def remote_user_jwt(request: HttpRequest, token: str = REQ(default="")) -> HttpR
return login_or_register_remote_user(request, result) return login_or_register_remote_user(request, result)
@has_request_variables @typed_endpoint
def oauth_redirect_to_root( def oauth_redirect_to_root(
request: HttpRequest, request: HttpRequest,
url: str, url: str,
sso_type: str, sso_type: str,
is_signup: bool = False, is_signup: bool,
extra_url_params: Mapping[str, str] = {}, extra_url_params: Mapping[str, str],
next: str | None = REQ(default=None), # Protect the above parameters from being processed as kwargs
multiuse_object_key: str = REQ(default=""), # provided by @typed_endpoint by marking them as mandatory
mobile_flow_otp: str | None = REQ(default=None), # positional parameters.
desktop_flow_otp: str | None = REQ(default=None), /,
*,
next: str | None = None,
multiuse_object_key: str = "",
mobile_flow_otp: str | None = None,
desktop_flow_otp: str | None = None,
) -> HttpResponse: ) -> HttpResponse:
main_site_url = settings.ROOT_DOMAIN_URI + url main_site_url = settings.ROOT_DOMAIN_URI + url
if settings.SOCIAL_AUTH_SUBDOMAIN is not None and sso_type == "social": if settings.SOCIAL_AUTH_SUBDOMAIN is not None and sso_type == "social":
@ -716,7 +725,13 @@ def start_social_login(
if not (getattr(settings, key_setting) and getattr(settings, secret_setting)): if not (getattr(settings, key_setting) and getattr(settings, secret_setting)):
return config_error(request, backend) return config_error(request, backend)
return oauth_redirect_to_root(request, backend_url, "social", extra_url_params=extra_url_params) return oauth_redirect_to_root(
request,
backend_url,
"social",
False,
extra_url_params,
)
@handle_desktop_flow @handle_desktop_flow
@ -738,7 +753,11 @@ def start_social_signup(
return config_error(request, "saml") return config_error(request, "saml")
extra_url_params = {"idp": extra_arg} extra_url_params = {"idp": extra_arg}
return oauth_redirect_to_root( return oauth_redirect_to_root(
request, backend_url, "social", is_signup=True, extra_url_params=extra_url_params request,
backend_url,
"social",
True,
extra_url_params,
) )
@ -862,11 +881,12 @@ class TwoFactorLoginView(BaseTwoFactorLoginView):
return super().done(form_list, **kwargs) return super().done(form_list, **kwargs)
@has_request_variables @typed_endpoint
def login_page( def login_page(
request: HttpRequest, request: HttpRequest,
/, /,
next: str = REQ(default="/"), *,
next: str = "/",
**kwargs: Any, **kwargs: Any,
) -> HttpResponse: ) -> HttpResponse:
if get_subdomain(request) == settings.SOCIAL_AUTH_SUBDOMAIN: if get_subdomain(request) == settings.SOCIAL_AUTH_SUBDOMAIN:
@ -1021,13 +1041,16 @@ def get_api_key_fetch_authenticate_failure(return_data: dict[str, bool]) -> Json
@csrf_exempt @csrf_exempt
@require_post @require_post
@has_request_variables @typed_endpoint
def jwt_fetch_api_key( def jwt_fetch_api_key(
request: HttpRequest, request: HttpRequest,
include_profile: bool = REQ(default=False, json_validator=check_bool), *,
token: str = REQ(default=""), include_profile: Json[bool] = False,
token: str = "",
) -> HttpResponse: ) -> HttpResponse:
remote_email, realm = get_email_and_realm_from_jwt_authentication_request(request, token) remote_email, realm = get_email_and_realm_from_jwt_authentication_request(
request, json_web_token=token
)
return_data: dict[str, bool] = {} return_data: dict[str, bool] = {}
@ -1062,10 +1085,8 @@ def jwt_fetch_api_key(
@csrf_exempt @csrf_exempt
@require_post @require_post
@has_request_variables @typed_endpoint
def api_fetch_api_key( def api_fetch_api_key(request: HttpRequest, *, username: str, password: str) -> HttpResponse:
request: HttpRequest, username: str = REQ(), password: str = REQ()
) -> HttpResponse:
return_data: dict[str, bool] = {} return_data: dict[str, bool] = {}
realm = get_realm_from_request(request) realm = get_realm_from_request(request)
@ -1160,9 +1181,9 @@ def api_get_server_settings(request: HttpRequest) -> HttpResponse:
return json_success(request, data=result) return json_success(request, data=result)
@has_request_variables @typed_endpoint
def json_fetch_api_key( def json_fetch_api_key(
request: HttpRequest, user_profile: UserProfile, password: str = REQ(default="") request: HttpRequest, user_profile: UserProfile, *, password: str = ""
) -> HttpResponse: ) -> HttpResponse:
realm = get_realm_from_request(request) realm = get_realm_from_request(request)
if realm is None: if realm is None: