mirror of https://github.com/zulip/zulip.git
rate_limiter: Add more detailed automated tests.
Extracted by tabbott from the original commit to support testing without the Tornado version merged yet.
This commit is contained in:
parent
46a02e70b0
commit
218be002f1
|
@ -1,7 +1,7 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple, Type
|
||||||
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.http import HttpRequest
|
from django.http import HttpRequest
|
||||||
|
@ -29,7 +29,10 @@ class RateLimiterLockingException(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class RateLimitedObject(ABC):
|
class RateLimitedObject(ABC):
|
||||||
def __init__(self) -> None:
|
def __init__(self, backend: Optional['Type[RateLimiterBackend]']=None) -> None:
|
||||||
|
if backend is not None:
|
||||||
|
self.backend = backend # type: Type[RateLimiterBackend]
|
||||||
|
else:
|
||||||
self.backend = RedisRateLimiterBackend
|
self.backend = RedisRateLimiterBackend
|
||||||
|
|
||||||
def rate_limit(self) -> Tuple[bool, float]:
|
def rate_limit(self) -> Tuple[bool, float]:
|
||||||
|
|
|
@ -0,0 +1,188 @@
|
||||||
|
from zerver.lib.rate_limiter import (
|
||||||
|
add_ratelimit_rule,
|
||||||
|
remove_ratelimit_rule,
|
||||||
|
RateLimitedObject,
|
||||||
|
RateLimitedUser,
|
||||||
|
RateLimiterBackend,
|
||||||
|
RedisRateLimiterBackend,
|
||||||
|
)
|
||||||
|
|
||||||
|
from zerver.lib.test_classes import ZulipTestCase
|
||||||
|
from zerver.lib.utils import generate_random_token
|
||||||
|
|
||||||
|
from typing import Dict, List, Tuple, Type
|
||||||
|
|
||||||
|
import mock
|
||||||
|
import time
|
||||||
|
|
||||||
|
RANDOM_KEY_PREFIX = generate_random_token(32)
|
||||||
|
|
||||||
|
class RateLimitedTestObject(RateLimitedObject):
|
||||||
|
def __init__(self, name: str, rules: List[Tuple[int, int]],
|
||||||
|
backend: Type[RateLimiterBackend]) -> None:
|
||||||
|
self.name = name
|
||||||
|
self._rules = rules
|
||||||
|
self._rules.sort(key=lambda x: x[0])
|
||||||
|
super().__init__(backend)
|
||||||
|
|
||||||
|
def key(self) -> str:
|
||||||
|
return RANDOM_KEY_PREFIX + self.name
|
||||||
|
|
||||||
|
def rules(self) -> List[Tuple[int, int]]:
|
||||||
|
return self._rules
|
||||||
|
|
||||||
|
class RateLimiterBackendBase(ZulipTestCase):
|
||||||
|
__unittest_skip__ = True
|
||||||
|
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.requests_record = {} # type: Dict[str, List[float]]
|
||||||
|
|
||||||
|
def create_object(self, name: str, rules: List[Tuple[int, int]]) -> RateLimitedTestObject:
|
||||||
|
obj = RateLimitedTestObject(name, rules, self.backend)
|
||||||
|
obj.clear_history()
|
||||||
|
|
||||||
|
return obj
|
||||||
|
|
||||||
|
def make_request(self, obj: RateLimitedTestObject, expect_ratelimited: bool=False,
|
||||||
|
verify_api_calls_left: bool=True) -> None:
|
||||||
|
key = obj.key()
|
||||||
|
if key not in self.requests_record:
|
||||||
|
self.requests_record[key] = []
|
||||||
|
|
||||||
|
ratelimited, secs_to_freedom = obj.rate_limit()
|
||||||
|
if not ratelimited:
|
||||||
|
self.requests_record[key].append(time.time())
|
||||||
|
|
||||||
|
self.assertEqual(ratelimited, expect_ratelimited)
|
||||||
|
|
||||||
|
if verify_api_calls_left:
|
||||||
|
self.verify_api_calls_left(obj)
|
||||||
|
|
||||||
|
def verify_api_calls_left(self, obj: RateLimitedTestObject) -> None:
|
||||||
|
now = time.time()
|
||||||
|
with mock.patch('time.time', return_value=now):
|
||||||
|
calls_remaining, time_till_reset = obj.api_calls_left()
|
||||||
|
|
||||||
|
expected_calls_remaining, expected_time_till_reset = self.expected_api_calls_left(obj, now)
|
||||||
|
self.assertEqual(expected_calls_remaining, calls_remaining)
|
||||||
|
self.assertEqual(expected_time_till_reset, time_till_reset)
|
||||||
|
|
||||||
|
def expected_api_calls_left(self, obj: RateLimitedTestObject, now: float) -> Tuple[int, float]:
|
||||||
|
longest_rule = obj.rules()[-1]
|
||||||
|
max_window, max_calls = longest_rule
|
||||||
|
history = self.requests_record.get(obj.key())
|
||||||
|
if history is None:
|
||||||
|
return max_calls, 0
|
||||||
|
history.sort()
|
||||||
|
|
||||||
|
return self.api_calls_left_from_history(history, max_window, max_calls, now)
|
||||||
|
|
||||||
|
def api_calls_left_from_history(self, history: List[float], max_window: int,
|
||||||
|
max_calls: int, now: float) -> Tuple[int, float]:
|
||||||
|
"""
|
||||||
|
This depends on the algorithm used in the backend, and should be defined by the test class.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError # nocoverage
|
||||||
|
|
||||||
|
def test_hit_ratelimits(self) -> None:
|
||||||
|
obj = self.create_object('test', [(2, 3), ])
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
for i in range(3):
|
||||||
|
with mock.patch('time.time', return_value=(start_time + i * 0.1)):
|
||||||
|
self.make_request(obj, expect_ratelimited=False)
|
||||||
|
|
||||||
|
with mock.patch('time.time', return_value=(start_time + 0.4)):
|
||||||
|
self.make_request(obj, expect_ratelimited=True)
|
||||||
|
|
||||||
|
with mock.patch('time.time', return_value=(start_time + 2.01)):
|
||||||
|
self.make_request(obj, expect_ratelimited=False)
|
||||||
|
|
||||||
|
def test_clear_history(self) -> None:
|
||||||
|
obj = self.create_object('test', [(2, 3), ])
|
||||||
|
start_time = time.time()
|
||||||
|
for i in range(3):
|
||||||
|
with mock.patch('time.time', return_value=(start_time + i * 0.1)):
|
||||||
|
self.make_request(obj, expect_ratelimited=False)
|
||||||
|
with mock.patch('time.time', return_value=(start_time + 0.4)):
|
||||||
|
self.make_request(obj, expect_ratelimited=True)
|
||||||
|
|
||||||
|
obj.clear_history()
|
||||||
|
self.requests_record[obj.key()] = []
|
||||||
|
for i in range(3):
|
||||||
|
with mock.patch('time.time', return_value=(start_time + i * 0.1)):
|
||||||
|
self.make_request(obj, expect_ratelimited=False)
|
||||||
|
|
||||||
|
def test_block_unblock_access(self) -> None:
|
||||||
|
obj = self.create_object('test', [(2, 5), ])
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
obj.block_access(1)
|
||||||
|
with mock.patch('time.time', return_value=(start_time)):
|
||||||
|
self.make_request(obj, expect_ratelimited=True, verify_api_calls_left=False)
|
||||||
|
|
||||||
|
obj.unblock_access()
|
||||||
|
with mock.patch('time.time', return_value=(start_time)):
|
||||||
|
self.make_request(obj, expect_ratelimited=False, verify_api_calls_left=False)
|
||||||
|
|
||||||
|
def test_api_calls_left(self) -> None:
|
||||||
|
obj = self.create_object('test', [(2, 5), (3, 6)])
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Check the edge case when no requests have been made yet.
|
||||||
|
with mock.patch('time.time', return_value=(start_time)):
|
||||||
|
self.verify_api_calls_left(obj)
|
||||||
|
|
||||||
|
with mock.patch('time.time', return_value=(start_time)):
|
||||||
|
self.make_request(obj)
|
||||||
|
|
||||||
|
# Check the correct default values again, after the reset has happened on the first rule,
|
||||||
|
# but not the other.
|
||||||
|
with mock.patch('time.time', return_value=(start_time + 2.1)):
|
||||||
|
self.make_request(obj)
|
||||||
|
|
||||||
|
class RedisRateLimiterBackendTest(RateLimiterBackendBase):
|
||||||
|
__unittest_skip__ = False
|
||||||
|
backend = RedisRateLimiterBackend
|
||||||
|
|
||||||
|
def api_calls_left_from_history(self, history: List[float], max_window: int,
|
||||||
|
max_calls: int, now: float) -> Tuple[int, float]:
|
||||||
|
latest_timestamp = history[-1]
|
||||||
|
relevant_requests = [t for t in history if (t >= now - max_window)]
|
||||||
|
relevant_requests_amount = len(relevant_requests)
|
||||||
|
|
||||||
|
return max_calls - relevant_requests_amount, latest_timestamp + max_window - now
|
||||||
|
|
||||||
|
def test_block_access(self) -> None:
|
||||||
|
"""
|
||||||
|
This test cannot verify that the user will get unblocked
|
||||||
|
after the correct amount of time, because that event happens
|
||||||
|
inside redis, so we're not able to mock the timer. Making the test
|
||||||
|
sleep for 1s is also too costly to be worth it.
|
||||||
|
"""
|
||||||
|
obj = self.create_object('test', [(2, 5), ])
|
||||||
|
|
||||||
|
obj.block_access(1)
|
||||||
|
self.make_request(obj, expect_ratelimited=True, verify_api_calls_left=False)
|
||||||
|
|
||||||
|
class RateLimitedUserTest(ZulipTestCase):
|
||||||
|
def test_user_rate_limits(self) -> None:
|
||||||
|
user_profile = self.example_user("hamlet")
|
||||||
|
user_profile.rate_limits = "1:3,2:4"
|
||||||
|
obj = RateLimitedUser(user_profile)
|
||||||
|
|
||||||
|
self.assertEqual(obj.rules(), [(1, 3), (2, 4)])
|
||||||
|
|
||||||
|
def test_add_remove_rule(self) -> None:
|
||||||
|
user_profile = self.example_user("hamlet")
|
||||||
|
add_ratelimit_rule(1, 2)
|
||||||
|
add_ratelimit_rule(4, 5, domain='some_new_domain')
|
||||||
|
add_ratelimit_rule(10, 100, domain='some_new_domain')
|
||||||
|
obj = RateLimitedUser(user_profile)
|
||||||
|
|
||||||
|
self.assertEqual(obj.rules(), [(1, 2), ])
|
||||||
|
obj.domain = 'some_new_domain'
|
||||||
|
self.assertEqual(obj.rules(), [(4, 5), (10, 100)])
|
||||||
|
|
||||||
|
remove_ratelimit_rule(10, 100, domain='some_new_domain')
|
||||||
|
self.assertEqual(obj.rules(), [(4, 5), ])
|
Loading…
Reference in New Issue