diff --git a/zerver/forms.py b/zerver/forms.py index 2dcc0fd1f9..5257fbf6da 100644 --- a/zerver/forms.py +++ b/zerver/forms.py @@ -303,6 +303,7 @@ class ZulipPasswordResetForm(PasswordResetForm): class RateLimitedPasswordResetByEmail(RateLimitedObject): def __init__(self, email: str) -> None: self.email = email + super().__init__() def __str__(self) -> str: return "Email: {}".format(self.email) diff --git a/zerver/lib/email_mirror.py b/zerver/lib/email_mirror.py index aa554f64ff..55704857bf 100644 --- a/zerver/lib/email_mirror.py +++ b/zerver/lib/email_mirror.py @@ -442,6 +442,7 @@ def mirror_email_message(data: Dict[str, str]) -> Dict[str, str]: class RateLimitedRealmMirror(RateLimitedObject): def __init__(self, realm: Realm) -> None: self.realm = realm + super().__init__() def key_fragment(self) -> str: return "emailmirror:{}:{}".format(type(self.realm), self.realm.id) diff --git a/zerver/lib/rate_limiter.py b/zerver/lib/rate_limiter.py index 907af13968..ff47010f50 100644 --- a/zerver/lib/rate_limiter.py +++ b/zerver/lib/rate_limiter.py @@ -29,6 +29,9 @@ class RateLimiterLockingException(Exception): pass class RateLimitedObject(ABC): + def __init__(self) -> None: + self.backend = RedisRateLimiterBackend + def get_keys(self) -> List[str]: key_fragment = self.key_fragment() return ["{}ratelimit:{}:{}".format(KEY_PREFIX, key_fragment, keytype) @@ -36,21 +39,7 @@ class RateLimitedObject(ABC): def rate_limit(self) -> Tuple[bool, float]: # Returns (ratelimited, secs_to_freedom) - ratelimited, time = is_ratelimited(self) - - if ratelimited: - statsd.incr("ratelimiter.limited.%s.%s" % (type(self), str(self))) - - else: - try: - incr_ratelimit(self) - except RateLimiterLockingException: - logger.warning("Deadlock trying to incr_ratelimit for %s:%s" % ( - type(self).__name__, str(self))) - # rate-limit users who are hitting the API so hard we can't update our stats. - ratelimited = True - - return ratelimited, time + return self.backend.rate_limit_entity(self) def rate_limit_request(self, request: HttpRequest) -> None: ratelimited, time = self.rate_limit() @@ -75,23 +64,17 @@ class RateLimitedObject(ABC): def block_access(self, seconds: int) -> None: "Manually blocks an entity for the desired number of seconds" - _, _, blocking_key = self.get_keys() - with client.pipeline() as pipe: - pipe.set(blocking_key, 1) - pipe.expire(blocking_key, seconds) - pipe.execute() + self.backend.block_access(self, seconds) def unblock_access(self) -> None: - _, _, blocking_key = self.get_keys() - client.delete(blocking_key) + self.backend.unblock_access(self) def clear_history(self) -> None: ''' This is only used by test code now, where it's very helpful in allowing us to run tests quickly, by giving a user a clean slate. ''' - for key in self.get_keys(): - client.delete(key) + self.backend.clear_history(self) def max_api_calls(self) -> int: "Returns the API rate limit for the highest limit" @@ -106,7 +89,7 @@ class RateLimitedObject(ABC): the rate-limit will be reset to 0""" max_window = self.max_api_window() max_calls = self.max_api_calls() - return _get_api_calls_left(self, max_window, max_calls) + return self.backend.get_api_calls_left(self, max_window, max_calls) @abstractmethod def key_fragment(self) -> str: @@ -124,6 +107,7 @@ class RateLimitedUser(RateLimitedObject): def __init__(self, user: UserProfile, domain: str='api_by_user') -> None: self.user = user self.domain = domain + super().__init__() def __str__(self) -> str: return "Id: {}".format(self.user.id) @@ -161,138 +145,215 @@ def remove_ratelimit_rule(range_seconds: int, num_requests: int, domain: str='ap global rules rules[domain] = [x for x in rules[domain] if x[0] != range_seconds and x[1] != num_requests] -def _get_api_calls_left(entity: RateLimitedObject, range_seconds: int, max_calls: int) -> Tuple[int, float]: - list_key, set_key, _ = entity.get_keys() - # Count the number of values in our sorted set - # that are between now and the cutoff - now = time.time() - boundary = now - range_seconds +class RateLimiterBackend(ABC): + @classmethod + @abstractmethod + def block_access(cls, entity: RateLimitedObject, seconds: int) -> None: + "Manually blocks an entity for the desired number of seconds" - with client.pipeline() as pipe: - # Count how many API calls in our range have already been made - pipe.zcount(set_key, boundary, now) - # Get the newest call so we can calculate when the ratelimit - # will reset to 0 - pipe.lindex(list_key, 0) + @classmethod + @abstractmethod + def unblock_access(cls, entity: RateLimitedObject) -> None: + pass - results = pipe.execute() + @classmethod + @abstractmethod + def clear_history(cls, entity: RateLimitedObject) -> None: + ''' + This is only used by test code now, where it's very helpful in + allowing us to run tests quickly, by giving a user a clean slate. + ''' - count = results[0] # type: int - newest_call = results[1] # type: Optional[bytes] + @classmethod + @abstractmethod + def get_api_calls_left(cls, entity: RateLimitedObject, range_seconds: int, + max_calls: int) -> Tuple[int, float]: + pass - calls_left = max_calls - count - if newest_call is not None: - time_reset = now + (range_seconds - (now - float(newest_call))) - else: - time_reset = now + @classmethod + @abstractmethod + def rate_limit_entity(cls, entity: RateLimitedObject) -> Tuple[bool, float]: + # Returns (ratelimited, secs_to_freedom) + pass - return calls_left, time_reset +class RedisRateLimiterBackend(RateLimiterBackend): + @classmethod + def block_access(cls, entity: RateLimitedObject, seconds: int) -> None: + "Manually blocks an entity for the desired number of seconds" + _, _, blocking_key = entity.get_keys() + with client.pipeline() as pipe: + pipe.set(blocking_key, 1) + pipe.expire(blocking_key, seconds) + pipe.execute() -def is_ratelimited(entity: RateLimitedObject) -> Tuple[bool, float]: - "Returns a tuple of (rate_limited, time_till_free)" - list_key, set_key, blocking_key = entity.get_keys() + @classmethod + def unblock_access(cls, entity: RateLimitedObject) -> None: + _, _, blocking_key = entity.get_keys() + client.delete(blocking_key) - rules = entity.rules() + @classmethod + def clear_history(cls, entity: RateLimitedObject) -> None: + ''' + This is only used by test code now, where it's very helpful in + allowing us to run tests quickly, by giving a user a clean slate. + ''' + for key in entity.get_keys(): + client.delete(key) - if len(rules) == 0: + @classmethod + def get_api_calls_left(cls, entity: RateLimitedObject, range_seconds: int, + max_calls: int) -> Tuple[int, float]: + list_key, set_key, _ = entity.get_keys() + # Count the number of values in our sorted set + # that are between now and the cutoff + now = time.time() + boundary = now - range_seconds + + with client.pipeline() as pipe: + # Count how many API calls in our range have already been made + pipe.zcount(set_key, boundary, now) + # Get the newest call so we can calculate when the ratelimit + # will reset to 0 + pipe.lindex(list_key, 0) + + results = pipe.execute() + + count = results[0] # type: int + newest_call = results[1] # type: Optional[bytes] + + calls_left = max_calls - count + if newest_call is not None: + time_reset = now + (range_seconds - (now - float(newest_call))) + else: + time_reset = now + + return calls_left, time_reset + + @classmethod + def is_ratelimited(cls, entity: RateLimitedObject) -> Tuple[bool, float]: + "Returns a tuple of (rate_limited, time_till_free)" + list_key, set_key, blocking_key = entity.get_keys() + + rules = entity.rules() + + if len(rules) == 0: + return False, 0.0 + + # Go through the rules from shortest to longest, + # seeing if this user has violated any of them. First + # get the timestamps for each nth items + with client.pipeline() as pipe: + for _, request_count in rules: + pipe.lindex(list_key, request_count - 1) # 0-indexed list + + # Get blocking info + pipe.get(blocking_key) + pipe.ttl(blocking_key) + + rule_timestamps = pipe.execute() # type: List[Optional[bytes]] + + # Check if there is a manual block on this API key + blocking_ttl_b = rule_timestamps.pop() + key_blocked = rule_timestamps.pop() + + if key_blocked is not None: + # We are manually blocked. Report for how much longer we will be + if blocking_ttl_b is None: + blocking_ttl = 0.5 + else: + blocking_ttl = int(blocking_ttl_b) + return True, blocking_ttl + + now = time.time() + for timestamp, (range_seconds, num_requests) in zip(rule_timestamps, rules): + # Check if the nth timestamp is newer than the associated rule. If so, + # it means we've hit our limit for this rule + if timestamp is None: + continue + + boundary = float(timestamp) + range_seconds + if boundary > now: + free = boundary - now + return True, free + + # No api calls recorded yet return False, 0.0 - # Go through the rules from shortest to longest, - # seeing if this user has violated any of them. First - # get the timestamps for each nth items - with client.pipeline() as pipe: - for _, request_count in rules: - pipe.lindex(list_key, request_count - 1) # 0-indexed list + @classmethod + def incr_ratelimit(cls, entity: RateLimitedObject) -> None: + """Increases the rate-limit for the specified entity""" + list_key, set_key, _ = entity.get_keys() + now = time.time() - # Get blocking info - pipe.get(blocking_key) - pipe.ttl(blocking_key) + # If we have no rules, we don't store anything + if len(rules) == 0: + return - rule_timestamps = pipe.execute() # type: List[Optional[bytes]] + # Start redis transaction + with client.pipeline() as pipe: + count = 0 + while True: + try: + # To avoid a race condition between getting the element we might trim from our list + # and removing it from our associated set, we abort this whole transaction if + # another agent manages to change our list out from under us + # When watching a value, the pipeline is set to Immediate mode + pipe.watch(list_key) - # Check if there is a manual block on this API key - blocking_ttl_b = rule_timestamps.pop() - key_blocked = rule_timestamps.pop() + # Get the last elem that we'll trim (so we can remove it from our sorted set) + last_val = pipe.lindex(list_key, entity.max_api_calls() - 1) + + # Restart buffered execution + pipe.multi() + + # Add this timestamp to our list + pipe.lpush(list_key, now) + + # Trim our list to the oldest rule we have + pipe.ltrim(list_key, 0, entity.max_api_calls() - 1) + + # Add our new value to the sorted set that we keep + # We need to put the score and val both as timestamp, + # as we sort by score but remove by value + pipe.zadd(set_key, {str(now): now}) + + # Remove the trimmed value from our sorted set, if there was one + if last_val is not None: + pipe.zrem(set_key, last_val) + + # Set the TTL for our keys as well + api_window = entity.max_api_window() + pipe.expire(list_key, api_window) + pipe.expire(set_key, api_window) + + pipe.execute() + + # If no exception was raised in the execution, there were no transaction conflicts + break + except redis.WatchError: + if count > 10: + raise RateLimiterLockingException() + count += 1 + + continue + + @classmethod + def rate_limit_entity(cls, entity: RateLimitedObject) -> Tuple[bool, float]: + ratelimited, time = cls.is_ratelimited(entity) + + if ratelimited: + statsd.incr("ratelimiter.limited.%s.%s" % (type(entity), str(entity))) - if key_blocked is not None: - # We are manually blocked. Report for how much longer we will be - if blocking_ttl_b is None: - blocking_ttl = 0.5 else: - blocking_ttl = int(blocking_ttl_b) - return True, blocking_ttl - - now = time.time() - for timestamp, (range_seconds, num_requests) in zip(rule_timestamps, rules): - # Check if the nth timestamp is newer than the associated rule. If so, - # it means we've hit our limit for this rule - if timestamp is None: - continue - - boundary = float(timestamp) + range_seconds - if boundary > now: - free = boundary - now - return True, free - - # No api calls recorded yet - return False, 0.0 - -def incr_ratelimit(entity: RateLimitedObject) -> None: - """Increases the rate-limit for the specified entity""" - list_key, set_key, _ = entity.get_keys() - now = time.time() - - # If we have no rules, we don't store anything - if len(rules) == 0: - return - - # Start redis transaction - with client.pipeline() as pipe: - count = 0 - while True: try: - # To avoid a race condition between getting the element we might trim from our list - # and removing it from our associated set, we abort this whole transaction if - # another agent manages to change our list out from under us - # When watching a value, the pipeline is set to Immediate mode - pipe.watch(list_key) + cls.incr_ratelimit(entity) + except RateLimiterLockingException: + logger.warning("Deadlock trying to incr_ratelimit for %s:%s" % ( + type(entity).__name__, str(entity))) + # rate-limit users who are hitting the API so hard we can't update our stats. + ratelimited = True - # Get the last elem that we'll trim (so we can remove it from our sorted set) - last_val = pipe.lindex(list_key, entity.max_api_calls() - 1) - - # Restart buffered execution - pipe.multi() - - # Add this timestamp to our list - pipe.lpush(list_key, now) - - # Trim our list to the oldest rule we have - pipe.ltrim(list_key, 0, entity.max_api_calls() - 1) - - # Add our new value to the sorted set that we keep - # We need to put the score and val both as timestamp, - # as we sort by score but remove by value - pipe.zadd(set_key, {str(now): now}) - - # Remove the trimmed value from our sorted set, if there was one - if last_val is not None: - pipe.zrem(set_key, last_val) - - # Set the TTL for our keys as well - api_window = entity.max_api_window() - pipe.expire(list_key, api_window) - pipe.expire(set_key, api_window) - - pipe.execute() - - # If no exception was raised in the execution, there were no transaction conflicts - break - except redis.WatchError: - if count > 10: - raise RateLimiterLockingException() - count += 1 - - continue + return ratelimited, time class RateLimitResult: def __init__(self, entity: RateLimitedObject, secs_to_freedom: float, over_limit: bool, diff --git a/zerver/tests/test_external.py b/zerver/tests/test_external.py index c4ed717fce..c2e29d10ff 100644 --- a/zerver/tests/test_external.py +++ b/zerver/tests/test_external.py @@ -117,7 +117,7 @@ class RateLimitTests(ZulipTestCase): user = self.example_user('cordelia') RateLimitedUser(user).clear_history() - with mock.patch('zerver.lib.rate_limiter.incr_ratelimit', + with mock.patch('zerver.lib.rate_limiter.RedisRateLimiterBackend.incr_ratelimit', side_effect=RateLimiterLockingException): result = self.send_api_message(user, "some stuff") self.assertEqual(result.status_code, 429) diff --git a/zerver/tests/test_queue_worker.py b/zerver/tests/test_queue_worker.py index 3a9d3dc406..1715210494 100644 --- a/zerver/tests/test_queue_worker.py +++ b/zerver/tests/test_queue_worker.py @@ -406,7 +406,7 @@ class WorkerTest(ZulipTestCase): self.assertEqual(mock_mirror_email.call_count, 4) # If RateLimiterLockingException is thrown, we rate-limit the new message: - with patch('zerver.lib.rate_limiter.incr_ratelimit', + with patch('zerver.lib.rate_limiter.RedisRateLimiterBackend.incr_ratelimit', side_effect=RateLimiterLockingException): fake_client.queue.append(('email_mirror', data[0])) worker.start() diff --git a/zproject/backends.py b/zproject/backends.py index 69e8e08c47..144a24780a 100644 --- a/zproject/backends.py +++ b/zproject/backends.py @@ -180,6 +180,7 @@ rate_limiting_rules = settings.RATE_LIMITING_RULES['authenticate_by_username'] class RateLimitedAuthenticationByUsername(RateLimitedObject): def __init__(self, username: str) -> None: self.username = username + super().__init__() def __str__(self) -> str: return "Username: {}".format(self.username)