diff --git a/zerver/management/commands/process_queue.py b/zerver/management/commands/process_queue.py index 24a634b315..c059a19b6d 100644 --- a/zerver/management/commands/process_queue.py +++ b/zerver/management/commands/process_queue.py @@ -86,6 +86,7 @@ class Command(BaseCommand): signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGUSR1, signal_handler) + worker.ENABLE_TIMEOUTS = True worker.start() class Threaded_worker(threading.Thread): diff --git a/zerver/tests/test_queue_worker.py b/zerver/tests/test_queue_worker.py index 76d1a8d448..8259eb4710 100644 --- a/zerver/tests/test_queue_worker.py +++ b/zerver/tests/test_queue_worker.py @@ -26,6 +26,7 @@ from zerver.worker.queue_processors import ( LoopQueueProcessingWorker, MissedMessageWorker, QueueProcessingWorker, + WorkerTimeoutException, get_active_worker_queues, ) @@ -621,6 +622,47 @@ class WorkerTest(ZulipTestCase): self.assertEqual([event["type"] for event in events], ['good', 'fine', 'unexpected behaviour', 'back to normal']) + def test_timeouts(self) -> None: + processed = [] + + @queue_processors.assign_queue('timeout_worker') + class TimeoutWorker(queue_processors.QueueProcessingWorker): + MAX_CONSUME_SECONDS = 1 + + def consume(self, data: Mapping[str, Any]) -> None: + if data["type"] == 'timeout': + time.sleep(5) + processed.append(data["type"]) + + fake_client = self.FakeClient() + for msg in ['good', 'fine', 'timeout', 'back to normal']: + fake_client.queue.append(('timeout_worker', {'type': msg})) + + fn = os.path.join(settings.QUEUE_ERROR_DIR, 'timeout_worker.errors') + try: + os.remove(fn) + except OSError: # nocoverage # error handling for the directory not existing + pass + + with simulated_queue_client(lambda: fake_client): + worker = TimeoutWorker() + worker.setup() + worker.ENABLE_TIMEOUTS = True + with patch('logging.exception') as logging_exception_mock: + worker.start() + logging_exception_mock.assert_called_once_with( + "%s in queue %s", str(WorkerTimeoutException(1, 1)), "timeout_worker", + stack_info=True, + ) + + self.assertEqual(processed, ['good', 'fine', 'back to normal']) + with open(fn) as f: + line = f.readline().strip() + events = orjson.loads(line.split('\t')[1]) + self.assert_length(events, 1) + event = events[0] + self.assertEqual(event["type"], 'timeout') + def test_worker_noname(self) -> None: class TestWorker(queue_processors.QueueProcessingWorker): def __init__(self) -> None: diff --git a/zerver/worker/queue_processors.py b/zerver/worker/queue_processors.py index 009d170d44..27a6f8f5bc 100644 --- a/zerver/worker/queue_processors.py +++ b/zerver/worker/queue_processors.py @@ -4,6 +4,7 @@ import copy import datetime import email import email.policy +import functools import logging import os import signal @@ -17,6 +18,7 @@ from collections import defaultdict, deque from email.message import EmailMessage from functools import wraps from threading import Lock, Timer +from types import FrameType from typing import ( Any, Callable, @@ -104,6 +106,14 @@ from zerver.models import ( logger = logging.getLogger(__name__) +class WorkerTimeoutException(Exception): + def __init__(self, limit: int, event_count: int) -> None: + self.limit = limit + self.event_count = event_count + + def __str__(self) -> str: + return f"Timed out after {self.limit * self.event_count} seconds processing {self.event_count} events" + class WorkerDeclarationException(Exception): pass @@ -163,8 +173,13 @@ def retry_send_email_failures( return wrapper +def timer_expired(limit: int, event_count: int, signal: int, frame: FrameType) -> None: + raise WorkerTimeoutException(limit, event_count) + class QueueProcessingWorker(ABC): queue_name: str + MAX_CONSUME_SECONDS: Optional[int] = 10 + ENABLE_TIMEOUTS = False CONSUME_ITERATIONS_BEFORE_UPDATE_STATS_NUM = 50 MAX_SECONDS_BEFORE_UPDATE_STATS = 30 @@ -247,11 +262,17 @@ class QueueProcessingWorker(ABC): self.update_statistics(self.get_remaining_queue_size()) time_start = time.time() - consume_func(events) + if self.MAX_CONSUME_SECONDS and self.ENABLE_TIMEOUTS: + signal.signal(signal.SIGALRM, functools.partial(timer_expired, self.MAX_CONSUME_SECONDS, len(events))) + signal.alarm(self.MAX_CONSUME_SECONDS * len(events)) + consume_func(events) + signal.alarm(0) + else: + consume_func(events) consume_time_seconds = time.time() - time_start self.consumed_since_last_emptied += len(events) - except Exception: - self._handle_consume_exception(events) + except Exception as e: + self._handle_consume_exception(events, e) finally: flush_per_request_caches() reset_queries() @@ -281,13 +302,17 @@ class QueueProcessingWorker(ABC): consume_func = lambda events: self.consume(events[0]) self.do_consume(consume_func, [data]) - def _handle_consume_exception(self, events: List[Dict[str, Any]]) -> None: + def _handle_consume_exception(self, events: List[Dict[str, Any]], exception: Exception) -> None: with configure_scope() as scope: scope.set_context("events", { "data": events, "queue_name": self.queue_name, }) - logging.exception("Problem handling data on queue %s", self.queue_name, stack_info=True) + if isinstance(exception, WorkerTimeoutException): + logging.exception("%s in queue %s", + str(exception), self.queue_name, stack_info=True) + else: + logging.exception("Problem handling data on queue %s", self.queue_name, stack_info=True) if not os.path.exists(settings.QUEUE_ERROR_DIR): os.mkdir(settings.QUEUE_ERROR_DIR) # nocoverage # Use 'mark_sanitized' to prevent Pysa from detecting this false positive @@ -738,6 +763,10 @@ class DeferredWorker(QueueProcessingWorker): can provide a low-latency HTTP response or avoid risk of request timeouts for an operation that could in rare cases take minutes). """ + # Because these operations have no SLO, and can take minutes, + # remove any processing timeouts + MAX_CONSUME_SECONDS = None + def consume(self, event: Dict[str, Any]) -> None: if event['type'] == 'mark_stream_messages_as_read': user_profile = get_user_profile_by_id(event['user_profile_id']) diff --git a/zilencer/management/commands/queue_rate.py b/zilencer/management/commands/queue_rate.py index 344ccc0c74..2641659b4b 100644 --- a/zilencer/management/commands/queue_rate.py +++ b/zilencer/management/commands/queue_rate.py @@ -46,6 +46,7 @@ class Command(BaseCommand): worker: QueueProcessingWorker = NoopWorker(count, options["slow"]) if options["batch"]: worker = BatchNoopWorker(count, options["slow"]) + worker.ENABLE_TIMEOUTS = True worker.setup() assert worker.q is not None assert worker.q.channel is not None