mirror of https://github.com/zulip/zulip.git
rate_limit: Stop wrapping rate limited functions.
This refactors `rate_limit` so that we no longer use it as a decorator. This is a workaround to https://github.com/python/mypy/issues/12909 as `rate_limit` previous expects different parameters than its callers. Our approach to test logging handlers also needs to be updated because the view function is not decorated by `rate_limit`. Signed-off-by: Zixuan James Li <p359101898@gmail.com>
This commit is contained in:
parent
cfa4973441
commit
232ba4866a
|
@ -533,7 +533,8 @@ def add_logging_data(
|
|||
request: HttpRequest, /, *args: ParamT.args, **kwargs: ParamT.kwargs
|
||||
) -> HttpResponse:
|
||||
process_client(request, request.user, is_browser_view=True, query=view_func.__name__)
|
||||
return rate_limit()(view_func)(request, *args, **kwargs)
|
||||
rate_limit(request)
|
||||
return view_func(request, *args, **kwargs)
|
||||
|
||||
return _wrapped_view_func
|
||||
|
||||
|
@ -724,10 +725,8 @@ def authenticated_uploads_api_view(
|
|||
) -> HttpResponse:
|
||||
user_profile = validate_api_key(request, None, api_key, False)
|
||||
if not skip_rate_limiting:
|
||||
limited_func = rate_limit()(view_func)
|
||||
else:
|
||||
limited_func = view_func
|
||||
return limited_func(request, user_profile, *args, **kwargs)
|
||||
rate_limit(request)
|
||||
return view_func(request, user_profile, *args, **kwargs)
|
||||
|
||||
return _wrapped_func_arguments
|
||||
|
||||
|
@ -788,10 +787,8 @@ def authenticated_rest_api_view(
|
|||
try:
|
||||
if not skip_rate_limiting:
|
||||
# Apply rate limiting
|
||||
target_view_func = rate_limit()(view_func)
|
||||
else:
|
||||
target_view_func = view_func
|
||||
return target_view_func(request, profile, *args, **kwargs)
|
||||
rate_limit(request)
|
||||
return view_func(request, profile, *args, **kwargs)
|
||||
except Exception as err:
|
||||
if not webhook_client_name:
|
||||
raise err
|
||||
|
@ -865,9 +862,7 @@ def authenticate_log_and_execute_json(
|
|||
**kwargs: object,
|
||||
) -> HttpResponse:
|
||||
if not skip_rate_limiting:
|
||||
limited_view_func = rate_limit()(view_func)
|
||||
else:
|
||||
limited_view_func = view_func
|
||||
rate_limit(request)
|
||||
|
||||
if not request.user.is_authenticated:
|
||||
if not allow_unauthenticated:
|
||||
|
@ -878,7 +873,7 @@ def authenticate_log_and_execute_json(
|
|||
is_browser_view=True,
|
||||
query=view_func.__name__,
|
||||
)
|
||||
return limited_view_func(request, request.user, *args, **kwargs)
|
||||
return view_func(request, request.user, *args, **kwargs)
|
||||
|
||||
user_profile = request.user
|
||||
validate_account_and_subdomain(request, user_profile)
|
||||
|
@ -887,7 +882,7 @@ def authenticate_log_and_execute_json(
|
|||
raise JsonableError(_("Webhook bots can only access webhooks"))
|
||||
|
||||
process_client(request, user_profile, is_browser_view=True, query=view_func.__name__)
|
||||
return limited_view_func(request, user_profile, *args, **kwargs)
|
||||
return view_func(request, user_profile, *args, **kwargs)
|
||||
|
||||
|
||||
# Checks if the user is logged in. If not, return an error (the
|
||||
|
@ -1072,39 +1067,23 @@ def rate_limit_remote_server(
|
|||
raise e
|
||||
|
||||
|
||||
def rate_limit() -> Callable[[ViewFuncT], ViewFuncT]:
|
||||
"""Rate-limits a view. Returns a decorator"""
|
||||
def rate_limit(request: HttpRequest) -> None:
|
||||
if not settings.RATE_LIMITING:
|
||||
return
|
||||
|
||||
def wrapper(func: ViewFuncT) -> ViewFuncT:
|
||||
@wraps(func)
|
||||
def wrapped_func(request: HttpRequest, *args: object, **kwargs: object) -> HttpResponse:
|
||||
if client_is_exempt_from_rate_limiting(request):
|
||||
return
|
||||
|
||||
# It is really tempting to not even wrap our original function
|
||||
# when settings.RATE_LIMITING is False, but it would make
|
||||
# for awkward unit testing in some situations.
|
||||
if not settings.RATE_LIMITING:
|
||||
return func(request, *args, **kwargs)
|
||||
user = request.user
|
||||
remote_server = RequestNotes.get_notes(request).remote_server
|
||||
|
||||
if client_is_exempt_from_rate_limiting(request):
|
||||
return func(request, *args, **kwargs)
|
||||
|
||||
user = request.user
|
||||
remote_server = RequestNotes.get_notes(request).remote_server
|
||||
|
||||
if settings.ZILENCER_ENABLED and remote_server is not None:
|
||||
rate_limit_remote_server(request, remote_server, domain="api_by_remote_server")
|
||||
elif not user.is_authenticated:
|
||||
rate_limit_request_by_ip(request, domain="api_by_ip")
|
||||
return func(request, *args, **kwargs)
|
||||
else:
|
||||
assert isinstance(user, UserProfile)
|
||||
rate_limit_user(request, user, domain="api_by_user")
|
||||
|
||||
return func(request, *args, **kwargs)
|
||||
|
||||
return cast(ViewFuncT, wrapped_func) # https://github.com/python/mypy/issues/1927
|
||||
|
||||
return wrapper
|
||||
if settings.ZILENCER_ENABLED and remote_server is not None:
|
||||
rate_limit_remote_server(request, remote_server, domain="api_by_remote_server")
|
||||
elif not user.is_authenticated:
|
||||
rate_limit_request_by_ip(request, domain="api_by_ip")
|
||||
else:
|
||||
assert isinstance(user, UserProfile)
|
||||
rate_limit_user(request, user, domain="api_by_user")
|
||||
|
||||
|
||||
def return_success_on_head_request(
|
||||
|
|
|
@ -630,10 +630,9 @@ class DecoratorLoggingTestCase(ZulipTestCase):
|
|||
class RateLimitTestCase(ZulipTestCase):
|
||||
def get_ratelimited_view(self) -> Callable[..., HttpResponse]:
|
||||
def f(req: Any) -> HttpResponse:
|
||||
rate_limit(req)
|
||||
return json_response(msg="some value")
|
||||
|
||||
f = rate_limit()(f)
|
||||
|
||||
return f
|
||||
|
||||
def errors_disallowed(self) -> Any:
|
||||
|
|
|
@ -2,7 +2,7 @@ import logging
|
|||
import sys
|
||||
from functools import wraps
|
||||
from types import TracebackType
|
||||
from typing import Callable, Dict, Iterator, NoReturn, Optional, Tuple, Type, Union, cast
|
||||
from typing import Dict, Iterator, NoReturn, Optional, Tuple, Type, Union, cast
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
@ -22,22 +22,19 @@ captured_exc_info: Optional[
|
|||
] = None
|
||||
|
||||
|
||||
def capture_and_throw(domain: Optional[str] = None) -> Callable[[ViewFuncT], ViewFuncT]:
|
||||
def wrapper(view_func: ViewFuncT) -> ViewFuncT:
|
||||
@wraps(view_func)
|
||||
def wrapped_view(request: HttpRequest, *args: object, **kwargs: object) -> NoReturn:
|
||||
global captured_request
|
||||
captured_request = request
|
||||
try:
|
||||
raise Exception("Request error")
|
||||
except Exception as e:
|
||||
global captured_exc_info
|
||||
captured_exc_info = sys.exc_info()
|
||||
raise e
|
||||
def capture_and_throw(view_func: ViewFuncT) -> ViewFuncT:
|
||||
@wraps(view_func)
|
||||
def wrapped_view(request: HttpRequest, *args: object, **kwargs: object) -> NoReturn:
|
||||
global captured_request
|
||||
captured_request = request
|
||||
try:
|
||||
raise Exception("Request error")
|
||||
except Exception as e:
|
||||
global captured_exc_info
|
||||
captured_exc_info = sys.exc_info()
|
||||
raise e
|
||||
|
||||
return cast(ViewFuncT, wrapped_view) # https://github.com/python/mypy/issues/1927
|
||||
|
||||
return wrapper
|
||||
return cast(ViewFuncT, wrapped_view) # https://github.com/python/mypy/issues/1927
|
||||
|
||||
|
||||
class AdminNotifyHandlerTest(ZulipTestCase):
|
||||
|
@ -78,17 +75,18 @@ class AdminNotifyHandlerTest(ZulipTestCase):
|
|||
|
||||
def simulate_error(self) -> logging.LogRecord:
|
||||
self.login("hamlet")
|
||||
with patch("zerver.decorator.rate_limit") as rate_limit_patch, self.assertLogs(
|
||||
with patch(
|
||||
"zerver.lib.rest.authenticated_json_view", side_effect=capture_and_throw
|
||||
) as view_decorator_patch, self.assertLogs(
|
||||
"django.request", level="ERROR"
|
||||
) as request_error_log, self.assertLogs(
|
||||
"zerver.middleware.json_error_handler", level="ERROR"
|
||||
) as json_error_handler_log, self.settings(
|
||||
TEST_SUITE=False
|
||||
):
|
||||
rate_limit_patch.side_effect = capture_and_throw
|
||||
result = self.client_get("/json/users")
|
||||
self.assert_json_error(result, "Internal server error", status_code=500)
|
||||
rate_limit_patch.assert_called_once()
|
||||
view_decorator_patch.assert_called_once()
|
||||
self.assertEqual(
|
||||
request_error_log.output, ["ERROR:django.request:Internal Server Error: /json/users"]
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue