From bd6a2b149c812b687dab52cf923391be961cbc06 Mon Sep 17 00:00:00 2001 From: Anders Kaseorg Date: Mon, 2 Aug 2021 17:30:24 -0700 Subject: [PATCH] queue: Split common part of SimpleQueueClient into new base class. Signed-off-by: Anders Kaseorg --- zerver/lib/queue.py | 61 ++++++++++++++++++++++++-------------- zerver/tests/test_queue.py | 2 +- 2 files changed, 39 insertions(+), 24 deletions(-) diff --git a/zerver/lib/queue.py b/zerver/lib/queue.py index 8243efe884..0fbca5dbb2 100644 --- a/zerver/lib/queue.py +++ b/zerver/lib/queue.py @@ -2,6 +2,7 @@ import logging import random import threading import time +from abc import ABCMeta, abstractmethod from collections import defaultdict from typing import Any, Callable, Dict, List, Mapping, Optional, Set @@ -22,7 +23,7 @@ Consumer = Callable[[BlockingChannel, Basic.Deliver, pika.BasicProperties, bytes # rabbitmq/pika's queuing system; its purpose is to just provide an # interface for external files to put things into queues and take them # out from bots without having to import pika code all over our codebase. -class SimpleQueueClient: +class QueueClient(metaclass=ABCMeta): def __init__( self, # Disable RabbitMQ heartbeats by default because BlockingConnection can't process them @@ -36,17 +37,13 @@ class SimpleQueueClient: self.is_consuming = False self._connect() + @abstractmethod def _connect(self) -> None: - start = time.time() - self.connection = pika.BlockingConnection(self._get_parameters()) - self.channel = self.connection.channel() - self.log.info(f"SimpleQueueClient connected (connecting took {time.time() - start:.3f}s)") + raise NotImplementedError + @abstractmethod def _reconnect(self) -> None: - self.connection = None - self.channel = None - self.queues = set() - self._connect() + raise NotImplementedError def _get_parameters(self) -> pika.ConnectionParameters: credentials = pika.PlainCredentials(settings.RABBITMQ_USERNAME, settings.RABBITMQ_PASSWORD) @@ -97,24 +94,12 @@ class SimpleQueueClient: for consumer in consumers: self._reconnect_consumer_callback(queue, consumer) - def close(self) -> None: - if self.connection: - self.connection.close() - def ready(self) -> bool: return self.channel is not None + @abstractmethod def ensure_queue(self, queue_name: str, callback: Callable[[BlockingChannel], None]) -> None: - """Ensure that a given queue has been declared, and then call - the callback with no arguments.""" - if self.connection is None or not self.connection.is_open: - self._connect() - - assert self.channel is not None - if queue_name not in self.queues: - self.channel.queue_declare(queue=queue_name, durable=True) - self.queues.add(queue_name) - callback(self.channel) + raise NotImplementedError def publish(self, queue_name: str, body: bytes) -> None: def do_publish(channel: BlockingChannel) -> None: @@ -140,6 +125,36 @@ class SimpleQueueClient: self._reconnect() self.publish(queue_name, data) + +class SimpleQueueClient(QueueClient): + def _connect(self) -> None: + start = time.time() + self.connection = pika.BlockingConnection(self._get_parameters()) + self.channel = self.connection.channel() + self.log.info(f"SimpleQueueClient connected (connecting took {time.time() - start:.3f}s)") + + def _reconnect(self) -> None: + self.connection = None + self.channel = None + self.queues = set() + self._connect() + + def close(self) -> None: + if self.connection is not None: + self.connection.close() + + def ensure_queue(self, queue_name: str, callback: Callable[[BlockingChannel], None]) -> None: + """Ensure that a given queue has been declared, and then call + the callback with no arguments.""" + if self.connection is None or not self.connection.is_open: + self._connect() + + assert self.channel is not None + if queue_name not in self.queues: + self.channel.queue_declare(queue=queue_name, durable=True) + self.queues.add(queue_name) + callback(self.channel) + def start_json_consumer( self, queue_name: str, diff --git a/zerver/tests/test_queue.py b/zerver/tests/test_queue.py index bf7ea4c5eb..f5c98d5c72 100644 --- a/zerver/tests/test_queue.py +++ b/zerver/tests/test_queue.py @@ -13,8 +13,8 @@ class TestTornadoQueueClient(ZulipTestCase): @mock.patch("zerver.lib.queue.ExceptionFreeTornadoConnection", autospec=True) def test_on_open_closed(self, mock_cxn: mock.MagicMock) -> None: with self.assertLogs("zulip.queue", "WARNING") as m: + mock_cxn().channel.side_effect = ConnectionClosed("500", "test") connection = TornadoQueueClient() - connection.connection.channel.side_effect = ConnectionClosed("500", "test") connection._on_open(mock.MagicMock()) self.assertEqual( m.output,