mypy: Improve type checks for user display recipients.

This commit is contained in:
Steve Howell 2023-08-10 14:12:37 +00:00 committed by Tim Abbott
parent 1b7880fc21
commit f8ec00b895
6 changed files with 17 additions and 23 deletions

View File

@ -34,16 +34,17 @@ def get_display_recipient_cache_key(
@cache_with_key(get_display_recipient_cache_key, timeout=3600 * 24 * 7) @cache_with_key(get_display_recipient_cache_key, timeout=3600 * 24 * 7)
def get_display_recipient_remote_cache( def get_display_recipient_remote_cache(
recipient_id: int, recipient_type: int, recipient_type_id: Optional[int] recipient_id: int, recipient_type: int, recipient_type_id: Optional[int]
) -> DisplayRecipientT: ) -> List[UserDisplayRecipient]:
""" """
returns: an appropriate object describing the recipient. For a This returns an appropriate object describing the recipient of a
stream this will be the stream name as a string. For a huddle or direct message (whether individual or group).
personal, it will be an array of dicts about each recipient.
It will be an array of dicts for each recipient.
Do not use this for streams.
""" """
if recipient_type == Recipient.STREAM: # nocoverage
assert recipient_type_id is not None assert recipient_type != Recipient.STREAM
stream = Stream.objects.values("name").get(id=recipient_type_id)
return stream["name"]
# The main priority for ordering here is being deterministic. # The main priority for ordering here is being deterministic.
# Right now, we order by ID, which matches the ordering of user # Right now, we order by ID, which matches the ordering of user

View File

@ -460,7 +460,6 @@ def process_missed_message(to: str, message: EmailMessage) -> None:
internal_send_private_message(user_profile, recipient_user, body) internal_send_private_message(user_profile, recipient_user, body)
elif recipient.type == Recipient.HUDDLE: elif recipient.type == Recipient.HUDDLE:
display_recipient = get_display_recipient(recipient) display_recipient = get_display_recipient(recipient)
assert not isinstance(display_recipient, str)
emails = [user_dict["email"] for user_dict in display_recipient] emails = [user_dict["email"] for user_dict in display_recipient]
recipient_str = ", ".join(emails) recipient_str = ", ".join(emails)
internal_send_huddle_message(user_profile.realm, user_profile, emails, body) internal_send_huddle_message(user_profile.realm, user_profile, emails, body)

View File

@ -269,7 +269,6 @@ def build_message_list(
elif message.recipient.type == Recipient.HUDDLE: elif message.recipient.type == Recipient.HUDDLE:
grouping = {"huddle": message.recipient_id} grouping = {"huddle": message.recipient_id}
display_recipient = get_display_recipient(message.recipient) display_recipient = get_display_recipient(message.recipient)
assert not isinstance(display_recipient, str)
narrow_link = huddle_narrow_url( narrow_link = huddle_narrow_url(
user=user, user=user,
display_recipient=display_recipient, display_recipient=display_recipient,
@ -475,8 +474,6 @@ def do_send_missedmessage_events_reply_in_zulip(
senders = list({m["message"].sender for m in missed_messages}) senders = list({m["message"].sender for m in missed_messages})
if missed_messages[0]["message"].recipient.type == Recipient.HUDDLE: if missed_messages[0]["message"].recipient.type == Recipient.HUDDLE:
display_recipient = get_display_recipient(missed_messages[0]["message"].recipient) display_recipient = get_display_recipient(missed_messages[0]["message"].recipient)
# Make sure that this is a list of strings, not a string.
assert not isinstance(display_recipient, str)
narrow_url = huddle_narrow_url( narrow_url = huddle_narrow_url(
user=user_profile, user=user_profile,
display_recipient=display_recipient, display_recipient=display_recipient,

View File

@ -989,15 +989,12 @@ def render_markdown(
def huddle_users(recipient_id: int) -> str: def huddle_users(recipient_id: int) -> str:
display_recipient: DisplayRecipientT = get_display_recipient_by_id( display_recipient: List[UserDisplayRecipient] = get_display_recipient_by_id(
recipient_id, recipient_id,
Recipient.HUDDLE, Recipient.HUDDLE,
None, None,
) )
# str is for streams.
assert not isinstance(display_recipient, str)
user_ids: List[int] = [obj["id"] for obj in display_recipient] user_ids: List[int] = [obj["id"] for obj in display_recipient]
user_ids = sorted(user_ids) user_ids = sorted(user_ids)
return ",".join(str(uid) for uid in user_ids) return ",".join(str(uid) for uid in user_ids)

View File

@ -87,7 +87,6 @@ from zerver.lib.pysa import mark_sanitized
from zerver.lib.timestamp import datetime_to_timestamp from zerver.lib.timestamp import datetime_to_timestamp
from zerver.lib.types import ( from zerver.lib.types import (
DefaultStreamDict, DefaultStreamDict,
DisplayRecipientT,
ExtendedFieldElement, ExtendedFieldElement,
ExtendedValidator, ExtendedValidator,
FieldElement, FieldElement,
@ -99,6 +98,7 @@ from zerver.lib.types import (
RealmPlaygroundDict, RealmPlaygroundDict,
RealmUserValidator, RealmUserValidator,
UnspecifiedValue, UnspecifiedValue,
UserDisplayRecipient,
UserFieldElement, UserFieldElement,
Validator, Validator,
) )
@ -191,12 +191,12 @@ def query_for_ids(
# could be replaced with smarter bulk-fetching logic that deduplicates # could be replaced with smarter bulk-fetching logic that deduplicates
# queries for the same recipient; this is just a convenient way to # queries for the same recipient; this is just a convenient way to
# write that code. # write that code.
per_request_display_recipient_cache: Dict[int, DisplayRecipientT] = {} per_request_display_recipient_cache: Dict[int, List[UserDisplayRecipient]] = {}
def get_display_recipient_by_id( def get_display_recipient_by_id(
recipient_id: int, recipient_type: int, recipient_type_id: Optional[int] recipient_id: int, recipient_type: int, recipient_type_id: Optional[int]
) -> DisplayRecipientT: ) -> List[UserDisplayRecipient]:
""" """
returns: an object describing the recipient (using a cache). returns: an object describing the recipient (using a cache).
If the type is a stream, the type_id must be an int; a string is returned. If the type is a stream, the type_id must be an int; a string is returned.
@ -211,7 +211,7 @@ def get_display_recipient_by_id(
return per_request_display_recipient_cache[recipient_id] return per_request_display_recipient_cache[recipient_id]
def get_display_recipient(recipient: "Recipient") -> DisplayRecipientT: def get_display_recipient(recipient: "Recipient") -> List[UserDisplayRecipient]:
return get_display_recipient_by_id( return get_display_recipient_by_id(
recipient.id, recipient.id,
recipient.type, recipient.type,

View File

@ -43,7 +43,7 @@ from zerver.lib.streams import StreamDict, create_streams_if_needed, get_public_
from zerver.lib.test_classes import ZulipTestCase from zerver.lib.test_classes import ZulipTestCase
from zerver.lib.test_helpers import HostRequestMock, get_user_messages, queries_captured from zerver.lib.test_helpers import HostRequestMock, get_user_messages, queries_captured
from zerver.lib.topic import MATCH_TOPIC, RESOLVED_TOPIC_PREFIX, TOPIC_NAME from zerver.lib.topic import MATCH_TOPIC, RESOLVED_TOPIC_PREFIX, TOPIC_NAME
from zerver.lib.types import DisplayRecipientT from zerver.lib.types import UserDisplayRecipient
from zerver.lib.upload.base import create_attachment from zerver.lib.upload.base import create_attachment
from zerver.lib.url_encoding import near_message_url from zerver.lib.url_encoding import near_message_url
from zerver.lib.user_topics import set_topic_visibility_policy from zerver.lib.user_topics import set_topic_visibility_policy
@ -2024,11 +2024,11 @@ class GetOldMessagesTest(ZulipTestCase):
""" """
me = self.example_user("hamlet") me = self.example_user("hamlet")
def dr_emails(dr: DisplayRecipientT) -> str: def dr_emails(dr: List[UserDisplayRecipient]) -> str:
assert isinstance(dr, list) assert isinstance(dr, list)
return ",".join(sorted({*(r["email"] for r in dr), me.email})) return ",".join(sorted({*(r["email"] for r in dr), me.email}))
def dr_ids(dr: DisplayRecipientT) -> List[int]: def dr_ids(dr: List[UserDisplayRecipient]) -> List[int]:
assert isinstance(dr, list) assert isinstance(dr, list)
return sorted({*(r["id"] for r in dr), self.example_user("hamlet").id}) return sorted({*(r["id"] for r in dr), self.example_user("hamlet").id})