From b9e5103d0c86bc638e011626ecf84bd8c2820af0 Mon Sep 17 00:00:00 2001 From: Mateusz Mandera Date: Fri, 6 Mar 2020 13:44:52 +0100 Subject: [PATCH] rate_limit: Refactor RateLimiterBackend to operate on keys and rules. Instead of operating on RateLimitedObjects, and making the classes depend on each too strongly. This also allows getting rid of get_keys() function from RateLimitedObject, which was a redis rate limiter implementation detail. RateLimitedObject should only define their own key() function and the logic forming various necessary redis keys from them should be in RedisRateLimiterBackend. --- zerver/lib/rate_limiter.py | 77 ++++++++++++++++++++------------------ 1 file changed, 40 insertions(+), 37 deletions(-) diff --git a/zerver/lib/rate_limiter.py b/zerver/lib/rate_limiter.py index fd9c47e6b9..78894561db 100644 --- a/zerver/lib/rate_limiter.py +++ b/zerver/lib/rate_limiter.py @@ -32,14 +32,11 @@ class RateLimitedObject(ABC): def __init__(self) -> None: self.backend = RedisRateLimiterBackend - def get_keys(self) -> List[str]: - key = self.key() - return ["{}ratelimit:{}:{}".format(KEY_PREFIX, key, keytype) - for keytype in ['list', 'zset', 'block']] - def rate_limit(self) -> Tuple[bool, float]: # Returns (ratelimited, secs_to_freedom) - return self.backend.rate_limit_entity(self) + return self.backend.rate_limit_entity(self.key(), self.rules(), + self.max_api_calls(), + self.max_api_window()) def rate_limit_request(self, request: HttpRequest) -> None: ratelimited, time = self.rate_limit() @@ -64,17 +61,17 @@ class RateLimitedObject(ABC): def block_access(self, seconds: int) -> None: "Manually blocks an entity for the desired number of seconds" - self.backend.block_access(self, seconds) + self.backend.block_access(self.key(), seconds) def unblock_access(self) -> None: - self.backend.unblock_access(self) + self.backend.unblock_access(self.key()) def clear_history(self) -> None: ''' This is only used by test code now, where it's very helpful in allowing us to run tests quickly, by giving a user a clean slate. ''' - self.backend.clear_history(self) + self.backend.clear_history(self.key()) def max_api_calls(self) -> int: "Returns the API rate limit for the highest limit" @@ -89,7 +86,7 @@ class RateLimitedObject(ABC): the rate-limit will be reset to 0""" max_window = self.max_api_window() max_calls = self.max_api_calls() - return self.backend.get_api_calls_left(self, max_window, max_calls) + return self.backend.get_api_calls_left(self.key(), max_window, max_calls) @abstractmethod def key(self) -> str: @@ -141,17 +138,17 @@ def remove_ratelimit_rule(range_seconds: int, num_requests: int, domain: str='ap class RateLimiterBackend(ABC): @classmethod @abstractmethod - def block_access(cls, entity: RateLimitedObject, seconds: int) -> None: + def block_access(cls, entity_key: str, seconds: int) -> None: "Manually blocks an entity for the desired number of seconds" @classmethod @abstractmethod - def unblock_access(cls, entity: RateLimitedObject) -> None: + def unblock_access(cls, entity_key: str) -> None: pass @classmethod @abstractmethod - def clear_history(cls, entity: RateLimitedObject) -> None: + def clear_history(cls, entity_key: str) -> None: ''' This is only used by test code now, where it's very helpful in allowing us to run tests quickly, by giving a user a clean slate. @@ -159,44 +156,50 @@ class RateLimiterBackend(ABC): @classmethod @abstractmethod - def get_api_calls_left(cls, entity: RateLimitedObject, range_seconds: int, + def get_api_calls_left(cls, entity_key: str, range_seconds: int, max_calls: int) -> Tuple[int, float]: pass @classmethod @abstractmethod - def rate_limit_entity(cls, entity: RateLimitedObject) -> Tuple[bool, float]: + def rate_limit_entity(cls, entity_key: str, rules: List[Tuple[int, int]], + max_api_calls: int, max_api_window: int) -> Tuple[bool, float]: # Returns (ratelimited, secs_to_freedom) pass class RedisRateLimiterBackend(RateLimiterBackend): @classmethod - def block_access(cls, entity: RateLimitedObject, seconds: int) -> None: + def get_keys(cls, entity_key: str) -> List[str]: + return ["{}ratelimit:{}:{}".format(KEY_PREFIX, entity_key, keytype) + for keytype in ['list', 'zset', 'block']] + + @classmethod + def block_access(cls, entity_key: str, seconds: int) -> None: "Manually blocks an entity for the desired number of seconds" - _, _, blocking_key = entity.get_keys() + _, _, blocking_key = cls.get_keys(entity_key) with client.pipeline() as pipe: pipe.set(blocking_key, 1) pipe.expire(blocking_key, seconds) pipe.execute() @classmethod - def unblock_access(cls, entity: RateLimitedObject) -> None: - _, _, blocking_key = entity.get_keys() + def unblock_access(cls, entity_key: str) -> None: + _, _, blocking_key = cls.get_keys(entity_key) client.delete(blocking_key) @classmethod - def clear_history(cls, entity: RateLimitedObject) -> None: + def clear_history(cls, entity_key: str) -> None: ''' This is only used by test code now, where it's very helpful in allowing us to run tests quickly, by giving a user a clean slate. ''' - for key in entity.get_keys(): + for key in cls.get_keys(entity_key): client.delete(key) @classmethod - def get_api_calls_left(cls, entity: RateLimitedObject, range_seconds: int, + def get_api_calls_left(cls, entity_key: str, range_seconds: int, max_calls: int) -> Tuple[int, float]: - list_key, set_key, _ = entity.get_keys() + list_key, set_key, _ = cls.get_keys(entity_key) # Count the number of values in our sorted set # that are between now and the cutoff now = time.time() @@ -223,11 +226,9 @@ class RedisRateLimiterBackend(RateLimiterBackend): return calls_left, time_reset @classmethod - def is_ratelimited(cls, entity: RateLimitedObject) -> Tuple[bool, float]: + def is_ratelimited(cls, entity_key: str, rules: List[Tuple[int, int]]) -> Tuple[bool, float]: "Returns a tuple of (rate_limited, time_till_free)" - list_key, set_key, blocking_key = entity.get_keys() - - rules = entity.rules() + list_key, set_key, blocking_key = cls.get_keys(entity_key) if len(rules) == 0: return False, 0.0 @@ -273,9 +274,10 @@ class RedisRateLimiterBackend(RateLimiterBackend): return False, 0.0 @classmethod - def incr_ratelimit(cls, entity: RateLimitedObject) -> None: + def incr_ratelimit(cls, entity_key: str, rules: List[Tuple[int, int]], + max_api_calls: int, max_api_window: int) -> None: """Increases the rate-limit for the specified entity""" - list_key, set_key, _ = entity.get_keys() + list_key, set_key, _ = cls.get_keys(entity_key) now = time.time() # If we have no rules, we don't store anything @@ -294,7 +296,7 @@ class RedisRateLimiterBackend(RateLimiterBackend): pipe.watch(list_key) # Get the last elem that we'll trim (so we can remove it from our sorted set) - last_val = pipe.lindex(list_key, entity.max_api_calls() - 1) + last_val = pipe.lindex(list_key, max_api_calls - 1) # Restart buffered execution pipe.multi() @@ -303,7 +305,7 @@ class RedisRateLimiterBackend(RateLimiterBackend): pipe.lpush(list_key, now) # Trim our list to the oldest rule we have - pipe.ltrim(list_key, 0, entity.max_api_calls() - 1) + pipe.ltrim(list_key, 0, max_api_calls - 1) # Add our new value to the sorted set that we keep # We need to put the score and val both as timestamp, @@ -315,7 +317,7 @@ class RedisRateLimiterBackend(RateLimiterBackend): pipe.zrem(set_key, last_val) # Set the TTL for our keys as well - api_window = entity.max_api_window() + api_window = max_api_window pipe.expire(list_key, api_window) pipe.expire(set_key, api_window) @@ -331,17 +333,18 @@ class RedisRateLimiterBackend(RateLimiterBackend): continue @classmethod - def rate_limit_entity(cls, entity: RateLimitedObject) -> Tuple[bool, float]: - ratelimited, time = cls.is_ratelimited(entity) + def rate_limit_entity(cls, entity_key: str, rules: List[Tuple[int, int]], + max_api_calls: int, max_api_window: int) -> Tuple[bool, float]: + ratelimited, time = cls.is_ratelimited(entity_key, rules) if ratelimited: - statsd.incr("ratelimiter.limited.%s" % (entity.key(),)) + statsd.incr("ratelimiter.limited.%s" % (entity_key,)) else: try: - cls.incr_ratelimit(entity) + cls.incr_ratelimit(entity_key, rules, max_api_calls, max_api_window) except RateLimiterLockingException: - logger.warning("Deadlock trying to incr_ratelimit for %s" % (entity.key(),)) + logger.warning("Deadlock trying to incr_ratelimit for %s" % (entity_key,)) # rate-limit users who are hitting the API so hard we can't update our stats. ratelimited = True