mirror of https://github.com/zulip/zulip.git
rate_limiter: Handle multiple types of rate limiting in middleware.
As more types of rate limiting of requests are added, one request may end up having various limits applied to it - and the middleware needs to be able to handle that. We implement that through a set_response_headers function, which sets the X-RateLimit-* headers in a sensible way based on all the limits that were applied to the request.
This commit is contained in:
parent
677764d9ca
commit
a6a2d70320
|
@ -275,10 +275,11 @@ def rate_limit_request_by_entity(request: HttpRequest, entity: RateLimitedObject
|
|||
entity_type = type(entity).__name__
|
||||
if not hasattr(request, '_ratelimit'):
|
||||
request._ratelimit = {}
|
||||
request._ratelimit[entity_type] = {}
|
||||
request._ratelimit[entity_type]['applied_limits'] = True
|
||||
request._ratelimit[entity_type]['secs_to_freedom'] = time
|
||||
request._ratelimit[entity_type]['over_limit'] = ratelimited
|
||||
request._ratelimit[entity_type] = RateLimitResult(
|
||||
entity=entity,
|
||||
secs_to_freedom=time,
|
||||
over_limit=ratelimited
|
||||
)
|
||||
# Abort this request if the user is over their rate limits
|
||||
if ratelimited:
|
||||
# Pass information about what kind of entity got limited in the exception:
|
||||
|
@ -286,5 +287,16 @@ def rate_limit_request_by_entity(request: HttpRequest, entity: RateLimitedObject
|
|||
|
||||
calls_remaining, time_reset = api_calls_left(entity)
|
||||
|
||||
request._ratelimit[entity_type]['remaining'] = calls_remaining
|
||||
request._ratelimit[entity_type]['secs_to_freedom'] = time_reset
|
||||
request._ratelimit[entity_type].remaining = calls_remaining
|
||||
request._ratelimit[entity_type].secs_to_freedom = time_reset
|
||||
|
||||
class RateLimitResult:
|
||||
def __init__(self, entity: RateLimitedObject, secs_to_freedom: float, over_limit: bool,
|
||||
remaining: Optional[int]=None) -> None:
|
||||
if over_limit:
|
||||
assert not remaining
|
||||
|
||||
self.entity = entity
|
||||
self.secs_to_freedom = secs_to_freedom
|
||||
self.over_limit = over_limit
|
||||
self.remaining = remaining
|
||||
|
|
|
@ -26,6 +26,7 @@ from zerver.lib.db import reset_queries
|
|||
from zerver.lib.exceptions import ErrorCode, JsonableError, RateLimited
|
||||
from zerver.lib.html_to_text import get_content_description
|
||||
from zerver.lib.queue import queue_json_publish
|
||||
from zerver.lib.rate_limiter import RateLimitResult, max_api_calls
|
||||
from zerver.lib.response import json_error, json_response_from_error
|
||||
from zerver.lib.subdomains import get_subdomain
|
||||
from zerver.lib.utils import statsd
|
||||
|
@ -322,20 +323,31 @@ def csrf_failure(request: HttpRequest, reason: str="") -> HttpResponse:
|
|||
return html_csrf_failure(request, reason)
|
||||
|
||||
class RateLimitMiddleware(MiddlewareMixin):
|
||||
def set_response_headers(self, response: HttpResponse,
|
||||
rate_limit_results: List[RateLimitResult]) -> None:
|
||||
# The limit on the action that was requested is the minimum of the limits that get applied:
|
||||
limit = min([max_api_calls(result.entity) for result in rate_limit_results])
|
||||
response['X-RateLimit-Limit'] = str(limit)
|
||||
# Same principle applies to remaining api calls:
|
||||
if all(result.remaining for result in rate_limit_results):
|
||||
remaining_api_calls = min([result.remaining for result in rate_limit_results])
|
||||
response['X-RateLimit-Remaining'] = str(remaining_api_calls)
|
||||
else:
|
||||
response['X-RateLimit-Remaining'] = str(0)
|
||||
|
||||
# The full reset time is the maximum of the reset times for the limits that get applied:
|
||||
reset_time = time.time() + max([result.secs_to_freedom for result in rate_limit_results])
|
||||
response['X-RateLimit-Reset'] = str(int(reset_time))
|
||||
|
||||
def process_response(self, request: HttpRequest, response: HttpResponse) -> HttpResponse:
|
||||
if not settings.RATE_LIMITING:
|
||||
return response
|
||||
|
||||
from zerver.lib.rate_limiter import max_api_calls, RateLimitedUser
|
||||
# Add X-RateLimit-*** headers
|
||||
if hasattr(request, '_ratelimit'):
|
||||
# Right now, the only kind of limiting requests is user-based.
|
||||
ratelimit_user_results = request._ratelimit['RateLimitedUser']
|
||||
entity = RateLimitedUser(request.user)
|
||||
response['X-RateLimit-Limit'] = str(max_api_calls(entity))
|
||||
response['X-RateLimit-Reset'] = str(int(time.time() + ratelimit_user_results['secs_to_freedom']))
|
||||
if 'remaining' in ratelimit_user_results:
|
||||
response['X-RateLimit-Remaining'] = str(ratelimit_user_results['remaining'])
|
||||
rate_limit_results = list(request._ratelimit.values())
|
||||
self.set_response_headers(response, rate_limit_results)
|
||||
|
||||
return response
|
||||
|
||||
# TODO: When we have Django stubs, we should be able to fix the
|
||||
|
@ -348,10 +360,10 @@ class RateLimitMiddleware(MiddlewareMixin):
|
|||
entity_type = str(exception) # entity type is passed to RateLimited when raising
|
||||
resp = json_error(
|
||||
_("API usage exceeded rate limit"),
|
||||
data={'retry-after': request._ratelimit[entity_type]['secs_to_freedom']},
|
||||
data={'retry-after': request._ratelimit[entity_type].secs_to_freedom},
|
||||
status=429
|
||||
)
|
||||
resp['Retry-After'] = request._ratelimit[entity_type]['secs_to_freedom']
|
||||
resp['Retry-After'] = request._ratelimit[entity_type].secs_to_freedom
|
||||
return resp
|
||||
return None
|
||||
|
||||
|
|
Loading…
Reference in New Issue