test_helpers: Switch add/remove_ratelimit to a contextmanager.

Failing to remove all of the rules which were added causes action at a
distance with other tests.  The two methods were also only used by
test code, making their existence in zerver.lib.rate_limiter clearly
misplaced.

This fixes one instance of a mis-balanced add/remove, which caused
tests to start failing if run non-parallel and one more anonymous
request was added within a rate-limit-enabled block.
This commit is contained in:
Alex Vandiver 2023-06-07 21:01:42 +00:00 committed by Tim Abbott
parent da09e003f6
commit 0dbe111ab3
10 changed files with 145 additions and 179 deletions

View File

@ -159,23 +159,6 @@ def bounce_redis_key_prefix_for_testing(test_name: str) -> None:
KEY_PREFIX = test_name + ":" + str(os.getpid()) + ":" KEY_PREFIX = test_name + ":" + str(os.getpid()) + ":"
def add_ratelimit_rule(range_seconds: int, num_requests: int, domain: str = "api_by_user") -> None:
"""Add a rate-limiting rule to the ratelimiter"""
if domain not in rules:
# If we don't have any rules for domain yet, the domain key needs to be
# added to the rules dictionary.
rules[domain] = []
rules[domain].append((range_seconds, num_requests))
rules[domain].sort(key=lambda x: x[0])
def remove_ratelimit_rule(
range_seconds: int, num_requests: int, domain: str = "api_by_user"
) -> None:
rules[domain] = [x for x in rules[domain] if x[0] != range_seconds and x[1] != num_requests]
class RateLimiterBackend(ABC): class RateLimiterBackend(ABC):
@classmethod @classmethod
@abstractmethod @abstractmethod

View File

