diff --git a/zerver/lib/socket.py b/zerver/lib/socket.py index 7edf8f9b41..c9fe049619 100644 --- a/zerver/lib/socket.py +++ b/zerver/lib/socket.py @@ -1,5 +1,5 @@ from __future__ import absolute_import -from typing import Any +from typing import Any, Union, Optional from django.conf import settings from django.utils.importlib import import_module @@ -7,6 +7,7 @@ from django.utils import timezone from django.contrib.sessions.models import Session as djSession import sockjs.tornado +from sockjs.tornado.session import ConnectionInfo import tornado.ioloop import ujson import logging @@ -27,6 +28,7 @@ logger = logging.getLogger('zulip.socket') djsession_engine = import_module(settings.SESSION_ENGINE) def get_user_profile(session_id): + # type: (str) -> Optional[UserProfile] if session_id is None: return None @@ -41,12 +43,14 @@ def get_user_profile(session_id): except (UserProfile.DoesNotExist, KeyError): return None -connections = dict() # type: Dict[int, SocketConnection] +connections = dict() # type: Dict[Union[int, str], SocketConnection] def get_connection(id): + # type: (Union[int, str]) -> SocketConnection return connections.get(id) def register_connection(id, conn): + # type: (Union[int, str], SocketConnection) -> None # Kill any old connections if they exist if id in connections: connections[id].close() @@ -55,26 +59,31 @@ def register_connection(id, conn): connections[conn.client_id] = conn def deregister_connection(conn): + # type: (SocketConnection) -> None del connections[conn.client_id] redis_client = get_redis_client() def req_redis_key(req_id): + # type: (str) -> str return 'socket_req_status:%s' % (req_id,) class SocketAuthError(Exception): def __init__(self, msg): + # type: (str) -> None self.msg = msg class CloseErrorInfo(object): def __init__(self, status_code, err_msg): + # type: (int, str) -> None self.status_code = status_code self.err_msg = err_msg class SocketConnection(sockjs.tornado.SockJSConnection): - client_id = None # type: str + client_id = None # type: Union[int, str] def on_open(self, info): + # type: (ConnectionInfo) -> None log_data = dict(extra='[transport=%s]' % (self.session.transport_name,)) record_request_start_data(log_data) @@ -97,6 +106,7 @@ class SocketConnection(sockjs.tornado.SockJSConnection): return def auth_timeout(): + # type: () -> None self.close_info = CloseErrorInfo(408, "Timeout while waiting for authentication") self.close() @@ -105,6 +115,7 @@ class SocketConnection(sockjs.tornado.SockJSConnection): remote_ip=info.ip, email='unknown', client_name='?') def authenticate_client(self, msg): + # type: (Dict[str, Any]) -> None if self.authenticated: self.session.send_message({'req_id': msg['req_id'], 'type': 'response', 'response': {'result': 'error', 'msg': 'Already authenticated'}}) @@ -151,10 +162,11 @@ class SocketConnection(sockjs.tornado.SockJSConnection): ioloop = tornado.ioloop.IOLoop.instance() ioloop.remove_timeout(self.timeout_handle) - def on_message(self, msg): + def on_message(self, msg_raw): + # type: (str) -> None log_data = dict(extra='[transport=%s' % (self.session.transport_name,)) record_request_start_data(log_data) - msg = ujson.loads(msg) + msg = ujson.loads(msg_raw) if self.did_close: logger.info("Received message on already closed socket! transport=%s user=%s client_id=%s" @@ -211,6 +223,7 @@ class SocketConnection(sockjs.tornado.SockJSConnection): fake_message_sender) def on_close(self): + # type: () -> None log_data = dict(extra='[transport=%s]' % (self.session.transport_name,)) record_request_start_data(log_data) if self.close_info is not None: @@ -229,6 +242,7 @@ class SocketConnection(sockjs.tornado.SockJSConnection): self.did_close = True def fake_message_sender(event): + # type: (Dict[str, Any]) -> None log_data = dict() # type: Dict[str, Any] record_request_start_data(log_data) @@ -254,6 +268,7 @@ def fake_message_sender(event): respond_send_message(result) def respond_send_message(data): + # type: (Dict[str, Any]) -> None log_data = data['server_meta']['log_data'] record_request_restart_data(log_data) @@ -286,4 +301,5 @@ sockjs_router = sockjs.tornado.SockJSRouter(SocketConnection, "/sockjs", {'sockjs_url': 'https://%s/static/third/sockjs/sockjs-0.3.4.js' % (settings.EXTERNAL_HOST,), 'disabled_transports': ['eventsource', 'htmlfile']}) def get_sockjs_router(): + # type: () -> sockjs.tornado.SockJSRouter return sockjs_router