zerver/tornado: Use python 3 syntax for typing.

This commit is contained in:
rht 2017-10-26 11:38:28 +02:00 committed by Tim Abbott
parent e296841447
commit 1047733486
7 changed files with 59 additions and 102 deletions

View File

@ -10,16 +10,14 @@ from zerver.lib.queue import get_queue_client
import tornado.autoreload
import tornado.web
def setup_tornado_rabbitmq():
# type: () -> None
def setup_tornado_rabbitmq() -> None:
# When tornado is shut down, disconnect cleanly from rabbitmq
if settings.USING_RABBITMQ:
queue_client = get_queue_client()
atexit.register(lambda: queue_client.close())
tornado.autoreload.add_reload_hook(lambda: queue_client.close())
def create_tornado_application():
# type: () -> tornado.web.Application
def create_tornado_application() -> tornado.web.Application:
urls = (r"/notify_tornado",
r"/json/events",
r"/api/v1/events",

View File

@ -8,11 +8,9 @@ class BadEventQueueIdError(JsonableError):
code = ErrorCode.BAD_EVENT_QUEUE_ID
data_fields = ['queue_id']
def __init__(self, queue_id):
# type: (Text) -> None
def __init__(self, queue_id: Text) -> None:
self.queue_id = queue_id # type: Text
@staticmethod
def msg_format():
# type: () -> Text
def msg_format() -> Text:
return _("Bad event queue id: {queue_id}")

View File

@ -25,30 +25,26 @@ from zerver.tornado.descriptors import get_descriptor_by_handler_id
from typing import Any, Callable, Dict, List, Optional
current_handler_id = 0
handlers = {} # type: Dict[int, AsyncDjangoHandler]
handlers = {} # type: Dict[int, 'AsyncDjangoHandler']
def get_handler_by_id(handler_id):
# type: (int) -> AsyncDjangoHandler
def get_handler_by_id(handler_id: int) -> 'AsyncDjangoHandler':
return handlers[handler_id]
def allocate_handler_id(handler):
# type: (AsyncDjangoHandler) -> int
def allocate_handler_id(handler: 'AsyncDjangoHandler') -> int:
global current_handler_id
handlers[current_handler_id] = handler
handler.handler_id = current_handler_id
current_handler_id += 1
return handler.handler_id
def clear_handler_by_id(handler_id):
# type: (int) -> None
def clear_handler_by_id(handler_id: int) -> None:
del handlers[handler_id]
def handler_stats_string():
# type: () -> str
def handler_stats_string() -> str:
return "%s handlers, latest ID %s" % (len(handlers), current_handler_id)
def finish_handler(handler_id, event_queue_id, contents, apply_markdown):
# type: (int, str, List[Dict[str, Any]], bool) -> None
def finish_handler(handler_id: int, event_queue_id: str,
contents: List[Dict[str, Any]], apply_markdown: bool) -> None:
err_msg = "Got error finishing handler for queue %s" % (event_queue_id,)
try:
# We call async_request_restart here in case we are
@ -80,8 +76,7 @@ def finish_handler(handler_id, event_queue_id, contents, apply_markdown):
class AsyncDjangoHandler(tornado.web.RequestHandler, base.BaseHandler):
initLock = Lock()
def __init__(self, *args, **kwargs):
# type: (*Any, **Any) -> None
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(AsyncDjangoHandler, self).__init__(*args, **kwargs)
# Set up middleware if needed. We couldn't do this earlier, because
@ -97,13 +92,11 @@ class AsyncDjangoHandler(tornado.web.RequestHandler, base.BaseHandler):
# be cleared when the handler finishes its response
allocate_handler_id(self)
def __repr__(self):
# type: () -> str
def __repr__(self) -> str:
descriptor = get_descriptor_by_handler_id(self.handler_id)
return "AsyncDjangoHandler<%s, %s>" % (self.handler_id, descriptor)
def load_middleware(self):
# type: () -> None
def load_middleware(self) -> None:
"""
Populate middleware lists from settings.MIDDLEWARE. This is copied
from Django. This uses settings.MIDDLEWARE setting with the old
@ -149,8 +142,7 @@ class AsyncDjangoHandler(tornado.web.RequestHandler, base.BaseHandler):
# as a flag for initialization being complete.
self._middleware_chain = handler
def get(self, *args, **kwargs):
# type: (*Any, **Any) -> None
def get(self, *args: Any, **kwargs: Any) -> None:
environ = WSGIContainer.environ(self.request)
environ['PATH_INFO'] = urllib.parse.unquote(environ['PATH_INFO'])
request = WSGIRequest(environ)
@ -177,27 +169,22 @@ class AsyncDjangoHandler(tornado.web.RequestHandler, base.BaseHandler):
self.write(response.content)
self.finish()
def head(self, *args, **kwargs):
# type: (*Any, **Any) -> None
def head(self, *args: Any, **kwargs: Any) -> None:
self.get(*args, **kwargs)
def post(self, *args, **kwargs):
# type: (*Any, **Any) -> None
def post(self, *args: Any, **kwargs: Any) -> None:
self.get(*args, **kwargs)
def delete(self, *args, **kwargs):
# type: (*Any, **Any) -> None
def delete(self, *args: Any, **kwargs: Any) -> None:
self.get(*args, **kwargs)
def on_connection_close(self):
# type: () -> None
def on_connection_close(self) -> None:
client_descriptor = get_descriptor_by_handler_id(self.handler_id)
if client_descriptor is not None:
client_descriptor.disconnect_handler(client_closed=True)
# Based on django.core.handlers.base: get_response
def get_response(self, request):
# type: (HttpRequest) -> HttpResponse
def get_response(self, request: HttpRequest) -> HttpResponse:
"Returns an HttpResponse object for the given HttpRequest"
try:
try:
@ -321,8 +308,8 @@ class AsyncDjangoHandler(tornado.web.RequestHandler, base.BaseHandler):
return response
### Copied from get_response (above in this file)
def apply_response_middleware(self, request, response, resolver):
# type: (HttpRequest, HttpResponse, urlresolvers.RegexURLResolver) -> HttpResponse
def apply_response_middleware(self, request: HttpRequest, response: HttpResponse,
resolver: urlresolvers.RegexURLResolver) -> HttpResponse:
try:
# Apply response middleware, regardless of the response
for middleware_method in self._response_middleware:
@ -335,8 +322,8 @@ class AsyncDjangoHandler(tornado.web.RequestHandler, base.BaseHandler):
return response
def zulip_finish(self, response, request, apply_markdown):
# type: (Dict[str, Any], HttpRequest, bool) -> None
def zulip_finish(self, response: Dict[str, Any], request: HttpRequest,
apply_markdown: bool) -> None:
# Make sure that Markdown rendering really happened, if requested.
# This is a security issue because it's where we escape HTML.
# c.f. ticket #64

View File

@ -16,8 +16,7 @@ class InstrumentedPollIOLoop(PollIOLoop):
def initialize(self, **kwargs): # type: ignore # TODO investigate likely buggy monkey patching here
super(InstrumentedPollIOLoop, self).initialize(impl=InstrumentedPoll(), **kwargs)
def instrument_tornado_ioloop():
# type: () -> None
def instrument_tornado_ioloop() -> None:
IOLoop.configure(InstrumentedPollIOLoop)
# A hack to keep track of how much time we spend working, versus sleeping in
@ -29,8 +28,7 @@ def instrument_tornado_ioloop():
# runs that might instantiate the default event loop.
class InstrumentedPoll(object):
def __init__(self):
# type: () -> None
def __init__(self) -> None:
self._underlying = orig_poll_impl()
self._times = [] # type: List[Tuple[float, float]]
self._last_print = 0.0
@ -38,13 +36,11 @@ class InstrumentedPoll(object):
# Python won't let us subclass e.g. select.epoll, so instead
# we proxy every method. __getattr__ handles anything we
# don't define elsewhere.
def __getattr__(self, name):
# type: (str) -> Any
def __getattr__(self, name: str) -> Any:
return getattr(self._underlying, name)
# Call the underlying poll method, and report timing data.
def poll(self, timeout):
# type: (float) -> Any
def poll(self, timeout: float) -> Any:
# Avoid accumulating a bunch of insignificant data points
# from short timeouts.

View File

@ -9,8 +9,7 @@ try:
from django.middleware.csrf import _compare_salted_tokens
except ImportError:
# This function was added in Django 1.10.
def _compare_salted_tokens(token1, token2):
# type: (str, str) -> bool
def _compare_salted_tokens(token1: str, token2: str) -> bool:
return token1 == token2
import sockjs.tornado
@ -34,8 +33,7 @@ from zerver.tornado.exceptions import BadEventQueueIdError
logger = logging.getLogger('zulip.socket')
def get_user_profile(session_id):
# type: (Optional[Text]) -> Optional[UserProfile]
def get_user_profile(session_id: Optional[Text]) -> Optional[UserProfile]:
if session_id is None:
return None
@ -50,14 +48,12 @@ def get_user_profile(session_id):
except (UserProfile.DoesNotExist, KeyError):
return None
connections = dict() # type: Dict[Union[int, str], SocketConnection]
connections = dict() # type: Dict[Union[int, str], 'SocketConnection']
def get_connection(id):
# type: (Union[int, str]) -> Optional[SocketConnection]
def get_connection(id: Union[int, str]) -> Optional['SocketConnection']:
return connections.get(id)
def register_connection(id, conn):
# type: (Union[int, str], SocketConnection) -> None
def register_connection(id: Union[int, str], conn: 'SocketConnection') -> None:
# Kill any old connections if they exist
if id in connections:
connections[id].close()
@ -65,28 +61,24 @@ def register_connection(id, conn):
conn.client_id = id
connections[conn.client_id] = conn
def deregister_connection(conn):
# type: (SocketConnection) -> None
def deregister_connection(conn: 'SocketConnection') -> None:
assert conn.client_id is not None
del connections[conn.client_id]
redis_client = get_redis_client()
def req_redis_key(req_id):
# type: (Text) -> Text
def req_redis_key(req_id: Text) -> Text:
return u'socket_req_status:%s' % (req_id,)
class CloseErrorInfo(object):
def __init__(self, status_code, err_msg):
# type: (int, str) -> None
def __init__(self, status_code: int, err_msg: str) -> None:
self.status_code = status_code
self.err_msg = err_msg
class SocketConnection(sockjs.tornado.SockJSConnection):
client_id = None # type: Optional[Union[int, str]]
def on_open(self, info):
# type: (ConnectionInfo) -> None
def on_open(self, info: ConnectionInfo) -> None:
log_data = dict(extra='[transport=%s]' % (self.session.transport_name,))
record_request_start_data(log_data)
@ -108,8 +100,7 @@ class SocketConnection(sockjs.tornado.SockJSConnection):
ioloop.add_callback(self.close)
return
def auth_timeout():
# type: () -> None
def auth_timeout() -> None:
self.close_info = CloseErrorInfo(408, "Timeout while waiting for authentication")
self.close()
@ -117,8 +108,7 @@ class SocketConnection(sockjs.tornado.SockJSConnection):
write_log_line(log_data, path='/socket/open', method='SOCKET',
remote_ip=info.ip, email='unknown', client_name='?')
def authenticate_client(self, msg):
# type: (Dict[str, Any]) -> None
def authenticate_client(self, msg: Dict[str, Any]) -> None:
if self.authenticated:
self.session.send_message({'req_id': msg['req_id'], 'type': 'response',
'response': {'result': 'error',
@ -173,8 +163,7 @@ class SocketConnection(sockjs.tornado.SockJSConnection):
ioloop = tornado.ioloop.IOLoop.instance()
ioloop.remove_timeout(self.timeout_handle)
def on_message(self, msg_raw):
# type: (str) -> None
def on_message(self, msg_raw: str) -> None:
log_data = dict(extra='[transport=%s' % (self.session.transport_name,))
record_request_start_data(log_data)
msg = ujson.loads(msg_raw)
@ -233,8 +222,7 @@ class SocketConnection(sockjs.tornado.SockJSConnection):
request_environ=dict(REMOTE_ADDR=self.session.conn_info.ip))),
fake_message_sender)
def on_close(self):
# type: () -> None
def on_close(self) -> None:
log_data = dict(extra='[transport=%s]' % (self.session.transport_name,))
record_request_start_data(log_data)
if self.close_info is not None:
@ -252,8 +240,7 @@ class SocketConnection(sockjs.tornado.SockJSConnection):
self.did_close = True
def fake_message_sender(event):
# type: (Dict[str, Any]) -> None
def fake_message_sender(event: Dict[str, Any]) -> None:
"""This function is used only for Casper and backend tests, where
rabbitmq is disabled"""
log_data = dict() # type: Dict[str, Any]
@ -280,8 +267,7 @@ def fake_message_sender(event):
'server_meta': server_meta}
respond_send_message(result)
def respond_send_message(data):
# type: (Mapping[str, Any]) -> None
def respond_send_message(data: Mapping[str, Any]) -> None:
log_data = data['server_meta']['log_data']
record_request_restart_data(log_data)
@ -314,6 +300,5 @@ sockjs_router = sockjs.tornado.SockJSRouter(SocketConnection, "/sockjs",
{'sockjs_url': 'https://%s/static/third/sockjs/sockjs-0.3.4.js' % (
settings.EXTERNAL_HOST,),
'disabled_transports': ['eventsource', 'htmlfile']})
def get_sockjs_router():
# type: () -> sockjs.tornado.SockJSRouter
def get_sockjs_router() -> sockjs.tornado.SockJSRouter:
return sockjs_router

View File

@ -20,14 +20,13 @@ import time
import ujson
@internal_notify_view(True)
def notify(request):
# type: (HttpRequest) -> HttpResponse
def notify(request: HttpRequest) -> HttpResponse:
process_notification(ujson.loads(request.POST['data']))
return json_success()
@has_request_variables
def cleanup_event_queue(request, user_profile, queue_id=REQ()):
# type: (HttpRequest, UserProfile, Text) -> HttpResponse
def cleanup_event_queue(request: HttpRequest, user_profile: UserProfile,
queue_id: Text=REQ()) -> HttpResponse:
client = get_client_descriptor(str(queue_id))
if client is None:
raise BadEventQueueIdError(queue_id)

View File

@ -40,8 +40,7 @@ class WebsocketClient(object):
self.scheme_dict = {'http': 'ws', 'https': 'wss'}
self.ws = None # type: Optional[WebSocketClientConnection]
def _login(self):
# type: () -> Dict[str,str]
def _login(self) -> Dict[str, str]:
# Ideally, we'd migrate this to use API auth instead of
# stealing cookies, but this works for now.
@ -57,14 +56,13 @@ class WebsocketClient(object):
settings.SESSION_COOKIE_NAME: session.session_key,
settings.CSRF_COOKIE_NAME: _get_new_csrf_token()}
def _get_cookie_header(self, cookies):
# type: (Dict[Any, Any]) -> str
def _get_cookie_header(self, cookies: Dict[Any, Any]) -> str:
return ';'.join(
["{}={}".format(name, value) for name, value in cookies.items()])
@gen.coroutine
def _websocket_auth(self, queue_events_data, cookies):
# type: (Dict[str, Dict[str, str]], SimpleCookie) -> Generator[str, str, None]
def _websocket_auth(self, queue_events_data: Dict[str, Dict[str, str]],
cookies: SimpleCookie) -> Generator[str, str, None]:
message = {
"req_id": self._get_request_id(),
"type": "auth",
@ -80,15 +78,13 @@ class WebsocketClient(object):
response_message = yield self.ws.read_message()
raise gen.Return([response_ack, response_message])
def _get_queue_events(self, cookies_header):
# type: (str) -> Dict[str, str]
def _get_queue_events(self, cookies_header: str) -> Dict[str, str]:
url = urljoin(self.parsed_host_url.geturl(), '/json/events?dont_block=true')
response = requests.get(url, headers={'Cookie': cookies_header}, verify=self.validate_ssl)
return response.json()
@gen.engine
def connect(self):
# type: () -> Generator[str, WebSocketClientConnection, None]
def connect(self) -> Generator[str, WebSocketClientConnection, None]:
try:
request = HTTPRequest(url=self._get_websocket_url(), validate_cert=self.validate_ssl)
request.headers.add('Cookie', self.cookie_str)
@ -102,8 +98,9 @@ class WebsocketClient(object):
IOLoop.instance().stop()
@gen.coroutine
def send_message(self, client, type, subject, stream, private_message_recepient, content=""):
# type: (str, str, str, str, str, str) -> Generator[str, WebSocketClientConnection, None]
def send_message(self, client: str, type: str, subject: str, stream: str,
private_message_recepient: str,
content: str="") -> Generator[str, WebSocketClientConnection, None]:
user_message = {
"req_id": self._get_request_id(),
"type": "request",
@ -126,17 +123,14 @@ class WebsocketClient(object):
response_message = yield self.ws.read_message()
raise gen.Return([response_ack, response_message])
def run(self):
# type: () -> None
def run(self) -> None:
self.ioloop_instance.add_callback(self.connect)
self.ioloop_instance.start()
def _get_websocket_url(self):
# type: () -> str
def _get_websocket_url(self) -> str:
return '{}://{}{}'.format(self.scheme_dict[self.parsed_host_url.scheme],
self.parsed_host_url.netloc, self.sockjs_url)
def _get_request_id(self):
# type: () -> Iterable[str]
def _get_request_id(self) -> Iterable[str]:
self.request_id_number += 1
return ':'.join((self.events_data['queue_id'], str(self.request_id_number)))