rate_limiter: Handle edge case where rules list may be empty.

This commit is contained in:
Mateusz Mandera 2020-04-02 22:23:20 +02:00 committed by Tim Abbott
parent b577366a05
commit 5f9da3053d
2 changed files with 27 additions and 21 deletions

View File

@ -39,7 +39,7 @@ class RateLimitedObject(ABC):
def rate_limit(self) -> Tuple[bool, float]:
# Returns (ratelimited, secs_to_freedom)
return self.backend.rate_limit_entity(self.key(), self.rules(),
return self.backend.rate_limit_entity(self.key(), self.get_rules(),
self.max_api_calls(),
self.max_api_window())
@ -76,11 +76,11 @@ class RateLimitedObject(ABC):
def max_api_calls(self) -> int:
"Returns the API rate limit for the highest limit"
return self.rules()[-1][1]
return self.get_rules()[-1][1]
def max_api_window(self) -> int:
"Returns the API time window for the highest limit"
return self.rules()[-1][0]
return self.get_rules()[-1][0]
def api_calls_left(self) -> Tuple[int, float]:
"""Returns how many API calls in this range this client has, as well as when
@ -89,6 +89,16 @@ class RateLimitedObject(ABC):
max_calls = self.max_api_calls()
return self.backend.get_api_calls_left(self.key(), max_window, max_calls)
def get_rules(self) -> List[Tuple[int, int]]:
"""
This is a simple wrapper meant to protect against having to deal with
an empty list of rules, as it would require fiddling with that special case
all around this system. "9999 max request per seconds" should be a good proxy
for "no rules".
"""
rules_list = self.rules()
return rules_list or [(1, 9999), ]
@abstractmethod
def key(self) -> str:
pass
@ -270,8 +280,7 @@ class TornadoInMemoryRateLimiterBackend(RateLimiterBackend):
else:
del cls.timestamps_blocked_until[entity_key]
if len(rules) == 0:
return False, 0
assert rules
for time_window, max_count in rules:
ratelimited, time_till_free = cls.need_to_limit(entity_key, time_window, max_count)
@ -338,6 +347,7 @@ class RedisRateLimiterBackend(RateLimiterBackend):
@classmethod
def is_ratelimited(cls, entity_key: str, rules: List[Tuple[int, int]]) -> Tuple[bool, float]:
"Returns a tuple of (rate_limited, time_till_free)"
assert rules
list_key, set_key, blocking_key = cls.get_keys(entity_key)
# Go through the rules from shortest to longest,
@ -365,9 +375,6 @@ class RedisRateLimiterBackend(RateLimiterBackend):
blocking_ttl = int(blocking_ttl_b)
return True, blocking_ttl
if len(rules) == 0:
return False, 0.0
now = time.time()
for timestamp, (range_seconds, num_requests) in zip(rule_timestamps, rules):
# Check if the nth timestamp is newer than the associated rule. If so,
@ -383,16 +390,11 @@ class RedisRateLimiterBackend(RateLimiterBackend):
return False, 0.0
@classmethod
def incr_ratelimit(cls, entity_key: str, rules: List[Tuple[int, int]],
max_api_calls: int, max_api_window: int) -> None:
def incr_ratelimit(cls, entity_key: str, max_api_calls: int, max_api_window: int) -> None:
"""Increases the rate-limit for the specified entity"""
list_key, set_key, _ = cls.get_keys(entity_key)
now = time.time()
# If we have no rules, we don't store anything
if len(rules) == 0:
return
# Start redis transaction
with client.pipeline() as pipe:
count = 0
@ -451,7 +453,7 @@ class RedisRateLimiterBackend(RateLimiterBackend):
else:
try:
cls.incr_ratelimit(entity_key, rules, max_api_calls, max_api_window)
cls.incr_ratelimit(entity_key, max_api_calls, max_api_window)
except RateLimiterLockingException:
logger.warning("Deadlock trying to incr_ratelimit for %s" % (entity_key,))
# rate-limit users who are hitting the API so hard we can't update our stats.

View File

@ -69,7 +69,7 @@ class RateLimiterBackendBase(ZulipTestCase):
self.assertEqual(expected_time_till_reset, time_till_reset)
def expected_api_calls_left(self, obj: RateLimitedTestObject, now: float) -> Tuple[int, float]:
longest_rule = obj.rules()[-1]
longest_rule = obj.get_rules()[-1]
max_window, max_calls = longest_rule
history = self.requests_record.get(obj.key())
if history is None:
@ -198,13 +198,13 @@ class TornadoInMemoryRateLimiterBackendTest(RateLimiterBackendBase):
with mock.patch('time.time', return_value=(start_time + 1.01)):
self.make_request(obj, expect_ratelimited=False, verify_api_calls_left=False)
class RateLimitedUserTest(ZulipTestCase):
class RateLimitedObjectsTest(ZulipTestCase):
def test_user_rate_limits(self) -> None:
user_profile = self.example_user("hamlet")
user_profile.rate_limits = "1:3,2:4"
obj = RateLimitedUser(user_profile)
self.assertEqual(obj.rules(), [(1, 3), (2, 4)])
self.assertEqual(obj.get_rules(), [(1, 3), (2, 4)])
def test_add_remove_rule(self) -> None:
user_profile = self.example_user("hamlet")
@ -213,9 +213,13 @@ class RateLimitedUserTest(ZulipTestCase):
add_ratelimit_rule(10, 100, domain='some_new_domain')
obj = RateLimitedUser(user_profile)
self.assertEqual(obj.rules(), [(1, 2), ])
self.assertEqual(obj.get_rules(), [(1, 2), ])
obj.domain = 'some_new_domain'
self.assertEqual(obj.rules(), [(4, 5), (10, 100)])
self.assertEqual(obj.get_rules(), [(4, 5), (10, 100)])
remove_ratelimit_rule(10, 100, domain='some_new_domain')
self.assertEqual(obj.rules(), [(4, 5), ])
self.assertEqual(obj.get_rules(), [(4, 5), ])
def test_empty_rules_edge_case(self) -> None:
obj = RateLimitedTestObject("test", rules=[], backend=RedisRateLimiterBackend)
self.assertEqual(obj.get_rules(), [(1, 9999), ])