queue: Switch batch interface to use the channel.consume iterator.

This low-level interface allows consuming from a queue with timeouts.
This can be used to either consume in batches (with an upper timeout),
or one-at-a-time.  This is notably more performant than calling
`.get()` repeatedly (what json_drain_queue does under the hood), which
is "*highly discouraged* as it is *very inefficient*"[1].

Before this change:
```
$ ./manage.py queue_rate --count 10000 --batch
Purging queue...
Enqueue rate: 11158 / sec
Dequeue rate: 3075 / sec
```

After:
```
$ ./manage.py queue_rate --count 10000 --batch
Purging queue...
Enqueue rate: 11511 / sec
Dequeue rate: 19938 / sec
```

[1] https://www.rabbitmq.com/consumers.html#fetching
This commit is contained in:
Alex Vandiver 2020-10-09 13:50:53 -07:00 committed by Tim Abbott
parent 571f8b8664
commit f9358d5330
3 changed files with 87 additions and 61 deletions

View File

@ -33,6 +33,7 @@ class SimpleQueueClient:
self.channel: Optional[BlockingChannel] = None self.channel: Optional[BlockingChannel] = None
self.consumers: Dict[str, Set[Consumer]] = defaultdict(set) self.consumers: Dict[str, Set[Consumer]] = defaultdict(set)
self.rabbitmq_heartbeat = rabbitmq_heartbeat self.rabbitmq_heartbeat = rabbitmq_heartbeat
self.is_consuming = False
self._connect() self._connect()
def _connect(self) -> None: def _connect(self) -> None:
@ -203,16 +204,59 @@ class SimpleQueueClient:
with self.drain_queue(queue_name) as binary_messages: with self.drain_queue(queue_name) as binary_messages:
yield list(map(orjson.loads, binary_messages)) yield list(map(orjson.loads, binary_messages))
def start_json_consumer(self,
queue_name: str,
callback: Callable[[List[Dict[str, Any]]], None],
batch_size: int=1,
timeout: Optional[int]=None) -> None:
if batch_size == 1:
timeout = None
def do_consume(channel: BlockingChannel) -> None:
events: List[Dict[str, Any]] = []
last_process = time.time()
max_processed: Optional[int] = None
self.is_consuming = True
# This iterator technique will iteratively collect up to
# batch_size events from the RabbitMQ queue (if present)
# before calling the callback with the batch. If not
# enough events are present, it will sleep for at most
# timeout seconds before calling the callback with the
# batch of events it has.
for method, properties, body in channel.consume(queue_name, inactivity_timeout=timeout):
if body is not None:
events.append(orjson.loads(body))
max_processed = method.delivery_tag
now = time.time()
if len(events) >= batch_size or (timeout and now >= last_process + timeout):
if events:
try:
callback(events)
channel.basic_ack(max_processed, multiple=True)
except Exception:
channel.basic_nack(max_processed, multiple=True)
raise
events = []
last_process = now
if not self.is_consuming:
break
self.ensure_queue(queue_name, do_consume)
def local_queue_size(self) -> int: def local_queue_size(self) -> int:
assert self.channel is not None assert self.channel is not None
return self.channel.get_waiting_message_count() + len(self.channel._pending_events) return self.channel.get_waiting_message_count() + len(self.channel._pending_events)
def start_consuming(self) -> None: def start_consuming(self) -> None:
assert self.channel is not None assert self.channel is not None
assert not self.is_consuming
self.is_consuming = True
self.channel.start_consuming() self.channel.start_consuming()
def stop_consuming(self) -> None: def stop_consuming(self) -> None:
assert self.channel is not None assert self.channel is not None
assert self.is_consuming
self.is_consuming = False
self.channel.stop_consuming() self.channel.stop_consuming()
# Patch pika.adapters.tornado_connection.TornadoConnection so that a socket error doesn't # Patch pika.adapters.tornado_connection.TornadoConnection so that a socket error doesn't

View File

@ -33,17 +33,6 @@ from zerver.worker.queue_processors import (
Event = Dict[str, Any] Event = Dict[str, Any]
# This is used for testing LoopQueueProcessingWorker, which
# would run forever if we don't mock time.sleep to abort the
# loop.
class AbortLoop(Exception):
pass
loopworker_sleep_mock = patch(
'zerver.worker.queue_processors.time.sleep',
side_effect=AbortLoop,
)
class WorkerTest(ZulipTestCase): class WorkerTest(ZulipTestCase):
class FakeClient: class FakeClient:
def __init__(self) -> None: def __init__(self) -> None:
@ -71,6 +60,19 @@ class WorkerTest(ZulipTestCase):
self.queues[queue_name] = [] self.queues[queue_name] = []
yield events yield events
def start_json_consumer(self,
queue_name: str,
callback: Callable[[List[Dict[str, Any]]], None],
batch_size: int=1,
timeout: Optional[int]=None) -> None:
chunk: List[Dict[str, Any]] = []
queue = self.queues[queue_name]
while queue:
chunk.append(queue.pop(0))
if len(chunk) >= batch_size or not len(queue):
callback(chunk)
chunk = []
def local_queue_size(self) -> int: def local_queue_size(self) -> int:
return sum([len(q) for q in self.queues.values()]) return sum([len(q) for q in self.queues.values()])
@ -103,39 +105,31 @@ class WorkerTest(ZulipTestCase):
) )
fake_client.enqueue('user_activity', data_old_format) fake_client.enqueue('user_activity', data_old_format)
with loopworker_sleep_mock: with simulated_queue_client(lambda: fake_client):
with simulated_queue_client(lambda: fake_client): worker = queue_processors.UserActivityWorker()
worker = queue_processors.UserActivityWorker() worker.setup()
worker.setup() worker.start()
try: activity_records = UserActivity.objects.filter(
worker.start() user_profile = user.id,
except AbortLoop: client = get_client('ios'),
pass )
activity_records = UserActivity.objects.filter( self.assertEqual(len(activity_records), 1)
user_profile = user.id, self.assertEqual(activity_records[0].count, 2)
client = get_client('ios'),
)
self.assertEqual(len(activity_records), 1)
self.assertEqual(activity_records[0].count, 2)
# Now process the event a second time and confirm count goes # Now process the event a second time and confirm count goes
# up. Ideally, we'd use an event with a slightly newer # up. Ideally, we'd use an event with a slightly newer
# time, but it's not really important. # time, but it's not really important.
fake_client.enqueue('user_activity', data) fake_client.enqueue('user_activity', data)
with loopworker_sleep_mock: with simulated_queue_client(lambda: fake_client):
with simulated_queue_client(lambda: fake_client): worker = queue_processors.UserActivityWorker()
worker = queue_processors.UserActivityWorker() worker.setup()
worker.setup() worker.start()
try: activity_records = UserActivity.objects.filter(
worker.start() user_profile = user.id,
except AbortLoop: client = get_client('ios'),
pass )
activity_records = UserActivity.objects.filter( self.assertEqual(len(activity_records), 1)
user_profile = user.id, self.assertEqual(activity_records[0].count, 3)
client = get_client('ios'),
)
self.assertEqual(len(activity_records), 1)
self.assertEqual(activity_records[0].count, 3)
def test_missed_message_worker(self) -> None: def test_missed_message_worker(self) -> None:
cordelia = self.example_user('cordelia') cordelia = self.example_user('cordelia')
@ -596,14 +590,11 @@ class WorkerTest(ZulipTestCase):
except OSError: # nocoverage # error handling for the directory not existing except OSError: # nocoverage # error handling for the directory not existing
pass pass
with loopworker_sleep_mock, simulated_queue_client(lambda: fake_client): with simulated_queue_client(lambda: fake_client):
loopworker = UnreliableLoopWorker() loopworker = UnreliableLoopWorker()
loopworker.setup() loopworker.setup()
with patch('logging.exception') as logging_exception_mock: with patch('logging.exception') as logging_exception_mock:
try: loopworker.start()
loopworker.start()
except AbortLoop:
pass
logging_exception_mock.assert_called_once_with( logging_exception_mock.assert_called_once_with(
"Problem handling data on queue %s", "unreliable_loopworker", "Problem handling data on queue %s", "unreliable_loopworker",
stack_info=True, stack_info=True,

View File

@ -348,25 +348,18 @@ class QueueProcessingWorker(ABC):
self.q.stop_consuming() self.q.stop_consuming()
class LoopQueueProcessingWorker(QueueProcessingWorker): class LoopQueueProcessingWorker(QueueProcessingWorker):
sleep_delay = 0 sleep_delay = 1
sleep_only_if_empty = True batch_size = 100
is_consuming = False
def start(self) -> None: # nocoverage def start(self) -> None: # nocoverage
assert self.q is not None assert self.q is not None
self.initialize_statistics() self.initialize_statistics()
self.is_consuming = True self.q.start_json_consumer(
while self.is_consuming: self.queue_name,
with self.q.json_drain_queue(self.queue_name) as events: lambda events: self.do_consume(self.consume_batch, events),
self.do_consume(self.consume_batch, events) batch_size=self.batch_size,
# To avoid spinning the CPU, we go to sleep if there's timeout=self.sleep_delay,
# nothing in the queue, or for certain queues with )
# sleep_only_if_empty=False, unconditionally.
if not self.sleep_only_if_empty or len(events) == 0:
time.sleep(self.sleep_delay)
def stop(self) -> None:
self.is_consuming = False
@abstractmethod @abstractmethod
def consume_batch(self, events: List[Dict[str, Any]]) -> None: def consume_batch(self, events: List[Dict[str, Any]]) -> None:
@ -460,8 +453,6 @@ class UserActivityWorker(LoopQueueProcessingWorker):
common events from doing an action multiple times. common events from doing an action multiple times.
""" """
sleep_delay = 10
sleep_only_if_empty = True
client_id_map: Dict[str, int] = {} client_id_map: Dict[str, int] = {}
def start(self) -> None: def start(self) -> None: