mirror of https://github.com/zulip/zulip.git
rate_limiter: Handle multiple types of rate limiting in middleware.
As more types of rate limiting of requests are added, one request may end up having various limits applied to it - and the middleware needs to be able to handle that. We implement that through a set_response_headers function, which sets the X-RateLimit-* headers in a sensible way based on all the limits that were applied to the request.
This commit is contained in:
parent
677764d9ca
commit
a6a2d70320
|
@ -275,10 +275,11 @@ def rate_limit_request_by_entity(request: HttpRequest, entity: RateLimitedObject
|
||||||
entity_type = type(entity).__name__
|
entity_type = type(entity).__name__
|
||||||
if not hasattr(request, '_ratelimit'):
|
if not hasattr(request, '_ratelimit'):
|
||||||
request._ratelimit = {}
|
request._ratelimit = {}
|
||||||
request._ratelimit[entity_type] = {}
|
request._ratelimit[entity_type] = RateLimitResult(
|
||||||
request._ratelimit[entity_type]['applied_limits'] = True
|
entity=entity,
|
||||||
request._ratelimit[entity_type]['secs_to_freedom'] = time
|
secs_to_freedom=time,
|
||||||
request._ratelimit[entity_type]['over_limit'] = ratelimited
|
over_limit=ratelimited
|
||||||
|
)
|
||||||
# Abort this request if the user is over their rate limits
|
# Abort this request if the user is over their rate limits
|
||||||
if ratelimited:
|
if ratelimited:
|
||||||
# Pass information about what kind of entity got limited in the exception:
|
# Pass information about what kind of entity got limited in the exception:
|
||||||
|
@ -286,5 +287,16 @@ def rate_limit_request_by_entity(request: HttpRequest, entity: RateLimitedObject
|
||||||
|
|
||||||
calls_remaining, time_reset = api_calls_left(entity)
|
calls_remaining, time_reset = api_calls_left(entity)
|
||||||
|
|
||||||
request._ratelimit[entity_type]['remaining'] = calls_remaining
|
request._ratelimit[entity_type].remaining = calls_remaining
|
||||||
request._ratelimit[entity_type]['secs_to_freedom'] = time_reset
|
request._ratelimit[entity_type].secs_to_freedom = time_reset
|
||||||
|
|
||||||
|
class RateLimitResult:
|
||||||
|
def __init__(self, entity: RateLimitedObject, secs_to_freedom: float, over_limit: bool,
|
||||||
|
remaining: Optional[int]=None) -> None:
|
||||||
|
if over_limit:
|
||||||
|
assert not remaining
|
||||||
|
|
||||||
|
self.entity = entity
|
||||||
|
self.secs_to_freedom = secs_to_freedom
|
||||||
|
self.over_limit = over_limit
|
||||||
|
self.remaining = remaining
|
||||||
|
|
|
@ -26,6 +26,7 @@ from zerver.lib.db import reset_queries
|
||||||
from zerver.lib.exceptions import ErrorCode, JsonableError, RateLimited
|
from zerver.lib.exceptions import ErrorCode, JsonableError, RateLimited
|
||||||
from zerver.lib.html_to_text import get_content_description
|
from zerver.lib.html_to_text import get_content_description
|
||||||
from zerver.lib.queue import queue_json_publish
|
from zerver.lib.queue import queue_json_publish
|
||||||
|
from zerver.lib.rate_limiter import RateLimitResult, max_api_calls
|
||||||
from zerver.lib.response import json_error, json_response_from_error
|
from zerver.lib.response import json_error, json_response_from_error
|
||||||
from zerver.lib.subdomains import get_subdomain
|
from zerver.lib.subdomains import get_subdomain
|
||||||
from zerver.lib.utils import statsd
|
from zerver.lib.utils import statsd
|
||||||
|
@ -322,20 +323,31 @@ def csrf_failure(request: HttpRequest, reason: str="") -> HttpResponse:
|
||||||
return html_csrf_failure(request, reason)
|
return html_csrf_failure(request, reason)
|
||||||
|
|
||||||
class RateLimitMiddleware(MiddlewareMixin):
|
class RateLimitMiddleware(MiddlewareMixin):
|
||||||
|
def set_response_headers(self, response: HttpResponse,
|
||||||
|
rate_limit_results: List[RateLimitResult]) -> None:
|
||||||
|
# The limit on the action that was requested is the minimum of the limits that get applied:
|
||||||
|
limit = min([max_api_calls(result.entity) for result in rate_limit_results])
|
||||||
|
response['X-RateLimit-Limit'] = str(limit)
|
||||||
|
# Same principle applies to remaining api calls:
|
||||||
|
if all(result.remaining for result in rate_limit_results):
|
||||||
|
remaining_api_calls = min([result.remaining for result in rate_limit_results])
|
||||||
|
response['X-RateLimit-Remaining'] = str(remaining_api_calls)
|
||||||
|
else:
|
||||||
|
response['X-RateLimit-Remaining'] = str(0)
|
||||||
|
|
||||||
|
# The full reset time is the maximum of the reset times for the limits that get applied:
|
||||||
|
reset_time = time.time() + max([result.secs_to_freedom for result in rate_limit_results])
|
||||||
|
response['X-RateLimit-Reset'] = str(int(reset_time))
|
||||||
|
|
||||||
def process_response(self, request: HttpRequest, response: HttpResponse) -> HttpResponse:
|
def process_response(self, request: HttpRequest, response: HttpResponse) -> HttpResponse:
|
||||||
if not settings.RATE_LIMITING:
|
if not settings.RATE_LIMITING:
|
||||||
return response
|
return response
|
||||||
|
|
||||||
from zerver.lib.rate_limiter import max_api_calls, RateLimitedUser
|
|
||||||
# Add X-RateLimit-*** headers
|
# Add X-RateLimit-*** headers
|
||||||
if hasattr(request, '_ratelimit'):
|
if hasattr(request, '_ratelimit'):
|
||||||
# Right now, the only kind of limiting requests is user-based.
|
rate_limit_results = list(request._ratelimit.values())
|
||||||
ratelimit_user_results = request._ratelimit['RateLimitedUser']
|
self.set_response_headers(response, rate_limit_results)
|
||||||
entity = RateLimitedUser(request.user)
|
|
||||||
response['X-RateLimit-Limit'] = str(max_api_calls(entity))
|
|
||||||
response['X-RateLimit-Reset'] = str(int(time.time() + ratelimit_user_results['secs_to_freedom']))
|
|
||||||
if 'remaining' in ratelimit_user_results:
|
|
||||||
response['X-RateLimit-Remaining'] = str(ratelimit_user_results['remaining'])
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
# TODO: When we have Django stubs, we should be able to fix the
|
# TODO: When we have Django stubs, we should be able to fix the
|
||||||
|
@ -348,10 +360,10 @@ class RateLimitMiddleware(MiddlewareMixin):
|
||||||
entity_type = str(exception) # entity type is passed to RateLimited when raising
|
entity_type = str(exception) # entity type is passed to RateLimited when raising
|
||||||
resp = json_error(
|
resp = json_error(
|
||||||
_("API usage exceeded rate limit"),
|
_("API usage exceeded rate limit"),
|
||||||
data={'retry-after': request._ratelimit[entity_type]['secs_to_freedom']},
|
data={'retry-after': request._ratelimit[entity_type].secs_to_freedom},
|
||||||
status=429
|
status=429
|
||||||
)
|
)
|
||||||
resp['Retry-After'] = request._ratelimit[entity_type]['secs_to_freedom']
|
resp['Retry-After'] = request._ratelimit[entity_type].secs_to_freedom
|
||||||
return resp
|
return resp
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue