test_tornado: Avoid deprecated AsyncHTTPTestCase.

Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
Anders Kaseorg 2022-07-05 14:28:31 -07:00 committed by Tim Abbott
parent b4cf9ad777
commit 6c79b8f2f1
1 changed files with 21 additions and 35 deletions

View File

@ -1,4 +1,5 @@
import asyncio import asyncio
import socket
import urllib.parse import urllib.parse
from functools import wraps from functools import wraps
from typing import Any, Awaitable, Callable, Dict, Optional, TypeVar from typing import Any, Awaitable, Callable, Dict, Optional, TypeVar
@ -10,11 +11,9 @@ from django.conf import settings
from django.core import signals from django.core import signals
from django.db import close_old_connections from django.db import close_old_connections
from django.test import override_settings from django.test import override_settings
from tornado.httpclient import HTTPResponse from tornado import netutil
from tornado.ioloop import IOLoop from tornado.httpclient import AsyncHTTPClient, HTTPResponse
from tornado.platform.asyncio import AsyncIOMainLoop from tornado.httpserver import HTTPServer
from tornado.testing import AsyncHTTPTestCase, AsyncTestCase
from tornado.web import Application
from typing_extensions import ParamSpec from typing_extensions import ParamSpec
from zerver.lib.test_classes import ZulipTestCase from zerver.lib.test_classes import ZulipTestCase
@ -38,51 +37,41 @@ async def in_django_thread(f: Callable[[], T]) -> T:
return await asyncio.create_task(sync_to_async(f)()) return await asyncio.create_task(sync_to_async(f)())
class TornadoWebTestCase(AsyncHTTPTestCase, ZulipTestCase): class TornadoWebTestCase(ZulipTestCase):
@async_to_sync_decorator @async_to_sync_decorator
async def setUp(self) -> None: async def setUp(self) -> None:
super().setUp() super().setUp()
with override_settings(DEBUG=False):
self.http_server = HTTPServer(create_tornado_application())
sock = netutil.bind_sockets(0, "127.0.0.1", family=socket.AF_INET)[0]
self.port = sock.getsockname()[1]
self.http_server.add_sockets([sock])
self.http_client = AsyncHTTPClient()
signals.request_started.disconnect(close_old_connections) signals.request_started.disconnect(close_old_connections)
signals.request_finished.disconnect(close_old_connections) signals.request_finished.disconnect(close_old_connections)
self.session_cookie: Optional[Dict[str, str]] = None self.session_cookie: Optional[Dict[str, str]] = None
@async_to_sync_decorator @async_to_sync_decorator
async def tearDown(self) -> None: async def tearDown(self) -> None:
# Skip tornado.testing.AsyncTestCase.tearDown because it tries to kill self.http_client.close()
# the current task. self.http_server.stop()
super(AsyncTestCase, self).tearDown() super().tearDown()
def run(self, result: Optional[TestResult] = None) -> Optional[TestResult]: def run(self, result: Optional[TestResult] = None) -> Optional[TestResult]:
return async_to_sync( return async_to_sync(
sync_to_async(super().run, thread_sensitive=False), force_new_loop=True sync_to_async(super().run, thread_sensitive=False), force_new_loop=True
)(result) )(result)
def get_new_ioloop(self) -> IOLoop:
return AsyncIOMainLoop()
@override_settings(DEBUG=False)
def get_app(self) -> Application:
return create_tornado_application()
async def tornado_client_get(self, path: str, **kwargs: Any) -> HTTPResponse:
self.add_session_cookie(kwargs)
self.set_http_headers(kwargs, skip_user_agent=True)
if "HTTP_HOST" in kwargs:
kwargs["headers"]["Host"] = kwargs["HTTP_HOST"]
del kwargs["HTTP_HOST"]
return await self.http_client.fetch(self.get_url(path), method="GET", **kwargs)
async def fetch_async(self, method: str, path: str, **kwargs: Any) -> HTTPResponse: async def fetch_async(self, method: str, path: str, **kwargs: Any) -> HTTPResponse:
self.add_session_cookie(kwargs) self.add_session_cookie(kwargs)
self.set_http_headers(kwargs, skip_user_agent=True) self.set_http_headers(kwargs, skip_user_agent=True)
if "HTTP_HOST" in kwargs: if "HTTP_HOST" in kwargs:
kwargs["headers"]["Host"] = kwargs["HTTP_HOST"] kwargs["headers"]["Host"] = kwargs["HTTP_HOST"]
del kwargs["HTTP_HOST"] del kwargs["HTTP_HOST"]
return await self.http_client.fetch(self.get_url(path), method=method, **kwargs) return await self.http_client.fetch(
f"http://127.0.0.1:{self.port}{path}", method=method, **kwargs
async def client_get_async(self, path: str, **kwargs: Any) -> HTTPResponse: )
self.set_http_headers(kwargs, skip_user_agent=True)
return await self.fetch_async("GET", path, **kwargs)
def login_user(self, *args: Any, **kwargs: Any) -> None: def login_user(self, *args: Any, **kwargs: Any) -> None:
super().login_user(*args, **kwargs) super().login_user(*args, **kwargs)
@ -102,10 +91,7 @@ class TornadoWebTestCase(AsyncHTTPTestCase, ZulipTestCase):
kwargs["headers"] = headers kwargs["headers"] = headers
async def create_queue(self, **kwargs: Any) -> str: async def create_queue(self, **kwargs: Any) -> str:
response = await self.tornado_client_get( response = await self.fetch_async("GET", "/json/events?dont_block=true", subdomain="zulip")
"/json/events?dont_block=true",
subdomain="zulip",
)
self.assertEqual(response.code, 200) self.assertEqual(response.code, 200)
body = orjson.loads(response.body) body = orjson.loads(response.body)
self.assertEqual(body["events"], []) self.assertEqual(body["events"], [])
@ -142,11 +128,11 @@ class EventsTestCase(TornadoWebTestCase):
def wrapped_fetch_events(**query: Any) -> Dict[str, Any]: def wrapped_fetch_events(**query: Any) -> Dict[str, Any]:
ret = event_queue.fetch_events(**query) ret = event_queue.fetch_events(**query)
self.io_loop.add_callback(process_events) asyncio.get_running_loop().call_soon(process_events)
return ret return ret
with mock.patch("zerver.tornado.views.fetch_events", side_effect=wrapped_fetch_events): with mock.patch("zerver.tornado.views.fetch_events", side_effect=wrapped_fetch_events):
response = await self.client_get_async(path) response = await self.fetch_async("GET", path)
self.assertEqual(response.headers["Vary"], "Accept-Language, Cookie") self.assertEqual(response.headers["Vary"], "Accept-Language, Cookie")
data = orjson.loads(response.body) data = orjson.loads(response.body)