rate_limit: Refactor RateLimiterBackend to operate on keys and rules.

Instead of operating on RateLimitedObjects, and making the classes
depend on each too strongly. This also allows getting rid of get_keys()
function from RateLimitedObject, which was a redis rate limiter
implementation detail. RateLimitedObject should only define their own
key() function and the logic forming various necessary redis keys from
them should be in RedisRateLimiterBackend.
This commit is contained in:
Mateusz Mandera 2020-03-06 13:44:52 +01:00 committed by Tim Abbott
parent 8069133f88
commit b9e5103d0c
1 changed files with 40 additions and 37 deletions

View File

@ -32,14 +32,11 @@ class RateLimitedObject(ABC):
def __init__(self) -> None:
self.backend = RedisRateLimiterBackend
def get_keys(self) -> List[str]:
key = self.key()
return ["{}ratelimit:{}:{}".format(KEY_PREFIX, key, keytype)
for keytype in ['list', 'zset', 'block']]
def rate_limit(self) -> Tuple[bool, float]:
# Returns (ratelimited, secs_to_freedom)
return self.backend.rate_limit_entity(self)
return self.backend.rate_limit_entity(self.key(), self.rules(),
self.max_api_calls(),
self.max_api_window())
def rate_limit_request(self, request: HttpRequest) -> None:
ratelimited, time = self.rate_limit()
@ -64,17 +61,17 @@ class RateLimitedObject(ABC):
def block_access(self, seconds: int) -> None:
"Manually blocks an entity for the desired number of seconds"
self.backend.block_access(self, seconds)
self.backend.block_access(self.key(), seconds)
def unblock_access(self) -> None:
self.backend.unblock_access(self)
self.backend.unblock_access(self.key())
def clear_history(self) -> None:
'''
This is only used by test code now, where it's very helpful in
allowing us to run tests quickly, by giving a user a clean slate.
'''
self.backend.clear_history(self)
self.backend.clear_history(self.key())
def max_api_calls(self) -> int:
"Returns the API rate limit for the highest limit"
@ -89,7 +86,7 @@ class RateLimitedObject(ABC):
the rate-limit will be reset to 0"""
max_window = self.max_api_window()
max_calls = self.max_api_calls()
return self.backend.get_api_calls_left(self, max_window, max_calls)
return self.backend.get_api_calls_left(self.key(), max_window, max_calls)
@abstractmethod
def key(self) -> str:
@ -141,17 +138,17 @@ def remove_ratelimit_rule(range_seconds: int, num_requests: int, domain: str='ap
class RateLimiterBackend(ABC):
@classmethod
@abstractmethod
def block_access(cls, entity: RateLimitedObject, seconds: int) -> None:
def block_access(cls, entity_key: str, seconds: int) -> None:
"Manually blocks an entity for the desired number of seconds"
@classmethod
@abstractmethod
def unblock_access(cls, entity: RateLimitedObject) -> None:
def unblock_access(cls, entity_key: str) -> None:
pass
@classmethod
@abstractmethod
def clear_history(cls, entity: RateLimitedObject) -> None:
def clear_history(cls, entity_key: str) -> None:
'''
This is only used by test code now, where it's very helpful in
allowing us to run tests quickly, by giving a user a clean slate.
@ -159,44 +156,50 @@ class RateLimiterBackend(ABC):
@classmethod
@abstractmethod
def get_api_calls_left(cls, entity: RateLimitedObject, range_seconds: int,
def get_api_calls_left(cls, entity_key: str, range_seconds: int,
max_calls: int) -> Tuple[int, float]:
pass
@classmethod
@abstractmethod
def rate_limit_entity(cls, entity: RateLimitedObject) -> Tuple[bool, float]:
def rate_limit_entity(cls, entity_key: str, rules: List[Tuple[int, int]],
max_api_calls: int, max_api_window: int) -> Tuple[bool, float]:
# Returns (ratelimited, secs_to_freedom)
pass
class RedisRateLimiterBackend(RateLimiterBackend):
@classmethod
def block_access(cls, entity: RateLimitedObject, seconds: int) -> None:
def get_keys(cls, entity_key: str) -> List[str]:
return ["{}ratelimit:{}:{}".format(KEY_PREFIX, entity_key, keytype)
for keytype in ['list', 'zset', 'block']]
@classmethod
def block_access(cls, entity_key: str, seconds: int) -> None:
"Manually blocks an entity for the desired number of seconds"
_, _, blocking_key = entity.get_keys()
_, _, blocking_key = cls.get_keys(entity_key)
with client.pipeline() as pipe:
pipe.set(blocking_key, 1)
pipe.expire(blocking_key, seconds)
pipe.execute()
@classmethod
def unblock_access(cls, entity: RateLimitedObject) -> None:
_, _, blocking_key = entity.get_keys()
def unblock_access(cls, entity_key: str) -> None:
_, _, blocking_key = cls.get_keys(entity_key)
client.delete(blocking_key)
@classmethod
def clear_history(cls, entity: RateLimitedObject) -> None:
def clear_history(cls, entity_key: str) -> None:
'''
This is only used by test code now, where it's very helpful in
allowing us to run tests quickly, by giving a user a clean slate.
'''
for key in entity.get_keys():
for key in cls.get_keys(entity_key):
client.delete(key)
@classmethod
def get_api_calls_left(cls, entity: RateLimitedObject, range_seconds: int,
def get_api_calls_left(cls, entity_key: str, range_seconds: int,
max_calls: int) -> Tuple[int, float]:
list_key, set_key, _ = entity.get_keys()
list_key, set_key, _ = cls.get_keys(entity_key)
# Count the number of values in our sorted set
# that are between now and the cutoff
now = time.time()
@ -223,11 +226,9 @@ class RedisRateLimiterBackend(RateLimiterBackend):
return calls_left, time_reset
@classmethod
def is_ratelimited(cls, entity: RateLimitedObject) -> Tuple[bool, float]:
def is_ratelimited(cls, entity_key: str, rules: List[Tuple[int, int]]) -> Tuple[bool, float]:
"Returns a tuple of (rate_limited, time_till_free)"
list_key, set_key, blocking_key = entity.get_keys()
rules = entity.rules()
list_key, set_key, blocking_key = cls.get_keys(entity_key)
if len(rules) == 0:
return False, 0.0
@ -273,9 +274,10 @@ class RedisRateLimiterBackend(RateLimiterBackend):
return False, 0.0
@classmethod
def incr_ratelimit(cls, entity: RateLimitedObject) -> None:
def incr_ratelimit(cls, entity_key: str, rules: List[Tuple[int, int]],
max_api_calls: int, max_api_window: int) -> None:
"""Increases the rate-limit for the specified entity"""
list_key, set_key, _ = entity.get_keys()
list_key, set_key, _ = cls.get_keys(entity_key)
now = time.time()
# If we have no rules, we don't store anything
@ -294,7 +296,7 @@ class RedisRateLimiterBackend(RateLimiterBackend):
pipe.watch(list_key)
# Get the last elem that we'll trim (so we can remove it from our sorted set)
last_val = pipe.lindex(list_key, entity.max_api_calls() - 1)
last_val = pipe.lindex(list_key, max_api_calls - 1)
# Restart buffered execution
pipe.multi()
@ -303,7 +305,7 @@ class RedisRateLimiterBackend(RateLimiterBackend):
pipe.lpush(list_key, now)
# Trim our list to the oldest rule we have
pipe.ltrim(list_key, 0, entity.max_api_calls() - 1)
pipe.ltrim(list_key, 0, max_api_calls - 1)
# Add our new value to the sorted set that we keep
# We need to put the score and val both as timestamp,
@ -315,7 +317,7 @@ class RedisRateLimiterBackend(RateLimiterBackend):
pipe.zrem(set_key, last_val)
# Set the TTL for our keys as well
api_window = entity.max_api_window()
api_window = max_api_window
pipe.expire(list_key, api_window)
pipe.expire(set_key, api_window)
@ -331,17 +333,18 @@ class RedisRateLimiterBackend(RateLimiterBackend):
continue
@classmethod
def rate_limit_entity(cls, entity: RateLimitedObject) -> Tuple[bool, float]:
ratelimited, time = cls.is_ratelimited(entity)
def rate_limit_entity(cls, entity_key: str, rules: List[Tuple[int, int]],
max_api_calls: int, max_api_window: int) -> Tuple[bool, float]:
ratelimited, time = cls.is_ratelimited(entity_key, rules)
if ratelimited:
statsd.incr("ratelimiter.limited.%s" % (entity.key(),))
statsd.incr("ratelimiter.limited.%s" % (entity_key,))
else:
try:
cls.incr_ratelimit(entity)
cls.incr_ratelimit(entity_key, rules, max_api_calls, max_api_window)
except RateLimiterLockingException:
logger.warning("Deadlock trying to incr_ratelimit for %s" % (entity.key(),))
logger.warning("Deadlock trying to incr_ratelimit for %s" % (entity_key,))
# rate-limit users who are hitting the API so hard we can't update our stats.
ratelimited = True