rate_limiter: Extract rate limit related functions.

This refactors rate limit related functions from `zerver.decorator` to
zerver.lib.rate_limiter.

We conditionally import `RemoteZulipServer`, `RequestNotes`, and
`RateLimitedRemoteZulipServer` to avoid circular dependency.

Most instances of importing these functions from `zerver.decorator` got
updated, with a few exceptions in `zerver.tests.test_decorators`, where
we do want to mock the rate limiting functions imported in
`zerver.decorator`. The same goes with the mocking example in the
"testing-with-django" documentation.

Signed-off-by: Zixuan James Li <p359101898@gmail.com>
This commit is contained in:
Zixuan James Li 2022-08-05 11:40:03 -04:00 committed by Tim Abbott
parent 232ba4866a
commit c9f54766c3
7 changed files with 143 additions and 151 deletions

View File

@ -4,22 +4,9 @@ import logging
import urllib
from functools import wraps
from io import BytesIO
from typing import (
TYPE_CHECKING,
Callable,
Dict,
Optional,
Sequence,
Set,
TypeVar,
Union,
cast,
overload,
)
from typing import TYPE_CHECKING, Callable, Dict, Optional, Sequence, TypeVar, Union, cast, overload
import django_otp
import orjson
from circuitbreaker import CircuitBreakerError, circuit
from django.conf import settings
from django.contrib.auth import REDIRECT_FIELD_NAME
from django.contrib.auth import login as django_login
@ -38,7 +25,6 @@ from django_otp import user_has_device
from two_factor.utils import default_device
from typing_extensions import Concatenate, ParamSpec
from zerver.lib.cache import cache_with_key
from zerver.lib.exceptions import (
AccessDeniedError,
AnomalousWebhookPayload,
@ -50,7 +36,6 @@ from zerver.lib.exceptions import (
OrganizationAdministratorRequired,
OrganizationMemberRequired,
OrganizationOwnerRequired,
RateLimited,
RealmDeactivatedError,
RemoteServerDeactivatedError,
UnauthorizedError,
@ -59,7 +44,7 @@ from zerver.lib.exceptions import (
WebhookError,
)
from zerver.lib.queue import queue_json_publish
from zerver.lib.rate_limiter import RateLimitedIPAddr, RateLimitedUser
from zerver.lib.rate_limiter import is_local_addr, rate_limit, rate_limit_user
from zerver.lib.request import REQ, RequestNotes, has_request_variables
from zerver.lib.response import json_method_not_allowed, json_success
from zerver.lib.subdomains import get_subdomain, user_matches_subdomain
@ -70,17 +55,11 @@ from zerver.lib.utils import has_api_key_format, statsd
from zerver.models import Realm, UserProfile, get_client, get_user_profile_by_api_key
if settings.ZILENCER_ENABLED:
from zilencer.models import (
RateLimitedRemoteZulipServer,
RemoteZulipServer,
get_remote_server_by_uuid,
)
from zilencer.models import RemoteZulipServer, get_remote_server_by_uuid
if TYPE_CHECKING:
from django.http.request import _ImmutableQueryDict
rate_limiter_logger = logging.getLogger("zerver.lib.rate_limiter")
webhook_logger = logging.getLogger("zulip.zerver.webhooks")
webhook_unsupported_events_logger = logging.getLogger("zulip.zerver.webhooks.unsupported")
webhook_anomalous_payloads_logger = logging.getLogger("zulip.zerver.webhooks.anomalous")
@ -907,10 +886,6 @@ def authenticated_json_view(
return _wrapped_view_func
def is_local_addr(addr: str) -> bool:
return addr in ("127.0.0.1", "::1")
# These views are used by the main Django server to notify the Tornado server
# of events. We protect them from the outside world by checking a shared
# secret, and also the originating IP (for now).
@ -921,16 +896,6 @@ def authenticate_notify(request: HttpRequest, secret: str = REQ("secret")) -> bo
)
def client_is_exempt_from_rate_limiting(request: HttpRequest) -> bool:
# Don't rate limit requests from Django that come from our own servers,
# and don't rate-limit dev instances
client = RequestNotes.get_notes(request).client
return (client is not None and client.name.lower() == "internal") and (
is_local_addr(request.META["REMOTE_ADDR"]) or settings.DEBUG_RATE_LIMITING
)
def internal_notify_view(
is_tornado_view: bool,
) -> Callable[
@ -992,100 +957,6 @@ def statsd_increment(
return wrapper
def rate_limit_user(request: HttpRequest, user: UserProfile, domain: str) -> None:
"""Returns whether or not a user was rate limited. Will raise a RateLimited exception
if the user has been rate limited, otherwise returns and modifies request to contain
the rate limit information"""
RateLimitedUser(user, domain=domain).rate_limit_request(request)
@cache_with_key(lambda: "tor_ip_addresses:", timeout=60 * 60)
@circuit(failure_threshold=2, recovery_timeout=60 * 10)
def get_tor_ips() -> Set[str]:
if not settings.RATE_LIMIT_TOR_TOGETHER:
return set()
# Cron job in /etc/cron.d/fetch-tor-exit-nodes fetches this
# hourly; we cache it in memcached to prevent going to disk on
# every unauth'd request. In case of failures to read, we
# circuit-break so 2 failures cause a 10-minute backoff.
with open(settings.TOR_EXIT_NODE_FILE_PATH, "rb") as f:
exit_node_list = orjson.loads(f.read())
# This should always be non-empty; if it's empty, assume something
# went wrong with writing and treat it as a non-existent file.
# Circuit-breaking will ensure that we back off on re-reading the
# file.
if len(exit_node_list) == 0:
raise OSError("File is empty")
return set(exit_node_list)
def rate_limit_ip(request: HttpRequest, ip_addr: str, domain: str) -> None:
RateLimitedIPAddr(ip_addr, domain=domain).rate_limit_request(request)
def rate_limit_request_by_ip(request: HttpRequest, domain: str) -> None:
# REMOTE_ADDR is set by SetRemoteAddrFromRealIpHeader in conjunction
# with the nginx configuration to guarantee this to be *the* correct
# IP address to use - without worrying we'll grab the IP of a proxy.
ip_addr = request.META["REMOTE_ADDR"]
assert ip_addr
try:
# We lump all TOR exit nodes into one bucket; this prevents
# abuse from TOR, while still allowing some access to these
# endpoints for legitimate users. Checking for local
# addresses is a shortcut somewhat for ease of testing without
# mocking the TOR endpoint in every test.
if is_local_addr(ip_addr):
pass
elif ip_addr in get_tor_ips():
ip_addr = "tor-exit-node"
except (OSError, CircuitBreakerError) as err:
# In the event that we can't get an updated list of TOR exit
# nodes, assume the IP is _not_ one, and leave it unchanged.
# We log a warning so that this endpoint being taken out of
# service doesn't silently remove this functionality.
rate_limiter_logger.warning("Failed to fetch TOR exit node list: %s", err)
pass
rate_limit_ip(request, ip_addr, domain=domain)
def rate_limit_remote_server(
request: HttpRequest, remote_server: "RemoteZulipServer", domain: str
) -> None:
try:
RateLimitedRemoteZulipServer(remote_server, domain=domain).rate_limit_request(request)
except RateLimited as e:
rate_limiter_logger.warning(
"Remote server %s exceeded rate limits on domain %s", remote_server, domain
)
raise e
def rate_limit(request: HttpRequest) -> None:
if not settings.RATE_LIMITING:
return
if client_is_exempt_from_rate_limiting(request):
return
user = request.user
remote_server = RequestNotes.get_notes(request).remote_server
if settings.ZILENCER_ENABLED and remote_server is not None:
rate_limit_remote_server(request, remote_server, domain="api_by_remote_server")
elif not user.is_authenticated:
rate_limit_request_by_ip(request, domain="api_by_ip")
else:
assert isinstance(user, UserProfile)
rate_limit_user(request, user, domain="api_by_user")
def return_success_on_head_request(
view_func: Callable[Concatenate[HttpRequest, ParamT], HttpResponse]
) -> Callable[Concatenate[HttpRequest, ParamT], HttpResponse]:

View File

@ -22,14 +22,13 @@ from two_factor.forms import AuthenticationTokenForm as TwoFactorAuthenticationT
from two_factor.utils import totp_digits
from zerver.actions.user_settings import do_change_password
from zerver.decorator import rate_limit_request_by_ip
from zerver.lib.email_validation import (
email_allowed_for_realm,
email_reserved_for_system_bots_error,
)
from zerver.lib.exceptions import JsonableError, RateLimited
from zerver.lib.name_restrictions import is_disposable_domain, is_reserved_subdomain
from zerver.lib.rate_limiter import RateLimitedObject
from zerver.lib.rate_limiter import RateLimitedObject, rate_limit_request_by_ip
from zerver.lib.send_email import FromAddress, send_email
from zerver.lib.soft_deactivation import queue_soft_reactivation
from zerver.lib.subdomains import get_subdomain, is_root_domain_available

View File

@ -2,17 +2,23 @@ import logging
import os
import time
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple, Type, cast
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Type, cast
import orjson
import redis
from circuitbreaker import CircuitBreakerError, circuit
from django.conf import settings
from django.http import HttpRequest
from zerver.lib.cache import cache_with_key
from zerver.lib.exceptions import RateLimited
from zerver.lib.redis_utils import get_redis_client
from zerver.lib.utils import statsd
from zerver.models import UserProfile
if TYPE_CHECKING:
from zilencer.models import RemoteZulipServer
# Implement a rate-limiting scheme inspired by the one described here, but heavily modified
# https://www.domaintools.com/resources/blog/rate-limiting-with-redis
@ -534,3 +540,114 @@ def rate_limit_spectator_attachment_access_by_file(path_id: str) -> None:
ratelimited, _ = RateLimitedSpectatorAttachmentAccessByFile(path_id).rate_limit()
if ratelimited:
raise RateLimited
def is_local_addr(addr: str) -> bool:
return addr in ("127.0.0.1", "::1")
@cache_with_key(lambda: "tor_ip_addresses:", timeout=60 * 60)
@circuit(failure_threshold=2, recovery_timeout=60 * 10)
def get_tor_ips() -> Set[str]:
if not settings.RATE_LIMIT_TOR_TOGETHER:
return set()
# Cron job in /etc/cron.d/fetch-tor-exit-nodes fetches this
# hourly; we cache it in memcached to prevent going to disk on
# every unauth'd request. In case of failures to read, we
# circuit-break so 2 failures cause a 10-minute backoff.
with open(settings.TOR_EXIT_NODE_FILE_PATH, "rb") as f:
exit_node_list = orjson.loads(f.read())
# This should always be non-empty; if it's empty, assume something
# went wrong with writing and treat it as a non-existent file.
# Circuit-breaking will ensure that we back off on re-reading the
# file.
if len(exit_node_list) == 0:
raise OSError("File is empty")
return set(exit_node_list)
def client_is_exempt_from_rate_limiting(request: HttpRequest) -> bool:
from zerver.lib.request import RequestNotes
# Don't rate limit requests from Django that come from our own servers,
# and don't rate-limit dev instances
client = RequestNotes.get_notes(request).client
return (client is not None and client.name.lower() == "internal") and (
is_local_addr(request.META["REMOTE_ADDR"]) or settings.DEBUG_RATE_LIMITING
)
def rate_limit_user(request: HttpRequest, user: UserProfile, domain: str) -> None:
"""Returns whether or not a user was rate limited. Will raise a RateLimited exception
if the user has been rate limited, otherwise returns and modifies request to contain
the rate limit information"""
RateLimitedUser(user, domain=domain).rate_limit_request(request)
def rate_limit_ip(request: HttpRequest, ip_addr: str, domain: str) -> None:
RateLimitedIPAddr(ip_addr, domain=domain).rate_limit_request(request)
def rate_limit_request_by_ip(request: HttpRequest, domain: str) -> None:
# REMOTE_ADDR is set by SetRemoteAddrFromRealIpHeader in conjunction
# with the nginx configuration to guarantee this to be *the* correct
# IP address to use - without worrying we'll grab the IP of a proxy.
ip_addr = request.META["REMOTE_ADDR"]
assert ip_addr
try:
# We lump all TOR exit nodes into one bucket; this prevents
# abuse from TOR, while still allowing some access to these
# endpoints for legitimate users. Checking for local
# addresses is a shortcut somewhat for ease of testing without
# mocking the TOR endpoint in every test.
if is_local_addr(ip_addr):
pass
elif ip_addr in get_tor_ips():
ip_addr = "tor-exit-node"
except (OSError, CircuitBreakerError) as err:
# In the event that we can't get an updated list of TOR exit
# nodes, assume the IP is _not_ one, and leave it unchanged.
# We log a warning so that this endpoint being taken out of
# service doesn't silently remove this functionality.
logger.warning("Failed to fetch TOR exit node list: %s", err)
pass
rate_limit_ip(request, ip_addr, domain=domain)
def rate_limit_remote_server(
request: HttpRequest, remote_server: "RemoteZulipServer", domain: str
) -> None:
if settings.ZILENCER_ENABLED:
from zilencer.models import RateLimitedRemoteZulipServer
try:
RateLimitedRemoteZulipServer(remote_server, domain=domain).rate_limit_request(request)
except RateLimited as e:
logger.warning("Remote server %s exceeded rate limits on domain %s", remote_server, domain)
raise e
def rate_limit(request: HttpRequest) -> None:
if not settings.RATE_LIMITING:
return
if client_is_exempt_from_rate_limiting(request):
return
from zerver.lib.request import RequestNotes
user = request.user
remote_server = RequestNotes.get_notes(request).remote_server
if settings.ZILENCER_ENABLED and remote_server is not None:
rate_limit_remote_server(request, remote_server, domain="api_by_remote_server")
elif not user.is_authenticated:
rate_limit_request_by_ip(request, domain="api_by_ip")
else:
assert isinstance(user, UserProfile)
rate_limit_user(request, user, domain="api_by_user")

View File

@ -27,8 +27,6 @@ from zerver.decorator import (
authenticated_rest_api_view,
authenticated_uploads_api_view,
internal_notify_view,
is_local_addr,
rate_limit,
return_success_on_head_request,
validate_api_key,
webhook_view,
@ -46,6 +44,7 @@ from zerver.lib.exceptions import (
UnsupportedWebhookEventType,
)
from zerver.lib.initial_password import initial_password
from zerver.lib.rate_limiter import is_local_addr, rate_limit
from zerver.lib.request import (
REQ,
RequestConfusingParmsError,
@ -654,8 +653,10 @@ class RateLimitTestCase(ZulipTestCase):
f = self.get_ratelimited_view()
with self.settings(RATE_LIMITING=True):
with mock.patch("zerver.decorator.rate_limit_user") as rate_limit_user_mock, mock.patch(
"zerver.decorator.rate_limit_ip"
with mock.patch(
"zerver.lib.rate_limiter.rate_limit_user"
) as rate_limit_user_mock, mock.patch(
"zerver.lib.rate_limiter.rate_limit_ip"
) as rate_limit_ip_mock:
with self.errors_disallowed():
self.assertEqual(orjson.loads(f(request).content).get("msg"), "some value")
@ -671,8 +672,10 @@ class RateLimitTestCase(ZulipTestCase):
f = self.get_ratelimited_view()
with self.settings(RATE_LIMITING=True):
with mock.patch("zerver.decorator.rate_limit_user") as rate_limit_user_mock, mock.patch(
"zerver.decorator.rate_limit_ip"
with mock.patch(
"zerver.lib.rate_limiter.rate_limit_user"
) as rate_limit_user_mock, mock.patch(
"zerver.lib.rate_limiter.rate_limit_ip"
) as rate_limit_ip_mock:
with self.errors_disallowed():
with self.settings(DEBUG_RATE_LIMITING=True):
@ -690,8 +693,10 @@ class RateLimitTestCase(ZulipTestCase):
f = self.get_ratelimited_view()
with self.settings(RATE_LIMITING=False):
with mock.patch("zerver.decorator.rate_limit_user") as rate_limit_user_mock, mock.patch(
"zerver.decorator.rate_limit_ip"
with mock.patch(
"zerver.lib.rate_limiter.rate_limit_user"
) as rate_limit_user_mock, mock.patch(
"zerver.lib.rate_limiter.rate_limit_ip"
) as rate_limit_ip_mock:
with self.errors_disallowed():
self.assertEqual(orjson.loads(f(req).content).get("msg"), "some value")
@ -708,7 +713,7 @@ class RateLimitTestCase(ZulipTestCase):
f = self.get_ratelimited_view()
with self.settings(RATE_LIMITING=True):
with mock.patch("zerver.decorator.rate_limit_user") as rate_limit_mock:
with mock.patch("zerver.lib.rate_limiter.rate_limit_user") as rate_limit_mock:
with self.errors_disallowed():
self.assertEqual(orjson.loads(f(req).content).get("msg"), "some value")
@ -730,7 +735,7 @@ class RateLimitTestCase(ZulipTestCase):
f = self.get_ratelimited_view()
with self.settings(RATE_LIMITING=True):
with mock.patch("zerver.decorator.rate_limit_remote_server") as rate_limit_mock:
with mock.patch("zerver.lib.rate_limiter.rate_limit_remote_server") as rate_limit_mock:
with self.errors_disallowed():
self.assertEqual(orjson.loads(f(req).content).get("msg"), "some value")
@ -744,7 +749,7 @@ class RateLimitTestCase(ZulipTestCase):
f = self.get_ratelimited_view()
with self.settings(RATE_LIMITING=True):
with mock.patch("zerver.decorator.rate_limit_ip") as rate_limit_mock:
with mock.patch("zerver.lib.rate_limiter.rate_limit_ip") as rate_limit_mock:
with self.errors_disallowed():
self.assertEqual(orjson.loads(f(req).content).get("msg"), "some value")

View File

@ -12,7 +12,6 @@ from django.core.exceptions import ValidationError
from django.test import override_settings
from django.utils.timezone import now as timezone_now
from zerver import decorator
from zerver.forms import email_is_not_mit_mailing_list
from zerver.lib.cache import cache_delete
from zerver.lib.rate_limiter import (
@ -20,6 +19,7 @@ from zerver.lib.rate_limiter import (
RateLimitedUser,
RateLimiterLockingException,
add_ratelimit_rule,
get_tor_ips,
remove_ratelimit_rule,
)
from zerver.lib.test_classes import ZulipTestCase
@ -308,7 +308,7 @@ class RateLimitTests(ZulipTestCase):
"circuitbreaker.CircuitBreaker.opened", new_callable=mock.PropertyMock
) as mock_opened:
mock_opened.return_value = False
decorator.get_tor_ips()
get_tor_ips()
# Having closed it, it's now cached. Clear the cache.
assert CircuitBreakerMonitor.get("get_tor_ips").closed

View File

@ -34,7 +34,7 @@ from zerver.actions.user_settings import (
do_change_user_setting,
)
from zerver.context_processors import get_realm_from_request, login_context
from zerver.decorator import do_login, rate_limit_request_by_ip, require_post
from zerver.decorator import do_login, require_post
from zerver.forms import (
FindMyTeamForm,
HomepageForm,
@ -46,6 +46,7 @@ from zerver.lib.email_validation import email_allowed_for_realm, validate_email_
from zerver.lib.exceptions import RateLimited
from zerver.lib.i18n import get_default_language_for_new_user
from zerver.lib.pysa import mark_sanitized
from zerver.lib.rate_limiter import rate_limit_request_by_ip
from zerver.lib.request import REQ, has_request_variables
from zerver.lib.send_email import EmailNotDeliveredException, FromAddress, send_email
from zerver.lib.sessions import get_expirable_session_var

View File

@ -75,14 +75,13 @@ from zerver.actions.create_user import do_create_user, do_reactivate_user
from zerver.actions.custom_profile_fields import do_update_user_custom_profile_data_if_changed
from zerver.actions.user_settings import do_regenerate_api_key
from zerver.actions.users import do_deactivate_user
from zerver.decorator import client_is_exempt_from_rate_limiting
from zerver.lib.avatar import avatar_url, is_avatar_new
from zerver.lib.avatar_hash import user_avatar_content_hash
from zerver.lib.dev_ldap_directory import init_fakeldap
from zerver.lib.email_validation import email_allowed_for_realm, validate_email_not_already_in_realm
from zerver.lib.exceptions import JsonableError
from zerver.lib.mobile_auth_otp import is_valid_otp
from zerver.lib.rate_limiter import RateLimitedObject
from zerver.lib.rate_limiter import RateLimitedObject, client_is_exempt_from_rate_limiting
from zerver.lib.redis_utils import get_dict_from_redis, get_redis_client, put_dict_in_redis
from zerver.lib.request import RequestNotes
from zerver.lib.sessions import delete_user_sessions