rate_limit: Replace rate_limit with inlined rate limit checks.

This change incorporate should_rate_limit into rate_limit_user and
rate_limit_request_by_ip. Note a slight behavior change to other callers
to rate_limit_request_by_ip is made as we now check if the client is
eligible to be exempted from rate limiting now, which was previously
only done as a part of zerver.lib.rate_limiter.rate_limit.

Now we mock zerver.lib.rate_limiter.RateLimitedUser instead of
zerver.decorator.rate_limit_user in
zerver.tests.test_decorators.RateLimitTestCase, because rate_limit_user
will always be called but rate limit only happens the should_rate_limit
check passes;

we can continue to mock zerver.lib.rate_limiter.rate_limit_ip, because the
decorated view functions call rate_limit_request_by_ip that calls
rate_limit_ip when the should_rate_limit check passes.

We need to mock zerver.decorator.rate_limit_user for SkipRateLimitingTest
now because rate_limit has been removed. We don't need to mock
RateLimitedUser in this case because we are only verifying that
the skip_rate_limiting flag works.

To ensure coverage in add_logging_data, a new test case is added to use
a web_public_view (which decorates the view function with
add_logging_data) with a new flag to check_rate_limit_public_or_user_views.

Signed-off-by: Zixuan James Li <p359101898@gmail.com>
This commit is contained in:
Zixuan James Li 2022-08-14 15:14:52 -04:00 committed by Tim Abbott
parent 2aac1dc40a
commit 26a518267a
4 changed files with 53 additions and 36 deletions

View File

