mirror of https://github.com/zulip/zulip.git
queue: Fix channel type for TornadoQueueClient.
The BlockingChannel annotations in TornadoQueueClient were flat-out wrong. BlockingChannel and Channel have no common base classes. Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
parent
5751479932
commit
87799177b5
|
@ -4,26 +4,28 @@ import threading
|
|||
import time
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
|
||||
from typing import Any, Callable, Dict, Generic, List, Mapping, Optional, Set, TypeVar, Union
|
||||
|
||||
import orjson
|
||||
import pika
|
||||
import pika.adapters.tornado_connection
|
||||
from django.conf import settings
|
||||
from pika.adapters.blocking_connection import BlockingChannel
|
||||
from pika.channel import Channel
|
||||
from pika.spec import Basic
|
||||
from tornado import ioloop
|
||||
|
||||
from zerver.lib.utils import statsd
|
||||
|
||||
MAX_REQUEST_RETRIES = 3
|
||||
Consumer = Callable[[BlockingChannel, Basic.Deliver, pika.BasicProperties, bytes], None]
|
||||
ChannelT = TypeVar("ChannelT", Channel, BlockingChannel)
|
||||
Consumer = Callable[[ChannelT, Basic.Deliver, pika.BasicProperties, bytes], None]
|
||||
|
||||
# This simple queuing library doesn't expose much of the power of
|
||||
# rabbitmq/pika's queuing system; its purpose is to just provide an
|
||||
# interface for external files to put things into queues and take them
|
||||
# out from bots without having to import pika code all over our codebase.
|
||||
class QueueClient(metaclass=ABCMeta):
|
||||
class QueueClient(Generic[ChannelT], metaclass=ABCMeta):
|
||||
def __init__(
|
||||
self,
|
||||
# Disable RabbitMQ heartbeats by default because BlockingConnection can't process them
|
||||
|
@ -31,8 +33,8 @@ class QueueClient(metaclass=ABCMeta):
|
|||
) -> None:
|
||||
self.log = logging.getLogger("zulip.queue")
|
||||
self.queues: Set[str] = set()
|
||||
self.channel: Optional[BlockingChannel] = None
|
||||
self.consumers: Dict[str, Set[Consumer]] = defaultdict(set)
|
||||
self.channel: Optional[ChannelT] = None
|
||||
self.consumers: Dict[str, Set[Consumer[ChannelT]]] = defaultdict(set)
|
||||
self.rabbitmq_heartbeat = rabbitmq_heartbeat
|
||||
self.is_consuming = False
|
||||
self._connect()
|
||||
|
@ -78,7 +80,7 @@ class QueueClient(metaclass=ABCMeta):
|
|||
def _generate_ctag(self, queue_name: str) -> str:
|
||||
return f"{queue_name}_{str(random.getrandbits(16))}"
|
||||
|
||||
def _reconnect_consumer_callback(self, queue: str, consumer: Consumer) -> None:
|
||||
def _reconnect_consumer_callback(self, queue: str, consumer: Consumer[ChannelT]) -> None:
|
||||
self.log.info(f"Queue reconnecting saved consumer {consumer} to queue {queue}")
|
||||
self.ensure_queue(
|
||||
queue,
|
||||
|
@ -98,11 +100,11 @@ class QueueClient(metaclass=ABCMeta):
|
|||
return self.channel is not None
|
||||
|
||||
@abstractmethod
|
||||
def ensure_queue(self, queue_name: str, callback: Callable[[BlockingChannel], None]) -> None:
|
||||
def ensure_queue(self, queue_name: str, callback: Callable[[ChannelT], None]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def publish(self, queue_name: str, body: bytes) -> None:
|
||||
def do_publish(channel: BlockingChannel) -> None:
|
||||
def do_publish(channel: ChannelT) -> None:
|
||||
channel.basic_publish(
|
||||
exchange="",
|
||||
routing_key=queue_name,
|
||||
|
@ -126,7 +128,7 @@ class QueueClient(metaclass=ABCMeta):
|
|||
self.publish(queue_name, data)
|
||||
|
||||
|
||||
class SimpleQueueClient(QueueClient):
|
||||
class SimpleQueueClient(QueueClient[BlockingChannel]):
|
||||
def _connect(self) -> None:
|
||||
start = time.time()
|
||||
self.connection = pika.BlockingConnection(self._get_parameters())
|
||||
|
@ -227,7 +229,7 @@ calling _adapter_disconnect, ignoring",
|
|||
)
|
||||
|
||||
|
||||
class TornadoQueueClient(QueueClient):
|
||||
class TornadoQueueClient(QueueClient[Channel]):
|
||||
connection: Optional[ExceptionFreeTornadoConnection]
|
||||
|
||||
# Based on:
|
||||
|
@ -237,7 +239,7 @@ class TornadoQueueClient(QueueClient):
|
|||
# TornadoConnection can process heartbeats, so enable them.
|
||||
rabbitmq_heartbeat=None
|
||||
)
|
||||
self._on_open_cbs: List[Callable[[BlockingChannel], None]] = []
|
||||
self._on_open_cbs: List[Callable[[Channel], None]] = []
|
||||
self._connection_failure_count = 0
|
||||
|
||||
def _connect(self) -> None:
|
||||
|
@ -305,7 +307,7 @@ class TornadoQueueClient(QueueClient):
|
|||
# Let _on_connection_closed deal with trying again.
|
||||
self.log.warning("TornadoQueueClient couldn't open channel: connection already closed")
|
||||
|
||||
def _on_channel_open(self, channel: BlockingChannel) -> None:
|
||||
def _on_channel_open(self, channel: Channel) -> None:
|
||||
self.channel = channel
|
||||
for callback in self._on_open_cbs:
|
||||
callback(channel)
|
||||
|
@ -316,7 +318,7 @@ class TornadoQueueClient(QueueClient):
|
|||
if self.connection is not None:
|
||||
self.connection.close()
|
||||
|
||||
def ensure_queue(self, queue_name: str, callback: Callable[[BlockingChannel], None]) -> None:
|
||||
def ensure_queue(self, queue_name: str, callback: Callable[[Channel], None]) -> None:
|
||||
def finish(frame: Any) -> None:
|
||||
assert self.channel is not None
|
||||
self.queues.add(queue_name)
|
||||
|
@ -343,7 +345,7 @@ class TornadoQueueClient(QueueClient):
|
|||
timeout: Optional[int] = None,
|
||||
) -> None:
|
||||
def wrapped_consumer(
|
||||
ch: BlockingChannel,
|
||||
ch: Channel,
|
||||
method: Basic.Deliver,
|
||||
properties: pika.BasicProperties,
|
||||
body: bytes,
|
||||
|
|
Loading…
Reference in New Issue