rate_limit: Replace rate_limit with inlined rate limit checks.

This change incorporate should_rate_limit into rate_limit_user and
rate_limit_request_by_ip. Note a slight behavior change to other callers
to rate_limit_request_by_ip is made as we now check if the client is
eligible to be exempted from rate limiting now, which was previously
only done as a part of zerver.lib.rate_limiter.rate_limit.

Now we mock zerver.lib.rate_limiter.RateLimitedUser instead of
zerver.decorator.rate_limit_user in
zerver.tests.test_decorators.RateLimitTestCase, because rate_limit_user
will always be called but rate limit only happens the should_rate_limit
check passes;

we can continue to mock zerver.lib.rate_limiter.rate_limit_ip, because the
decorated view functions call rate_limit_request_by_ip that calls
rate_limit_ip when the should_rate_limit check passes.

We need to mock zerver.decorator.rate_limit_user for SkipRateLimitingTest
now because rate_limit has been removed. We don't need to mock
RateLimitedUser in this case because we are only verifying that
the skip_rate_limiting flag works.

To ensure coverage in add_logging_data, a new test case is added to use
a web_public_view (which decorates the view function with
add_logging_data) with a new flag to check_rate_limit_public_or_user_views.

Signed-off-by: Zixuan James Li <p359101898@gmail.com>
This commit is contained in:
Zixuan James Li 2022-08-14 15:14:52 -04:00 committed by Tim Abbott
parent 2aac1dc40a
commit 26a518267a
4 changed files with 53 additions and 36 deletions

View File

