diff --git a/zerver/actions/invites.py b/zerver/actions/invites.py index 829a88384c..97a01724d7 100644 --- a/zerver/actions/invites.py +++ b/zerver/actions/invites.py @@ -5,7 +5,7 @@ from typing import Any, Collection, Dict, List, Optional, Sequence, Set, Tuple, from django.conf import settings from django.contrib.contenttypes.models import ContentType from django.db import transaction -from django.db.models import Q, Sum +from django.db.models import Q, QuerySet, Sum from django.utils.timezone import now as timezone_now from django.utils.translation import gettext as _ from zxcvbn import zxcvbn @@ -70,7 +70,7 @@ def do_send_confirmation_email( return activation_url -def estimate_recent_invites(realms: Collection[Realm], *, days: int) -> int: +def estimate_recent_invites(realms: Collection[Realm] | QuerySet[Realm], *, days: int) -> int: """An upper bound on the number of invites sent in the last `days` days""" recent_invites = RealmCount.objects.filter( realm__in=realms, diff --git a/zerver/actions/message_send.py b/zerver/actions/message_send.py index 746e0012d2..f7cbff5c46 100644 --- a/zerver/actions/message_send.py +++ b/zerver/actions/message_send.py @@ -238,8 +238,8 @@ def get_recipient_info( if recipient.type == Recipient.PERSONAL: # The sender and recipient may be the same id, so # de-duplicate using a set. - message_to_user_ids: Collection[int] = list({recipient.type_id, sender_id}) - assert len(message_to_user_ids) in [1, 2] + message_to_user_id_set = {recipient.type_id, sender_id} + assert len(message_to_user_id_set) in [1, 2] elif recipient.type == Recipient.STREAM: # Anybody calling us w/r/t a stream message needs to supply @@ -302,9 +302,9 @@ def get_recipient_info( .order_by("user_profile_id") ) - message_to_user_ids = list() + message_to_user_id_set = set() for row in subscription_rows: - message_to_user_ids.append(row["user_profile_id"]) + message_to_user_id_set.add(row["user_profile_id"]) # We store the 'sender_muted_stream' information here to avoid db query at # a later stage when we perform automatically unmute topic in muted stream operation. if row["user_profile_id"] == sender_id: @@ -373,21 +373,18 @@ def get_recipient_info( ) elif recipient.type == Recipient.DIRECT_MESSAGE_GROUP: - message_to_user_ids = get_huddle_user_ids(recipient) + message_to_user_id_set = set(get_huddle_user_ids(recipient)) else: raise ValueError("Bad recipient type") - message_to_user_id_set = set(message_to_user_ids) - - user_ids = set(message_to_user_id_set) # Important note: Because we haven't rendered Markdown yet, we # don't yet know which of these possibly-mentioned users was # actually mentioned in the message (in other words, the # mention syntax might have been in a code block or otherwise # escaped). `get_ids_for` will filter these extra user rows # for our data structures not related to bots - user_ids |= possibly_mentioned_user_ids + user_ids = message_to_user_id_set | possibly_mentioned_user_ids if user_ids: query: ValuesQuerySet[UserProfile, ActiveUserDict] = UserProfile.objects.filter( diff --git a/zerver/lib/bulk_create.py b/zerver/lib/bulk_create.py index 8232d1b69e..86dadba18b 100644 --- a/zerver/lib/bulk_create.py +++ b/zerver/lib/bulk_create.py @@ -1,6 +1,6 @@ from typing import Any, Collection, Dict, Iterable, List, Optional, Set, Tuple, Type, Union -from django.db.models import Model +from django.db.models import Model, QuerySet from django.utils.timezone import now as timezone_now from zerver.lib.create_user import create_user_profile, get_display_email_address @@ -163,7 +163,9 @@ def bulk_create_users( def bulk_set_users_or_streams_recipient_fields( model: Type[Model], - objects: Union[Collection[UserProfile], Collection[Stream]], + objects: Union[ + Collection[UserProfile], QuerySet[UserProfile], Collection[Stream], QuerySet[Stream] + ], recipients: Optional[Iterable[Recipient]] = None, ) -> None: assert model in [UserProfile, Stream] diff --git a/zerver/lib/digest.py b/zerver/lib/digest.py index ef4c624786..c3cb1e18e8 100644 --- a/zerver/lib/digest.py +++ b/zerver/lib/digest.py @@ -7,7 +7,7 @@ from typing import Any, Collection, Dict, Iterator, List, Optional, Set, Tuple from django.conf import settings from django.db import transaction -from django.db.models import Exists, OuterRef +from django.db.models import Exists, OuterRef, QuerySet from django.utils.timezone import now as timezone_now from django.utils.translation import gettext as _ from typing_extensions import TypeAlias @@ -331,7 +331,7 @@ def get_slim_stream_id_map(realm: Realm) -> Dict[int, Stream]: def bulk_get_digest_context( - users: Collection[UserProfile], cutoff: float + users: Collection[UserProfile] | QuerySet[UserProfile], cutoff: float ) -> Iterator[Tuple[UserProfile, Dict[str, Any]]]: # We expect a non-empty list of users all from the same realm. assert users diff --git a/zerver/lib/message.py b/zerver/lib/message.py index 6d18db147f..1a190a2a55 100644 --- a/zerver/lib/message.py +++ b/zerver/lib/message.py @@ -420,7 +420,10 @@ def has_message_access( def bulk_access_messages( - user_profile: UserProfile, messages: Collection[Message], *, stream: Optional[Stream] = None + user_profile: UserProfile, + messages: Collection[Message] | QuerySet[Message], + *, + stream: Optional[Stream] = None, ) -> List[Message]: """This function does the full has_message_access check for each message. If stream is provided, it is used to avoid unnecessary diff --git a/zerver/lib/message_cache.py b/zerver/lib/message_cache.py index aa9b31bf4a..c9cadde085 100644 --- a/zerver/lib/message_cache.py +++ b/zerver/lib/message_cache.py @@ -2,7 +2,7 @@ import copy import zlib from datetime import datetime from email.headerregistry import Address -from typing import Any, Collection, Dict, List, Optional, TypedDict +from typing import Any, Dict, Iterable, List, Optional, TypedDict import orjson @@ -78,7 +78,7 @@ def message_to_encoded_cache(message: Message, realm_id: Optional[int] = None) - def update_message_cache( - changed_messages: Collection[Message], realm_id: Optional[int] = None + changed_messages: Iterable[Message], realm_id: Optional[int] = None ) -> List[int]: """Updates the message as stored in the to_dict cache (for serving messages).""" @@ -273,7 +273,7 @@ class MessageDict: @staticmethod def messages_to_encoded_cache( - messages: Collection[Message], realm_id: Optional[int] = None + messages: Iterable[Message], realm_id: Optional[int] = None ) -> Dict[int, bytes]: messages_dict = MessageDict.messages_to_encoded_cache_helper(messages, realm_id) encoded_messages = {msg["id"]: stringify_message_dict(msg) for msg in messages_dict} @@ -281,7 +281,7 @@ class MessageDict: @staticmethod def messages_to_encoded_cache_helper( - messages: Collection[Message], realm_id: Optional[int] = None + messages: Iterable[Message], realm_id: Optional[int] = None ) -> List[Dict[str, Any]]: # Near duplicate of the build_message_dict + get_raw_db_rows # code path that accepts already fetched Message objects diff --git a/zerver/lib/stream_subscription.py b/zerver/lib/stream_subscription.py index da63a52117..9edc7807b4 100644 --- a/zerver/lib/stream_subscription.py +++ b/zerver/lib/stream_subscription.py @@ -180,7 +180,7 @@ def get_users_for_streams(stream_ids: Set[int]) -> Dict[int, Set[UserProfile]]: def bulk_get_subscriber_peer_info( realm: Realm, - streams: Collection[Stream], + streams: Collection[Stream] | QuerySet[Stream], ) -> SubscriberPeerInfo: """ Glossary: diff --git a/zerver/lib/test_classes.py b/zerver/lib/test_classes.py index d897f5a6db..1ba5be0350 100644 --- a/zerver/lib/test_classes.py +++ b/zerver/lib/test_classes.py @@ -47,6 +47,7 @@ from django.urls import resolve from django.utils import translation from django.utils.module_loading import import_string from django.utils.timezone import now as timezone_now +from django_stubs_ext import ValuesQuerySet from fakeldap import MockLDAP from openapi_core.contrib.django import DjangoOpenAPIRequest, DjangoOpenAPIResponse from requests import PreparedRequest @@ -1245,7 +1246,7 @@ Output: """ self.assertEqual(self.get_json_error(result, status_code=status_code), msg) - def assert_length(self, items: Collection[Any], count: int) -> None: + def assert_length(self, items: Collection[Any] | ValuesQuerySet[Any, Any], count: int) -> None: actual_count = len(items) if actual_count != count: # nocoverage print("\nITEMS:\n") diff --git a/zerver/management/commands/sync_ldap_user_data.py b/zerver/management/commands/sync_ldap_user_data.py index 59e1e6a236..5b885e364d 100644 --- a/zerver/management/commands/sync_ldap_user_data.py +++ b/zerver/management/commands/sync_ldap_user_data.py @@ -1,10 +1,11 @@ import logging from argparse import ArgumentParser -from typing import Any, Collection +from typing import Any from django.conf import settings from django.core.management.base import CommandError from django.db import transaction +from django.db.models import QuerySet from typing_extensions import override from zerver.lib.logging_util import log_to_file @@ -20,7 +21,7 @@ log_to_file(logger, settings.LDAP_SYNC_LOG_PATH) # Run this on a cron job to pick up on name changes. @transaction.atomic def sync_ldap_user_data( - user_profiles: Collection[UserProfile], deactivation_protection: bool = True + user_profiles: QuerySet[UserProfile], deactivation_protection: bool = True ) -> None: logger.info("Starting update.") try: