request: Refactor remote_server into RequestNotes.

This eliminates the possibility of having `request.user` as
`RemoteZulipServer` by refactoring it as an attribute of `RequestNotes`.

So we can effectively narrow the type of `request.user` by testing
`user.is_authenticated` in most cases (except that of `SCIMClient`) in
code paths that require access to `.format_requestor_for_logs` where we
previously expect either `UserProfile` or `RemoteZulipServer` backed by
the implied polymorphism.

Signed-off-by: Zixuan James Li <p359101898@gmail.com>
This commit is contained in:
Zixuan James Li 2022-06-12 15:33:20 -04:00 committed by Tim Abbott
parent 3bc78d2473
commit b02779c005
6 changed files with 18 additions and 11 deletions

View File

@ -263,8 +263,7 @@ def validate_api_key(
if get_subdomain(request) != Realm.SUBDOMAIN_FOR_ROOT_DOMAIN: if get_subdomain(request) != Realm.SUBDOMAIN_FOR_ROOT_DOMAIN:
raise JsonableError(_("Invalid subdomain for push notifications bouncer")) raise JsonableError(_("Invalid subdomain for push notifications bouncer"))
request.user = remote_server RequestNotes.get_notes(request).remote_server = remote_server
# Skip updating UserActivity, since remote_server isn't actually a UserProfile object.
process_client(request) process_client(request)
return remote_server return remote_server
@ -996,12 +995,13 @@ def rate_limit() -> Callable[[ViewFuncT], ViewFuncT]:
return func(request, *args, **kwargs) return func(request, *args, **kwargs)
user = request.user user = request.user
remote_server = RequestNotes.get_notes(request).remote_server
if isinstance(user, AnonymousUser): if settings.ZILENCER_ENABLED and remote_server is not None:
rate_limit_remote_server(request, remote_server, domain="api_by_remote_server")
elif not user.is_authenticated:
rate_limit_request_by_ip(request, domain="api_by_ip") rate_limit_request_by_ip(request, domain="api_by_ip")
return func(request, *args, **kwargs) return func(request, *args, **kwargs)
elif settings.ZILENCER_ENABLED and isinstance(user, RemoteZulipServer):
rate_limit_remote_server(request, user, domain="api_by_remote_server")
else: else:
assert isinstance(user, UserProfile) assert isinstance(user, UserProfile)
rate_limit_user(request, user, domain="api_by_user") rate_limit_user(request, user, domain="api_by_user")

View File

@ -21,6 +21,7 @@ from typing import (
) )
import orjson import orjson
from django.conf import settings
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.http import HttpRequest, HttpResponse from django.http import HttpRequest, HttpResponse
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
@ -32,6 +33,9 @@ from zerver.lib.types import Validator, ViewFuncT
from zerver.lib.validator import check_anything from zerver.lib.validator import check_anything
from zerver.models import Client, Realm from zerver.models import Client, Realm
if settings.ZILENCER_ENABLED:
from zilencer.models import RemoteZulipServer
@dataclass @dataclass
class RequestNotes(BaseNotes[HttpRequest, "RequestNotes"]): class RequestNotes(BaseNotes[HttpRequest, "RequestNotes"]):
@ -66,6 +70,7 @@ class RequestNotes(BaseNotes[HttpRequest, "RequestNotes"]):
tornado_handler_id: Optional[int] = None tornado_handler_id: Optional[int] = None
processed_parameters: Set[str] = field(default_factory=set) processed_parameters: Set[str] = field(default_factory=set)
ignored_parameters: Set[str] = field(default_factory=set) ignored_parameters: Set[str] = field(default_factory=set)
remote_server: Optional["RemoteZulipServer"] = None
@classmethod @classmethod
def init_notes(cls) -> "RequestNotes": def init_notes(cls) -> "RequestNotes":

View File

@ -299,7 +299,8 @@ class HostRequestMock(HttpRequest):
def __init__( def __init__(
self, self,
post_data: Dict[str, Any] = {}, post_data: Dict[str, Any] = {},
user_profile: Union[UserProfile, RemoteZulipServer, None] = None, user_profile: Union[UserProfile, None] = None,
remote_server: Optional[RemoteZulipServer] = None,
host: str = settings.EXTERNAL_HOST, host: str = settings.EXTERNAL_HOST,
client_name: Optional[str] = None, client_name: Optional[str] = None,
meta_data: Optional[Dict[str, Any]] = None, meta_data: Optional[Dict[str, Any]] = None,
@ -335,6 +336,7 @@ class HostRequestMock(HttpRequest):
log_data={}, log_data={},
tornado_handler_id=None if tornado_handler is None else tornado_handler.handler_id, tornado_handler_id=None if tornado_handler is None else tornado_handler.handler_id,
client=get_client(client_name) if client_name is not None else None, client=get_client(client_name) if client_name is not None else None,
remote_server=remote_server,
), ),
) )

View File

@ -407,9 +407,9 @@ class LogRequests(MiddlewareMixin):
request_notes = RequestNotes.get_notes(request) request_notes = RequestNotes.get_notes(request)
requestor_for_logs = request_notes.requestor_for_logs requestor_for_logs = request_notes.requestor_for_logs
if requestor_for_logs is None: if requestor_for_logs is None:
# Note that request.user is a Union[RemoteZulipServer, UserProfile, AnonymousUser], if request_notes.remote_server is not None:
# if it is present. requestor_for_logs = request_notes.remote_server.format_requestor_for_logs()
if hasattr(request.user, "format_requestor_for_logs"): elif request.user.is_authenticated:
requestor_for_logs = request.user.format_requestor_for_logs() requestor_for_logs = request.user.format_requestor_for_logs()
else: else:
requestor_for_logs = "unauth@{}".format(get_subdomain(request) or "root") requestor_for_logs = "unauth@{}".format(get_subdomain(request) or "root")

View File

@ -726,7 +726,7 @@ class RateLimitTestCase(ZulipTestCase):
) )
META = {"REMOTE_ADDR": "3.3.3.3"} META = {"REMOTE_ADDR": "3.3.3.3"}
req = HostRequestMock(client_name="external", user_profile=server, meta_data=META) req = HostRequestMock(client_name="external", remote_server=server, meta_data=META)
f = self.get_ratelimited_view() f = self.get_ratelimited_view()

View File

@ -251,7 +251,7 @@ class LogRequestsTest(ZulipTestCase):
def test_requestor_for_logs_as_remote_server(self) -> None: def test_requestor_for_logs_as_remote_server(self) -> None:
remote_server = RemoteZulipServer() remote_server = RemoteZulipServer()
request = HostRequestMock(user_profile=remote_server, meta_data=self.meta_data) request = HostRequestMock(remote_server=remote_server, meta_data=self.meta_data)
RequestNotes.get_notes(request).log_data = None RequestNotes.get_notes(request).log_data = None
with self.assertLogs("zulip.requests", level="INFO") as m: with self.assertLogs("zulip.requests", level="INFO") as m: