From c9f54766c3b45dbca5252fb6d39d3fa5465007d6 Mon Sep 17 00:00:00 2001 From: Zixuan James Li Date: Fri, 5 Aug 2022 11:40:03 -0400 Subject: [PATCH] rate_limiter: Extract rate limit related functions. This refactors rate limit related functions from `zerver.decorator` to zerver.lib.rate_limiter. We conditionally import `RemoteZulipServer`, `RequestNotes`, and `RateLimitedRemoteZulipServer` to avoid circular dependency. Most instances of importing these functions from `zerver.decorator` got updated, with a few exceptions in `zerver.tests.test_decorators`, where we do want to mock the rate limiting functions imported in `zerver.decorator`. The same goes with the mocking example in the "testing-with-django" documentation. Signed-off-by: Zixuan James Li --- zerver/decorator.py | 135 +------------------------------- zerver/forms.py | 3 +- zerver/lib/rate_limiter.py | 119 +++++++++++++++++++++++++++- zerver/tests/test_decorators.py | 27 ++++--- zerver/tests/test_external.py | 4 +- zerver/views/registration.py | 3 +- zproject/backends.py | 3 +- 7 files changed, 143 insertions(+), 151 deletions(-) diff --git a/zerver/decorator.py b/zerver/decorator.py index 2b6d7c2e78..ecb881b2a6 100644 --- a/zerver/decorator.py +++ b/zerver/decorator.py @@ -4,22 +4,9 @@ import logging import urllib from functools import wraps from io import BytesIO -from typing import ( - TYPE_CHECKING, - Callable, - Dict, - Optional, - Sequence, - Set, - TypeVar, - Union, - cast, - overload, -) +from typing import TYPE_CHECKING, Callable, Dict, Optional, Sequence, TypeVar, Union, cast, overload import django_otp -import orjson -from circuitbreaker import CircuitBreakerError, circuit from django.conf import settings from django.contrib.auth import REDIRECT_FIELD_NAME from django.contrib.auth import login as django_login @@ -38,7 +25,6 @@ from django_otp import user_has_device from two_factor.utils import default_device from typing_extensions import Concatenate, ParamSpec -from zerver.lib.cache import cache_with_key from zerver.lib.exceptions import ( AccessDeniedError, AnomalousWebhookPayload, @@ -50,7 +36,6 @@ from zerver.lib.exceptions import ( OrganizationAdministratorRequired, OrganizationMemberRequired, OrganizationOwnerRequired, - RateLimited, RealmDeactivatedError, RemoteServerDeactivatedError, UnauthorizedError, @@ -59,7 +44,7 @@ from zerver.lib.exceptions import ( WebhookError, ) from zerver.lib.queue import queue_json_publish -from zerver.lib.rate_limiter import RateLimitedIPAddr, RateLimitedUser +from zerver.lib.rate_limiter import is_local_addr, rate_limit, rate_limit_user from zerver.lib.request import REQ, RequestNotes, has_request_variables from zerver.lib.response import json_method_not_allowed, json_success from zerver.lib.subdomains import get_subdomain, user_matches_subdomain @@ -70,17 +55,11 @@ from zerver.lib.utils import has_api_key_format, statsd from zerver.models import Realm, UserProfile, get_client, get_user_profile_by_api_key if settings.ZILENCER_ENABLED: - from zilencer.models import ( - RateLimitedRemoteZulipServer, - RemoteZulipServer, - get_remote_server_by_uuid, - ) + from zilencer.models import RemoteZulipServer, get_remote_server_by_uuid if TYPE_CHECKING: from django.http.request import _ImmutableQueryDict -rate_limiter_logger = logging.getLogger("zerver.lib.rate_limiter") - webhook_logger = logging.getLogger("zulip.zerver.webhooks") webhook_unsupported_events_logger = logging.getLogger("zulip.zerver.webhooks.unsupported") webhook_anomalous_payloads_logger = logging.getLogger("zulip.zerver.webhooks.anomalous") @@ -907,10 +886,6 @@ def authenticated_json_view( return _wrapped_view_func -def is_local_addr(addr: str) -> bool: - return addr in ("127.0.0.1", "::1") - - # These views are used by the main Django server to notify the Tornado server # of events. We protect them from the outside world by checking a shared # secret, and also the originating IP (for now). @@ -921,16 +896,6 @@ def authenticate_notify(request: HttpRequest, secret: str = REQ("secret")) -> bo ) -def client_is_exempt_from_rate_limiting(request: HttpRequest) -> bool: - - # Don't rate limit requests from Django that come from our own servers, - # and don't rate-limit dev instances - client = RequestNotes.get_notes(request).client - return (client is not None and client.name.lower() == "internal") and ( - is_local_addr(request.META["REMOTE_ADDR"]) or settings.DEBUG_RATE_LIMITING - ) - - def internal_notify_view( is_tornado_view: bool, ) -> Callable[ @@ -992,100 +957,6 @@ def statsd_increment( return wrapper -def rate_limit_user(request: HttpRequest, user: UserProfile, domain: str) -> None: - """Returns whether or not a user was rate limited. Will raise a RateLimited exception - if the user has been rate limited, otherwise returns and modifies request to contain - the rate limit information""" - - RateLimitedUser(user, domain=domain).rate_limit_request(request) - - -@cache_with_key(lambda: "tor_ip_addresses:", timeout=60 * 60) -@circuit(failure_threshold=2, recovery_timeout=60 * 10) -def get_tor_ips() -> Set[str]: - if not settings.RATE_LIMIT_TOR_TOGETHER: - return set() - - # Cron job in /etc/cron.d/fetch-tor-exit-nodes fetches this - # hourly; we cache it in memcached to prevent going to disk on - # every unauth'd request. In case of failures to read, we - # circuit-break so 2 failures cause a 10-minute backoff. - - with open(settings.TOR_EXIT_NODE_FILE_PATH, "rb") as f: - exit_node_list = orjson.loads(f.read()) - - # This should always be non-empty; if it's empty, assume something - # went wrong with writing and treat it as a non-existent file. - # Circuit-breaking will ensure that we back off on re-reading the - # file. - if len(exit_node_list) == 0: - raise OSError("File is empty") - - return set(exit_node_list) - - -def rate_limit_ip(request: HttpRequest, ip_addr: str, domain: str) -> None: - RateLimitedIPAddr(ip_addr, domain=domain).rate_limit_request(request) - - -def rate_limit_request_by_ip(request: HttpRequest, domain: str) -> None: - # REMOTE_ADDR is set by SetRemoteAddrFromRealIpHeader in conjunction - # with the nginx configuration to guarantee this to be *the* correct - # IP address to use - without worrying we'll grab the IP of a proxy. - ip_addr = request.META["REMOTE_ADDR"] - assert ip_addr - - try: - # We lump all TOR exit nodes into one bucket; this prevents - # abuse from TOR, while still allowing some access to these - # endpoints for legitimate users. Checking for local - # addresses is a shortcut somewhat for ease of testing without - # mocking the TOR endpoint in every test. - if is_local_addr(ip_addr): - pass - elif ip_addr in get_tor_ips(): - ip_addr = "tor-exit-node" - except (OSError, CircuitBreakerError) as err: - # In the event that we can't get an updated list of TOR exit - # nodes, assume the IP is _not_ one, and leave it unchanged. - # We log a warning so that this endpoint being taken out of - # service doesn't silently remove this functionality. - rate_limiter_logger.warning("Failed to fetch TOR exit node list: %s", err) - pass - rate_limit_ip(request, ip_addr, domain=domain) - - -def rate_limit_remote_server( - request: HttpRequest, remote_server: "RemoteZulipServer", domain: str -) -> None: - try: - RateLimitedRemoteZulipServer(remote_server, domain=domain).rate_limit_request(request) - except RateLimited as e: - rate_limiter_logger.warning( - "Remote server %s exceeded rate limits on domain %s", remote_server, domain - ) - raise e - - -def rate_limit(request: HttpRequest) -> None: - if not settings.RATE_LIMITING: - return - - if client_is_exempt_from_rate_limiting(request): - return - - user = request.user - remote_server = RequestNotes.get_notes(request).remote_server - - if settings.ZILENCER_ENABLED and remote_server is not None: - rate_limit_remote_server(request, remote_server, domain="api_by_remote_server") - elif not user.is_authenticated: - rate_limit_request_by_ip(request, domain="api_by_ip") - else: - assert isinstance(user, UserProfile) - rate_limit_user(request, user, domain="api_by_user") - - def return_success_on_head_request( view_func: Callable[Concatenate[HttpRequest, ParamT], HttpResponse] ) -> Callable[Concatenate[HttpRequest, ParamT], HttpResponse]: diff --git a/zerver/forms.py b/zerver/forms.py index cabf928647..36f46fada4 100644 --- a/zerver/forms.py +++ b/zerver/forms.py @@ -22,14 +22,13 @@ from two_factor.forms import AuthenticationTokenForm as TwoFactorAuthenticationT from two_factor.utils import totp_digits from zerver.actions.user_settings import do_change_password -from zerver.decorator import rate_limit_request_by_ip from zerver.lib.email_validation import ( email_allowed_for_realm, email_reserved_for_system_bots_error, ) from zerver.lib.exceptions import JsonableError, RateLimited from zerver.lib.name_restrictions import is_disposable_domain, is_reserved_subdomain -from zerver.lib.rate_limiter import RateLimitedObject +from zerver.lib.rate_limiter import RateLimitedObject, rate_limit_request_by_ip from zerver.lib.send_email import FromAddress, send_email from zerver.lib.soft_deactivation import queue_soft_reactivation from zerver.lib.subdomains import get_subdomain, is_root_domain_available diff --git a/zerver/lib/rate_limiter.py b/zerver/lib/rate_limiter.py index 6921f9279f..3ddf7ef741 100644 --- a/zerver/lib/rate_limiter.py +++ b/zerver/lib/rate_limiter.py @@ -2,17 +2,23 @@ import logging import os import time from abc import ABC, abstractmethod -from typing import Dict, List, Optional, Tuple, Type, cast +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Type, cast +import orjson import redis +from circuitbreaker import CircuitBreakerError, circuit from django.conf import settings from django.http import HttpRequest +from zerver.lib.cache import cache_with_key from zerver.lib.exceptions import RateLimited from zerver.lib.redis_utils import get_redis_client from zerver.lib.utils import statsd from zerver.models import UserProfile +if TYPE_CHECKING: + from zilencer.models import RemoteZulipServer + # Implement a rate-limiting scheme inspired by the one described here, but heavily modified # https://www.domaintools.com/resources/blog/rate-limiting-with-redis @@ -534,3 +540,114 @@ def rate_limit_spectator_attachment_access_by_file(path_id: str) -> None: ratelimited, _ = RateLimitedSpectatorAttachmentAccessByFile(path_id).rate_limit() if ratelimited: raise RateLimited + + +def is_local_addr(addr: str) -> bool: + return addr in ("127.0.0.1", "::1") + + +@cache_with_key(lambda: "tor_ip_addresses:", timeout=60 * 60) +@circuit(failure_threshold=2, recovery_timeout=60 * 10) +def get_tor_ips() -> Set[str]: + if not settings.RATE_LIMIT_TOR_TOGETHER: + return set() + + # Cron job in /etc/cron.d/fetch-tor-exit-nodes fetches this + # hourly; we cache it in memcached to prevent going to disk on + # every unauth'd request. In case of failures to read, we + # circuit-break so 2 failures cause a 10-minute backoff. + + with open(settings.TOR_EXIT_NODE_FILE_PATH, "rb") as f: + exit_node_list = orjson.loads(f.read()) + + # This should always be non-empty; if it's empty, assume something + # went wrong with writing and treat it as a non-existent file. + # Circuit-breaking will ensure that we back off on re-reading the + # file. + if len(exit_node_list) == 0: + raise OSError("File is empty") + + return set(exit_node_list) + + +def client_is_exempt_from_rate_limiting(request: HttpRequest) -> bool: + from zerver.lib.request import RequestNotes + + # Don't rate limit requests from Django that come from our own servers, + # and don't rate-limit dev instances + client = RequestNotes.get_notes(request).client + return (client is not None and client.name.lower() == "internal") and ( + is_local_addr(request.META["REMOTE_ADDR"]) or settings.DEBUG_RATE_LIMITING + ) + + +def rate_limit_user(request: HttpRequest, user: UserProfile, domain: str) -> None: + """Returns whether or not a user was rate limited. Will raise a RateLimited exception + if the user has been rate limited, otherwise returns and modifies request to contain + the rate limit information""" + + RateLimitedUser(user, domain=domain).rate_limit_request(request) + + +def rate_limit_ip(request: HttpRequest, ip_addr: str, domain: str) -> None: + RateLimitedIPAddr(ip_addr, domain=domain).rate_limit_request(request) + + +def rate_limit_request_by_ip(request: HttpRequest, domain: str) -> None: + # REMOTE_ADDR is set by SetRemoteAddrFromRealIpHeader in conjunction + # with the nginx configuration to guarantee this to be *the* correct + # IP address to use - without worrying we'll grab the IP of a proxy. + ip_addr = request.META["REMOTE_ADDR"] + assert ip_addr + + try: + # We lump all TOR exit nodes into one bucket; this prevents + # abuse from TOR, while still allowing some access to these + # endpoints for legitimate users. Checking for local + # addresses is a shortcut somewhat for ease of testing without + # mocking the TOR endpoint in every test. + if is_local_addr(ip_addr): + pass + elif ip_addr in get_tor_ips(): + ip_addr = "tor-exit-node" + except (OSError, CircuitBreakerError) as err: + # In the event that we can't get an updated list of TOR exit + # nodes, assume the IP is _not_ one, and leave it unchanged. + # We log a warning so that this endpoint being taken out of + # service doesn't silently remove this functionality. + logger.warning("Failed to fetch TOR exit node list: %s", err) + pass + rate_limit_ip(request, ip_addr, domain=domain) + + +def rate_limit_remote_server( + request: HttpRequest, remote_server: "RemoteZulipServer", domain: str +) -> None: + if settings.ZILENCER_ENABLED: + from zilencer.models import RateLimitedRemoteZulipServer + try: + RateLimitedRemoteZulipServer(remote_server, domain=domain).rate_limit_request(request) + except RateLimited as e: + logger.warning("Remote server %s exceeded rate limits on domain %s", remote_server, domain) + raise e + + +def rate_limit(request: HttpRequest) -> None: + if not settings.RATE_LIMITING: + return + + if client_is_exempt_from_rate_limiting(request): + return + + from zerver.lib.request import RequestNotes + + user = request.user + remote_server = RequestNotes.get_notes(request).remote_server + + if settings.ZILENCER_ENABLED and remote_server is not None: + rate_limit_remote_server(request, remote_server, domain="api_by_remote_server") + elif not user.is_authenticated: + rate_limit_request_by_ip(request, domain="api_by_ip") + else: + assert isinstance(user, UserProfile) + rate_limit_user(request, user, domain="api_by_user") diff --git a/zerver/tests/test_decorators.py b/zerver/tests/test_decorators.py index 70768681d9..b47abf5847 100644 --- a/zerver/tests/test_decorators.py +++ b/zerver/tests/test_decorators.py @@ -27,8 +27,6 @@ from zerver.decorator import ( authenticated_rest_api_view, authenticated_uploads_api_view, internal_notify_view, - is_local_addr, - rate_limit, return_success_on_head_request, validate_api_key, webhook_view, @@ -46,6 +44,7 @@ from zerver.lib.exceptions import ( UnsupportedWebhookEventType, ) from zerver.lib.initial_password import initial_password +from zerver.lib.rate_limiter import is_local_addr, rate_limit from zerver.lib.request import ( REQ, RequestConfusingParmsError, @@ -654,8 +653,10 @@ class RateLimitTestCase(ZulipTestCase): f = self.get_ratelimited_view() with self.settings(RATE_LIMITING=True): - with mock.patch("zerver.decorator.rate_limit_user") as rate_limit_user_mock, mock.patch( - "zerver.decorator.rate_limit_ip" + with mock.patch( + "zerver.lib.rate_limiter.rate_limit_user" + ) as rate_limit_user_mock, mock.patch( + "zerver.lib.rate_limiter.rate_limit_ip" ) as rate_limit_ip_mock: with self.errors_disallowed(): self.assertEqual(orjson.loads(f(request).content).get("msg"), "some value") @@ -671,8 +672,10 @@ class RateLimitTestCase(ZulipTestCase): f = self.get_ratelimited_view() with self.settings(RATE_LIMITING=True): - with mock.patch("zerver.decorator.rate_limit_user") as rate_limit_user_mock, mock.patch( - "zerver.decorator.rate_limit_ip" + with mock.patch( + "zerver.lib.rate_limiter.rate_limit_user" + ) as rate_limit_user_mock, mock.patch( + "zerver.lib.rate_limiter.rate_limit_ip" ) as rate_limit_ip_mock: with self.errors_disallowed(): with self.settings(DEBUG_RATE_LIMITING=True): @@ -690,8 +693,10 @@ class RateLimitTestCase(ZulipTestCase): f = self.get_ratelimited_view() with self.settings(RATE_LIMITING=False): - with mock.patch("zerver.decorator.rate_limit_user") as rate_limit_user_mock, mock.patch( - "zerver.decorator.rate_limit_ip" + with mock.patch( + "zerver.lib.rate_limiter.rate_limit_user" + ) as rate_limit_user_mock, mock.patch( + "zerver.lib.rate_limiter.rate_limit_ip" ) as rate_limit_ip_mock: with self.errors_disallowed(): self.assertEqual(orjson.loads(f(req).content).get("msg"), "some value") @@ -708,7 +713,7 @@ class RateLimitTestCase(ZulipTestCase): f = self.get_ratelimited_view() with self.settings(RATE_LIMITING=True): - with mock.patch("zerver.decorator.rate_limit_user") as rate_limit_mock: + with mock.patch("zerver.lib.rate_limiter.rate_limit_user") as rate_limit_mock: with self.errors_disallowed(): self.assertEqual(orjson.loads(f(req).content).get("msg"), "some value") @@ -730,7 +735,7 @@ class RateLimitTestCase(ZulipTestCase): f = self.get_ratelimited_view() with self.settings(RATE_LIMITING=True): - with mock.patch("zerver.decorator.rate_limit_remote_server") as rate_limit_mock: + with mock.patch("zerver.lib.rate_limiter.rate_limit_remote_server") as rate_limit_mock: with self.errors_disallowed(): self.assertEqual(orjson.loads(f(req).content).get("msg"), "some value") @@ -744,7 +749,7 @@ class RateLimitTestCase(ZulipTestCase): f = self.get_ratelimited_view() with self.settings(RATE_LIMITING=True): - with mock.patch("zerver.decorator.rate_limit_ip") as rate_limit_mock: + with mock.patch("zerver.lib.rate_limiter.rate_limit_ip") as rate_limit_mock: with self.errors_disallowed(): self.assertEqual(orjson.loads(f(req).content).get("msg"), "some value") diff --git a/zerver/tests/test_external.py b/zerver/tests/test_external.py index 243a6b828f..41c526ef21 100644 --- a/zerver/tests/test_external.py +++ b/zerver/tests/test_external.py @@ -12,7 +12,6 @@ from django.core.exceptions import ValidationError from django.test import override_settings from django.utils.timezone import now as timezone_now -from zerver import decorator from zerver.forms import email_is_not_mit_mailing_list from zerver.lib.cache import cache_delete from zerver.lib.rate_limiter import ( @@ -20,6 +19,7 @@ from zerver.lib.rate_limiter import ( RateLimitedUser, RateLimiterLockingException, add_ratelimit_rule, + get_tor_ips, remove_ratelimit_rule, ) from zerver.lib.test_classes import ZulipTestCase @@ -308,7 +308,7 @@ class RateLimitTests(ZulipTestCase): "circuitbreaker.CircuitBreaker.opened", new_callable=mock.PropertyMock ) as mock_opened: mock_opened.return_value = False - decorator.get_tor_ips() + get_tor_ips() # Having closed it, it's now cached. Clear the cache. assert CircuitBreakerMonitor.get("get_tor_ips").closed diff --git a/zerver/views/registration.py b/zerver/views/registration.py index b86cb3eeab..f5d2541e55 100644 --- a/zerver/views/registration.py +++ b/zerver/views/registration.py @@ -34,7 +34,7 @@ from zerver.actions.user_settings import ( do_change_user_setting, ) from zerver.context_processors import get_realm_from_request, login_context -from zerver.decorator import do_login, rate_limit_request_by_ip, require_post +from zerver.decorator import do_login, require_post from zerver.forms import ( FindMyTeamForm, HomepageForm, @@ -46,6 +46,7 @@ from zerver.lib.email_validation import email_allowed_for_realm, validate_email_ from zerver.lib.exceptions import RateLimited from zerver.lib.i18n import get_default_language_for_new_user from zerver.lib.pysa import mark_sanitized +from zerver.lib.rate_limiter import rate_limit_request_by_ip from zerver.lib.request import REQ, has_request_variables from zerver.lib.send_email import EmailNotDeliveredException, FromAddress, send_email from zerver.lib.sessions import get_expirable_session_var diff --git a/zproject/backends.py b/zproject/backends.py index 38abf64b3e..3efaf6422a 100644 --- a/zproject/backends.py +++ b/zproject/backends.py @@ -75,14 +75,13 @@ from zerver.actions.create_user import do_create_user, do_reactivate_user from zerver.actions.custom_profile_fields import do_update_user_custom_profile_data_if_changed from zerver.actions.user_settings import do_regenerate_api_key from zerver.actions.users import do_deactivate_user -from zerver.decorator import client_is_exempt_from_rate_limiting from zerver.lib.avatar import avatar_url, is_avatar_new from zerver.lib.avatar_hash import user_avatar_content_hash from zerver.lib.dev_ldap_directory import init_fakeldap from zerver.lib.email_validation import email_allowed_for_realm, validate_email_not_already_in_realm from zerver.lib.exceptions import JsonableError from zerver.lib.mobile_auth_otp import is_valid_otp -from zerver.lib.rate_limiter import RateLimitedObject +from zerver.lib.rate_limiter import RateLimitedObject, client_is_exempt_from_rate_limiting from zerver.lib.redis_utils import get_dict_from_redis, get_redis_client, put_dict_in_redis from zerver.lib.request import RequestNotes from zerver.lib.sessions import delete_user_sessions