run-dev: Rewrite development proxy with aiohttp.

This allows request cancellation to be propagated to the server.

Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
Anders Kaseorg 2023-12-11 13:44:55 -08:00 committed by Tim Abbott
parent c1988a14a7
commit 55b26da82b
4 changed files with 89 additions and 145 deletions

View File

@ -8,6 +8,9 @@
# moto s3 mock # moto s3 mock
moto[s3] moto[s3]
# For tools/run-dev
aiohttp
# Needed for documentation links test # Needed for documentation links test
Scrapy Scrapy

View File

@ -88,6 +88,7 @@ aiohttp==3.9.1 \
--hash=sha256:f800164276eec54e0af5c99feb9494c295118fc10a11b997bbb1348ba1a52065 \ --hash=sha256:f800164276eec54e0af5c99feb9494c295118fc10a11b997bbb1348ba1a52065 \
--hash=sha256:ffcd828e37dc219a72c9012ec44ad2e7e3066bec6ff3aaa19e7d435dbf4032ca --hash=sha256:ffcd828e37dc219a72c9012ec44ad2e7e3066bec6ff3aaa19e7d435dbf4032ca
# via # via
# -r requirements/dev.in
# aiohttp-retry # aiohttp-retry
# twilio # twilio
aiohttp-retry==2.8.3 \ aiohttp-retry==2.8.3 \

View File

