rate_limit: Stop wrapping rate limited functions.

This refactors `rate_limit` so that we no longer use it as a decorator.
This is a workaround to https://github.com/python/mypy/issues/12909 as
`rate_limit` previous expects different parameters than its callers.

Our approach to test logging handlers also needs to be updated because
the view function is not decorated by `rate_limit`.

Signed-off-by: Zixuan James Li <p359101898@gmail.com>
This commit is contained in:
Zixuan James Li 2022-07-28 07:59:22 -04:00 committed by Tim Abbott
parent cfa4973441
commit 232ba4866a
3 changed files with 41 additions and 65 deletions

View File

@ -533,7 +533,8 @@ def add_logging_data(
request: HttpRequest, /, *args: ParamT.args, **kwargs: ParamT.kwargs request: HttpRequest, /, *args: ParamT.args, **kwargs: ParamT.kwargs
) -> HttpResponse: ) -> HttpResponse:
process_client(request, request.user, is_browser_view=True, query=view_func.__name__) process_client(request, request.user, is_browser_view=True, query=view_func.__name__)
return rate_limit()(view_func)(request, *args, **kwargs) rate_limit(request)
return view_func(request, *args, **kwargs)
return _wrapped_view_func return _wrapped_view_func
@ -724,10 +725,8 @@ def authenticated_uploads_api_view(
) -> HttpResponse: ) -> HttpResponse:
user_profile = validate_api_key(request, None, api_key, False) user_profile = validate_api_key(request, None, api_key, False)
if not skip_rate_limiting: if not skip_rate_limiting:
limited_func = rate_limit()(view_func) rate_limit(request)
else: return view_func(request, user_profile, *args, **kwargs)
limited_func = view_func
return limited_func(request, user_profile, *args, **kwargs)
return _wrapped_func_arguments return _wrapped_func_arguments
@ -788,10 +787,8 @@ def authenticated_rest_api_view(
try: try:
if not skip_rate_limiting: if not skip_rate_limiting:
# Apply rate limiting # Apply rate limiting
target_view_func = rate_limit()(view_func) rate_limit(request)
else: return view_func(request, profile, *args, **kwargs)
target_view_func = view_func
return target_view_func(request, profile, *args, **kwargs)
except Exception as err: except Exception as err:
if not webhook_client_name: if not webhook_client_name:
raise err raise err
@ -865,9 +862,7 @@ def authenticate_log_and_execute_json(
**kwargs: object, **kwargs: object,
) -> HttpResponse: ) -> HttpResponse:
if not skip_rate_limiting: if not skip_rate_limiting:
limited_view_func = rate_limit()(view_func) rate_limit(request)
else:
limited_view_func = view_func
if not request.user.is_authenticated: if not request.user.is_authenticated:
if not allow_unauthenticated: if not allow_unauthenticated:
@ -878,7 +873,7 @@ def authenticate_log_and_execute_json(
is_browser_view=True, is_browser_view=True,
query=view_func.__name__, query=view_func.__name__,
) )
return limited_view_func(request, request.user, *args, **kwargs) return view_func(request, request.user, *args, **kwargs)
user_profile = request.user user_profile = request.user
validate_account_and_subdomain(request, user_profile) validate_account_and_subdomain(request, user_profile)
@ -887,7 +882,7 @@ def authenticate_log_and_execute_json(
raise JsonableError(_("Webhook bots can only access webhooks")) raise JsonableError(_("Webhook bots can only access webhooks"))
process_client(request, user_profile, is_browser_view=True, query=view_func.__name__) process_client(request, user_profile, is_browser_view=True, query=view_func.__name__)
return limited_view_func(request, user_profile, *args, **kwargs) return view_func(request, user_profile, *args, **kwargs)
# Checks if the user is logged in. If not, return an error (the # Checks if the user is logged in. If not, return an error (the
@ -1072,39 +1067,23 @@ def rate_limit_remote_server(
raise e raise e
def rate_limit() -> Callable[[ViewFuncT], ViewFuncT]: def rate_limit(request: HttpRequest) -> None:
"""Rate-limits a view. Returns a decorator""" if not settings.RATE_LIMITING:
return
def wrapper(func: ViewFuncT) -> ViewFuncT: if client_is_exempt_from_rate_limiting(request):
@wraps(func) return
def wrapped_func(request: HttpRequest, *args: object, **kwargs: object) -> HttpResponse:
# It is really tempting to not even wrap our original function user = request.user
# when settings.RATE_LIMITING is False, but it would make remote_server = RequestNotes.get_notes(request).remote_server
# for awkward unit testing in some situations.
if not settings.RATE_LIMITING:
return func(request, *args, **kwargs)
if client_is_exempt_from_rate_limiting(request): if settings.ZILENCER_ENABLED and remote_server is not None:
return func(request, *args, **kwargs) rate_limit_remote_server(request, remote_server, domain="api_by_remote_server")
elif not user.is_authenticated:
user = request.user rate_limit_request_by_ip(request, domain="api_by_ip")
remote_server = RequestNotes.get_notes(request).remote_server else:
assert isinstance(user, UserProfile)
if settings.ZILENCER_ENABLED and remote_server is not None: rate_limit_user(request, user, domain="api_by_user")
rate_limit_remote_server(request, remote_server, domain="api_by_remote_server")
elif not user.is_authenticated:
rate_limit_request_by_ip(request, domain="api_by_ip")
return func(request, *args, **kwargs)
else:
assert isinstance(user, UserProfile)
rate_limit_user(request, user, domain="api_by_user")
return func(request, *args, **kwargs)
return cast(ViewFuncT, wrapped_func) # https://github.com/python/mypy/issues/1927
return wrapper
def return_success_on_head_request( def return_success_on_head_request(

View File

@ -630,10 +630,9 @@ class DecoratorLoggingTestCase(ZulipTestCase):
class RateLimitTestCase(ZulipTestCase): class RateLimitTestCase(ZulipTestCase):
def get_ratelimited_view(self) -> Callable[..., HttpResponse]: def get_ratelimited_view(self) -> Callable[..., HttpResponse]:
def f(req: Any) -> HttpResponse: def f(req: Any) -> HttpResponse:
rate_limit(req)
return json_response(msg="some value") return json_response(msg="some value")
f = rate_limit()(f)
return f return f
def errors_disallowed(self) -> Any: def errors_disallowed(self) -> Any:

View File

@ -2,7 +2,7 @@ import logging
import sys import sys
from functools import wraps from functools import wraps
from types import TracebackType from types import TracebackType
from typing import Callable, Dict, Iterator, NoReturn, Optional, Tuple, Type, Union, cast from typing import Dict, Iterator, NoReturn, Optional, Tuple, Type, Union, cast
from unittest import mock from unittest import mock
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
@ -22,22 +22,19 @@ captured_exc_info: Optional[
] = None ] = None
def capture_and_throw(domain: Optional[str] = None) -> Callable[[ViewFuncT], ViewFuncT]: def capture_and_throw(view_func: ViewFuncT) -> ViewFuncT:
def wrapper(view_func: ViewFuncT) -> ViewFuncT: @wraps(view_func)
@wraps(view_func) def wrapped_view(request: HttpRequest, *args: object, **kwargs: object) -> NoReturn:
def wrapped_view(request: HttpRequest, *args: object, **kwargs: object) -> NoReturn: global captured_request
global captured_request captured_request = request
captured_request = request try:
try: raise Exception("Request error")
raise Exception("Request error") except Exception as e:
except Exception as e: global captured_exc_info
global captured_exc_info captured_exc_info = sys.exc_info()
captured_exc_info = sys.exc_info() raise e
raise e
return cast(ViewFuncT, wrapped_view) # https://github.com/python/mypy/issues/1927 return cast(ViewFuncT, wrapped_view) # https://github.com/python/mypy/issues/1927
return wrapper
class AdminNotifyHandlerTest(ZulipTestCase): class AdminNotifyHandlerTest(ZulipTestCase):
@ -78,17 +75,18 @@ class AdminNotifyHandlerTest(ZulipTestCase):
def simulate_error(self) -> logging.LogRecord: def simulate_error(self) -> logging.LogRecord:
self.login("hamlet") self.login("hamlet")
with patch("zerver.decorator.rate_limit") as rate_limit_patch, self.assertLogs( with patch(
"zerver.lib.rest.authenticated_json_view", side_effect=capture_and_throw
) as view_decorator_patch, self.assertLogs(
"django.request", level="ERROR" "django.request", level="ERROR"
) as request_error_log, self.assertLogs( ) as request_error_log, self.assertLogs(
"zerver.middleware.json_error_handler", level="ERROR" "zerver.middleware.json_error_handler", level="ERROR"
) as json_error_handler_log, self.settings( ) as json_error_handler_log, self.settings(
TEST_SUITE=False TEST_SUITE=False
): ):
rate_limit_patch.side_effect = capture_and_throw
result = self.client_get("/json/users") result = self.client_get("/json/users")
self.assert_json_error(result, "Internal server error", status_code=500) self.assert_json_error(result, "Internal server error", status_code=500)
rate_limit_patch.assert_called_once() view_decorator_patch.assert_called_once()
self.assertEqual( self.assertEqual(
request_error_log.output, ["ERROR:django.request:Internal Server Error: /json/users"] request_error_log.output, ["ERROR:django.request:Internal Server Error: /json/users"]
) )