mirror of https://github.com/zulip/zulip.git
rate_limit: Add the concept of RateLimiterBackend.
This will allow easily swapping and using various implementations of rate-limiting, and separate the implementation logic from RateLimitedObjects.
This commit is contained in:
parent
85df6201f6
commit
9c9f8100e7
|
@ -303,6 +303,7 @@ class ZulipPasswordResetForm(PasswordResetForm):
|
|||
class RateLimitedPasswordResetByEmail(RateLimitedObject):
|
||||
def __init__(self, email: str) -> None:
|
||||
self.email = email
|
||||
super().__init__()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "Email: {}".format(self.email)
|
||||
|
|
|
@ -442,6 +442,7 @@ def mirror_email_message(data: Dict[str, str]) -> Dict[str, str]:
|
|||
class RateLimitedRealmMirror(RateLimitedObject):
|
||||
def __init__(self, realm: Realm) -> None:
|
||||
self.realm = realm
|
||||
super().__init__()
|
||||
|
||||
def key_fragment(self) -> str:
|
||||
return "emailmirror:{}:{}".format(type(self.realm), self.realm.id)
|
||||
|
|
|
@ -29,6 +29,9 @@ class RateLimiterLockingException(Exception):
|
|||
pass
|
||||
|
||||
class RateLimitedObject(ABC):
|
||||
def __init__(self) -> None:
|
||||
self.backend = RedisRateLimiterBackend
|
||||
|
||||
def get_keys(self) -> List[str]:
|
||||
key_fragment = self.key_fragment()
|
||||
return ["{}ratelimit:{}:{}".format(KEY_PREFIX, key_fragment, keytype)
|
||||
|
@ -36,21 +39,7 @@ class RateLimitedObject(ABC):
|
|||
|
||||
def rate_limit(self) -> Tuple[bool, float]:
|
||||
# Returns (ratelimited, secs_to_freedom)
|
||||
ratelimited, time = is_ratelimited(self)
|
||||
|
||||
if ratelimited:
|
||||
statsd.incr("ratelimiter.limited.%s.%s" % (type(self), str(self)))
|
||||
|
||||
else:
|
||||
try:
|
||||
incr_ratelimit(self)
|
||||
except RateLimiterLockingException:
|
||||
logger.warning("Deadlock trying to incr_ratelimit for %s:%s" % (
|
||||
type(self).__name__, str(self)))
|
||||
# rate-limit users who are hitting the API so hard we can't update our stats.
|
||||
ratelimited = True
|
||||
|
||||
return ratelimited, time
|
||||
return self.backend.rate_limit_entity(self)
|
||||
|
||||
def rate_limit_request(self, request: HttpRequest) -> None:
|
||||
ratelimited, time = self.rate_limit()
|
||||
|
@ -75,23 +64,17 @@ class RateLimitedObject(ABC):
|
|||
|
||||
def block_access(self, seconds: int) -> None:
|
||||
"Manually blocks an entity for the desired number of seconds"
|
||||
_, _, blocking_key = self.get_keys()
|
||||
with client.pipeline() as pipe:
|
||||
pipe.set(blocking_key, 1)
|
||||
pipe.expire(blocking_key, seconds)
|
||||
pipe.execute()
|
||||
self.backend.block_access(self, seconds)
|
||||
|
||||
def unblock_access(self) -> None:
|
||||
_, _, blocking_key = self.get_keys()
|
||||
client.delete(blocking_key)
|
||||
self.backend.unblock_access(self)
|
||||
|
||||
def clear_history(self) -> None:
|
||||
'''
|
||||
This is only used by test code now, where it's very helpful in
|
||||
allowing us to run tests quickly, by giving a user a clean slate.
|
||||
'''
|
||||
for key in self.get_keys():
|
||||
client.delete(key)
|
||||
self.backend.clear_history(self)
|
||||
|
||||
def max_api_calls(self) -> int:
|
||||
"Returns the API rate limit for the highest limit"
|
||||
|
@ -106,7 +89,7 @@ class RateLimitedObject(ABC):
|
|||
the rate-limit will be reset to 0"""
|
||||
max_window = self.max_api_window()
|
||||
max_calls = self.max_api_calls()
|
||||
return _get_api_calls_left(self, max_window, max_calls)
|
||||
return self.backend.get_api_calls_left(self, max_window, max_calls)
|
||||
|
||||
@abstractmethod
|
||||
def key_fragment(self) -> str:
|
||||
|
@ -124,6 +107,7 @@ class RateLimitedUser(RateLimitedObject):
|
|||
def __init__(self, user: UserProfile, domain: str='api_by_user') -> None:
|
||||
self.user = user
|
||||
self.domain = domain
|
||||
super().__init__()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "Id: {}".format(self.user.id)
|
||||
|
@ -161,138 +145,215 @@ def remove_ratelimit_rule(range_seconds: int, num_requests: int, domain: str='ap
|
|||
global rules
|
||||
rules[domain] = [x for x in rules[domain] if x[0] != range_seconds and x[1] != num_requests]
|
||||
|
||||
def _get_api_calls_left(entity: RateLimitedObject, range_seconds: int, max_calls: int) -> Tuple[int, float]:
|
||||
list_key, set_key, _ = entity.get_keys()
|
||||
# Count the number of values in our sorted set
|
||||
# that are between now and the cutoff
|
||||
now = time.time()
|
||||
boundary = now - range_seconds
|
||||
class RateLimiterBackend(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def block_access(cls, entity: RateLimitedObject, seconds: int) -> None:
|
||||
"Manually blocks an entity for the desired number of seconds"
|
||||
|
||||
with client.pipeline() as pipe:
|
||||
# Count how many API calls in our range have already been made
|
||||
pipe.zcount(set_key, boundary, now)
|
||||
# Get the newest call so we can calculate when the ratelimit
|
||||
# will reset to 0
|
||||
pipe.lindex(list_key, 0)
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def unblock_access(cls, entity: RateLimitedObject) -> None:
|
||||
pass
|
||||
|
||||
results = pipe.execute()
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def clear_history(cls, entity: RateLimitedObject) -> None:
|
||||
'''
|
||||
This is only used by test code now, where it's very helpful in
|
||||
allowing us to run tests quickly, by giving a user a clean slate.
|
||||
'''
|
||||
|
||||
count = results[0] # type: int
|
||||
newest_call = results[1] # type: Optional[bytes]
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_api_calls_left(cls, entity: RateLimitedObject, range_seconds: int,
|
||||
max_calls: int) -> Tuple[int, float]:
|
||||
pass
|
||||
|
||||
calls_left = max_calls - count
|
||||
if newest_call is not None:
|
||||
time_reset = now + (range_seconds - (now - float(newest_call)))
|
||||
else:
|
||||
time_reset = now
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def rate_limit_entity(cls, entity: RateLimitedObject) -> Tuple[bool, float]:
|
||||
# Returns (ratelimited, secs_to_freedom)
|
||||
pass
|
||||
|
||||
return calls_left, time_reset
|
||||
class RedisRateLimiterBackend(RateLimiterBackend):
|
||||
@classmethod
|
||||
def block_access(cls, entity: RateLimitedObject, seconds: int) -> None:
|
||||
"Manually blocks an entity for the desired number of seconds"
|
||||
_, _, blocking_key = entity.get_keys()
|
||||
with client.pipeline() as pipe:
|
||||
pipe.set(blocking_key, 1)
|
||||
pipe.expire(blocking_key, seconds)
|
||||
pipe.execute()
|
||||
|
||||
def is_ratelimited(entity: RateLimitedObject) -> Tuple[bool, float]:
|
||||
"Returns a tuple of (rate_limited, time_till_free)"
|
||||
list_key, set_key, blocking_key = entity.get_keys()
|
||||
@classmethod
|
||||
def unblock_access(cls, entity: RateLimitedObject) -> None:
|
||||
_, _, blocking_key = entity.get_keys()
|
||||
client.delete(blocking_key)
|
||||
|
||||
rules = entity.rules()
|
||||
@classmethod
|
||||
def clear_history(cls, entity: RateLimitedObject) -> None:
|
||||
'''
|
||||
This is only used by test code now, where it's very helpful in
|
||||
allowing us to run tests quickly, by giving a user a clean slate.
|
||||
'''
|
||||
for key in entity.get_keys():
|
||||
client.delete(key)
|
||||
|
||||
if len(rules) == 0:
|
||||
@classmethod
|
||||
def get_api_calls_left(cls, entity: RateLimitedObject, range_seconds: int,
|
||||
max_calls: int) -> Tuple[int, float]:
|
||||
list_key, set_key, _ = entity.get_keys()
|
||||
# Count the number of values in our sorted set
|
||||
# that are between now and the cutoff
|
||||
now = time.time()
|
||||
boundary = now - range_seconds
|
||||
|
||||
with client.pipeline() as pipe:
|
||||
# Count how many API calls in our range have already been made
|
||||
pipe.zcount(set_key, boundary, now)
|
||||
# Get the newest call so we can calculate when the ratelimit
|
||||
# will reset to 0
|
||||
pipe.lindex(list_key, 0)
|
||||
|
||||
results = pipe.execute()
|
||||
|
||||
count = results[0] # type: int
|
||||
newest_call = results[1] # type: Optional[bytes]
|
||||
|
||||
calls_left = max_calls - count
|
||||
if newest_call is not None:
|
||||
time_reset = now + (range_seconds - (now - float(newest_call)))
|
||||
else:
|
||||
time_reset = now
|
||||
|
||||
return calls_left, time_reset
|
||||
|
||||
@classmethod
|
||||
def is_ratelimited(cls, entity: RateLimitedObject) -> Tuple[bool, float]:
|
||||
"Returns a tuple of (rate_limited, time_till_free)"
|
||||
list_key, set_key, blocking_key = entity.get_keys()
|
||||
|
||||
rules = entity.rules()
|
||||
|
||||
if len(rules) == 0:
|
||||
return False, 0.0
|
||||
|
||||
# Go through the rules from shortest to longest,
|
||||
# seeing if this user has violated any of them. First
|
||||
# get the timestamps for each nth items
|
||||
with client.pipeline() as pipe:
|
||||
for _, request_count in rules:
|
||||
pipe.lindex(list_key, request_count - 1) # 0-indexed list
|
||||
|
||||
# Get blocking info
|
||||
pipe.get(blocking_key)
|
||||
pipe.ttl(blocking_key)
|
||||
|
||||
rule_timestamps = pipe.execute() # type: List[Optional[bytes]]
|
||||
|
||||
# Check if there is a manual block on this API key
|
||||
blocking_ttl_b = rule_timestamps.pop()
|
||||
key_blocked = rule_timestamps.pop()
|
||||
|
||||
if key_blocked is not None:
|
||||
# We are manually blocked. Report for how much longer we will be
|
||||
if blocking_ttl_b is None:
|
||||
blocking_ttl = 0.5
|
||||
else:
|
||||
blocking_ttl = int(blocking_ttl_b)
|
||||
return True, blocking_ttl
|
||||
|
||||
now = time.time()
|
||||
for timestamp, (range_seconds, num_requests) in zip(rule_timestamps, rules):
|
||||
# Check if the nth timestamp is newer than the associated rule. If so,
|
||||
# it means we've hit our limit for this rule
|
||||
if timestamp is None:
|
||||
continue
|
||||
|
||||
boundary = float(timestamp) + range_seconds
|
||||
if boundary > now:
|
||||
free = boundary - now
|
||||
return True, free
|
||||
|
||||
# No api calls recorded yet
|
||||
return False, 0.0
|
||||
|
||||
# Go through the rules from shortest to longest,
|
||||
# seeing if this user has violated any of them. First
|
||||
# get the timestamps for each nth items
|
||||
with client.pipeline() as pipe:
|
||||
for _, request_count in rules:
|
||||
pipe.lindex(list_key, request_count - 1) # 0-indexed list
|
||||
@classmethod
|
||||
def incr_ratelimit(cls, entity: RateLimitedObject) -> None:
|
||||
"""Increases the rate-limit for the specified entity"""
|
||||
list_key, set_key, _ = entity.get_keys()
|
||||
now = time.time()
|
||||
|
||||
# Get blocking info
|
||||
pipe.get(blocking_key)
|
||||
pipe.ttl(blocking_key)
|
||||
# If we have no rules, we don't store anything
|
||||
if len(rules) == 0:
|
||||
return
|
||||
|
||||
rule_timestamps = pipe.execute() # type: List[Optional[bytes]]
|
||||
# Start redis transaction
|
||||
with client.pipeline() as pipe:
|
||||
count = 0
|
||||
while True:
|
||||
try:
|
||||
# To avoid a race condition between getting the element we might trim from our list
|
||||
# and removing it from our associated set, we abort this whole transaction if
|
||||
# another agent manages to change our list out from under us
|
||||
# When watching a value, the pipeline is set to Immediate mode
|
||||
pipe.watch(list_key)
|
||||
|
||||
# Check if there is a manual block on this API key
|
||||
blocking_ttl_b = rule_timestamps.pop()
|
||||
key_blocked = rule_timestamps.pop()
|
||||
# Get the last elem that we'll trim (so we can remove it from our sorted set)
|
||||
last_val = pipe.lindex(list_key, entity.max_api_calls() - 1)
|
||||
|
||||
# Restart buffered execution
|
||||
pipe.multi()
|
||||
|
||||
# Add this timestamp to our list
|
||||
pipe.lpush(list_key, now)
|
||||
|
||||
# Trim our list to the oldest rule we have
|
||||
pipe.ltrim(list_key, 0, entity.max_api_calls() - 1)
|
||||
|
||||
# Add our new value to the sorted set that we keep
|
||||
# We need to put the score and val both as timestamp,
|
||||
# as we sort by score but remove by value
|
||||
pipe.zadd(set_key, {str(now): now})
|
||||
|
||||
# Remove the trimmed value from our sorted set, if there was one
|
||||
if last_val is not None:
|
||||
pipe.zrem(set_key, last_val)
|
||||
|
||||
# Set the TTL for our keys as well
|
||||
api_window = entity.max_api_window()
|
||||
pipe.expire(list_key, api_window)
|
||||
pipe.expire(set_key, api_window)
|
||||
|
||||
pipe.execute()
|
||||
|
||||
# If no exception was raised in the execution, there were no transaction conflicts
|
||||
break
|
||||
except redis.WatchError:
|
||||
if count > 10:
|
||||
raise RateLimiterLockingException()
|
||||
count += 1
|
||||
|
||||
continue
|
||||
|
||||
@classmethod
|
||||
def rate_limit_entity(cls, entity: RateLimitedObject) -> Tuple[bool, float]:
|
||||
ratelimited, time = cls.is_ratelimited(entity)
|
||||
|
||||
if ratelimited:
|
||||
statsd.incr("ratelimiter.limited.%s.%s" % (type(entity), str(entity)))
|
||||
|
||||
if key_blocked is not None:
|
||||
# We are manually blocked. Report for how much longer we will be
|
||||
if blocking_ttl_b is None:
|
||||
blocking_ttl = 0.5
|
||||
else:
|
||||
blocking_ttl = int(blocking_ttl_b)
|
||||
return True, blocking_ttl
|
||||
|
||||
now = time.time()
|
||||
for timestamp, (range_seconds, num_requests) in zip(rule_timestamps, rules):
|
||||
# Check if the nth timestamp is newer than the associated rule. If so,
|
||||
# it means we've hit our limit for this rule
|
||||
if timestamp is None:
|
||||
continue
|
||||
|
||||
boundary = float(timestamp) + range_seconds
|
||||
if boundary > now:
|
||||
free = boundary - now
|
||||
return True, free
|
||||
|
||||
# No api calls recorded yet
|
||||
return False, 0.0
|
||||
|
||||
def incr_ratelimit(entity: RateLimitedObject) -> None:
|
||||
"""Increases the rate-limit for the specified entity"""
|
||||
list_key, set_key, _ = entity.get_keys()
|
||||
now = time.time()
|
||||
|
||||
# If we have no rules, we don't store anything
|
||||
if len(rules) == 0:
|
||||
return
|
||||
|
||||
# Start redis transaction
|
||||
with client.pipeline() as pipe:
|
||||
count = 0
|
||||
while True:
|
||||
try:
|
||||
# To avoid a race condition between getting the element we might trim from our list
|
||||
# and removing it from our associated set, we abort this whole transaction if
|
||||
# another agent manages to change our list out from under us
|
||||
# When watching a value, the pipeline is set to Immediate mode
|
||||
pipe.watch(list_key)
|
||||
cls.incr_ratelimit(entity)
|
||||
except RateLimiterLockingException:
|
||||
logger.warning("Deadlock trying to incr_ratelimit for %s:%s" % (
|
||||
type(entity).__name__, str(entity)))
|
||||
# rate-limit users who are hitting the API so hard we can't update our stats.
|
||||
ratelimited = True
|
||||
|
||||
# Get the last elem that we'll trim (so we can remove it from our sorted set)
|
||||
last_val = pipe.lindex(list_key, entity.max_api_calls() - 1)
|
||||
|
||||
# Restart buffered execution
|
||||
pipe.multi()
|
||||
|
||||
# Add this timestamp to our list
|
||||
pipe.lpush(list_key, now)
|
||||
|
||||
# Trim our list to the oldest rule we have
|
||||
pipe.ltrim(list_key, 0, entity.max_api_calls() - 1)
|
||||
|
||||
# Add our new value to the sorted set that we keep
|
||||
# We need to put the score and val both as timestamp,
|
||||
# as we sort by score but remove by value
|
||||
pipe.zadd(set_key, {str(now): now})
|
||||
|
||||
# Remove the trimmed value from our sorted set, if there was one
|
||||
if last_val is not None:
|
||||
pipe.zrem(set_key, last_val)
|
||||
|
||||
# Set the TTL for our keys as well
|
||||
api_window = entity.max_api_window()
|
||||
pipe.expire(list_key, api_window)
|
||||
pipe.expire(set_key, api_window)
|
||||
|
||||
pipe.execute()
|
||||
|
||||
# If no exception was raised in the execution, there were no transaction conflicts
|
||||
break
|
||||
except redis.WatchError:
|
||||
if count > 10:
|
||||
raise RateLimiterLockingException()
|
||||
count += 1
|
||||
|
||||
continue
|
||||
return ratelimited, time
|
||||
|
||||
class RateLimitResult:
|
||||
def __init__(self, entity: RateLimitedObject, secs_to_freedom: float, over_limit: bool,
|
||||
|
|
|
@ -117,7 +117,7 @@ class RateLimitTests(ZulipTestCase):
|
|||
user = self.example_user('cordelia')
|
||||
RateLimitedUser(user).clear_history()
|
||||
|
||||
with mock.patch('zerver.lib.rate_limiter.incr_ratelimit',
|
||||
with mock.patch('zerver.lib.rate_limiter.RedisRateLimiterBackend.incr_ratelimit',
|
||||
side_effect=RateLimiterLockingException):
|
||||
result = self.send_api_message(user, "some stuff")
|
||||
self.assertEqual(result.status_code, 429)
|
||||
|
|
|
@ -406,7 +406,7 @@ class WorkerTest(ZulipTestCase):
|
|||
self.assertEqual(mock_mirror_email.call_count, 4)
|
||||
|
||||
# If RateLimiterLockingException is thrown, we rate-limit the new message:
|
||||
with patch('zerver.lib.rate_limiter.incr_ratelimit',
|
||||
with patch('zerver.lib.rate_limiter.RedisRateLimiterBackend.incr_ratelimit',
|
||||
side_effect=RateLimiterLockingException):
|
||||
fake_client.queue.append(('email_mirror', data[0]))
|
||||
worker.start()
|
||||
|
|
|
@ -180,6 +180,7 @@ rate_limiting_rules = settings.RATE_LIMITING_RULES['authenticate_by_username']
|
|||
class RateLimitedAuthenticationByUsername(RateLimitedObject):
|
||||
def __init__(self, username: str) -> None:
|
||||
self.username = username
|
||||
super().__init__()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "Username: {}".format(self.username)
|
||||
|
|
Loading…
Reference in New Issue