From 55b26da82badc6aae8147e1926addec74468b42e Mon Sep 17 00:00:00 2001 From: Anders Kaseorg Date: Mon, 11 Dec 2023 13:44:55 -0800 Subject: [PATCH] run-dev: Rewrite development proxy with aiohttp. This allows request cancellation to be propagated to the server. Signed-off-by: Anders Kaseorg --- requirements/dev.in | 3 + requirements/dev.txt | 1 + tools/run-dev | 225 ++++++++++++++----------------------- zerver/tornado/handlers.py | 5 - 4 files changed, 89 insertions(+), 145 deletions(-) diff --git a/requirements/dev.in b/requirements/dev.in index 73fd80e732..3fb3f73137 100644 --- a/requirements/dev.in +++ b/requirements/dev.in @@ -8,6 +8,9 @@ # moto s3 mock moto[s3] +# For tools/run-dev +aiohttp + # Needed for documentation links test Scrapy diff --git a/requirements/dev.txt b/requirements/dev.txt index 604661392a..1e53348c99 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -88,6 +88,7 @@ aiohttp==3.9.1 \ --hash=sha256:f800164276eec54e0af5c99feb9494c295118fc10a11b997bbb1348ba1a52065 \ --hash=sha256:ffcd828e37dc219a72c9012ec44ad2e7e3066bec6ff3aaa19e7d435dbf4032ca # via + # -r requirements/dev.in # aiohttp-retry # twilio aiohttp-retry==2.8.3 \ diff --git a/tools/run-dev b/tools/run-dev index d6231ded6c..9963e4ac3d 100755 --- a/tools/run-dev +++ b/tools/run-dev @@ -2,13 +2,13 @@ import argparse import asyncio import errno +import logging import os import pwd import signal import subprocess import sys -from typing import List, Sequence -from urllib.parse import urlunsplit +from typing import List TOOLS_DIR = os.path.dirname(os.path.abspath(__file__)) sys.path.insert(0, os.path.dirname(TOOLS_DIR)) @@ -18,9 +18,9 @@ from tools.lib import sanity_check sanity_check.check_venv(__file__) -from tornado import httpclient, httputil, web -from tornado.platform.asyncio import AsyncIOMainLoop -from typing_extensions import override +import aiohttp +from aiohttp import hdrs, web +from returns.curry import partial from tools.lib.test_script import add_provision_check_override_param, assert_provisioning_status_ok @@ -54,11 +54,6 @@ parser.add_argument( help="Do not clear memcached on startup", ) parser.add_argument("--streamlined", action="store_true", help="Avoid process_queue, etc.") -parser.add_argument( - "--enable-tornado-logging", - action="store_true", - help="Enable access logs from tornado proxy server.", -) parser.add_argument( "--behind-https-proxy", action="store_true", @@ -204,135 +199,81 @@ def start_webpack_watcher() -> "subprocess.Popen[bytes]": return subprocess.Popen(webpack_cmd) -def transform_url(protocol: str, path: str, query: str, target_port: int, target_host: str) -> str: - # generate url with target host - host = ":".join((target_host, str(target_port))) - # Here we are going to rewrite the path a bit so that it is in parity with - # what we will have for production - newpath = urlunsplit((protocol, host, path, query, "")) - return newpath +session: aiohttp.ClientSession + +# https://datatracker.ietf.org/doc/html/rfc2616#section-13.5.1 +HOP_BY_HOP_HEADERS = { + hdrs.CONNECTION, + hdrs.KEEP_ALIVE, + hdrs.PROXY_AUTHENTICATE, + hdrs.PROXY_AUTHORIZATION, + hdrs.TE, + hdrs.TRAILER, + hdrs.TRANSFER_ENCODING, + hdrs.UPGRADE, +} + +# Headers that aiohttp would otherwise generate by default +SKIP_AUTO_HEADERS = { + hdrs.ACCEPT, + hdrs.ACCEPT_ENCODING, + hdrs.CONTENT_TYPE, + hdrs.USER_AGENT, +} -client: httpclient.AsyncHTTPClient - - -class BaseHandler(web.RequestHandler): - # target server ip - target_host: str = "127.0.0.1" - # target server port - target_port: int - - def _add_request_headers( - self, - exclude_lower_headers_list: Sequence[str] = [], - ) -> httputil.HTTPHeaders: - headers = httputil.HTTPHeaders() - for header, v in self.request.headers.get_all(): - if header.lower() not in exclude_lower_headers_list: - headers.add(header, v) - return headers - - @override - def get(self) -> None: - pass - - @override - def head(self) -> None: - pass - - @override - def post(self) -> None: - pass - - @override - def put(self) -> None: - pass - - @override - def patch(self) -> None: - pass - - @override - def options(self) -> None: - pass - - @override - def delete(self) -> None: - pass - - @override - async def prepare(self) -> None: - assert self.request.method is not None - assert self.request.remote_ip is not None - if "X-REAL-IP" not in self.request.headers: - self.request.headers["X-REAL-IP"] = self.request.remote_ip - if "X-FORWARDED_PORT" not in self.request.headers: - self.request.headers["X-FORWARDED-PORT"] = str(proxy_port) - url = transform_url( - self.request.protocol, - self.request.path, - self.request.query, - self.target_port, - self.target_host, - ) - try: - request = httpclient.HTTPRequest( - url=url, - method=self.request.method, - headers=self._add_request_headers(["upgrade-insecure-requests"]), - follow_redirects=False, - body=self.request.body, - allow_nonstandard_methods=True, - # use large timeouts to handle polling requests - connect_timeout=240.0, - request_timeout=240.0, - # https://github.com/tornadoweb/tornado/issues/2743 - decompress_response=False, - ) - response = await client.fetch(request, raise_error=False) - - self.set_status(response.code, response.reason) - self._headers = httputil.HTTPHeaders() # clear tornado default header - - for header, v in response.headers.get_all(): - # some header appear multiple times, eg 'Set-Cookie' - if header.lower() != "transfer-encoding": - self.add_header(header, v) - if response.body: - self.write(response.body) - except (ConnectionError, httpclient.HTTPError) as e: - self.set_status(500) - self.write("Internal server error:\n" + str(e)) - - -class WebPackHandler(BaseHandler): - target_port = webpack_port - - -class DjangoHandler(BaseHandler): - target_port = django_port - - -class TornadoHandler(BaseHandler): - target_port = tornado_port - - -class Application(web.Application): - def __init__(self, enable_logging: bool = False) -> None: - super().__init__( - [ - (r"/json/events.*", TornadoHandler), - (r"/api/v1/events.*", TornadoHandler), - (r"/webpack.*", WebPackHandler), - (r"/.*", DjangoHandler), +async def forward(upstream_port: int, request: web.Request) -> web.StreamResponse: + try: + upstream_response = await session.request( + request.method, + request.url.with_host("127.0.0.1").with_port(upstream_port), + headers=[ + (key, value) + for key, value in request.headers.items() + if key not in HOP_BY_HOP_HEADERS ], - enable_logging=enable_logging, + data=request.content.iter_any() if request.body_exists else None, + allow_redirects=False, + auto_decompress=False, + compress=False, + skip_auto_headers=SKIP_AUTO_HEADERS, ) + except aiohttp.ClientError as error: + logging.error( + "Failed to forward %s %s to port %d: %s", + request.method, + request.url.path, + upstream_port, + error, + ) + raise web.HTTPBadGateway from error - @override - def log_request(self, handler: web.RequestHandler) -> None: - if self.settings["enable_logging"]: - super().log_request(handler) + response = web.StreamResponse(status=upstream_response.status, reason=upstream_response.reason) + response.headers.extend( + (key, value) + for key, value in upstream_response.headers.items() + if key not in HOP_BY_HOP_HEADERS + ) + assert request.remote is not None + response.headers["X-Real-IP"] = request.remote + response.headers["X-Forwarded-Port"] = str(proxy_port) + await response.prepare(request) + async for data in upstream_response.content.iter_any(): + await response.write(data) + await response.write_eof() + return response + + +app = web.Application() +app.add_routes( + [ + web.route( + hdrs.METH_ANY, r"/{path:json/events|api/v1/events}", partial(forward, tornado_port) + ), + web.route(hdrs.METH_ANY, r"/{path:webpack/.*}", partial(forward, webpack_port)), + web.route(hdrs.METH_ANY, r"/{path:.*}", partial(forward, django_port)), + ] +) def print_listeners() -> None: @@ -365,13 +306,12 @@ def print_listeners() -> None: print() +runner: web.AppRunner children: List["subprocess.Popen[bytes]"] = [] async def serve() -> None: - global client - - AsyncIOMainLoop().install() + global runner, session if options.test: do_one_time_webpack_compile() @@ -380,10 +320,12 @@ async def serve() -> None: children.extend(subprocess.Popen(cmd) for cmd in server_processes()) - client = httpclient.AsyncHTTPClient() - app = Application(enable_logging=options.enable_tornado_logging) + session = aiohttp.ClientSession() + runner = web.AppRunner(app, auto_decompress=False, handler_cancellation=True) + await runner.setup() + site = web.TCPSite(runner, host=options.interface, port=proxy_port) try: - app.listen(proxy_port, address=options.interface) + await site.start() except OSError as e: if e.errno == errno.EADDRINUSE: print("\n\nERROR: You probably have another server running!!!\n\n") @@ -400,6 +342,9 @@ try: loop.add_signal_handler(s, loop.stop) loop.run_forever() finally: + loop.run_until_complete(runner.cleanup()) + loop.run_until_complete(session.close()) + for child in children: child.terminate() diff --git a/zerver/tornado/handlers.py b/zerver/tornado/handlers.py index 00a1f9694e..70180939f5 100644 --- a/zerver/tornado/handlers.py +++ b/zerver/tornado/handlers.py @@ -214,11 +214,6 @@ class AsyncDjangoHandler(tornado.web.RequestHandler): def on_connection_close(self) -> None: # Register a Tornado handler that runs when client-side # connections are closed to notify the events system. - # - # Note that in the development environment, the development - # proxy does not correctly close connections to Tornado when - # its clients (e.g. `curl`) close their connections. This - # code path is thus _unreachable except in production_. # If the client goes away, garbage collect the handler (with # attached request information).