mirror of https://github.com/zulip/zulip.git
rate_limit: Refactor RateLimiterBackend to operate on keys and rules.
Instead of operating on RateLimitedObjects, and making the classes depend on each too strongly. This also allows getting rid of get_keys() function from RateLimitedObject, which was a redis rate limiter implementation detail. RateLimitedObject should only define their own key() function and the logic forming various necessary redis keys from them should be in RedisRateLimiterBackend.
This commit is contained in:
parent
8069133f88
commit
b9e5103d0c
|
@ -32,14 +32,11 @@ class RateLimitedObject(ABC):
|
|||
def __init__(self) -> None:
|
||||
self.backend = RedisRateLimiterBackend
|
||||
|
||||
def get_keys(self) -> List[str]:
|
||||
key = self.key()
|
||||
return ["{}ratelimit:{}:{}".format(KEY_PREFIX, key, keytype)
|
||||
for keytype in ['list', 'zset', 'block']]
|
||||
|
||||
def rate_limit(self) -> Tuple[bool, float]:
|
||||
# Returns (ratelimited, secs_to_freedom)
|
||||
return self.backend.rate_limit_entity(self)
|
||||
return self.backend.rate_limit_entity(self.key(), self.rules(),
|
||||
self.max_api_calls(),
|
||||
self.max_api_window())
|
||||
|
||||
def rate_limit_request(self, request: HttpRequest) -> None:
|
||||
ratelimited, time = self.rate_limit()
|
||||
|
@ -64,17 +61,17 @@ class RateLimitedObject(ABC):
|
|||
|
||||
def block_access(self, seconds: int) -> None:
|
||||
"Manually blocks an entity for the desired number of seconds"
|
||||
self.backend.block_access(self, seconds)
|
||||
self.backend.block_access(self.key(), seconds)
|
||||
|
||||
def unblock_access(self) -> None:
|
||||
self.backend.unblock_access(self)
|
||||
self.backend.unblock_access(self.key())
|
||||
|
||||
def clear_history(self) -> None:
|
||||
'''
|
||||
This is only used by test code now, where it's very helpful in
|
||||
allowing us to run tests quickly, by giving a user a clean slate.
|
||||
'''
|
||||
self.backend.clear_history(self)
|
||||
self.backend.clear_history(self.key())
|
||||
|
||||
def max_api_calls(self) -> int:
|
||||
"Returns the API rate limit for the highest limit"
|
||||
|
@ -89,7 +86,7 @@ class RateLimitedObject(ABC):
|
|||
the rate-limit will be reset to 0"""
|
||||
max_window = self.max_api_window()
|
||||
max_calls = self.max_api_calls()
|
||||
return self.backend.get_api_calls_left(self, max_window, max_calls)
|
||||
return self.backend.get_api_calls_left(self.key(), max_window, max_calls)
|
||||
|
||||
@abstractmethod
|
||||
def key(self) -> str:
|
||||
|
@ -141,17 +138,17 @@ def remove_ratelimit_rule(range_seconds: int, num_requests: int, domain: str='ap
|
|||
class RateLimiterBackend(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def block_access(cls, entity: RateLimitedObject, seconds: int) -> None:
|
||||
def block_access(cls, entity_key: str, seconds: int) -> None:
|
||||
"Manually blocks an entity for the desired number of seconds"
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def unblock_access(cls, entity: RateLimitedObject) -> None:
|
||||
def unblock_access(cls, entity_key: str) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def clear_history(cls, entity: RateLimitedObject) -> None:
|
||||
def clear_history(cls, entity_key: str) -> None:
|
||||
'''
|
||||
This is only used by test code now, where it's very helpful in
|
||||
allowing us to run tests quickly, by giving a user a clean slate.
|
||||
|
@ -159,44 +156,50 @@ class RateLimiterBackend(ABC):
|
|||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_api_calls_left(cls, entity: RateLimitedObject, range_seconds: int,
|
||||
def get_api_calls_left(cls, entity_key: str, range_seconds: int,
|
||||
max_calls: int) -> Tuple[int, float]:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def rate_limit_entity(cls, entity: RateLimitedObject) -> Tuple[bool, float]:
|
||||
def rate_limit_entity(cls, entity_key: str, rules: List[Tuple[int, int]],
|
||||
max_api_calls: int, max_api_window: int) -> Tuple[bool, float]:
|
||||
# Returns (ratelimited, secs_to_freedom)
|
||||
pass
|
||||
|
||||
class RedisRateLimiterBackend(RateLimiterBackend):
|
||||
@classmethod
|
||||
def block_access(cls, entity: RateLimitedObject, seconds: int) -> None:
|
||||
def get_keys(cls, entity_key: str) -> List[str]:
|
||||
return ["{}ratelimit:{}:{}".format(KEY_PREFIX, entity_key, keytype)
|
||||
for keytype in ['list', 'zset', 'block']]
|
||||
|
||||
@classmethod
|
||||
def block_access(cls, entity_key: str, seconds: int) -> None:
|
||||
"Manually blocks an entity for the desired number of seconds"
|
||||
_, _, blocking_key = entity.get_keys()
|
||||
_, _, blocking_key = cls.get_keys(entity_key)
|
||||
with client.pipeline() as pipe:
|
||||
pipe.set(blocking_key, 1)
|
||||
pipe.expire(blocking_key, seconds)
|
||||
pipe.execute()
|
||||
|
||||
@classmethod
|
||||
def unblock_access(cls, entity: RateLimitedObject) -> None:
|
||||
_, _, blocking_key = entity.get_keys()
|
||||
def unblock_access(cls, entity_key: str) -> None:
|
||||
_, _, blocking_key = cls.get_keys(entity_key)
|
||||
client.delete(blocking_key)
|
||||
|
||||
@classmethod
|
||||
def clear_history(cls, entity: RateLimitedObject) -> None:
|
||||
def clear_history(cls, entity_key: str) -> None:
|
||||
'''
|
||||
This is only used by test code now, where it's very helpful in
|
||||
allowing us to run tests quickly, by giving a user a clean slate.
|
||||
'''
|
||||
for key in entity.get_keys():
|
||||
for key in cls.get_keys(entity_key):
|
||||
client.delete(key)
|
||||
|
||||
@classmethod
|
||||
def get_api_calls_left(cls, entity: RateLimitedObject, range_seconds: int,
|
||||
def get_api_calls_left(cls, entity_key: str, range_seconds: int,
|
||||
max_calls: int) -> Tuple[int, float]:
|
||||
list_key, set_key, _ = entity.get_keys()
|
||||
list_key, set_key, _ = cls.get_keys(entity_key)
|
||||
# Count the number of values in our sorted set
|
||||
# that are between now and the cutoff
|
||||
now = time.time()
|
||||
|
@ -223,11 +226,9 @@ class RedisRateLimiterBackend(RateLimiterBackend):
|
|||
return calls_left, time_reset
|
||||
|
||||
@classmethod
|
||||
def is_ratelimited(cls, entity: RateLimitedObject) -> Tuple[bool, float]:
|
||||
def is_ratelimited(cls, entity_key: str, rules: List[Tuple[int, int]]) -> Tuple[bool, float]:
|
||||
"Returns a tuple of (rate_limited, time_till_free)"
|
||||
list_key, set_key, blocking_key = entity.get_keys()
|
||||
|
||||
rules = entity.rules()
|
||||
list_key, set_key, blocking_key = cls.get_keys(entity_key)
|
||||
|
||||
if len(rules) == 0:
|
||||
return False, 0.0
|
||||
|
@ -273,9 +274,10 @@ class RedisRateLimiterBackend(RateLimiterBackend):
|
|||
return False, 0.0
|
||||
|
||||
@classmethod
|
||||
def incr_ratelimit(cls, entity: RateLimitedObject) -> None:
|
||||
def incr_ratelimit(cls, entity_key: str, rules: List[Tuple[int, int]],
|
||||
max_api_calls: int, max_api_window: int) -> None:
|
||||
"""Increases the rate-limit for the specified entity"""
|
||||
list_key, set_key, _ = entity.get_keys()
|
||||
list_key, set_key, _ = cls.get_keys(entity_key)
|
||||
now = time.time()
|
||||
|
||||
# If we have no rules, we don't store anything
|
||||
|
@ -294,7 +296,7 @@ class RedisRateLimiterBackend(RateLimiterBackend):
|
|||
pipe.watch(list_key)
|
||||
|
||||
# Get the last elem that we'll trim (so we can remove it from our sorted set)
|
||||
last_val = pipe.lindex(list_key, entity.max_api_calls() - 1)
|
||||
last_val = pipe.lindex(list_key, max_api_calls - 1)
|
||||
|
||||
# Restart buffered execution
|
||||
pipe.multi()
|
||||
|
@ -303,7 +305,7 @@ class RedisRateLimiterBackend(RateLimiterBackend):
|
|||
pipe.lpush(list_key, now)
|
||||
|
||||
# Trim our list to the oldest rule we have
|
||||
pipe.ltrim(list_key, 0, entity.max_api_calls() - 1)
|
||||
pipe.ltrim(list_key, 0, max_api_calls - 1)
|
||||
|
||||
# Add our new value to the sorted set that we keep
|
||||
# We need to put the score and val both as timestamp,
|
||||
|
@ -315,7 +317,7 @@ class RedisRateLimiterBackend(RateLimiterBackend):
|
|||
pipe.zrem(set_key, last_val)
|
||||
|
||||
# Set the TTL for our keys as well
|
||||
api_window = entity.max_api_window()
|
||||
api_window = max_api_window
|
||||
pipe.expire(list_key, api_window)
|
||||
pipe.expire(set_key, api_window)
|
||||
|
||||
|
@ -331,17 +333,18 @@ class RedisRateLimiterBackend(RateLimiterBackend):
|
|||
continue
|
||||
|
||||
@classmethod
|
||||
def rate_limit_entity(cls, entity: RateLimitedObject) -> Tuple[bool, float]:
|
||||
ratelimited, time = cls.is_ratelimited(entity)
|
||||
def rate_limit_entity(cls, entity_key: str, rules: List[Tuple[int, int]],
|
||||
max_api_calls: int, max_api_window: int) -> Tuple[bool, float]:
|
||||
ratelimited, time = cls.is_ratelimited(entity_key, rules)
|
||||
|
||||
if ratelimited:
|
||||
statsd.incr("ratelimiter.limited.%s" % (entity.key(),))
|
||||
statsd.incr("ratelimiter.limited.%s" % (entity_key,))
|
||||
|
||||
else:
|
||||
try:
|
||||
cls.incr_ratelimit(entity)
|
||||
cls.incr_ratelimit(entity_key, rules, max_api_calls, max_api_window)
|
||||
except RateLimiterLockingException:
|
||||
logger.warning("Deadlock trying to incr_ratelimit for %s" % (entity.key(),))
|
||||
logger.warning("Deadlock trying to incr_ratelimit for %s" % (entity_key,))
|
||||
# rate-limit users who are hitting the API so hard we can't update our stats.
|
||||
ratelimited = True
|
||||
|
||||
|
|
Loading…
Reference in New Issue