From f8ec00b895cf3a528c3cd40ac673ece1c1b2fba8 Mon Sep 17 00:00:00 2001 From: Steve Howell Date: Thu, 10 Aug 2023 14:12:37 +0000 Subject: [PATCH] mypy: Improve type checks for user display recipients. --- zerver/lib/display_recipient.py | 17 +++++++++-------- zerver/lib/email_mirror.py | 1 - zerver/lib/email_notifications.py | 3 --- zerver/lib/message.py | 5 +---- zerver/models.py | 8 ++++---- zerver/tests/test_message_fetch.py | 6 +++--- 6 files changed, 17 insertions(+), 23 deletions(-) diff --git a/zerver/lib/display_recipient.py b/zerver/lib/display_recipient.py index ed8367b465..f72d11166d 100644 --- a/zerver/lib/display_recipient.py +++ b/zerver/lib/display_recipient.py @@ -34,16 +34,17 @@ def get_display_recipient_cache_key( @cache_with_key(get_display_recipient_cache_key, timeout=3600 * 24 * 7) def get_display_recipient_remote_cache( recipient_id: int, recipient_type: int, recipient_type_id: Optional[int] -) -> DisplayRecipientT: +) -> List[UserDisplayRecipient]: """ - returns: an appropriate object describing the recipient. For a - stream this will be the stream name as a string. For a huddle or - personal, it will be an array of dicts about each recipient. + This returns an appropriate object describing the recipient of a + direct message (whether individual or group). + + It will be an array of dicts for each recipient. + + Do not use this for streams. """ - if recipient_type == Recipient.STREAM: # nocoverage - assert recipient_type_id is not None - stream = Stream.objects.values("name").get(id=recipient_type_id) - return stream["name"] + + assert recipient_type != Recipient.STREAM # The main priority for ordering here is being deterministic. # Right now, we order by ID, which matches the ordering of user diff --git a/zerver/lib/email_mirror.py b/zerver/lib/email_mirror.py index 0165d1d5e0..71d26b8a1c 100644 --- a/zerver/lib/email_mirror.py +++ b/zerver/lib/email_mirror.py @@ -460,7 +460,6 @@ def process_missed_message(to: str, message: EmailMessage) -> None: internal_send_private_message(user_profile, recipient_user, body) elif recipient.type == Recipient.HUDDLE: display_recipient = get_display_recipient(recipient) - assert not isinstance(display_recipient, str) emails = [user_dict["email"] for user_dict in display_recipient] recipient_str = ", ".join(emails) internal_send_huddle_message(user_profile.realm, user_profile, emails, body) diff --git a/zerver/lib/email_notifications.py b/zerver/lib/email_notifications.py index 839e10f049..926f7412ce 100644 --- a/zerver/lib/email_notifications.py +++ b/zerver/lib/email_notifications.py @@ -269,7 +269,6 @@ def build_message_list( elif message.recipient.type == Recipient.HUDDLE: grouping = {"huddle": message.recipient_id} display_recipient = get_display_recipient(message.recipient) - assert not isinstance(display_recipient, str) narrow_link = huddle_narrow_url( user=user, display_recipient=display_recipient, @@ -475,8 +474,6 @@ def do_send_missedmessage_events_reply_in_zulip( senders = list({m["message"].sender for m in missed_messages}) if missed_messages[0]["message"].recipient.type == Recipient.HUDDLE: display_recipient = get_display_recipient(missed_messages[0]["message"].recipient) - # Make sure that this is a list of strings, not a string. - assert not isinstance(display_recipient, str) narrow_url = huddle_narrow_url( user=user_profile, display_recipient=display_recipient, diff --git a/zerver/lib/message.py b/zerver/lib/message.py index 3e90550de8..fb8e4aab1b 100644 --- a/zerver/lib/message.py +++ b/zerver/lib/message.py @@ -989,15 +989,12 @@ def render_markdown( def huddle_users(recipient_id: int) -> str: - display_recipient: DisplayRecipientT = get_display_recipient_by_id( + display_recipient: List[UserDisplayRecipient] = get_display_recipient_by_id( recipient_id, Recipient.HUDDLE, None, ) - # str is for streams. - assert not isinstance(display_recipient, str) - user_ids: List[int] = [obj["id"] for obj in display_recipient] user_ids = sorted(user_ids) return ",".join(str(uid) for uid in user_ids) diff --git a/zerver/models.py b/zerver/models.py index 14f9bee6a2..50d0ca7aeb 100644 --- a/zerver/models.py +++ b/zerver/models.py @@ -87,7 +87,6 @@ from zerver.lib.pysa import mark_sanitized from zerver.lib.timestamp import datetime_to_timestamp from zerver.lib.types import ( DefaultStreamDict, - DisplayRecipientT, ExtendedFieldElement, ExtendedValidator, FieldElement, @@ -99,6 +98,7 @@ from zerver.lib.types import ( RealmPlaygroundDict, RealmUserValidator, UnspecifiedValue, + UserDisplayRecipient, UserFieldElement, Validator, ) @@ -191,12 +191,12 @@ def query_for_ids( # could be replaced with smarter bulk-fetching logic that deduplicates # queries for the same recipient; this is just a convenient way to # write that code. -per_request_display_recipient_cache: Dict[int, DisplayRecipientT] = {} +per_request_display_recipient_cache: Dict[int, List[UserDisplayRecipient]] = {} def get_display_recipient_by_id( recipient_id: int, recipient_type: int, recipient_type_id: Optional[int] -) -> DisplayRecipientT: +) -> List[UserDisplayRecipient]: """ returns: an object describing the recipient (using a cache). If the type is a stream, the type_id must be an int; a string is returned. @@ -211,7 +211,7 @@ def get_display_recipient_by_id( return per_request_display_recipient_cache[recipient_id] -def get_display_recipient(recipient: "Recipient") -> DisplayRecipientT: +def get_display_recipient(recipient: "Recipient") -> List[UserDisplayRecipient]: return get_display_recipient_by_id( recipient.id, recipient.type, diff --git a/zerver/tests/test_message_fetch.py b/zerver/tests/test_message_fetch.py index a61b1ab117..5cfcad040e 100644 --- a/zerver/tests/test_message_fetch.py +++ b/zerver/tests/test_message_fetch.py @@ -43,7 +43,7 @@ from zerver.lib.streams import StreamDict, create_streams_if_needed, get_public_ from zerver.lib.test_classes import ZulipTestCase from zerver.lib.test_helpers import HostRequestMock, get_user_messages, queries_captured from zerver.lib.topic import MATCH_TOPIC, RESOLVED_TOPIC_PREFIX, TOPIC_NAME -from zerver.lib.types import DisplayRecipientT +from zerver.lib.types import UserDisplayRecipient from zerver.lib.upload.base import create_attachment from zerver.lib.url_encoding import near_message_url from zerver.lib.user_topics import set_topic_visibility_policy @@ -2024,11 +2024,11 @@ class GetOldMessagesTest(ZulipTestCase): """ me = self.example_user("hamlet") - def dr_emails(dr: DisplayRecipientT) -> str: + def dr_emails(dr: List[UserDisplayRecipient]) -> str: assert isinstance(dr, list) return ",".join(sorted({*(r["email"] for r in dr), me.email})) - def dr_ids(dr: DisplayRecipientT) -> List[int]: + def dr_ids(dr: List[UserDisplayRecipient]) -> List[int]: assert isinstance(dr, list) return sorted({*(r["id"] for r in dr), self.example_user("hamlet").id})