runtornado: Switch to asyncio event loop.

Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
Anders Kaseorg 2022-03-18 00:34:10 -07:00 committed by Alex Vandiver
parent c263bfdb41
commit 6fd1a558b7
9 changed files with 177 additions and 96 deletions

View File

@ -49,6 +49,7 @@ no_implicit_reexport = false
module = [ module = [
"ahocorasick.*", "ahocorasick.*",
"aioapns.*", "aioapns.*",
"asgiref.*",
"bitfield.*", "bitfield.*",
"bmemcached.*", "bmemcached.*",
"bson.*", "bson.*",

16
zerver/lib/async_utils.py Normal file
View File

@ -0,0 +1,16 @@
import asyncio
class NoAutoCreateEventLoopPolicy(asyncio.DefaultEventLoopPolicy): # type: ignore[misc,valid-type] # https://github.com/python/typeshed/issues/7452
"""
By default asyncio.get_event_loop() automatically creates an event
loop for the main thread if one isn't currently installed. Since
Django intentionally uninstalls the event loop within
sync_to_async, that autocreation proliferates confusing extra
event loops that will never be run. It is also deprecated in
Python 3.10. This policy disables it so we don't rely on it by
accident.
"""
def get_event_loop(self) -> asyncio.AbstractEventLoop: # nocoverage
return asyncio.get_running_loop()

View File

@ -307,6 +307,8 @@ class TornadoQueueClient(QueueClient[Channel]):
def _on_connection_closed( def _on_connection_closed(
self, connection: pika.connection.Connection, reason: Exception self, connection: pika.connection.Connection, reason: Exception
) -> None: ) -> None:
if self.connection is None:
return
self._connection_failure_count = 1 self._connection_failure_count = 1
retry_secs = self.CONNECTION_RETRY_SECS retry_secs = self.CONNECTION_RETRY_SECS
self.log.warning( self.log.warning(
@ -335,6 +337,7 @@ class TornadoQueueClient(QueueClient[Channel]):
def close(self) -> None: def close(self) -> None:
if self.connection is not None: if self.connection is not None:
self.connection.close() self.connection.close()
self.connection = None
def ensure_queue(self, queue_name: str, callback: Callable[[Channel], object]) -> None: def ensure_queue(self, queue_name: str, callback: Callable[[Channel], object]) -> None:
def set_qos(frame: Any) -> None: def set_qos(frame: Any) -> None:

View File

@ -1,19 +1,25 @@
import asyncio
import logging import logging
import sys import signal
from contextlib import AsyncExitStack
from typing import Any from typing import Any
from urllib.parse import SplitResult from urllib.parse import SplitResult
import __main__ import __main__
from asgiref.sync import async_to_sync, sync_to_async
from django.conf import settings from django.conf import settings
from django.core.management.base import BaseCommand, CommandError, CommandParser from django.core.management.base import BaseCommand, CommandError, CommandParser
from tornado import autoreload, ioloop from tornado import autoreload
from tornado.platform.asyncio import AsyncIOMainLoop, to_asyncio_future
settings.RUNNING_INSIDE_TORNADO = True settings.RUNNING_INSIDE_TORNADO = True
from zerver.lib.async_utils import NoAutoCreateEventLoopPolicy
from zerver.lib.debug import interactive_debug_listen from zerver.lib.debug import interactive_debug_listen
from zerver.tornado.application import create_tornado_application, setup_tornado_rabbitmq from zerver.tornado.application import create_tornado_application, setup_tornado_rabbitmq
from zerver.tornado.event_queue import ( from zerver.tornado.event_queue import (
add_client_gc_hook, add_client_gc_hook,
dump_event_queues,
get_wrapped_process_notification, get_wrapped_process_notification,
missedmessage_hook, missedmessage_hook,
setup_event_queue, setup_event_queue,
@ -23,6 +29,8 @@ from zerver.tornado.sharding import notify_tornado_queue_name
if settings.USING_RABBITMQ: if settings.USING_RABBITMQ:
from zerver.lib.queue import TornadoQueueClient, set_queue_client from zerver.lib.queue import TornadoQueueClient, set_queue_client
asyncio.set_event_loop_policy(NoAutoCreateEventLoopPolicy())
class Command(BaseCommand): class Command(BaseCommand):
help = "Starts a Tornado Web server wrapping Django." help = "Starts a Tornado Web server wrapping Django."
@ -56,9 +64,31 @@ class Command(BaseCommand):
level=logging.INFO, format="%(asctime)s %(levelname)-8s %(message)s" level=logging.INFO, format="%(asctime)s %(levelname)-8s %(message)s"
) )
def inner_run() -> None: async def inner_run() -> None:
from django.utils import translation from django.utils import translation
AsyncIOMainLoop().install()
loop = asyncio.get_running_loop()
stop_fut = loop.create_future()
def stop() -> None:
if not stop_fut.done():
stop_fut.set_result(None)
def add_signal_handlers() -> None:
loop.add_signal_handler(signal.SIGINT, stop),
loop.add_signal_handler(signal.SIGTERM, stop),
def remove_signal_handlers() -> None:
loop.remove_signal_handler(signal.SIGINT),
loop.remove_signal_handler(signal.SIGTERM),
async with AsyncExitStack() as stack:
stack.push_async_callback(
sync_to_async(remove_signal_handlers, thread_sensitive=True)
)
await sync_to_async(add_signal_handlers, thread_sensitive=True)()
translation.activate(settings.LANGUAGE_CODE) translation.activate(settings.LANGUAGE_CODE)
# We pass display_num_errors=False, since Django will # We pass display_num_errors=False, since Django will
@ -71,38 +101,39 @@ class Command(BaseCommand):
set_queue_client(queue_client) set_queue_client(queue_client)
# Process notifications received via RabbitMQ # Process notifications received via RabbitMQ
queue_name = notify_tornado_queue_name(port) queue_name = notify_tornado_queue_name(port)
stack.callback(queue_client.close)
queue_client.start_json_consumer( queue_client.start_json_consumer(
queue_name, get_wrapped_process_notification(queue_name) queue_name, get_wrapped_process_notification(queue_name)
) )
try:
# Application is an instance of Django's standard wsgi handler. # Application is an instance of Django's standard wsgi handler.
application = create_tornado_application() application = create_tornado_application()
# start tornado web server in single-threaded mode # start tornado web server in single-threaded mode
http_server = httpserver.HTTPServer(application, xheaders=True) http_server = httpserver.HTTPServer(application, xheaders=True)
stack.push_async_callback(
lambda: to_asyncio_future(http_server.close_all_connections())
)
stack.callback(http_server.stop)
http_server.listen(port, address=addr) http_server.listen(port, address=addr)
from zerver.tornado.ioloop_logging import logging_data from zerver.tornado.ioloop_logging import logging_data
logging_data["port"] = str(port) logging_data["port"] = str(port)
setup_event_queue(http_server, port) await setup_event_queue(http_server, port)
stack.callback(dump_event_queues, port)
add_client_gc_hook(missedmessage_hook) add_client_gc_hook(missedmessage_hook)
if settings.USING_RABBITMQ: if settings.USING_RABBITMQ:
setup_tornado_rabbitmq(queue_client) setup_tornado_rabbitmq(queue_client)
instance = ioloop.IOLoop.instance()
if hasattr(__main__, "add_reload_hook"): if hasattr(__main__, "add_reload_hook"):
autoreload.start() autoreload.start()
instance.start() await stop_fut
except KeyboardInterrupt:
# Monkey patch tornado.autoreload to prevent it from continuing # Monkey patch tornado.autoreload to prevent it from continuing
# to watch for changes after catching our SystemExit. Otherwise # to watch for changes after catching our SystemExit. Otherwise
# the user needs to press Ctrl+C twice. # the user needs to press Ctrl+C twice.
__main__.wait = lambda: None __main__.wait = lambda: None
sys.exit(0) async_to_sync(inner_run, force_new_loop=True)()
inner_run()

View File

@ -1,63 +1,95 @@
import asyncio
import urllib.parse import urllib.parse
from typing import Any, Dict, Optional from functools import wraps
from typing import Any, Awaitable, Callable, Dict, Optional, TypeVar
from unittest import TestResult
import orjson import orjson
from asgiref.sync import async_to_sync, sync_to_async
from django.conf import settings from django.conf import settings
from django.core import signals from django.core import signals
from django.db import close_old_connections from django.db import close_old_connections
from django.test import override_settings from django.test import override_settings
from tornado.httpclient import HTTPResponse from tornado.httpclient import HTTPResponse
from tornado.testing import AsyncHTTPTestCase from tornado.ioloop import IOLoop
from tornado.platform.asyncio import AsyncIOMainLoop, to_asyncio_future
from tornado.testing import AsyncHTTPTestCase, AsyncTestCase
from tornado.web import Application from tornado.web import Application
from typing_extensions import ParamSpec
from zerver.lib.test_classes import ZulipTestCase from zerver.lib.test_classes import ZulipTestCase
from zerver.tornado import event_queue from zerver.tornado import event_queue
from zerver.tornado.application import create_tornado_application from zerver.tornado.application import create_tornado_application
from zerver.tornado.event_queue import process_event from zerver.tornado.event_queue import process_event
P = ParamSpec("P")
T = TypeVar("T")
def async_to_sync_decorator(f: Callable[P, Awaitable[T]]) -> Callable[P, T]:
@wraps(f)
def wrapped(*args: P.args, **kwargs: P.kwargs) -> T:
return async_to_sync(f)(*args, **kwargs)
return wrapped
async def in_django_thread(f: Callable[[], T]) -> T:
return await asyncio.create_task(sync_to_async(f)())
class TornadoWebTestCase(AsyncHTTPTestCase, ZulipTestCase): class TornadoWebTestCase(AsyncHTTPTestCase, ZulipTestCase):
def setUp(self) -> None: @async_to_sync_decorator
async def setUp(self) -> None:
super().setUp() super().setUp()
signals.request_started.disconnect(close_old_connections) signals.request_started.disconnect(close_old_connections)
signals.request_finished.disconnect(close_old_connections) signals.request_finished.disconnect(close_old_connections)
self.session_cookie: Optional[Dict[str, str]] = None self.session_cookie: Optional[Dict[str, str]] = None
def tearDown(self) -> None: @async_to_sync_decorator
super().tearDown() async def tearDown(self) -> None:
self.session_cookie = None # Skip tornado.testing.AsyncTestCase.tearDown because it tries to kill
# the current task.
super(AsyncTestCase, self).tearDown()
def run(self, result: Optional[TestResult] = None) -> Optional[TestResult]:
return async_to_sync(
sync_to_async(super().run, thread_sensitive=False), force_new_loop=True
)(result)
def get_new_ioloop(self) -> IOLoop:
return AsyncIOMainLoop()
@override_settings(DEBUG=False) @override_settings(DEBUG=False)
def get_app(self) -> Application: def get_app(self) -> Application:
return create_tornado_application() return create_tornado_application()
def tornado_client_get(self, path: str, **kwargs: Any) -> HTTPResponse: async def tornado_client_get(self, path: str, **kwargs: Any) -> HTTPResponse:
self.add_session_cookie(kwargs) self.add_session_cookie(kwargs)
kwargs["skip_user_agent"] = True kwargs["skip_user_agent"] = True
self.set_http_headers(kwargs) self.set_http_headers(kwargs)
if "HTTP_HOST" in kwargs: if "HTTP_HOST" in kwargs:
kwargs["headers"]["Host"] = kwargs["HTTP_HOST"] kwargs["headers"]["Host"] = kwargs["HTTP_HOST"]
del kwargs["HTTP_HOST"] del kwargs["HTTP_HOST"]
return self.fetch(path, method="GET", **kwargs) return await to_asyncio_future(
self.http_client.fetch(self.get_url(path), method="GET", **kwargs)
def fetch_async(self, method: str, path: str, **kwargs: Any) -> None:
self.add_session_cookie(kwargs)
kwargs["skip_user_agent"] = True
self.set_http_headers(kwargs)
if "HTTP_HOST" in kwargs:
kwargs["headers"]["Host"] = kwargs["HTTP_HOST"]
del kwargs["HTTP_HOST"]
self.http_client.fetch(
self.get_url(path),
self.stop,
method=method,
**kwargs,
) )
def client_get_async(self, path: str, **kwargs: Any) -> None: async def fetch_async(self, method: str, path: str, **kwargs: Any) -> HTTPResponse:
self.add_session_cookie(kwargs)
kwargs["skip_user_agent"] = True kwargs["skip_user_agent"] = True
self.set_http_headers(kwargs) self.set_http_headers(kwargs)
self.fetch_async("GET", path, **kwargs) if "HTTP_HOST" in kwargs:
kwargs["headers"]["Host"] = kwargs["HTTP_HOST"]
del kwargs["HTTP_HOST"]
return await to_asyncio_future(
self.http_client.fetch(self.get_url(path), method=method, **kwargs)
)
async def client_get_async(self, path: str, **kwargs: Any) -> HTTPResponse:
kwargs["skip_user_agent"] = True
self.set_http_headers(kwargs)
return await self.fetch_async("GET", path, **kwargs)
def login_user(self, *args: Any, **kwargs: Any) -> None: def login_user(self, *args: Any, **kwargs: Any) -> None:
super().login_user(*args, **kwargs) super().login_user(*args, **kwargs)
@ -76,8 +108,8 @@ class TornadoWebTestCase(AsyncHTTPTestCase, ZulipTestCase):
headers.update(self.get_session_cookie()) headers.update(self.get_session_cookie())
kwargs["headers"] = headers kwargs["headers"] = headers
def create_queue(self, **kwargs: Any) -> str: async def create_queue(self, **kwargs: Any) -> str:
response = self.tornado_client_get( response = await self.tornado_client_get(
"/json/events?dont_block=true", "/json/events?dont_block=true",
subdomain="zulip", subdomain="zulip",
skip_user_agent=True, skip_user_agent=True,
@ -90,22 +122,23 @@ class TornadoWebTestCase(AsyncHTTPTestCase, ZulipTestCase):
class EventsTestCase(TornadoWebTestCase): class EventsTestCase(TornadoWebTestCase):
def test_create_queue(self) -> None: @async_to_sync_decorator
self.login_user(self.example_user("hamlet")) async def test_create_queue(self) -> None:
queue_id = self.create_queue() await in_django_thread(lambda: self.login_user(self.example_user("hamlet")))
queue_id = await self.create_queue()
self.assertIn(queue_id, event_queue.clients) self.assertIn(queue_id, event_queue.clients)
def test_events_async(self) -> None: @async_to_sync_decorator
user_profile = self.example_user("hamlet") async def test_events_async(self) -> None:
self.login_user(user_profile) user_profile = await in_django_thread(lambda: self.example_user("hamlet"))
event_queue_id = self.create_queue() await in_django_thread(lambda: self.login_user(user_profile))
event_queue_id = await self.create_queue()
data = { data = {
"queue_id": event_queue_id, "queue_id": event_queue_id,
"last_event_id": -1, "last_event_id": -1,
} }
path = f"/json/events?{urllib.parse.urlencode(data)}" path = f"/json/events?{urllib.parse.urlencode(data)}"
self.client_get_async(path)
def process_events() -> None: def process_events() -> None:
users = [user_profile.id] users = [user_profile.id]
@ -116,7 +149,7 @@ class EventsTestCase(TornadoWebTestCase):
process_event(event, users) process_event(event, users)
self.io_loop.call_later(0.1, process_events) self.io_loop.call_later(0.1, process_events)
response = self.wait() response = await self.client_get_async(path)
self.assertEqual(response.headers["Vary"], "Accept-Language, Cookie") self.assertEqual(response.headers["Vary"], "Accept-Language, Cookie")
data = orjson.loads(response.body) data = orjson.loads(response.body)
self.assertEqual( self.assertEqual(

View File

@ -1,5 +1,3 @@
import atexit
import tornado.web import tornado.web
from django.conf import settings from django.conf import settings
from tornado import autoreload from tornado import autoreload
@ -10,7 +8,6 @@ from zerver.tornado.handlers import AsyncDjangoHandler
def setup_tornado_rabbitmq(queue_client: TornadoQueueClient) -> None: # nocoverage def setup_tornado_rabbitmq(queue_client: TornadoQueueClient) -> None: # nocoverage
# When tornado is shut down, disconnect cleanly from RabbitMQ # When tornado is shut down, disconnect cleanly from RabbitMQ
atexit.register(lambda: queue_client.close())
autoreload.add_reload_hook(lambda: queue_client.close()) autoreload.add_reload_hook(lambda: queue_client.close())

View File

@ -1,12 +1,9 @@
# See https://zulip.readthedocs.io/en/latest/subsystems/events-system.html for # See https://zulip.readthedocs.io/en/latest/subsystems/events-system.html for
# high-level documentation on how this system works. # high-level documentation on how this system works.
import atexit
import copy import copy
import logging import logging
import os import os
import random import random
import signal
import sys
import time import time
import traceback import traceback
from collections import deque from collections import deque
@ -23,7 +20,6 @@ from typing import (
List, List,
Mapping, Mapping,
MutableMapping, MutableMapping,
NoReturn,
Optional, Optional,
Sequence, Sequence,
Set, Set,
@ -604,24 +600,11 @@ def send_restart_events(immediate: bool = False) -> None:
client.add_event(event) client.add_event(event)
def handle_sigterm(server: tornado.httpserver.HTTPServer) -> NoReturn: async def setup_event_queue(server: tornado.httpserver.HTTPServer, port: int) -> None:
logging.warning("Got SIGTERM, shutting down...")
server.stop()
tornado.ioloop.IOLoop.instance().stop()
sys.exit(1)
def setup_event_queue(server: tornado.httpserver.HTTPServer, port: int) -> None:
ioloop = tornado.ioloop.IOLoop.instance() ioloop = tornado.ioloop.IOLoop.instance()
if not settings.TEST_SUITE: if not settings.TEST_SUITE:
load_event_queues(port) load_event_queues(port)
atexit.register(dump_event_queues, port)
# Make sure we dump event queues even if we exit via signal
signal.signal(
signal.SIGTERM,
lambda signum, frame: ioloop.add_callback_from_signal(handle_sigterm, server),
)
autoreload.add_reload_hook(lambda: dump_event_queues(port)) autoreload.add_reload_hook(lambda: dump_event_queues(port))
try: try:

View File

@ -1,9 +1,11 @@
import asyncio
import logging import logging
import urllib import urllib
import weakref import weakref
from typing import Any, Dict, List from typing import Any, Dict, List
import tornado.web import tornado.web
from asgiref.sync import sync_to_async
from django import http from django import http
from django.core import signals from django.core import signals
from django.core.handlers import base from django.core.handlers import base
@ -63,7 +65,8 @@ def finish_handler(
else: else:
log_data["extra"] = "[{}/1/{}]".format(event_queue_id, contents[0]["type"]) log_data["extra"] = "[{}/1/{}]".format(event_queue_id, contents[0]["type"])
handler.zulip_finish( tornado.ioloop.IOLoop.current().add_callback(
handler.zulip_finish,
dict(result="success", msg="", events=contents, queue_id=event_queue_id), dict(result="success", msg="", events=contents, queue_id=event_queue_id),
request, request,
apply_markdown=apply_markdown, apply_markdown=apply_markdown,
@ -146,9 +149,11 @@ class AsyncDjangoHandler(tornado.web.RequestHandler, base.BaseHandler):
# Close the connection. # Close the connection.
self.finish() self.finish()
def get(self, *args: Any, **kwargs: Any) -> None: async def get(self, *args: Any, **kwargs: Any) -> None:
request = self.convert_tornado_request_to_django_request() request = self.convert_tornado_request_to_django_request()
response = self.get_response(request) response = await asyncio.ensure_future(
sync_to_async(lambda: self.get_response(request), thread_sensitive=True)()
)
try: try:
if hasattr(response, "asynchronous"): if hasattr(response, "asynchronous"):
@ -179,16 +184,16 @@ class AsyncDjangoHandler(tornado.web.RequestHandler, base.BaseHandler):
# the Django side; this triggers cleanup work like # the Django side; this triggers cleanup work like
# resetting the urlconf and any cache/database # resetting the urlconf and any cache/database
# connections. # connections.
response.close() await asyncio.ensure_future(sync_to_async(response.close, thread_sensitive=True)())
def head(self, *args: Any, **kwargs: Any) -> None: async def head(self, *args: Any, **kwargs: Any) -> None:
self.get(*args, **kwargs) await self.get(*args, **kwargs)
def post(self, *args: Any, **kwargs: Any) -> None: async def post(self, *args: Any, **kwargs: Any) -> None:
self.get(*args, **kwargs) await self.get(*args, **kwargs)
def delete(self, *args: Any, **kwargs: Any) -> None: async def delete(self, *args: Any, **kwargs: Any) -> None:
self.get(*args, **kwargs) await self.get(*args, **kwargs)
def on_connection_close(self) -> None: def on_connection_close(self) -> None:
# Register a Tornado handler that runs when client-side # Register a Tornado handler that runs when client-side
@ -201,7 +206,7 @@ class AsyncDjangoHandler(tornado.web.RequestHandler, base.BaseHandler):
if client_descriptor is not None: if client_descriptor is not None:
client_descriptor.disconnect_handler(client_closed=True) client_descriptor.disconnect_handler(client_closed=True)
def zulip_finish( async def zulip_finish(
self, result_dict: Dict[str, Any], old_request: HttpRequest, apply_markdown: bool self, result_dict: Dict[str, Any], old_request: HttpRequest, apply_markdown: bool
) -> None: ) -> None:
# Function called when we want to break a long-polled # Function called when we want to break a long-polled
@ -257,7 +262,9 @@ class AsyncDjangoHandler(tornado.web.RequestHandler, base.BaseHandler):
res_type=result_dict["result"], data=result_dict, status=self.get_status() res_type=result_dict["result"], data=result_dict, status=self.get_status()
) )
response = self.get_response(request) response = await asyncio.ensure_future(
sync_to_async(lambda: self.get_response(request), thread_sensitive=True)()
)
try: try:
# Explicitly mark requests as varying by cookie, since the # Explicitly mark requests as varying by cookie, since the
# middleware will not have seen a session access # middleware will not have seen a session access
@ -266,4 +273,4 @@ class AsyncDjangoHandler(tornado.web.RequestHandler, base.BaseHandler):
self.write_django_response_as_tornado_response(response) self.write_django_response_as_tornado_response(response)
finally: finally:
# Tell Django we're done processing this request # Tell Django we're done processing this request
response.close() await asyncio.ensure_future(sync_to_async(response.close, thread_sensitive=True)())

View File

@ -1,7 +1,8 @@
import time import time
from typing import Optional, Sequence from typing import Callable, Optional, Sequence, TypeVar
import orjson import orjson
from asgiref.sync import async_to_sync
from django.http import HttpRequest, HttpResponse from django.http import HttpRequest, HttpResponse
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
@ -20,10 +21,19 @@ from zerver.models import Client, UserProfile, get_client, get_user_profile_by_i
from zerver.tornado.event_queue import fetch_events, get_client_descriptor, process_notification from zerver.tornado.event_queue import fetch_events, get_client_descriptor, process_notification
from zerver.tornado.exceptions import BadEventQueueIdError from zerver.tornado.exceptions import BadEventQueueIdError
T = TypeVar("T")
def in_tornado_thread(f: Callable[[], T]) -> T:
async def wrapped() -> T:
return f()
return async_to_sync(wrapped)()
@internal_notify_view(True) @internal_notify_view(True)
def notify(request: HttpRequest) -> HttpResponse: def notify(request: HttpRequest) -> HttpResponse:
process_notification(orjson.loads(request.POST["data"])) in_tornado_thread(lambda: process_notification(orjson.loads(request.POST["data"])))
return json_success(request) return json_success(request)
@ -39,7 +49,7 @@ def cleanup_event_queue(
log_data = RequestNotes.get_notes(request).log_data log_data = RequestNotes.get_notes(request).log_data
assert log_data is not None assert log_data is not None
log_data["extra"] = f"[{queue_id}]" log_data["extra"] = f"[{queue_id}]"
client.cleanup() in_tornado_thread(client.cleanup)
return json_success(request) return json_success(request)
@ -153,7 +163,7 @@ def get_events_backend(
user_settings_object=user_settings_object, user_settings_object=user_settings_object,
) )
result = fetch_events(events_query) result = in_tornado_thread(lambda: fetch_events(events_query))
if "extra_log_data" in result: if "extra_log_data" in result:
log_data = RequestNotes.get_notes(request).log_data log_data = RequestNotes.get_notes(request).log_data
assert log_data is not None assert log_data is not None