run-dev: Rewrite development proxy with aiohttp.

This allows request cancellation to be propagated to the server.

Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
Anders Kaseorg 2023-12-11 13:44:55 -08:00 committed by Tim Abbott
parent c1988a14a7
commit 55b26da82b
4 changed files with 89 additions and 145 deletions

View File

@ -8,6 +8,9 @@
# moto s3 mock
moto[s3]
# For tools/run-dev
aiohttp
# Needed for documentation links test
Scrapy

View File

@ -88,6 +88,7 @@ aiohttp==3.9.1 \
--hash=sha256:f800164276eec54e0af5c99feb9494c295118fc10a11b997bbb1348ba1a52065 \
--hash=sha256:ffcd828e37dc219a72c9012ec44ad2e7e3066bec6ff3aaa19e7d435dbf4032ca
# via
# -r requirements/dev.in
# aiohttp-retry
# twilio
aiohttp-retry==2.8.3 \

View File

@ -2,13 +2,13 @@
import argparse
import asyncio
import errno
import logging
import os
import pwd
import signal
import subprocess
import sys
from typing import List, Sequence
from urllib.parse import urlunsplit
from typing import List
TOOLS_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.dirname(TOOLS_DIR))
@ -18,9 +18,9 @@ from tools.lib import sanity_check
sanity_check.check_venv(__file__)
from tornado import httpclient, httputil, web
from tornado.platform.asyncio import AsyncIOMainLoop
from typing_extensions import override
import aiohttp
from aiohttp import hdrs, web
from returns.curry import partial
from tools.lib.test_script import add_provision_check_override_param, assert_provisioning_status_ok
@ -54,11 +54,6 @@ parser.add_argument(
help="Do not clear memcached on startup",
)
parser.add_argument("--streamlined", action="store_true", help="Avoid process_queue, etc.")
parser.add_argument(
"--enable-tornado-logging",
action="store_true",
help="Enable access logs from tornado proxy server.",
)
parser.add_argument(
"--behind-https-proxy",
action="store_true",
@ -204,135 +199,81 @@ def start_webpack_watcher() -> "subprocess.Popen[bytes]":
return subprocess.Popen(webpack_cmd)
def transform_url(protocol: str, path: str, query: str, target_port: int, target_host: str) -> str:
# generate url with target host
host = ":".join((target_host, str(target_port)))
# Here we are going to rewrite the path a bit so that it is in parity with
# what we will have for production
newpath = urlunsplit((protocol, host, path, query, ""))
return newpath
session: aiohttp.ClientSession
# https://datatracker.ietf.org/doc/html/rfc2616#section-13.5.1
HOP_BY_HOP_HEADERS = {
hdrs.CONNECTION,
hdrs.KEEP_ALIVE,
hdrs.PROXY_AUTHENTICATE,
hdrs.PROXY_AUTHORIZATION,
hdrs.TE,
hdrs.TRAILER,
hdrs.TRANSFER_ENCODING,
hdrs.UPGRADE,
}
# Headers that aiohttp would otherwise generate by default
SKIP_AUTO_HEADERS = {
hdrs.ACCEPT,
hdrs.ACCEPT_ENCODING,
hdrs.CONTENT_TYPE,
hdrs.USER_AGENT,
}
client: httpclient.AsyncHTTPClient
class BaseHandler(web.RequestHandler):
# target server ip
target_host: str = "127.0.0.1"
# target server port
target_port: int
def _add_request_headers(
self,
exclude_lower_headers_list: Sequence[str] = [],
) -> httputil.HTTPHeaders:
headers = httputil.HTTPHeaders()
for header, v in self.request.headers.get_all():
if header.lower() not in exclude_lower_headers_list:
headers.add(header, v)
return headers
@override
def get(self) -> None:
pass
@override
def head(self) -> None:
pass
@override
def post(self) -> None:
pass
@override
def put(self) -> None:
pass
@override
def patch(self) -> None:
pass
@override
def options(self) -> None:
pass
@override
def delete(self) -> None:
pass
@override
async def prepare(self) -> None:
assert self.request.method is not None
assert self.request.remote_ip is not None
if "X-REAL-IP" not in self.request.headers:
self.request.headers["X-REAL-IP"] = self.request.remote_ip
if "X-FORWARDED_PORT" not in self.request.headers:
self.request.headers["X-FORWARDED-PORT"] = str(proxy_port)
url = transform_url(
self.request.protocol,
self.request.path,
self.request.query,
self.target_port,
self.target_host,
)
async def forward(upstream_port: int, request: web.Request) -> web.StreamResponse:
try:
request = httpclient.HTTPRequest(
url=url,
method=self.request.method,
headers=self._add_request_headers(["upgrade-insecure-requests"]),
follow_redirects=False,
body=self.request.body,
allow_nonstandard_methods=True,
# use large timeouts to handle polling requests
connect_timeout=240.0,
request_timeout=240.0,
# https://github.com/tornadoweb/tornado/issues/2743
decompress_response=False,
)
response = await client.fetch(request, raise_error=False)
self.set_status(response.code, response.reason)
self._headers = httputil.HTTPHeaders() # clear tornado default header
for header, v in response.headers.get_all():
# some header appear multiple times, eg 'Set-Cookie'
if header.lower() != "transfer-encoding":
self.add_header(header, v)
if response.body:
self.write(response.body)
except (ConnectionError, httpclient.HTTPError) as e:
self.set_status(500)
self.write("Internal server error:\n" + str(e))
class WebPackHandler(BaseHandler):
target_port = webpack_port
class DjangoHandler(BaseHandler):
target_port = django_port
class TornadoHandler(BaseHandler):
target_port = tornado_port
class Application(web.Application):
def __init__(self, enable_logging: bool = False) -> None:
super().__init__(
[
(r"/json/events.*", TornadoHandler),
(r"/api/v1/events.*", TornadoHandler),
(r"/webpack.*", WebPackHandler),
(r"/.*", DjangoHandler),
upstream_response = await session.request(
request.method,
request.url.with_host("127.0.0.1").with_port(upstream_port),
headers=[
(key, value)
for key, value in request.headers.items()
if key not in HOP_BY_HOP_HEADERS
],
enable_logging=enable_logging,
data=request.content.iter_any() if request.body_exists else None,
allow_redirects=False,
auto_decompress=False,
compress=False,
skip_auto_headers=SKIP_AUTO_HEADERS,
)
except aiohttp.ClientError as error:
logging.error(
"Failed to forward %s %s to port %d: %s",
request.method,
request.url.path,
upstream_port,
error,
)
raise web.HTTPBadGateway from error
@override
def log_request(self, handler: web.RequestHandler) -> None:
if self.settings["enable_logging"]:
super().log_request(handler)
response = web.StreamResponse(status=upstream_response.status, reason=upstream_response.reason)
response.headers.extend(
(key, value)
for key, value in upstream_response.headers.items()
if key not in HOP_BY_HOP_HEADERS
)
assert request.remote is not None
response.headers["X-Real-IP"] = request.remote
response.headers["X-Forwarded-Port"] = str(proxy_port)
await response.prepare(request)
async for data in upstream_response.content.iter_any():
await response.write(data)
await response.write_eof()
return response
app = web.Application()
app.add_routes(
[
web.route(
hdrs.METH_ANY, r"/{path:json/events|api/v1/events}", partial(forward, tornado_port)
),
web.route(hdrs.METH_ANY, r"/{path:webpack/.*}", partial(forward, webpack_port)),
web.route(hdrs.METH_ANY, r"/{path:.*}", partial(forward, django_port)),
]
)
def print_listeners() -> None:
@ -365,13 +306,12 @@ def print_listeners() -> None:
print()
runner: web.AppRunner
children: List["subprocess.Popen[bytes]"] = []
async def serve() -> None:
global client
AsyncIOMainLoop().install()
global runner, session
if options.test:
do_one_time_webpack_compile()
@ -380,10 +320,12 @@ async def serve() -> None:
children.extend(subprocess.Popen(cmd) for cmd in server_processes())
client = httpclient.AsyncHTTPClient()
app = Application(enable_logging=options.enable_tornado_logging)
session = aiohttp.ClientSession()
runner = web.AppRunner(app, auto_decompress=False, handler_cancellation=True)
await runner.setup()
site = web.TCPSite(runner, host=options.interface, port=proxy_port)
try:
app.listen(proxy_port, address=options.interface)
await site.start()
except OSError as e:
if e.errno == errno.EADDRINUSE:
print("\n\nERROR: You probably have another server running!!!\n\n")
@ -400,6 +342,9 @@ try:
loop.add_signal_handler(s, loop.stop)
loop.run_forever()
finally:
loop.run_until_complete(runner.cleanup())
loop.run_until_complete(session.close())
for child in children:
child.terminate()

View File

@ -214,11 +214,6 @@ class AsyncDjangoHandler(tornado.web.RequestHandler):
def on_connection_close(self) -> None:
# Register a Tornado handler that runs when client-side
# connections are closed to notify the events system.
#
# Note that in the development environment, the development
# proxy does not correctly close connections to Tornado when
# its clients (e.g. `curl`) close their connections. This
# code path is thus _unreachable except in production_.
# If the client goes away, garbage collect the handler (with
# attached request information).