rate_limit: Move rate_limit_remote_server to zilencer.auth.

This allows us to avoid importing from zilencer conditionally in
zerver.lib.rate_limiter, as we make rate limiting self-contained now.

Signed-off-by: Zixuan James Li <p359101898@gmail.com>
This commit is contained in:
Zixuan James Li 2022-08-14 10:19:44 -04:00 committed by Tim Abbott
parent f158c86ae1
commit 2aac1dc40a
4 changed files with 24 additions and 21 deletions

View File

@ -2,7 +2,7 @@ import logging
import os
import time
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Type, cast
from typing import Dict, List, Optional, Set, Tuple, Type, cast
import orjson
import redis
@ -16,9 +16,6 @@ from zerver.lib.redis_utils import get_redis_client
from zerver.lib.utils import statsd
from zerver.models import UserProfile
if TYPE_CHECKING:
from zilencer.models import RemoteZulipServer
# Implement a rate-limiting scheme inspired by the one described here, but heavily modified
# https://www.domaintools.com/resources/blog/rate-limiting-with-redis
@ -620,18 +617,6 @@ def rate_limit_request_by_ip(request: HttpRequest, domain: str) -> None:
rate_limit_ip(request, ip_addr, domain=domain)
def rate_limit_remote_server(
request: HttpRequest, remote_server: "RemoteZulipServer", domain: str
) -> None:
if settings.ZILENCER_ENABLED:
from zilencer.models import RateLimitedRemoteZulipServer
try:
RateLimitedRemoteZulipServer(remote_server, domain=domain).rate_limit_request(request)
except RateLimited as e:
logger.warning("Remote server %s exceeded rate limits on domain %s", remote_server, domain)
raise e
def should_rate_limit(request: HttpRequest) -> bool:
if not settings.RATE_LIMITING:
return False

View File

@ -718,7 +718,7 @@ class RateLimitTestCase(ZulipTestCase):
server.save()
with self.settings(RATE_LIMITING=True), mock.patch(
"zerver.lib.rate_limiter.rate_limit_remote_server"
"zilencer.auth.rate_limit_remote_server"
) as rate_limit_mock:
result = self.uuid_post(
server_uuid,

View File

@ -430,12 +430,12 @@ class RateLimitTests(ZulipTestCase):
self.DEFAULT_SUBDOMAIN = ""
RateLimitedRemoteZulipServer(server).clear_history()
with self.assertLogs("zerver.lib.rate_limiter", level="WARNING") as m:
with self.assertLogs("zilencer.auth", level="WARNING") as m:
self.do_test_hit_ratelimits(lambda: self.uuid_post(server_uuid, endpoint, payload))
self.assertEqual(
m.output,
[
f"WARNING:zerver.lib.rate_limiter:Remote server <RemoteZulipServer demo.example.com {server_uuid[:12]}> exceeded rate limits on domain api_by_remote_server"
f"WARNING:zilencer.auth:Remote server <RemoteZulipServer demo.example.com {server_uuid[:12]}> exceeded rate limits on domain api_by_remote_server"
],
)
finally:

View File

@ -1,3 +1,4 @@
import logging
from functools import wraps
from typing import Any, Callable
@ -12,15 +13,22 @@ from zerver.decorator import get_basic_credentials, process_client
from zerver.lib.exceptions import (
ErrorCode,
JsonableError,
RateLimited,
RemoteServerDeactivatedError,
UnauthorizedError,
)
from zerver.lib.rate_limiter import rate_limit_remote_server, should_rate_limit
from zerver.lib.rate_limiter import should_rate_limit
from zerver.lib.request import RequestNotes
from zerver.lib.rest import get_target_view_function_or_response
from zerver.lib.subdomains import get_subdomain
from zerver.models import Realm
from zilencer.models import RemoteZulipServer, get_remote_server_by_uuid
from zilencer.models import (
RateLimitedRemoteZulipServer,
RemoteZulipServer,
get_remote_server_by_uuid,
)
logger = logging.getLogger(__name__)
ParamT = ParamSpec("ParamT")
@ -43,6 +51,16 @@ class InvalidZulipServerKeyError(InvalidZulipServerError):
return "Zulip server auth failure: key does not match role {role}"
def rate_limit_remote_server(
request: HttpRequest, remote_server: RemoteZulipServer, domain: str
) -> None:
try:
RateLimitedRemoteZulipServer(remote_server, domain=domain).rate_limit_request(request)
except RateLimited as e:
logger.warning("Remote server %s exceeded rate limits on domain %s", remote_server, domain)
raise e
def validate_remote_server(
request: HttpRequest,
role: str,