mirror of https://github.com/zulip/zulip.git
auth: Migrate to @typed_endpoint.
Since this is the last has_request_variables endpoint outside tests, more test_openapi code needs to be deleted in this transition.
This commit is contained in:
parent
a8ecba8ab8
commit
3da91e951c
|
@ -324,7 +324,7 @@ def has_request_variables(
|
||||||
) -> Callable[Concatenate[HttpRequest, ParamT], ReturnT]:
|
) -> Callable[Concatenate[HttpRequest, ParamT], ReturnT]:
|
||||||
num_params = req_func.__code__.co_argcount
|
num_params = req_func.__code__.co_argcount
|
||||||
default_param_values = cast(FunctionType, req_func).__defaults__
|
default_param_values = cast(FunctionType, req_func).__defaults__
|
||||||
if default_param_values is None:
|
if default_param_values is None: # nocoverage # No users of this path.
|
||||||
default_param_values = ()
|
default_param_values = ()
|
||||||
num_default_params = len(default_param_values)
|
num_default_params = len(default_param_values)
|
||||||
default_param_names = req_func.__code__.co_varnames[num_params - num_default_params :]
|
default_param_names = req_func.__code__.co_varnames[num_params - num_default_params :]
|
||||||
|
@ -392,7 +392,7 @@ def has_request_variables(
|
||||||
if req_var in request.POST:
|
if req_var in request.POST:
|
||||||
val = request.POST[req_var]
|
val = request.POST[req_var]
|
||||||
request_notes.processed_parameters.add(req_var)
|
request_notes.processed_parameters.add(req_var)
|
||||||
elif req_var in request.GET:
|
elif req_var in request.GET: # nocoverage # No users of this path
|
||||||
val = request.GET[req_var]
|
val = request.GET[req_var]
|
||||||
request_notes.processed_parameters.add(req_var)
|
request_notes.processed_parameters.add(req_var)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
import inspect
|
|
||||||
import os
|
import os
|
||||||
from collections.abc import Callable, Mapping
|
from collections.abc import Callable, Mapping
|
||||||
from typing import Any, get_origin
|
from typing import Any
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
@ -10,7 +9,7 @@ from django.urls import URLPattern
|
||||||
from django.utils import regex_helper
|
from django.utils import regex_helper
|
||||||
from pydantic import TypeAdapter
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
from zerver.lib.request import _REQ, arguments_map
|
from zerver.lib.request import arguments_map
|
||||||
from zerver.lib.rest import rest_dispatch
|
from zerver.lib.rest import rest_dispatch
|
||||||
from zerver.lib.test_classes import ZulipTestCase
|
from zerver.lib.test_classes import ZulipTestCase
|
||||||
from zerver.lib.typed_endpoint import parse_view_func_signature
|
from zerver.lib.typed_endpoint import parse_view_func_signature
|
||||||
|
@ -307,21 +306,6 @@ so maybe we shouldn't mark it as intentionally undocumented in the URLs.
|
||||||
msg += f"\n + {undocumented_path}"
|
msg += f"\n + {undocumented_path}"
|
||||||
raise AssertionError(msg)
|
raise AssertionError(msg)
|
||||||
|
|
||||||
def get_standardized_argument_type(self, t: Any) -> type | tuple[type, object]:
|
|
||||||
"""Given a type from the typing module such as List[str] or Union[str, int],
|
|
||||||
convert it into a corresponding Python type. Unions are mapped to a canonical
|
|
||||||
choice among the options.
|
|
||||||
E.g. typing.Union[typing.List[typing.Dict[str, typing.Any]], NoneType]
|
|
||||||
needs to be mapped to list."""
|
|
||||||
|
|
||||||
origin = get_origin(t)
|
|
||||||
|
|
||||||
if origin is None:
|
|
||||||
# Then it's most likely one of the fundamental data types
|
|
||||||
# I.E. Not one of the data types from the "typing" module.
|
|
||||||
return t
|
|
||||||
raise AssertionError(f"Unknown origin {origin}")
|
|
||||||
|
|
||||||
def render_openapi_type_exception(
|
def render_openapi_type_exception(
|
||||||
self,
|
self,
|
||||||
function: Callable[..., HttpResponse],
|
function: Callable[..., HttpResponse],
|
||||||
|
@ -451,8 +435,8 @@ do not match the types declared in the implementation of {function.__name__}.\n"
|
||||||
OpenAPI data defines a different type than that actually accepted by the function.
|
OpenAPI data defines a different type than that actually accepted by the function.
|
||||||
Otherwise, we print out the exact differences for convenient debugging and raise an
|
Otherwise, we print out the exact differences for convenient debugging and raise an
|
||||||
AssertionError."""
|
AssertionError."""
|
||||||
# Iterate through the decorators to find the original function, wrapped
|
# Iterate through the decorators to find the original
|
||||||
# by has_request_variables/typed_endpoint, so we can parse its
|
# function, wrapped by typed_endpoint, so we can parse its
|
||||||
# arguments.
|
# arguments.
|
||||||
use_endpoint_decorator = False
|
use_endpoint_decorator = False
|
||||||
while (wrapped := getattr(function, "__wrapped__", None)) is not None:
|
while (wrapped := getattr(function, "__wrapped__", None)) is not None:
|
||||||
|
@ -462,42 +446,10 @@ do not match the types declared in the implementation of {function.__name__}.\n"
|
||||||
use_endpoint_decorator = True
|
use_endpoint_decorator = True
|
||||||
function = wrapped
|
function = wrapped
|
||||||
|
|
||||||
if use_endpoint_decorator:
|
if len(openapi_parameters) > 0:
|
||||||
|
assert use_endpoint_decorator
|
||||||
return self.validate_json_schema(function, openapi_parameters)
|
return self.validate_json_schema(function, openapi_parameters)
|
||||||
|
|
||||||
openapi_params: set[tuple[str, type | tuple[type, object]]] = set()
|
|
||||||
json_params: dict[str, type | tuple[type, object]] = {}
|
|
||||||
for openapi_parameter in openapi_parameters:
|
|
||||||
name = openapi_parameter.name
|
|
||||||
# This no longer happens in remaining has_request_variables endpoint.
|
|
||||||
assert not openapi_parameter.json_encoded
|
|
||||||
openapi_params.add((name, schema_type(openapi_parameter.value_schema)))
|
|
||||||
|
|
||||||
function_params: set[tuple[str, type | tuple[type, object]]] = set()
|
|
||||||
|
|
||||||
for pname, defval in inspect.signature(function).parameters.items():
|
|
||||||
defval = defval.default
|
|
||||||
if isinstance(defval, _REQ):
|
|
||||||
# TODO: The below inference logic in cases where
|
|
||||||
# there's a converter function declared is incorrect.
|
|
||||||
# Theoretically, we could restructure the converter
|
|
||||||
# function model so that we can check what type it
|
|
||||||
# excepts to be passed to make validation here
|
|
||||||
# possible.
|
|
||||||
|
|
||||||
vtype = self.get_standardized_argument_type(function.__annotations__[pname])
|
|
||||||
vname = defval.post_var_name
|
|
||||||
assert vname is not None
|
|
||||||
# This no longer happens following typed_endpoint migrations.
|
|
||||||
assert vname not in json_params
|
|
||||||
function_params.add((vname, vtype))
|
|
||||||
|
|
||||||
# After the above operations `json_params` should be empty.
|
|
||||||
assert len(json_params) == 0
|
|
||||||
diff = openapi_params - function_params
|
|
||||||
if diff: # nocoverage
|
|
||||||
self.render_openapi_type_exception(function, openapi_params, function_params, diff)
|
|
||||||
|
|
||||||
def check_openapi_arguments_for_view(
|
def check_openapi_arguments_for_view(
|
||||||
self,
|
self,
|
||||||
pattern: URLPattern,
|
pattern: URLPattern,
|
||||||
|
|
|
@ -23,6 +23,7 @@ from django.utils.http import url_has_allowed_host_and_scheme
|
||||||
from django.utils.translation import gettext as _
|
from django.utils.translation import gettext as _
|
||||||
from django.views.decorators.csrf import csrf_exempt
|
from django.views.decorators.csrf import csrf_exempt
|
||||||
from django.views.decorators.http import require_safe
|
from django.views.decorators.http import require_safe
|
||||||
|
from pydantic import Json
|
||||||
from social_django.utils import load_backend, load_strategy
|
from social_django.utils import load_backend, load_strategy
|
||||||
from two_factor.forms import BackupTokenForm
|
from two_factor.forms import BackupTokenForm
|
||||||
from two_factor.views import LoginView as BaseTwoFactorLoginView
|
from two_factor.views import LoginView as BaseTwoFactorLoginView
|
||||||
|
@ -59,15 +60,16 @@ from zerver.lib.mobile_auth_otp import otp_encrypt_api_key
|
||||||
from zerver.lib.push_notifications import push_notifications_configured
|
from zerver.lib.push_notifications import push_notifications_configured
|
||||||
from zerver.lib.pysa import mark_sanitized
|
from zerver.lib.pysa import mark_sanitized
|
||||||
from zerver.lib.realm_icon import realm_icon_url
|
from zerver.lib.realm_icon import realm_icon_url
|
||||||
from zerver.lib.request import REQ, RequestNotes, has_request_variables
|
from zerver.lib.request import RequestNotes
|
||||||
from zerver.lib.response import json_success
|
from zerver.lib.response import json_success
|
||||||
from zerver.lib.sessions import set_expirable_session_var
|
from zerver.lib.sessions import set_expirable_session_var
|
||||||
from zerver.lib.subdomains import get_subdomain, is_subdomain_root_or_alias
|
from zerver.lib.subdomains import get_subdomain, is_subdomain_root_or_alias
|
||||||
|
from zerver.lib.typed_endpoint import typed_endpoint
|
||||||
from zerver.lib.url_encoding import append_url_query_string
|
from zerver.lib.url_encoding import append_url_query_string
|
||||||
from zerver.lib.user_agent import parse_user_agent
|
from zerver.lib.user_agent import parse_user_agent
|
||||||
from zerver.lib.users import get_api_key, get_users_for_api, is_2fa_verified
|
from zerver.lib.users import get_api_key, get_users_for_api, is_2fa_verified
|
||||||
from zerver.lib.utils import has_api_key_format
|
from zerver.lib.utils import has_api_key_format
|
||||||
from zerver.lib.validator import check_bool, validate_login_email
|
from zerver.lib.validator import validate_login_email
|
||||||
from zerver.models import (
|
from zerver.models import (
|
||||||
MultiuseInvite,
|
MultiuseInvite,
|
||||||
PreregistrationRealm,
|
PreregistrationRealm,
|
||||||
|
@ -508,12 +510,13 @@ def create_response_for_otp_flow(
|
||||||
|
|
||||||
|
|
||||||
@log_view_func
|
@log_view_func
|
||||||
@has_request_variables
|
@typed_endpoint
|
||||||
def remote_user_sso(
|
def remote_user_sso(
|
||||||
request: HttpRequest,
|
request: HttpRequest,
|
||||||
mobile_flow_otp: str | None = REQ(default=None),
|
*,
|
||||||
desktop_flow_otp: str | None = REQ(default=None),
|
mobile_flow_otp: str | None = None,
|
||||||
next: str = REQ(default="/"),
|
desktop_flow_otp: str | None = None,
|
||||||
|
next: str = "/",
|
||||||
) -> HttpResponse:
|
) -> HttpResponse:
|
||||||
subdomain = get_subdomain(request)
|
subdomain = get_subdomain(request)
|
||||||
try:
|
try:
|
||||||
|
@ -562,9 +565,8 @@ def remote_user_sso(
|
||||||
return login_or_register_remote_user(request, result)
|
return login_or_register_remote_user(request, result)
|
||||||
|
|
||||||
|
|
||||||
@has_request_variables
|
|
||||||
def get_email_and_realm_from_jwt_authentication_request(
|
def get_email_and_realm_from_jwt_authentication_request(
|
||||||
request: HttpRequest, json_web_token: str
|
request: HttpRequest, *, json_web_token: str
|
||||||
) -> tuple[str, Realm]:
|
) -> tuple[str, Realm]:
|
||||||
realm = get_realm_from_request(request)
|
realm = get_realm_from_request(request)
|
||||||
if realm is None:
|
if realm is None:
|
||||||
|
@ -595,9 +597,11 @@ def get_email_and_realm_from_jwt_authentication_request(
|
||||||
@csrf_exempt
|
@csrf_exempt
|
||||||
@require_post
|
@require_post
|
||||||
@log_view_func
|
@log_view_func
|
||||||
@has_request_variables
|
@typed_endpoint
|
||||||
def remote_user_jwt(request: HttpRequest, token: str = REQ(default="")) -> HttpResponse:
|
def remote_user_jwt(request: HttpRequest, *, token: str = "") -> HttpResponse:
|
||||||
email, realm = get_email_and_realm_from_jwt_authentication_request(request, token)
|
email, realm = get_email_and_realm_from_jwt_authentication_request(
|
||||||
|
request, json_web_token=token
|
||||||
|
)
|
||||||
|
|
||||||
user_profile = authenticate(username=email, realm=realm, use_dummy_backend=True)
|
user_profile = authenticate(username=email, realm=realm, use_dummy_backend=True)
|
||||||
if user_profile is None:
|
if user_profile is None:
|
||||||
|
@ -611,17 +615,22 @@ def remote_user_jwt(request: HttpRequest, token: str = REQ(default="")) -> HttpR
|
||||||
return login_or_register_remote_user(request, result)
|
return login_or_register_remote_user(request, result)
|
||||||
|
|
||||||
|
|
||||||
@has_request_variables
|
@typed_endpoint
|
||||||
def oauth_redirect_to_root(
|
def oauth_redirect_to_root(
|
||||||
request: HttpRequest,
|
request: HttpRequest,
|
||||||
url: str,
|
url: str,
|
||||||
sso_type: str,
|
sso_type: str,
|
||||||
is_signup: bool = False,
|
is_signup: bool,
|
||||||
extra_url_params: Mapping[str, str] = {},
|
extra_url_params: Mapping[str, str],
|
||||||
next: str | None = REQ(default=None),
|
# Protect the above parameters from being processed as kwargs
|
||||||
multiuse_object_key: str = REQ(default=""),
|
# provided by @typed_endpoint by marking them as mandatory
|
||||||
mobile_flow_otp: str | None = REQ(default=None),
|
# positional parameters.
|
||||||
desktop_flow_otp: str | None = REQ(default=None),
|
/,
|
||||||
|
*,
|
||||||
|
next: str | None = None,
|
||||||
|
multiuse_object_key: str = "",
|
||||||
|
mobile_flow_otp: str | None = None,
|
||||||
|
desktop_flow_otp: str | None = None,
|
||||||
) -> HttpResponse:
|
) -> HttpResponse:
|
||||||
main_site_url = settings.ROOT_DOMAIN_URI + url
|
main_site_url = settings.ROOT_DOMAIN_URI + url
|
||||||
if settings.SOCIAL_AUTH_SUBDOMAIN is not None and sso_type == "social":
|
if settings.SOCIAL_AUTH_SUBDOMAIN is not None and sso_type == "social":
|
||||||
|
@ -716,7 +725,13 @@ def start_social_login(
|
||||||
if not (getattr(settings, key_setting) and getattr(settings, secret_setting)):
|
if not (getattr(settings, key_setting) and getattr(settings, secret_setting)):
|
||||||
return config_error(request, backend)
|
return config_error(request, backend)
|
||||||
|
|
||||||
return oauth_redirect_to_root(request, backend_url, "social", extra_url_params=extra_url_params)
|
return oauth_redirect_to_root(
|
||||||
|
request,
|
||||||
|
backend_url,
|
||||||
|
"social",
|
||||||
|
False,
|
||||||
|
extra_url_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@handle_desktop_flow
|
@handle_desktop_flow
|
||||||
|
@ -738,7 +753,11 @@ def start_social_signup(
|
||||||
return config_error(request, "saml")
|
return config_error(request, "saml")
|
||||||
extra_url_params = {"idp": extra_arg}
|
extra_url_params = {"idp": extra_arg}
|
||||||
return oauth_redirect_to_root(
|
return oauth_redirect_to_root(
|
||||||
request, backend_url, "social", is_signup=True, extra_url_params=extra_url_params
|
request,
|
||||||
|
backend_url,
|
||||||
|
"social",
|
||||||
|
True,
|
||||||
|
extra_url_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -862,11 +881,12 @@ class TwoFactorLoginView(BaseTwoFactorLoginView):
|
||||||
return super().done(form_list, **kwargs)
|
return super().done(form_list, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@has_request_variables
|
@typed_endpoint
|
||||||
def login_page(
|
def login_page(
|
||||||
request: HttpRequest,
|
request: HttpRequest,
|
||||||
/,
|
/,
|
||||||
next: str = REQ(default="/"),
|
*,
|
||||||
|
next: str = "/",
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> HttpResponse:
|
) -> HttpResponse:
|
||||||
if get_subdomain(request) == settings.SOCIAL_AUTH_SUBDOMAIN:
|
if get_subdomain(request) == settings.SOCIAL_AUTH_SUBDOMAIN:
|
||||||
|
@ -1021,13 +1041,16 @@ def get_api_key_fetch_authenticate_failure(return_data: dict[str, bool]) -> Json
|
||||||
|
|
||||||
@csrf_exempt
|
@csrf_exempt
|
||||||
@require_post
|
@require_post
|
||||||
@has_request_variables
|
@typed_endpoint
|
||||||
def jwt_fetch_api_key(
|
def jwt_fetch_api_key(
|
||||||
request: HttpRequest,
|
request: HttpRequest,
|
||||||
include_profile: bool = REQ(default=False, json_validator=check_bool),
|
*,
|
||||||
token: str = REQ(default=""),
|
include_profile: Json[bool] = False,
|
||||||
|
token: str = "",
|
||||||
) -> HttpResponse:
|
) -> HttpResponse:
|
||||||
remote_email, realm = get_email_and_realm_from_jwt_authentication_request(request, token)
|
remote_email, realm = get_email_and_realm_from_jwt_authentication_request(
|
||||||
|
request, json_web_token=token
|
||||||
|
)
|
||||||
|
|
||||||
return_data: dict[str, bool] = {}
|
return_data: dict[str, bool] = {}
|
||||||
|
|
||||||
|
@ -1062,10 +1085,8 @@ def jwt_fetch_api_key(
|
||||||
|
|
||||||
@csrf_exempt
|
@csrf_exempt
|
||||||
@require_post
|
@require_post
|
||||||
@has_request_variables
|
@typed_endpoint
|
||||||
def api_fetch_api_key(
|
def api_fetch_api_key(request: HttpRequest, *, username: str, password: str) -> HttpResponse:
|
||||||
request: HttpRequest, username: str = REQ(), password: str = REQ()
|
|
||||||
) -> HttpResponse:
|
|
||||||
return_data: dict[str, bool] = {}
|
return_data: dict[str, bool] = {}
|
||||||
|
|
||||||
realm = get_realm_from_request(request)
|
realm = get_realm_from_request(request)
|
||||||
|
@ -1160,9 +1181,9 @@ def api_get_server_settings(request: HttpRequest) -> HttpResponse:
|
||||||
return json_success(request, data=result)
|
return json_success(request, data=result)
|
||||||
|
|
||||||
|
|
||||||
@has_request_variables
|
@typed_endpoint
|
||||||
def json_fetch_api_key(
|
def json_fetch_api_key(
|
||||||
request: HttpRequest, user_profile: UserProfile, password: str = REQ(default="")
|
request: HttpRequest, user_profile: UserProfile, *, password: str = ""
|
||||||
) -> HttpResponse:
|
) -> HttpResponse:
|
||||||
realm = get_realm_from_request(request)
|
realm = get_realm_from_request(request)
|
||||||
if realm is None:
|
if realm is None:
|
||||||
|
|
Loading…
Reference in New Issue