mirror of https://github.com/zulip/zulip.git
rate_limiter: Handle edge case where rules list may be empty.
This commit is contained in:
parent
b577366a05
commit
5f9da3053d
|
@ -39,7 +39,7 @@ class RateLimitedObject(ABC):
|
|||
|
||||
def rate_limit(self) -> Tuple[bool, float]:
|
||||
# Returns (ratelimited, secs_to_freedom)
|
||||
return self.backend.rate_limit_entity(self.key(), self.rules(),
|
||||
return self.backend.rate_limit_entity(self.key(), self.get_rules(),
|
||||
self.max_api_calls(),
|
||||
self.max_api_window())
|
||||
|
||||
|
@ -76,11 +76,11 @@ class RateLimitedObject(ABC):
|
|||
|
||||
def max_api_calls(self) -> int:
|
||||
"Returns the API rate limit for the highest limit"
|
||||
return self.rules()[-1][1]
|
||||
return self.get_rules()[-1][1]
|
||||
|
||||
def max_api_window(self) -> int:
|
||||
"Returns the API time window for the highest limit"
|
||||
return self.rules()[-1][0]
|
||||
return self.get_rules()[-1][0]
|
||||
|
||||
def api_calls_left(self) -> Tuple[int, float]:
|
||||
"""Returns how many API calls in this range this client has, as well as when
|
||||
|
@ -89,6 +89,16 @@ class RateLimitedObject(ABC):
|
|||
max_calls = self.max_api_calls()
|
||||
return self.backend.get_api_calls_left(self.key(), max_window, max_calls)
|
||||
|
||||
def get_rules(self) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
This is a simple wrapper meant to protect against having to deal with
|
||||
an empty list of rules, as it would require fiddling with that special case
|
||||
all around this system. "9999 max request per seconds" should be a good proxy
|
||||
for "no rules".
|
||||
"""
|
||||
rules_list = self.rules()
|
||||
return rules_list or [(1, 9999), ]
|
||||
|
||||
@abstractmethod
|
||||
def key(self) -> str:
|
||||
pass
|
||||
|
@ -270,8 +280,7 @@ class TornadoInMemoryRateLimiterBackend(RateLimiterBackend):
|
|||
else:
|
||||
del cls.timestamps_blocked_until[entity_key]
|
||||
|
||||
if len(rules) == 0:
|
||||
return False, 0
|
||||
assert rules
|
||||
for time_window, max_count in rules:
|
||||
ratelimited, time_till_free = cls.need_to_limit(entity_key, time_window, max_count)
|
||||
|
||||
|
@ -338,6 +347,7 @@ class RedisRateLimiterBackend(RateLimiterBackend):
|
|||
@classmethod
|
||||
def is_ratelimited(cls, entity_key: str, rules: List[Tuple[int, int]]) -> Tuple[bool, float]:
|
||||
"Returns a tuple of (rate_limited, time_till_free)"
|
||||
assert rules
|
||||
list_key, set_key, blocking_key = cls.get_keys(entity_key)
|
||||
|
||||
# Go through the rules from shortest to longest,
|
||||
|
@ -365,9 +375,6 @@ class RedisRateLimiterBackend(RateLimiterBackend):
|
|||
blocking_ttl = int(blocking_ttl_b)
|
||||
return True, blocking_ttl
|
||||
|
||||
if len(rules) == 0:
|
||||
return False, 0.0
|
||||
|
||||
now = time.time()
|
||||
for timestamp, (range_seconds, num_requests) in zip(rule_timestamps, rules):
|
||||
# Check if the nth timestamp is newer than the associated rule. If so,
|
||||
|
@ -383,16 +390,11 @@ class RedisRateLimiterBackend(RateLimiterBackend):
|
|||
return False, 0.0
|
||||
|
||||
@classmethod
|
||||
def incr_ratelimit(cls, entity_key: str, rules: List[Tuple[int, int]],
|
||||
max_api_calls: int, max_api_window: int) -> None:
|
||||
def incr_ratelimit(cls, entity_key: str, max_api_calls: int, max_api_window: int) -> None:
|
||||
"""Increases the rate-limit for the specified entity"""
|
||||
list_key, set_key, _ = cls.get_keys(entity_key)
|
||||
now = time.time()
|
||||
|
||||
# If we have no rules, we don't store anything
|
||||
if len(rules) == 0:
|
||||
return
|
||||
|
||||
# Start redis transaction
|
||||
with client.pipeline() as pipe:
|
||||
count = 0
|
||||
|
@ -451,7 +453,7 @@ class RedisRateLimiterBackend(RateLimiterBackend):
|
|||
|
||||
else:
|
||||
try:
|
||||
cls.incr_ratelimit(entity_key, rules, max_api_calls, max_api_window)
|
||||
cls.incr_ratelimit(entity_key, max_api_calls, max_api_window)
|
||||
except RateLimiterLockingException:
|
||||
logger.warning("Deadlock trying to incr_ratelimit for %s" % (entity_key,))
|
||||
# rate-limit users who are hitting the API so hard we can't update our stats.
|
||||
|
|
|
@ -69,7 +69,7 @@ class RateLimiterBackendBase(ZulipTestCase):
|
|||
self.assertEqual(expected_time_till_reset, time_till_reset)
|
||||
|
||||
def expected_api_calls_left(self, obj: RateLimitedTestObject, now: float) -> Tuple[int, float]:
|
||||
longest_rule = obj.rules()[-1]
|
||||
longest_rule = obj.get_rules()[-1]
|
||||
max_window, max_calls = longest_rule
|
||||
history = self.requests_record.get(obj.key())
|
||||
if history is None:
|
||||
|
@ -198,13 +198,13 @@ class TornadoInMemoryRateLimiterBackendTest(RateLimiterBackendBase):
|
|||
with mock.patch('time.time', return_value=(start_time + 1.01)):
|
||||
self.make_request(obj, expect_ratelimited=False, verify_api_calls_left=False)
|
||||
|
||||
class RateLimitedUserTest(ZulipTestCase):
|
||||
class RateLimitedObjectsTest(ZulipTestCase):
|
||||
def test_user_rate_limits(self) -> None:
|
||||
user_profile = self.example_user("hamlet")
|
||||
user_profile.rate_limits = "1:3,2:4"
|
||||
obj = RateLimitedUser(user_profile)
|
||||
|
||||
self.assertEqual(obj.rules(), [(1, 3), (2, 4)])
|
||||
self.assertEqual(obj.get_rules(), [(1, 3), (2, 4)])
|
||||
|
||||
def test_add_remove_rule(self) -> None:
|
||||
user_profile = self.example_user("hamlet")
|
||||
|
@ -213,9 +213,13 @@ class RateLimitedUserTest(ZulipTestCase):
|
|||
add_ratelimit_rule(10, 100, domain='some_new_domain')
|
||||
obj = RateLimitedUser(user_profile)
|
||||
|
||||
self.assertEqual(obj.rules(), [(1, 2), ])
|
||||
self.assertEqual(obj.get_rules(), [(1, 2), ])
|
||||
obj.domain = 'some_new_domain'
|
||||
self.assertEqual(obj.rules(), [(4, 5), (10, 100)])
|
||||
self.assertEqual(obj.get_rules(), [(4, 5), (10, 100)])
|
||||
|
||||
remove_ratelimit_rule(10, 100, domain='some_new_domain')
|
||||
self.assertEqual(obj.rules(), [(4, 5), ])
|
||||
self.assertEqual(obj.get_rules(), [(4, 5), ])
|
||||
|
||||
def test_empty_rules_edge_case(self) -> None:
|
||||
obj = RateLimitedTestObject("test", rules=[], backend=RedisRateLimiterBackend)
|
||||
self.assertEqual(obj.get_rules(), [(1, 9999), ])
|
||||
|
|
Loading…
Reference in New Issue