queue: Fix channel type for TornadoQueueClient.

The BlockingChannel annotations in TornadoQueueClient were flat-out
wrong.  BlockingChannel and Channel have no common base classes.

Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
Anders Kaseorg 2021-08-02 18:42:32 -07:00 committed by Tim Abbott
parent 5751479932
commit 87799177b5
1 changed files with 16 additions and 14 deletions

View File

@ -4,26 +4,28 @@ import threading
import time
from abc import ABCMeta, abstractmethod
from collections import defaultdict
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
from typing import Any, Callable, Dict, Generic, List, Mapping, Optional, Set, TypeVar, Union
import orjson
import pika
import pika.adapters.tornado_connection
from django.conf import settings
from pika.adapters.blocking_connection import BlockingChannel
from pika.channel import Channel
from pika.spec import Basic
from tornado import ioloop
from zerver.lib.utils import statsd
MAX_REQUEST_RETRIES = 3
Consumer = Callable[[BlockingChannel, Basic.Deliver, pika.BasicProperties, bytes], None]
ChannelT = TypeVar("ChannelT", Channel, BlockingChannel)
Consumer = Callable[[ChannelT, Basic.Deliver, pika.BasicProperties, bytes], None]
# This simple queuing library doesn't expose much of the power of
# rabbitmq/pika's queuing system; its purpose is to just provide an
# interface for external files to put things into queues and take them
# out from bots without having to import pika code all over our codebase.
class QueueClient(metaclass=ABCMeta):
class QueueClient(Generic[ChannelT], metaclass=ABCMeta):
def __init__(
self,
# Disable RabbitMQ heartbeats by default because BlockingConnection can't process them
@ -31,8 +33,8 @@ class QueueClient(metaclass=ABCMeta):
) -> None:
self.log = logging.getLogger("zulip.queue")
self.queues: Set[str] = set()
self.channel: Optional[BlockingChannel] = None
self.consumers: Dict[str, Set[Consumer]] = defaultdict(set)
self.channel: Optional[ChannelT] = None
self.consumers: Dict[str, Set[Consumer[ChannelT]]] = defaultdict(set)
self.rabbitmq_heartbeat = rabbitmq_heartbeat
self.is_consuming = False
self._connect()
@ -78,7 +80,7 @@ class QueueClient(metaclass=ABCMeta):
def _generate_ctag(self, queue_name: str) -> str:
return f"{queue_name}_{str(random.getrandbits(16))}"
def _reconnect_consumer_callback(self, queue: str, consumer: Consumer) -> None:
def _reconnect_consumer_callback(self, queue: str, consumer: Consumer[ChannelT]) -> None:
self.log.info(f"Queue reconnecting saved consumer {consumer} to queue {queue}")
self.ensure_queue(
queue,
@ -98,11 +100,11 @@ class QueueClient(metaclass=ABCMeta):
return self.channel is not None
@abstractmethod
def ensure_queue(self, queue_name: str, callback: Callable[[BlockingChannel], None]) -> None:
def ensure_queue(self, queue_name: str, callback: Callable[[ChannelT], None]) -> None:
raise NotImplementedError
def publish(self, queue_name: str, body: bytes) -> None:
def do_publish(channel: BlockingChannel) -> None:
def do_publish(channel: ChannelT) -> None:
channel.basic_publish(
exchange="",
routing_key=queue_name,
@ -126,7 +128,7 @@ class QueueClient(metaclass=ABCMeta):
self.publish(queue_name, data)
class SimpleQueueClient(QueueClient):
class SimpleQueueClient(QueueClient[BlockingChannel]):
def _connect(self) -> None:
start = time.time()
self.connection = pika.BlockingConnection(self._get_parameters())
@ -227,7 +229,7 @@ calling _adapter_disconnect, ignoring",
)
class TornadoQueueClient(QueueClient):
class TornadoQueueClient(QueueClient[Channel]):
connection: Optional[ExceptionFreeTornadoConnection]
# Based on:
@ -237,7 +239,7 @@ class TornadoQueueClient(QueueClient):
# TornadoConnection can process heartbeats, so enable them.
rabbitmq_heartbeat=None
)
self._on_open_cbs: List[Callable[[BlockingChannel], None]] = []
self._on_open_cbs: List[Callable[[Channel], None]] = []
self._connection_failure_count = 0
def _connect(self) -> None:
@ -305,7 +307,7 @@ class TornadoQueueClient(QueueClient):
# Let _on_connection_closed deal with trying again.
self.log.warning("TornadoQueueClient couldn't open channel: connection already closed")
def _on_channel_open(self, channel: BlockingChannel) -> None:
def _on_channel_open(self, channel: Channel) -> None:
self.channel = channel
for callback in self._on_open_cbs:
callback(channel)
@ -316,7 +318,7 @@ class TornadoQueueClient(QueueClient):
if self.connection is not None:
self.connection.close()
def ensure_queue(self, queue_name: str, callback: Callable[[BlockingChannel], None]) -> None:
def ensure_queue(self, queue_name: str, callback: Callable[[Channel], None]) -> None:
def finish(frame: Any) -> None:
assert self.channel is not None
self.queues.add(queue_name)
@ -343,7 +345,7 @@ class TornadoQueueClient(QueueClient):
timeout: Optional[int] = None,
) -> None:
def wrapped_consumer(
ch: BlockingChannel,
ch: Channel,
method: Basic.Deliver,
properties: pika.BasicProperties,
body: bytes,