@ -53,7 +53,7 @@ from zerver.lib.exceptions import (
WebhookError,
)
from zerver.lib.queue import queue_json_publish
from zerver.lib.rate_limiter import is_local_addr, rate_limit, rate_limit_user
from zerver.lib.rate_limiter import is_local_addr, rate_limit_request_by_ip, rate_limit_user
from zerver.lib.request import REQ, RequestNotes, has_request_variables
from zerver.lib.response import json_method_not_allowed, json_success
from zerver.lib.subdomains import get_subdomain, user_matches_subdomain
@ -347,7 +347,6 @@ def webhook_view(
client_name=full_webhook_client_name(webhook_client_name),
)
if settings.RATE_LIMITING:
rate_limit_user(request, user_profile, domain="api_by_user")
try:
return view_func(request, user_profile, *args, **kwargs)
@ -481,7 +480,12 @@ def add_logging_data(
request: HttpRequest, /, *args: ParamT.args, **kwargs: ParamT.kwargs
) -> HttpResponse:
process_client(request, request.user, is_browser_view=True, query=view_func.__name__)
rate_limit(request)
if request.user.is_authenticated:
rate_limit_user(request, request.user, domain="api_by_user")
else:
rate_limit_request_by_ip(request, domain="api_by_ip")
return view_func(request, *args, **kwargs)
return _wrapped_view_func
@ -673,7 +677,7 @@ def authenticated_uploads_api_view(
) -> HttpResponse:
user_profile = validate_api_key(request, None, api_key, False)
if not skip_rate_limiting:
rate_limit(request)
rate_limit_user(request, user_profile, domain="api_by_user")
return view_func(request, user_profile, *args, **kwargs)
return _wrapped_func_arguments
@ -751,8 +755,7 @@ def authenticated_rest_api_view(
raise UnauthorizedError(e.msg)
try:
if not skip_rate_limiting:
# Apply rate limiting
rate_limit(request)
rate_limit_user(request, profile, domain="api_by_user")
return view_func(request, profile, *args, **kwargs)
except Exception as err:
if not webhook_client_name:
@ -839,7 +842,7 @@ def public_json_view(
# Otherwise, process the request for a logged-out visitor.
if not skip_rate_limiting:
rate_limit(request)
rate_limit_request_by_ip(request, domain="api_by_ip")
process_client(
request,
@ -870,7 +873,7 @@ def authenticated_json_view(
user_profile = request.user
if not skip_rate_limiting:
rate_limit(request)
rate_limit_user(request, user_profile, domain="api_by_user")
validate_account_and_subdomain(request, user_profile)

View File

@ -582,6 +582,8 @@ def rate_limit_user(request: HttpRequest, user: UserProfile, domain: str) -> Non
"""Returns whether or not a user was rate limited. Will raise a RateLimited exception
if the user has been rate limited, otherwise returns and modifies request to contain
the rate limit information"""
if not should_rate_limit(request):
return
RateLimitedUser(user, domain=domain).rate_limit_request(request)
@ -591,6 +593,9 @@ def rate_limit_ip(request: HttpRequest, ip_addr: str, domain: str) -> None:
def rate_limit_request_by_ip(request: HttpRequest, domain: str) -> None:
if not should_rate_limit(request):
return
# REMOTE_ADDR is set by SetRemoteAddrFromRealIpHeader in conjunction
# with the nginx configuration to guarantee this to be *the* correct
# IP address to use - without worrying we'll grab the IP of a proxy.
@ -625,16 +630,3 @@ def should_rate_limit(request: HttpRequest) -> bool:
return False
return True
def rate_limit(request: HttpRequest) -> None:
if not should_rate_limit(request):
return
user = request.user
if not user.is_authenticated:
rate_limit_request_by_ip(request, domain="api_by_ip")
else:
assert isinstance(user, UserProfile)
rate_limit_user(request, user, domain="api_by_user")

View File

@ -31,6 +31,7 @@ from zerver.decorator import (
public_json_view,
return_success_on_head_request,
validate_api_key,
web_public_view,
webhook_view,
zulip_login_required,
zulip_otp_required_if_logged_in,
@ -467,7 +468,7 @@ class SkipRateLimitingTest(ZulipTestCase):
request = HostRequestMock(host="zulip.testserver")
request.META["HTTP_AUTHORIZATION"] = self.encode_email(self.example_email("hamlet"))
request.method = "POST"
with mock.patch("zerver.decorator.rate_limit") as rate_limit_mock:
with mock.patch("zerver.decorator.rate_limit_user") as rate_limit_mock:
result = my_unlimited_view(request)
self.assert_json_success(result)
@ -476,7 +477,7 @@ class SkipRateLimitingTest(ZulipTestCase):
request = HostRequestMock(host="zulip.testserver")
request.META["HTTP_AUTHORIZATION"] = self.encode_email(self.example_email("hamlet"))
request.method = "POST"
with mock.patch("zerver.decorator.rate_limit") as rate_limit_mock:
with mock.patch("zerver.decorator.rate_limit_user") as rate_limit_mock:
result = my_rate_limited_view(request)
# Don't assert json_success, since it'll be the rate_limit mock object
@ -494,7 +495,7 @@ class SkipRateLimitingTest(ZulipTestCase):
request = HostRequestMock(host="zulip.testserver")
request.method = "POST"
request.POST["api_key"] = get_api_key(self.example_user("hamlet"))
with mock.patch("zerver.decorator.rate_limit") as rate_limit_mock:
with mock.patch("zerver.decorator.rate_limit_user") as rate_limit_mock:
result = my_unlimited_view(request)
self.assert_json_success(result)
@ -503,7 +504,7 @@ class SkipRateLimitingTest(ZulipTestCase):
request = HostRequestMock(host="zulip.testserver")
request.method = "POST"
request.POST["api_key"] = get_api_key(self.example_user("hamlet"))
with mock.patch("zerver.decorator.rate_limit") as rate_limit_mock:
with mock.patch("zerver.decorator.rate_limit_user") as rate_limit_mock:
result = my_rate_limited_view(request)
# Don't assert json_success, since it'll be the rate_limit mock object
@ -519,7 +520,7 @@ class SkipRateLimitingTest(ZulipTestCase):
request = HostRequestMock(host="zulip.testserver")
request.method = "POST"
request.user = self.example_user("hamlet")
with mock.patch("zerver.decorator.rate_limit") as rate_limit_mock:
with mock.patch("zerver.decorator.rate_limit_user") as rate_limit_mock:
result = my_unlimited_view(request)
self.assert_json_success(result)
@ -528,7 +529,7 @@ class SkipRateLimitingTest(ZulipTestCase):
request = HostRequestMock(host="zulip.testserver")
request.method = "POST"
request.user = self.example_user("hamlet")
with mock.patch("zerver.decorator.rate_limit") as rate_limit_mock:
with mock.patch("zerver.decorator.rate_limit_user") as rate_limit_mock:
result = my_rate_limited_view(request)
# Don't assert json_success, since it'll be the rate_limit mock object
@ -631,11 +632,16 @@ class DecoratorLoggingTestCase(ZulipTestCase):
class RateLimitTestCase(ZulipTestCase):
@staticmethod
@public_json_view
def public_view(
def ratelimited_json_view(
req: HttpRequest, maybe_user_profile: Union[AnonymousUser, UserProfile], /
) -> HttpResponse:
return HttpResponse("some value")
@staticmethod
@web_public_view
def ratelimited_web_view(req: HttpRequest) -> HttpResponse:
return HttpResponse("some value")
def errors_disallowed(self) -> Any:
# Due to what is probably a hack in rate_limit(),
# some tests will give a false positive (or succeed
@ -648,18 +654,23 @@ class RateLimitTestCase(ZulipTestCase):
return mock.patch("logging.error", side_effect=TestLoggingErrorException)
def check_rate_limit_public_or_user_views(
self, remote_addr: str, client_name: str, expect_rate_limit: bool
self,
remote_addr: str,
client_name: str,
expect_rate_limit: bool,
check_web_view: bool = False,
) -> None:
META = {"REMOTE_ADDR": remote_addr, "PATH_INFO": "test"}
request = HostRequestMock(host="zulip.testserver", client_name=client_name, meta_data=META)
view_func = self.ratelimited_web_view if check_web_view else self.ratelimited_json_view
with mock.patch(
"zerver.lib.rate_limiter.rate_limit_user"
"zerver.lib.rate_limiter.RateLimitedUser"
) as rate_limit_user_mock, mock.patch(
"zerver.lib.rate_limiter.rate_limit_ip"
) as rate_limit_ip_mock, self.errors_disallowed():
self.assert_in_success_response(["some value"], self.public_view(request))
self.assert_in_success_response(["some value"], view_func(request))
self.assertEqual(rate_limit_ip_mock.called, expect_rate_limit)
self.assertFalse(rate_limit_user_mock.called)
@ -671,11 +682,11 @@ class RateLimitTestCase(ZulipTestCase):
user_profile=user, host="zulip.testserver", client_name=client_name, meta_data=META
)
with mock.patch(
"zerver.lib.rate_limiter.rate_limit_user"
"zerver.lib.rate_limiter.RateLimitedUser"
) as rate_limit_user_mock, mock.patch(
"zerver.lib.rate_limiter.rate_limit_ip"
) as rate_limit_ip_mock, self.errors_disallowed():
self.assert_in_success_response(["some value"], self.public_view(request))
self.assert_in_success_response(["some value"], view_func(request))
self.assertEqual(rate_limit_user_mock.called, expect_rate_limit)
self.assertFalse(rate_limit_ip_mock.called)
@ -705,6 +716,15 @@ class RateLimitTestCase(ZulipTestCase):
remote_addr="3.3.3.3", client_name="external", expect_rate_limit=True
)
def test_rate_limiting_web_public_views(self) -> None:
with self.settings(RATE_LIMITING=True):
self.check_rate_limit_public_or_user_views(
remote_addr="3.3.3.3",
client_name="external",
expect_rate_limit=True,
check_web_view=True,
)
@skipUnless(settings.ZILENCER_ENABLED, "requires zilencer")
def test_rate_limiting_happens_if_remote_server(self) -> None:
user = self.example_user("hamlet")

View File

@ -54,6 +54,9 @@ class InvalidZulipServerKeyError(InvalidZulipServerError):
def rate_limit_remote_server(
request: HttpRequest, remote_server: RemoteZulipServer, domain: str
) -> None:
if not should_rate_limit(request):
return
try:
RateLimitedRemoteZulipServer(remote_server, domain=domain).rate_limit_request(request)
except RateLimited as e:
@ -98,7 +101,6 @@ def authenticated_remote_server_view(
except JsonableError as e:
raise UnauthorizedError(e.msg)
if should_rate_limit(request):
rate_limit_remote_server(request, remote_server, domain="api_by_remote_server")
return view_func(request, remote_server, *args, **kwargs)