mypy: Improve type checks for user display recipients.

This commit is contained in:
Steve Howell 2023-08-10 14:12:37 +00:00 committed by Tim Abbott
parent 1b7880fc21
commit f8ec00b895
6 changed files with 17 additions and 23 deletions

View File

@ -34,16 +34,17 @@ def get_display_recipient_cache_key(
@cache_with_key(get_display_recipient_cache_key, timeout=3600 * 24 * 7)
def get_display_recipient_remote_cache(
recipient_id: int, recipient_type: int, recipient_type_id: Optional[int]
) -> DisplayRecipientT:
) -> List[UserDisplayRecipient]:
"""
returns: an appropriate object describing the recipient. For a
stream this will be the stream name as a string. For a huddle or
personal, it will be an array of dicts about each recipient.
This returns an appropriate object describing the recipient of a
direct message (whether individual or group).
It will be an array of dicts for each recipient.
Do not use this for streams.
"""
if recipient_type == Recipient.STREAM: # nocoverage
assert recipient_type_id is not None
stream = Stream.objects.values("name").get(id=recipient_type_id)
return stream["name"]
assert recipient_type != Recipient.STREAM
# The main priority for ordering here is being deterministic.
# Right now, we order by ID, which matches the ordering of user

View File

@ -460,7 +460,6 @@ def process_missed_message(to: str, message: EmailMessage) -> None:
internal_send_private_message(user_profile, recipient_user, body)
elif recipient.type == Recipient.HUDDLE:
display_recipient = get_display_recipient(recipient)
assert not isinstance(display_recipient, str)
emails = [user_dict["email"] for user_dict in display_recipient]
recipient_str = ", ".join(emails)
internal_send_huddle_message(user_profile.realm, user_profile, emails, body)

View File

@ -269,7 +269,6 @@ def build_message_list(
elif message.recipient.type == Recipient.HUDDLE:
grouping = {"huddle": message.recipient_id}
display_recipient = get_display_recipient(message.recipient)
assert not isinstance(display_recipient, str)
narrow_link = huddle_narrow_url(
user=user,
display_recipient=display_recipient,
@ -475,8 +474,6 @@ def do_send_missedmessage_events_reply_in_zulip(
senders = list({m["message"].sender for m in missed_messages})
if missed_messages[0]["message"].recipient.type == Recipient.HUDDLE:
display_recipient = get_display_recipient(missed_messages[0]["message"].recipient)
# Make sure that this is a list of strings, not a string.
assert not isinstance(display_recipient, str)
narrow_url = huddle_narrow_url(
user=user_profile,
display_recipient=display_recipient,

View File

@ -989,15 +989,12 @@ def render_markdown(
def huddle_users(recipient_id: int) -> str:
display_recipient: DisplayRecipientT = get_display_recipient_by_id(
display_recipient: List[UserDisplayRecipient] = get_display_recipient_by_id(
recipient_id,
Recipient.HUDDLE,
None,
)
# str is for streams.
assert not isinstance(display_recipient, str)
user_ids: List[int] = [obj["id"] for obj in display_recipient]
user_ids = sorted(user_ids)
return ",".join(str(uid) for uid in user_ids)

View File

@ -87,7 +87,6 @@ from zerver.lib.pysa import mark_sanitized
from zerver.lib.timestamp import datetime_to_timestamp
from zerver.lib.types import (
DefaultStreamDict,
DisplayRecipientT,
ExtendedFieldElement,
ExtendedValidator,
FieldElement,
@ -99,6 +98,7 @@ from zerver.lib.types import (
RealmPlaygroundDict,
RealmUserValidator,
UnspecifiedValue,
UserDisplayRecipient,
UserFieldElement,
Validator,
)
@ -191,12 +191,12 @@ def query_for_ids(
# could be replaced with smarter bulk-fetching logic that deduplicates
# queries for the same recipient; this is just a convenient way to
# write that code.
per_request_display_recipient_cache: Dict[int, DisplayRecipientT] = {}
per_request_display_recipient_cache: Dict[int, List[UserDisplayRecipient]] = {}
def get_display_recipient_by_id(
recipient_id: int, recipient_type: int, recipient_type_id: Optional[int]
) -> DisplayRecipientT:
) -> List[UserDisplayRecipient]:
"""
returns: an object describing the recipient (using a cache).
If the type is a stream, the type_id must be an int; a string is returned.
@ -211,7 +211,7 @@ def get_display_recipient_by_id(
return per_request_display_recipient_cache[recipient_id]
def get_display_recipient(recipient: "Recipient") -> DisplayRecipientT:
def get_display_recipient(recipient: "Recipient") -> List[UserDisplayRecipient]:
return get_display_recipient_by_id(
recipient.id,
recipient.type,

View File

@ -43,7 +43,7 @@ from zerver.lib.streams import StreamDict, create_streams_if_needed, get_public_
from zerver.lib.test_classes import ZulipTestCase
from zerver.lib.test_helpers import HostRequestMock, get_user_messages, queries_captured
from zerver.lib.topic import MATCH_TOPIC, RESOLVED_TOPIC_PREFIX, TOPIC_NAME
from zerver.lib.types import DisplayRecipientT
from zerver.lib.types import UserDisplayRecipient
from zerver.lib.upload.base import create_attachment
from zerver.lib.url_encoding import near_message_url
from zerver.lib.user_topics import set_topic_visibility_policy
@ -2024,11 +2024,11 @@ class GetOldMessagesTest(ZulipTestCase):
"""
me = self.example_user("hamlet")
def dr_emails(dr: DisplayRecipientT) -> str:
def dr_emails(dr: List[UserDisplayRecipient]) -> str:
assert isinstance(dr, list)
return ",".join(sorted({*(r["email"] for r in dr), me.email}))
def dr_ids(dr: DisplayRecipientT) -> List[int]:
def dr_ids(dr: List[UserDisplayRecipient]) -> List[int]:
assert isinstance(dr, list)
return sorted({*(r["id"] for r in dr), self.example_user("hamlet").id})