From da6b0b1cc6c9730180a3966ace329c0826c5e10d Mon Sep 17 00:00:00 2001 From: Alex Vandiver Date: Thu, 8 Feb 2024 19:57:16 +0000 Subject: [PATCH] tornado: Add a web_reload_clients endpoint to notify web clients. --- zerver/tests/test_event_system.py | 13 +++++++++++++ zerver/tornado/application.py | 1 + zerver/tornado/event_queue.py | 3 ++- zerver/tornado/views.py | 29 ++++++++++++++++++++++++++++- zproject/urls.py | 9 ++++++++- 5 files changed, 52 insertions(+), 3 deletions(-) diff --git a/zerver/tests/test_event_system.py b/zerver/tests/test_event_system.py index 93f60a3b10..7412e39488 100644 --- a/zerver/tests/test_event_system.py +++ b/zerver/tests/test_event_system.py @@ -270,6 +270,19 @@ class EventsEndpointTest(ZulipTestCase): self.assertEqual(str(context.exception), "Missing 'data' argument") self.assertEqual(context.exception.http_status_code, 400) + def test_web_reload_clients(self) -> None: + # Minimal testing of the /api/internal/web_reload_clients endpoint + post_data = { + "client_count": "1", + "immediate": orjson.dumps(False).decode(), + "secret": settings.SHARED_SECRET, + } + req = HostRequestMock(post_data, tornado_handler=dummy_handler) + req.META["REMOTE_ADDR"] = "127.0.0.1" + result = self.client_post_request("/api/internal/web_reload_clients", req) + self.assert_json_success(result) + self.assertEqual(orjson.loads(result.content)["sent_events"], 0) + class GetEventsTest(ZulipTestCase): def tornado_call( diff --git a/zerver/tornado/application.py b/zerver/tornado/application.py index 02502a65cf..5d9dc392aa 100644 --- a/zerver/tornado/application.py +++ b/zerver/tornado/application.py @@ -21,6 +21,7 @@ def create_tornado_application(*, autoreload: bool = False) -> tornado.web.Appli r"/api/v1/events", r"/api/v1/events/internal", r"/api/internal/notify_tornado", + r"/api/internal/web_reload_clients", ) return tornado.web.Application( diff --git a/zerver/tornado/event_queue.py b/zerver/tornado/event_queue.py index 7d262b9602..acd834a418 100644 --- a/zerver/tornado/event_queue.py +++ b/zerver/tornado/event_queue.py @@ -666,7 +666,7 @@ def mark_clients_to_reload(queue_ids: Iterable[str]) -> None: web_reload_clients[qid] = True -def send_web_reload_client_events(immediate: bool = False, count: Optional[int] = None) -> None: +def send_web_reload_client_events(immediate: bool = False, count: Optional[int] = None) -> int: event: Dict[str, Any] = dict( type="web_reload_client", immediate=immediate, @@ -679,6 +679,7 @@ def send_web_reload_client_events(immediate: bool = False, count: Optional[int] client = clients[qid] if client.accepts_event(event): client.add_event(event) + return len(queue_ids) async def setup_event_queue( diff --git a/zerver/tornado/views.py b/zerver/tornado/views.py index 8ca6385f95..3276e05115 100644 --- a/zerver/tornado/views.py +++ b/zerver/tornado/views.py @@ -5,6 +5,7 @@ from asgiref.sync import async_to_sync from django.conf import settings from django.http import HttpRequest, HttpResponse from django.utils.translation import gettext as _ +from pydantic import Json from typing_extensions import ParamSpec from zerver.decorator import internal_api_view, process_client @@ -12,6 +13,7 @@ from zerver.lib.exceptions import JsonableError from zerver.lib.queue import get_queue_client from zerver.lib.request import REQ, RequestNotes, has_request_variables from zerver.lib.response import AsynchronousResponse, json_success +from zerver.lib.typed_endpoint import typed_endpoint from zerver.lib.validator import ( check_bool, check_dict, @@ -24,7 +26,12 @@ from zerver.models import Client, UserProfile from zerver.models.clients import get_client from zerver.models.users import get_user_profile_by_id from zerver.tornado.descriptors import is_current_port -from zerver.tornado.event_queue import access_client_descriptor, fetch_events, process_notification +from zerver.tornado.event_queue import ( + access_client_descriptor, + fetch_events, + process_notification, + send_web_reload_client_events, +) from zerver.tornado.sharding import get_user_tornado_port, notify_tornado_queue_name P = ParamSpec("P") @@ -49,6 +56,26 @@ def notify( return json_success(request) +@internal_api_view(True) +@typed_endpoint +def web_reload_clients( + request: HttpRequest, + *, + client_count: Optional[Json[int]] = None, + immediate: Json[bool] = False, +) -> HttpResponse: + sent_events = in_tornado_thread(send_web_reload_client_events)( + immediate=immediate, count=client_count + ) + return json_success( + request, + { + "sent_events": sent_events, + "complete": client_count is None or client_count != sent_events, + }, + ) + + @has_request_variables def cleanup_event_queue( request: HttpRequest, user_profile: UserProfile, queue_id: str = REQ() diff --git a/zproject/urls.py b/zproject/urls.py index a8387aea40..864a5352d0 100644 --- a/zproject/urls.py +++ b/zproject/urls.py @@ -19,7 +19,13 @@ from zerver.forms import LoggingSetPasswordForm from zerver.lib.integrations import WEBHOOK_INTEGRATIONS from zerver.lib.rest import rest_path from zerver.lib.url_redirects import DOCUMENTATION_REDIRECTS -from zerver.tornado.views import cleanup_event_queue, get_events, get_events_internal, notify +from zerver.tornado.views import ( + cleanup_event_queue, + get_events, + get_events_internal, + notify, + web_reload_clients, +) from zerver.views.alert_words import add_alert_words, list_alert_words, remove_alert_words from zerver.views.attachments import list_by_user, remove from zerver.views.auth import ( @@ -739,6 +745,7 @@ for app_name in settings.EXTRA_INSTALLED_APPS: urls += [ path("api/internal/email_mirror_message", email_mirror_message), path("api/internal/notify_tornado", notify), + path("api/internal/web_reload_clients", web_reload_clients), path("api/v1/events/internal", get_events_internal), ]