2017-11-16 00:43:27 +01:00
|
|
|
import logging
|
|
|
|
import time
|
2016-06-04 16:52:18 +02:00
|
|
|
from typing import Any, Callable, Optional
|
|
|
|
|
2013-05-31 19:47:09 +02:00
|
|
|
from django.conf import settings
|
2020-06-11 00:54:34 +02:00
|
|
|
from django.core.management.base import BaseCommand, CommandError, CommandParser
|
2023-04-13 02:05:54 +02:00
|
|
|
from returns.curry import partial
|
2023-10-12 19:43:45 +02:00
|
|
|
from typing_extensions import override
|
2013-05-31 19:47:09 +02:00
|
|
|
|
2020-03-04 14:05:25 +01:00
|
|
|
from zerver.lib.rate_limiter import RateLimitedUser, client
|
2023-12-15 01:16:00 +01:00
|
|
|
from zerver.models.users import get_user_profile_by_id
|
2013-05-31 19:47:09 +02:00
|
|
|
|
2020-01-14 21:59:46 +01:00
|
|
|
|
2013-05-31 19:47:09 +02:00
|
|
|
class Command(BaseCommand):
|
2020-10-23 02:43:28 +02:00
|
|
|
help = """Checks Redis to make sure our rate limiting system hasn't grown a bug
|
|
|
|
and left Redis with a bunch of data
|
2013-05-31 19:47:09 +02:00
|
|
|
|
|
|
|
Usage: ./manage.py [--trim] check_redis"""
|
|
|
|
|
2023-10-12 19:43:45 +02:00
|
|
|
@override
|
2017-10-26 11:35:57 +02:00
|
|
|
def add_arguments(self, parser: CommandParser) -> None:
|
2021-02-12 08:20:45 +01:00
|
|
|
parser.add_argument("-t", "--trim", action="store_true", help="Actually trim excess")
|
2013-05-31 19:47:09 +02:00
|
|
|
|
2021-02-12 08:19:30 +01:00
|
|
|
def _check_within_range(
|
|
|
|
self,
|
2021-07-03 06:44:37 +02:00
|
|
|
key: bytes,
|
2021-02-12 08:19:30 +01:00
|
|
|
count_func: Callable[[], int],
|
2021-07-03 06:44:37 +02:00
|
|
|
trim_func: Optional[Callable[[bytes, int], object]] = None,
|
2021-02-12 08:19:30 +01:00
|
|
|
) -> None:
|
2021-07-03 06:44:37 +02:00
|
|
|
user_id = int(key.split(b":")[2])
|
2018-05-17 22:12:23 +02:00
|
|
|
user = get_user_profile_by_id(user_id)
|
2017-07-31 08:00:57 +02:00
|
|
|
entity = RateLimitedUser(user)
|
2020-03-04 14:05:25 +01:00
|
|
|
max_calls = entity.max_api_calls()
|
2013-05-31 19:47:09 +02:00
|
|
|
|
2013-06-05 22:56:25 +02:00
|
|
|
age = int(client.ttl(key))
|
|
|
|
if age < 0:
|
2020-05-02 08:44:14 +02:00
|
|
|
logging.error("Found key with age of %s, will never expire: %s", age, key)
|
2013-06-05 22:44:22 +02:00
|
|
|
|
2013-05-31 19:47:09 +02:00
|
|
|
count = count_func()
|
|
|
|
if count > max_calls:
|
2021-02-12 08:19:30 +01:00
|
|
|
logging.error(
|
|
|
|
"Redis health check found key with more elements \
|
|
|
|
than max_api_calls! (trying to trim) %s %s",
|
|
|
|
key,
|
|
|
|
count,
|
|
|
|
)
|
2016-01-27 22:43:44 +01:00
|
|
|
if trim_func is not None:
|
2020-03-04 14:05:25 +01:00
|
|
|
client.expire(key, entity.max_api_window())
|
2013-05-31 19:47:09 +02:00
|
|
|
trim_func(key, max_calls)
|
|
|
|
|
2023-10-12 19:43:45 +02:00
|
|
|
@override
|
2017-10-26 11:35:57 +02:00
|
|
|
def handle(self, *args: Any, **options: Any) -> None:
|
2013-05-31 19:47:09 +02:00
|
|
|
if not settings.RATE_LIMITING:
|
2020-10-23 02:43:28 +02:00
|
|
|
raise CommandError("This machine is not using Redis or rate limiting, aborting")
|
2013-05-31 19:47:09 +02:00
|
|
|
|
|
|
|
# Find all keys, and make sure they're all within size constraints
|
2021-07-03 06:44:37 +02:00
|
|
|
wildcard_list = "ratelimit:*:*:*:list"
|
|
|
|
wildcard_zset = "ratelimit:*:*:*:zset"
|
2013-05-31 19:47:09 +02:00
|
|
|
|
2021-07-03 06:44:37 +02:00
|
|
|
trim_func: Optional[Callable[[bytes, int], object]] = lambda key, max_calls: client.ltrim(
|
2021-02-12 08:19:30 +01:00
|
|
|
key, 0, max_calls - 1
|
|
|
|
)
|
2021-02-12 08:20:45 +01:00
|
|
|
if not options["trim"]:
|
2016-01-27 22:43:44 +01:00
|
|
|
trim_func = None
|
2013-05-31 19:47:09 +02:00
|
|
|
|
|
|
|
lists = client.keys(wildcard_list)
|
|
|
|
for list_name in lists:
|
2023-04-13 02:05:54 +02:00
|
|
|
self._check_within_range(list_name, partial(client.llen, list_name), trim_func)
|
2013-05-31 19:47:09 +02:00
|
|
|
|
|
|
|
zsets = client.keys(wildcard_zset)
|
|
|
|
for zset in zsets:
|
|
|
|
now = time.time()
|
|
|
|
# We can warn on our zset being too large, but we don't know what
|
|
|
|
# elements to trim. We'd have to go through every list item and take
|
|
|
|
# the intersection. The best we can do is expire it
|
2021-02-12 08:19:30 +01:00
|
|
|
self._check_within_range(
|
2023-04-13 02:05:54 +02:00
|
|
|
zset,
|
|
|
|
partial(client.zcount, zset, 0, now),
|
|
|
|
lambda key, max_calls: None,
|
2021-02-12 08:19:30 +01:00
|
|
|
)
|