@ -53,7 +53,7 @@ from zerver.lib.exceptions import (
WebhookError, WebhookError,
) )
from zerver.lib.queue import queue_json_publish from zerver.lib.queue import queue_json_publish
from zerver.lib.rate_limiter import is_local_addr, rate_limit, rate_limit_user from zerver.lib.rate_limiter import is_local_addr, rate_limit_request_by_ip, rate_limit_user
from zerver.lib.request import REQ, RequestNotes, has_request_variables from zerver.lib.request import REQ, RequestNotes, has_request_variables
from zerver.lib.response import json_method_not_allowed, json_success from zerver.lib.response import json_method_not_allowed, json_success
from zerver.lib.subdomains import get_subdomain, user_matches_subdomain from zerver.lib.subdomains import get_subdomain, user_matches_subdomain
@ -347,8 +347,7 @@ def webhook_view(
client_name=full_webhook_client_name(webhook_client_name), client_name=full_webhook_client_name(webhook_client_name),
) )
if settings.RATE_LIMITING: rate_limit_user(request, user_profile, domain="api_by_user")
rate_limit_user(request, user_profile, domain="api_by_user")
try: try:
return view_func(request, user_profile, *args, **kwargs) return view_func(request, user_profile, *args, **kwargs)
except Exception as err: except Exception as err:
@ -481,7 +480,12 @@ def add_logging_data(
request: HttpRequest, /, *args: ParamT.args, **kwargs: ParamT.kwargs request: HttpRequest, /, *args: ParamT.args, **kwargs: ParamT.kwargs
) -> HttpResponse: ) -> HttpResponse:
process_client(request, request.user, is_browser_view=True, query=view_func.__name__) process_client(request, request.user, is_browser_view=True, query=view_func.__name__)
rate_limit(request)
if request.user.is_authenticated:
rate_limit_user(request, request.user, domain="api_by_user")
else:
rate_limit_request_by_ip(request, domain="api_by_ip")
return view_func(request, *args, **kwargs) return view_func(request, *args, **kwargs)
return _wrapped_view_func return _wrapped_view_func
@ -673,7 +677,7 @@ def authenticated_uploads_api_view(
) -> HttpResponse: ) -> HttpResponse:
user_profile = validate_api_key(request, None, api_key, False) user_profile = validate_api_key(request, None, api_key, False)
if not skip_rate_limiting: if not skip_rate_limiting:
rate_limit(request) rate_limit_user(request, user_profile, domain="api_by_user")
return view_func(request, user_profile, *args, **kwargs) return view_func(request, user_profile, *args, **kwargs)
return _wrapped_func_arguments return _wrapped_func_arguments
@ -751,8 +755,7 @@ def authenticated_rest_api_view(
raise UnauthorizedError(e.msg) raise UnauthorizedError(e.msg)
try: try:
if not skip_rate_limiting: if not skip_rate_limiting:
# Apply rate limiting rate_limit_user(request, profile, domain="api_by_user")
rate_limit(request)
return view_func(request, profile, *args, **kwargs) return view_func(request, profile, *args, **kwargs)
except Exception as err: except Exception as err:
if not webhook_client_name: if not webhook_client_name:
@ -839,7 +842,7 @@ def public_json_view(
# Otherwise, process the request for a logged-out visitor. # Otherwise, process the request for a logged-out visitor.
if not skip_rate_limiting: if not skip_rate_limiting:
rate_limit(request) rate_limit_request_by_ip(request, domain="api_by_ip")
process_client( process_client(
request, request,
@ -870,7 +873,7 @@ def authenticated_json_view(
user_profile = request.user user_profile = request.user
if not skip_rate_limiting: if not skip_rate_limiting:
rate_limit(request) rate_limit_user(request, user_profile, domain="api_by_user")
validate_account_and_subdomain(request, user_profile) validate_account_and_subdomain(request, user_profile)

View File

@ -582,6 +582,8 @@ def rate_limit_user(request: HttpRequest, user: UserProfile, domain: str) -> Non
"""Returns whether or not a user was rate limited. Will raise a RateLimited exception """Returns whether or not a user was rate limited. Will raise a RateLimited exception
if the user has been rate limited, otherwise returns and modifies request to contain if the user has been rate limited, otherwise returns and modifies request to contain
the rate limit information""" the rate limit information"""
if not should_rate_limit(request):
return
RateLimitedUser(user, domain=domain).rate_limit_request(request) RateLimitedUser(user, domain=domain).rate_limit_request(request)
@ -591,6 +593,9 @@ def rate_limit_ip(request: HttpRequest, ip_addr: str, domain: str) -> None:
def rate_limit_request_by_ip(request: HttpRequest, domain: str) -> None: def rate_limit_request_by_ip(request: HttpRequest, domain: str) -> None:
if not should_rate_limit(request):
return
# REMOTE_ADDR is set by SetRemoteAddrFromRealIpHeader in conjunction # REMOTE_ADDR is set by SetRemoteAddrFromRealIpHeader in conjunction
# with the nginx configuration to guarantee this to be *the* correct # with the nginx configuration to guarantee this to be *the* correct
# IP address to use - without worrying we'll grab the IP of a proxy. # IP address to use - without worrying we'll grab the IP of a proxy.
@ -625,16 +630,3 @@ def should_rate_limit(request: HttpRequest) -> bool:
return False return False
return True return True
def rate_limit(request: HttpRequest) -> None:
if not should_rate_limit(request):
return
user = request.user
if not user.is_authenticated:
rate_limit_request_by_ip(request, domain="api_by_ip")
else:
assert isinstance(user, UserProfile)
rate_limit_user(request, user, domain="api_by_user")

View File

@ -31,6 +31,7 @@ from zerver.decorator import (
public_json_view, public_json_view,
return_success_on_head_request, return_success_on_head_request,
validate_api_key, validate_api_key,
web_public_view,
webhook_view, webhook_view,
zulip_login_required, zulip_login_required,
zulip_otp_required_if_logged_in, zulip_otp_required_if_logged_in,
@ -467,7 +468,7 @@ class SkipRateLimitingTest(ZulipTestCase):
request = HostRequestMock(host="zulip.testserver") request = HostRequestMock(host="zulip.testserver")
request.META["HTTP_AUTHORIZATION"] = self.encode_email(self.example_email("hamlet")) request.META["HTTP_AUTHORIZATION"] = self.encode_email(self.example_email("hamlet"))
request.method = "POST" request.method = "POST"
with mock.patch("zerver.decorator.rate_limit") as rate_limit_mock: with mock.patch("zerver.decorator.rate_limit_user") as rate_limit_mock:
result = my_unlimited_view(request) result = my_unlimited_view(request)
self.assert_json_success(result) self.assert_json_success(result)
@ -476,7 +477,7 @@ class SkipRateLimitingTest(ZulipTestCase):
request = HostRequestMock(host="zulip.testserver") request = HostRequestMock(host="zulip.testserver")
request.META["HTTP_AUTHORIZATION"] = self.encode_email(self.example_email("hamlet")) request.META["HTTP_AUTHORIZATION"] = self.encode_email(self.example_email("hamlet"))
request.method = "POST" request.method = "POST"
with mock.patch("zerver.decorator.rate_limit") as rate_limit_mock: with mock.patch("zerver.decorator.rate_limit_user") as rate_limit_mock:
result = my_rate_limited_view(request) result = my_rate_limited_view(request)
# Don't assert json_success, since it'll be the rate_limit mock object # Don't assert json_success, since it'll be the rate_limit mock object
@ -494,7 +495,7 @@ class SkipRateLimitingTest(ZulipTestCase):
request = HostRequestMock(host="zulip.testserver") request = HostRequestMock(host="zulip.testserver")
request.method = "POST" request.method = "POST"
request.POST["api_key"] = get_api_key(self.example_user("hamlet")) request.POST["api_key"] = get_api_key(self.example_user("hamlet"))
with mock.patch("zerver.decorator.rate_limit") as rate_limit_mock: with mock.patch("zerver.decorator.rate_limit_user") as rate_limit_mock:
result = my_unlimited_view(request) result = my_unlimited_view(request)
self.assert_json_success(result) self.assert_json_success(result)
@ -503,7 +504,7 @@ class SkipRateLimitingTest(ZulipTestCase):
request = HostRequestMock(host="zulip.testserver") request = HostRequestMock(host="zulip.testserver")
request.method = "POST" request.method = "POST"
request.POST["api_key"] = get_api_key(self.example_user("hamlet")) request.POST["api_key"] = get_api_key(self.example_user("hamlet"))
with mock.patch("zerver.decorator.rate_limit") as rate_limit_mock: with mock.patch("zerver.decorator.rate_limit_user") as rate_limit_mock:
result = my_rate_limited_view(request) result = my_rate_limited_view(request)
# Don't assert json_success, since it'll be the rate_limit mock object # Don't assert json_success, since it'll be the rate_limit mock object
@ -519,7 +520,7 @@ class SkipRateLimitingTest(ZulipTestCase):
request = HostRequestMock(host="zulip.testserver") request = HostRequestMock(host="zulip.testserver")
request.method = "POST" request.method = "POST"
request.user = self.example_user("hamlet") request.user = self.example_user("hamlet")
with mock.patch("zerver.decorator.rate_limit") as rate_limit_mock: with mock.patch("zerver.decorator.rate_limit_user") as rate_limit_mock:
result = my_unlimited_view(request) result = my_unlimited_view(request)
self.assert_json_success(result) self.assert_json_success(result)
@ -528,7 +529,7 @@ class SkipRateLimitingTest(ZulipTestCase):
request = HostRequestMock(host="zulip.testserver") request = HostRequestMock(host="zulip.testserver")
request.method = "POST" request.method = "POST"
request.user = self.example_user("hamlet") request.user = self.example_user("hamlet")
with mock.patch("zerver.decorator.rate_limit") as rate_limit_mock: with mock.patch("zerver.decorator.rate_limit_user") as rate_limit_mock:
result = my_rate_limited_view(request) result = my_rate_limited_view(request)
# Don't assert json_success, since it'll be the rate_limit mock object # Don't assert json_success, since it'll be the rate_limit mock object
@ -631,11 +632,16 @@ class DecoratorLoggingTestCase(ZulipTestCase):
class RateLimitTestCase(ZulipTestCase): class RateLimitTestCase(ZulipTestCase):
@staticmethod @staticmethod
@public_json_view @public_json_view
def public_view( def ratelimited_json_view(
req: HttpRequest, maybe_user_profile: Union[AnonymousUser, UserProfile], / req: HttpRequest, maybe_user_profile: Union[AnonymousUser, UserProfile], /
) -> HttpResponse: ) -> HttpResponse:
return HttpResponse("some value") return HttpResponse("some value")
@staticmethod
@web_public_view
def ratelimited_web_view(req: HttpRequest) -> HttpResponse:
return HttpResponse("some value")
def errors_disallowed(self) -> Any: def errors_disallowed(self) -> Any:
# Due to what is probably a hack in rate_limit(), # Due to what is probably a hack in rate_limit(),
# some tests will give a false positive (or succeed # some tests will give a false positive (or succeed
@ -648,18 +654,23 @@ class RateLimitTestCase(ZulipTestCase):
return mock.patch("logging.error", side_effect=TestLoggingErrorException) return mock.patch("logging.error", side_effect=TestLoggingErrorException)
def check_rate_limit_public_or_user_views( def check_rate_limit_public_or_user_views(
self, remote_addr: str, client_name: str, expect_rate_limit: bool self,
remote_addr: str,
client_name: str,
expect_rate_limit: bool,
check_web_view: bool = False,
) -> None: ) -> None:
META = {"REMOTE_ADDR": remote_addr, "PATH_INFO": "test"} META = {"REMOTE_ADDR": remote_addr, "PATH_INFO": "test"}
request = HostRequestMock(host="zulip.testserver", client_name=client_name, meta_data=META) request = HostRequestMock(host="zulip.testserver", client_name=client_name, meta_data=META)
view_func = self.ratelimited_web_view if check_web_view else self.ratelimited_json_view
with mock.patch( with mock.patch(
"zerver.lib.rate_limiter.rate_limit_user" "zerver.lib.rate_limiter.RateLimitedUser"
) as rate_limit_user_mock, mock.patch( ) as rate_limit_user_mock, mock.patch(
"zerver.lib.rate_limiter.rate_limit_ip" "zerver.lib.rate_limiter.rate_limit_ip"
) as rate_limit_ip_mock, self.errors_disallowed(): ) as rate_limit_ip_mock, self.errors_disallowed():
self.assert_in_success_response(["some value"], self.public_view(request)) self.assert_in_success_response(["some value"], view_func(request))
self.assertEqual(rate_limit_ip_mock.called, expect_rate_limit) self.assertEqual(rate_limit_ip_mock.called, expect_rate_limit)
self.assertFalse(rate_limit_user_mock.called) self.assertFalse(rate_limit_user_mock.called)
@ -671,11 +682,11 @@ class RateLimitTestCase(ZulipTestCase):
user_profile=user, host="zulip.testserver", client_name=client_name, meta_data=META user_profile=user, host="zulip.testserver", client_name=client_name, meta_data=META
) )
with mock.patch( with mock.patch(
"zerver.lib.rate_limiter.rate_limit_user" "zerver.lib.rate_limiter.RateLimitedUser"
) as rate_limit_user_mock, mock.patch( ) as rate_limit_user_mock, mock.patch(
"zerver.lib.rate_limiter.rate_limit_ip" "zerver.lib.rate_limiter.rate_limit_ip"
) as rate_limit_ip_mock, self.errors_disallowed(): ) as rate_limit_ip_mock, self.errors_disallowed():
self.assert_in_success_response(["some value"], self.public_view(request)) self.assert_in_success_response(["some value"], view_func(request))
self.assertEqual(rate_limit_user_mock.called, expect_rate_limit) self.assertEqual(rate_limit_user_mock.called, expect_rate_limit)
self.assertFalse(rate_limit_ip_mock.called) self.assertFalse(rate_limit_ip_mock.called)
@ -705,6 +716,15 @@ class RateLimitTestCase(ZulipTestCase):
remote_addr="3.3.3.3", client_name="external", expect_rate_limit=True remote_addr="3.3.3.3", client_name="external", expect_rate_limit=True
) )
def test_rate_limiting_web_public_views(self) -> None:
with self.settings(RATE_LIMITING=True):
self.check_rate_limit_public_or_user_views(
remote_addr="3.3.3.3",
client_name="external",
expect_rate_limit=True,
check_web_view=True,
)
@skipUnless(settings.ZILENCER_ENABLED, "requires zilencer") @skipUnless(settings.ZILENCER_ENABLED, "requires zilencer")
def test_rate_limiting_happens_if_remote_server(self) -> None: def test_rate_limiting_happens_if_remote_server(self) -> None:
user = self.example_user("hamlet") user = self.example_user("hamlet")

View File

@ -54,6 +54,9 @@ class InvalidZulipServerKeyError(InvalidZulipServerError):
def rate_limit_remote_server( def rate_limit_remote_server(
request: HttpRequest, remote_server: RemoteZulipServer, domain: str request: HttpRequest, remote_server: RemoteZulipServer, domain: str
) -> None: ) -> None:
if not should_rate_limit(request):
return
try: try:
RateLimitedRemoteZulipServer(remote_server, domain=domain).rate_limit_request(request) RateLimitedRemoteZulipServer(remote_server, domain=domain).rate_limit_request(request)
except RateLimited as e: except RateLimited as e:
@ -98,8 +101,7 @@ def authenticated_remote_server_view(
except JsonableError as e: except JsonableError as e:
raise UnauthorizedError(e.msg) raise UnauthorizedError(e.msg)
if should_rate_limit(request): rate_limit_remote_server(request, remote_server, domain="api_by_remote_server")
rate_limit_remote_server(request, remote_server, domain="api_by_remote_server")
return view_func(request, remote_server, *args, **kwargs) return view_func(request, remote_server, *args, **kwargs)
return _wrapped_view_func return _wrapped_view_func