From 232ba4866a308e94e832376f5b988dfbcf75303d Mon Sep 17 00:00:00 2001 From: Zixuan James Li Date: Thu, 28 Jul 2022 07:59:22 -0400 Subject: [PATCH] rate_limit: Stop wrapping rate limited functions. This refactors `rate_limit` so that we no longer use it as a decorator. This is a workaround to https://github.com/python/mypy/issues/12909 as `rate_limit` previous expects different parameters than its callers. Our approach to test logging handlers also needs to be updated because the view function is not decorated by `rate_limit`. Signed-off-by: Zixuan James Li --- zerver/decorator.py | 67 +++++++++------------------ zerver/tests/test_decorators.py | 3 +- zerver/tests/test_logging_handlers.py | 36 +++++++------- 3 files changed, 41 insertions(+), 65 deletions(-) diff --git a/zerver/decorator.py b/zerver/decorator.py index 1739137f5c..2b6d7c2e78 100644 --- a/zerver/decorator.py +++ b/zerver/decorator.py @@ -533,7 +533,8 @@ def add_logging_data( request: HttpRequest, /, *args: ParamT.args, **kwargs: ParamT.kwargs ) -> HttpResponse: process_client(request, request.user, is_browser_view=True, query=view_func.__name__) - return rate_limit()(view_func)(request, *args, **kwargs) + rate_limit(request) + return view_func(request, *args, **kwargs) return _wrapped_view_func @@ -724,10 +725,8 @@ def authenticated_uploads_api_view( ) -> HttpResponse: user_profile = validate_api_key(request, None, api_key, False) if not skip_rate_limiting: - limited_func = rate_limit()(view_func) - else: - limited_func = view_func - return limited_func(request, user_profile, *args, **kwargs) + rate_limit(request) + return view_func(request, user_profile, *args, **kwargs) return _wrapped_func_arguments @@ -788,10 +787,8 @@ def authenticated_rest_api_view( try: if not skip_rate_limiting: # Apply rate limiting - target_view_func = rate_limit()(view_func) - else: - target_view_func = view_func - return target_view_func(request, profile, *args, **kwargs) + rate_limit(request) + return view_func(request, profile, *args, **kwargs) except Exception as err: if not webhook_client_name: raise err @@ -865,9 +862,7 @@ def authenticate_log_and_execute_json( **kwargs: object, ) -> HttpResponse: if not skip_rate_limiting: - limited_view_func = rate_limit()(view_func) - else: - limited_view_func = view_func + rate_limit(request) if not request.user.is_authenticated: if not allow_unauthenticated: @@ -878,7 +873,7 @@ def authenticate_log_and_execute_json( is_browser_view=True, query=view_func.__name__, ) - return limited_view_func(request, request.user, *args, **kwargs) + return view_func(request, request.user, *args, **kwargs) user_profile = request.user validate_account_and_subdomain(request, user_profile) @@ -887,7 +882,7 @@ def authenticate_log_and_execute_json( raise JsonableError(_("Webhook bots can only access webhooks")) process_client(request, user_profile, is_browser_view=True, query=view_func.__name__) - return limited_view_func(request, user_profile, *args, **kwargs) + return view_func(request, user_profile, *args, **kwargs) # Checks if the user is logged in. If not, return an error (the @@ -1072,39 +1067,23 @@ def rate_limit_remote_server( raise e -def rate_limit() -> Callable[[ViewFuncT], ViewFuncT]: - """Rate-limits a view. Returns a decorator""" +def rate_limit(request: HttpRequest) -> None: + if not settings.RATE_LIMITING: + return - def wrapper(func: ViewFuncT) -> ViewFuncT: - @wraps(func) - def wrapped_func(request: HttpRequest, *args: object, **kwargs: object) -> HttpResponse: + if client_is_exempt_from_rate_limiting(request): + return - # It is really tempting to not even wrap our original function - # when settings.RATE_LIMITING is False, but it would make - # for awkward unit testing in some situations. - if not settings.RATE_LIMITING: - return func(request, *args, **kwargs) + user = request.user + remote_server = RequestNotes.get_notes(request).remote_server - if client_is_exempt_from_rate_limiting(request): - return func(request, *args, **kwargs) - - user = request.user - remote_server = RequestNotes.get_notes(request).remote_server - - if settings.ZILENCER_ENABLED and remote_server is not None: - rate_limit_remote_server(request, remote_server, domain="api_by_remote_server") - elif not user.is_authenticated: - rate_limit_request_by_ip(request, domain="api_by_ip") - return func(request, *args, **kwargs) - else: - assert isinstance(user, UserProfile) - rate_limit_user(request, user, domain="api_by_user") - - return func(request, *args, **kwargs) - - return cast(ViewFuncT, wrapped_func) # https://github.com/python/mypy/issues/1927 - - return wrapper + if settings.ZILENCER_ENABLED and remote_server is not None: + rate_limit_remote_server(request, remote_server, domain="api_by_remote_server") + elif not user.is_authenticated: + rate_limit_request_by_ip(request, domain="api_by_ip") + else: + assert isinstance(user, UserProfile) + rate_limit_user(request, user, domain="api_by_user") def return_success_on_head_request( diff --git a/zerver/tests/test_decorators.py b/zerver/tests/test_decorators.py index 3e01da933d..70768681d9 100644 --- a/zerver/tests/test_decorators.py +++ b/zerver/tests/test_decorators.py @@ -630,10 +630,9 @@ class DecoratorLoggingTestCase(ZulipTestCase): class RateLimitTestCase(ZulipTestCase): def get_ratelimited_view(self) -> Callable[..., HttpResponse]: def f(req: Any) -> HttpResponse: + rate_limit(req) return json_response(msg="some value") - f = rate_limit()(f) - return f def errors_disallowed(self) -> Any: diff --git a/zerver/tests/test_logging_handlers.py b/zerver/tests/test_logging_handlers.py index 346d57a043..ab1c20084f 100644 --- a/zerver/tests/test_logging_handlers.py +++ b/zerver/tests/test_logging_handlers.py @@ -2,7 +2,7 @@ import logging import sys from functools import wraps from types import TracebackType -from typing import Callable, Dict, Iterator, NoReturn, Optional, Tuple, Type, Union, cast +from typing import Dict, Iterator, NoReturn, Optional, Tuple, Type, Union, cast from unittest import mock from unittest.mock import MagicMock, patch @@ -22,22 +22,19 @@ captured_exc_info: Optional[ ] = None -def capture_and_throw(domain: Optional[str] = None) -> Callable[[ViewFuncT], ViewFuncT]: - def wrapper(view_func: ViewFuncT) -> ViewFuncT: - @wraps(view_func) - def wrapped_view(request: HttpRequest, *args: object, **kwargs: object) -> NoReturn: - global captured_request - captured_request = request - try: - raise Exception("Request error") - except Exception as e: - global captured_exc_info - captured_exc_info = sys.exc_info() - raise e +def capture_and_throw(view_func: ViewFuncT) -> ViewFuncT: + @wraps(view_func) + def wrapped_view(request: HttpRequest, *args: object, **kwargs: object) -> NoReturn: + global captured_request + captured_request = request + try: + raise Exception("Request error") + except Exception as e: + global captured_exc_info + captured_exc_info = sys.exc_info() + raise e - return cast(ViewFuncT, wrapped_view) # https://github.com/python/mypy/issues/1927 - - return wrapper + return cast(ViewFuncT, wrapped_view) # https://github.com/python/mypy/issues/1927 class AdminNotifyHandlerTest(ZulipTestCase): @@ -78,17 +75,18 @@ class AdminNotifyHandlerTest(ZulipTestCase): def simulate_error(self) -> logging.LogRecord: self.login("hamlet") - with patch("zerver.decorator.rate_limit") as rate_limit_patch, self.assertLogs( + with patch( + "zerver.lib.rest.authenticated_json_view", side_effect=capture_and_throw + ) as view_decorator_patch, self.assertLogs( "django.request", level="ERROR" ) as request_error_log, self.assertLogs( "zerver.middleware.json_error_handler", level="ERROR" ) as json_error_handler_log, self.settings( TEST_SUITE=False ): - rate_limit_patch.side_effect = capture_and_throw result = self.client_get("/json/users") self.assert_json_error(result, "Internal server error", status_code=500) - rate_limit_patch.assert_called_once() + view_decorator_patch.assert_called_once() self.assertEqual( request_error_log.output, ["ERROR:django.request:Internal Server Error: /json/users"] )