diff --git a/zerver/actions/message_send.py b/zerver/actions/message_send.py index b32ed2f996..c18147361d 100644 --- a/zerver/actions/message_send.py +++ b/zerver/actions/message_send.py @@ -68,6 +68,7 @@ from zerver.lib.notification_data import ( get_user_group_mentions_data, user_allows_notifications_in_StreamTopic, ) +from zerver.lib.query_helpers import query_for_ids from zerver.lib.queue import queue_json_publish from zerver.lib.recipient_users import recipient_for_user_profiles from zerver.lib.stream_subscription import ( @@ -101,7 +102,6 @@ from zerver.models import ( UserPresence, UserProfile, UserTopic, - query_for_ids, ) from zerver.models.clients import get_client from zerver.models.groups import SystemGroups diff --git a/zerver/lib/message.py b/zerver/lib/message.py index 25f7bdc860..e0d54dfd47 100644 --- a/zerver/lib/message.py +++ b/zerver/lib/message.py @@ -42,6 +42,7 @@ from zerver.lib.exceptions import JsonableError, MissingAuthenticationError from zerver.lib.markdown import MessageRenderingResult, markdown_convert, topic_links from zerver.lib.markdown import version as markdown_version from zerver.lib.mention import MentionData +from zerver.lib.query_helpers import query_for_ids from zerver.lib.request import RequestVariableConversionError from zerver.lib.stream_subscription import ( get_stream_subscriptions_for_user, @@ -74,7 +75,6 @@ from zerver.models import ( UserMessage, UserProfile, UserTopic, - query_for_ids, ) from zerver.models.constants import MAX_TOPIC_NAME_LENGTH from zerver.models.messages import get_usermessage_by_message_id diff --git a/zerver/lib/presence.py b/zerver/lib/presence.py index bb30b11ae6..06e09517da 100644 --- a/zerver/lib/presence.py +++ b/zerver/lib/presence.py @@ -6,9 +6,10 @@ from typing import Any, Dict, Mapping, Optional, Sequence, Set from django.conf import settings from django.utils.timezone import now as timezone_now +from zerver.lib.query_helpers import query_for_ids from zerver.lib.timestamp import datetime_to_timestamp from zerver.lib.users import check_user_can_access_all_users, get_accessible_user_ids -from zerver.models import PushDeviceToken, Realm, UserPresence, UserProfile, query_for_ids +from zerver.models import PushDeviceToken, Realm, UserPresence, UserProfile def get_presence_dicts_for_rows( diff --git a/zerver/lib/query_helpers.py b/zerver/lib/query_helpers.py new file mode 100644 index 0000000000..52f1dea666 --- /dev/null +++ b/zerver/lib/query_helpers.py @@ -0,0 +1,30 @@ +from typing import List, TypeVar + +from django.db import models +from django_stubs_ext import ValuesQuerySet + +ModelT = TypeVar("ModelT", bound=models.Model) +RowT = TypeVar("RowT") + + +def query_for_ids( + query: ValuesQuerySet[ModelT, RowT], + user_ids: List[int], + field: str, +) -> ValuesQuerySet[ModelT, RowT]: + """ + This function optimizes searches of the form + `user_profile_id in (1, 2, 3, 4)` by quickly + building the where clauses. Profiling shows significant + speedups over the normal Django-based approach. + + Use this very carefully! Also, the caller should + guard against empty lists of user_ids. + """ + assert user_ids + clause = f"{field} IN %s" + query = query.extra( + where=[clause], + params=(tuple(user_ids),), + ) + return query diff --git a/zerver/models/__init__.py b/zerver/models/__init__.py index 149902cfd0..8de3b242bc 100644 --- a/zerver/models/__init__.py +++ b/zerver/models/__init__.py @@ -1,9 +1,8 @@ -from typing import List, Tuple, TypeVar, Union +from typing import List, Tuple, Union from django.db import models from django.db.backends.base.base import BaseDatabaseWrapper from django.db.models.sql.compiler import SQLCompiler -from django_stubs_ext import ValuesQuerySet from typing_extensions import override from zerver.models.alert_words import AlertWord as AlertWord @@ -98,30 +97,3 @@ class AndNonZero(models.Lookup[int]): lhs, lhs_params = self.process_lhs(compiler, connection) rhs, rhs_params = self.process_rhs(compiler, connection) return f"{lhs} & {rhs} != 0", lhs_params + rhs_params - - -ModelT = TypeVar("ModelT", bound=models.Model) -RowT = TypeVar("RowT") - - -def query_for_ids( - query: ValuesQuerySet[ModelT, RowT], - user_ids: List[int], - field: str, -) -> ValuesQuerySet[ModelT, RowT]: - """ - This function optimizes searches of the form - `user_profile_id in (1, 2, 3, 4)` by quickly - building the where clauses. Profiling shows significant - speedups over the normal Django-based approach. - - Use this very carefully! Also, the caller should - guard against empty lists of user_ids. - """ - assert user_ids - clause = f"{field} IN %s" - query = query.extra( - where=[clause], - params=(tuple(user_ids),), - ) - return query