queue: Fix channel type for TornadoQueueClient.

The BlockingChannel annotations in TornadoQueueClient were flat-out
wrong.  BlockingChannel and Channel have no common base classes.

Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
Anders Kaseorg 2021-08-02 18:42:32 -07:00 committed by Tim Abbott
parent 5751479932
commit 87799177b5
1 changed files with 16 additions and 14 deletions

View File

@ -4,26 +4,28 @@ import threading
import time import time
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from collections import defaultdict from collections import defaultdict
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union from typing import Any, Callable, Dict, Generic, List, Mapping, Optional, Set, TypeVar, Union
import orjson import orjson
import pika import pika
import pika.adapters.tornado_connection import pika.adapters.tornado_connection
from django.conf import settings from django.conf import settings
from pika.adapters.blocking_connection import BlockingChannel from pika.adapters.blocking_connection import BlockingChannel
from pika.channel import Channel
from pika.spec import Basic from pika.spec import Basic
from tornado import ioloop from tornado import ioloop
from zerver.lib.utils import statsd from zerver.lib.utils import statsd
MAX_REQUEST_RETRIES = 3 MAX_REQUEST_RETRIES = 3
Consumer = Callable[[BlockingChannel, Basic.Deliver, pika.BasicProperties, bytes], None] ChannelT = TypeVar("ChannelT", Channel, BlockingChannel)
Consumer = Callable[[ChannelT, Basic.Deliver, pika.BasicProperties, bytes], None]
# This simple queuing library doesn't expose much of the power of # This simple queuing library doesn't expose much of the power of
# rabbitmq/pika's queuing system; its purpose is to just provide an # rabbitmq/pika's queuing system; its purpose is to just provide an
# interface for external files to put things into queues and take them # interface for external files to put things into queues and take them
# out from bots without having to import pika code all over our codebase. # out from bots without having to import pika code all over our codebase.
class QueueClient(metaclass=ABCMeta): class QueueClient(Generic[ChannelT], metaclass=ABCMeta):
def __init__( def __init__(
self, self,
# Disable RabbitMQ heartbeats by default because BlockingConnection can't process them # Disable RabbitMQ heartbeats by default because BlockingConnection can't process them
@ -31,8 +33,8 @@ class QueueClient(metaclass=ABCMeta):
) -> None: ) -> None:
self.log = logging.getLogger("zulip.queue") self.log = logging.getLogger("zulip.queue")
self.queues: Set[str] = set() self.queues: Set[str] = set()
self.channel: Optional[BlockingChannel] = None self.channel: Optional[ChannelT] = None
self.consumers: Dict[str, Set[Consumer]] = defaultdict(set) self.consumers: Dict[str, Set[Consumer[ChannelT]]] = defaultdict(set)
self.rabbitmq_heartbeat = rabbitmq_heartbeat self.rabbitmq_heartbeat = rabbitmq_heartbeat
self.is_consuming = False self.is_consuming = False
self._connect() self._connect()
@ -78,7 +80,7 @@ class QueueClient(metaclass=ABCMeta):
def _generate_ctag(self, queue_name: str) -> str: def _generate_ctag(self, queue_name: str) -> str:
return f"{queue_name}_{str(random.getrandbits(16))}" return f"{queue_name}_{str(random.getrandbits(16))}"
def _reconnect_consumer_callback(self, queue: str, consumer: Consumer) -> None: def _reconnect_consumer_callback(self, queue: str, consumer: Consumer[ChannelT]) -> None:
self.log.info(f"Queue reconnecting saved consumer {consumer} to queue {queue}") self.log.info(f"Queue reconnecting saved consumer {consumer} to queue {queue}")
self.ensure_queue( self.ensure_queue(
queue, queue,
@ -98,11 +100,11 @@ class QueueClient(metaclass=ABCMeta):
return self.channel is not None return self.channel is not None
@abstractmethod @abstractmethod
def ensure_queue(self, queue_name: str, callback: Callable[[BlockingChannel], None]) -> None: def ensure_queue(self, queue_name: str, callback: Callable[[ChannelT], None]) -> None:
raise NotImplementedError raise NotImplementedError
def publish(self, queue_name: str, body: bytes) -> None: def publish(self, queue_name: str, body: bytes) -> None:
def do_publish(channel: BlockingChannel) -> None: def do_publish(channel: ChannelT) -> None:
channel.basic_publish( channel.basic_publish(
exchange="", exchange="",
routing_key=queue_name, routing_key=queue_name,
@ -126,7 +128,7 @@ class QueueClient(metaclass=ABCMeta):
self.publish(queue_name, data) self.publish(queue_name, data)
class SimpleQueueClient(QueueClient): class SimpleQueueClient(QueueClient[BlockingChannel]):
def _connect(self) -> None: def _connect(self) -> None:
start = time.time() start = time.time()
self.connection = pika.BlockingConnection(self._get_parameters()) self.connection = pika.BlockingConnection(self._get_parameters())
@ -227,7 +229,7 @@ calling _adapter_disconnect, ignoring",
) )
class TornadoQueueClient(QueueClient): class TornadoQueueClient(QueueClient[Channel]):
connection: Optional[ExceptionFreeTornadoConnection] connection: Optional[ExceptionFreeTornadoConnection]
# Based on: # Based on:
@ -237,7 +239,7 @@ class TornadoQueueClient(QueueClient):
# TornadoConnection can process heartbeats, so enable them. # TornadoConnection can process heartbeats, so enable them.
rabbitmq_heartbeat=None rabbitmq_heartbeat=None
) )
self._on_open_cbs: List[Callable[[BlockingChannel], None]] = [] self._on_open_cbs: List[Callable[[Channel], None]] = []
self._connection_failure_count = 0 self._connection_failure_count = 0
def _connect(self) -> None: def _connect(self) -> None:
@ -305,7 +307,7 @@ class TornadoQueueClient(QueueClient):
# Let _on_connection_closed deal with trying again. # Let _on_connection_closed deal with trying again.
self.log.warning("TornadoQueueClient couldn't open channel: connection already closed") self.log.warning("TornadoQueueClient couldn't open channel: connection already closed")
def _on_channel_open(self, channel: BlockingChannel) -> None: def _on_channel_open(self, channel: Channel) -> None:
self.channel = channel self.channel = channel
for callback in self._on_open_cbs: for callback in self._on_open_cbs:
callback(channel) callback(channel)
@ -316,7 +318,7 @@ class TornadoQueueClient(QueueClient):
if self.connection is not None: if self.connection is not None:
self.connection.close() self.connection.close()
def ensure_queue(self, queue_name: str, callback: Callable[[BlockingChannel], None]) -> None: def ensure_queue(self, queue_name: str, callback: Callable[[Channel], None]) -> None:
def finish(frame: Any) -> None: def finish(frame: Any) -> None:
assert self.channel is not None assert self.channel is not None
self.queues.add(queue_name) self.queues.add(queue_name)
@ -343,7 +345,7 @@ class TornadoQueueClient(QueueClient):
timeout: Optional[int] = None, timeout: Optional[int] = None,
) -> None: ) -> None:
def wrapped_consumer( def wrapped_consumer(
ch: BlockingChannel, ch: Channel,
method: Basic.Deliver, method: Basic.Deliver,
properties: pika.BasicProperties, properties: pika.BasicProperties,
body: bytes, body: bytes,