exceptions: Make RateLimited into a subclass of JsonableError.

This simplifies the code, as it allows using the mechanism of converting
JsonableErrors into a response instead of having separate, but
ultimately similar, logic in RateLimitMiddleware.
We don't touch tests here because "rate limited" error responses are
already verified in test_external.py.
This commit is contained in:
Mateusz Mandera 2020-11-27 16:33:01 +01:00 committed by Alex Vandiver
parent 92ce2d0e31
commit 43a0c60e96
6 changed files with 45 additions and 25 deletions

View File

@ -349,7 +349,8 @@ class OurAuthenticationForm(AuthenticationForm):
self.user_cache = authenticate(request=self.request, username=username, password=password,
realm=realm, return_data=return_data)
except RateLimited as e:
secs_to_freedom = int(float(str(e)))
assert e.secs_to_freedom is not None
secs_to_freedom = int(e.secs_to_freedom)
raise ValidationError(AUTHENTICATION_RATE_LIMITED_ERROR.format(secs_to_freedom))
if return_data.get("inactive_realm"):

View File

@ -47,6 +47,7 @@ class ErrorCode(AbstractEnum):
INVALID_ZOOM_TOKEN = ()
UNAUTHENTICATED_USER = ()
NONEXISTENT_SUBDOMAIN = ()
RATE_LIMIT_HIT = ()
class JsonableError(Exception):
'''A standardized error format we can turn into a nice JSON HTTP response.
@ -111,6 +112,10 @@ class JsonableError(Exception):
# at construction time.
return '{_msg}'
@property
def extra_headers(self) -> Dict[str, Any]:
return {}
#
# Infrastructure -- not intended to be overridden in subclasses.
#
@ -179,9 +184,31 @@ class InvalidMarkdownIncludeStatement(JsonableError):
def msg_format() -> str:
return _("Invalid Markdown include statement: {include_statement}")
class RateLimited(Exception):
def __init__(self, msg: str="") -> None:
super().__init__(msg)
class RateLimited(JsonableError):
code = ErrorCode.RATE_LIMIT_HIT
http_status_code = 429
def __init__(self, secs_to_freedom: Optional[float]=None) -> None:
self.secs_to_freedom = secs_to_freedom
@staticmethod
def msg_format() -> str:
return _("API usage exceeded rate limit")
@property
def extra_headers(self) -> Dict[str, Any]:
extra_headers_dict = super().extra_headers
if self.secs_to_freedom is not None:
extra_headers_dict["Retry-After"] = self.secs_to_freedom
return extra_headers_dict
@property
def data(self) -> Dict[str, Any]:
data_dict = super().data
data_dict['retry-after'] = self.secs_to_freedom
return data_dict
class InvalidJSONError(JsonableError):
code = ErrorCode.INVALID_JSON

View File

@ -53,7 +53,7 @@ class RateLimitedObject(ABC):
# Abort this request if the user is over their rate limits
if ratelimited:
# Pass information about what kind of entity got limited in the exception:
raise RateLimited(str(time))
raise RateLimited(time)
calls_remaining, seconds_until_reset = self.api_calls_left()

View File

@ -66,10 +66,15 @@ def json_response_from_error(exception: JsonableError) -> HttpResponse:
middleware takes care of transforming it into a response by
calling this function.
'''
return json_response('error',
msg=exception.msg,
data=exception.data,
status=exception.http_status_code)
response = json_response('error',
msg=exception.msg,
data=exception.data,
status=exception.http_status_code)
for header, value in exception.extra_headers.items():
response[header] = value
return response
def json_error(msg: str, data: Mapping[str, Any]={}, status: int=400) -> HttpResponse:
return json_response(res_type="error", msg=msg, data=data, status=status)

View File

@ -23,7 +23,7 @@ from sentry_sdk.integrations.logging import ignore_logger
from zerver.lib.cache import get_remote_cache_requests, get_remote_cache_time
from zerver.lib.db import reset_queries
from zerver.lib.debug import maybe_tracemalloc_listen
from zerver.lib.exceptions import ErrorCode, JsonableError, MissingAuthenticationError, RateLimited
from zerver.lib.exceptions import ErrorCode, JsonableError, MissingAuthenticationError
from zerver.lib.html_to_text import get_content_description
from zerver.lib.markdown import get_markdown_requests, get_markdown_time
from zerver.lib.rate_limiter import RateLimitResult
@ -408,20 +408,6 @@ class RateLimitMiddleware(MiddlewareMixin):
return response
def process_exception(self, request: HttpRequest,
exception: Exception) -> Optional[HttpResponse]:
if isinstance(exception, RateLimited):
# secs_to_freedom is passed to RateLimited when raising
secs_to_freedom = float(str(exception))
resp = json_error(
_("API usage exceeded rate limit"),
data={'retry-after': secs_to_freedom},
status=429,
)
resp['Retry-After'] = secs_to_freedom
return resp
return None
class FlushDisplayRecipientCache(MiddlewareMixin):
def process_response(self, request: HttpRequest, response: HttpResponse) -> HttpResponse:
# We flush the per-request caches after every request, so they

View File

@ -98,7 +98,8 @@ def json_change_settings(request: HttpRequest, user_profile: UserProfile,
realm=user_profile.realm, return_data=return_data):
return json_error(_("Wrong password!"))
except RateLimited as e:
secs_to_freedom = int(float(str(e)))
assert e.secs_to_freedom is not None
secs_to_freedom = int(e.secs_to_freedom)
return json_error(
_("You're making too many attempts! Try again in {} seconds.").format(secs_to_freedom),
)