mirror of https://github.com/zulip/zulip.git
queue: Split common part of SimpleQueueClient into new base class.
Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
parent
6fe67f0143
commit
bd6a2b149c
|
@ -2,6 +2,7 @@ import logging
|
|||
import random
|
||||
import threading
|
||||
import time
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Set
|
||||
|
||||
|
@ -22,7 +23,7 @@ Consumer = Callable[[BlockingChannel, Basic.Deliver, pika.BasicProperties, bytes
|
|||
# rabbitmq/pika's queuing system; its purpose is to just provide an
|
||||
# interface for external files to put things into queues and take them
|
||||
# out from bots without having to import pika code all over our codebase.
|
||||
class SimpleQueueClient:
|
||||
class QueueClient(metaclass=ABCMeta):
|
||||
def __init__(
|
||||
self,
|
||||
# Disable RabbitMQ heartbeats by default because BlockingConnection can't process them
|
||||
|
@ -36,17 +37,13 @@ class SimpleQueueClient:
|
|||
self.is_consuming = False
|
||||
self._connect()
|
||||
|
||||
@abstractmethod
|
||||
def _connect(self) -> None:
|
||||
start = time.time()
|
||||
self.connection = pika.BlockingConnection(self._get_parameters())
|
||||
self.channel = self.connection.channel()
|
||||
self.log.info(f"SimpleQueueClient connected (connecting took {time.time() - start:.3f}s)")
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _reconnect(self) -> None:
|
||||
self.connection = None
|
||||
self.channel = None
|
||||
self.queues = set()
|
||||
self._connect()
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_parameters(self) -> pika.ConnectionParameters:
|
||||
credentials = pika.PlainCredentials(settings.RABBITMQ_USERNAME, settings.RABBITMQ_PASSWORD)
|
||||
|
@ -97,24 +94,12 @@ class SimpleQueueClient:
|
|||
for consumer in consumers:
|
||||
self._reconnect_consumer_callback(queue, consumer)
|
||||
|
||||
def close(self) -> None:
|
||||
if self.connection:
|
||||
self.connection.close()
|
||||
|
||||
def ready(self) -> bool:
|
||||
return self.channel is not None
|
||||
|
||||
@abstractmethod
|
||||
def ensure_queue(self, queue_name: str, callback: Callable[[BlockingChannel], None]) -> None:
|
||||
"""Ensure that a given queue has been declared, and then call
|
||||
the callback with no arguments."""
|
||||
if self.connection is None or not self.connection.is_open:
|
||||
self._connect()
|
||||
|
||||
assert self.channel is not None
|
||||
if queue_name not in self.queues:
|
||||
self.channel.queue_declare(queue=queue_name, durable=True)
|
||||
self.queues.add(queue_name)
|
||||
callback(self.channel)
|
||||
raise NotImplementedError
|
||||
|
||||
def publish(self, queue_name: str, body: bytes) -> None:
|
||||
def do_publish(channel: BlockingChannel) -> None:
|
||||
|
@ -140,6 +125,36 @@ class SimpleQueueClient:
|
|||
self._reconnect()
|
||||
self.publish(queue_name, data)
|
||||
|
||||
|
||||
class SimpleQueueClient(QueueClient):
|
||||
def _connect(self) -> None:
|
||||
start = time.time()
|
||||
self.connection = pika.BlockingConnection(self._get_parameters())
|
||||
self.channel = self.connection.channel()
|
||||
self.log.info(f"SimpleQueueClient connected (connecting took {time.time() - start:.3f}s)")
|
||||
|
||||
def _reconnect(self) -> None:
|
||||
self.connection = None
|
||||
self.channel = None
|
||||
self.queues = set()
|
||||
self._connect()
|
||||
|
||||
def close(self) -> None:
|
||||
if self.connection is not None:
|
||||
self.connection.close()
|
||||
|
||||
def ensure_queue(self, queue_name: str, callback: Callable[[BlockingChannel], None]) -> None:
|
||||
"""Ensure that a given queue has been declared, and then call
|
||||
the callback with no arguments."""
|
||||
if self.connection is None or not self.connection.is_open:
|
||||
self._connect()
|
||||
|
||||
assert self.channel is not None
|
||||
if queue_name not in self.queues:
|
||||
self.channel.queue_declare(queue=queue_name, durable=True)
|
||||
self.queues.add(queue_name)
|
||||
callback(self.channel)
|
||||
|
||||
def start_json_consumer(
|
||||
self,
|
||||
queue_name: str,
|
||||
|
|
|
@ -13,8 +13,8 @@ class TestTornadoQueueClient(ZulipTestCase):
|
|||
@mock.patch("zerver.lib.queue.ExceptionFreeTornadoConnection", autospec=True)
|
||||
def test_on_open_closed(self, mock_cxn: mock.MagicMock) -> None:
|
||||
with self.assertLogs("zulip.queue", "WARNING") as m:
|
||||
mock_cxn().channel.side_effect = ConnectionClosed("500", "test")
|
||||
connection = TornadoQueueClient()
|
||||
connection.connection.channel.side_effect = ConnectionClosed("500", "test")
|
||||
connection._on_open(mock.MagicMock())
|
||||
self.assertEqual(
|
||||
m.output,
|
||||
|
|
Loading…
Reference in New Issue