request: Refactor remote_server into RequestNotes.

This eliminates the possibility of having `request.user` as
`RemoteZulipServer` by refactoring it as an attribute of `RequestNotes`.

So we can effectively narrow the type of `request.user` by testing
`user.is_authenticated` in most cases (except that of `SCIMClient`) in
code paths that require access to `.format_requestor_for_logs` where we
previously expect either `UserProfile` or `RemoteZulipServer` backed by
the implied polymorphism.

Signed-off-by: Zixuan James Li <p359101898@gmail.com>
This commit is contained in:
Zixuan James Li 2022-06-12 15:33:20 -04:00 committed by Tim Abbott
parent 3bc78d2473
commit b02779c005
6 changed files with 18 additions and 11 deletions

View File

@ -263,8 +263,7 @@ def validate_api_key(
if get_subdomain(request) != Realm.SUBDOMAIN_FOR_ROOT_DOMAIN:
raise JsonableError(_("Invalid subdomain for push notifications bouncer"))
request.user = remote_server
# Skip updating UserActivity, since remote_server isn't actually a UserProfile object.
RequestNotes.get_notes(request).remote_server = remote_server
process_client(request)
return remote_server
@ -996,12 +995,13 @@ def rate_limit() -> Callable[[ViewFuncT], ViewFuncT]:
return func(request, *args, **kwargs)
user = request.user
remote_server = RequestNotes.get_notes(request).remote_server
if isinstance(user, AnonymousUser):
if settings.ZILENCER_ENABLED and remote_server is not None:
rate_limit_remote_server(request, remote_server, domain="api_by_remote_server")
elif not user.is_authenticated:
rate_limit_request_by_ip(request, domain="api_by_ip")
return func(request, *args, **kwargs)
elif settings.ZILENCER_ENABLED and isinstance(user, RemoteZulipServer):
rate_limit_remote_server(request, user, domain="api_by_remote_server")
else:
assert isinstance(user, UserProfile)
rate_limit_user(request, user, domain="api_by_user")

View File

@ -21,6 +21,7 @@ from typing import (
)
import orjson
from django.conf import settings
from django.core.exceptions import ValidationError
from django.http import HttpRequest, HttpResponse
from django.utils.translation import gettext as _
@ -32,6 +33,9 @@ from zerver.lib.types import Validator, ViewFuncT
from zerver.lib.validator import check_anything
from zerver.models import Client, Realm
if settings.ZILENCER_ENABLED:
from zilencer.models import RemoteZulipServer
@dataclass
class RequestNotes(BaseNotes[HttpRequest, "RequestNotes"]):
@ -66,6 +70,7 @@ class RequestNotes(BaseNotes[HttpRequest, "RequestNotes"]):
tornado_handler_id: Optional[int] = None
processed_parameters: Set[str] = field(default_factory=set)
ignored_parameters: Set[str] = field(default_factory=set)
remote_server: Optional["RemoteZulipServer"] = None
@classmethod
def init_notes(cls) -> "RequestNotes":

View File

@ -299,7 +299,8 @@ class HostRequestMock(HttpRequest):
def __init__(
self,
post_data: Dict[str, Any] = {},
user_profile: Union[UserProfile, RemoteZulipServer, None] = None,
user_profile: Union[UserProfile, None] = None,
remote_server: Optional[RemoteZulipServer] = None,
host: str = settings.EXTERNAL_HOST,
client_name: Optional[str] = None,
meta_data: Optional[Dict[str, Any]] = None,
@ -335,6 +336,7 @@ class HostRequestMock(HttpRequest):
log_data={},
tornado_handler_id=None if tornado_handler is None else tornado_handler.handler_id,
client=get_client(client_name) if client_name is not None else None,
remote_server=remote_server,
),
)

View File

@ -407,9 +407,9 @@ class LogRequests(MiddlewareMixin):
request_notes = RequestNotes.get_notes(request)
requestor_for_logs = request_notes.requestor_for_logs
if requestor_for_logs is None:
# Note that request.user is a Union[RemoteZulipServer, UserProfile, AnonymousUser],
# if it is present.
if hasattr(request.user, "format_requestor_for_logs"):
if request_notes.remote_server is not None:
requestor_for_logs = request_notes.remote_server.format_requestor_for_logs()
elif request.user.is_authenticated:
requestor_for_logs = request.user.format_requestor_for_logs()
else:
requestor_for_logs = "unauth@{}".format(get_subdomain(request) or "root")

View File

@ -726,7 +726,7 @@ class RateLimitTestCase(ZulipTestCase):
)
META = {"REMOTE_ADDR": "3.3.3.3"}
req = HostRequestMock(client_name="external", user_profile=server, meta_data=META)
req = HostRequestMock(client_name="external", remote_server=server, meta_data=META)
f = self.get_ratelimited_view()

View File

@ -251,7 +251,7 @@ class LogRequestsTest(ZulipTestCase):
def test_requestor_for_logs_as_remote_server(self) -> None:
remote_server = RemoteZulipServer()
request = HostRequestMock(user_profile=remote_server, meta_data=self.meta_data)
request = HostRequestMock(remote_server=remote_server, meta_data=self.meta_data)
RequestNotes.get_notes(request).log_data = None
with self.assertLogs("zulip.requests", level="INFO") as m: