diff --git a/zerver/lib/queue.py b/zerver/lib/queue.py index 6e0decf9ae..d9f97daaa3 100644 --- a/zerver/lib/queue.py +++ b/zerver/lib/queue.py @@ -33,6 +33,7 @@ class SimpleQueueClient: self.channel: Optional[BlockingChannel] = None self.consumers: Dict[str, Set[Consumer]] = defaultdict(set) self.rabbitmq_heartbeat = rabbitmq_heartbeat + self.is_consuming = False self._connect() def _connect(self) -> None: @@ -203,16 +204,59 @@ class SimpleQueueClient: with self.drain_queue(queue_name) as binary_messages: yield list(map(orjson.loads, binary_messages)) + def start_json_consumer(self, + queue_name: str, + callback: Callable[[List[Dict[str, Any]]], None], + batch_size: int=1, + timeout: Optional[int]=None) -> None: + if batch_size == 1: + timeout = None + + def do_consume(channel: BlockingChannel) -> None: + events: List[Dict[str, Any]] = [] + last_process = time.time() + max_processed: Optional[int] = None + self.is_consuming = True + + # This iterator technique will iteratively collect up to + # batch_size events from the RabbitMQ queue (if present) + # before calling the callback with the batch. If not + # enough events are present, it will sleep for at most + # timeout seconds before calling the callback with the + # batch of events it has. + for method, properties, body in channel.consume(queue_name, inactivity_timeout=timeout): + if body is not None: + events.append(orjson.loads(body)) + max_processed = method.delivery_tag + now = time.time() + if len(events) >= batch_size or (timeout and now >= last_process + timeout): + if events: + try: + callback(events) + channel.basic_ack(max_processed, multiple=True) + except Exception: + channel.basic_nack(max_processed, multiple=True) + raise + events = [] + last_process = now + if not self.is_consuming: + break + self.ensure_queue(queue_name, do_consume) + def local_queue_size(self) -> int: assert self.channel is not None return self.channel.get_waiting_message_count() + len(self.channel._pending_events) def start_consuming(self) -> None: assert self.channel is not None + assert not self.is_consuming + self.is_consuming = True self.channel.start_consuming() def stop_consuming(self) -> None: assert self.channel is not None + assert self.is_consuming + self.is_consuming = False self.channel.stop_consuming() # Patch pika.adapters.tornado_connection.TornadoConnection so that a socket error doesn't diff --git a/zerver/tests/test_queue_worker.py b/zerver/tests/test_queue_worker.py index 1ba6b3a11a..7024c49380 100644 --- a/zerver/tests/test_queue_worker.py +++ b/zerver/tests/test_queue_worker.py @@ -33,17 +33,6 @@ from zerver.worker.queue_processors import ( Event = Dict[str, Any] -# This is used for testing LoopQueueProcessingWorker, which -# would run forever if we don't mock time.sleep to abort the -# loop. -class AbortLoop(Exception): - pass - -loopworker_sleep_mock = patch( - 'zerver.worker.queue_processors.time.sleep', - side_effect=AbortLoop, -) - class WorkerTest(ZulipTestCase): class FakeClient: def __init__(self) -> None: @@ -71,6 +60,19 @@ class WorkerTest(ZulipTestCase): self.queues[queue_name] = [] yield events + def start_json_consumer(self, + queue_name: str, + callback: Callable[[List[Dict[str, Any]]], None], + batch_size: int=1, + timeout: Optional[int]=None) -> None: + chunk: List[Dict[str, Any]] = [] + queue = self.queues[queue_name] + while queue: + chunk.append(queue.pop(0)) + if len(chunk) >= batch_size or not len(queue): + callback(chunk) + chunk = [] + def local_queue_size(self) -> int: return sum([len(q) for q in self.queues.values()]) @@ -103,39 +105,31 @@ class WorkerTest(ZulipTestCase): ) fake_client.enqueue('user_activity', data_old_format) - with loopworker_sleep_mock: - with simulated_queue_client(lambda: fake_client): - worker = queue_processors.UserActivityWorker() - worker.setup() - try: - worker.start() - except AbortLoop: - pass - activity_records = UserActivity.objects.filter( - user_profile = user.id, - client = get_client('ios'), - ) - self.assertEqual(len(activity_records), 1) - self.assertEqual(activity_records[0].count, 2) + with simulated_queue_client(lambda: fake_client): + worker = queue_processors.UserActivityWorker() + worker.setup() + worker.start() + activity_records = UserActivity.objects.filter( + user_profile = user.id, + client = get_client('ios'), + ) + self.assertEqual(len(activity_records), 1) + self.assertEqual(activity_records[0].count, 2) # Now process the event a second time and confirm count goes # up. Ideally, we'd use an event with a slightly newer # time, but it's not really important. fake_client.enqueue('user_activity', data) - with loopworker_sleep_mock: - with simulated_queue_client(lambda: fake_client): - worker = queue_processors.UserActivityWorker() - worker.setup() - try: - worker.start() - except AbortLoop: - pass - activity_records = UserActivity.objects.filter( - user_profile = user.id, - client = get_client('ios'), - ) - self.assertEqual(len(activity_records), 1) - self.assertEqual(activity_records[0].count, 3) + with simulated_queue_client(lambda: fake_client): + worker = queue_processors.UserActivityWorker() + worker.setup() + worker.start() + activity_records = UserActivity.objects.filter( + user_profile = user.id, + client = get_client('ios'), + ) + self.assertEqual(len(activity_records), 1) + self.assertEqual(activity_records[0].count, 3) def test_missed_message_worker(self) -> None: cordelia = self.example_user('cordelia') @@ -596,14 +590,11 @@ class WorkerTest(ZulipTestCase): except OSError: # nocoverage # error handling for the directory not existing pass - with loopworker_sleep_mock, simulated_queue_client(lambda: fake_client): + with simulated_queue_client(lambda: fake_client): loopworker = UnreliableLoopWorker() loopworker.setup() with patch('logging.exception') as logging_exception_mock: - try: - loopworker.start() - except AbortLoop: - pass + loopworker.start() logging_exception_mock.assert_called_once_with( "Problem handling data on queue %s", "unreliable_loopworker", stack_info=True, diff --git a/zerver/worker/queue_processors.py b/zerver/worker/queue_processors.py index 63d8992fe2..ffa3c882a3 100644 --- a/zerver/worker/queue_processors.py +++ b/zerver/worker/queue_processors.py @@ -348,25 +348,18 @@ class QueueProcessingWorker(ABC): self.q.stop_consuming() class LoopQueueProcessingWorker(QueueProcessingWorker): - sleep_delay = 0 - sleep_only_if_empty = True - is_consuming = False + sleep_delay = 1 + batch_size = 100 def start(self) -> None: # nocoverage assert self.q is not None self.initialize_statistics() - self.is_consuming = True - while self.is_consuming: - with self.q.json_drain_queue(self.queue_name) as events: - self.do_consume(self.consume_batch, events) - # To avoid spinning the CPU, we go to sleep if there's - # nothing in the queue, or for certain queues with - # sleep_only_if_empty=False, unconditionally. - if not self.sleep_only_if_empty or len(events) == 0: - time.sleep(self.sleep_delay) - - def stop(self) -> None: - self.is_consuming = False + self.q.start_json_consumer( + self.queue_name, + lambda events: self.do_consume(self.consume_batch, events), + batch_size=self.batch_size, + timeout=self.sleep_delay, + ) @abstractmethod def consume_batch(self, events: List[Dict[str, Any]]) -> None: @@ -460,8 +453,6 @@ class UserActivityWorker(LoopQueueProcessingWorker): common events from doing an action multiple times. """ - sleep_delay = 10 - sleep_only_if_empty = True client_id_map: Dict[str, int] = {} def start(self) -> None: