mirror of https://github.com/zulip/zulip.git
test_helpers: Switch add/remove_ratelimit to a contextmanager.
Failing to remove all of the rules which were added causes action at a distance with other tests. The two methods were also only used by test code, making their existence in zerver.lib.rate_limiter clearly misplaced. This fixes one instance of a mis-balanced add/remove, which caused tests to start failing if run non-parallel and one more anonymous request was added within a rate-limit-enabled block.
This commit is contained in:
parent
da09e003f6
commit
0dbe111ab3
|
@ -159,23 +159,6 @@ def bounce_redis_key_prefix_for_testing(test_name: str) -> None:
|
||||||
KEY_PREFIX = test_name + ":" + str(os.getpid()) + ":"
|
KEY_PREFIX = test_name + ":" + str(os.getpid()) + ":"
|
||||||
|
|
||||||
|
|
||||||
def add_ratelimit_rule(range_seconds: int, num_requests: int, domain: str = "api_by_user") -> None:
|
|
||||||
"""Add a rate-limiting rule to the ratelimiter"""
|
|
||||||
if domain not in rules:
|
|
||||||
# If we don't have any rules for domain yet, the domain key needs to be
|
|
||||||
# added to the rules dictionary.
|
|
||||||
rules[domain] = []
|
|
||||||
|
|
||||||
rules[domain].append((range_seconds, num_requests))
|
|
||||||
rules[domain].sort(key=lambda x: x[0])
|
|
||||||
|
|
||||||
|
|
||||||
def remove_ratelimit_rule(
|
|
||||||
range_seconds: int, num_requests: int, domain: str = "api_by_user"
|
|
||||||
) -> None:
|
|
||||||
rules[domain] = [x for x in rules[domain] if x[0] != range_seconds and x[1] != num_requests]
|
|
||||||
|
|
||||||
|
|
||||||
class RateLimiterBackend(ABC):
|
class RateLimiterBackend(ABC):
|
||||||
@classmethod
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|
|
@ -22,6 +22,7 @@ from typing import (
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import boto3.session
|
import boto3.session
|
||||||
import fakeldap
|
import fakeldap
|
||||||
|
@ -45,6 +46,7 @@ from zerver.lib.avatar import avatar_url
|
||||||
from zerver.lib.cache import get_cache_backend
|
from zerver.lib.cache import get_cache_backend
|
||||||
from zerver.lib.db import Params, ParamsT, Query, TimeTrackingCursor
|
from zerver.lib.db import Params, ParamsT, Query, TimeTrackingCursor
|
||||||
from zerver.lib.integrations import WEBHOOK_INTEGRATIONS
|
from zerver.lib.integrations import WEBHOOK_INTEGRATIONS
|
||||||
|
from zerver.lib.rate_limiter import RateLimitedIPAddr, rules
|
||||||
from zerver.lib.request import RequestNotes
|
from zerver.lib.request import RequestNotes
|
||||||
from zerver.lib.upload.s3 import S3UploadBackend
|
from zerver.lib.upload.s3 import S3UploadBackend
|
||||||
from zerver.models import (
|
from zerver.models import (
|
||||||
|
@ -743,3 +745,20 @@ def timeout_mock(mock_path: str) -> Iterator[None]:
|
||||||
|
|
||||||
with mock.patch(f"{mock_path}.timeout", new=mock_timeout):
|
with mock.patch(f"{mock_path}.timeout", new=mock_timeout):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def ratelimit_rule(
|
||||||
|
range_seconds: int,
|
||||||
|
num_requests: int,
|
||||||
|
domain: str = "api_by_user",
|
||||||
|
) -> Iterator[None]:
|
||||||
|
"""Temporarily add a rate-limiting rule to the ratelimiter"""
|
||||||
|
RateLimitedIPAddr("127.0.0.1", domain=domain).clear_history()
|
||||||
|
|
||||||
|
domain_rules = rules.get(domain, []).copy()
|
||||||
|
domain_rules.append((range_seconds, num_requests))
|
||||||
|
domain_rules.sort(key=lambda x: x[0])
|
||||||
|
|
||||||
|
with patch.dict(rules, {domain: domain_rules}), override_settings(RATE_LIMITING=True):
|
||||||
|
yield
|
||||||
|
|
|
@ -76,7 +76,6 @@ from zerver.lib.email_validation import (
|
||||||
from zerver.lib.exceptions import JsonableError, RateLimitedError
|
from zerver.lib.exceptions import JsonableError, RateLimitedError
|
||||||
from zerver.lib.initial_password import initial_password
|
from zerver.lib.initial_password import initial_password
|
||||||
from zerver.lib.mobile_auth_otp import otp_decrypt_api_key
|
from zerver.lib.mobile_auth_otp import otp_decrypt_api_key
|
||||||
from zerver.lib.rate_limiter import add_ratelimit_rule, remove_ratelimit_rule
|
|
||||||
from zerver.lib.storage import static_path
|
from zerver.lib.storage import static_path
|
||||||
from zerver.lib.streams import ensure_stream
|
from zerver.lib.streams import ensure_stream
|
||||||
from zerver.lib.test_classes import ZulipTestCase
|
from zerver.lib.test_classes import ZulipTestCase
|
||||||
|
@ -84,6 +83,7 @@ from zerver.lib.test_helpers import (
|
||||||
HostRequestMock,
|
HostRequestMock,
|
||||||
create_s3_buckets,
|
create_s3_buckets,
|
||||||
load_subdomain_token,
|
load_subdomain_token,
|
||||||
|
ratelimit_rule,
|
||||||
read_test_image_file,
|
read_test_image_file,
|
||||||
use_s3_backend,
|
use_s3_backend,
|
||||||
)
|
)
|
||||||
|
@ -671,8 +671,9 @@ class RateLimitAuthenticationTests(ZulipTestCase):
|
||||||
request.session = mock.MagicMock()
|
request.session = mock.MagicMock()
|
||||||
return attempt_authentication_func(request, username, password)
|
return attempt_authentication_func(request, username, password)
|
||||||
|
|
||||||
add_ratelimit_rule(10, 2, domain="authenticate_by_username")
|
with mock.patch.object(
|
||||||
with mock.patch.object(RateLimitedAuthenticationByUsername, "key", new=_mock_key):
|
RateLimitedAuthenticationByUsername, "key", new=_mock_key
|
||||||
|
), ratelimit_rule(10, 2, domain="authenticate_by_username"):
|
||||||
try:
|
try:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
with mock.patch("time.time", return_value=start_time):
|
with mock.patch("time.time", return_value=start_time):
|
||||||
|
@ -708,7 +709,6 @@ class RateLimitAuthenticationTests(ZulipTestCase):
|
||||||
finally:
|
finally:
|
||||||
# Clean up to avoid affecting other tests.
|
# Clean up to avoid affecting other tests.
|
||||||
RateLimitedAuthenticationByUsername(username).clear_history()
|
RateLimitedAuthenticationByUsername(username).clear_history()
|
||||||
remove_ratelimit_rule(10, 2, domain="authenticate_by_username")
|
|
||||||
|
|
||||||
def test_email_auth_backend_user_based_rate_limiting(self) -> None:
|
def test_email_auth_backend_user_based_rate_limiting(self) -> None:
|
||||||
user_profile = self.example_user("hamlet")
|
user_profile = self.example_user("hamlet")
|
||||||
|
|
|
@ -18,11 +18,10 @@ from zerver.lib.rate_limiter import (
|
||||||
RateLimitedIPAddr,
|
RateLimitedIPAddr,
|
||||||
RateLimitedUser,
|
RateLimitedUser,
|
||||||
RateLimiterLockingError,
|
RateLimiterLockingError,
|
||||||
add_ratelimit_rule,
|
|
||||||
get_tor_ips,
|
get_tor_ips,
|
||||||
remove_ratelimit_rule,
|
|
||||||
)
|
)
|
||||||
from zerver.lib.test_classes import ZulipTestCase
|
from zerver.lib.test_classes import ZulipTestCase
|
||||||
|
from zerver.lib.test_helpers import ratelimit_rule
|
||||||
from zerver.lib.zephyr import compute_mit_user_fullname
|
from zerver.lib.zephyr import compute_mit_user_fullname
|
||||||
from zerver.models import PushDeviceToken, UserProfile
|
from zerver.models import PushDeviceToken, UserProfile
|
||||||
|
|
||||||
|
@ -76,18 +75,6 @@ class MITNameTest(ZulipTestCase):
|
||||||
email_is_not_mit_mailing_list("sipbexch@mit.edu")
|
email_is_not_mit_mailing_list("sipbexch@mit.edu")
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def rate_limit_rule(range_seconds: int, num_requests: int, domain: str) -> Iterator[None]:
|
|
||||||
RateLimitedIPAddr("127.0.0.1", domain=domain).clear_history()
|
|
||||||
add_ratelimit_rule(range_seconds, num_requests, domain=domain)
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
# We need this in a finally block to ensure the test cleans up after itself
|
|
||||||
# even in case of failure, to avoid polluting the rules state.
|
|
||||||
remove_ratelimit_rule(range_seconds, num_requests, domain=domain)
|
|
||||||
|
|
||||||
|
|
||||||
class RateLimitTests(ZulipTestCase):
|
class RateLimitTests(ZulipTestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
super().setUp()
|
super().setUp()
|
||||||
|
@ -194,14 +181,14 @@ class RateLimitTests(ZulipTestCase):
|
||||||
|
|
||||||
self.assertNotEqual(result.status_code, 429)
|
self.assertNotEqual(result.status_code, 429)
|
||||||
|
|
||||||
@rate_limit_rule(1, 5, domain="api_by_user")
|
@ratelimit_rule(1, 5, domain="api_by_user")
|
||||||
def test_hit_ratelimits_as_user(self) -> None:
|
def test_hit_ratelimits_as_user(self) -> None:
|
||||||
user = self.example_user("cordelia")
|
user = self.example_user("cordelia")
|
||||||
RateLimitedUser(user).clear_history()
|
RateLimitedUser(user).clear_history()
|
||||||
|
|
||||||
self.do_test_hit_ratelimits(lambda: self.send_api_message(user, "some stuff"))
|
self.do_test_hit_ratelimits(lambda: self.send_api_message(user, "some stuff"))
|
||||||
|
|
||||||
@rate_limit_rule(1, 5, domain="email_change_by_user")
|
@ratelimit_rule(1, 5, domain="email_change_by_user")
|
||||||
def test_hit_change_email_ratelimit_as_user(self) -> None:
|
def test_hit_change_email_ratelimit_as_user(self) -> None:
|
||||||
user = self.example_user("cordelia")
|
user = self.example_user("cordelia")
|
||||||
RateLimitedUser(user).clear_history()
|
RateLimitedUser(user).clear_history()
|
||||||
|
@ -211,7 +198,7 @@ class RateLimitTests(ZulipTestCase):
|
||||||
lambda: self.api_patch(user, "/api/v1/settings", {"email": emails.pop()}),
|
lambda: self.api_patch(user, "/api/v1/settings", {"email": emails.pop()}),
|
||||||
)
|
)
|
||||||
|
|
||||||
@rate_limit_rule(1, 5, domain="api_by_ip")
|
@ratelimit_rule(1, 5, domain="api_by_ip")
|
||||||
def test_hit_ratelimits_as_ip(self) -> None:
|
def test_hit_ratelimits_as_ip(self) -> None:
|
||||||
self.do_test_hit_ratelimits(self.send_unauthed_api_request)
|
self.do_test_hit_ratelimits(self.send_unauthed_api_request)
|
||||||
|
|
||||||
|
@ -219,7 +206,7 @@ class RateLimitTests(ZulipTestCase):
|
||||||
resp = self.send_unauthed_api_request(REMOTE_ADDR="127.0.0.2")
|
resp = self.send_unauthed_api_request(REMOTE_ADDR="127.0.0.2")
|
||||||
self.assertNotEqual(resp.status_code, 429)
|
self.assertNotEqual(resp.status_code, 429)
|
||||||
|
|
||||||
@rate_limit_rule(1, 5, domain="sends_email_by_ip")
|
@ratelimit_rule(1, 5, domain="sends_email_by_ip")
|
||||||
def test_create_realm_rate_limiting(self) -> None:
|
def test_create_realm_rate_limiting(self) -> None:
|
||||||
with self.settings(OPEN_REALM_CREATION=True):
|
with self.settings(OPEN_REALM_CREATION=True):
|
||||||
self.do_test_hit_ratelimits(
|
self.do_test_hit_ratelimits(
|
||||||
|
@ -229,14 +216,14 @@ class RateLimitTests(ZulipTestCase):
|
||||||
is_json=False,
|
is_json=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
@rate_limit_rule(1, 5, domain="sends_email_by_ip")
|
@ratelimit_rule(1, 5, domain="sends_email_by_ip")
|
||||||
def test_find_account_rate_limiting(self) -> None:
|
def test_find_account_rate_limiting(self) -> None:
|
||||||
self.do_test_hit_ratelimits(
|
self.do_test_hit_ratelimits(
|
||||||
lambda: self.client_post("/accounts/find/", {"emails": "new@zulip.com"}),
|
lambda: self.client_post("/accounts/find/", {"emails": "new@zulip.com"}),
|
||||||
is_json=False,
|
is_json=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
@rate_limit_rule(1, 5, domain="sends_email_by_ip")
|
@ratelimit_rule(1, 5, domain="sends_email_by_ip")
|
||||||
def test_password_reset_rate_limiting(self) -> None:
|
def test_password_reset_rate_limiting(self) -> None:
|
||||||
with self.assertLogs(level="INFO") as m:
|
with self.assertLogs(level="INFO") as m:
|
||||||
self.do_test_hit_ratelimits(
|
self.do_test_hit_ratelimits(
|
||||||
|
@ -251,7 +238,7 @@ class RateLimitTests(ZulipTestCase):
|
||||||
# Test whether submitting multiple emails is handled correctly.
|
# Test whether submitting multiple emails is handled correctly.
|
||||||
# The limit is set to 10 per second, so 5 requests with 2 emails
|
# The limit is set to 10 per second, so 5 requests with 2 emails
|
||||||
# submitted in each should be allowed.
|
# submitted in each should be allowed.
|
||||||
@rate_limit_rule(1, 10, domain="sends_email_by_ip")
|
@ratelimit_rule(1, 10, domain="sends_email_by_ip")
|
||||||
def test_find_account_rate_limiting_multiple(self) -> None:
|
def test_find_account_rate_limiting_multiple(self) -> None:
|
||||||
self.do_test_hit_ratelimits(
|
self.do_test_hit_ratelimits(
|
||||||
lambda: self.client_post("/accounts/find/", {"emails": "new@zulip.com,new2@zulip.com"}),
|
lambda: self.client_post("/accounts/find/", {"emails": "new@zulip.com,new2@zulip.com"}),
|
||||||
|
@ -260,7 +247,7 @@ class RateLimitTests(ZulipTestCase):
|
||||||
|
|
||||||
# If I submit with 3 emails and the rate-limit is 2, I should get
|
# If I submit with 3 emails and the rate-limit is 2, I should get
|
||||||
# a 429 and not send any emails.
|
# a 429 and not send any emails.
|
||||||
@rate_limit_rule(1, 2, domain="sends_email_by_ip")
|
@ratelimit_rule(1, 2, domain="sends_email_by_ip")
|
||||||
def test_find_account_rate_limiting_multiple_one_request(self) -> None:
|
def test_find_account_rate_limiting_multiple_one_request(self) -> None:
|
||||||
emails = [
|
emails = [
|
||||||
"iago@zulip.com",
|
"iago@zulip.com",
|
||||||
|
@ -274,14 +261,14 @@ class RateLimitTests(ZulipTestCase):
|
||||||
|
|
||||||
self.assert_length(outbox, 0)
|
self.assert_length(outbox, 0)
|
||||||
|
|
||||||
@rate_limit_rule(1, 5, domain="sends_email_by_ip")
|
@ratelimit_rule(1, 5, domain="sends_email_by_ip")
|
||||||
def test_register_account_rate_limiting(self) -> None:
|
def test_register_account_rate_limiting(self) -> None:
|
||||||
self.do_test_hit_ratelimits(
|
self.do_test_hit_ratelimits(
|
||||||
lambda: self.client_post("/register/", {"email": "new@zulip.com"}),
|
lambda: self.client_post("/register/", {"email": "new@zulip.com"}),
|
||||||
is_json=False,
|
is_json=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
@rate_limit_rule(1, 5, domain="sends_email_by_ip")
|
@ratelimit_rule(1, 5, domain="sends_email_by_ip")
|
||||||
def test_combined_ip_limits(self) -> None:
|
def test_combined_ip_limits(self) -> None:
|
||||||
# Alternate requests to /new/ and /accounts/find/
|
# Alternate requests to /new/ and /accounts/find/
|
||||||
request_count = 0
|
request_count = 0
|
||||||
|
@ -332,7 +319,7 @@ class RateLimitTests(ZulipTestCase):
|
||||||
with mock.patch("builtins.open", selective_mock_open):
|
with mock.patch("builtins.open", selective_mock_open):
|
||||||
yield tor_open
|
yield tor_open
|
||||||
|
|
||||||
@rate_limit_rule(1, 5, domain="api_by_ip")
|
@ratelimit_rule(1, 5, domain="api_by_ip")
|
||||||
@override_settings(RATE_LIMIT_TOR_TOGETHER=True)
|
@override_settings(RATE_LIMIT_TOR_TOGETHER=True)
|
||||||
def test_tor_ip_limits(self) -> None:
|
def test_tor_ip_limits(self) -> None:
|
||||||
request_count = 0
|
request_count = 0
|
||||||
|
@ -354,7 +341,7 @@ class RateLimitTests(ZulipTestCase):
|
||||||
tor_open.assert_called_once_with(settings.TOR_EXIT_NODE_FILE_PATH, "rb")
|
tor_open.assert_called_once_with(settings.TOR_EXIT_NODE_FILE_PATH, "rb")
|
||||||
tor_open().read.assert_called_once()
|
tor_open().read.assert_called_once()
|
||||||
|
|
||||||
@rate_limit_rule(1, 5, domain="api_by_ip")
|
@ratelimit_rule(1, 5, domain="api_by_ip")
|
||||||
@override_settings(RATE_LIMIT_TOR_TOGETHER=True)
|
@override_settings(RATE_LIMIT_TOR_TOGETHER=True)
|
||||||
def test_tor_file_empty(self) -> None:
|
def test_tor_file_empty(self) -> None:
|
||||||
for ip in ["1.2.3.4", "5.6.7.8", "tor-exit-node"]:
|
for ip in ["1.2.3.4", "5.6.7.8", "tor-exit-node"]:
|
||||||
|
@ -375,7 +362,7 @@ class RateLimitTests(ZulipTestCase):
|
||||||
# circuit-breaker, and stopping trying
|
# circuit-breaker, and stopping trying
|
||||||
tor_open().read.assert_has_calls([mock.call(), mock.call()])
|
tor_open().read.assert_has_calls([mock.call(), mock.call()])
|
||||||
|
|
||||||
@rate_limit_rule(1, 5, domain="api_by_ip")
|
@ratelimit_rule(1, 5, domain="api_by_ip")
|
||||||
@override_settings(RATE_LIMIT_TOR_TOGETHER=True)
|
@override_settings(RATE_LIMIT_TOR_TOGETHER=True)
|
||||||
def test_tor_file_not_found(self) -> None:
|
def test_tor_file_not_found(self) -> None:
|
||||||
for ip in ["1.2.3.4", "5.6.7.8", "tor-exit-node"]:
|
for ip in ["1.2.3.4", "5.6.7.8", "tor-exit-node"]:
|
||||||
|
@ -415,7 +402,7 @@ class RateLimitTests(ZulipTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
@skipUnless(settings.ZILENCER_ENABLED, "requires zilencer")
|
@skipUnless(settings.ZILENCER_ENABLED, "requires zilencer")
|
||||||
@rate_limit_rule(1, 5, domain="api_by_remote_server")
|
@ratelimit_rule(1, 5, domain="api_by_remote_server")
|
||||||
def test_hit_ratelimits_as_remote_server(self) -> None:
|
def test_hit_ratelimits_as_remote_server(self) -> None:
|
||||||
server_uuid = str(uuid.uuid4())
|
server_uuid = str(uuid.uuid4())
|
||||||
server = RemoteZulipServer(
|
server = RemoteZulipServer(
|
||||||
|
|
|
@ -11,10 +11,9 @@ from zerver.lib.rate_limiter import (
|
||||||
RateLimiterBackend,
|
RateLimiterBackend,
|
||||||
RedisRateLimiterBackend,
|
RedisRateLimiterBackend,
|
||||||
TornadoInMemoryRateLimiterBackend,
|
TornadoInMemoryRateLimiterBackend,
|
||||||
add_ratelimit_rule,
|
|
||||||
remove_ratelimit_rule,
|
|
||||||
)
|
)
|
||||||
from zerver.lib.test_classes import ZulipTestCase
|
from zerver.lib.test_classes import ZulipTestCase
|
||||||
|
from zerver.lib.test_helpers import ratelimit_rule
|
||||||
|
|
||||||
RANDOM_KEY_PREFIX = secrets.token_hex(16)
|
RANDOM_KEY_PREFIX = secrets.token_hex(16)
|
||||||
|
|
||||||
|
@ -227,27 +226,18 @@ class RateLimitedObjectsTest(ZulipTestCase):
|
||||||
|
|
||||||
self.assertEqual(obj.get_rules(), [(1, 3), (2, 4)])
|
self.assertEqual(obj.get_rules(), [(1, 3), (2, 4)])
|
||||||
|
|
||||||
def test_add_remove_rule(self) -> None:
|
def test_ratelimit_rule(self) -> None:
|
||||||
user_profile = self.example_user("hamlet")
|
user_profile = self.example_user("hamlet")
|
||||||
try:
|
with ratelimit_rule(1, 2), ratelimit_rule(4, 5, domain="some_new_domain"):
|
||||||
add_ratelimit_rule(1, 2)
|
with ratelimit_rule(10, 100, domain="some_new_domain"):
|
||||||
add_ratelimit_rule(4, 5, domain="some_new_domain")
|
obj = RateLimitedUser(user_profile)
|
||||||
add_ratelimit_rule(10, 100, domain="some_new_domain")
|
|
||||||
obj = RateLimitedUser(user_profile)
|
|
||||||
|
|
||||||
self.assertEqual(obj.get_rules(), [(1, 2)])
|
self.assertEqual(obj.get_rules(), [(1, 2)])
|
||||||
obj.domain = "some_new_domain"
|
obj.domain = "some_new_domain"
|
||||||
self.assertEqual(obj.get_rules(), [(4, 5), (10, 100)])
|
self.assertEqual(obj.get_rules(), [(4, 5), (10, 100)])
|
||||||
|
|
||||||
remove_ratelimit_rule(10, 100, domain="some_new_domain")
|
|
||||||
self.assertEqual(obj.get_rules(), [(4, 5)])
|
self.assertEqual(obj.get_rules(), [(4, 5)])
|
||||||
|
|
||||||
finally:
|
|
||||||
# Ensure all the rules get cleaned up.
|
|
||||||
remove_ratelimit_rule(1, 2)
|
|
||||||
remove_ratelimit_rule(4, 5, domain="some_new_domain")
|
|
||||||
remove_ratelimit_rule(10, 100, domain="some_new_domain")
|
|
||||||
|
|
||||||
def test_empty_rules_edge_case(self) -> None:
|
def test_empty_rules_edge_case(self) -> None:
|
||||||
obj = RateLimitedTestObject("test", rules=[], backend=RedisRateLimiterBackend)
|
obj = RateLimitedTestObject("test", rules=[], backend=RedisRateLimiterBackend)
|
||||||
self.assertEqual(obj.get_rules(), [(1, 9999)])
|
self.assertEqual(obj.get_rules(), [(1, 9999)])
|
||||||
|
|
|
@ -8,9 +8,8 @@ from django.http import HttpRequest
|
||||||
from django.test import override_settings
|
from django.test import override_settings
|
||||||
|
|
||||||
from zerver.lib.initial_password import initial_password
|
from zerver.lib.initial_password import initial_password
|
||||||
from zerver.lib.rate_limiter import add_ratelimit_rule, remove_ratelimit_rule
|
|
||||||
from zerver.lib.test_classes import ZulipTestCase
|
from zerver.lib.test_classes import ZulipTestCase
|
||||||
from zerver.lib.test_helpers import get_test_image_file
|
from zerver.lib.test_helpers import get_test_image_file, ratelimit_rule
|
||||||
from zerver.lib.users import get_all_api_keys
|
from zerver.lib.users import get_all_api_keys
|
||||||
from zerver.models import (
|
from zerver.models import (
|
||||||
Draft,
|
Draft,
|
||||||
|
@ -239,53 +238,51 @@ class ChangeSettingsTest(ZulipTestCase):
|
||||||
)
|
)
|
||||||
self.assert_json_error(result, "Wrong password!")
|
self.assert_json_error(result, "Wrong password!")
|
||||||
|
|
||||||
|
@override_settings(RATE_LIMITING_AUTHENTICATE=True)
|
||||||
|
@ratelimit_rule(10, 2, domain="authenticate_by_username")
|
||||||
def test_wrong_old_password_rate_limiter(self) -> None:
|
def test_wrong_old_password_rate_limiter(self) -> None:
|
||||||
self.login("hamlet")
|
self.login("hamlet")
|
||||||
with self.settings(RATE_LIMITING_AUTHENTICATE=True):
|
start_time = time.time()
|
||||||
add_ratelimit_rule(10, 2, domain="authenticate_by_username")
|
with mock.patch("time.time", return_value=start_time):
|
||||||
start_time = time.time()
|
result = self.client_patch(
|
||||||
with mock.patch("time.time", return_value=start_time):
|
"/json/settings",
|
||||||
result = self.client_patch(
|
dict(
|
||||||
"/json/settings",
|
old_password="bad_password",
|
||||||
dict(
|
new_password="ignored",
|
||||||
old_password="bad_password",
|
),
|
||||||
new_password="ignored",
|
)
|
||||||
),
|
self.assert_json_error(result, "Wrong password!")
|
||||||
)
|
result = self.client_patch(
|
||||||
self.assert_json_error(result, "Wrong password!")
|
"/json/settings",
|
||||||
result = self.client_patch(
|
dict(
|
||||||
"/json/settings",
|
old_password="bad_password",
|
||||||
dict(
|
new_password="ignored",
|
||||||
old_password="bad_password",
|
),
|
||||||
new_password="ignored",
|
)
|
||||||
),
|
self.assert_json_error(result, "Wrong password!")
|
||||||
)
|
|
||||||
self.assert_json_error(result, "Wrong password!")
|
|
||||||
|
|
||||||
# We're over the limit, so we'll get blocked even with the correct password.
|
# We're over the limit, so we'll get blocked even with the correct password.
|
||||||
result = self.client_patch(
|
result = self.client_patch(
|
||||||
"/json/settings",
|
"/json/settings",
|
||||||
dict(
|
dict(
|
||||||
old_password=initial_password(self.example_email("hamlet")),
|
old_password=initial_password(self.example_email("hamlet")),
|
||||||
new_password="ignored",
|
new_password="ignored",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.assert_json_error(
|
self.assert_json_error(
|
||||||
result, "You're making too many attempts! Try again in 10 seconds."
|
result, "You're making too many attempts! Try again in 10 seconds."
|
||||||
)
|
)
|
||||||
|
|
||||||
# After time passes, we should be able to succeed if we give the correct password.
|
# After time passes, we should be able to succeed if we give the correct password.
|
||||||
with mock.patch("time.time", return_value=start_time + 11):
|
with mock.patch("time.time", return_value=start_time + 11):
|
||||||
json_result = self.client_patch(
|
json_result = self.client_patch(
|
||||||
"/json/settings",
|
"/json/settings",
|
||||||
dict(
|
dict(
|
||||||
old_password=initial_password(self.example_email("hamlet")),
|
old_password=initial_password(self.example_email("hamlet")),
|
||||||
new_password="foobar1",
|
new_password="foobar1",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.assert_json_success(json_result)
|
self.assert_json_success(json_result)
|
||||||
|
|
||||||
remove_ratelimit_rule(10, 2, domain="authenticate_by_username")
|
|
||||||
|
|
||||||
@override_settings(
|
@override_settings(
|
||||||
AUTHENTICATION_BACKENDS=(
|
AUTHENTICATION_BACKENDS=(
|
||||||
|
|
|
@ -50,7 +50,6 @@ from zerver.lib.mobile_auth_otp import (
|
||||||
xor_hex_strings,
|
xor_hex_strings,
|
||||||
)
|
)
|
||||||
from zerver.lib.name_restrictions import is_disposable_domain
|
from zerver.lib.name_restrictions import is_disposable_domain
|
||||||
from zerver.lib.rate_limiter import add_ratelimit_rule, remove_ratelimit_rule
|
|
||||||
from zerver.lib.send_email import (
|
from zerver.lib.send_email import (
|
||||||
EmailNotDeliveredError,
|
EmailNotDeliveredError,
|
||||||
FromAddress,
|
FromAddress,
|
||||||
|
@ -70,6 +69,7 @@ from zerver.lib.test_helpers import (
|
||||||
message_stream_count,
|
message_stream_count,
|
||||||
most_recent_message,
|
most_recent_message,
|
||||||
most_recent_usermessage,
|
most_recent_usermessage,
|
||||||
|
ratelimit_rule,
|
||||||
reset_email_visibility_to_everyone_in_zulip_realm,
|
reset_email_visibility_to_everyone_in_zulip_realm,
|
||||||
)
|
)
|
||||||
from zerver.models import (
|
from zerver.models import (
|
||||||
|
@ -586,13 +586,12 @@ class PasswordResetTest(ZulipTestCase):
|
||||||
|
|
||||||
self.assert_length(outbox, 0)
|
self.assert_length(outbox, 0)
|
||||||
|
|
||||||
@override_settings(RATE_LIMITING=True)
|
@ratelimit_rule(10, 2, domain="password_reset_form_by_email")
|
||||||
def test_rate_limiting(self) -> None:
|
def test_rate_limiting(self) -> None:
|
||||||
user_profile = self.example_user("hamlet")
|
user_profile = self.example_user("hamlet")
|
||||||
email = user_profile.delivery_email
|
email = user_profile.delivery_email
|
||||||
from django.core.mail import outbox
|
from django.core.mail import outbox
|
||||||
|
|
||||||
add_ratelimit_rule(10, 2, domain="password_reset_form_by_email")
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
with patch("time.time", return_value=start_time):
|
with patch("time.time", return_value=start_time):
|
||||||
self.client_post("/accounts/password/reset/", {"email": email})
|
self.client_post("/accounts/password/reset/", {"email": email})
|
||||||
|
@ -622,8 +621,6 @@ class PasswordResetTest(ZulipTestCase):
|
||||||
self.client_post("/accounts/password/reset/", {"email": email})
|
self.client_post("/accounts/password/reset/", {"email": email})
|
||||||
self.assert_length(outbox, 6)
|
self.assert_length(outbox, 6)
|
||||||
|
|
||||||
remove_ratelimit_rule(10, 2, domain="password_reset_form_by_email")
|
|
||||||
|
|
||||||
def test_wrong_subdomain(self) -> None:
|
def test_wrong_subdomain(self) -> None:
|
||||||
email = self.example_email("hamlet")
|
email = self.example_email("hamlet")
|
||||||
|
|
||||||
|
@ -837,10 +834,10 @@ class LoginTest(ZulipTestCase):
|
||||||
self.assert_logged_in_user_id(user.id)
|
self.assert_logged_in_user_id(user.id)
|
||||||
|
|
||||||
@override_settings(RATE_LIMITING_AUTHENTICATE=True)
|
@override_settings(RATE_LIMITING_AUTHENTICATE=True)
|
||||||
|
@ratelimit_rule(10, 2, domain="authenticate_by_username")
|
||||||
def test_login_bad_password_rate_limiter(self) -> None:
|
def test_login_bad_password_rate_limiter(self) -> None:
|
||||||
user_profile = self.example_user("hamlet")
|
user_profile = self.example_user("hamlet")
|
||||||
email = user_profile.delivery_email
|
email = user_profile.delivery_email
|
||||||
add_ratelimit_rule(10, 2, domain="authenticate_by_username")
|
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
with patch("time.time", return_value=start_time):
|
with patch("time.time", return_value=start_time):
|
||||||
|
@ -859,8 +856,6 @@ class LoginTest(ZulipTestCase):
|
||||||
self.login_with_return(email)
|
self.login_with_return(email)
|
||||||
self.assert_logged_in_user_id(user_profile.id)
|
self.assert_logged_in_user_id(user_profile.id)
|
||||||
|
|
||||||
remove_ratelimit_rule(10, 2, domain="authenticate_by_username")
|
|
||||||
|
|
||||||
def test_login_with_old_weak_password_after_hasher_change(self) -> None:
|
def test_login_with_old_weak_password_after_hasher_change(self) -> None:
|
||||||
user_profile = self.example_user("hamlet")
|
user_profile = self.example_user("hamlet")
|
||||||
password = "a_password_of_22_chars"
|
password = "a_password_of_22_chars"
|
||||||
|
|
|
@ -3,8 +3,8 @@ from io import StringIO
|
||||||
import orjson
|
import orjson
|
||||||
from django.test import override_settings
|
from django.test import override_settings
|
||||||
|
|
||||||
from zerver.lib.rate_limiter import add_ratelimit_rule, remove_ratelimit_rule
|
|
||||||
from zerver.lib.test_classes import ZulipTestCase
|
from zerver.lib.test_classes import ZulipTestCase
|
||||||
|
from zerver.lib.test_helpers import ratelimit_rule
|
||||||
|
|
||||||
|
|
||||||
class ThumbnailTest(ZulipTestCase):
|
class ThumbnailTest(ZulipTestCase):
|
||||||
|
@ -59,34 +59,32 @@ class ThumbnailTest(ZulipTestCase):
|
||||||
json = orjson.loads(result.content)
|
json = orjson.loads(result.content)
|
||||||
url = json["uri"]
|
url = json["uri"]
|
||||||
|
|
||||||
add_ratelimit_rule(86400, 1000, domain="spectator_attachment_access_by_file")
|
with ratelimit_rule(86400, 1000, domain="spectator_attachment_access_by_file"):
|
||||||
# Deny file access for non-web-public stream
|
# Deny file access for non-web-public stream
|
||||||
self.subscribe(self.example_user("hamlet"), "Denmark")
|
self.subscribe(self.example_user("hamlet"), "Denmark")
|
||||||
host = self.example_user("hamlet").realm.host
|
host = self.example_user("hamlet").realm.host
|
||||||
body = f"First message ...[zulip.txt](http://{host}" + url + ")"
|
body = f"First message ...[zulip.txt](http://{host}" + url + ")"
|
||||||
self.send_stream_message(self.example_user("hamlet"), "Denmark", body, "test")
|
self.send_stream_message(self.example_user("hamlet"), "Denmark", body, "test")
|
||||||
|
|
||||||
self.logout()
|
self.logout()
|
||||||
response = self.client_get("/thumbnail", {"url": url[1:], "size": "full"})
|
response = self.client_get("/thumbnail", {"url": url[1:], "size": "full"})
|
||||||
self.assertEqual(response.status_code, 403)
|
self.assertEqual(response.status_code, 403)
|
||||||
|
|
||||||
# Allow file access for web-public stream
|
# Allow file access for web-public stream
|
||||||
self.login("hamlet")
|
self.login("hamlet")
|
||||||
self.make_stream("web-public-stream", is_web_public=True)
|
self.make_stream("web-public-stream", is_web_public=True)
|
||||||
self.subscribe(self.example_user("hamlet"), "web-public-stream")
|
self.subscribe(self.example_user("hamlet"), "web-public-stream")
|
||||||
body = f"First message ...[zulip.txt](http://{host}" + url + ")"
|
body = f"First message ...[zulip.txt](http://{host}" + url + ")"
|
||||||
self.send_stream_message(self.example_user("hamlet"), "web-public-stream", body, "test")
|
self.send_stream_message(self.example_user("hamlet"), "web-public-stream", body, "test")
|
||||||
|
|
||||||
self.logout()
|
self.logout()
|
||||||
response = self.client_get("/thumbnail", {"url": url[1:], "size": "full"})
|
response = self.client_get("/thumbnail", {"url": url[1:], "size": "full"})
|
||||||
self.assertEqual(response.status_code, 302)
|
self.assertEqual(response.status_code, 302)
|
||||||
remove_ratelimit_rule(86400, 1000, domain="spectator_attachment_access_by_file")
|
|
||||||
|
|
||||||
# Deny file access since rate limited
|
# Deny file access since rate limited
|
||||||
add_ratelimit_rule(86400, 0, domain="spectator_attachment_access_by_file")
|
with ratelimit_rule(86400, 0, domain="spectator_attachment_access_by_file"):
|
||||||
response = self.client_get("/thumbnail", {"url": url[1:], "size": "full"})
|
response = self.client_get("/thumbnail", {"url": url[1:], "size": "full"})
|
||||||
self.assertEqual(response.status_code, 403)
|
self.assertEqual(response.status_code, 403)
|
||||||
remove_ratelimit_rule(86400, 0, domain="spectator_attachment_access_by_file")
|
|
||||||
|
|
||||||
# Deny random file access
|
# Deny random file access
|
||||||
response = self.client_get(
|
response = self.client_get(
|
||||||
|
|
|
@ -11,7 +11,6 @@ from unittest.mock import patch
|
||||||
import orjson
|
import orjson
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.http.response import StreamingHttpResponse
|
from django.http.response import StreamingHttpResponse
|
||||||
from django.test import override_settings
|
|
||||||
from django.utils.timezone import now as timezone_now
|
from django.utils.timezone import now as timezone_now
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from urllib3 import encode_multipart_formdata
|
from urllib3 import encode_multipart_formdata
|
||||||
|
@ -31,12 +30,16 @@ from zerver.lib.avatar import avatar_url, get_avatar_field
|
||||||
from zerver.lib.cache import cache_get, get_realm_used_upload_space_cache_key
|
from zerver.lib.cache import cache_get, get_realm_used_upload_space_cache_key
|
||||||
from zerver.lib.create_user import copy_default_settings
|
from zerver.lib.create_user import copy_default_settings
|
||||||
from zerver.lib.initial_password import initial_password
|
from zerver.lib.initial_password import initial_password
|
||||||
from zerver.lib.rate_limiter import add_ratelimit_rule, remove_ratelimit_rule
|
|
||||||
from zerver.lib.realm_icon import realm_icon_url
|
from zerver.lib.realm_icon import realm_icon_url
|
||||||
from zerver.lib.realm_logo import get_realm_logo_url
|
from zerver.lib.realm_logo import get_realm_logo_url
|
||||||
from zerver.lib.retention import clean_archived_data
|
from zerver.lib.retention import clean_archived_data
|
||||||
from zerver.lib.test_classes import UploadSerializeMixin, ZulipTestCase
|
from zerver.lib.test_classes import UploadSerializeMixin, ZulipTestCase
|
||||||
from zerver.lib.test_helpers import avatar_disk_path, get_test_image_file, read_test_image_file
|
from zerver.lib.test_helpers import (
|
||||||
|
avatar_disk_path,
|
||||||
|
get_test_image_file,
|
||||||
|
ratelimit_rule,
|
||||||
|
read_test_image_file,
|
||||||
|
)
|
||||||
from zerver.lib.upload import delete_message_attachment, upload_message_attachment
|
from zerver.lib.upload import delete_message_attachment, upload_message_attachment
|
||||||
from zerver.lib.upload.base import BadImageError, ZulipUploadBackend, resize_emoji, sanitize_name
|
from zerver.lib.upload.base import BadImageError, ZulipUploadBackend, resize_emoji, sanitize_name
|
||||||
from zerver.lib.upload.local import LocalUploadBackend
|
from zerver.lib.upload.local import LocalUploadBackend
|
||||||
|
@ -225,7 +228,6 @@ class FileUploadTest(UploadSerializeMixin, ZulipTestCase):
|
||||||
result = self.client_get(url)
|
result = self.client_get(url)
|
||||||
self.assertEqual(result.status_code, 403)
|
self.assertEqual(result.status_code, 403)
|
||||||
|
|
||||||
@override_settings(RATE_LIMITING=True)
|
|
||||||
def test_serve_file_unauthed(self) -> None:
|
def test_serve_file_unauthed(self) -> None:
|
||||||
self.login("hamlet")
|
self.login("hamlet")
|
||||||
fp = StringIO("zulip!")
|
fp = StringIO("zulip!")
|
||||||
|
@ -234,34 +236,32 @@ class FileUploadTest(UploadSerializeMixin, ZulipTestCase):
|
||||||
result = self.client_post("/json/user_uploads", {"file": fp})
|
result = self.client_post("/json/user_uploads", {"file": fp})
|
||||||
url = self.assert_json_success(result)["uri"]
|
url = self.assert_json_success(result)["uri"]
|
||||||
|
|
||||||
add_ratelimit_rule(86400, 1000, domain="spectator_attachment_access_by_file")
|
with ratelimit_rule(86400, 1000, domain="spectator_attachment_access_by_file"):
|
||||||
# Deny file access for non-web-public stream
|
# Deny file access for non-web-public stream
|
||||||
self.subscribe(self.example_user("hamlet"), "Denmark")
|
self.subscribe(self.example_user("hamlet"), "Denmark")
|
||||||
host = self.example_user("hamlet").realm.host
|
host = self.example_user("hamlet").realm.host
|
||||||
body = f"First message ...[zulip.txt](http://{host}" + url + ")"
|
body = f"First message ...[zulip.txt](http://{host}" + url + ")"
|
||||||
self.send_stream_message(self.example_user("hamlet"), "Denmark", body, "test")
|
self.send_stream_message(self.example_user("hamlet"), "Denmark", body, "test")
|
||||||
|
|
||||||
self.logout()
|
self.logout()
|
||||||
response = self.client_get(url)
|
response = self.client_get(url)
|
||||||
self.assertEqual(response.status_code, 403)
|
self.assertEqual(response.status_code, 403)
|
||||||
|
|
||||||
# Allow file access for web-public stream
|
# Allow file access for web-public stream
|
||||||
self.login("hamlet")
|
self.login("hamlet")
|
||||||
self.make_stream("web-public-stream", is_web_public=True)
|
self.make_stream("web-public-stream", is_web_public=True)
|
||||||
self.subscribe(self.example_user("hamlet"), "web-public-stream")
|
self.subscribe(self.example_user("hamlet"), "web-public-stream")
|
||||||
body = f"First message ...[zulip.txt](http://{host}" + url + ")"
|
body = f"First message ...[zulip.txt](http://{host}" + url + ")"
|
||||||
self.send_stream_message(self.example_user("hamlet"), "web-public-stream", body, "test")
|
self.send_stream_message(self.example_user("hamlet"), "web-public-stream", body, "test")
|
||||||
|
|
||||||
self.logout()
|
self.logout()
|
||||||
response = self.client_get(url)
|
response = self.client_get(url)
|
||||||
self.assertEqual(response.status_code, 200)
|
self.assertEqual(response.status_code, 200)
|
||||||
remove_ratelimit_rule(86400, 1000, domain="spectator_attachment_access_by_file")
|
|
||||||
|
|
||||||
# Deny file access since rate limited
|
# Deny file access since rate limited
|
||||||
add_ratelimit_rule(86400, 0, domain="spectator_attachment_access_by_file")
|
with ratelimit_rule(86400, 0, domain="spectator_attachment_access_by_file"):
|
||||||
response = self.client_get(url)
|
response = self.client_get(url)
|
||||||
self.assertEqual(response.status_code, 403)
|
self.assertEqual(response.status_code, 403)
|
||||||
remove_ratelimit_rule(86400, 0, domain="spectator_attachment_access_by_file")
|
|
||||||
|
|
||||||
# Deny random file access
|
# Deny random file access
|
||||||
response = self.client_get(
|
response = self.client_get(
|
||||||
|
@ -1261,15 +1261,13 @@ class AvatarTest(UploadSerializeMixin, ZulipTestCase):
|
||||||
status_code=401,
|
status_code=401,
|
||||||
)
|
)
|
||||||
|
|
||||||
with self.settings(RATE_LIMITING=True):
|
# Allow unauthenticated/spectator requests by ID for a reasonable number of requests.
|
||||||
# Allow unauthenticated/spectator requests by ID for a reasonable number of requests.
|
with ratelimit_rule(86400, 1000, domain="spectator_attachment_access_by_file"):
|
||||||
add_ratelimit_rule(86400, 1000, domain="spectator_attachment_access_by_file")
|
|
||||||
response = self.client_get(f"/avatar/{cordelia.id}/medium", {"foo": "bar"})
|
response = self.client_get(f"/avatar/{cordelia.id}/medium", {"foo": "bar"})
|
||||||
self.assertEqual(302, response.status_code)
|
self.assertEqual(302, response.status_code)
|
||||||
remove_ratelimit_rule(86400, 1000, domain="spectator_attachment_access_by_file")
|
|
||||||
|
|
||||||
# Deny file access since rate limited
|
# Deny file access since rate limited
|
||||||
add_ratelimit_rule(86400, 0, domain="spectator_attachment_access_by_file")
|
with ratelimit_rule(86400, 0, domain="spectator_attachment_access_by_file"):
|
||||||
response = self.client_get(f"/avatar/{cordelia.id}/medium", {"foo": "bar"})
|
response = self.client_get(f"/avatar/{cordelia.id}/medium", {"foo": "bar"})
|
||||||
self.assertEqual(429, response.status_code)
|
self.assertEqual(429, response.status_code)
|
||||||
|
|
||||||
|
|
|
@ -293,7 +293,6 @@ def is_subdomain_in_allowed_subdomains_list(subdomain: str, allowed_subdomains:
|
||||||
|
|
||||||
|
|
||||||
AuthFuncT = TypeVar("AuthFuncT", bound=Callable[..., Optional[UserProfile]])
|
AuthFuncT = TypeVar("AuthFuncT", bound=Callable[..., Optional[UserProfile]])
|
||||||
rate_limiting_rules = settings.RATE_LIMITING_RULES["authenticate_by_username"]
|
|
||||||
|
|
||||||
|
|
||||||
class RateLimitedAuthenticationByUsername(RateLimitedObject):
|
class RateLimitedAuthenticationByUsername(RateLimitedObject):
|
||||||
|
@ -305,7 +304,7 @@ class RateLimitedAuthenticationByUsername(RateLimitedObject):
|
||||||
return f"{type(self).__name__}:{self.username}"
|
return f"{type(self).__name__}:{self.username}"
|
||||||
|
|
||||||
def rules(self) -> List[Tuple[int, int]]:
|
def rules(self) -> List[Tuple[int, int]]:
|
||||||
return rate_limiting_rules
|
return settings.RATE_LIMITING_RULES["authenticate_by_username"]
|
||||||
|
|
||||||
|
|
||||||
def rate_limit_authentication_by_username(request: HttpRequest, username: str) -> None:
|
def rate_limit_authentication_by_username(request: HttpRequest, username: str) -> None:
|
||||||
|
|
Loading…
Reference in New Issue