mirror of https://github.com/zulip/zulip.git
rate_limit: Replace rate_limit with inlined rate limit checks.
This change incorporate should_rate_limit into rate_limit_user and rate_limit_request_by_ip. Note a slight behavior change to other callers to rate_limit_request_by_ip is made as we now check if the client is eligible to be exempted from rate limiting now, which was previously only done as a part of zerver.lib.rate_limiter.rate_limit. Now we mock zerver.lib.rate_limiter.RateLimitedUser instead of zerver.decorator.rate_limit_user in zerver.tests.test_decorators.RateLimitTestCase, because rate_limit_user will always be called but rate limit only happens the should_rate_limit check passes; we can continue to mock zerver.lib.rate_limiter.rate_limit_ip, because the decorated view functions call rate_limit_request_by_ip that calls rate_limit_ip when the should_rate_limit check passes. We need to mock zerver.decorator.rate_limit_user for SkipRateLimitingTest now because rate_limit has been removed. We don't need to mock RateLimitedUser in this case because we are only verifying that the skip_rate_limiting flag works. To ensure coverage in add_logging_data, a new test case is added to use a web_public_view (which decorates the view function with add_logging_data) with a new flag to check_rate_limit_public_or_user_views. Signed-off-by: Zixuan James Li <p359101898@gmail.com>
This commit is contained in:
parent
2aac1dc40a
commit
26a518267a
|
@ -53,7 +53,7 @@ from zerver.lib.exceptions import (
|
||||||
WebhookError,
|
WebhookError,
|
||||||
)
|
)
|
||||||
from zerver.lib.queue import queue_json_publish
|
from zerver.lib.queue import queue_json_publish
|
||||||
from zerver.lib.rate_limiter import is_local_addr, rate_limit, rate_limit_user
|
from zerver.lib.rate_limiter import is_local_addr, rate_limit_request_by_ip, rate_limit_user
|
||||||
from zerver.lib.request import REQ, RequestNotes, has_request_variables
|
from zerver.lib.request import REQ, RequestNotes, has_request_variables
|
||||||
from zerver.lib.response import json_method_not_allowed, json_success
|
from zerver.lib.response import json_method_not_allowed, json_success
|
||||||
from zerver.lib.subdomains import get_subdomain, user_matches_subdomain
|
from zerver.lib.subdomains import get_subdomain, user_matches_subdomain
|
||||||
|
@ -347,8 +347,7 @@ def webhook_view(
|
||||||
client_name=full_webhook_client_name(webhook_client_name),
|
client_name=full_webhook_client_name(webhook_client_name),
|
||||||
)
|
)
|
||||||
|
|
||||||
if settings.RATE_LIMITING:
|
rate_limit_user(request, user_profile, domain="api_by_user")
|
||||||
rate_limit_user(request, user_profile, domain="api_by_user")
|
|
||||||
try:
|
try:
|
||||||
return view_func(request, user_profile, *args, **kwargs)
|
return view_func(request, user_profile, *args, **kwargs)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
|
@ -481,7 +480,12 @@ def add_logging_data(
|
||||||
request: HttpRequest, /, *args: ParamT.args, **kwargs: ParamT.kwargs
|
request: HttpRequest, /, *args: ParamT.args, **kwargs: ParamT.kwargs
|
||||||
) -> HttpResponse:
|
) -> HttpResponse:
|
||||||
process_client(request, request.user, is_browser_view=True, query=view_func.__name__)
|
process_client(request, request.user, is_browser_view=True, query=view_func.__name__)
|
||||||
rate_limit(request)
|
|
||||||
|
if request.user.is_authenticated:
|
||||||
|
rate_limit_user(request, request.user, domain="api_by_user")
|
||||||
|
else:
|
||||||
|
rate_limit_request_by_ip(request, domain="api_by_ip")
|
||||||
|
|
||||||
return view_func(request, *args, **kwargs)
|
return view_func(request, *args, **kwargs)
|
||||||
|
|
||||||
return _wrapped_view_func
|
return _wrapped_view_func
|
||||||
|
@ -673,7 +677,7 @@ def authenticated_uploads_api_view(
|
||||||
) -> HttpResponse:
|
) -> HttpResponse:
|
||||||
user_profile = validate_api_key(request, None, api_key, False)
|
user_profile = validate_api_key(request, None, api_key, False)
|
||||||
if not skip_rate_limiting:
|
if not skip_rate_limiting:
|
||||||
rate_limit(request)
|
rate_limit_user(request, user_profile, domain="api_by_user")
|
||||||
return view_func(request, user_profile, *args, **kwargs)
|
return view_func(request, user_profile, *args, **kwargs)
|
||||||
|
|
||||||
return _wrapped_func_arguments
|
return _wrapped_func_arguments
|
||||||
|
@ -751,8 +755,7 @@ def authenticated_rest_api_view(
|
||||||
raise UnauthorizedError(e.msg)
|
raise UnauthorizedError(e.msg)
|
||||||
try:
|
try:
|
||||||
if not skip_rate_limiting:
|
if not skip_rate_limiting:
|
||||||
# Apply rate limiting
|
rate_limit_user(request, profile, domain="api_by_user")
|
||||||
rate_limit(request)
|
|
||||||
return view_func(request, profile, *args, **kwargs)
|
return view_func(request, profile, *args, **kwargs)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
if not webhook_client_name:
|
if not webhook_client_name:
|
||||||
|
@ -839,7 +842,7 @@ def public_json_view(
|
||||||
|
|
||||||
# Otherwise, process the request for a logged-out visitor.
|
# Otherwise, process the request for a logged-out visitor.
|
||||||
if not skip_rate_limiting:
|
if not skip_rate_limiting:
|
||||||
rate_limit(request)
|
rate_limit_request_by_ip(request, domain="api_by_ip")
|
||||||
|
|
||||||
process_client(
|
process_client(
|
||||||
request,
|
request,
|
||||||
|
@ -870,7 +873,7 @@ def authenticated_json_view(
|
||||||
|
|
||||||
user_profile = request.user
|
user_profile = request.user
|
||||||
if not skip_rate_limiting:
|
if not skip_rate_limiting:
|
||||||
rate_limit(request)
|
rate_limit_user(request, user_profile, domain="api_by_user")
|
||||||
|
|
||||||
validate_account_and_subdomain(request, user_profile)
|
validate_account_and_subdomain(request, user_profile)
|
||||||
|
|
||||||
|
|
|
@ -582,6 +582,8 @@ def rate_limit_user(request: HttpRequest, user: UserProfile, domain: str) -> Non
|
||||||
"""Returns whether or not a user was rate limited. Will raise a RateLimited exception
|
"""Returns whether or not a user was rate limited. Will raise a RateLimited exception
|
||||||
if the user has been rate limited, otherwise returns and modifies request to contain
|
if the user has been rate limited, otherwise returns and modifies request to contain
|
||||||
the rate limit information"""
|
the rate limit information"""
|
||||||
|
if not should_rate_limit(request):
|
||||||
|
return
|
||||||
|
|
||||||
RateLimitedUser(user, domain=domain).rate_limit_request(request)
|
RateLimitedUser(user, domain=domain).rate_limit_request(request)
|
||||||
|
|
||||||
|
@ -591,6 +593,9 @@ def rate_limit_ip(request: HttpRequest, ip_addr: str, domain: str) -> None:
|
||||||
|
|
||||||
|
|
||||||
def rate_limit_request_by_ip(request: HttpRequest, domain: str) -> None:
|
def rate_limit_request_by_ip(request: HttpRequest, domain: str) -> None:
|
||||||
|
if not should_rate_limit(request):
|
||||||
|
return
|
||||||
|
|
||||||
# REMOTE_ADDR is set by SetRemoteAddrFromRealIpHeader in conjunction
|
# REMOTE_ADDR is set by SetRemoteAddrFromRealIpHeader in conjunction
|
||||||
# with the nginx configuration to guarantee this to be *the* correct
|
# with the nginx configuration to guarantee this to be *the* correct
|
||||||
# IP address to use - without worrying we'll grab the IP of a proxy.
|
# IP address to use - without worrying we'll grab the IP of a proxy.
|
||||||
|
@ -625,16 +630,3 @@ def should_rate_limit(request: HttpRequest) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def rate_limit(request: HttpRequest) -> None:
|
|
||||||
if not should_rate_limit(request):
|
|
||||||
return
|
|
||||||
|
|
||||||
user = request.user
|
|
||||||
|
|
||||||
if not user.is_authenticated:
|
|
||||||
rate_limit_request_by_ip(request, domain="api_by_ip")
|
|
||||||
else:
|
|
||||||
assert isinstance(user, UserProfile)
|
|
||||||
rate_limit_user(request, user, domain="api_by_user")
|
|
||||||
|
|
|
@ -31,6 +31,7 @@ from zerver.decorator import (
|
||||||
public_json_view,
|
public_json_view,
|
||||||
return_success_on_head_request,
|
return_success_on_head_request,
|
||||||
validate_api_key,
|
validate_api_key,
|
||||||
|
web_public_view,
|
||||||
webhook_view,
|
webhook_view,
|
||||||
zulip_login_required,
|
zulip_login_required,
|
||||||
zulip_otp_required_if_logged_in,
|
zulip_otp_required_if_logged_in,
|
||||||
|
@ -467,7 +468,7 @@ class SkipRateLimitingTest(ZulipTestCase):
|
||||||
request = HostRequestMock(host="zulip.testserver")
|
request = HostRequestMock(host="zulip.testserver")
|
||||||
request.META["HTTP_AUTHORIZATION"] = self.encode_email(self.example_email("hamlet"))
|
request.META["HTTP_AUTHORIZATION"] = self.encode_email(self.example_email("hamlet"))
|
||||||
request.method = "POST"
|
request.method = "POST"
|
||||||
with mock.patch("zerver.decorator.rate_limit") as rate_limit_mock:
|
with mock.patch("zerver.decorator.rate_limit_user") as rate_limit_mock:
|
||||||
result = my_unlimited_view(request)
|
result = my_unlimited_view(request)
|
||||||
|
|
||||||
self.assert_json_success(result)
|
self.assert_json_success(result)
|
||||||
|
@ -476,7 +477,7 @@ class SkipRateLimitingTest(ZulipTestCase):
|
||||||
request = HostRequestMock(host="zulip.testserver")
|
request = HostRequestMock(host="zulip.testserver")
|
||||||
request.META["HTTP_AUTHORIZATION"] = self.encode_email(self.example_email("hamlet"))
|
request.META["HTTP_AUTHORIZATION"] = self.encode_email(self.example_email("hamlet"))
|
||||||
request.method = "POST"
|
request.method = "POST"
|
||||||
with mock.patch("zerver.decorator.rate_limit") as rate_limit_mock:
|
with mock.patch("zerver.decorator.rate_limit_user") as rate_limit_mock:
|
||||||
result = my_rate_limited_view(request)
|
result = my_rate_limited_view(request)
|
||||||
|
|
||||||
# Don't assert json_success, since it'll be the rate_limit mock object
|
# Don't assert json_success, since it'll be the rate_limit mock object
|
||||||
|
@ -494,7 +495,7 @@ class SkipRateLimitingTest(ZulipTestCase):
|
||||||
request = HostRequestMock(host="zulip.testserver")
|
request = HostRequestMock(host="zulip.testserver")
|
||||||
request.method = "POST"
|
request.method = "POST"
|
||||||
request.POST["api_key"] = get_api_key(self.example_user("hamlet"))
|
request.POST["api_key"] = get_api_key(self.example_user("hamlet"))
|
||||||
with mock.patch("zerver.decorator.rate_limit") as rate_limit_mock:
|
with mock.patch("zerver.decorator.rate_limit_user") as rate_limit_mock:
|
||||||
result = my_unlimited_view(request)
|
result = my_unlimited_view(request)
|
||||||
|
|
||||||
self.assert_json_success(result)
|
self.assert_json_success(result)
|
||||||
|
@ -503,7 +504,7 @@ class SkipRateLimitingTest(ZulipTestCase):
|
||||||
request = HostRequestMock(host="zulip.testserver")
|
request = HostRequestMock(host="zulip.testserver")
|
||||||
request.method = "POST"
|
request.method = "POST"
|
||||||
request.POST["api_key"] = get_api_key(self.example_user("hamlet"))
|
request.POST["api_key"] = get_api_key(self.example_user("hamlet"))
|
||||||
with mock.patch("zerver.decorator.rate_limit") as rate_limit_mock:
|
with mock.patch("zerver.decorator.rate_limit_user") as rate_limit_mock:
|
||||||
result = my_rate_limited_view(request)
|
result = my_rate_limited_view(request)
|
||||||
|
|
||||||
# Don't assert json_success, since it'll be the rate_limit mock object
|
# Don't assert json_success, since it'll be the rate_limit mock object
|
||||||
|
@ -519,7 +520,7 @@ class SkipRateLimitingTest(ZulipTestCase):
|
||||||
request = HostRequestMock(host="zulip.testserver")
|
request = HostRequestMock(host="zulip.testserver")
|
||||||
request.method = "POST"
|
request.method = "POST"
|
||||||
request.user = self.example_user("hamlet")
|
request.user = self.example_user("hamlet")
|
||||||
with mock.patch("zerver.decorator.rate_limit") as rate_limit_mock:
|
with mock.patch("zerver.decorator.rate_limit_user") as rate_limit_mock:
|
||||||
result = my_unlimited_view(request)
|
result = my_unlimited_view(request)
|
||||||
|
|
||||||
self.assert_json_success(result)
|
self.assert_json_success(result)
|
||||||
|
@ -528,7 +529,7 @@ class SkipRateLimitingTest(ZulipTestCase):
|
||||||
request = HostRequestMock(host="zulip.testserver")
|
request = HostRequestMock(host="zulip.testserver")
|
||||||
request.method = "POST"
|
request.method = "POST"
|
||||||
request.user = self.example_user("hamlet")
|
request.user = self.example_user("hamlet")
|
||||||
with mock.patch("zerver.decorator.rate_limit") as rate_limit_mock:
|
with mock.patch("zerver.decorator.rate_limit_user") as rate_limit_mock:
|
||||||
result = my_rate_limited_view(request)
|
result = my_rate_limited_view(request)
|
||||||
|
|
||||||
# Don't assert json_success, since it'll be the rate_limit mock object
|
# Don't assert json_success, since it'll be the rate_limit mock object
|
||||||
|
@ -631,11 +632,16 @@ class DecoratorLoggingTestCase(ZulipTestCase):
|
||||||
class RateLimitTestCase(ZulipTestCase):
|
class RateLimitTestCase(ZulipTestCase):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@public_json_view
|
@public_json_view
|
||||||
def public_view(
|
def ratelimited_json_view(
|
||||||
req: HttpRequest, maybe_user_profile: Union[AnonymousUser, UserProfile], /
|
req: HttpRequest, maybe_user_profile: Union[AnonymousUser, UserProfile], /
|
||||||
) -> HttpResponse:
|
) -> HttpResponse:
|
||||||
return HttpResponse("some value")
|
return HttpResponse("some value")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@web_public_view
|
||||||
|
def ratelimited_web_view(req: HttpRequest) -> HttpResponse:
|
||||||
|
return HttpResponse("some value")
|
||||||
|
|
||||||
def errors_disallowed(self) -> Any:
|
def errors_disallowed(self) -> Any:
|
||||||
# Due to what is probably a hack in rate_limit(),
|
# Due to what is probably a hack in rate_limit(),
|
||||||
# some tests will give a false positive (or succeed
|
# some tests will give a false positive (or succeed
|
||||||
|
@ -648,18 +654,23 @@ class RateLimitTestCase(ZulipTestCase):
|
||||||
return mock.patch("logging.error", side_effect=TestLoggingErrorException)
|
return mock.patch("logging.error", side_effect=TestLoggingErrorException)
|
||||||
|
|
||||||
def check_rate_limit_public_or_user_views(
|
def check_rate_limit_public_or_user_views(
|
||||||
self, remote_addr: str, client_name: str, expect_rate_limit: bool
|
self,
|
||||||
|
remote_addr: str,
|
||||||
|
client_name: str,
|
||||||
|
expect_rate_limit: bool,
|
||||||
|
check_web_view: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
META = {"REMOTE_ADDR": remote_addr, "PATH_INFO": "test"}
|
META = {"REMOTE_ADDR": remote_addr, "PATH_INFO": "test"}
|
||||||
|
|
||||||
request = HostRequestMock(host="zulip.testserver", client_name=client_name, meta_data=META)
|
request = HostRequestMock(host="zulip.testserver", client_name=client_name, meta_data=META)
|
||||||
|
view_func = self.ratelimited_web_view if check_web_view else self.ratelimited_json_view
|
||||||
|
|
||||||
with mock.patch(
|
with mock.patch(
|
||||||
"zerver.lib.rate_limiter.rate_limit_user"
|
"zerver.lib.rate_limiter.RateLimitedUser"
|
||||||
) as rate_limit_user_mock, mock.patch(
|
) as rate_limit_user_mock, mock.patch(
|
||||||
"zerver.lib.rate_limiter.rate_limit_ip"
|
"zerver.lib.rate_limiter.rate_limit_ip"
|
||||||
) as rate_limit_ip_mock, self.errors_disallowed():
|
) as rate_limit_ip_mock, self.errors_disallowed():
|
||||||
self.assert_in_success_response(["some value"], self.public_view(request))
|
self.assert_in_success_response(["some value"], view_func(request))
|
||||||
self.assertEqual(rate_limit_ip_mock.called, expect_rate_limit)
|
self.assertEqual(rate_limit_ip_mock.called, expect_rate_limit)
|
||||||
self.assertFalse(rate_limit_user_mock.called)
|
self.assertFalse(rate_limit_user_mock.called)
|
||||||
|
|
||||||
|
@ -671,11 +682,11 @@ class RateLimitTestCase(ZulipTestCase):
|
||||||
user_profile=user, host="zulip.testserver", client_name=client_name, meta_data=META
|
user_profile=user, host="zulip.testserver", client_name=client_name, meta_data=META
|
||||||
)
|
)
|
||||||
with mock.patch(
|
with mock.patch(
|
||||||
"zerver.lib.rate_limiter.rate_limit_user"
|
"zerver.lib.rate_limiter.RateLimitedUser"
|
||||||
) as rate_limit_user_mock, mock.patch(
|
) as rate_limit_user_mock, mock.patch(
|
||||||
"zerver.lib.rate_limiter.rate_limit_ip"
|
"zerver.lib.rate_limiter.rate_limit_ip"
|
||||||
) as rate_limit_ip_mock, self.errors_disallowed():
|
) as rate_limit_ip_mock, self.errors_disallowed():
|
||||||
self.assert_in_success_response(["some value"], self.public_view(request))
|
self.assert_in_success_response(["some value"], view_func(request))
|
||||||
self.assertEqual(rate_limit_user_mock.called, expect_rate_limit)
|
self.assertEqual(rate_limit_user_mock.called, expect_rate_limit)
|
||||||
self.assertFalse(rate_limit_ip_mock.called)
|
self.assertFalse(rate_limit_ip_mock.called)
|
||||||
|
|
||||||
|
@ -705,6 +716,15 @@ class RateLimitTestCase(ZulipTestCase):
|
||||||
remote_addr="3.3.3.3", client_name="external", expect_rate_limit=True
|
remote_addr="3.3.3.3", client_name="external", expect_rate_limit=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_rate_limiting_web_public_views(self) -> None:
|
||||||
|
with self.settings(RATE_LIMITING=True):
|
||||||
|
self.check_rate_limit_public_or_user_views(
|
||||||
|
remote_addr="3.3.3.3",
|
||||||
|
client_name="external",
|
||||||
|
expect_rate_limit=True,
|
||||||
|
check_web_view=True,
|
||||||
|
)
|
||||||
|
|
||||||
@skipUnless(settings.ZILENCER_ENABLED, "requires zilencer")
|
@skipUnless(settings.ZILENCER_ENABLED, "requires zilencer")
|
||||||
def test_rate_limiting_happens_if_remote_server(self) -> None:
|
def test_rate_limiting_happens_if_remote_server(self) -> None:
|
||||||
user = self.example_user("hamlet")
|
user = self.example_user("hamlet")
|
||||||
|
|
|
@ -54,6 +54,9 @@ class InvalidZulipServerKeyError(InvalidZulipServerError):
|
||||||
def rate_limit_remote_server(
|
def rate_limit_remote_server(
|
||||||
request: HttpRequest, remote_server: RemoteZulipServer, domain: str
|
request: HttpRequest, remote_server: RemoteZulipServer, domain: str
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if not should_rate_limit(request):
|
||||||
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
RateLimitedRemoteZulipServer(remote_server, domain=domain).rate_limit_request(request)
|
RateLimitedRemoteZulipServer(remote_server, domain=domain).rate_limit_request(request)
|
||||||
except RateLimited as e:
|
except RateLimited as e:
|
||||||
|
@ -98,8 +101,7 @@ def authenticated_remote_server_view(
|
||||||
except JsonableError as e:
|
except JsonableError as e:
|
||||||
raise UnauthorizedError(e.msg)
|
raise UnauthorizedError(e.msg)
|
||||||
|
|
||||||
if should_rate_limit(request):
|
rate_limit_remote_server(request, remote_server, domain="api_by_remote_server")
|
||||||
rate_limit_remote_server(request, remote_server, domain="api_by_remote_server")
|
|
||||||
return view_func(request, remote_server, *args, **kwargs)
|
return view_func(request, remote_server, *args, **kwargs)
|
||||||
|
|
||||||
return _wrapped_view_func
|
return _wrapped_view_func
|
||||||
|
|
Loading…
Reference in New Issue