rate_limiter: Handle multiple types of rate limiting in middleware.

As more types of rate limiting of requests are added, one request may
end up having various limits applied to it - and the middleware needs to
be able to handle that. We implement that through a set_response_headers
function, which sets the X-RateLimit-* headers in a sensible way based
on all the limits that were applied to the request.
This commit is contained in:
Mateusz Mandera 2019-12-28 20:23:18 +01:00 committed by Tim Abbott
parent 677764d9ca
commit a6a2d70320
2 changed files with 40 additions and 16 deletions

View File

@ -275,10 +275,11 @@ def rate_limit_request_by_entity(request: HttpRequest, entity: RateLimitedObject
entity_type = type(entity).__name__ entity_type = type(entity).__name__
if not hasattr(request, '_ratelimit'): if not hasattr(request, '_ratelimit'):
request._ratelimit = {} request._ratelimit = {}
request._ratelimit[entity_type] = {} request._ratelimit[entity_type] = RateLimitResult(
request._ratelimit[entity_type]['applied_limits'] = True entity=entity,
request._ratelimit[entity_type]['secs_to_freedom'] = time secs_to_freedom=time,
request._ratelimit[entity_type]['over_limit'] = ratelimited over_limit=ratelimited
)
# Abort this request if the user is over their rate limits # Abort this request if the user is over their rate limits
if ratelimited: if ratelimited:
# Pass information about what kind of entity got limited in the exception: # Pass information about what kind of entity got limited in the exception:
@ -286,5 +287,16 @@ def rate_limit_request_by_entity(request: HttpRequest, entity: RateLimitedObject
calls_remaining, time_reset = api_calls_left(entity) calls_remaining, time_reset = api_calls_left(entity)
request._ratelimit[entity_type]['remaining'] = calls_remaining request._ratelimit[entity_type].remaining = calls_remaining
request._ratelimit[entity_type]['secs_to_freedom'] = time_reset request._ratelimit[entity_type].secs_to_freedom = time_reset
class RateLimitResult:
def __init__(self, entity: RateLimitedObject, secs_to_freedom: float, over_limit: bool,
remaining: Optional[int]=None) -> None:
if over_limit:
assert not remaining
self.entity = entity
self.secs_to_freedom = secs_to_freedom
self.over_limit = over_limit
self.remaining = remaining

View File

@ -26,6 +26,7 @@ from zerver.lib.db import reset_queries
from zerver.lib.exceptions import ErrorCode, JsonableError, RateLimited from zerver.lib.exceptions import ErrorCode, JsonableError, RateLimited
from zerver.lib.html_to_text import get_content_description from zerver.lib.html_to_text import get_content_description
from zerver.lib.queue import queue_json_publish from zerver.lib.queue import queue_json_publish
from zerver.lib.rate_limiter import RateLimitResult, max_api_calls
from zerver.lib.response import json_error, json_response_from_error from zerver.lib.response import json_error, json_response_from_error
from zerver.lib.subdomains import get_subdomain from zerver.lib.subdomains import get_subdomain
from zerver.lib.utils import statsd from zerver.lib.utils import statsd
@ -322,20 +323,31 @@ def csrf_failure(request: HttpRequest, reason: str="") -> HttpResponse:
return html_csrf_failure(request, reason) return html_csrf_failure(request, reason)
class RateLimitMiddleware(MiddlewareMixin): class RateLimitMiddleware(MiddlewareMixin):
def set_response_headers(self, response: HttpResponse,
rate_limit_results: List[RateLimitResult]) -> None:
# The limit on the action that was requested is the minimum of the limits that get applied:
limit = min([max_api_calls(result.entity) for result in rate_limit_results])
response['X-RateLimit-Limit'] = str(limit)
# Same principle applies to remaining api calls:
if all(result.remaining for result in rate_limit_results):
remaining_api_calls = min([result.remaining for result in rate_limit_results])
response['X-RateLimit-Remaining'] = str(remaining_api_calls)
else:
response['X-RateLimit-Remaining'] = str(0)
# The full reset time is the maximum of the reset times for the limits that get applied:
reset_time = time.time() + max([result.secs_to_freedom for result in rate_limit_results])
response['X-RateLimit-Reset'] = str(int(reset_time))
def process_response(self, request: HttpRequest, response: HttpResponse) -> HttpResponse: def process_response(self, request: HttpRequest, response: HttpResponse) -> HttpResponse:
if not settings.RATE_LIMITING: if not settings.RATE_LIMITING:
return response return response
from zerver.lib.rate_limiter import max_api_calls, RateLimitedUser
# Add X-RateLimit-*** headers # Add X-RateLimit-*** headers
if hasattr(request, '_ratelimit'): if hasattr(request, '_ratelimit'):
# Right now, the only kind of limiting requests is user-based. rate_limit_results = list(request._ratelimit.values())
ratelimit_user_results = request._ratelimit['RateLimitedUser'] self.set_response_headers(response, rate_limit_results)
entity = RateLimitedUser(request.user)
response['X-RateLimit-Limit'] = str(max_api_calls(entity))
response['X-RateLimit-Reset'] = str(int(time.time() + ratelimit_user_results['secs_to_freedom']))
if 'remaining' in ratelimit_user_results:
response['X-RateLimit-Remaining'] = str(ratelimit_user_results['remaining'])
return response return response
# TODO: When we have Django stubs, we should be able to fix the # TODO: When we have Django stubs, we should be able to fix the
@ -348,10 +360,10 @@ class RateLimitMiddleware(MiddlewareMixin):
entity_type = str(exception) # entity type is passed to RateLimited when raising entity_type = str(exception) # entity type is passed to RateLimited when raising
resp = json_error( resp = json_error(
_("API usage exceeded rate limit"), _("API usage exceeded rate limit"),
data={'retry-after': request._ratelimit[entity_type]['secs_to_freedom']}, data={'retry-after': request._ratelimit[entity_type].secs_to_freedom},
status=429 status=429
) )
resp['Retry-After'] = request._ratelimit[entity_type]['secs_to_freedom'] resp['Retry-After'] = request._ratelimit[entity_type].secs_to_freedom
return resp return resp
return None return None