From 26a518267a6b389840c622b4bddd843855ee0925 Mon Sep 17 00:00:00 2001 From: Zixuan James Li Date: Sun, 14 Aug 2022 15:14:52 -0400 Subject: [PATCH] rate_limit: Replace rate_limit with inlined rate limit checks. This change incorporate should_rate_limit into rate_limit_user and rate_limit_request_by_ip. Note a slight behavior change to other callers to rate_limit_request_by_ip is made as we now check if the client is eligible to be exempted from rate limiting now, which was previously only done as a part of zerver.lib.rate_limiter.rate_limit. Now we mock zerver.lib.rate_limiter.RateLimitedUser instead of zerver.decorator.rate_limit_user in zerver.tests.test_decorators.RateLimitTestCase, because rate_limit_user will always be called but rate limit only happens the should_rate_limit check passes; we can continue to mock zerver.lib.rate_limiter.rate_limit_ip, because the decorated view functions call rate_limit_request_by_ip that calls rate_limit_ip when the should_rate_limit check passes. We need to mock zerver.decorator.rate_limit_user for SkipRateLimitingTest now because rate_limit has been removed. We don't need to mock RateLimitedUser in this case because we are only verifying that the skip_rate_limiting flag works. To ensure coverage in add_logging_data, a new test case is added to use a web_public_view (which decorates the view function with add_logging_data) with a new flag to check_rate_limit_public_or_user_views. Signed-off-by: Zixuan James Li --- zerver/decorator.py | 21 +++++++++------- zerver/lib/rate_limiter.py | 18 ++++---------- zerver/tests/test_decorators.py | 44 ++++++++++++++++++++++++--------- zilencer/auth.py | 6 +++-- 4 files changed, 53 insertions(+), 36 deletions(-) diff --git a/zerver/decorator.py b/zerver/decorator.py index 9193981fe4..85676d93e2 100644 --- a/zerver/decorator.py +++ b/zerver/decorator.py @@ -53,7 +53,7 @@ from zerver.lib.exceptions import ( WebhookError, ) from zerver.lib.queue import queue_json_publish -from zerver.lib.rate_limiter import is_local_addr, rate_limit, rate_limit_user +from zerver.lib.rate_limiter import is_local_addr, rate_limit_request_by_ip, rate_limit_user from zerver.lib.request import REQ, RequestNotes, has_request_variables from zerver.lib.response import json_method_not_allowed, json_success from zerver.lib.subdomains import get_subdomain, user_matches_subdomain @@ -347,8 +347,7 @@ def webhook_view( client_name=full_webhook_client_name(webhook_client_name), ) - if settings.RATE_LIMITING: - rate_limit_user(request, user_profile, domain="api_by_user") + rate_limit_user(request, user_profile, domain="api_by_user") try: return view_func(request, user_profile, *args, **kwargs) except Exception as err: @@ -481,7 +480,12 @@ def add_logging_data( request: HttpRequest, /, *args: ParamT.args, **kwargs: ParamT.kwargs ) -> HttpResponse: process_client(request, request.user, is_browser_view=True, query=view_func.__name__) - rate_limit(request) + + if request.user.is_authenticated: + rate_limit_user(request, request.user, domain="api_by_user") + else: + rate_limit_request_by_ip(request, domain="api_by_ip") + return view_func(request, *args, **kwargs) return _wrapped_view_func @@ -673,7 +677,7 @@ def authenticated_uploads_api_view( ) -> HttpResponse: user_profile = validate_api_key(request, None, api_key, False) if not skip_rate_limiting: - rate_limit(request) + rate_limit_user(request, user_profile, domain="api_by_user") return view_func(request, user_profile, *args, **kwargs) return _wrapped_func_arguments @@ -751,8 +755,7 @@ def authenticated_rest_api_view( raise UnauthorizedError(e.msg) try: if not skip_rate_limiting: - # Apply rate limiting - rate_limit(request) + rate_limit_user(request, profile, domain="api_by_user") return view_func(request, profile, *args, **kwargs) except Exception as err: if not webhook_client_name: @@ -839,7 +842,7 @@ def public_json_view( # Otherwise, process the request for a logged-out visitor. if not skip_rate_limiting: - rate_limit(request) + rate_limit_request_by_ip(request, domain="api_by_ip") process_client( request, @@ -870,7 +873,7 @@ def authenticated_json_view( user_profile = request.user if not skip_rate_limiting: - rate_limit(request) + rate_limit_user(request, user_profile, domain="api_by_user") validate_account_and_subdomain(request, user_profile) diff --git a/zerver/lib/rate_limiter.py b/zerver/lib/rate_limiter.py index afc2ba3558..866b58821c 100644 --- a/zerver/lib/rate_limiter.py +++ b/zerver/lib/rate_limiter.py @@ -582,6 +582,8 @@ def rate_limit_user(request: HttpRequest, user: UserProfile, domain: str) -> Non """Returns whether or not a user was rate limited. Will raise a RateLimited exception if the user has been rate limited, otherwise returns and modifies request to contain the rate limit information""" + if not should_rate_limit(request): + return RateLimitedUser(user, domain=domain).rate_limit_request(request) @@ -591,6 +593,9 @@ def rate_limit_ip(request: HttpRequest, ip_addr: str, domain: str) -> None: def rate_limit_request_by_ip(request: HttpRequest, domain: str) -> None: + if not should_rate_limit(request): + return + # REMOTE_ADDR is set by SetRemoteAddrFromRealIpHeader in conjunction # with the nginx configuration to guarantee this to be *the* correct # IP address to use - without worrying we'll grab the IP of a proxy. @@ -625,16 +630,3 @@ def should_rate_limit(request: HttpRequest) -> bool: return False return True - - -def rate_limit(request: HttpRequest) -> None: - if not should_rate_limit(request): - return - - user = request.user - - if not user.is_authenticated: - rate_limit_request_by_ip(request, domain="api_by_ip") - else: - assert isinstance(user, UserProfile) - rate_limit_user(request, user, domain="api_by_user") diff --git a/zerver/tests/test_decorators.py b/zerver/tests/test_decorators.py index e362f4c246..2475bb6645 100644 --- a/zerver/tests/test_decorators.py +++ b/zerver/tests/test_decorators.py @@ -31,6 +31,7 @@ from zerver.decorator import ( public_json_view, return_success_on_head_request, validate_api_key, + web_public_view, webhook_view, zulip_login_required, zulip_otp_required_if_logged_in, @@ -467,7 +468,7 @@ class SkipRateLimitingTest(ZulipTestCase): request = HostRequestMock(host="zulip.testserver") request.META["HTTP_AUTHORIZATION"] = self.encode_email(self.example_email("hamlet")) request.method = "POST" - with mock.patch("zerver.decorator.rate_limit") as rate_limit_mock: + with mock.patch("zerver.decorator.rate_limit_user") as rate_limit_mock: result = my_unlimited_view(request) self.assert_json_success(result) @@ -476,7 +477,7 @@ class SkipRateLimitingTest(ZulipTestCase): request = HostRequestMock(host="zulip.testserver") request.META["HTTP_AUTHORIZATION"] = self.encode_email(self.example_email("hamlet")) request.method = "POST" - with mock.patch("zerver.decorator.rate_limit") as rate_limit_mock: + with mock.patch("zerver.decorator.rate_limit_user") as rate_limit_mock: result = my_rate_limited_view(request) # Don't assert json_success, since it'll be the rate_limit mock object @@ -494,7 +495,7 @@ class SkipRateLimitingTest(ZulipTestCase): request = HostRequestMock(host="zulip.testserver") request.method = "POST" request.POST["api_key"] = get_api_key(self.example_user("hamlet")) - with mock.patch("zerver.decorator.rate_limit") as rate_limit_mock: + with mock.patch("zerver.decorator.rate_limit_user") as rate_limit_mock: result = my_unlimited_view(request) self.assert_json_success(result) @@ -503,7 +504,7 @@ class SkipRateLimitingTest(ZulipTestCase): request = HostRequestMock(host="zulip.testserver") request.method = "POST" request.POST["api_key"] = get_api_key(self.example_user("hamlet")) - with mock.patch("zerver.decorator.rate_limit") as rate_limit_mock: + with mock.patch("zerver.decorator.rate_limit_user") as rate_limit_mock: result = my_rate_limited_view(request) # Don't assert json_success, since it'll be the rate_limit mock object @@ -519,7 +520,7 @@ class SkipRateLimitingTest(ZulipTestCase): request = HostRequestMock(host="zulip.testserver") request.method = "POST" request.user = self.example_user("hamlet") - with mock.patch("zerver.decorator.rate_limit") as rate_limit_mock: + with mock.patch("zerver.decorator.rate_limit_user") as rate_limit_mock: result = my_unlimited_view(request) self.assert_json_success(result) @@ -528,7 +529,7 @@ class SkipRateLimitingTest(ZulipTestCase): request = HostRequestMock(host="zulip.testserver") request.method = "POST" request.user = self.example_user("hamlet") - with mock.patch("zerver.decorator.rate_limit") as rate_limit_mock: + with mock.patch("zerver.decorator.rate_limit_user") as rate_limit_mock: result = my_rate_limited_view(request) # Don't assert json_success, since it'll be the rate_limit mock object @@ -631,11 +632,16 @@ class DecoratorLoggingTestCase(ZulipTestCase): class RateLimitTestCase(ZulipTestCase): @staticmethod @public_json_view - def public_view( + def ratelimited_json_view( req: HttpRequest, maybe_user_profile: Union[AnonymousUser, UserProfile], / ) -> HttpResponse: return HttpResponse("some value") + @staticmethod + @web_public_view + def ratelimited_web_view(req: HttpRequest) -> HttpResponse: + return HttpResponse("some value") + def errors_disallowed(self) -> Any: # Due to what is probably a hack in rate_limit(), # some tests will give a false positive (or succeed @@ -648,18 +654,23 @@ class RateLimitTestCase(ZulipTestCase): return mock.patch("logging.error", side_effect=TestLoggingErrorException) def check_rate_limit_public_or_user_views( - self, remote_addr: str, client_name: str, expect_rate_limit: bool + self, + remote_addr: str, + client_name: str, + expect_rate_limit: bool, + check_web_view: bool = False, ) -> None: META = {"REMOTE_ADDR": remote_addr, "PATH_INFO": "test"} request = HostRequestMock(host="zulip.testserver", client_name=client_name, meta_data=META) + view_func = self.ratelimited_web_view if check_web_view else self.ratelimited_json_view with mock.patch( - "zerver.lib.rate_limiter.rate_limit_user" + "zerver.lib.rate_limiter.RateLimitedUser" ) as rate_limit_user_mock, mock.patch( "zerver.lib.rate_limiter.rate_limit_ip" ) as rate_limit_ip_mock, self.errors_disallowed(): - self.assert_in_success_response(["some value"], self.public_view(request)) + self.assert_in_success_response(["some value"], view_func(request)) self.assertEqual(rate_limit_ip_mock.called, expect_rate_limit) self.assertFalse(rate_limit_user_mock.called) @@ -671,11 +682,11 @@ class RateLimitTestCase(ZulipTestCase): user_profile=user, host="zulip.testserver", client_name=client_name, meta_data=META ) with mock.patch( - "zerver.lib.rate_limiter.rate_limit_user" + "zerver.lib.rate_limiter.RateLimitedUser" ) as rate_limit_user_mock, mock.patch( "zerver.lib.rate_limiter.rate_limit_ip" ) as rate_limit_ip_mock, self.errors_disallowed(): - self.assert_in_success_response(["some value"], self.public_view(request)) + self.assert_in_success_response(["some value"], view_func(request)) self.assertEqual(rate_limit_user_mock.called, expect_rate_limit) self.assertFalse(rate_limit_ip_mock.called) @@ -705,6 +716,15 @@ class RateLimitTestCase(ZulipTestCase): remote_addr="3.3.3.3", client_name="external", expect_rate_limit=True ) + def test_rate_limiting_web_public_views(self) -> None: + with self.settings(RATE_LIMITING=True): + self.check_rate_limit_public_or_user_views( + remote_addr="3.3.3.3", + client_name="external", + expect_rate_limit=True, + check_web_view=True, + ) + @skipUnless(settings.ZILENCER_ENABLED, "requires zilencer") def test_rate_limiting_happens_if_remote_server(self) -> None: user = self.example_user("hamlet") diff --git a/zilencer/auth.py b/zilencer/auth.py index 04282bddd7..b4fa30d90e 100644 --- a/zilencer/auth.py +++ b/zilencer/auth.py @@ -54,6 +54,9 @@ class InvalidZulipServerKeyError(InvalidZulipServerError): def rate_limit_remote_server( request: HttpRequest, remote_server: RemoteZulipServer, domain: str ) -> None: + if not should_rate_limit(request): + return + try: RateLimitedRemoteZulipServer(remote_server, domain=domain).rate_limit_request(request) except RateLimited as e: @@ -98,8 +101,7 @@ def authenticated_remote_server_view( except JsonableError as e: raise UnauthorizedError(e.msg) - if should_rate_limit(request): - rate_limit_remote_server(request, remote_server, domain="api_by_remote_server") + rate_limit_remote_server(request, remote_server, domain="api_by_remote_server") return view_func(request, remote_server, *args, **kwargs) return _wrapped_view_func