mirror of https://github.com/zulip/zulip.git
python: Avoid relying on Collection supertype of QuerySet.
QuerySet doesn’t implement __contains__, so it can’t be a subtype of Container or Collection (https://code.djangoproject.com/ticket/35154). This incorrect subtyping annotation was removed in https://github.com/typeddjango/django-stubs/pull/1925, so we need to stop relying on it before upgrading to django-stubs 5. Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
parent
5654d051f7
commit
f31579a220
|
@ -5,7 +5,7 @@ from typing import Any, Collection, Dict, List, Optional, Sequence, Set, Tuple,
|
|||
from django.conf import settings
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.db import transaction
|
||||
from django.db.models import Q, Sum
|
||||
from django.db.models import Q, QuerySet, Sum
|
||||
from django.utils.timezone import now as timezone_now
|
||||
from django.utils.translation import gettext as _
|
||||
from zxcvbn import zxcvbn
|
||||
|
@ -70,7 +70,7 @@ def do_send_confirmation_email(
|
|||
return activation_url
|
||||
|
||||
|
||||
def estimate_recent_invites(realms: Collection[Realm], *, days: int) -> int:
|
||||
def estimate_recent_invites(realms: Collection[Realm] | QuerySet[Realm], *, days: int) -> int:
|
||||
"""An upper bound on the number of invites sent in the last `days` days"""
|
||||
recent_invites = RealmCount.objects.filter(
|
||||
realm__in=realms,
|
||||
|
|
|
@ -238,8 +238,8 @@ def get_recipient_info(
|
|||
if recipient.type == Recipient.PERSONAL:
|
||||
# The sender and recipient may be the same id, so
|
||||
# de-duplicate using a set.
|
||||
message_to_user_ids: Collection[int] = list({recipient.type_id, sender_id})
|
||||
assert len(message_to_user_ids) in [1, 2]
|
||||
message_to_user_id_set = {recipient.type_id, sender_id}
|
||||
assert len(message_to_user_id_set) in [1, 2]
|
||||
|
||||
elif recipient.type == Recipient.STREAM:
|
||||
# Anybody calling us w/r/t a stream message needs to supply
|
||||
|
@ -302,9 +302,9 @@ def get_recipient_info(
|
|||
.order_by("user_profile_id")
|
||||
)
|
||||
|
||||
message_to_user_ids = list()
|
||||
message_to_user_id_set = set()
|
||||
for row in subscription_rows:
|
||||
message_to_user_ids.append(row["user_profile_id"])
|
||||
message_to_user_id_set.add(row["user_profile_id"])
|
||||
# We store the 'sender_muted_stream' information here to avoid db query at
|
||||
# a later stage when we perform automatically unmute topic in muted stream operation.
|
||||
if row["user_profile_id"] == sender_id:
|
||||
|
@ -373,21 +373,18 @@ def get_recipient_info(
|
|||
)
|
||||
|
||||
elif recipient.type == Recipient.DIRECT_MESSAGE_GROUP:
|
||||
message_to_user_ids = get_huddle_user_ids(recipient)
|
||||
message_to_user_id_set = set(get_huddle_user_ids(recipient))
|
||||
|
||||
else:
|
||||
raise ValueError("Bad recipient type")
|
||||
|
||||
message_to_user_id_set = set(message_to_user_ids)
|
||||
|
||||
user_ids = set(message_to_user_id_set)
|
||||
# Important note: Because we haven't rendered Markdown yet, we
|
||||
# don't yet know which of these possibly-mentioned users was
|
||||
# actually mentioned in the message (in other words, the
|
||||
# mention syntax might have been in a code block or otherwise
|
||||
# escaped). `get_ids_for` will filter these extra user rows
|
||||
# for our data structures not related to bots
|
||||
user_ids |= possibly_mentioned_user_ids
|
||||
user_ids = message_to_user_id_set | possibly_mentioned_user_ids
|
||||
|
||||
if user_ids:
|
||||
query: ValuesQuerySet[UserProfile, ActiveUserDict] = UserProfile.objects.filter(
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Any, Collection, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
from django.db.models import Model
|
||||
from django.db.models import Model, QuerySet
|
||||
from django.utils.timezone import now as timezone_now
|
||||
|
||||
from zerver.lib.create_user import create_user_profile, get_display_email_address
|
||||
|
@ -163,7 +163,9 @@ def bulk_create_users(
|
|||
|
||||
def bulk_set_users_or_streams_recipient_fields(
|
||||
model: Type[Model],
|
||||
objects: Union[Collection[UserProfile], Collection[Stream]],
|
||||
objects: Union[
|
||||
Collection[UserProfile], QuerySet[UserProfile], Collection[Stream], QuerySet[Stream]
|
||||
],
|
||||
recipients: Optional[Iterable[Recipient]] = None,
|
||||
) -> None:
|
||||
assert model in [UserProfile, Stream]
|
||||
|
|
|
@ -7,7 +7,7 @@ from typing import Any, Collection, Dict, Iterator, List, Optional, Set, Tuple
|
|||
|
||||
from django.conf import settings
|
||||
from django.db import transaction
|
||||
from django.db.models import Exists, OuterRef
|
||||
from django.db.models import Exists, OuterRef, QuerySet
|
||||
from django.utils.timezone import now as timezone_now
|
||||
from django.utils.translation import gettext as _
|
||||
from typing_extensions import TypeAlias
|
||||
|
@ -331,7 +331,7 @@ def get_slim_stream_id_map(realm: Realm) -> Dict[int, Stream]:
|
|||
|
||||
|
||||
def bulk_get_digest_context(
|
||||
users: Collection[UserProfile], cutoff: float
|
||||
users: Collection[UserProfile] | QuerySet[UserProfile], cutoff: float
|
||||
) -> Iterator[Tuple[UserProfile, Dict[str, Any]]]:
|
||||
# We expect a non-empty list of users all from the same realm.
|
||||
assert users
|
||||
|
|
|
@ -420,7 +420,10 @@ def has_message_access(
|
|||
|
||||
|
||||
def bulk_access_messages(
|
||||
user_profile: UserProfile, messages: Collection[Message], *, stream: Optional[Stream] = None
|
||||
user_profile: UserProfile,
|
||||
messages: Collection[Message] | QuerySet[Message],
|
||||
*,
|
||||
stream: Optional[Stream] = None,
|
||||
) -> List[Message]:
|
||||
"""This function does the full has_message_access check for each
|
||||
message. If stream is provided, it is used to avoid unnecessary
|
||||
|
|
|
@ -2,7 +2,7 @@ import copy
|
|||
import zlib
|
||||
from datetime import datetime
|
||||
from email.headerregistry import Address
|
||||
from typing import Any, Collection, Dict, List, Optional, TypedDict
|
||||
from typing import Any, Dict, Iterable, List, Optional, TypedDict
|
||||
|
||||
import orjson
|
||||
|
||||
|
@ -78,7 +78,7 @@ def message_to_encoded_cache(message: Message, realm_id: Optional[int] = None) -
|
|||
|
||||
|
||||
def update_message_cache(
|
||||
changed_messages: Collection[Message], realm_id: Optional[int] = None
|
||||
changed_messages: Iterable[Message], realm_id: Optional[int] = None
|
||||
) -> List[int]:
|
||||
"""Updates the message as stored in the to_dict cache (for serving
|
||||
messages)."""
|
||||
|
@ -273,7 +273,7 @@ class MessageDict:
|
|||
|
||||
@staticmethod
|
||||
def messages_to_encoded_cache(
|
||||
messages: Collection[Message], realm_id: Optional[int] = None
|
||||
messages: Iterable[Message], realm_id: Optional[int] = None
|
||||
) -> Dict[int, bytes]:
|
||||
messages_dict = MessageDict.messages_to_encoded_cache_helper(messages, realm_id)
|
||||
encoded_messages = {msg["id"]: stringify_message_dict(msg) for msg in messages_dict}
|
||||
|
@ -281,7 +281,7 @@ class MessageDict:
|
|||
|
||||
@staticmethod
|
||||
def messages_to_encoded_cache_helper(
|
||||
messages: Collection[Message], realm_id: Optional[int] = None
|
||||
messages: Iterable[Message], realm_id: Optional[int] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
# Near duplicate of the build_message_dict + get_raw_db_rows
|
||||
# code path that accepts already fetched Message objects
|
||||
|
|
|
@ -180,7 +180,7 @@ def get_users_for_streams(stream_ids: Set[int]) -> Dict[int, Set[UserProfile]]:
|
|||
|
||||
def bulk_get_subscriber_peer_info(
|
||||
realm: Realm,
|
||||
streams: Collection[Stream],
|
||||
streams: Collection[Stream] | QuerySet[Stream],
|
||||
) -> SubscriberPeerInfo:
|
||||
"""
|
||||
Glossary:
|
||||
|
|
|
@ -47,6 +47,7 @@ from django.urls import resolve
|
|||
from django.utils import translation
|
||||
from django.utils.module_loading import import_string
|
||||
from django.utils.timezone import now as timezone_now
|
||||
from django_stubs_ext import ValuesQuerySet
|
||||
from fakeldap import MockLDAP
|
||||
from openapi_core.contrib.django import DjangoOpenAPIRequest, DjangoOpenAPIResponse
|
||||
from requests import PreparedRequest
|
||||
|
@ -1245,7 +1246,7 @@ Output:
|
|||
"""
|
||||
self.assertEqual(self.get_json_error(result, status_code=status_code), msg)
|
||||
|
||||
def assert_length(self, items: Collection[Any], count: int) -> None:
|
||||
def assert_length(self, items: Collection[Any] | ValuesQuerySet[Any, Any], count: int) -> None:
|
||||
actual_count = len(items)
|
||||
if actual_count != count: # nocoverage
|
||||
print("\nITEMS:\n")
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
import logging
|
||||
from argparse import ArgumentParser
|
||||
from typing import Any, Collection
|
||||
from typing import Any
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.management.base import CommandError
|
||||
from django.db import transaction
|
||||
from django.db.models import QuerySet
|
||||
from typing_extensions import override
|
||||
|
||||
from zerver.lib.logging_util import log_to_file
|
||||
|
@ -20,7 +21,7 @@ log_to_file(logger, settings.LDAP_SYNC_LOG_PATH)
|
|||
# Run this on a cron job to pick up on name changes.
|
||||
@transaction.atomic
|
||||
def sync_ldap_user_data(
|
||||
user_profiles: Collection[UserProfile], deactivation_protection: bool = True
|
||||
user_profiles: QuerySet[UserProfile], deactivation_protection: bool = True
|
||||
) -> None:
|
||||
logger.info("Starting update.")
|
||||
try:
|
||||
|
|
Loading…
Reference in New Issue