diff --git a/pyproject.toml b/pyproject.toml index 5e26a0d322..11106a80d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ no_implicit_reexport = false module = [ "ahocorasick.*", "aioapns.*", + "asgiref.*", "bitfield.*", "bmemcached.*", "bson.*", diff --git a/zerver/lib/async_utils.py b/zerver/lib/async_utils.py new file mode 100644 index 0000000000..858cfae539 --- /dev/null +++ b/zerver/lib/async_utils.py @@ -0,0 +1,16 @@ +import asyncio + + +class NoAutoCreateEventLoopPolicy(asyncio.DefaultEventLoopPolicy): # type: ignore[misc,valid-type] # https://github.com/python/typeshed/issues/7452 + """ + By default asyncio.get_event_loop() automatically creates an event + loop for the main thread if one isn't currently installed. Since + Django intentionally uninstalls the event loop within + sync_to_async, that autocreation proliferates confusing extra + event loops that will never be run. It is also deprecated in + Python 3.10. This policy disables it so we don't rely on it by + accident. + """ + + def get_event_loop(self) -> asyncio.AbstractEventLoop: # nocoverage + return asyncio.get_running_loop() diff --git a/zerver/lib/queue.py b/zerver/lib/queue.py index 0ec7ff4849..7ac4ee00c7 100644 --- a/zerver/lib/queue.py +++ b/zerver/lib/queue.py @@ -307,6 +307,8 @@ class TornadoQueueClient(QueueClient[Channel]): def _on_connection_closed( self, connection: pika.connection.Connection, reason: Exception ) -> None: + if self.connection is None: + return self._connection_failure_count = 1 retry_secs = self.CONNECTION_RETRY_SECS self.log.warning( @@ -335,6 +337,7 @@ class TornadoQueueClient(QueueClient[Channel]): def close(self) -> None: if self.connection is not None: self.connection.close() + self.connection = None def ensure_queue(self, queue_name: str, callback: Callable[[Channel], object]) -> None: def set_qos(frame: Any) -> None: diff --git a/zerver/management/commands/runtornado.py b/zerver/management/commands/runtornado.py index ef8dd6c367..15a927fd22 100644 --- a/zerver/management/commands/runtornado.py +++ b/zerver/management/commands/runtornado.py @@ -1,19 +1,25 @@ +import asyncio import logging -import sys +import signal +from contextlib import AsyncExitStack from typing import Any from urllib.parse import SplitResult import __main__ +from asgiref.sync import async_to_sync, sync_to_async from django.conf import settings from django.core.management.base import BaseCommand, CommandError, CommandParser -from tornado import autoreload, ioloop +from tornado import autoreload +from tornado.platform.asyncio import AsyncIOMainLoop, to_asyncio_future settings.RUNNING_INSIDE_TORNADO = True +from zerver.lib.async_utils import NoAutoCreateEventLoopPolicy from zerver.lib.debug import interactive_debug_listen from zerver.tornado.application import create_tornado_application, setup_tornado_rabbitmq from zerver.tornado.event_queue import ( add_client_gc_hook, + dump_event_queues, get_wrapped_process_notification, missedmessage_hook, setup_event_queue, @@ -23,6 +29,8 @@ from zerver.tornado.sharding import notify_tornado_queue_name if settings.USING_RABBITMQ: from zerver.lib.queue import TornadoQueueClient, set_queue_client +asyncio.set_event_loop_policy(NoAutoCreateEventLoopPolicy()) + class Command(BaseCommand): help = "Starts a Tornado Web server wrapping Django." @@ -56,53 +64,76 @@ class Command(BaseCommand): level=logging.INFO, format="%(asctime)s %(levelname)-8s %(message)s" ) - def inner_run() -> None: + async def inner_run() -> None: from django.utils import translation - translation.activate(settings.LANGUAGE_CODE) + AsyncIOMainLoop().install() + loop = asyncio.get_running_loop() + stop_fut = loop.create_future() - # We pass display_num_errors=False, since Django will - # likely display similar output anyway. - self.check(display_num_errors=False) - print(f"Tornado server (re)started on port {port}") + def stop() -> None: + if not stop_fut.done(): + stop_fut.set_result(None) - if settings.USING_RABBITMQ: - queue_client = TornadoQueueClient() - set_queue_client(queue_client) - # Process notifications received via RabbitMQ - queue_name = notify_tornado_queue_name(port) - queue_client.start_json_consumer( - queue_name, get_wrapped_process_notification(queue_name) + def add_signal_handlers() -> None: + loop.add_signal_handler(signal.SIGINT, stop), + loop.add_signal_handler(signal.SIGTERM, stop), + + def remove_signal_handlers() -> None: + loop.remove_signal_handler(signal.SIGINT), + loop.remove_signal_handler(signal.SIGTERM), + + async with AsyncExitStack() as stack: + stack.push_async_callback( + sync_to_async(remove_signal_handlers, thread_sensitive=True) ) + await sync_to_async(add_signal_handlers, thread_sensitive=True)() + + translation.activate(settings.LANGUAGE_CODE) + + # We pass display_num_errors=False, since Django will + # likely display similar output anyway. + self.check(display_num_errors=False) + print(f"Tornado server (re)started on port {port}") + + if settings.USING_RABBITMQ: + queue_client = TornadoQueueClient() + set_queue_client(queue_client) + # Process notifications received via RabbitMQ + queue_name = notify_tornado_queue_name(port) + stack.callback(queue_client.close) + queue_client.start_json_consumer( + queue_name, get_wrapped_process_notification(queue_name) + ) - try: # Application is an instance of Django's standard wsgi handler. application = create_tornado_application() # start tornado web server in single-threaded mode http_server = httpserver.HTTPServer(application, xheaders=True) + stack.push_async_callback( + lambda: to_asyncio_future(http_server.close_all_connections()) + ) + stack.callback(http_server.stop) http_server.listen(port, address=addr) from zerver.tornado.ioloop_logging import logging_data logging_data["port"] = str(port) - setup_event_queue(http_server, port) + await setup_event_queue(http_server, port) + stack.callback(dump_event_queues, port) add_client_gc_hook(missedmessage_hook) if settings.USING_RABBITMQ: setup_tornado_rabbitmq(queue_client) - instance = ioloop.IOLoop.instance() - if hasattr(__main__, "add_reload_hook"): autoreload.start() - instance.start() - except KeyboardInterrupt: + await stop_fut + # Monkey patch tornado.autoreload to prevent it from continuing # to watch for changes after catching our SystemExit. Otherwise # the user needs to press Ctrl+C twice. __main__.wait = lambda: None - sys.exit(0) - - inner_run() + async_to_sync(inner_run, force_new_loop=True)() diff --git a/zerver/tests/test_tornado.py b/zerver/tests/test_tornado.py index a564c97ee1..78f5c0fda5 100644 --- a/zerver/tests/test_tornado.py +++ b/zerver/tests/test_tornado.py @@ -1,63 +1,95 @@ +import asyncio import urllib.parse -from typing import Any, Dict, Optional +from functools import wraps +from typing import Any, Awaitable, Callable, Dict, Optional, TypeVar +from unittest import TestResult import orjson +from asgiref.sync import async_to_sync, sync_to_async from django.conf import settings from django.core import signals from django.db import close_old_connections from django.test import override_settings from tornado.httpclient import HTTPResponse -from tornado.testing import AsyncHTTPTestCase +from tornado.ioloop import IOLoop +from tornado.platform.asyncio import AsyncIOMainLoop, to_asyncio_future +from tornado.testing import AsyncHTTPTestCase, AsyncTestCase from tornado.web import Application +from typing_extensions import ParamSpec from zerver.lib.test_classes import ZulipTestCase from zerver.tornado import event_queue from zerver.tornado.application import create_tornado_application from zerver.tornado.event_queue import process_event +P = ParamSpec("P") +T = TypeVar("T") + + +def async_to_sync_decorator(f: Callable[P, Awaitable[T]]) -> Callable[P, T]: + @wraps(f) + def wrapped(*args: P.args, **kwargs: P.kwargs) -> T: + return async_to_sync(f)(*args, **kwargs) + + return wrapped + + +async def in_django_thread(f: Callable[[], T]) -> T: + return await asyncio.create_task(sync_to_async(f)()) + class TornadoWebTestCase(AsyncHTTPTestCase, ZulipTestCase): - def setUp(self) -> None: + @async_to_sync_decorator + async def setUp(self) -> None: super().setUp() signals.request_started.disconnect(close_old_connections) signals.request_finished.disconnect(close_old_connections) self.session_cookie: Optional[Dict[str, str]] = None - def tearDown(self) -> None: - super().tearDown() - self.session_cookie = None + @async_to_sync_decorator + async def tearDown(self) -> None: + # Skip tornado.testing.AsyncTestCase.tearDown because it tries to kill + # the current task. + super(AsyncTestCase, self).tearDown() + + def run(self, result: Optional[TestResult] = None) -> Optional[TestResult]: + return async_to_sync( + sync_to_async(super().run, thread_sensitive=False), force_new_loop=True + )(result) + + def get_new_ioloop(self) -> IOLoop: + return AsyncIOMainLoop() @override_settings(DEBUG=False) def get_app(self) -> Application: return create_tornado_application() - def tornado_client_get(self, path: str, **kwargs: Any) -> HTTPResponse: + async def tornado_client_get(self, path: str, **kwargs: Any) -> HTTPResponse: self.add_session_cookie(kwargs) kwargs["skip_user_agent"] = True self.set_http_headers(kwargs) if "HTTP_HOST" in kwargs: kwargs["headers"]["Host"] = kwargs["HTTP_HOST"] del kwargs["HTTP_HOST"] - return self.fetch(path, method="GET", **kwargs) - - def fetch_async(self, method: str, path: str, **kwargs: Any) -> None: - self.add_session_cookie(kwargs) - kwargs["skip_user_agent"] = True - self.set_http_headers(kwargs) - if "HTTP_HOST" in kwargs: - kwargs["headers"]["Host"] = kwargs["HTTP_HOST"] - del kwargs["HTTP_HOST"] - self.http_client.fetch( - self.get_url(path), - self.stop, - method=method, - **kwargs, + return await to_asyncio_future( + self.http_client.fetch(self.get_url(path), method="GET", **kwargs) ) - def client_get_async(self, path: str, **kwargs: Any) -> None: + async def fetch_async(self, method: str, path: str, **kwargs: Any) -> HTTPResponse: + self.add_session_cookie(kwargs) kwargs["skip_user_agent"] = True self.set_http_headers(kwargs) - self.fetch_async("GET", path, **kwargs) + if "HTTP_HOST" in kwargs: + kwargs["headers"]["Host"] = kwargs["HTTP_HOST"] + del kwargs["HTTP_HOST"] + return await to_asyncio_future( + self.http_client.fetch(self.get_url(path), method=method, **kwargs) + ) + + async def client_get_async(self, path: str, **kwargs: Any) -> HTTPResponse: + kwargs["skip_user_agent"] = True + self.set_http_headers(kwargs) + return await self.fetch_async("GET", path, **kwargs) def login_user(self, *args: Any, **kwargs: Any) -> None: super().login_user(*args, **kwargs) @@ -76,8 +108,8 @@ class TornadoWebTestCase(AsyncHTTPTestCase, ZulipTestCase): headers.update(self.get_session_cookie()) kwargs["headers"] = headers - def create_queue(self, **kwargs: Any) -> str: - response = self.tornado_client_get( + async def create_queue(self, **kwargs: Any) -> str: + response = await self.tornado_client_get( "/json/events?dont_block=true", subdomain="zulip", skip_user_agent=True, @@ -90,22 +122,23 @@ class TornadoWebTestCase(AsyncHTTPTestCase, ZulipTestCase): class EventsTestCase(TornadoWebTestCase): - def test_create_queue(self) -> None: - self.login_user(self.example_user("hamlet")) - queue_id = self.create_queue() + @async_to_sync_decorator + async def test_create_queue(self) -> None: + await in_django_thread(lambda: self.login_user(self.example_user("hamlet"))) + queue_id = await self.create_queue() self.assertIn(queue_id, event_queue.clients) - def test_events_async(self) -> None: - user_profile = self.example_user("hamlet") - self.login_user(user_profile) - event_queue_id = self.create_queue() + @async_to_sync_decorator + async def test_events_async(self) -> None: + user_profile = await in_django_thread(lambda: self.example_user("hamlet")) + await in_django_thread(lambda: self.login_user(user_profile)) + event_queue_id = await self.create_queue() data = { "queue_id": event_queue_id, "last_event_id": -1, } path = f"/json/events?{urllib.parse.urlencode(data)}" - self.client_get_async(path) def process_events() -> None: users = [user_profile.id] @@ -116,7 +149,7 @@ class EventsTestCase(TornadoWebTestCase): process_event(event, users) self.io_loop.call_later(0.1, process_events) - response = self.wait() + response = await self.client_get_async(path) self.assertEqual(response.headers["Vary"], "Accept-Language, Cookie") data = orjson.loads(response.body) self.assertEqual( diff --git a/zerver/tornado/application.py b/zerver/tornado/application.py index 6d29ceab18..439dd5b531 100644 --- a/zerver/tornado/application.py +++ b/zerver/tornado/application.py @@ -1,5 +1,3 @@ -import atexit - import tornado.web from django.conf import settings from tornado import autoreload @@ -10,7 +8,6 @@ from zerver.tornado.handlers import AsyncDjangoHandler def setup_tornado_rabbitmq(queue_client: TornadoQueueClient) -> None: # nocoverage # When tornado is shut down, disconnect cleanly from RabbitMQ - atexit.register(lambda: queue_client.close()) autoreload.add_reload_hook(lambda: queue_client.close()) diff --git a/zerver/tornado/event_queue.py b/zerver/tornado/event_queue.py index fe794177b9..714233c430 100644 --- a/zerver/tornado/event_queue.py +++ b/zerver/tornado/event_queue.py @@ -1,12 +1,9 @@ # See https://zulip.readthedocs.io/en/latest/subsystems/events-system.html for # high-level documentation on how this system works. -import atexit import copy import logging import os import random -import signal -import sys import time import traceback from collections import deque @@ -23,7 +20,6 @@ from typing import ( List, Mapping, MutableMapping, - NoReturn, Optional, Sequence, Set, @@ -604,24 +600,11 @@ def send_restart_events(immediate: bool = False) -> None: client.add_event(event) -def handle_sigterm(server: tornado.httpserver.HTTPServer) -> NoReturn: - logging.warning("Got SIGTERM, shutting down...") - server.stop() - tornado.ioloop.IOLoop.instance().stop() - sys.exit(1) - - -def setup_event_queue(server: tornado.httpserver.HTTPServer, port: int) -> None: +async def setup_event_queue(server: tornado.httpserver.HTTPServer, port: int) -> None: ioloop = tornado.ioloop.IOLoop.instance() if not settings.TEST_SUITE: load_event_queues(port) - atexit.register(dump_event_queues, port) - # Make sure we dump event queues even if we exit via signal - signal.signal( - signal.SIGTERM, - lambda signum, frame: ioloop.add_callback_from_signal(handle_sigterm, server), - ) autoreload.add_reload_hook(lambda: dump_event_queues(port)) try: diff --git a/zerver/tornado/handlers.py b/zerver/tornado/handlers.py index 8948786319..146708f664 100644 --- a/zerver/tornado/handlers.py +++ b/zerver/tornado/handlers.py @@ -1,9 +1,11 @@ +import asyncio import logging import urllib import weakref from typing import Any, Dict, List import tornado.web +from asgiref.sync import sync_to_async from django import http from django.core import signals from django.core.handlers import base @@ -63,7 +65,8 @@ def finish_handler( else: log_data["extra"] = "[{}/1/{}]".format(event_queue_id, contents[0]["type"]) - handler.zulip_finish( + tornado.ioloop.IOLoop.current().add_callback( + handler.zulip_finish, dict(result="success", msg="", events=contents, queue_id=event_queue_id), request, apply_markdown=apply_markdown, @@ -146,9 +149,11 @@ class AsyncDjangoHandler(tornado.web.RequestHandler, base.BaseHandler): # Close the connection. self.finish() - def get(self, *args: Any, **kwargs: Any) -> None: + async def get(self, *args: Any, **kwargs: Any) -> None: request = self.convert_tornado_request_to_django_request() - response = self.get_response(request) + response = await asyncio.ensure_future( + sync_to_async(lambda: self.get_response(request), thread_sensitive=True)() + ) try: if hasattr(response, "asynchronous"): @@ -179,16 +184,16 @@ class AsyncDjangoHandler(tornado.web.RequestHandler, base.BaseHandler): # the Django side; this triggers cleanup work like # resetting the urlconf and any cache/database # connections. - response.close() + await asyncio.ensure_future(sync_to_async(response.close, thread_sensitive=True)()) - def head(self, *args: Any, **kwargs: Any) -> None: - self.get(*args, **kwargs) + async def head(self, *args: Any, **kwargs: Any) -> None: + await self.get(*args, **kwargs) - def post(self, *args: Any, **kwargs: Any) -> None: - self.get(*args, **kwargs) + async def post(self, *args: Any, **kwargs: Any) -> None: + await self.get(*args, **kwargs) - def delete(self, *args: Any, **kwargs: Any) -> None: - self.get(*args, **kwargs) + async def delete(self, *args: Any, **kwargs: Any) -> None: + await self.get(*args, **kwargs) def on_connection_close(self) -> None: # Register a Tornado handler that runs when client-side @@ -201,7 +206,7 @@ class AsyncDjangoHandler(tornado.web.RequestHandler, base.BaseHandler): if client_descriptor is not None: client_descriptor.disconnect_handler(client_closed=True) - def zulip_finish( + async def zulip_finish( self, result_dict: Dict[str, Any], old_request: HttpRequest, apply_markdown: bool ) -> None: # Function called when we want to break a long-polled @@ -257,7 +262,9 @@ class AsyncDjangoHandler(tornado.web.RequestHandler, base.BaseHandler): res_type=result_dict["result"], data=result_dict, status=self.get_status() ) - response = self.get_response(request) + response = await asyncio.ensure_future( + sync_to_async(lambda: self.get_response(request), thread_sensitive=True)() + ) try: # Explicitly mark requests as varying by cookie, since the # middleware will not have seen a session access @@ -266,4 +273,4 @@ class AsyncDjangoHandler(tornado.web.RequestHandler, base.BaseHandler): self.write_django_response_as_tornado_response(response) finally: # Tell Django we're done processing this request - response.close() + await asyncio.ensure_future(sync_to_async(response.close, thread_sensitive=True)()) diff --git a/zerver/tornado/views.py b/zerver/tornado/views.py index cb5c07e515..0edcfdd098 100644 --- a/zerver/tornado/views.py +++ b/zerver/tornado/views.py @@ -1,7 +1,8 @@ import time -from typing import Optional, Sequence +from typing import Callable, Optional, Sequence, TypeVar import orjson +from asgiref.sync import async_to_sync from django.http import HttpRequest, HttpResponse from django.utils.translation import gettext as _ @@ -20,10 +21,19 @@ from zerver.models import Client, UserProfile, get_client, get_user_profile_by_i from zerver.tornado.event_queue import fetch_events, get_client_descriptor, process_notification from zerver.tornado.exceptions import BadEventQueueIdError +T = TypeVar("T") + + +def in_tornado_thread(f: Callable[[], T]) -> T: + async def wrapped() -> T: + return f() + + return async_to_sync(wrapped)() + @internal_notify_view(True) def notify(request: HttpRequest) -> HttpResponse: - process_notification(orjson.loads(request.POST["data"])) + in_tornado_thread(lambda: process_notification(orjson.loads(request.POST["data"]))) return json_success(request) @@ -39,7 +49,7 @@ def cleanup_event_queue( log_data = RequestNotes.get_notes(request).log_data assert log_data is not None log_data["extra"] = f"[{queue_id}]" - client.cleanup() + in_tornado_thread(client.cleanup) return json_success(request) @@ -153,7 +163,7 @@ def get_events_backend( user_settings_object=user_settings_object, ) - result = fetch_events(events_query) + result = in_tornado_thread(lambda: fetch_events(events_query)) if "extra_log_data" in result: log_data = RequestNotes.get_notes(request).log_data assert log_data is not None