event_queue: Fix strict_optional errors.

Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
Anders Kaseorg 2020-07-04 17:29:31 -07:00 committed by Tim Abbott
parent 0d7539dc50
commit e3fc74fd20
3 changed files with 12 additions and 13 deletions

View File

@ -40,8 +40,6 @@ strict_optional = True
# General exclusions to work on # General exclusions to work on
[mypy-zerver.tornado.event_queue]
strict_optional = False
[mypy-zerver.lib.outgoing_webhook] [mypy-zerver.lib.outgoing_webhook]
strict_optional = False strict_optional = False
[mypy-zerver.lib.markdown] # for __init__.py [mypy-zerver.lib.markdown] # for __init__.py

View File

@ -12,6 +12,5 @@ def set_descriptor_by_handler_id(handler_id: int,
client_descriptor: 'ClientDescriptor') -> None: client_descriptor: 'ClientDescriptor') -> None:
descriptors_by_handler_id[handler_id] = client_descriptor descriptors_by_handler_id[handler_id] = client_descriptor
def clear_descriptor_by_handler_id(handler_id: int, def clear_descriptor_by_handler_id(handler_id: int) -> None:
client_descriptor: 'ClientDescriptor') -> None:
del descriptors_by_handler_id[handler_id] del descriptors_by_handler_id[handler_id]

View File

@ -226,7 +226,7 @@ class ClientDescriptor:
def disconnect_handler(self, client_closed: bool=False) -> None: def disconnect_handler(self, client_closed: bool=False) -> None:
if self.current_handler_id: if self.current_handler_id:
clear_descriptor_by_handler_id(self.current_handler_id, None) clear_descriptor_by_handler_id(self.current_handler_id)
clear_handler_by_id(self.current_handler_id) clear_handler_by_id(self.current_handler_id)
if client_closed: if client_closed:
logging.info("Client disconnected for queue %s (%s via %s)", logging.info("Client disconnected for queue %s (%s via %s)",
@ -388,7 +388,10 @@ def add_client_gc_hook(hook: Callable[[int, ClientDescriptor, bool], None]) -> N
gc_hooks.append(hook) gc_hooks.append(hook)
def get_client_descriptor(queue_id: str) -> ClientDescriptor: def get_client_descriptor(queue_id: str) -> ClientDescriptor:
return clients.get(queue_id) try:
return clients[queue_id]
except KeyError:
raise BadEventQueueIdError(queue_id)
def get_client_descriptors_for_user(user_profile_id: int) -> List[ClientDescriptor]: def get_client_descriptors_for_user(user_profile_id: int) -> List[ClientDescriptor]:
return user_clients.get(user_profile_id, []) return user_clients.get(user_profile_id, [])
@ -535,9 +538,9 @@ def setup_event_queue(port: int) -> None:
send_restart_events(immediate=settings.DEVELOPMENT) send_restart_events(immediate=settings.DEVELOPMENT)
def fetch_events(query: Mapping[str, Any]) -> Dict[str, Any]: def fetch_events(query: Mapping[str, Any]) -> Dict[str, Any]:
queue_id: str = query["queue_id"] queue_id: Optional[str] = query["queue_id"]
dont_block: bool = query["dont_block"] dont_block: bool = query["dont_block"]
last_event_id: int = query["last_event_id"] last_event_id: Optional[int] = query["last_event_id"]
user_profile_id: int = query["user_profile_id"] user_profile_id: int = query["user_profile_id"]
new_queue_data: Optional[MutableMapping[str, Any]] = query.get("new_queue_data") new_queue_data: Optional[MutableMapping[str, Any]] = query.get("new_queue_data")
client_type_name: str = query["client_type_name"] client_type_name: str = query["client_type_name"]
@ -549,6 +552,7 @@ def fetch_events(query: Mapping[str, Any]) -> Dict[str, Any]:
extra_log_data = "" extra_log_data = ""
if queue_id is None: if queue_id is None:
if dont_block: if dont_block:
assert new_queue_data is not None
client = allocate_client_descriptor(new_queue_data) client = allocate_client_descriptor(new_queue_data)
queue_id = client.event_queue.id queue_id = client.event_queue.id
else: else:
@ -557,8 +561,6 @@ def fetch_events(query: Mapping[str, Any]) -> Dict[str, Any]:
if last_event_id is None: if last_event_id is None:
raise JsonableError(_("Missing 'last_event_id' argument")) raise JsonableError(_("Missing 'last_event_id' argument"))
client = get_client_descriptor(queue_id) client = get_client_descriptor(queue_id)
if client is None:
raise BadEventQueueIdError(queue_id)
if user_profile_id != client.user_profile_id: if user_profile_id != client.user_profile_id:
raise JsonableError(_("You are not authorized to get events from this queue")) raise JsonableError(_("You are not authorized to get events from this queue"))
if ( if (
@ -706,7 +708,7 @@ def missedmessage_hook(user_profile_id: int, client: ClientDescriptor, last_for_
continue continue
assert 'flags' in event assert 'flags' in event
flags = event.get('flags') flags = event['flags']
mentioned = 'mentioned' in flags and 'read' not in flags mentioned = 'mentioned' in flags and 'read' not in flags
private_message = event['message']['type'] == 'private' private_message = event['message']['type'] == 'private'
@ -799,7 +801,7 @@ def maybe_enqueue_notifications(user_profile_id: int, message_id: int, private_m
class ClientInfo(TypedDict): class ClientInfo(TypedDict):
client: ClientDescriptor client: ClientDescriptor
flags: Optional[Iterable[str]] flags: Iterable[str]
is_sender: bool is_sender: bool
def get_client_info_for_message_event(event_template: Mapping[str, Any], def get_client_info_for_message_event(event_template: Mapping[str, Any],
@ -1049,7 +1051,7 @@ def process_message_update_event(event_template: Mapping[str, Any],
def maybe_enqueue_notifications_for_message_update(user_profile_id: UserProfile, def maybe_enqueue_notifications_for_message_update(user_profile_id: UserProfile,
message_id: int, message_id: int,
stream_name: str, stream_name: Optional[str],
prior_mention_user_ids: Set[int], prior_mention_user_ids: Set[int],
mention_user_ids: Set[int], mention_user_ids: Set[int],
wildcard_mention_notify: bool, wildcard_mention_notify: bool,