rate_limiter: Add more detailed automated tests.

Extracted by tabbott from the original commit to support testing
without the Tornado version merged yet.
This commit is contained in:
Mateusz Mandera 2020-04-02 20:40:10 +02:00 committed by Tim Abbott
parent 46a02e70b0
commit 218be002f1
2 changed files with 194 additions and 3 deletions

View File

@ -1,7 +1,7 @@
import os
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Type
from django.conf import settings
from django.http import HttpRequest
@ -29,8 +29,11 @@ class RateLimiterLockingException(Exception):
pass
class RateLimitedObject(ABC):
def __init__(self) -> None:
self.backend = RedisRateLimiterBackend
def __init__(self, backend: Optional['Type[RateLimiterBackend]']=None) -> None:
if backend is not None:
self.backend = backend # type: Type[RateLimiterBackend]
else:
self.backend = RedisRateLimiterBackend
def rate_limit(self) -> Tuple[bool, float]:
# Returns (ratelimited, secs_to_freedom)

View File

@ -0,0 +1,188 @@
from zerver.lib.rate_limiter import (
add_ratelimit_rule,
remove_ratelimit_rule,
RateLimitedObject,
RateLimitedUser,
RateLimiterBackend,
RedisRateLimiterBackend,
)
from zerver.lib.test_classes import ZulipTestCase
from zerver.lib.utils import generate_random_token
from typing import Dict, List, Tuple, Type
import mock
import time
RANDOM_KEY_PREFIX = generate_random_token(32)
class RateLimitedTestObject(RateLimitedObject):
def __init__(self, name: str, rules: List[Tuple[int, int]],
backend: Type[RateLimiterBackend]) -> None:
self.name = name
self._rules = rules
self._rules.sort(key=lambda x: x[0])
super().__init__(backend)
def key(self) -> str:
return RANDOM_KEY_PREFIX + self.name
def rules(self) -> List[Tuple[int, int]]:
return self._rules
class RateLimiterBackendBase(ZulipTestCase):
__unittest_skip__ = True
def setUp(self) -> None:
self.requests_record = {} # type: Dict[str, List[float]]
def create_object(self, name: str, rules: List[Tuple[int, int]]) -> RateLimitedTestObject:
obj = RateLimitedTestObject(name, rules, self.backend)
obj.clear_history()
return obj
def make_request(self, obj: RateLimitedTestObject, expect_ratelimited: bool=False,
verify_api_calls_left: bool=True) -> None:
key = obj.key()
if key not in self.requests_record:
self.requests_record[key] = []
ratelimited, secs_to_freedom = obj.rate_limit()
if not ratelimited:
self.requests_record[key].append(time.time())
self.assertEqual(ratelimited, expect_ratelimited)
if verify_api_calls_left:
self.verify_api_calls_left(obj)
def verify_api_calls_left(self, obj: RateLimitedTestObject) -> None:
now = time.time()
with mock.patch('time.time', return_value=now):
calls_remaining, time_till_reset = obj.api_calls_left()
expected_calls_remaining, expected_time_till_reset = self.expected_api_calls_left(obj, now)
self.assertEqual(expected_calls_remaining, calls_remaining)
self.assertEqual(expected_time_till_reset, time_till_reset)
def expected_api_calls_left(self, obj: RateLimitedTestObject, now: float) -> Tuple[int, float]:
longest_rule = obj.rules()[-1]
max_window, max_calls = longest_rule
history = self.requests_record.get(obj.key())
if history is None:
return max_calls, 0
history.sort()
return self.api_calls_left_from_history(history, max_window, max_calls, now)
def api_calls_left_from_history(self, history: List[float], max_window: int,
max_calls: int, now: float) -> Tuple[int, float]:
"""
This depends on the algorithm used in the backend, and should be defined by the test class.
"""
raise NotImplementedError # nocoverage
def test_hit_ratelimits(self) -> None:
obj = self.create_object('test', [(2, 3), ])
start_time = time.time()
for i in range(3):
with mock.patch('time.time', return_value=(start_time + i * 0.1)):
self.make_request(obj, expect_ratelimited=False)
with mock.patch('time.time', return_value=(start_time + 0.4)):
self.make_request(obj, expect_ratelimited=True)
with mock.patch('time.time', return_value=(start_time + 2.01)):
self.make_request(obj, expect_ratelimited=False)
def test_clear_history(self) -> None:
obj = self.create_object('test', [(2, 3), ])
start_time = time.time()
for i in range(3):
with mock.patch('time.time', return_value=(start_time + i * 0.1)):
self.make_request(obj, expect_ratelimited=False)
with mock.patch('time.time', return_value=(start_time + 0.4)):
self.make_request(obj, expect_ratelimited=True)
obj.clear_history()
self.requests_record[obj.key()] = []
for i in range(3):
with mock.patch('time.time', return_value=(start_time + i * 0.1)):
self.make_request(obj, expect_ratelimited=False)
def test_block_unblock_access(self) -> None:
obj = self.create_object('test', [(2, 5), ])
start_time = time.time()
obj.block_access(1)
with mock.patch('time.time', return_value=(start_time)):
self.make_request(obj, expect_ratelimited=True, verify_api_calls_left=False)
obj.unblock_access()
with mock.patch('time.time', return_value=(start_time)):
self.make_request(obj, expect_ratelimited=False, verify_api_calls_left=False)
def test_api_calls_left(self) -> None:
obj = self.create_object('test', [(2, 5), (3, 6)])
start_time = time.time()
# Check the edge case when no requests have been made yet.
with mock.patch('time.time', return_value=(start_time)):
self.verify_api_calls_left(obj)
with mock.patch('time.time', return_value=(start_time)):
self.make_request(obj)
# Check the correct default values again, after the reset has happened on the first rule,
# but not the other.
with mock.patch('time.time', return_value=(start_time + 2.1)):
self.make_request(obj)
class RedisRateLimiterBackendTest(RateLimiterBackendBase):
__unittest_skip__ = False
backend = RedisRateLimiterBackend
def api_calls_left_from_history(self, history: List[float], max_window: int,
max_calls: int, now: float) -> Tuple[int, float]:
latest_timestamp = history[-1]
relevant_requests = [t for t in history if (t >= now - max_window)]
relevant_requests_amount = len(relevant_requests)
return max_calls - relevant_requests_amount, latest_timestamp + max_window - now
def test_block_access(self) -> None:
"""
This test cannot verify that the user will get unblocked
after the correct amount of time, because that event happens
inside redis, so we're not able to mock the timer. Making the test
sleep for 1s is also too costly to be worth it.
"""
obj = self.create_object('test', [(2, 5), ])
obj.block_access(1)
self.make_request(obj, expect_ratelimited=True, verify_api_calls_left=False)
class RateLimitedUserTest(ZulipTestCase):
def test_user_rate_limits(self) -> None:
user_profile = self.example_user("hamlet")
user_profile.rate_limits = "1:3,2:4"
obj = RateLimitedUser(user_profile)
self.assertEqual(obj.rules(), [(1, 3), (2, 4)])
def test_add_remove_rule(self) -> None:
user_profile = self.example_user("hamlet")
add_ratelimit_rule(1, 2)
add_ratelimit_rule(4, 5, domain='some_new_domain')
add_ratelimit_rule(10, 100, domain='some_new_domain')
obj = RateLimitedUser(user_profile)
self.assertEqual(obj.rules(), [(1, 2), ])
obj.domain = 'some_new_domain'
self.assertEqual(obj.rules(), [(4, 5), (10, 100)])
remove_ratelimit_rule(10, 100, domain='some_new_domain')
self.assertEqual(obj.rules(), [(4, 5), ])