@ -22,6 +22,7 @@ from typing import (
cast, cast,
) )
from unittest import mock from unittest import mock
from unittest.mock import patch
import boto3.session import boto3.session
import fakeldap import fakeldap
@ -45,6 +46,7 @@ from zerver.lib.avatar import avatar_url
from zerver.lib.cache import get_cache_backend from zerver.lib.cache import get_cache_backend
from zerver.lib.db import Params, ParamsT, Query, TimeTrackingCursor from zerver.lib.db import Params, ParamsT, Query, TimeTrackingCursor
from zerver.lib.integrations import WEBHOOK_INTEGRATIONS from zerver.lib.integrations import WEBHOOK_INTEGRATIONS
from zerver.lib.rate_limiter import RateLimitedIPAddr, rules
from zerver.lib.request import RequestNotes from zerver.lib.request import RequestNotes
from zerver.lib.upload.s3 import S3UploadBackend from zerver.lib.upload.s3 import S3UploadBackend
from zerver.models import ( from zerver.models import (
@ -743,3 +745,20 @@ def timeout_mock(mock_path: str) -> Iterator[None]:
with mock.patch(f"{mock_path}.timeout", new=mock_timeout): with mock.patch(f"{mock_path}.timeout", new=mock_timeout):
yield yield
@contextmanager
def ratelimit_rule(
range_seconds: int,
num_requests: int,
domain: str = "api_by_user",
) -> Iterator[None]:
"""Temporarily add a rate-limiting rule to the ratelimiter"""
RateLimitedIPAddr("127.0.0.1", domain=domain).clear_history()
domain_rules = rules.get(domain, []).copy()
domain_rules.append((range_seconds, num_requests))
domain_rules.sort(key=lambda x: x[0])
with patch.dict(rules, {domain: domain_rules}), override_settings(RATE_LIMITING=True):
yield

View File

@ -76,7 +76,6 @@ from zerver.lib.email_validation import (
from zerver.lib.exceptions import JsonableError, RateLimitedError from zerver.lib.exceptions import JsonableError, RateLimitedError
from zerver.lib.initial_password import initial_password from zerver.lib.initial_password import initial_password
from zerver.lib.mobile_auth_otp import otp_decrypt_api_key from zerver.lib.mobile_auth_otp import otp_decrypt_api_key
from zerver.lib.rate_limiter import add_ratelimit_rule, remove_ratelimit_rule
from zerver.lib.storage import static_path from zerver.lib.storage import static_path
from zerver.lib.streams import ensure_stream from zerver.lib.streams import ensure_stream
from zerver.lib.test_classes import ZulipTestCase from zerver.lib.test_classes import ZulipTestCase
@ -84,6 +83,7 @@ from zerver.lib.test_helpers import (
HostRequestMock, HostRequestMock,
create_s3_buckets, create_s3_buckets,
load_subdomain_token, load_subdomain_token,
ratelimit_rule,
read_test_image_file, read_test_image_file,
use_s3_backend, use_s3_backend,
) )
@ -671,8 +671,9 @@ class RateLimitAuthenticationTests(ZulipTestCase):
request.session = mock.MagicMock() request.session = mock.MagicMock()
return attempt_authentication_func(request, username, password) return attempt_authentication_func(request, username, password)
add_ratelimit_rule(10, 2, domain="authenticate_by_username") with mock.patch.object(
with mock.patch.object(RateLimitedAuthenticationByUsername, "key", new=_mock_key): RateLimitedAuthenticationByUsername, "key", new=_mock_key
), ratelimit_rule(10, 2, domain="authenticate_by_username"):
try: try:
start_time = time.time() start_time = time.time()
with mock.patch("time.time", return_value=start_time): with mock.patch("time.time", return_value=start_time):
@ -708,7 +709,6 @@ class RateLimitAuthenticationTests(ZulipTestCase):
finally: finally:
# Clean up to avoid affecting other tests. # Clean up to avoid affecting other tests.
RateLimitedAuthenticationByUsername(username).clear_history() RateLimitedAuthenticationByUsername(username).clear_history()
remove_ratelimit_rule(10, 2, domain="authenticate_by_username")
def test_email_auth_backend_user_based_rate_limiting(self) -> None: def test_email_auth_backend_user_based_rate_limiting(self) -> None:
user_profile = self.example_user("hamlet") user_profile = self.example_user("hamlet")

View File

@ -18,11 +18,10 @@ from zerver.lib.rate_limiter import (
RateLimitedIPAddr, RateLimitedIPAddr,
RateLimitedUser, RateLimitedUser,
RateLimiterLockingError, RateLimiterLockingError,
add_ratelimit_rule,
get_tor_ips, get_tor_ips,
remove_ratelimit_rule,
) )
from zerver.lib.test_classes import ZulipTestCase from zerver.lib.test_classes import ZulipTestCase
from zerver.lib.test_helpers import ratelimit_rule
from zerver.lib.zephyr import compute_mit_user_fullname from zerver.lib.zephyr import compute_mit_user_fullname
from zerver.models import PushDeviceToken, UserProfile from zerver.models import PushDeviceToken, UserProfile
@ -76,18 +75,6 @@ class MITNameTest(ZulipTestCase):
email_is_not_mit_mailing_list("sipbexch@mit.edu") email_is_not_mit_mailing_list("sipbexch@mit.edu")
@contextmanager
def rate_limit_rule(range_seconds: int, num_requests: int, domain: str) -> Iterator[None]:
RateLimitedIPAddr("127.0.0.1", domain=domain).clear_history()
add_ratelimit_rule(range_seconds, num_requests, domain=domain)
try:
yield
finally:
# We need this in a finally block to ensure the test cleans up after itself
# even in case of failure, to avoid polluting the rules state.
remove_ratelimit_rule(range_seconds, num_requests, domain=domain)
class RateLimitTests(ZulipTestCase): class RateLimitTests(ZulipTestCase):
def setUp(self) -> None: def setUp(self) -> None:
super().setUp() super().setUp()
@ -194,14 +181,14 @@ class RateLimitTests(ZulipTestCase):
self.assertNotEqual(result.status_code, 429) self.assertNotEqual(result.status_code, 429)
@rate_limit_rule(1, 5, domain="api_by_user") @ratelimit_rule(1, 5, domain="api_by_user")
def test_hit_ratelimits_as_user(self) -> None: def test_hit_ratelimits_as_user(self) -> None:
user = self.example_user("cordelia") user = self.example_user("cordelia")
RateLimitedUser(user).clear_history() RateLimitedUser(user).clear_history()
self.do_test_hit_ratelimits(lambda: self.send_api_message(user, "some stuff")) self.do_test_hit_ratelimits(lambda: self.send_api_message(user, "some stuff"))
@rate_limit_rule(1, 5, domain="email_change_by_user") @ratelimit_rule(1, 5, domain="email_change_by_user")
def test_hit_change_email_ratelimit_as_user(self) -> None: def test_hit_change_email_ratelimit_as_user(self) -> None:
user = self.example_user("cordelia") user = self.example_user("cordelia")
RateLimitedUser(user).clear_history() RateLimitedUser(user).clear_history()
@ -211,7 +198,7 @@ class RateLimitTests(ZulipTestCase):
lambda: self.api_patch(user, "/api/v1/settings", {"email": emails.pop()}), lambda: self.api_patch(user, "/api/v1/settings", {"email": emails.pop()}),
) )
@rate_limit_rule(1, 5, domain="api_by_ip") @ratelimit_rule(1, 5, domain="api_by_ip")
def test_hit_ratelimits_as_ip(self) -> None: def test_hit_ratelimits_as_ip(self) -> None:
self.do_test_hit_ratelimits(self.send_unauthed_api_request) self.do_test_hit_ratelimits(self.send_unauthed_api_request)
@ -219,7 +206,7 @@ class RateLimitTests(ZulipTestCase):
resp = self.send_unauthed_api_request(REMOTE_ADDR="127.0.0.2") resp = self.send_unauthed_api_request(REMOTE_ADDR="127.0.0.2")
self.assertNotEqual(resp.status_code, 429) self.assertNotEqual(resp.status_code, 429)
@rate_limit_rule(1, 5, domain="sends_email_by_ip") @ratelimit_rule(1, 5, domain="sends_email_by_ip")
def test_create_realm_rate_limiting(self) -> None: def test_create_realm_rate_limiting(self) -> None:
with self.settings(OPEN_REALM_CREATION=True): with self.settings(OPEN_REALM_CREATION=True):
self.do_test_hit_ratelimits( self.do_test_hit_ratelimits(
@ -229,14 +216,14 @@ class RateLimitTests(ZulipTestCase):
is_json=False, is_json=False,
) )
@rate_limit_rule(1, 5, domain="sends_email_by_ip") @ratelimit_rule(1, 5, domain="sends_email_by_ip")
def test_find_account_rate_limiting(self) -> None: def test_find_account_rate_limiting(self) -> None:
self.do_test_hit_ratelimits( self.do_test_hit_ratelimits(
lambda: self.client_post("/accounts/find/", {"emails": "new@zulip.com"}), lambda: self.client_post("/accounts/find/", {"emails": "new@zulip.com"}),
is_json=False, is_json=False,
) )
@rate_limit_rule(1, 5, domain="sends_email_by_ip") @ratelimit_rule(1, 5, domain="sends_email_by_ip")
def test_password_reset_rate_limiting(self) -> None: def test_password_reset_rate_limiting(self) -> None:
with self.assertLogs(level="INFO") as m: with self.assertLogs(level="INFO") as m:
self.do_test_hit_ratelimits( self.do_test_hit_ratelimits(
@ -251,7 +238,7 @@ class RateLimitTests(ZulipTestCase):
# Test whether submitting multiple emails is handled correctly. # Test whether submitting multiple emails is handled correctly.
# The limit is set to 10 per second, so 5 requests with 2 emails # The limit is set to 10 per second, so 5 requests with 2 emails
# submitted in each should be allowed. # submitted in each should be allowed.
@rate_limit_rule(1, 10, domain="sends_email_by_ip") @ratelimit_rule(1, 10, domain="sends_email_by_ip")
def test_find_account_rate_limiting_multiple(self) -> None: def test_find_account_rate_limiting_multiple(self) -> None:
self.do_test_hit_ratelimits( self.do_test_hit_ratelimits(
lambda: self.client_post("/accounts/find/", {"emails": "new@zulip.com,new2@zulip.com"}), lambda: self.client_post("/accounts/find/", {"emails": "new@zulip.com,new2@zulip.com"}),
@ -260,7 +247,7 @@ class RateLimitTests(ZulipTestCase):
# If I submit with 3 emails and the rate-limit is 2, I should get # If I submit with 3 emails and the rate-limit is 2, I should get
# a 429 and not send any emails. # a 429 and not send any emails.
@rate_limit_rule(1, 2, domain="sends_email_by_ip") @ratelimit_rule(1, 2, domain="sends_email_by_ip")
def test_find_account_rate_limiting_multiple_one_request(self) -> None: def test_find_account_rate_limiting_multiple_one_request(self) -> None:
emails = [ emails = [
"iago@zulip.com", "iago@zulip.com",
@ -274,14 +261,14 @@ class RateLimitTests(ZulipTestCase):
self.assert_length(outbox, 0) self.assert_length(outbox, 0)
@rate_limit_rule(1, 5, domain="sends_email_by_ip") @ratelimit_rule(1, 5, domain="sends_email_by_ip")
def test_register_account_rate_limiting(self) -> None: def test_register_account_rate_limiting(self) -> None:
self.do_test_hit_ratelimits( self.do_test_hit_ratelimits(
lambda: self.client_post("/register/", {"email": "new@zulip.com"}), lambda: self.client_post("/register/", {"email": "new@zulip.com"}),
is_json=False, is_json=False,
) )
@rate_limit_rule(1, 5, domain="sends_email_by_ip") @ratelimit_rule(1, 5, domain="sends_email_by_ip")
def test_combined_ip_limits(self) -> None: def test_combined_ip_limits(self) -> None:
# Alternate requests to /new/ and /accounts/find/ # Alternate requests to /new/ and /accounts/find/
request_count = 0 request_count = 0
@ -332,7 +319,7 @@ class RateLimitTests(ZulipTestCase):
with mock.patch("builtins.open", selective_mock_open): with mock.patch("builtins.open", selective_mock_open):
yield tor_open yield tor_open
@rate_limit_rule(1, 5, domain="api_by_ip") @ratelimit_rule(1, 5, domain="api_by_ip")
@override_settings(RATE_LIMIT_TOR_TOGETHER=True) @override_settings(RATE_LIMIT_TOR_TOGETHER=True)
def test_tor_ip_limits(self) -> None: def test_tor_ip_limits(self) -> None:
request_count = 0 request_count = 0
@ -354,7 +341,7 @@ class RateLimitTests(ZulipTestCase):
tor_open.assert_called_once_with(settings.TOR_EXIT_NODE_FILE_PATH, "rb") tor_open.assert_called_once_with(settings.TOR_EXIT_NODE_FILE_PATH, "rb")
tor_open().read.assert_called_once() tor_open().read.assert_called_once()
@rate_limit_rule(1, 5, domain="api_by_ip") @ratelimit_rule(1, 5, domain="api_by_ip")
@override_settings(RATE_LIMIT_TOR_TOGETHER=True) @override_settings(RATE_LIMIT_TOR_TOGETHER=True)
def test_tor_file_empty(self) -> None: def test_tor_file_empty(self) -> None:
for ip in ["1.2.3.4", "5.6.7.8", "tor-exit-node"]: for ip in ["1.2.3.4", "5.6.7.8", "tor-exit-node"]:
@ -375,7 +362,7 @@ class RateLimitTests(ZulipTestCase):
# circuit-breaker, and stopping trying # circuit-breaker, and stopping trying
tor_open().read.assert_has_calls([mock.call(), mock.call()]) tor_open().read.assert_has_calls([mock.call(), mock.call()])
@rate_limit_rule(1, 5, domain="api_by_ip") @ratelimit_rule(1, 5, domain="api_by_ip")
@override_settings(RATE_LIMIT_TOR_TOGETHER=True) @override_settings(RATE_LIMIT_TOR_TOGETHER=True)
def test_tor_file_not_found(self) -> None: def test_tor_file_not_found(self) -> None:
for ip in ["1.2.3.4", "5.6.7.8", "tor-exit-node"]: for ip in ["1.2.3.4", "5.6.7.8", "tor-exit-node"]:
@ -415,7 +402,7 @@ class RateLimitTests(ZulipTestCase):
) )
@skipUnless(settings.ZILENCER_ENABLED, "requires zilencer") @skipUnless(settings.ZILENCER_ENABLED, "requires zilencer")
@rate_limit_rule(1, 5, domain="api_by_remote_server") @ratelimit_rule(1, 5, domain="api_by_remote_server")
def test_hit_ratelimits_as_remote_server(self) -> None: def test_hit_ratelimits_as_remote_server(self) -> None:
server_uuid = str(uuid.uuid4()) server_uuid = str(uuid.uuid4())
server = RemoteZulipServer( server = RemoteZulipServer(

View File

@ -11,10 +11,9 @@ from zerver.lib.rate_limiter import (
RateLimiterBackend, RateLimiterBackend,
RedisRateLimiterBackend, RedisRateLimiterBackend,
TornadoInMemoryRateLimiterBackend, TornadoInMemoryRateLimiterBackend,
add_ratelimit_rule,
remove_ratelimit_rule,
) )
from zerver.lib.test_classes import ZulipTestCase from zerver.lib.test_classes import ZulipTestCase
from zerver.lib.test_helpers import ratelimit_rule
RANDOM_KEY_PREFIX = secrets.token_hex(16) RANDOM_KEY_PREFIX = secrets.token_hex(16)
@ -227,27 +226,18 @@ class RateLimitedObjectsTest(ZulipTestCase):
self.assertEqual(obj.get_rules(), [(1, 3), (2, 4)]) self.assertEqual(obj.get_rules(), [(1, 3), (2, 4)])
def test_add_remove_rule(self) -> None: def test_ratelimit_rule(self) -> None:
user_profile = self.example_user("hamlet") user_profile = self.example_user("hamlet")
try: with ratelimit_rule(1, 2), ratelimit_rule(4, 5, domain="some_new_domain"):
add_ratelimit_rule(1, 2) with ratelimit_rule(10, 100, domain="some_new_domain"):
add_ratelimit_rule(4, 5, domain="some_new_domain") obj = RateLimitedUser(user_profile)
add_ratelimit_rule(10, 100, domain="some_new_domain")
obj = RateLimitedUser(user_profile)
self.assertEqual(obj.get_rules(), [(1, 2)]) self.assertEqual(obj.get_rules(), [(1, 2)])
obj.domain = "some_new_domain" obj.domain = "some_new_domain"
self.assertEqual(obj.get_rules(), [(4, 5), (10, 100)]) self.assertEqual(obj.get_rules(), [(4, 5), (10, 100)])
remove_ratelimit_rule(10, 100, domain="some_new_domain")
self.assertEqual(obj.get_rules(), [(4, 5)]) self.assertEqual(obj.get_rules(), [(4, 5)])
finally:
# Ensure all the rules get cleaned up.
remove_ratelimit_rule(1, 2)
remove_ratelimit_rule(4, 5, domain="some_new_domain")
remove_ratelimit_rule(10, 100, domain="some_new_domain")
def test_empty_rules_edge_case(self) -> None: def test_empty_rules_edge_case(self) -> None:
obj = RateLimitedTestObject("test", rules=[], backend=RedisRateLimiterBackend) obj = RateLimitedTestObject("test", rules=[], backend=RedisRateLimiterBackend)
self.assertEqual(obj.get_rules(), [(1, 9999)]) self.assertEqual(obj.get_rules(), [(1, 9999)])

View File

@ -8,9 +8,8 @@ from django.http import HttpRequest
from django.test import override_settings from django.test import override_settings
from zerver.lib.initial_password import initial_password from zerver.lib.initial_password import initial_password
from zerver.lib.rate_limiter import add_ratelimit_rule, remove_ratelimit_rule
from zerver.lib.test_classes import ZulipTestCase from zerver.lib.test_classes import ZulipTestCase
from zerver.lib.test_helpers import get_test_image_file from zerver.lib.test_helpers import get_test_image_file, ratelimit_rule
from zerver.lib.users import get_all_api_keys from zerver.lib.users import get_all_api_keys
from zerver.models import ( from zerver.models import (
Draft, Draft,
@ -239,53 +238,51 @@ class ChangeSettingsTest(ZulipTestCase):
) )
self.assert_json_error(result, "Wrong password!") self.assert_json_error(result, "Wrong password!")
@override_settings(RATE_LIMITING_AUTHENTICATE=True)
@ratelimit_rule(10, 2, domain="authenticate_by_username")
def test_wrong_old_password_rate_limiter(self) -> None: def test_wrong_old_password_rate_limiter(self) -> None:
self.login("hamlet") self.login("hamlet")
with self.settings(RATE_LIMITING_AUTHENTICATE=True): start_time = time.time()
add_ratelimit_rule(10, 2, domain="authenticate_by_username") with mock.patch("time.time", return_value=start_time):
start_time = time.time() result = self.client_patch(
with mock.patch("time.time", return_value=start_time): "/json/settings",
result = self.client_patch( dict(
"/json/settings", old_password="bad_password",
dict( new_password="ignored",
old_password="bad_password", ),
new_password="ignored", )
), self.assert_json_error(result, "Wrong password!")
) result = self.client_patch(
self.assert_json_error(result, "Wrong password!") "/json/settings",
result = self.client_patch( dict(
"/json/settings", old_password="bad_password",
dict( new_password="ignored",
old_password="bad_password", ),
new_password="ignored", )
), self.assert_json_error(result, "Wrong password!")
)
self.assert_json_error(result, "Wrong password!")
# We're over the limit, so we'll get blocked even with the correct password. # We're over the limit, so we'll get blocked even with the correct password.
result = self.client_patch( result = self.client_patch(
"/json/settings", "/json/settings",
dict( dict(
old_password=initial_password(self.example_email("hamlet")), old_password=initial_password(self.example_email("hamlet")),
new_password="ignored", new_password="ignored",
), ),
) )
self.assert_json_error( self.assert_json_error(
result, "You're making too many attempts! Try again in 10 seconds." result, "You're making too many attempts! Try again in 10 seconds."
) )
# After time passes, we should be able to succeed if we give the correct password. # After time passes, we should be able to succeed if we give the correct password.
with mock.patch("time.time", return_value=start_time + 11): with mock.patch("time.time", return_value=start_time + 11):
json_result = self.client_patch( json_result = self.client_patch(
"/json/settings", "/json/settings",
dict( dict(
old_password=initial_password(self.example_email("hamlet")), old_password=initial_password(self.example_email("hamlet")),
new_password="foobar1", new_password="foobar1",
), ),
) )
self.assert_json_success(json_result) self.assert_json_success(json_result)
remove_ratelimit_rule(10, 2, domain="authenticate_by_username")
@override_settings( @override_settings(
AUTHENTICATION_BACKENDS=( AUTHENTICATION_BACKENDS=(

View File

@ -50,7 +50,6 @@ from zerver.lib.mobile_auth_otp import (
xor_hex_strings, xor_hex_strings,
) )
from zerver.lib.name_restrictions import is_disposable_domain from zerver.lib.name_restrictions import is_disposable_domain
from zerver.lib.rate_limiter import add_ratelimit_rule, remove_ratelimit_rule
from zerver.lib.send_email import ( from zerver.lib.send_email import (
EmailNotDeliveredError, EmailNotDeliveredError,
FromAddress, FromAddress,
@ -70,6 +69,7 @@ from zerver.lib.test_helpers import (
message_stream_count, message_stream_count,
most_recent_message, most_recent_message,
most_recent_usermessage, most_recent_usermessage,
ratelimit_rule,
reset_email_visibility_to_everyone_in_zulip_realm, reset_email_visibility_to_everyone_in_zulip_realm,
) )
from zerver.models import ( from zerver.models import (
@ -586,13 +586,12 @@ class PasswordResetTest(ZulipTestCase):
self.assert_length(outbox, 0) self.assert_length(outbox, 0)
@override_settings(RATE_LIMITING=True) @ratelimit_rule(10, 2, domain="password_reset_form_by_email")
def test_rate_limiting(self) -> None: def test_rate_limiting(self) -> None:
user_profile = self.example_user("hamlet") user_profile = self.example_user("hamlet")
email = user_profile.delivery_email email = user_profile.delivery_email
from django.core.mail import outbox from django.core.mail import outbox
add_ratelimit_rule(10, 2, domain="password_reset_form_by_email")
start_time = time.time() start_time = time.time()
with patch("time.time", return_value=start_time): with patch("time.time", return_value=start_time):
self.client_post("/accounts/password/reset/", {"email": email}) self.client_post("/accounts/password/reset/", {"email": email})
@ -622,8 +621,6 @@ class PasswordResetTest(ZulipTestCase):
self.client_post("/accounts/password/reset/", {"email": email}) self.client_post("/accounts/password/reset/", {"email": email})
self.assert_length(outbox, 6) self.assert_length(outbox, 6)
remove_ratelimit_rule(10, 2, domain="password_reset_form_by_email")
def test_wrong_subdomain(self) -> None: def test_wrong_subdomain(self) -> None:
email = self.example_email("hamlet") email = self.example_email("hamlet")
@ -837,10 +834,10 @@ class LoginTest(ZulipTestCase):
self.assert_logged_in_user_id(user.id) self.assert_logged_in_user_id(user.id)
@override_settings(RATE_LIMITING_AUTHENTICATE=True) @override_settings(RATE_LIMITING_AUTHENTICATE=True)
@ratelimit_rule(10, 2, domain="authenticate_by_username")
def test_login_bad_password_rate_limiter(self) -> None: def test_login_bad_password_rate_limiter(self) -> None:
user_profile = self.example_user("hamlet") user_profile = self.example_user("hamlet")
email = user_profile.delivery_email email = user_profile.delivery_email
add_ratelimit_rule(10, 2, domain="authenticate_by_username")
start_time = time.time() start_time = time.time()
with patch("time.time", return_value=start_time): with patch("time.time", return_value=start_time):
@ -859,8 +856,6 @@ class LoginTest(ZulipTestCase):
self.login_with_return(email) self.login_with_return(email)
self.assert_logged_in_user_id(user_profile.id) self.assert_logged_in_user_id(user_profile.id)
remove_ratelimit_rule(10, 2, domain="authenticate_by_username")
def test_login_with_old_weak_password_after_hasher_change(self) -> None: def test_login_with_old_weak_password_after_hasher_change(self) -> None:
user_profile = self.example_user("hamlet") user_profile = self.example_user("hamlet")
password = "a_password_of_22_chars" password = "a_password_of_22_chars"

View File

@ -3,8 +3,8 @@ from io import StringIO
import orjson import orjson
from django.test import override_settings from django.test import override_settings
from zerver.lib.rate_limiter import add_ratelimit_rule, remove_ratelimit_rule
from zerver.lib.test_classes import ZulipTestCase from zerver.lib.test_classes import ZulipTestCase
from zerver.lib.test_helpers import ratelimit_rule
class ThumbnailTest(ZulipTestCase): class ThumbnailTest(ZulipTestCase):
@ -59,34 +59,32 @@ class ThumbnailTest(ZulipTestCase):
json = orjson.loads(result.content) json = orjson.loads(result.content)
url = json["uri"] url = json["uri"]
add_ratelimit_rule(86400, 1000, domain="spectator_attachment_access_by_file") with ratelimit_rule(86400, 1000, domain="spectator_attachment_access_by_file"):
# Deny file access for non-web-public stream # Deny file access for non-web-public stream
self.subscribe(self.example_user("hamlet"), "Denmark") self.subscribe(self.example_user("hamlet"), "Denmark")
host = self.example_user("hamlet").realm.host host = self.example_user("hamlet").realm.host
body = f"First message ...[zulip.txt](http://{host}" + url + ")" body = f"First message ...[zulip.txt](http://{host}" + url + ")"
self.send_stream_message(self.example_user("hamlet"), "Denmark", body, "test") self.send_stream_message(self.example_user("hamlet"), "Denmark", body, "test")
self.logout() self.logout()
response = self.client_get("/thumbnail", {"url": url[1:], "size": "full"}) response = self.client_get("/thumbnail", {"url": url[1:], "size": "full"})
self.assertEqual(response.status_code, 403) self.assertEqual(response.status_code, 403)
# Allow file access for web-public stream # Allow file access for web-public stream
self.login("hamlet") self.login("hamlet")
self.make_stream("web-public-stream", is_web_public=True) self.make_stream("web-public-stream", is_web_public=True)
self.subscribe(self.example_user("hamlet"), "web-public-stream") self.subscribe(self.example_user("hamlet"), "web-public-stream")
body = f"First message ...[zulip.txt](http://{host}" + url + ")" body = f"First message ...[zulip.txt](http://{host}" + url + ")"
self.send_stream_message(self.example_user("hamlet"), "web-public-stream", body, "test") self.send_stream_message(self.example_user("hamlet"), "web-public-stream", body, "test")
self.logout() self.logout()
response = self.client_get("/thumbnail", {"url": url[1:], "size": "full"}) response = self.client_get("/thumbnail", {"url": url[1:], "size": "full"})
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
remove_ratelimit_rule(86400, 1000, domain="spectator_attachment_access_by_file")
# Deny file access since rate limited # Deny file access since rate limited
add_ratelimit_rule(86400, 0, domain="spectator_attachment_access_by_file") with ratelimit_rule(86400, 0, domain="spectator_attachment_access_by_file"):
response = self.client_get("/thumbnail", {"url": url[1:], "size": "full"}) response = self.client_get("/thumbnail", {"url": url[1:], "size": "full"})
self.assertEqual(response.status_code, 403) self.assertEqual(response.status_code, 403)
remove_ratelimit_rule(86400, 0, domain="spectator_attachment_access_by_file")
# Deny random file access # Deny random file access
response = self.client_get( response = self.client_get(

View File

@ -11,7 +11,6 @@ from unittest.mock import patch
import orjson import orjson
from django.conf import settings from django.conf import settings
from django.http.response import StreamingHttpResponse from django.http.response import StreamingHttpResponse
from django.test import override_settings
from django.utils.timezone import now as timezone_now from django.utils.timezone import now as timezone_now
from PIL import Image from PIL import Image
from urllib3 import encode_multipart_formdata from urllib3 import encode_multipart_formdata
@ -31,12 +30,16 @@ from zerver.lib.avatar import avatar_url, get_avatar_field
from zerver.lib.cache import cache_get, get_realm_used_upload_space_cache_key from zerver.lib.cache import cache_get, get_realm_used_upload_space_cache_key
from zerver.lib.create_user import copy_default_settings from zerver.lib.create_user import copy_default_settings
from zerver.lib.initial_password import initial_password from zerver.lib.initial_password import initial_password
from zerver.lib.rate_limiter import add_ratelimit_rule, remove_ratelimit_rule
from zerver.lib.realm_icon import realm_icon_url from zerver.lib.realm_icon import realm_icon_url
from zerver.lib.realm_logo import get_realm_logo_url from zerver.lib.realm_logo import get_realm_logo_url
from zerver.lib.retention import clean_archived_data from zerver.lib.retention import clean_archived_data
from zerver.lib.test_classes import UploadSerializeMixin, ZulipTestCase from zerver.lib.test_classes import UploadSerializeMixin, ZulipTestCase
from zerver.lib.test_helpers import avatar_disk_path, get_test_image_file, read_test_image_file from zerver.lib.test_helpers import (
avatar_disk_path,
get_test_image_file,
ratelimit_rule,
read_test_image_file,
)
from zerver.lib.upload import delete_message_attachment, upload_message_attachment from zerver.lib.upload import delete_message_attachment, upload_message_attachment
from zerver.lib.upload.base import BadImageError, ZulipUploadBackend, resize_emoji, sanitize_name from zerver.lib.upload.base import BadImageError, ZulipUploadBackend, resize_emoji, sanitize_name
from zerver.lib.upload.local import LocalUploadBackend from zerver.lib.upload.local import LocalUploadBackend
@ -225,7 +228,6 @@ class FileUploadTest(UploadSerializeMixin, ZulipTestCase):
result = self.client_get(url) result = self.client_get(url)
self.assertEqual(result.status_code, 403) self.assertEqual(result.status_code, 403)
@override_settings(RATE_LIMITING=True)
def test_serve_file_unauthed(self) -> None: def test_serve_file_unauthed(self) -> None:
self.login("hamlet") self.login("hamlet")
fp = StringIO("zulip!") fp = StringIO("zulip!")
@ -234,34 +236,32 @@ class FileUploadTest(UploadSerializeMixin, ZulipTestCase):
result = self.client_post("/json/user_uploads", {"file": fp}) result = self.client_post("/json/user_uploads", {"file": fp})
url = self.assert_json_success(result)["uri"] url = self.assert_json_success(result)["uri"]
add_ratelimit_rule(86400, 1000, domain="spectator_attachment_access_by_file") with ratelimit_rule(86400, 1000, domain="spectator_attachment_access_by_file"):
# Deny file access for non-web-public stream # Deny file access for non-web-public stream
self.subscribe(self.example_user("hamlet"), "Denmark") self.subscribe(self.example_user("hamlet"), "Denmark")
host = self.example_user("hamlet").realm.host host = self.example_user("hamlet").realm.host
body = f"First message ...[zulip.txt](http://{host}" + url + ")" body = f"First message ...[zulip.txt](http://{host}" + url + ")"
self.send_stream_message(self.example_user("hamlet"), "Denmark", body, "test") self.send_stream_message(self.example_user("hamlet"), "Denmark", body, "test")
self.logout() self.logout()
response = self.client_get(url) response = self.client_get(url)
self.assertEqual(response.status_code, 403) self.assertEqual(response.status_code, 403)
# Allow file access for web-public stream # Allow file access for web-public stream
self.login("hamlet") self.login("hamlet")
self.make_stream("web-public-stream", is_web_public=True) self.make_stream("web-public-stream", is_web_public=True)
self.subscribe(self.example_user("hamlet"), "web-public-stream") self.subscribe(self.example_user("hamlet"), "web-public-stream")
body = f"First message ...[zulip.txt](http://{host}" + url + ")" body = f"First message ...[zulip.txt](http://{host}" + url + ")"
self.send_stream_message(self.example_user("hamlet"), "web-public-stream", body, "test") self.send_stream_message(self.example_user("hamlet"), "web-public-stream", body, "test")
self.logout() self.logout()
response = self.client_get(url) response = self.client_get(url)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
remove_ratelimit_rule(86400, 1000, domain="spectator_attachment_access_by_file")
# Deny file access since rate limited # Deny file access since rate limited
add_ratelimit_rule(86400, 0, domain="spectator_attachment_access_by_file") with ratelimit_rule(86400, 0, domain="spectator_attachment_access_by_file"):
response = self.client_get(url) response = self.client_get(url)
self.assertEqual(response.status_code, 403) self.assertEqual(response.status_code, 403)
remove_ratelimit_rule(86400, 0, domain="spectator_attachment_access_by_file")
# Deny random file access # Deny random file access
response = self.client_get( response = self.client_get(
@ -1261,15 +1261,13 @@ class AvatarTest(UploadSerializeMixin, ZulipTestCase):
status_code=401, status_code=401,
) )
with self.settings(RATE_LIMITING=True): # Allow unauthenticated/spectator requests by ID for a reasonable number of requests.
# Allow unauthenticated/spectator requests by ID for a reasonable number of requests. with ratelimit_rule(86400, 1000, domain="spectator_attachment_access_by_file"):
add_ratelimit_rule(86400, 1000, domain="spectator_attachment_access_by_file")
response = self.client_get(f"/avatar/{cordelia.id}/medium", {"foo": "bar"}) response = self.client_get(f"/avatar/{cordelia.id}/medium", {"foo": "bar"})
self.assertEqual(302, response.status_code) self.assertEqual(302, response.status_code)
remove_ratelimit_rule(86400, 1000, domain="spectator_attachment_access_by_file")
# Deny file access since rate limited # Deny file access since rate limited
add_ratelimit_rule(86400, 0, domain="spectator_attachment_access_by_file") with ratelimit_rule(86400, 0, domain="spectator_attachment_access_by_file"):
response = self.client_get(f"/avatar/{cordelia.id}/medium", {"foo": "bar"}) response = self.client_get(f"/avatar/{cordelia.id}/medium", {"foo": "bar"})
self.assertEqual(429, response.status_code) self.assertEqual(429, response.status_code)

View File

@ -293,7 +293,6 @@ def is_subdomain_in_allowed_subdomains_list(subdomain: str, allowed_subdomains:
AuthFuncT = TypeVar("AuthFuncT", bound=Callable[..., Optional[UserProfile]]) AuthFuncT = TypeVar("AuthFuncT", bound=Callable[..., Optional[UserProfile]])
rate_limiting_rules = settings.RATE_LIMITING_RULES["authenticate_by_username"]
class RateLimitedAuthenticationByUsername(RateLimitedObject): class RateLimitedAuthenticationByUsername(RateLimitedObject):
@ -305,7 +304,7 @@ class RateLimitedAuthenticationByUsername(RateLimitedObject):
return f"{type(self).__name__}:{self.username}" return f"{type(self).__name__}:{self.username}"
def rules(self) -> List[Tuple[int, int]]: def rules(self) -> List[Tuple[int, int]]:
return rate_limiting_rules return settings.RATE_LIMITING_RULES["authenticate_by_username"]
def rate_limit_authentication_by_username(request: HttpRequest, username: str) -> None: def rate_limit_authentication_by_username(request: HttpRequest, username: str) -> None: