rate_limit: Add the concept of RateLimiterBackend.

This will allow easily swapping and using various implementations of
rate-limiting, and separate the implementation logic from
RateLimitedObjects.
This commit is contained in:
Mateusz Mandera 2020-03-05 13:38:20 +01:00 committed by Tim Abbott
parent 85df6201f6
commit 9c9f8100e7
6 changed files with 206 additions and 142 deletions

View File

@ -303,6 +303,7 @@ class ZulipPasswordResetForm(PasswordResetForm):
class RateLimitedPasswordResetByEmail(RateLimitedObject):
def __init__(self, email: str) -> None:
self.email = email
super().__init__()
def __str__(self) -> str:
return "Email: {}".format(self.email)

View File

@ -442,6 +442,7 @@ def mirror_email_message(data: Dict[str, str]) -> Dict[str, str]:
class RateLimitedRealmMirror(RateLimitedObject):
def __init__(self, realm: Realm) -> None:
self.realm = realm
super().__init__()
def key_fragment(self) -> str:
return "emailmirror:{}:{}".format(type(self.realm), self.realm.id)

View File

@ -29,6 +29,9 @@ class RateLimiterLockingException(Exception):
pass
class RateLimitedObject(ABC):
def __init__(self) -> None:
self.backend = RedisRateLimiterBackend
def get_keys(self) -> List[str]:
key_fragment = self.key_fragment()
return ["{}ratelimit:{}:{}".format(KEY_PREFIX, key_fragment, keytype)
@ -36,21 +39,7 @@ class RateLimitedObject(ABC):
def rate_limit(self) -> Tuple[bool, float]:
# Returns (ratelimited, secs_to_freedom)
ratelimited, time = is_ratelimited(self)
if ratelimited:
statsd.incr("ratelimiter.limited.%s.%s" % (type(self), str(self)))
else:
try:
incr_ratelimit(self)
except RateLimiterLockingException:
logger.warning("Deadlock trying to incr_ratelimit for %s:%s" % (
type(self).__name__, str(self)))
# rate-limit users who are hitting the API so hard we can't update our stats.
ratelimited = True
return ratelimited, time
return self.backend.rate_limit_entity(self)
def rate_limit_request(self, request: HttpRequest) -> None:
ratelimited, time = self.rate_limit()
@ -75,23 +64,17 @@ class RateLimitedObject(ABC):
def block_access(self, seconds: int) -> None:
"Manually blocks an entity for the desired number of seconds"
_, _, blocking_key = self.get_keys()
with client.pipeline() as pipe:
pipe.set(blocking_key, 1)
pipe.expire(blocking_key, seconds)
pipe.execute()
self.backend.block_access(self, seconds)
def unblock_access(self) -> None:
_, _, blocking_key = self.get_keys()
client.delete(blocking_key)
self.backend.unblock_access(self)
def clear_history(self) -> None:
'''
This is only used by test code now, where it's very helpful in
allowing us to run tests quickly, by giving a user a clean slate.
'''
for key in self.get_keys():
client.delete(key)
self.backend.clear_history(self)
def max_api_calls(self) -> int:
"Returns the API rate limit for the highest limit"
@ -106,7 +89,7 @@ class RateLimitedObject(ABC):
the rate-limit will be reset to 0"""
max_window = self.max_api_window()
max_calls = self.max_api_calls()
return _get_api_calls_left(self, max_window, max_calls)
return self.backend.get_api_calls_left(self, max_window, max_calls)
@abstractmethod
def key_fragment(self) -> str:
@ -124,6 +107,7 @@ class RateLimitedUser(RateLimitedObject):
def __init__(self, user: UserProfile, domain: str='api_by_user') -> None:
self.user = user
self.domain = domain
super().__init__()
def __str__(self) -> str:
return "Id: {}".format(self.user.id)
@ -161,138 +145,215 @@ def remove_ratelimit_rule(range_seconds: int, num_requests: int, domain: str='ap
global rules
rules[domain] = [x for x in rules[domain] if x[0] != range_seconds and x[1] != num_requests]
def _get_api_calls_left(entity: RateLimitedObject, range_seconds: int, max_calls: int) -> Tuple[int, float]:
list_key, set_key, _ = entity.get_keys()
# Count the number of values in our sorted set
# that are between now and the cutoff
now = time.time()
boundary = now - range_seconds
class RateLimiterBackend(ABC):
@classmethod
@abstractmethod
def block_access(cls, entity: RateLimitedObject, seconds: int) -> None:
"Manually blocks an entity for the desired number of seconds"
with client.pipeline() as pipe:
# Count how many API calls in our range have already been made
pipe.zcount(set_key, boundary, now)
# Get the newest call so we can calculate when the ratelimit
# will reset to 0
pipe.lindex(list_key, 0)
@classmethod
@abstractmethod
def unblock_access(cls, entity: RateLimitedObject) -> None:
pass
results = pipe.execute()
@classmethod
@abstractmethod
def clear_history(cls, entity: RateLimitedObject) -> None:
'''
This is only used by test code now, where it's very helpful in
allowing us to run tests quickly, by giving a user a clean slate.
'''
count = results[0] # type: int
newest_call = results[1] # type: Optional[bytes]
@classmethod
@abstractmethod
def get_api_calls_left(cls, entity: RateLimitedObject, range_seconds: int,
max_calls: int) -> Tuple[int, float]:
pass
calls_left = max_calls - count
if newest_call is not None:
time_reset = now + (range_seconds - (now - float(newest_call)))
else:
time_reset = now
@classmethod
@abstractmethod
def rate_limit_entity(cls, entity: RateLimitedObject) -> Tuple[bool, float]:
# Returns (ratelimited, secs_to_freedom)
pass
return calls_left, time_reset
class RedisRateLimiterBackend(RateLimiterBackend):
@classmethod
def block_access(cls, entity: RateLimitedObject, seconds: int) -> None:
"Manually blocks an entity for the desired number of seconds"
_, _, blocking_key = entity.get_keys()
with client.pipeline() as pipe:
pipe.set(blocking_key, 1)
pipe.expire(blocking_key, seconds)
pipe.execute()
def is_ratelimited(entity: RateLimitedObject) -> Tuple[bool, float]:
"Returns a tuple of (rate_limited, time_till_free)"
list_key, set_key, blocking_key = entity.get_keys()
@classmethod
def unblock_access(cls, entity: RateLimitedObject) -> None:
_, _, blocking_key = entity.get_keys()
client.delete(blocking_key)
rules = entity.rules()
@classmethod
def clear_history(cls, entity: RateLimitedObject) -> None:
'''
This is only used by test code now, where it's very helpful in
allowing us to run tests quickly, by giving a user a clean slate.
'''
for key in entity.get_keys():
client.delete(key)
if len(rules) == 0:
@classmethod
def get_api_calls_left(cls, entity: RateLimitedObject, range_seconds: int,
max_calls: int) -> Tuple[int, float]:
list_key, set_key, _ = entity.get_keys()
# Count the number of values in our sorted set
# that are between now and the cutoff
now = time.time()
boundary = now - range_seconds
with client.pipeline() as pipe:
# Count how many API calls in our range have already been made
pipe.zcount(set_key, boundary, now)
# Get the newest call so we can calculate when the ratelimit
# will reset to 0
pipe.lindex(list_key, 0)
results = pipe.execute()
count = results[0] # type: int
newest_call = results[1] # type: Optional[bytes]
calls_left = max_calls - count
if newest_call is not None:
time_reset = now + (range_seconds - (now - float(newest_call)))
else:
time_reset = now
return calls_left, time_reset
@classmethod
def is_ratelimited(cls, entity: RateLimitedObject) -> Tuple[bool, float]:
"Returns a tuple of (rate_limited, time_till_free)"
list_key, set_key, blocking_key = entity.get_keys()
rules = entity.rules()
if len(rules) == 0:
return False, 0.0
# Go through the rules from shortest to longest,
# seeing if this user has violated any of them. First
# get the timestamps for each nth items
with client.pipeline() as pipe:
for _, request_count in rules:
pipe.lindex(list_key, request_count - 1) # 0-indexed list
# Get blocking info
pipe.get(blocking_key)
pipe.ttl(blocking_key)
rule_timestamps = pipe.execute() # type: List[Optional[bytes]]
# Check if there is a manual block on this API key
blocking_ttl_b = rule_timestamps.pop()
key_blocked = rule_timestamps.pop()
if key_blocked is not None:
# We are manually blocked. Report for how much longer we will be
if blocking_ttl_b is None:
blocking_ttl = 0.5
else:
blocking_ttl = int(blocking_ttl_b)
return True, blocking_ttl
now = time.time()
for timestamp, (range_seconds, num_requests) in zip(rule_timestamps, rules):
# Check if the nth timestamp is newer than the associated rule. If so,
# it means we've hit our limit for this rule
if timestamp is None:
continue
boundary = float(timestamp) + range_seconds
if boundary > now:
free = boundary - now
return True, free
# No api calls recorded yet
return False, 0.0
# Go through the rules from shortest to longest,
# seeing if this user has violated any of them. First
# get the timestamps for each nth items
with client.pipeline() as pipe:
for _, request_count in rules:
pipe.lindex(list_key, request_count - 1) # 0-indexed list
@classmethod
def incr_ratelimit(cls, entity: RateLimitedObject) -> None:
"""Increases the rate-limit for the specified entity"""
list_key, set_key, _ = entity.get_keys()
now = time.time()
# Get blocking info
pipe.get(blocking_key)
pipe.ttl(blocking_key)
# If we have no rules, we don't store anything
if len(rules) == 0:
return
rule_timestamps = pipe.execute() # type: List[Optional[bytes]]
# Start redis transaction
with client.pipeline() as pipe:
count = 0
while True:
try:
# To avoid a race condition between getting the element we might trim from our list
# and removing it from our associated set, we abort this whole transaction if
# another agent manages to change our list out from under us
# When watching a value, the pipeline is set to Immediate mode
pipe.watch(list_key)
# Check if there is a manual block on this API key
blocking_ttl_b = rule_timestamps.pop()
key_blocked = rule_timestamps.pop()
# Get the last elem that we'll trim (so we can remove it from our sorted set)
last_val = pipe.lindex(list_key, entity.max_api_calls() - 1)
# Restart buffered execution
pipe.multi()
# Add this timestamp to our list
pipe.lpush(list_key, now)
# Trim our list to the oldest rule we have
pipe.ltrim(list_key, 0, entity.max_api_calls() - 1)
# Add our new value to the sorted set that we keep
# We need to put the score and val both as timestamp,
# as we sort by score but remove by value
pipe.zadd(set_key, {str(now): now})
# Remove the trimmed value from our sorted set, if there was one
if last_val is not None:
pipe.zrem(set_key, last_val)
# Set the TTL for our keys as well
api_window = entity.max_api_window()
pipe.expire(list_key, api_window)
pipe.expire(set_key, api_window)
pipe.execute()
# If no exception was raised in the execution, there were no transaction conflicts
break
except redis.WatchError:
if count > 10:
raise RateLimiterLockingException()
count += 1
continue
@classmethod
def rate_limit_entity(cls, entity: RateLimitedObject) -> Tuple[bool, float]:
ratelimited, time = cls.is_ratelimited(entity)
if ratelimited:
statsd.incr("ratelimiter.limited.%s.%s" % (type(entity), str(entity)))
if key_blocked is not None:
# We are manually blocked. Report for how much longer we will be
if blocking_ttl_b is None:
blocking_ttl = 0.5
else:
blocking_ttl = int(blocking_ttl_b)
return True, blocking_ttl
now = time.time()
for timestamp, (range_seconds, num_requests) in zip(rule_timestamps, rules):
# Check if the nth timestamp is newer than the associated rule. If so,
# it means we've hit our limit for this rule
if timestamp is None:
continue
boundary = float(timestamp) + range_seconds
if boundary > now:
free = boundary - now
return True, free
# No api calls recorded yet
return False, 0.0
def incr_ratelimit(entity: RateLimitedObject) -> None:
"""Increases the rate-limit for the specified entity"""
list_key, set_key, _ = entity.get_keys()
now = time.time()
# If we have no rules, we don't store anything
if len(rules) == 0:
return
# Start redis transaction
with client.pipeline() as pipe:
count = 0
while True:
try:
# To avoid a race condition between getting the element we might trim from our list
# and removing it from our associated set, we abort this whole transaction if
# another agent manages to change our list out from under us
# When watching a value, the pipeline is set to Immediate mode
pipe.watch(list_key)
cls.incr_ratelimit(entity)
except RateLimiterLockingException:
logger.warning("Deadlock trying to incr_ratelimit for %s:%s" % (
type(entity).__name__, str(entity)))
# rate-limit users who are hitting the API so hard we can't update our stats.
ratelimited = True
# Get the last elem that we'll trim (so we can remove it from our sorted set)
last_val = pipe.lindex(list_key, entity.max_api_calls() - 1)
# Restart buffered execution
pipe.multi()
# Add this timestamp to our list
pipe.lpush(list_key, now)
# Trim our list to the oldest rule we have
pipe.ltrim(list_key, 0, entity.max_api_calls() - 1)
# Add our new value to the sorted set that we keep
# We need to put the score and val both as timestamp,
# as we sort by score but remove by value
pipe.zadd(set_key, {str(now): now})
# Remove the trimmed value from our sorted set, if there was one
if last_val is not None:
pipe.zrem(set_key, last_val)
# Set the TTL for our keys as well
api_window = entity.max_api_window()
pipe.expire(list_key, api_window)
pipe.expire(set_key, api_window)
pipe.execute()
# If no exception was raised in the execution, there were no transaction conflicts
break
except redis.WatchError:
if count > 10:
raise RateLimiterLockingException()
count += 1
continue
return ratelimited, time
class RateLimitResult:
def __init__(self, entity: RateLimitedObject, secs_to_freedom: float, over_limit: bool,

View File

@ -117,7 +117,7 @@ class RateLimitTests(ZulipTestCase):
user = self.example_user('cordelia')
RateLimitedUser(user).clear_history()
with mock.patch('zerver.lib.rate_limiter.incr_ratelimit',
with mock.patch('zerver.lib.rate_limiter.RedisRateLimiterBackend.incr_ratelimit',
side_effect=RateLimiterLockingException):
result = self.send_api_message(user, "some stuff")
self.assertEqual(result.status_code, 429)

View File

@ -406,7 +406,7 @@ class WorkerTest(ZulipTestCase):
self.assertEqual(mock_mirror_email.call_count, 4)
# If RateLimiterLockingException is thrown, we rate-limit the new message:
with patch('zerver.lib.rate_limiter.incr_ratelimit',
with patch('zerver.lib.rate_limiter.RedisRateLimiterBackend.incr_ratelimit',
side_effect=RateLimiterLockingException):
fake_client.queue.append(('email_mirror', data[0]))
worker.start()

View File

@ -180,6 +180,7 @@ rate_limiting_rules = settings.RATE_LIMITING_RULES['authenticate_by_username']
class RateLimitedAuthenticationByUsername(RateLimitedObject):
def __init__(self, username: str) -> None:
self.username = username
super().__init__()
def __str__(self) -> str:
return "Username: {}".format(self.username)