diff --git a/zerver/lib/rate_limiter.py b/zerver/lib/rate_limiter.py index 335dd5486c..b6b8cbfd3e 100644 --- a/zerver/lib/rate_limiter.py +++ b/zerver/lib/rate_limiter.py @@ -39,7 +39,7 @@ class RateLimitedObject(ABC): def rate_limit(self) -> Tuple[bool, float]: # Returns (ratelimited, secs_to_freedom) - return self.backend.rate_limit_entity(self.key(), self.rules(), + return self.backend.rate_limit_entity(self.key(), self.get_rules(), self.max_api_calls(), self.max_api_window()) @@ -76,11 +76,11 @@ class RateLimitedObject(ABC): def max_api_calls(self) -> int: "Returns the API rate limit for the highest limit" - return self.rules()[-1][1] + return self.get_rules()[-1][1] def max_api_window(self) -> int: "Returns the API time window for the highest limit" - return self.rules()[-1][0] + return self.get_rules()[-1][0] def api_calls_left(self) -> Tuple[int, float]: """Returns how many API calls in this range this client has, as well as when @@ -89,6 +89,16 @@ class RateLimitedObject(ABC): max_calls = self.max_api_calls() return self.backend.get_api_calls_left(self.key(), max_window, max_calls) + def get_rules(self) -> List[Tuple[int, int]]: + """ + This is a simple wrapper meant to protect against having to deal with + an empty list of rules, as it would require fiddling with that special case + all around this system. "9999 max request per seconds" should be a good proxy + for "no rules". + """ + rules_list = self.rules() + return rules_list or [(1, 9999), ] + @abstractmethod def key(self) -> str: pass @@ -270,8 +280,7 @@ class TornadoInMemoryRateLimiterBackend(RateLimiterBackend): else: del cls.timestamps_blocked_until[entity_key] - if len(rules) == 0: - return False, 0 + assert rules for time_window, max_count in rules: ratelimited, time_till_free = cls.need_to_limit(entity_key, time_window, max_count) @@ -338,6 +347,7 @@ class RedisRateLimiterBackend(RateLimiterBackend): @classmethod def is_ratelimited(cls, entity_key: str, rules: List[Tuple[int, int]]) -> Tuple[bool, float]: "Returns a tuple of (rate_limited, time_till_free)" + assert rules list_key, set_key, blocking_key = cls.get_keys(entity_key) # Go through the rules from shortest to longest, @@ -365,9 +375,6 @@ class RedisRateLimiterBackend(RateLimiterBackend): blocking_ttl = int(blocking_ttl_b) return True, blocking_ttl - if len(rules) == 0: - return False, 0.0 - now = time.time() for timestamp, (range_seconds, num_requests) in zip(rule_timestamps, rules): # Check if the nth timestamp is newer than the associated rule. If so, @@ -383,16 +390,11 @@ class RedisRateLimiterBackend(RateLimiterBackend): return False, 0.0 @classmethod - def incr_ratelimit(cls, entity_key: str, rules: List[Tuple[int, int]], - max_api_calls: int, max_api_window: int) -> None: + def incr_ratelimit(cls, entity_key: str, max_api_calls: int, max_api_window: int) -> None: """Increases the rate-limit for the specified entity""" list_key, set_key, _ = cls.get_keys(entity_key) now = time.time() - # If we have no rules, we don't store anything - if len(rules) == 0: - return - # Start redis transaction with client.pipeline() as pipe: count = 0 @@ -451,7 +453,7 @@ class RedisRateLimiterBackend(RateLimiterBackend): else: try: - cls.incr_ratelimit(entity_key, rules, max_api_calls, max_api_window) + cls.incr_ratelimit(entity_key, max_api_calls, max_api_window) except RateLimiterLockingException: logger.warning("Deadlock trying to incr_ratelimit for %s" % (entity_key,)) # rate-limit users who are hitting the API so hard we can't update our stats. diff --git a/zerver/tests/test_rate_limiter.py b/zerver/tests/test_rate_limiter.py index 6a9f93743f..91b59b75dc 100644 --- a/zerver/tests/test_rate_limiter.py +++ b/zerver/tests/test_rate_limiter.py @@ -69,7 +69,7 @@ class RateLimiterBackendBase(ZulipTestCase): self.assertEqual(expected_time_till_reset, time_till_reset) def expected_api_calls_left(self, obj: RateLimitedTestObject, now: float) -> Tuple[int, float]: - longest_rule = obj.rules()[-1] + longest_rule = obj.get_rules()[-1] max_window, max_calls = longest_rule history = self.requests_record.get(obj.key()) if history is None: @@ -198,13 +198,13 @@ class TornadoInMemoryRateLimiterBackendTest(RateLimiterBackendBase): with mock.patch('time.time', return_value=(start_time + 1.01)): self.make_request(obj, expect_ratelimited=False, verify_api_calls_left=False) -class RateLimitedUserTest(ZulipTestCase): +class RateLimitedObjectsTest(ZulipTestCase): def test_user_rate_limits(self) -> None: user_profile = self.example_user("hamlet") user_profile.rate_limits = "1:3,2:4" obj = RateLimitedUser(user_profile) - self.assertEqual(obj.rules(), [(1, 3), (2, 4)]) + self.assertEqual(obj.get_rules(), [(1, 3), (2, 4)]) def test_add_remove_rule(self) -> None: user_profile = self.example_user("hamlet") @@ -213,9 +213,13 @@ class RateLimitedUserTest(ZulipTestCase): add_ratelimit_rule(10, 100, domain='some_new_domain') obj = RateLimitedUser(user_profile) - self.assertEqual(obj.rules(), [(1, 2), ]) + self.assertEqual(obj.get_rules(), [(1, 2), ]) obj.domain = 'some_new_domain' - self.assertEqual(obj.rules(), [(4, 5), (10, 100)]) + self.assertEqual(obj.get_rules(), [(4, 5), (10, 100)]) remove_ratelimit_rule(10, 100, domain='some_new_domain') - self.assertEqual(obj.rules(), [(4, 5), ]) + self.assertEqual(obj.get_rules(), [(4, 5), ]) + + def test_empty_rules_edge_case(self) -> None: + obj = RateLimitedTestObject("test", rules=[], backend=RedisRateLimiterBackend) + self.assertEqual(obj.get_rules(), [(1, 9999), ])