mirror of https://github.com/zulip/zulip.git
rate_limit: Stop wrapping rate limited functions.
This refactors `rate_limit` so that we no longer use it as a decorator. This is a workaround to https://github.com/python/mypy/issues/12909 as `rate_limit` previous expects different parameters than its callers. Our approach to test logging handlers also needs to be updated because the view function is not decorated by `rate_limit`. Signed-off-by: Zixuan James Li <p359101898@gmail.com>
This commit is contained in:
parent
cfa4973441
commit
232ba4866a
|
@ -533,7 +533,8 @@ def add_logging_data(
|
||||||
request: HttpRequest, /, *args: ParamT.args, **kwargs: ParamT.kwargs
|
request: HttpRequest, /, *args: ParamT.args, **kwargs: ParamT.kwargs
|
||||||
) -> HttpResponse:
|
) -> HttpResponse:
|
||||||
process_client(request, request.user, is_browser_view=True, query=view_func.__name__)
|
process_client(request, request.user, is_browser_view=True, query=view_func.__name__)
|
||||||
return rate_limit()(view_func)(request, *args, **kwargs)
|
rate_limit(request)
|
||||||
|
return view_func(request, *args, **kwargs)
|
||||||
|
|
||||||
return _wrapped_view_func
|
return _wrapped_view_func
|
||||||
|
|
||||||
|
@ -724,10 +725,8 @@ def authenticated_uploads_api_view(
|
||||||
) -> HttpResponse:
|
) -> HttpResponse:
|
||||||
user_profile = validate_api_key(request, None, api_key, False)
|
user_profile = validate_api_key(request, None, api_key, False)
|
||||||
if not skip_rate_limiting:
|
if not skip_rate_limiting:
|
||||||
limited_func = rate_limit()(view_func)
|
rate_limit(request)
|
||||||
else:
|
return view_func(request, user_profile, *args, **kwargs)
|
||||||
limited_func = view_func
|
|
||||||
return limited_func(request, user_profile, *args, **kwargs)
|
|
||||||
|
|
||||||
return _wrapped_func_arguments
|
return _wrapped_func_arguments
|
||||||
|
|
||||||
|
@ -788,10 +787,8 @@ def authenticated_rest_api_view(
|
||||||
try:
|
try:
|
||||||
if not skip_rate_limiting:
|
if not skip_rate_limiting:
|
||||||
# Apply rate limiting
|
# Apply rate limiting
|
||||||
target_view_func = rate_limit()(view_func)
|
rate_limit(request)
|
||||||
else:
|
return view_func(request, profile, *args, **kwargs)
|
||||||
target_view_func = view_func
|
|
||||||
return target_view_func(request, profile, *args, **kwargs)
|
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
if not webhook_client_name:
|
if not webhook_client_name:
|
||||||
raise err
|
raise err
|
||||||
|
@ -865,9 +862,7 @@ def authenticate_log_and_execute_json(
|
||||||
**kwargs: object,
|
**kwargs: object,
|
||||||
) -> HttpResponse:
|
) -> HttpResponse:
|
||||||
if not skip_rate_limiting:
|
if not skip_rate_limiting:
|
||||||
limited_view_func = rate_limit()(view_func)
|
rate_limit(request)
|
||||||
else:
|
|
||||||
limited_view_func = view_func
|
|
||||||
|
|
||||||
if not request.user.is_authenticated:
|
if not request.user.is_authenticated:
|
||||||
if not allow_unauthenticated:
|
if not allow_unauthenticated:
|
||||||
|
@ -878,7 +873,7 @@ def authenticate_log_and_execute_json(
|
||||||
is_browser_view=True,
|
is_browser_view=True,
|
||||||
query=view_func.__name__,
|
query=view_func.__name__,
|
||||||
)
|
)
|
||||||
return limited_view_func(request, request.user, *args, **kwargs)
|
return view_func(request, request.user, *args, **kwargs)
|
||||||
|
|
||||||
user_profile = request.user
|
user_profile = request.user
|
||||||
validate_account_and_subdomain(request, user_profile)
|
validate_account_and_subdomain(request, user_profile)
|
||||||
|
@ -887,7 +882,7 @@ def authenticate_log_and_execute_json(
|
||||||
raise JsonableError(_("Webhook bots can only access webhooks"))
|
raise JsonableError(_("Webhook bots can only access webhooks"))
|
||||||
|
|
||||||
process_client(request, user_profile, is_browser_view=True, query=view_func.__name__)
|
process_client(request, user_profile, is_browser_view=True, query=view_func.__name__)
|
||||||
return limited_view_func(request, user_profile, *args, **kwargs)
|
return view_func(request, user_profile, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
# Checks if the user is logged in. If not, return an error (the
|
# Checks if the user is logged in. If not, return an error (the
|
||||||
|
@ -1072,39 +1067,23 @@ def rate_limit_remote_server(
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def rate_limit() -> Callable[[ViewFuncT], ViewFuncT]:
|
def rate_limit(request: HttpRequest) -> None:
|
||||||
"""Rate-limits a view. Returns a decorator"""
|
if not settings.RATE_LIMITING:
|
||||||
|
return
|
||||||
|
|
||||||
def wrapper(func: ViewFuncT) -> ViewFuncT:
|
if client_is_exempt_from_rate_limiting(request):
|
||||||
@wraps(func)
|
return
|
||||||
def wrapped_func(request: HttpRequest, *args: object, **kwargs: object) -> HttpResponse:
|
|
||||||
|
|
||||||
# It is really tempting to not even wrap our original function
|
user = request.user
|
||||||
# when settings.RATE_LIMITING is False, but it would make
|
remote_server = RequestNotes.get_notes(request).remote_server
|
||||||
# for awkward unit testing in some situations.
|
|
||||||
if not settings.RATE_LIMITING:
|
|
||||||
return func(request, *args, **kwargs)
|
|
||||||
|
|
||||||
if client_is_exempt_from_rate_limiting(request):
|
if settings.ZILENCER_ENABLED and remote_server is not None:
|
||||||
return func(request, *args, **kwargs)
|
rate_limit_remote_server(request, remote_server, domain="api_by_remote_server")
|
||||||
|
elif not user.is_authenticated:
|
||||||
user = request.user
|
rate_limit_request_by_ip(request, domain="api_by_ip")
|
||||||
remote_server = RequestNotes.get_notes(request).remote_server
|
else:
|
||||||
|
assert isinstance(user, UserProfile)
|
||||||
if settings.ZILENCER_ENABLED and remote_server is not None:
|
rate_limit_user(request, user, domain="api_by_user")
|
||||||
rate_limit_remote_server(request, remote_server, domain="api_by_remote_server")
|
|
||||||
elif not user.is_authenticated:
|
|
||||||
rate_limit_request_by_ip(request, domain="api_by_ip")
|
|
||||||
return func(request, *args, **kwargs)
|
|
||||||
else:
|
|
||||||
assert isinstance(user, UserProfile)
|
|
||||||
rate_limit_user(request, user, domain="api_by_user")
|
|
||||||
|
|
||||||
return func(request, *args, **kwargs)
|
|
||||||
|
|
||||||
return cast(ViewFuncT, wrapped_func) # https://github.com/python/mypy/issues/1927
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
|
|
||||||
def return_success_on_head_request(
|
def return_success_on_head_request(
|
||||||
|
|
|
@ -630,10 +630,9 @@ class DecoratorLoggingTestCase(ZulipTestCase):
|
||||||
class RateLimitTestCase(ZulipTestCase):
|
class RateLimitTestCase(ZulipTestCase):
|
||||||
def get_ratelimited_view(self) -> Callable[..., HttpResponse]:
|
def get_ratelimited_view(self) -> Callable[..., HttpResponse]:
|
||||||
def f(req: Any) -> HttpResponse:
|
def f(req: Any) -> HttpResponse:
|
||||||
|
rate_limit(req)
|
||||||
return json_response(msg="some value")
|
return json_response(msg="some value")
|
||||||
|
|
||||||
f = rate_limit()(f)
|
|
||||||
|
|
||||||
return f
|
return f
|
||||||
|
|
||||||
def errors_disallowed(self) -> Any:
|
def errors_disallowed(self) -> Any:
|
||||||
|
|
|
@ -2,7 +2,7 @@ import logging
|
||||||
import sys
|
import sys
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from types import TracebackType
|
from types import TracebackType
|
||||||
from typing import Callable, Dict, Iterator, NoReturn, Optional, Tuple, Type, Union, cast
|
from typing import Dict, Iterator, NoReturn, Optional, Tuple, Type, Union, cast
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
@ -22,22 +22,19 @@ captured_exc_info: Optional[
|
||||||
] = None
|
] = None
|
||||||
|
|
||||||
|
|
||||||
def capture_and_throw(domain: Optional[str] = None) -> Callable[[ViewFuncT], ViewFuncT]:
|
def capture_and_throw(view_func: ViewFuncT) -> ViewFuncT:
|
||||||
def wrapper(view_func: ViewFuncT) -> ViewFuncT:
|
@wraps(view_func)
|
||||||
@wraps(view_func)
|
def wrapped_view(request: HttpRequest, *args: object, **kwargs: object) -> NoReturn:
|
||||||
def wrapped_view(request: HttpRequest, *args: object, **kwargs: object) -> NoReturn:
|
global captured_request
|
||||||
global captured_request
|
captured_request = request
|
||||||
captured_request = request
|
try:
|
||||||
try:
|
raise Exception("Request error")
|
||||||
raise Exception("Request error")
|
except Exception as e:
|
||||||
except Exception as e:
|
global captured_exc_info
|
||||||
global captured_exc_info
|
captured_exc_info = sys.exc_info()
|
||||||
captured_exc_info = sys.exc_info()
|
raise e
|
||||||
raise e
|
|
||||||
|
|
||||||
return cast(ViewFuncT, wrapped_view) # https://github.com/python/mypy/issues/1927
|
return cast(ViewFuncT, wrapped_view) # https://github.com/python/mypy/issues/1927
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
|
|
||||||
class AdminNotifyHandlerTest(ZulipTestCase):
|
class AdminNotifyHandlerTest(ZulipTestCase):
|
||||||
|
@ -78,17 +75,18 @@ class AdminNotifyHandlerTest(ZulipTestCase):
|
||||||
|
|
||||||
def simulate_error(self) -> logging.LogRecord:
|
def simulate_error(self) -> logging.LogRecord:
|
||||||
self.login("hamlet")
|
self.login("hamlet")
|
||||||
with patch("zerver.decorator.rate_limit") as rate_limit_patch, self.assertLogs(
|
with patch(
|
||||||
|
"zerver.lib.rest.authenticated_json_view", side_effect=capture_and_throw
|
||||||
|
) as view_decorator_patch, self.assertLogs(
|
||||||
"django.request", level="ERROR"
|
"django.request", level="ERROR"
|
||||||
) as request_error_log, self.assertLogs(
|
) as request_error_log, self.assertLogs(
|
||||||
"zerver.middleware.json_error_handler", level="ERROR"
|
"zerver.middleware.json_error_handler", level="ERROR"
|
||||||
) as json_error_handler_log, self.settings(
|
) as json_error_handler_log, self.settings(
|
||||||
TEST_SUITE=False
|
TEST_SUITE=False
|
||||||
):
|
):
|
||||||
rate_limit_patch.side_effect = capture_and_throw
|
|
||||||
result = self.client_get("/json/users")
|
result = self.client_get("/json/users")
|
||||||
self.assert_json_error(result, "Internal server error", status_code=500)
|
self.assert_json_error(result, "Internal server error", status_code=500)
|
||||||
rate_limit_patch.assert_called_once()
|
view_decorator_patch.assert_called_once()
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
request_error_log.output, ["ERROR:django.request:Internal Server Error: /json/users"]
|
request_error_log.output, ["ERROR:django.request:Internal Server Error: /json/users"]
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue