2013-05-29 23:58:07 +02:00
|
|
|
|
2017-05-05 12:07:10 +02:00
|
|
|
import os
|
|
|
|
|
2018-05-11 01:40:23 +02:00
|
|
|
from typing import Any, Iterator, List, Optional, Tuple
|
2016-03-27 12:09:54 +02:00
|
|
|
|
2013-05-29 23:58:07 +02:00
|
|
|
from django.conf import settings
|
2014-02-05 00:35:32 +01:00
|
|
|
from zerver.lib.redis_utils import get_redis_client
|
2013-05-29 23:58:07 +02:00
|
|
|
|
2016-03-27 12:09:54 +02:00
|
|
|
from zerver.models import UserProfile
|
|
|
|
|
2013-05-29 23:58:07 +02:00
|
|
|
import redis
|
|
|
|
import time
|
|
|
|
import logging
|
|
|
|
|
|
|
|
# Implement a rate-limiting scheme inspired by the one described here, but heavily modified
|
|
|
|
# http://blog.domaintools.com/2013/04/rate-limiting-with-redis/
|
|
|
|
|
2014-02-05 00:35:32 +01:00
|
|
|
client = get_redis_client()
|
2017-05-07 17:09:18 +02:00
|
|
|
rules = settings.RATE_LIMITING_RULES # type: List[Tuple[int, int]]
|
2013-05-29 23:58:07 +02:00
|
|
|
|
2017-11-03 03:12:25 +01:00
|
|
|
KEY_PREFIX = ''
|
2017-05-05 12:07:10 +02:00
|
|
|
|
2018-12-12 21:13:00 +01:00
|
|
|
class RateLimiterLockingException(Exception):
|
|
|
|
pass
|
|
|
|
|
2017-11-05 11:37:41 +01:00
|
|
|
class RateLimitedObject:
|
2018-05-11 01:40:23 +02:00
|
|
|
def get_keys(self) -> List[str]:
|
2017-07-28 06:40:52 +02:00
|
|
|
key_fragment = self.key_fragment()
|
|
|
|
return ["{}ratelimit:{}:{}".format(KEY_PREFIX, key_fragment, keytype)
|
|
|
|
for keytype in ['list', 'zset', 'block']]
|
|
|
|
|
2018-05-11 01:40:23 +02:00
|
|
|
def key_fragment(self) -> str:
|
2017-08-25 18:53:09 +02:00
|
|
|
raise NotImplementedError()
|
2017-07-28 06:40:52 +02:00
|
|
|
|
2017-11-05 11:15:10 +01:00
|
|
|
def rules(self) -> List[Tuple[int, int]]:
|
2017-08-25 18:53:09 +02:00
|
|
|
raise NotImplementedError()
|
2017-07-28 06:40:52 +02:00
|
|
|
|
|
|
|
class RateLimitedUser(RateLimitedObject):
|
2018-05-11 01:40:23 +02:00
|
|
|
def __init__(self, user: UserProfile, domain: str='all') -> None:
|
2017-07-28 06:40:52 +02:00
|
|
|
self.user = user
|
|
|
|
self.domain = domain
|
|
|
|
|
2018-05-11 01:40:23 +02:00
|
|
|
def key_fragment(self) -> str:
|
2017-07-28 06:40:52 +02:00
|
|
|
return "{}:{}:{}".format(type(self.user), self.user.id, self.domain)
|
|
|
|
|
2017-11-05 11:15:10 +01:00
|
|
|
def rules(self) -> List[Tuple[int, int]]:
|
2017-07-28 06:40:52 +02:00
|
|
|
if self.user.rate_limits != "":
|
|
|
|
result = [] # type: List[Tuple[int, int]]
|
|
|
|
for limit in self.user.rate_limits.split(','):
|
|
|
|
(seconds, requests) = limit.split(':', 2)
|
|
|
|
result.append((int(seconds), int(requests)))
|
|
|
|
return result
|
|
|
|
return rules
|
|
|
|
|
2018-05-11 01:40:23 +02:00
|
|
|
def bounce_redis_key_prefix_for_testing(test_name: str) -> None:
|
2017-05-05 12:07:10 +02:00
|
|
|
global KEY_PREFIX
|
2018-05-11 01:40:23 +02:00
|
|
|
KEY_PREFIX = test_name + ':' + str(os.getpid()) + ':'
|
2017-05-05 12:07:10 +02:00
|
|
|
|
2017-11-05 11:15:10 +01:00
|
|
|
def max_api_calls(entity: RateLimitedObject) -> int:
|
2013-05-29 23:58:07 +02:00
|
|
|
"Returns the API rate limit for the highest limit"
|
2017-07-28 06:40:52 +02:00
|
|
|
return entity.rules()[-1][1]
|
2013-05-29 23:58:07 +02:00
|
|
|
|
2017-11-05 11:15:10 +01:00
|
|
|
def max_api_window(entity: RateLimitedObject) -> int:
|
2013-05-29 23:58:07 +02:00
|
|
|
"Returns the API time window for the highest limit"
|
2017-07-28 06:40:52 +02:00
|
|
|
return entity.rules()[-1][0]
|
2013-05-29 23:58:07 +02:00
|
|
|
|
2017-11-05 11:15:10 +01:00
|
|
|
def add_ratelimit_rule(range_seconds: int, num_requests: int) -> None:
|
2013-05-29 23:58:07 +02:00
|
|
|
"Add a rate-limiting rule to the ratelimiter"
|
2013-06-05 22:32:23 +02:00
|
|
|
global rules
|
2013-05-29 23:58:07 +02:00
|
|
|
|
|
|
|
rules.append((range_seconds, num_requests))
|
2016-06-17 17:50:52 +02:00
|
|
|
rules.sort(key=lambda x: x[0])
|
2013-05-29 23:58:07 +02:00
|
|
|
|
2017-11-05 11:15:10 +01:00
|
|
|
def remove_ratelimit_rule(range_seconds: int, num_requests: int) -> None:
|
2013-05-29 23:58:07 +02:00
|
|
|
global rules
|
2015-11-01 17:14:31 +01:00
|
|
|
rules = [x for x in rules if x[0] != range_seconds and x[1] != num_requests]
|
2013-05-29 23:58:07 +02:00
|
|
|
|
2017-11-05 11:15:10 +01:00
|
|
|
def block_access(entity: RateLimitedObject, seconds: int) -> None:
|
2017-07-31 07:08:33 +02:00
|
|
|
"Manually blocks an entity for the desired number of seconds"
|
2017-07-28 06:40:52 +02:00
|
|
|
_, _, blocking_key = entity.get_keys()
|
2013-05-29 23:58:07 +02:00
|
|
|
with client.pipeline() as pipe:
|
|
|
|
pipe.set(blocking_key, 1)
|
|
|
|
pipe.expire(blocking_key, seconds)
|
|
|
|
pipe.execute()
|
|
|
|
|
2017-11-05 11:15:10 +01:00
|
|
|
def unblock_access(entity: RateLimitedObject) -> None:
|
2017-07-28 06:40:52 +02:00
|
|
|
_, _, blocking_key = entity.get_keys()
|
2013-05-29 23:58:07 +02:00
|
|
|
client.delete(blocking_key)
|
|
|
|
|
2017-11-05 11:15:10 +01:00
|
|
|
def clear_history(entity: RateLimitedObject) -> None:
|
2013-06-20 23:18:39 +02:00
|
|
|
'''
|
|
|
|
This is only used by test code now, where it's very helpful in
|
|
|
|
allowing us to run tests quickly, by giving a user a clean slate.
|
|
|
|
'''
|
2017-07-28 06:40:52 +02:00
|
|
|
for key in entity.get_keys():
|
2013-06-20 23:18:39 +02:00
|
|
|
client.delete(key)
|
|
|
|
|
2017-11-05 11:15:10 +01:00
|
|
|
def _get_api_calls_left(entity: RateLimitedObject, range_seconds: int, max_calls: int) -> Tuple[int, float]:
|
2017-07-28 06:40:52 +02:00
|
|
|
list_key, set_key, _ = entity.get_keys()
|
2013-05-29 23:58:07 +02:00
|
|
|
# Count the number of values in our sorted set
|
|
|
|
# that are between now and the cutoff
|
|
|
|
now = time.time()
|
|
|
|
boundary = now - range_seconds
|
|
|
|
|
|
|
|
with client.pipeline() as pipe:
|
|
|
|
# Count how many API calls in our range have already been made
|
|
|
|
pipe.zcount(set_key, boundary, now)
|
|
|
|
# Get the newest call so we can calculate when the ratelimit
|
|
|
|
# will reset to 0
|
|
|
|
pipe.lindex(list_key, 0)
|
|
|
|
|
|
|
|
results = pipe.execute()
|
|
|
|
|
2017-08-26 00:52:43 +02:00
|
|
|
count = results[0] # type: int
|
|
|
|
newest_call = results[1] # type: Optional[bytes]
|
2013-05-29 23:58:07 +02:00
|
|
|
|
|
|
|
calls_left = max_calls - count
|
|
|
|
if newest_call is not None:
|
|
|
|
time_reset = now + (range_seconds - (now - float(newest_call)))
|
|
|
|
else:
|
|
|
|
time_reset = now
|
|
|
|
|
|
|
|
return calls_left, time_reset
|
|
|
|
|
2017-11-05 11:15:10 +01:00
|
|
|
def api_calls_left(entity: RateLimitedObject) -> Tuple[int, float]:
|
2013-05-29 23:58:07 +02:00
|
|
|
"""Returns how many API calls in this range this client has, as well as when
|
|
|
|
the rate-limit will be reset to 0"""
|
2017-07-31 08:00:57 +02:00
|
|
|
max_window = max_api_window(entity)
|
2017-07-31 08:08:47 +02:00
|
|
|
max_calls = max_api_calls(entity)
|
2017-07-31 07:55:09 +02:00
|
|
|
return _get_api_calls_left(entity, max_window, max_calls)
|
2013-05-29 23:58:07 +02:00
|
|
|
|
2017-11-05 11:15:10 +01:00
|
|
|
def is_ratelimited(entity: RateLimitedObject) -> Tuple[bool, float]:
|
2013-05-29 23:58:07 +02:00
|
|
|
"Returns a tuple of (rate_limited, time_till_free)"
|
2017-07-28 06:40:52 +02:00
|
|
|
list_key, set_key, blocking_key = entity.get_keys()
|
2013-05-29 23:58:07 +02:00
|
|
|
|
2017-07-28 06:40:52 +02:00
|
|
|
rules = entity.rules()
|
2013-05-29 23:58:07 +02:00
|
|
|
|
|
|
|
if len(rules) == 0:
|
|
|
|
return False, 0.0
|
|
|
|
|
|
|
|
# Go through the rules from shortest to longest,
|
|
|
|
# seeing if this user has violated any of them. First
|
|
|
|
# get the timestamps for each nth items
|
|
|
|
with client.pipeline() as pipe:
|
|
|
|
for _, request_count in rules:
|
2017-05-07 17:09:18 +02:00
|
|
|
pipe.lindex(list_key, request_count - 1) # 0-indexed list
|
2013-05-29 23:58:07 +02:00
|
|
|
|
|
|
|
# Get blocking info
|
|
|
|
pipe.get(blocking_key)
|
|
|
|
pipe.ttl(blocking_key)
|
|
|
|
|
2017-08-26 00:52:43 +02:00
|
|
|
rule_timestamps = pipe.execute() # type: List[Optional[bytes]]
|
2013-05-29 23:58:07 +02:00
|
|
|
|
|
|
|
# Check if there is a manual block on this API key
|
2017-08-26 00:52:43 +02:00
|
|
|
blocking_ttl_b = rule_timestamps.pop()
|
2013-05-29 23:58:07 +02:00
|
|
|
key_blocked = rule_timestamps.pop()
|
|
|
|
|
|
|
|
if key_blocked is not None:
|
|
|
|
# We are manually blocked. Report for how much longer we will be
|
2017-08-26 00:52:43 +02:00
|
|
|
if blocking_ttl_b is None:
|
2013-05-29 23:58:07 +02:00
|
|
|
blocking_ttl = 0.5
|
|
|
|
else:
|
2017-08-26 00:52:43 +02:00
|
|
|
blocking_ttl = int(blocking_ttl_b)
|
2013-05-29 23:58:07 +02:00
|
|
|
return True, blocking_ttl
|
|
|
|
|
|
|
|
now = time.time()
|
2016-01-24 05:21:28 +01:00
|
|
|
for timestamp, (range_seconds, num_requests) in zip(rule_timestamps, rules):
|
2013-05-29 23:58:07 +02:00
|
|
|
# Check if the nth timestamp is newer than the associated rule. If so,
|
|
|
|
# it means we've hit our limit for this rule
|
|
|
|
if timestamp is None:
|
|
|
|
continue
|
|
|
|
|
2017-08-26 00:52:43 +02:00
|
|
|
boundary = float(timestamp) + range_seconds
|
2013-05-29 23:58:07 +02:00
|
|
|
if boundary > now:
|
|
|
|
free = boundary - now
|
|
|
|
return True, free
|
|
|
|
|
|
|
|
# No api calls recorded yet
|
|
|
|
return False, 0.0
|
|
|
|
|
2017-11-05 11:15:10 +01:00
|
|
|
def incr_ratelimit(entity: RateLimitedObject) -> None:
|
2017-07-31 07:26:24 +02:00
|
|
|
"""Increases the rate-limit for the specified entity"""
|
2017-07-28 06:40:52 +02:00
|
|
|
list_key, set_key, _ = entity.get_keys()
|
2013-05-29 23:58:07 +02:00
|
|
|
now = time.time()
|
|
|
|
|
|
|
|
# If we have no rules, we don't store anything
|
|
|
|
if len(rules) == 0:
|
|
|
|
return
|
|
|
|
|
|
|
|
# Start redis transaction
|
|
|
|
with client.pipeline() as pipe:
|
|
|
|
count = 0
|
|
|
|
while True:
|
|
|
|
try:
|
|
|
|
# To avoid a race condition between getting the element we might trim from our list
|
|
|
|
# and removing it from our associated set, we abort this whole transaction if
|
|
|
|
# another agent manages to change our list out from under us
|
|
|
|
# When watching a value, the pipeline is set to Immediate mode
|
|
|
|
pipe.watch(list_key)
|
|
|
|
|
|
|
|
# Get the last elem that we'll trim (so we can remove it from our sorted set)
|
2017-07-31 08:08:47 +02:00
|
|
|
last_val = pipe.lindex(list_key, max_api_calls(entity) - 1)
|
2013-05-29 23:58:07 +02:00
|
|
|
|
|
|
|
# Restart buffered execution
|
|
|
|
pipe.multi()
|
|
|
|
|
|
|
|
# Add this timestamp to our list
|
|
|
|
pipe.lpush(list_key, now)
|
|
|
|
|
|
|
|
# Trim our list to the oldest rule we have
|
2017-07-31 08:08:47 +02:00
|
|
|
pipe.ltrim(list_key, 0, max_api_calls(entity) - 1)
|
2013-05-29 23:58:07 +02:00
|
|
|
|
|
|
|
# Add our new value to the sorted set that we keep
|
|
|
|
# We need to put the score and val both as timestamp,
|
|
|
|
# as we sort by score but remove by value
|
|
|
|
pipe.zadd(set_key, now, now)
|
|
|
|
|
|
|
|
# Remove the trimmed value from our sorted set, if there was one
|
|
|
|
if last_val is not None:
|
|
|
|
pipe.zrem(set_key, last_val)
|
|
|
|
|
2013-06-05 22:32:23 +02:00
|
|
|
# Set the TTL for our keys as well
|
2017-07-31 08:00:57 +02:00
|
|
|
api_window = max_api_window(entity)
|
2013-06-05 22:32:23 +02:00
|
|
|
pipe.expire(list_key, api_window)
|
|
|
|
pipe.expire(set_key, api_window)
|
|
|
|
|
2013-05-29 23:58:07 +02:00
|
|
|
pipe.execute()
|
|
|
|
|
|
|
|
# If no exception was raised in the execution, there were no transaction conflicts
|
|
|
|
break
|
|
|
|
except redis.WatchError:
|
|
|
|
if count > 10:
|
2018-12-12 21:13:00 +01:00
|
|
|
raise RateLimiterLockingException()
|
2013-05-29 23:58:07 +02:00
|
|
|
count += 1
|
|
|
|
|
|
|
|
continue
|