python: Avoid relying on Collection supertype of QuerySet.

QuerySet doesn’t implement __contains__, so it can’t be a subtype of
Container or Collection (https://code.djangoproject.com/ticket/35154).
This incorrect subtyping annotation was removed in
https://github.com/typeddjango/django-stubs/pull/1925, so we need to
stop relying on it before upgrading to django-stubs 5.

Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
Anders Kaseorg 2024-04-16 20:28:33 -07:00 committed by Tim Abbott
parent 5654d051f7
commit f31579a220
9 changed files with 28 additions and 24 deletions

View File

@ -5,7 +5,7 @@ from typing import Any, Collection, Dict, List, Optional, Sequence, Set, Tuple,
from django.conf import settings
from django.contrib.contenttypes.models import ContentType
from django.db import transaction
from django.db.models import Q, Sum
from django.db.models import Q, QuerySet, Sum
from django.utils.timezone import now as timezone_now
from django.utils.translation import gettext as _
from zxcvbn import zxcvbn
@ -70,7 +70,7 @@ def do_send_confirmation_email(
return activation_url
def estimate_recent_invites(realms: Collection[Realm], *, days: int) -> int:
def estimate_recent_invites(realms: Collection[Realm] | QuerySet[Realm], *, days: int) -> int:
"""An upper bound on the number of invites sent in the last `days` days"""
recent_invites = RealmCount.objects.filter(
realm__in=realms,

View File

@ -238,8 +238,8 @@ def get_recipient_info(
if recipient.type == Recipient.PERSONAL:
# The sender and recipient may be the same id, so
# de-duplicate using a set.
message_to_user_ids: Collection[int] = list({recipient.type_id, sender_id})
assert len(message_to_user_ids) in [1, 2]
message_to_user_id_set = {recipient.type_id, sender_id}
assert len(message_to_user_id_set) in [1, 2]
elif recipient.type == Recipient.STREAM:
# Anybody calling us w/r/t a stream message needs to supply
@ -302,9 +302,9 @@ def get_recipient_info(
.order_by("user_profile_id")
)
message_to_user_ids = list()
message_to_user_id_set = set()
for row in subscription_rows:
message_to_user_ids.append(row["user_profile_id"])
message_to_user_id_set.add(row["user_profile_id"])
# We store the 'sender_muted_stream' information here to avoid db query at
# a later stage when we perform automatically unmute topic in muted stream operation.
if row["user_profile_id"] == sender_id:
@ -373,21 +373,18 @@ def get_recipient_info(
)
elif recipient.type == Recipient.DIRECT_MESSAGE_GROUP:
message_to_user_ids = get_huddle_user_ids(recipient)
message_to_user_id_set = set(get_huddle_user_ids(recipient))
else:
raise ValueError("Bad recipient type")
message_to_user_id_set = set(message_to_user_ids)
user_ids = set(message_to_user_id_set)
# Important note: Because we haven't rendered Markdown yet, we
# don't yet know which of these possibly-mentioned users was
# actually mentioned in the message (in other words, the
# mention syntax might have been in a code block or otherwise
# escaped). `get_ids_for` will filter these extra user rows
# for our data structures not related to bots
user_ids |= possibly_mentioned_user_ids
user_ids = message_to_user_id_set | possibly_mentioned_user_ids
if user_ids:
query: ValuesQuerySet[UserProfile, ActiveUserDict] = UserProfile.objects.filter(

View File

@ -1,6 +1,6 @@
from typing import Any, Collection, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
from django.db.models import Model
from django.db.models import Model, QuerySet
from django.utils.timezone import now as timezone_now
from zerver.lib.create_user import create_user_profile, get_display_email_address
@ -163,7 +163,9 @@ def bulk_create_users(
def bulk_set_users_or_streams_recipient_fields(
model: Type[Model],
objects: Union[Collection[UserProfile], Collection[Stream]],
objects: Union[
Collection[UserProfile], QuerySet[UserProfile], Collection[Stream], QuerySet[Stream]
],
recipients: Optional[Iterable[Recipient]] = None,
) -> None:
assert model in [UserProfile, Stream]

View File

@ -7,7 +7,7 @@ from typing import Any, Collection, Dict, Iterator, List, Optional, Set, Tuple
from django.conf import settings
from django.db import transaction
from django.db.models import Exists, OuterRef
from django.db.models import Exists, OuterRef, QuerySet
from django.utils.timezone import now as timezone_now
from django.utils.translation import gettext as _
from typing_extensions import TypeAlias
@ -331,7 +331,7 @@ def get_slim_stream_id_map(realm: Realm) -> Dict[int, Stream]:
def bulk_get_digest_context(
users: Collection[UserProfile], cutoff: float
users: Collection[UserProfile] | QuerySet[UserProfile], cutoff: float
) -> Iterator[Tuple[UserProfile, Dict[str, Any]]]:
# We expect a non-empty list of users all from the same realm.
assert users

View File

@ -420,7 +420,10 @@ def has_message_access(
def bulk_access_messages(
user_profile: UserProfile, messages: Collection[Message], *, stream: Optional[Stream] = None
user_profile: UserProfile,
messages: Collection[Message] | QuerySet[Message],
*,
stream: Optional[Stream] = None,
) -> List[Message]:
"""This function does the full has_message_access check for each
message. If stream is provided, it is used to avoid unnecessary

View File

@ -2,7 +2,7 @@ import copy
import zlib
from datetime import datetime
from email.headerregistry import Address
from typing import Any, Collection, Dict, List, Optional, TypedDict
from typing import Any, Dict, Iterable, List, Optional, TypedDict
import orjson
@ -78,7 +78,7 @@ def message_to_encoded_cache(message: Message, realm_id: Optional[int] = None) -
def update_message_cache(
changed_messages: Collection[Message], realm_id: Optional[int] = None
changed_messages: Iterable[Message], realm_id: Optional[int] = None
) -> List[int]:
"""Updates the message as stored in the to_dict cache (for serving
messages)."""
@ -273,7 +273,7 @@ class MessageDict:
@staticmethod
def messages_to_encoded_cache(
messages: Collection[Message], realm_id: Optional[int] = None
messages: Iterable[Message], realm_id: Optional[int] = None
) -> Dict[int, bytes]:
messages_dict = MessageDict.messages_to_encoded_cache_helper(messages, realm_id)
encoded_messages = {msg["id"]: stringify_message_dict(msg) for msg in messages_dict}
@ -281,7 +281,7 @@ class MessageDict:
@staticmethod
def messages_to_encoded_cache_helper(
messages: Collection[Message], realm_id: Optional[int] = None
messages: Iterable[Message], realm_id: Optional[int] = None
) -> List[Dict[str, Any]]:
# Near duplicate of the build_message_dict + get_raw_db_rows
# code path that accepts already fetched Message objects

View File

@ -180,7 +180,7 @@ def get_users_for_streams(stream_ids: Set[int]) -> Dict[int, Set[UserProfile]]:
def bulk_get_subscriber_peer_info(
realm: Realm,
streams: Collection[Stream],
streams: Collection[Stream] | QuerySet[Stream],
) -> SubscriberPeerInfo:
"""
Glossary:

View File

@ -47,6 +47,7 @@ from django.urls import resolve
from django.utils import translation
from django.utils.module_loading import import_string
from django.utils.timezone import now as timezone_now
from django_stubs_ext import ValuesQuerySet
from fakeldap import MockLDAP
from openapi_core.contrib.django import DjangoOpenAPIRequest, DjangoOpenAPIResponse
from requests import PreparedRequest
@ -1245,7 +1246,7 @@ Output:
"""
self.assertEqual(self.get_json_error(result, status_code=status_code), msg)
def assert_length(self, items: Collection[Any], count: int) -> None:
def assert_length(self, items: Collection[Any] | ValuesQuerySet[Any, Any], count: int) -> None:
actual_count = len(items)
if actual_count != count: # nocoverage
print("\nITEMS:\n")

View File

@ -1,10 +1,11 @@
import logging
from argparse import ArgumentParser
from typing import Any, Collection
from typing import Any
from django.conf import settings
from django.core.management.base import CommandError
from django.db import transaction
from django.db.models import QuerySet
from typing_extensions import override
from zerver.lib.logging_util import log_to_file
@ -20,7 +21,7 @@ log_to_file(logger, settings.LDAP_SYNC_LOG_PATH)
# Run this on a cron job to pick up on name changes.
@transaction.atomic
def sync_ldap_user_data(
user_profiles: Collection[UserProfile], deactivation_protection: bool = True
user_profiles: QuerySet[UserProfile], deactivation_protection: bool = True
) -> None:
logger.info("Starting update.")
try: