From 3f9a5e1e17f0e5f50430300dfcf6e0cce11e899b Mon Sep 17 00:00:00 2001 From: PIG208 <359101898@qq.com> Date: Fri, 9 Jul 2021 19:15:19 +0800 Subject: [PATCH] request: Refactor to record rate limit data using ZulipRequestNotes. We will no longer use the HttpRequest to store the rate limit data. Using ZulipRequestNotes, we can access rate_limit and ratelimits_applied with type hints support. We also save the process of initializing ratelimits_applied by giving it a default value. --- zerver/lib/rate_limiter.py | 13 +++++++------ zerver/middleware.py | 5 +++-- zerver/tornado/handlers.py | 4 ++-- zproject/backends.py | 7 +++---- 4 files changed, 15 insertions(+), 14 deletions(-) diff --git a/zerver/lib/rate_limiter.py b/zerver/lib/rate_limiter.py index 5c4214d13d..4abf77858c 100644 --- a/zerver/lib/rate_limiter.py +++ b/zerver/lib/rate_limiter.py @@ -42,11 +42,12 @@ class RateLimitedObject(ABC): ) def rate_limit_request(self, request: HttpRequest) -> None: - ratelimited, time = self.rate_limit() + from zerver.lib.request import get_request_notes - if not hasattr(request, "_ratelimits_applied"): - request._ratelimits_applied = [] - request._ratelimits_applied.append( + ratelimited, time = self.rate_limit() + request_notes = get_request_notes(request) + + request_notes.ratelimits_applied.append( RateLimitResult( entity=self, secs_to_freedom=time, @@ -61,8 +62,8 @@ class RateLimitedObject(ABC): calls_remaining, seconds_until_reset = self.api_calls_left() - request._ratelimits_applied[-1].remaining = calls_remaining - request._ratelimits_applied[-1].secs_to_freedom = seconds_until_reset + request_notes.ratelimits_applied[-1].remaining = calls_remaining + request_notes.ratelimits_applied[-1].secs_to_freedom = seconds_until_reset def block_access(self, seconds: int) -> None: "Manually blocks an entity for the desired number of seconds" diff --git a/zerver/middleware.py b/zerver/middleware.py index 693bee59ef..70f6a458b8 100644 --- a/zerver/middleware.py +++ b/zerver/middleware.py @@ -544,8 +544,9 @@ class RateLimitMiddleware(MiddlewareMixin): return response # Add X-RateLimit-*** headers - if hasattr(request, "_ratelimits_applied"): - self.set_response_headers(response, request._ratelimits_applied) + ratelimits_applied = get_request_notes(request).ratelimits_applied + if len(ratelimits_applied) > 0: + self.set_response_headers(response, ratelimits_applied) return response diff --git a/zerver/tornado/handlers.py b/zerver/tornado/handlers.py index 5d2d457b79..c8ac0999ec 100644 --- a/zerver/tornado/handlers.py +++ b/zerver/tornado/handlers.py @@ -235,8 +235,8 @@ class AsyncDjangoHandler(tornado.web.RequestHandler, base.BaseHandler): # Add to this new HttpRequest logging data from the processing of # the original request; we will need these for logging. request_notes.log_data = old_request_notes.log_data - if hasattr(request, "_rate_limit"): - request._rate_limit = old_request._rate_limit + if request_notes.rate_limit is not None: + request_notes.rate_limit = old_request_notes.rate_limit if hasattr(request, "_requestor_for_logs"): request._requestor_for_logs = old_request._requestor_for_logs request.user = old_request.user diff --git a/zproject/backends.py b/zproject/backends.py index 8451ff27b6..43c4bc6c1b 100644 --- a/zproject/backends.py +++ b/zproject/backends.py @@ -70,7 +70,7 @@ from zerver.lib.email_validation import email_allowed_for_realm, validate_email_ from zerver.lib.mobile_auth_otp import is_valid_otp from zerver.lib.rate_limiter import RateLimitedObject from zerver.lib.redis_utils import get_dict_from_redis, get_redis_client, put_dict_in_redis -from zerver.lib.request import JsonableError +from zerver.lib.request import JsonableError, get_request_notes from zerver.lib.subdomains import get_subdomain from zerver.lib.users import check_full_name, validate_user_custom_profile_field from zerver.models import ( @@ -259,12 +259,11 @@ def rate_limit_authentication_by_username(request: HttpRequest, username: str) - def auth_rate_limiting_already_applied(request: HttpRequest) -> bool: - if not hasattr(request, "_ratelimits_applied"): - return False + request_notes = get_request_notes(request) return any( isinstance(r.entity, RateLimitedAuthenticationByUsername) - for r in request._ratelimits_applied + for r in request_notes.ratelimits_applied )