diff --git a/zerver/decorator.py b/zerver/decorator.py index ecb881b2a6..8e6d0a1ee0 100644 --- a/zerver/decorator.py +++ b/zerver/decorator.py @@ -49,7 +49,6 @@ from zerver.lib.request import REQ, RequestNotes, has_request_variables from zerver.lib.response import json_method_not_allowed, json_success from zerver.lib.subdomains import get_subdomain, user_matches_subdomain from zerver.lib.timestamp import datetime_to_timestamp, timestamp_to_datetime -from zerver.lib.types import ViewFuncT from zerver.lib.users import is_2fa_verified from zerver.lib.utils import has_api_key_format, statsd from zerver.models import Realm, UserProfile, get_client, get_user_profile_by_api_key @@ -723,15 +722,20 @@ def authenticated_rest_api_view( allow_webhook_access: bool = False, skip_rate_limiting: bool = False, beanstalk_email_decode: bool = False, -) -> Callable[[Callable[..., HttpResponse]], Callable[..., HttpResponse]]: +) -> Callable[ + [Callable[Concatenate[HttpRequest, UserProfile, ParamT], HttpResponse]], + Callable[Concatenate[HttpRequest, ParamT], HttpResponse], +]: if webhook_client_name is not None: allow_webhook_access = True - def _wrapped_view_func(view_func: Callable[..., HttpResponse]) -> Callable[..., HttpResponse]: + def _wrapped_view_func( + view_func: Callable[Concatenate[HttpRequest, UserProfile, ParamT], HttpResponse] + ) -> Callable[Concatenate[HttpRequest, ParamT], HttpResponse]: @csrf_exempt @wraps(view_func) def _wrapped_func_arguments( - request: HttpRequest, *args: object, **kwargs: object + request: HttpRequest, /, *args: ParamT.args, **kwargs: ParamT.kwargs ) -> HttpResponse: # First try block attempts to get the credentials we need to do authentication try: @@ -832,56 +836,43 @@ def process_as_post( return _wrapped_view_func -def authenticate_log_and_execute_json( - request: HttpRequest, - view_func: ViewFuncT, - *args: object, - skip_rate_limiting: bool = False, - allow_unauthenticated: bool = False, - **kwargs: object, -) -> HttpResponse: - if not skip_rate_limiting: - rate_limit(request) - - if not request.user.is_authenticated: - if not allow_unauthenticated: - raise UnauthorizedError() - - process_client( - request, - is_browser_view=True, - query=view_func.__name__, - ) - return view_func(request, request.user, *args, **kwargs) - - user_profile = request.user - validate_account_and_subdomain(request, user_profile) - - if user_profile.is_incoming_webhook: - raise JsonableError(_("Webhook bots can only access webhooks")) - - process_client(request, user_profile, is_browser_view=True, query=view_func.__name__) - return view_func(request, user_profile, *args, **kwargs) - - # Checks if the user is logged in. If not, return an error (the # @login_required behavior of redirecting to a login page doesn't make # sense for json views) def authenticated_json_view( - view_func: Callable[..., HttpResponse], + view_func: Callable[Concatenate[HttpRequest, UserProfile, ParamT], HttpResponse], skip_rate_limiting: bool = False, allow_unauthenticated: bool = False, -) -> Callable[..., HttpResponse]: +) -> Callable[Concatenate[HttpRequest, ParamT], HttpResponse]: @wraps(view_func) - def _wrapped_view_func(request: HttpRequest, *args: object, **kwargs: object) -> HttpResponse: - return authenticate_log_and_execute_json( - request, - view_func, - *args, - skip_rate_limiting=skip_rate_limiting, - allow_unauthenticated=allow_unauthenticated, - **kwargs, - ) + def _wrapped_view_func( + request: HttpRequest, + /, + *args: ParamT.args, + **kwargs: ParamT.kwargs, + ) -> HttpResponse: + if not skip_rate_limiting: + rate_limit(request) + + if not request.user.is_authenticated: + if not allow_unauthenticated: + raise UnauthorizedError() + + process_client( + request, + is_browser_view=True, + query=view_func.__name__, + ) + return view_func(request, request.user, *args, **kwargs) + + user_profile = request.user + validate_account_and_subdomain(request, user_profile) + + if user_profile.is_incoming_webhook: + raise JsonableError(_("Webhook bots can only access webhooks")) + + process_client(request, user_profile, is_browser_view=True, query=view_func.__name__) + return view_func(request, user_profile, *args, **kwargs) return _wrapped_view_func