@ -2,13 +2,13 @@
import argparse import argparse
import asyncio import asyncio
import errno import errno
import logging
import os import os
import pwd import pwd
import signal import signal
import subprocess import subprocess
import sys import sys
from typing import List, Sequence from typing import List
from urllib.parse import urlunsplit
TOOLS_DIR = os.path.dirname(os.path.abspath(__file__)) TOOLS_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.dirname(TOOLS_DIR)) sys.path.insert(0, os.path.dirname(TOOLS_DIR))
@ -18,9 +18,9 @@ from tools.lib import sanity_check
sanity_check.check_venv(__file__) sanity_check.check_venv(__file__)
from tornado import httpclient, httputil, web import aiohttp
from tornado.platform.asyncio import AsyncIOMainLoop from aiohttp import hdrs, web
from typing_extensions import override from returns.curry import partial
from tools.lib.test_script import add_provision_check_override_param, assert_provisioning_status_ok from tools.lib.test_script import add_provision_check_override_param, assert_provisioning_status_ok
@ -54,11 +54,6 @@ parser.add_argument(
help="Do not clear memcached on startup", help="Do not clear memcached on startup",
) )
parser.add_argument("--streamlined", action="store_true", help="Avoid process_queue, etc.") parser.add_argument("--streamlined", action="store_true", help="Avoid process_queue, etc.")
parser.add_argument(
"--enable-tornado-logging",
action="store_true",
help="Enable access logs from tornado proxy server.",
)
parser.add_argument( parser.add_argument(
"--behind-https-proxy", "--behind-https-proxy",
action="store_true", action="store_true",
@ -204,135 +199,81 @@ def start_webpack_watcher() -> "subprocess.Popen[bytes]":
return subprocess.Popen(webpack_cmd) return subprocess.Popen(webpack_cmd)
def transform_url(protocol: str, path: str, query: str, target_port: int, target_host: str) -> str: session: aiohttp.ClientSession
# generate url with target host
host = ":".join((target_host, str(target_port))) # https://datatracker.ietf.org/doc/html/rfc2616#section-13.5.1
# Here we are going to rewrite the path a bit so that it is in parity with HOP_BY_HOP_HEADERS = {
# what we will have for production hdrs.CONNECTION,
newpath = urlunsplit((protocol, host, path, query, "")) hdrs.KEEP_ALIVE,
return newpath hdrs.PROXY_AUTHENTICATE,
hdrs.PROXY_AUTHORIZATION,
hdrs.TE,
hdrs.TRAILER,
hdrs.TRANSFER_ENCODING,
hdrs.UPGRADE,
}
# Headers that aiohttp would otherwise generate by default
SKIP_AUTO_HEADERS = {
hdrs.ACCEPT,
hdrs.ACCEPT_ENCODING,
hdrs.CONTENT_TYPE,
hdrs.USER_AGENT,
}
client: httpclient.AsyncHTTPClient async def forward(upstream_port: int, request: web.Request) -> web.StreamResponse:
try:
upstream_response = await session.request(
class BaseHandler(web.RequestHandler): request.method,
# target server ip request.url.with_host("127.0.0.1").with_port(upstream_port),
target_host: str = "127.0.0.1" headers=[
# target server port (key, value)
target_port: int for key, value in request.headers.items()
if key not in HOP_BY_HOP_HEADERS
def _add_request_headers(
self,
exclude_lower_headers_list: Sequence[str] = [],
) -> httputil.HTTPHeaders:
headers = httputil.HTTPHeaders()
for header, v in self.request.headers.get_all():
if header.lower() not in exclude_lower_headers_list:
headers.add(header, v)
return headers
@override
def get(self) -> None:
pass
@override
def head(self) -> None:
pass
@override
def post(self) -> None:
pass
@override
def put(self) -> None:
pass
@override
def patch(self) -> None:
pass
@override
def options(self) -> None:
pass
@override
def delete(self) -> None:
pass
@override
async def prepare(self) -> None:
assert self.request.method is not None
assert self.request.remote_ip is not None
if "X-REAL-IP" not in self.request.headers:
self.request.headers["X-REAL-IP"] = self.request.remote_ip
if "X-FORWARDED_PORT" not in self.request.headers:
self.request.headers["X-FORWARDED-PORT"] = str(proxy_port)
url = transform_url(
self.request.protocol,
self.request.path,
self.request.query,
self.target_port,
self.target_host,
)
try:
request = httpclient.HTTPRequest(
url=url,
method=self.request.method,
headers=self._add_request_headers(["upgrade-insecure-requests"]),
follow_redirects=False,
body=self.request.body,
allow_nonstandard_methods=True,
# use large timeouts to handle polling requests
connect_timeout=240.0,
request_timeout=240.0,
# https://github.com/tornadoweb/tornado/issues/2743
decompress_response=False,
)
response = await client.fetch(request, raise_error=False)
self.set_status(response.code, response.reason)
self._headers = httputil.HTTPHeaders() # clear tornado default header
for header, v in response.headers.get_all():
# some header appear multiple times, eg 'Set-Cookie'
if header.lower() != "transfer-encoding":
self.add_header(header, v)
if response.body:
self.write(response.body)
except (ConnectionError, httpclient.HTTPError) as e:
self.set_status(500)
self.write("Internal server error:\n" + str(e))
class WebPackHandler(BaseHandler):
target_port = webpack_port
class DjangoHandler(BaseHandler):
target_port = django_port
class TornadoHandler(BaseHandler):
target_port = tornado_port
class Application(web.Application):
def __init__(self, enable_logging: bool = False) -> None:
super().__init__(
[
(r"/json/events.*", TornadoHandler),
(r"/api/v1/events.*", TornadoHandler),
(r"/webpack.*", WebPackHandler),
(r"/.*", DjangoHandler),
], ],
enable_logging=enable_logging, data=request.content.iter_any() if request.body_exists else None,
allow_redirects=False,
auto_decompress=False,
compress=False,
skip_auto_headers=SKIP_AUTO_HEADERS,
) )
except aiohttp.ClientError as error:
logging.error(
"Failed to forward %s %s to port %d: %s",
request.method,
request.url.path,
upstream_port,
error,
)
raise web.HTTPBadGateway from error
@override response = web.StreamResponse(status=upstream_response.status, reason=upstream_response.reason)
def log_request(self, handler: web.RequestHandler) -> None: response.headers.extend(
if self.settings["enable_logging"]: (key, value)
super().log_request(handler) for key, value in upstream_response.headers.items()
if key not in HOP_BY_HOP_HEADERS
)
assert request.remote is not None
response.headers["X-Real-IP"] = request.remote
response.headers["X-Forwarded-Port"] = str(proxy_port)
await response.prepare(request)
async for data in upstream_response.content.iter_any():
await response.write(data)
await response.write_eof()
return response
app = web.Application()
app.add_routes(
[
web.route(
hdrs.METH_ANY, r"/{path:json/events|api/v1/events}", partial(forward, tornado_port)
),
web.route(hdrs.METH_ANY, r"/{path:webpack/.*}", partial(forward, webpack_port)),
web.route(hdrs.METH_ANY, r"/{path:.*}", partial(forward, django_port)),
]
)
def print_listeners() -> None: def print_listeners() -> None:
@ -365,13 +306,12 @@ def print_listeners() -> None:
print() print()
runner: web.AppRunner
children: List["subprocess.Popen[bytes]"] = [] children: List["subprocess.Popen[bytes]"] = []
async def serve() -> None: async def serve() -> None:
global client global runner, session
AsyncIOMainLoop().install()
if options.test: if options.test:
do_one_time_webpack_compile() do_one_time_webpack_compile()
@ -380,10 +320,12 @@ async def serve() -> None:
children.extend(subprocess.Popen(cmd) for cmd in server_processes()) children.extend(subprocess.Popen(cmd) for cmd in server_processes())
client = httpclient.AsyncHTTPClient() session = aiohttp.ClientSession()
app = Application(enable_logging=options.enable_tornado_logging) runner = web.AppRunner(app, auto_decompress=False, handler_cancellation=True)
await runner.setup()
site = web.TCPSite(runner, host=options.interface, port=proxy_port)
try: try:
app.listen(proxy_port, address=options.interface) await site.start()
except OSError as e: except OSError as e:
if e.errno == errno.EADDRINUSE: if e.errno == errno.EADDRINUSE:
print("\n\nERROR: You probably have another server running!!!\n\n") print("\n\nERROR: You probably have another server running!!!\n\n")
@ -400,6 +342,9 @@ try:
loop.add_signal_handler(s, loop.stop) loop.add_signal_handler(s, loop.stop)
loop.run_forever() loop.run_forever()
finally: finally:
loop.run_until_complete(runner.cleanup())
loop.run_until_complete(session.close())
for child in children: for child in children:
child.terminate() child.terminate()

View File

@ -214,11 +214,6 @@ class AsyncDjangoHandler(tornado.web.RequestHandler):
def on_connection_close(self) -> None: def on_connection_close(self) -> None:
# Register a Tornado handler that runs when client-side # Register a Tornado handler that runs when client-side
# connections are closed to notify the events system. # connections are closed to notify the events system.
#
# Note that in the development environment, the development
# proxy does not correctly close connections to Tornado when
# its clients (e.g. `curl`) close their connections. This
# code path is thus _unreachable except in production_.
# If the client goes away, garbage collect the handler (with # If the client goes away, garbage collect the handler (with
# attached request information). # attached request information).