diff --git a/zerver/tornado/application.py b/zerver/tornado/application.py index 253881d09b..95c44ee77a 100644 --- a/zerver/tornado/application.py +++ b/zerver/tornado/application.py @@ -10,16 +10,14 @@ from zerver.lib.queue import get_queue_client import tornado.autoreload import tornado.web -def setup_tornado_rabbitmq(): - # type: () -> None +def setup_tornado_rabbitmq() -> None: # When tornado is shut down, disconnect cleanly from rabbitmq if settings.USING_RABBITMQ: queue_client = get_queue_client() atexit.register(lambda: queue_client.close()) tornado.autoreload.add_reload_hook(lambda: queue_client.close()) -def create_tornado_application(): - # type: () -> tornado.web.Application +def create_tornado_application() -> tornado.web.Application: urls = (r"/notify_tornado", r"/json/events", r"/api/v1/events", diff --git a/zerver/tornado/exceptions.py b/zerver/tornado/exceptions.py index 0ba35b4a64..d69c3732f3 100644 --- a/zerver/tornado/exceptions.py +++ b/zerver/tornado/exceptions.py @@ -8,11 +8,9 @@ class BadEventQueueIdError(JsonableError): code = ErrorCode.BAD_EVENT_QUEUE_ID data_fields = ['queue_id'] - def __init__(self, queue_id): - # type: (Text) -> None + def __init__(self, queue_id: Text) -> None: self.queue_id = queue_id # type: Text @staticmethod - def msg_format(): - # type: () -> Text + def msg_format() -> Text: return _("Bad event queue id: {queue_id}") diff --git a/zerver/tornado/handlers.py b/zerver/tornado/handlers.py index fb4f2e3541..6da4e08c8d 100644 --- a/zerver/tornado/handlers.py +++ b/zerver/tornado/handlers.py @@ -25,30 +25,26 @@ from zerver.tornado.descriptors import get_descriptor_by_handler_id from typing import Any, Callable, Dict, List, Optional current_handler_id = 0 -handlers = {} # type: Dict[int, AsyncDjangoHandler] +handlers = {} # type: Dict[int, 'AsyncDjangoHandler'] -def get_handler_by_id(handler_id): - # type: (int) -> AsyncDjangoHandler +def get_handler_by_id(handler_id: int) -> 'AsyncDjangoHandler': return handlers[handler_id] -def allocate_handler_id(handler): - # type: (AsyncDjangoHandler) -> int +def allocate_handler_id(handler: 'AsyncDjangoHandler') -> int: global current_handler_id handlers[current_handler_id] = handler handler.handler_id = current_handler_id current_handler_id += 1 return handler.handler_id -def clear_handler_by_id(handler_id): - # type: (int) -> None +def clear_handler_by_id(handler_id: int) -> None: del handlers[handler_id] -def handler_stats_string(): - # type: () -> str +def handler_stats_string() -> str: return "%s handlers, latest ID %s" % (len(handlers), current_handler_id) -def finish_handler(handler_id, event_queue_id, contents, apply_markdown): - # type: (int, str, List[Dict[str, Any]], bool) -> None +def finish_handler(handler_id: int, event_queue_id: str, + contents: List[Dict[str, Any]], apply_markdown: bool) -> None: err_msg = "Got error finishing handler for queue %s" % (event_queue_id,) try: # We call async_request_restart here in case we are @@ -80,8 +76,7 @@ def finish_handler(handler_id, event_queue_id, contents, apply_markdown): class AsyncDjangoHandler(tornado.web.RequestHandler, base.BaseHandler): initLock = Lock() - def __init__(self, *args, **kwargs): - # type: (*Any, **Any) -> None + def __init__(self, *args: Any, **kwargs: Any) -> None: super(AsyncDjangoHandler, self).__init__(*args, **kwargs) # Set up middleware if needed. We couldn't do this earlier, because @@ -97,13 +92,11 @@ class AsyncDjangoHandler(tornado.web.RequestHandler, base.BaseHandler): # be cleared when the handler finishes its response allocate_handler_id(self) - def __repr__(self): - # type: () -> str + def __repr__(self) -> str: descriptor = get_descriptor_by_handler_id(self.handler_id) return "AsyncDjangoHandler<%s, %s>" % (self.handler_id, descriptor) - def load_middleware(self): - # type: () -> None + def load_middleware(self) -> None: """ Populate middleware lists from settings.MIDDLEWARE. This is copied from Django. This uses settings.MIDDLEWARE setting with the old @@ -149,8 +142,7 @@ class AsyncDjangoHandler(tornado.web.RequestHandler, base.BaseHandler): # as a flag for initialization being complete. self._middleware_chain = handler - def get(self, *args, **kwargs): - # type: (*Any, **Any) -> None + def get(self, *args: Any, **kwargs: Any) -> None: environ = WSGIContainer.environ(self.request) environ['PATH_INFO'] = urllib.parse.unquote(environ['PATH_INFO']) request = WSGIRequest(environ) @@ -177,27 +169,22 @@ class AsyncDjangoHandler(tornado.web.RequestHandler, base.BaseHandler): self.write(response.content) self.finish() - def head(self, *args, **kwargs): - # type: (*Any, **Any) -> None + def head(self, *args: Any, **kwargs: Any) -> None: self.get(*args, **kwargs) - def post(self, *args, **kwargs): - # type: (*Any, **Any) -> None + def post(self, *args: Any, **kwargs: Any) -> None: self.get(*args, **kwargs) - def delete(self, *args, **kwargs): - # type: (*Any, **Any) -> None + def delete(self, *args: Any, **kwargs: Any) -> None: self.get(*args, **kwargs) - def on_connection_close(self): - # type: () -> None + def on_connection_close(self) -> None: client_descriptor = get_descriptor_by_handler_id(self.handler_id) if client_descriptor is not None: client_descriptor.disconnect_handler(client_closed=True) # Based on django.core.handlers.base: get_response - def get_response(self, request): - # type: (HttpRequest) -> HttpResponse + def get_response(self, request: HttpRequest) -> HttpResponse: "Returns an HttpResponse object for the given HttpRequest" try: try: @@ -321,8 +308,8 @@ class AsyncDjangoHandler(tornado.web.RequestHandler, base.BaseHandler): return response ### Copied from get_response (above in this file) - def apply_response_middleware(self, request, response, resolver): - # type: (HttpRequest, HttpResponse, urlresolvers.RegexURLResolver) -> HttpResponse + def apply_response_middleware(self, request: HttpRequest, response: HttpResponse, + resolver: urlresolvers.RegexURLResolver) -> HttpResponse: try: # Apply response middleware, regardless of the response for middleware_method in self._response_middleware: @@ -335,8 +322,8 @@ class AsyncDjangoHandler(tornado.web.RequestHandler, base.BaseHandler): return response - def zulip_finish(self, response, request, apply_markdown): - # type: (Dict[str, Any], HttpRequest, bool) -> None + def zulip_finish(self, response: Dict[str, Any], request: HttpRequest, + apply_markdown: bool) -> None: # Make sure that Markdown rendering really happened, if requested. # This is a security issue because it's where we escape HTML. # c.f. ticket #64 diff --git a/zerver/tornado/ioloop_logging.py b/zerver/tornado/ioloop_logging.py index 0ce09c48d6..87c50350e0 100644 --- a/zerver/tornado/ioloop_logging.py +++ b/zerver/tornado/ioloop_logging.py @@ -16,8 +16,7 @@ class InstrumentedPollIOLoop(PollIOLoop): def initialize(self, **kwargs): # type: ignore # TODO investigate likely buggy monkey patching here super(InstrumentedPollIOLoop, self).initialize(impl=InstrumentedPoll(), **kwargs) -def instrument_tornado_ioloop(): - # type: () -> None +def instrument_tornado_ioloop() -> None: IOLoop.configure(InstrumentedPollIOLoop) # A hack to keep track of how much time we spend working, versus sleeping in @@ -29,8 +28,7 @@ def instrument_tornado_ioloop(): # runs that might instantiate the default event loop. class InstrumentedPoll(object): - def __init__(self): - # type: () -> None + def __init__(self) -> None: self._underlying = orig_poll_impl() self._times = [] # type: List[Tuple[float, float]] self._last_print = 0.0 @@ -38,13 +36,11 @@ class InstrumentedPoll(object): # Python won't let us subclass e.g. select.epoll, so instead # we proxy every method. __getattr__ handles anything we # don't define elsewhere. - def __getattr__(self, name): - # type: (str) -> Any + def __getattr__(self, name: str) -> Any: return getattr(self._underlying, name) # Call the underlying poll method, and report timing data. - def poll(self, timeout): - # type: (float) -> Any + def poll(self, timeout: float) -> Any: # Avoid accumulating a bunch of insignificant data points # from short timeouts. diff --git a/zerver/tornado/socket.py b/zerver/tornado/socket.py index 1eb0c9398c..ee6924a21d 100644 --- a/zerver/tornado/socket.py +++ b/zerver/tornado/socket.py @@ -9,8 +9,7 @@ try: from django.middleware.csrf import _compare_salted_tokens except ImportError: # This function was added in Django 1.10. - def _compare_salted_tokens(token1, token2): - # type: (str, str) -> bool + def _compare_salted_tokens(token1: str, token2: str) -> bool: return token1 == token2 import sockjs.tornado @@ -34,8 +33,7 @@ from zerver.tornado.exceptions import BadEventQueueIdError logger = logging.getLogger('zulip.socket') -def get_user_profile(session_id): - # type: (Optional[Text]) -> Optional[UserProfile] +def get_user_profile(session_id: Optional[Text]) -> Optional[UserProfile]: if session_id is None: return None @@ -50,14 +48,12 @@ def get_user_profile(session_id): except (UserProfile.DoesNotExist, KeyError): return None -connections = dict() # type: Dict[Union[int, str], SocketConnection] +connections = dict() # type: Dict[Union[int, str], 'SocketConnection'] -def get_connection(id): - # type: (Union[int, str]) -> Optional[SocketConnection] +def get_connection(id: Union[int, str]) -> Optional['SocketConnection']: return connections.get(id) -def register_connection(id, conn): - # type: (Union[int, str], SocketConnection) -> None +def register_connection(id: Union[int, str], conn: 'SocketConnection') -> None: # Kill any old connections if they exist if id in connections: connections[id].close() @@ -65,28 +61,24 @@ def register_connection(id, conn): conn.client_id = id connections[conn.client_id] = conn -def deregister_connection(conn): - # type: (SocketConnection) -> None +def deregister_connection(conn: 'SocketConnection') -> None: assert conn.client_id is not None del connections[conn.client_id] redis_client = get_redis_client() -def req_redis_key(req_id): - # type: (Text) -> Text +def req_redis_key(req_id: Text) -> Text: return u'socket_req_status:%s' % (req_id,) class CloseErrorInfo(object): - def __init__(self, status_code, err_msg): - # type: (int, str) -> None + def __init__(self, status_code: int, err_msg: str) -> None: self.status_code = status_code self.err_msg = err_msg class SocketConnection(sockjs.tornado.SockJSConnection): client_id = None # type: Optional[Union[int, str]] - def on_open(self, info): - # type: (ConnectionInfo) -> None + def on_open(self, info: ConnectionInfo) -> None: log_data = dict(extra='[transport=%s]' % (self.session.transport_name,)) record_request_start_data(log_data) @@ -108,8 +100,7 @@ class SocketConnection(sockjs.tornado.SockJSConnection): ioloop.add_callback(self.close) return - def auth_timeout(): - # type: () -> None + def auth_timeout() -> None: self.close_info = CloseErrorInfo(408, "Timeout while waiting for authentication") self.close() @@ -117,8 +108,7 @@ class SocketConnection(sockjs.tornado.SockJSConnection): write_log_line(log_data, path='/socket/open', method='SOCKET', remote_ip=info.ip, email='unknown', client_name='?') - def authenticate_client(self, msg): - # type: (Dict[str, Any]) -> None + def authenticate_client(self, msg: Dict[str, Any]) -> None: if self.authenticated: self.session.send_message({'req_id': msg['req_id'], 'type': 'response', 'response': {'result': 'error', @@ -173,8 +163,7 @@ class SocketConnection(sockjs.tornado.SockJSConnection): ioloop = tornado.ioloop.IOLoop.instance() ioloop.remove_timeout(self.timeout_handle) - def on_message(self, msg_raw): - # type: (str) -> None + def on_message(self, msg_raw: str) -> None: log_data = dict(extra='[transport=%s' % (self.session.transport_name,)) record_request_start_data(log_data) msg = ujson.loads(msg_raw) @@ -233,8 +222,7 @@ class SocketConnection(sockjs.tornado.SockJSConnection): request_environ=dict(REMOTE_ADDR=self.session.conn_info.ip))), fake_message_sender) - def on_close(self): - # type: () -> None + def on_close(self) -> None: log_data = dict(extra='[transport=%s]' % (self.session.transport_name,)) record_request_start_data(log_data) if self.close_info is not None: @@ -252,8 +240,7 @@ class SocketConnection(sockjs.tornado.SockJSConnection): self.did_close = True -def fake_message_sender(event): - # type: (Dict[str, Any]) -> None +def fake_message_sender(event: Dict[str, Any]) -> None: """This function is used only for Casper and backend tests, where rabbitmq is disabled""" log_data = dict() # type: Dict[str, Any] @@ -280,8 +267,7 @@ def fake_message_sender(event): 'server_meta': server_meta} respond_send_message(result) -def respond_send_message(data): - # type: (Mapping[str, Any]) -> None +def respond_send_message(data: Mapping[str, Any]) -> None: log_data = data['server_meta']['log_data'] record_request_restart_data(log_data) @@ -314,6 +300,5 @@ sockjs_router = sockjs.tornado.SockJSRouter(SocketConnection, "/sockjs", {'sockjs_url': 'https://%s/static/third/sockjs/sockjs-0.3.4.js' % ( settings.EXTERNAL_HOST,), 'disabled_transports': ['eventsource', 'htmlfile']}) -def get_sockjs_router(): - # type: () -> sockjs.tornado.SockJSRouter +def get_sockjs_router() -> sockjs.tornado.SockJSRouter: return sockjs_router diff --git a/zerver/tornado/views.py b/zerver/tornado/views.py index 498f010b56..1a1ca36bfa 100644 --- a/zerver/tornado/views.py +++ b/zerver/tornado/views.py @@ -20,14 +20,13 @@ import time import ujson @internal_notify_view(True) -def notify(request): - # type: (HttpRequest) -> HttpResponse +def notify(request: HttpRequest) -> HttpResponse: process_notification(ujson.loads(request.POST['data'])) return json_success() @has_request_variables -def cleanup_event_queue(request, user_profile, queue_id=REQ()): - # type: (HttpRequest, UserProfile, Text) -> HttpResponse +def cleanup_event_queue(request: HttpRequest, user_profile: UserProfile, + queue_id: Text=REQ()) -> HttpResponse: client = get_client_descriptor(str(queue_id)) if client is None: raise BadEventQueueIdError(queue_id) diff --git a/zerver/tornado/websocket_client.py b/zerver/tornado/websocket_client.py index 23629b05c5..b45fb69248 100644 --- a/zerver/tornado/websocket_client.py +++ b/zerver/tornado/websocket_client.py @@ -40,8 +40,7 @@ class WebsocketClient(object): self.scheme_dict = {'http': 'ws', 'https': 'wss'} self.ws = None # type: Optional[WebSocketClientConnection] - def _login(self): - # type: () -> Dict[str,str] + def _login(self) -> Dict[str, str]: # Ideally, we'd migrate this to use API auth instead of # stealing cookies, but this works for now. @@ -57,14 +56,13 @@ class WebsocketClient(object): settings.SESSION_COOKIE_NAME: session.session_key, settings.CSRF_COOKIE_NAME: _get_new_csrf_token()} - def _get_cookie_header(self, cookies): - # type: (Dict[Any, Any]) -> str + def _get_cookie_header(self, cookies: Dict[Any, Any]) -> str: return ';'.join( ["{}={}".format(name, value) for name, value in cookies.items()]) @gen.coroutine - def _websocket_auth(self, queue_events_data, cookies): - # type: (Dict[str, Dict[str, str]], SimpleCookie) -> Generator[str, str, None] + def _websocket_auth(self, queue_events_data: Dict[str, Dict[str, str]], + cookies: SimpleCookie) -> Generator[str, str, None]: message = { "req_id": self._get_request_id(), "type": "auth", @@ -80,15 +78,13 @@ class WebsocketClient(object): response_message = yield self.ws.read_message() raise gen.Return([response_ack, response_message]) - def _get_queue_events(self, cookies_header): - # type: (str) -> Dict[str, str] + def _get_queue_events(self, cookies_header: str) -> Dict[str, str]: url = urljoin(self.parsed_host_url.geturl(), '/json/events?dont_block=true') response = requests.get(url, headers={'Cookie': cookies_header}, verify=self.validate_ssl) return response.json() @gen.engine - def connect(self): - # type: () -> Generator[str, WebSocketClientConnection, None] + def connect(self) -> Generator[str, WebSocketClientConnection, None]: try: request = HTTPRequest(url=self._get_websocket_url(), validate_cert=self.validate_ssl) request.headers.add('Cookie', self.cookie_str) @@ -102,8 +98,9 @@ class WebsocketClient(object): IOLoop.instance().stop() @gen.coroutine - def send_message(self, client, type, subject, stream, private_message_recepient, content=""): - # type: (str, str, str, str, str, str) -> Generator[str, WebSocketClientConnection, None] + def send_message(self, client: str, type: str, subject: str, stream: str, + private_message_recepient: str, + content: str="") -> Generator[str, WebSocketClientConnection, None]: user_message = { "req_id": self._get_request_id(), "type": "request", @@ -126,17 +123,14 @@ class WebsocketClient(object): response_message = yield self.ws.read_message() raise gen.Return([response_ack, response_message]) - def run(self): - # type: () -> None + def run(self) -> None: self.ioloop_instance.add_callback(self.connect) self.ioloop_instance.start() - def _get_websocket_url(self): - # type: () -> str + def _get_websocket_url(self) -> str: return '{}://{}{}'.format(self.scheme_dict[self.parsed_host_url.scheme], self.parsed_host_url.netloc, self.sockjs_url) - def _get_request_id(self): - # type: () -> Iterable[str] + def _get_request_id(self) -> Iterable[str]: self.request_id_number += 1 return ':'.join((self.events_data['queue_id'], str(self.request_id_number)))