diff --git a/zerver/tests/test_decorators.py b/zerver/tests/test_decorators.py index 0dc843eeb3..a8897d1ed3 100644 --- a/zerver/tests/test_decorators.py +++ b/zerver/tests/test_decorators.py @@ -647,64 +647,43 @@ class RateLimitTestCase(ZulipTestCase): return mock.patch("logging.error", side_effect=TestLoggingErrorException) - def test_internal_local_clients_skip_rate_limiting(self) -> None: - META = {"REMOTE_ADDR": "127.0.0.1"} + def check_rate_limit_public_or_user_views(self, remote_addr: str, client_name: str) -> None: + META = {"REMOTE_ADDR": remote_addr} - request = HostRequestMock(client_name="internal", meta_data=META) + request = HostRequestMock(client_name=client_name, meta_data=META) f = self.get_ratelimited_view() - with self.settings(RATE_LIMITING=True): - with mock.patch( - "zerver.lib.rate_limiter.rate_limit_user" - ) as rate_limit_user_mock, mock.patch( - "zerver.lib.rate_limiter.rate_limit_ip" - ) as rate_limit_ip_mock: - with self.errors_disallowed(): - self.assertEqual(orjson.loads(f(request).content).get("msg"), "some value") + with mock.patch( + "zerver.lib.rate_limiter.rate_limit_user" + ) as rate_limit_user_mock, mock.patch( + "zerver.lib.rate_limiter.rate_limit_ip" + ) as rate_limit_ip_mock: + with self.errors_disallowed(): + self.assertEqual(orjson.loads(f(request).content).get("msg"), "some value") self.assertFalse(rate_limit_ip_mock.called) self.assertFalse(rate_limit_user_mock.called) + def test_internal_local_clients_skip_rate_limiting(self) -> None: + with self.settings(RATE_LIMITING=True): + self.check_rate_limit_public_or_user_views( + remote_addr="127.0.0.1", client_name="internal" + ) + def test_debug_clients_skip_rate_limiting(self) -> None: - META = {"REMOTE_ADDR": "3.3.3.3"} - - req = HostRequestMock(client_name="internal", meta_data=META) - - f = self.get_ratelimited_view() - - with self.settings(RATE_LIMITING=True): - with mock.patch( - "zerver.lib.rate_limiter.rate_limit_user" - ) as rate_limit_user_mock, mock.patch( - "zerver.lib.rate_limiter.rate_limit_ip" - ) as rate_limit_ip_mock: - with self.errors_disallowed(): - with self.settings(DEBUG_RATE_LIMITING=True): - self.assertEqual(orjson.loads(f(req).content).get("msg"), "some value") - - self.assertFalse(rate_limit_ip_mock.called) - self.assertFalse(rate_limit_user_mock.called) + with self.settings(DEBUG_RATE_LIMITING=True, RATE_LIMITING=True): + # Rate limiting is skipped for internal clients with an external address + # when DEBUG_RATE_LIMITING is True. + self.check_rate_limit_public_or_user_views( + remote_addr="3.3.3.3", client_name="internal" + ) def test_rate_limit_setting_of_false_bypasses_rate_limiting(self) -> None: - META = {"REMOTE_ADDR": "3.3.3.3"} - user = self.example_user("hamlet") - - req = HostRequestMock(client_name="external", user_profile=user, meta_data=META) - - f = self.get_ratelimited_view() - with self.settings(RATE_LIMITING=False): - with mock.patch( - "zerver.lib.rate_limiter.rate_limit_user" - ) as rate_limit_user_mock, mock.patch( - "zerver.lib.rate_limiter.rate_limit_ip" - ) as rate_limit_ip_mock: - with self.errors_disallowed(): - self.assertEqual(orjson.loads(f(req).content).get("msg"), "some value") - - self.assertFalse(rate_limit_ip_mock.called) - self.assertFalse(rate_limit_user_mock.called) + self.check_rate_limit_public_or_user_views( + remote_addr="3.3.3.3", client_name="external" + ) def test_rate_limiting_happens_in_normal_case(self) -> None: META = {"REMOTE_ADDR": "3.3.3.3"}