From a6a2d70320f8f1a3b90058c5f80be8d6e5579241 Mon Sep 17 00:00:00 2001 From: Mateusz Mandera Date: Sat, 28 Dec 2019 20:23:18 +0100 Subject: [PATCH] rate_limiter: Handle multiple types of rate limiting in middleware. As more types of rate limiting of requests are added, one request may end up having various limits applied to it - and the middleware needs to be able to handle that. We implement that through a set_response_headers function, which sets the X-RateLimit-* headers in a sensible way based on all the limits that were applied to the request. --- zerver/lib/rate_limiter.py | 24 ++++++++++++++++++------ zerver/middleware.py | 32 ++++++++++++++++++++++---------- 2 files changed, 40 insertions(+), 16 deletions(-) diff --git a/zerver/lib/rate_limiter.py b/zerver/lib/rate_limiter.py index 44847eacf1..5ac64bbdcb 100644 --- a/zerver/lib/rate_limiter.py +++ b/zerver/lib/rate_limiter.py @@ -275,10 +275,11 @@ def rate_limit_request_by_entity(request: HttpRequest, entity: RateLimitedObject entity_type = type(entity).__name__ if not hasattr(request, '_ratelimit'): request._ratelimit = {} - request._ratelimit[entity_type] = {} - request._ratelimit[entity_type]['applied_limits'] = True - request._ratelimit[entity_type]['secs_to_freedom'] = time - request._ratelimit[entity_type]['over_limit'] = ratelimited + request._ratelimit[entity_type] = RateLimitResult( + entity=entity, + secs_to_freedom=time, + over_limit=ratelimited + ) # Abort this request if the user is over their rate limits if ratelimited: # Pass information about what kind of entity got limited in the exception: @@ -286,5 +287,16 @@ def rate_limit_request_by_entity(request: HttpRequest, entity: RateLimitedObject calls_remaining, time_reset = api_calls_left(entity) - request._ratelimit[entity_type]['remaining'] = calls_remaining - request._ratelimit[entity_type]['secs_to_freedom'] = time_reset + request._ratelimit[entity_type].remaining = calls_remaining + request._ratelimit[entity_type].secs_to_freedom = time_reset + +class RateLimitResult: + def __init__(self, entity: RateLimitedObject, secs_to_freedom: float, over_limit: bool, + remaining: Optional[int]=None) -> None: + if over_limit: + assert not remaining + + self.entity = entity + self.secs_to_freedom = secs_to_freedom + self.over_limit = over_limit + self.remaining = remaining diff --git a/zerver/middleware.py b/zerver/middleware.py index b40c338860..704622097e 100644 --- a/zerver/middleware.py +++ b/zerver/middleware.py @@ -26,6 +26,7 @@ from zerver.lib.db import reset_queries from zerver.lib.exceptions import ErrorCode, JsonableError, RateLimited from zerver.lib.html_to_text import get_content_description from zerver.lib.queue import queue_json_publish +from zerver.lib.rate_limiter import RateLimitResult, max_api_calls from zerver.lib.response import json_error, json_response_from_error from zerver.lib.subdomains import get_subdomain from zerver.lib.utils import statsd @@ -322,20 +323,31 @@ def csrf_failure(request: HttpRequest, reason: str="") -> HttpResponse: return html_csrf_failure(request, reason) class RateLimitMiddleware(MiddlewareMixin): + def set_response_headers(self, response: HttpResponse, + rate_limit_results: List[RateLimitResult]) -> None: + # The limit on the action that was requested is the minimum of the limits that get applied: + limit = min([max_api_calls(result.entity) for result in rate_limit_results]) + response['X-RateLimit-Limit'] = str(limit) + # Same principle applies to remaining api calls: + if all(result.remaining for result in rate_limit_results): + remaining_api_calls = min([result.remaining for result in rate_limit_results]) + response['X-RateLimit-Remaining'] = str(remaining_api_calls) + else: + response['X-RateLimit-Remaining'] = str(0) + + # The full reset time is the maximum of the reset times for the limits that get applied: + reset_time = time.time() + max([result.secs_to_freedom for result in rate_limit_results]) + response['X-RateLimit-Reset'] = str(int(reset_time)) + def process_response(self, request: HttpRequest, response: HttpResponse) -> HttpResponse: if not settings.RATE_LIMITING: return response - from zerver.lib.rate_limiter import max_api_calls, RateLimitedUser # Add X-RateLimit-*** headers if hasattr(request, '_ratelimit'): - # Right now, the only kind of limiting requests is user-based. - ratelimit_user_results = request._ratelimit['RateLimitedUser'] - entity = RateLimitedUser(request.user) - response['X-RateLimit-Limit'] = str(max_api_calls(entity)) - response['X-RateLimit-Reset'] = str(int(time.time() + ratelimit_user_results['secs_to_freedom'])) - if 'remaining' in ratelimit_user_results: - response['X-RateLimit-Remaining'] = str(ratelimit_user_results['remaining']) + rate_limit_results = list(request._ratelimit.values()) + self.set_response_headers(response, rate_limit_results) + return response # TODO: When we have Django stubs, we should be able to fix the @@ -348,10 +360,10 @@ class RateLimitMiddleware(MiddlewareMixin): entity_type = str(exception) # entity type is passed to RateLimited when raising resp = json_error( _("API usage exceeded rate limit"), - data={'retry-after': request._ratelimit[entity_type]['secs_to_freedom']}, + data={'retry-after': request._ratelimit[entity_type].secs_to_freedom}, status=429 ) - resp['Retry-After'] = request._ratelimit[entity_type]['secs_to_freedom'] + resp['Retry-After'] = request._ratelimit[entity_type].secs_to_freedom return resp return None