2016-12-29 12:02:44 +01:00
|
|
|
import logging
|
|
|
|
import requests
|
|
|
|
import ujson
|
|
|
|
|
|
|
|
from django.conf import settings
|
|
|
|
from django.contrib.auth import SESSION_KEY, BACKEND_SESSION_KEY, HASH_SESSION_KEY
|
|
|
|
from django.middleware.csrf import _get_new_csrf_token
|
|
|
|
from importlib import import_module
|
|
|
|
from tornado.ioloop import IOLoop
|
|
|
|
from tornado import gen
|
|
|
|
from tornado.httpclient import HTTPRequest
|
|
|
|
from tornado.websocket import websocket_connect, WebSocketClientConnection
|
2017-11-06 03:14:57 +01:00
|
|
|
from urllib.parse import urlparse, urlunparse, urljoin
|
2017-11-06 03:07:49 +01:00
|
|
|
from http.cookies import SimpleCookie
|
2016-12-29 12:02:44 +01:00
|
|
|
|
2017-08-25 07:54:40 +02:00
|
|
|
from zerver.models import get_system_bot
|
2016-12-29 12:02:44 +01:00
|
|
|
|
|
|
|
from typing import Any, Callable, Dict, Generator, Iterable, Optional
|
|
|
|
|
|
|
|
|
2017-11-05 11:52:10 +01:00
|
|
|
class WebsocketClient:
|
2016-12-29 12:02:44 +01:00
|
|
|
def __init__(self, host_url, sockjs_url, sender_email, run_on_start, validate_ssl=True,
|
|
|
|
**run_kwargs):
|
2017-10-31 05:04:32 +01:00
|
|
|
# NOTE: Callable should take a WebsocketClient & kwargs, but this is not standardised
|
|
|
|
# type: (str, str, str, Callable[..., None], bool, **Any) -> None
|
2016-12-29 12:02:44 +01:00
|
|
|
self.validate_ssl = validate_ssl
|
|
|
|
self.auth_email = sender_email
|
2017-08-25 07:54:40 +02:00
|
|
|
self.user_profile = get_system_bot(sender_email)
|
2016-12-29 12:02:44 +01:00
|
|
|
self.request_id_number = 0
|
|
|
|
self.parsed_host_url = urlparse(host_url)
|
|
|
|
self.sockjs_url = sockjs_url
|
|
|
|
self.cookie_dict = self._login()
|
|
|
|
self.cookie_str = self._get_cookie_header(self.cookie_dict)
|
|
|
|
self.events_data = self._get_queue_events(self.cookie_str)
|
|
|
|
self.ioloop_instance = IOLoop.instance()
|
|
|
|
self.run_on_start = run_on_start
|
|
|
|
self.run_kwargs = run_kwargs
|
|
|
|
self.scheme_dict = {'http': 'ws', 'https': 'wss'}
|
2017-07-09 02:08:11 +02:00
|
|
|
self.ws = None # type: Optional[WebSocketClientConnection]
|
2016-12-29 12:02:44 +01:00
|
|
|
|
2017-10-26 11:38:28 +02:00
|
|
|
def _login(self) -> Dict[str, str]:
|
2016-12-29 12:02:44 +01:00
|
|
|
|
|
|
|
# Ideally, we'd migrate this to use API auth instead of
|
|
|
|
# stealing cookies, but this works for now.
|
|
|
|
auth_backend = settings.AUTHENTICATION_BACKENDS[0]
|
|
|
|
session_auth_hash = self.user_profile.get_session_auth_hash()
|
|
|
|
engine = import_module(settings.SESSION_ENGINE)
|
2017-07-09 02:08:11 +02:00
|
|
|
session = engine.SessionStore() # type: ignore # import_module
|
2016-12-29 12:02:44 +01:00
|
|
|
session[SESSION_KEY] = self.user_profile._meta.pk.value_to_string(self.user_profile)
|
|
|
|
session[BACKEND_SESSION_KEY] = auth_backend
|
|
|
|
session[HASH_SESSION_KEY] = session_auth_hash
|
|
|
|
session.save()
|
|
|
|
return {
|
|
|
|
settings.SESSION_COOKIE_NAME: session.session_key,
|
|
|
|
settings.CSRF_COOKIE_NAME: _get_new_csrf_token()}
|
|
|
|
|
2017-10-26 11:38:28 +02:00
|
|
|
def _get_cookie_header(self, cookies: Dict[Any, Any]) -> str:
|
2016-12-29 12:02:44 +01:00
|
|
|
return ';'.join(
|
|
|
|
["{}={}".format(name, value) for name, value in cookies.items()])
|
|
|
|
|
|
|
|
@gen.coroutine
|
2017-10-26 11:38:28 +02:00
|
|
|
def _websocket_auth(self, queue_events_data: Dict[str, Dict[str, str]],
|
|
|
|
cookies: SimpleCookie) -> Generator[str, str, None]:
|
2016-12-29 12:02:44 +01:00
|
|
|
message = {
|
|
|
|
"req_id": self._get_request_id(),
|
|
|
|
"type": "auth",
|
|
|
|
"request": {
|
|
|
|
"csrf_token": cookies.get(settings.CSRF_COOKIE_NAME),
|
|
|
|
"queue_id": queue_events_data['queue_id'],
|
|
|
|
"status_inquiries": []
|
|
|
|
}
|
|
|
|
}
|
|
|
|
auth_frame_str = ujson.dumps(message)
|
|
|
|
self.ws.write_message(ujson.dumps([auth_frame_str]))
|
|
|
|
response_ack = yield self.ws.read_message()
|
|
|
|
response_message = yield self.ws.read_message()
|
|
|
|
raise gen.Return([response_ack, response_message])
|
|
|
|
|
2017-10-26 11:38:28 +02:00
|
|
|
def _get_queue_events(self, cookies_header: str) -> Dict[str, str]:
|
2016-12-29 12:02:44 +01:00
|
|
|
url = urljoin(self.parsed_host_url.geturl(), '/json/events?dont_block=true')
|
|
|
|
response = requests.get(url, headers={'Cookie': cookies_header}, verify=self.validate_ssl)
|
|
|
|
return response.json()
|
|
|
|
|
|
|
|
@gen.engine
|
2017-10-26 11:38:28 +02:00
|
|
|
def connect(self) -> Generator[str, WebSocketClientConnection, None]:
|
2016-12-29 12:02:44 +01:00
|
|
|
try:
|
|
|
|
request = HTTPRequest(url=self._get_websocket_url(), validate_cert=self.validate_ssl)
|
|
|
|
request.headers.add('Cookie', self.cookie_str)
|
|
|
|
self.ws = yield websocket_connect(request)
|
|
|
|
yield self.ws.read_message()
|
|
|
|
yield self._websocket_auth(self.events_data, self.cookie_dict)
|
|
|
|
self.run_on_start(self, **self.run_kwargs)
|
|
|
|
except Exception as e:
|
|
|
|
logging.exception(str(e))
|
|
|
|
IOLoop.instance().stop()
|
|
|
|
IOLoop.instance().stop()
|
|
|
|
|
|
|
|
@gen.coroutine
|
2017-10-26 11:38:28 +02:00
|
|
|
def send_message(self, client: str, type: str, subject: str, stream: str,
|
|
|
|
private_message_recepient: str,
|
|
|
|
content: str="") -> Generator[str, WebSocketClientConnection, None]:
|
2016-12-29 12:02:44 +01:00
|
|
|
user_message = {
|
|
|
|
"req_id": self._get_request_id(),
|
|
|
|
"type": "request",
|
|
|
|
"request": {
|
|
|
|
"client": client,
|
|
|
|
"type": type,
|
|
|
|
"subject": subject,
|
|
|
|
"stream": stream,
|
|
|
|
"private_message_recipient": private_message_recepient,
|
|
|
|
"content": content,
|
|
|
|
"sender_id": self.user_profile.id,
|
|
|
|
"queue_id": self.events_data['queue_id'],
|
|
|
|
"to": ujson.dumps([private_message_recepient]),
|
|
|
|
"reply_to": self.user_profile.email,
|
2017-07-14 19:30:23 +02:00
|
|
|
"local_id": -1
|
2016-12-29 12:02:44 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
self.ws.write_message(ujson.dumps([ujson.dumps(user_message)]))
|
|
|
|
response_ack = yield self.ws.read_message()
|
|
|
|
response_message = yield self.ws.read_message()
|
|
|
|
raise gen.Return([response_ack, response_message])
|
|
|
|
|
2017-10-26 11:38:28 +02:00
|
|
|
def run(self) -> None:
|
2016-12-29 12:02:44 +01:00
|
|
|
self.ioloop_instance.add_callback(self.connect)
|
|
|
|
self.ioloop_instance.start()
|
|
|
|
|
2017-10-26 11:38:28 +02:00
|
|
|
def _get_websocket_url(self) -> str:
|
2016-12-29 12:02:44 +01:00
|
|
|
return '{}://{}{}'.format(self.scheme_dict[self.parsed_host_url.scheme],
|
|
|
|
self.parsed_host_url.netloc, self.sockjs_url)
|
|
|
|
|
2017-10-26 11:38:28 +02:00
|
|
|
def _get_request_id(self) -> Iterable[str]:
|
2016-12-29 12:02:44 +01:00
|
|
|
self.request_id_number += 1
|
|
|
|
return ':'.join((self.events_data['queue_id'], str(self.request_id_number)))
|