mirror of https://github.com/zulip/zulip.git
request: Refactor remote_server into RequestNotes.
This eliminates the possibility of having `request.user` as `RemoteZulipServer` by refactoring it as an attribute of `RequestNotes`. So we can effectively narrow the type of `request.user` by testing `user.is_authenticated` in most cases (except that of `SCIMClient`) in code paths that require access to `.format_requestor_for_logs` where we previously expect either `UserProfile` or `RemoteZulipServer` backed by the implied polymorphism. Signed-off-by: Zixuan James Li <p359101898@gmail.com>
This commit is contained in:
parent
3bc78d2473
commit
b02779c005
|
@ -263,8 +263,7 @@ def validate_api_key(
|
||||||
|
|
||||||
if get_subdomain(request) != Realm.SUBDOMAIN_FOR_ROOT_DOMAIN:
|
if get_subdomain(request) != Realm.SUBDOMAIN_FOR_ROOT_DOMAIN:
|
||||||
raise JsonableError(_("Invalid subdomain for push notifications bouncer"))
|
raise JsonableError(_("Invalid subdomain for push notifications bouncer"))
|
||||||
request.user = remote_server
|
RequestNotes.get_notes(request).remote_server = remote_server
|
||||||
# Skip updating UserActivity, since remote_server isn't actually a UserProfile object.
|
|
||||||
process_client(request)
|
process_client(request)
|
||||||
return remote_server
|
return remote_server
|
||||||
|
|
||||||
|
@ -996,12 +995,13 @@ def rate_limit() -> Callable[[ViewFuncT], ViewFuncT]:
|
||||||
return func(request, *args, **kwargs)
|
return func(request, *args, **kwargs)
|
||||||
|
|
||||||
user = request.user
|
user = request.user
|
||||||
|
remote_server = RequestNotes.get_notes(request).remote_server
|
||||||
|
|
||||||
if isinstance(user, AnonymousUser):
|
if settings.ZILENCER_ENABLED and remote_server is not None:
|
||||||
|
rate_limit_remote_server(request, remote_server, domain="api_by_remote_server")
|
||||||
|
elif not user.is_authenticated:
|
||||||
rate_limit_request_by_ip(request, domain="api_by_ip")
|
rate_limit_request_by_ip(request, domain="api_by_ip")
|
||||||
return func(request, *args, **kwargs)
|
return func(request, *args, **kwargs)
|
||||||
elif settings.ZILENCER_ENABLED and isinstance(user, RemoteZulipServer):
|
|
||||||
rate_limit_remote_server(request, user, domain="api_by_remote_server")
|
|
||||||
else:
|
else:
|
||||||
assert isinstance(user, UserProfile)
|
assert isinstance(user, UserProfile)
|
||||||
rate_limit_user(request, user, domain="api_by_user")
|
rate_limit_user(request, user, domain="api_by_user")
|
||||||
|
|
|
@ -21,6 +21,7 @@ from typing import (
|
||||||
)
|
)
|
||||||
|
|
||||||
import orjson
|
import orjson
|
||||||
|
from django.conf import settings
|
||||||
from django.core.exceptions import ValidationError
|
from django.core.exceptions import ValidationError
|
||||||
from django.http import HttpRequest, HttpResponse
|
from django.http import HttpRequest, HttpResponse
|
||||||
from django.utils.translation import gettext as _
|
from django.utils.translation import gettext as _
|
||||||
|
@ -32,6 +33,9 @@ from zerver.lib.types import Validator, ViewFuncT
|
||||||
from zerver.lib.validator import check_anything
|
from zerver.lib.validator import check_anything
|
||||||
from zerver.models import Client, Realm
|
from zerver.models import Client, Realm
|
||||||
|
|
||||||
|
if settings.ZILENCER_ENABLED:
|
||||||
|
from zilencer.models import RemoteZulipServer
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RequestNotes(BaseNotes[HttpRequest, "RequestNotes"]):
|
class RequestNotes(BaseNotes[HttpRequest, "RequestNotes"]):
|
||||||
|
@ -66,6 +70,7 @@ class RequestNotes(BaseNotes[HttpRequest, "RequestNotes"]):
|
||||||
tornado_handler_id: Optional[int] = None
|
tornado_handler_id: Optional[int] = None
|
||||||
processed_parameters: Set[str] = field(default_factory=set)
|
processed_parameters: Set[str] = field(default_factory=set)
|
||||||
ignored_parameters: Set[str] = field(default_factory=set)
|
ignored_parameters: Set[str] = field(default_factory=set)
|
||||||
|
remote_server: Optional["RemoteZulipServer"] = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def init_notes(cls) -> "RequestNotes":
|
def init_notes(cls) -> "RequestNotes":
|
||||||
|
|
|
@ -299,7 +299,8 @@ class HostRequestMock(HttpRequest):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
post_data: Dict[str, Any] = {},
|
post_data: Dict[str, Any] = {},
|
||||||
user_profile: Union[UserProfile, RemoteZulipServer, None] = None,
|
user_profile: Union[UserProfile, None] = None,
|
||||||
|
remote_server: Optional[RemoteZulipServer] = None,
|
||||||
host: str = settings.EXTERNAL_HOST,
|
host: str = settings.EXTERNAL_HOST,
|
||||||
client_name: Optional[str] = None,
|
client_name: Optional[str] = None,
|
||||||
meta_data: Optional[Dict[str, Any]] = None,
|
meta_data: Optional[Dict[str, Any]] = None,
|
||||||
|
@ -335,6 +336,7 @@ class HostRequestMock(HttpRequest):
|
||||||
log_data={},
|
log_data={},
|
||||||
tornado_handler_id=None if tornado_handler is None else tornado_handler.handler_id,
|
tornado_handler_id=None if tornado_handler is None else tornado_handler.handler_id,
|
||||||
client=get_client(client_name) if client_name is not None else None,
|
client=get_client(client_name) if client_name is not None else None,
|
||||||
|
remote_server=remote_server,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -407,9 +407,9 @@ class LogRequests(MiddlewareMixin):
|
||||||
request_notes = RequestNotes.get_notes(request)
|
request_notes = RequestNotes.get_notes(request)
|
||||||
requestor_for_logs = request_notes.requestor_for_logs
|
requestor_for_logs = request_notes.requestor_for_logs
|
||||||
if requestor_for_logs is None:
|
if requestor_for_logs is None:
|
||||||
# Note that request.user is a Union[RemoteZulipServer, UserProfile, AnonymousUser],
|
if request_notes.remote_server is not None:
|
||||||
# if it is present.
|
requestor_for_logs = request_notes.remote_server.format_requestor_for_logs()
|
||||||
if hasattr(request.user, "format_requestor_for_logs"):
|
elif request.user.is_authenticated:
|
||||||
requestor_for_logs = request.user.format_requestor_for_logs()
|
requestor_for_logs = request.user.format_requestor_for_logs()
|
||||||
else:
|
else:
|
||||||
requestor_for_logs = "unauth@{}".format(get_subdomain(request) or "root")
|
requestor_for_logs = "unauth@{}".format(get_subdomain(request) or "root")
|
||||||
|
|
|
@ -726,7 +726,7 @@ class RateLimitTestCase(ZulipTestCase):
|
||||||
)
|
)
|
||||||
META = {"REMOTE_ADDR": "3.3.3.3"}
|
META = {"REMOTE_ADDR": "3.3.3.3"}
|
||||||
|
|
||||||
req = HostRequestMock(client_name="external", user_profile=server, meta_data=META)
|
req = HostRequestMock(client_name="external", remote_server=server, meta_data=META)
|
||||||
|
|
||||||
f = self.get_ratelimited_view()
|
f = self.get_ratelimited_view()
|
||||||
|
|
||||||
|
|
|
@ -251,7 +251,7 @@ class LogRequestsTest(ZulipTestCase):
|
||||||
|
|
||||||
def test_requestor_for_logs_as_remote_server(self) -> None:
|
def test_requestor_for_logs_as_remote_server(self) -> None:
|
||||||
remote_server = RemoteZulipServer()
|
remote_server = RemoteZulipServer()
|
||||||
request = HostRequestMock(user_profile=remote_server, meta_data=self.meta_data)
|
request = HostRequestMock(remote_server=remote_server, meta_data=self.meta_data)
|
||||||
RequestNotes.get_notes(request).log_data = None
|
RequestNotes.get_notes(request).log_data = None
|
||||||
|
|
||||||
with self.assertLogs("zulip.requests", level="INFO") as m:
|
with self.assertLogs("zulip.requests", level="INFO") as m:
|
||||||
|
|
Loading…
Reference in New Issue