request: Refactor to record rate limit data using ZulipRequestNotes.

We will no longer use the HttpRequest to store the rate limit data.
Using ZulipRequestNotes, we can access rate_limit and ratelimits_applied
with type hints support. We also save the process of initializing
ratelimits_applied by giving it a default value.
This commit is contained in:
PIG208 2021-07-09 19:15:19 +08:00 committed by Tim Abbott
parent da6e5ddcae
commit 3f9a5e1e17
4 changed files with 15 additions and 14 deletions

View File

@ -42,11 +42,12 @@ class RateLimitedObject(ABC):
) )
def rate_limit_request(self, request: HttpRequest) -> None: def rate_limit_request(self, request: HttpRequest) -> None:
ratelimited, time = self.rate_limit() from zerver.lib.request import get_request_notes
if not hasattr(request, "_ratelimits_applied"): ratelimited, time = self.rate_limit()
request._ratelimits_applied = [] request_notes = get_request_notes(request)
request._ratelimits_applied.append(
request_notes.ratelimits_applied.append(
RateLimitResult( RateLimitResult(
entity=self, entity=self,
secs_to_freedom=time, secs_to_freedom=time,
@ -61,8 +62,8 @@ class RateLimitedObject(ABC):
calls_remaining, seconds_until_reset = self.api_calls_left() calls_remaining, seconds_until_reset = self.api_calls_left()
request._ratelimits_applied[-1].remaining = calls_remaining request_notes.ratelimits_applied[-1].remaining = calls_remaining
request._ratelimits_applied[-1].secs_to_freedom = seconds_until_reset request_notes.ratelimits_applied[-1].secs_to_freedom = seconds_until_reset
def block_access(self, seconds: int) -> None: def block_access(self, seconds: int) -> None:
"Manually blocks an entity for the desired number of seconds" "Manually blocks an entity for the desired number of seconds"

View File

@ -544,8 +544,9 @@ class RateLimitMiddleware(MiddlewareMixin):
return response return response
# Add X-RateLimit-*** headers # Add X-RateLimit-*** headers
if hasattr(request, "_ratelimits_applied"): ratelimits_applied = get_request_notes(request).ratelimits_applied
self.set_response_headers(response, request._ratelimits_applied) if len(ratelimits_applied) > 0:
self.set_response_headers(response, ratelimits_applied)
return response return response

View File

@ -235,8 +235,8 @@ class AsyncDjangoHandler(tornado.web.RequestHandler, base.BaseHandler):
# Add to this new HttpRequest logging data from the processing of # Add to this new HttpRequest logging data from the processing of
# the original request; we will need these for logging. # the original request; we will need these for logging.
request_notes.log_data = old_request_notes.log_data request_notes.log_data = old_request_notes.log_data
if hasattr(request, "_rate_limit"): if request_notes.rate_limit is not None:
request._rate_limit = old_request._rate_limit request_notes.rate_limit = old_request_notes.rate_limit
if hasattr(request, "_requestor_for_logs"): if hasattr(request, "_requestor_for_logs"):
request._requestor_for_logs = old_request._requestor_for_logs request._requestor_for_logs = old_request._requestor_for_logs
request.user = old_request.user request.user = old_request.user

View File

@ -70,7 +70,7 @@ from zerver.lib.email_validation import email_allowed_for_realm, validate_email_
from zerver.lib.mobile_auth_otp import is_valid_otp from zerver.lib.mobile_auth_otp import is_valid_otp
from zerver.lib.rate_limiter import RateLimitedObject from zerver.lib.rate_limiter import RateLimitedObject
from zerver.lib.redis_utils import get_dict_from_redis, get_redis_client, put_dict_in_redis from zerver.lib.redis_utils import get_dict_from_redis, get_redis_client, put_dict_in_redis
from zerver.lib.request import JsonableError from zerver.lib.request import JsonableError, get_request_notes
from zerver.lib.subdomains import get_subdomain from zerver.lib.subdomains import get_subdomain
from zerver.lib.users import check_full_name, validate_user_custom_profile_field from zerver.lib.users import check_full_name, validate_user_custom_profile_field
from zerver.models import ( from zerver.models import (
@ -259,12 +259,11 @@ def rate_limit_authentication_by_username(request: HttpRequest, username: str) -
def auth_rate_limiting_already_applied(request: HttpRequest) -> bool: def auth_rate_limiting_already_applied(request: HttpRequest) -> bool:
if not hasattr(request, "_ratelimits_applied"): request_notes = get_request_notes(request)
return False
return any( return any(
isinstance(r.entity, RateLimitedAuthenticationByUsername) isinstance(r.entity, RateLimitedAuthenticationByUsername)
for r in request._ratelimits_applied for r in request_notes.ratelimits_applied
) )