refactor: Extract get_stream_subscriptions_for_users().

This commit is contained in:
Steve Howell 2017-10-29 11:15:35 -07:00 committed by Tim Abbott
parent b3192d17ab
commit 08ad26f913
2 changed files with 12 additions and 3 deletions

View File

@ -38,6 +38,7 @@ from zerver.lib.stream_subscription import (
get_active_subscriptions_for_stream_id,
get_active_subscriptions_for_stream_ids,
get_stream_subscriptions_for_user,
get_stream_subscriptions_for_users,
num_subscribers_for_stream_id,
)
from zerver.lib.stream_topic import StreamTopicTarget
@ -2104,6 +2105,8 @@ def get_user_ids_for_streams(streams):
def bulk_add_subscriptions(streams, users, from_stream_creation=False, acting_user=None):
# type: (Iterable[Stream], Iterable[UserProfile], bool, Optional[UserProfile]) -> Tuple[List[Tuple[UserProfile, Stream]], List[Tuple[UserProfile, Stream]]]
users = list(users)
recipients_map = bulk_get_recipients(Recipient.STREAM, [stream.id for stream in streams]) # type: Mapping[int, Recipient]
recipients = [recipient.id for recipient in recipients_map.values()] # type: List[int]
@ -2112,9 +2115,8 @@ def bulk_add_subscriptions(streams, users, from_stream_creation=False, acting_us
stream_map[recipients_map[stream.id].id] = stream
subs_by_user = defaultdict(list) # type: Dict[int, List[Subscription]]
all_subs_query = Subscription.objects.select_related("user_profile")
for sub in all_subs_query.filter(user_profile__in=users,
recipient__type=Recipient.STREAM):
all_subs_query = get_stream_subscriptions_for_users(users).select_related('user_profile')
for sub in all_subs_query:
subs_by_user[sub.user_profile_id].append(sub)
already_subscribed = [] # type: List[Tuple[UserProfile, Stream]]

View File

@ -30,6 +30,13 @@ def get_stream_subscriptions_for_user(user_profile):
recipient__type=Recipient.STREAM,
)
def get_stream_subscriptions_for_users(user_profiles):
# type: (List[UserProfile]) -> QuerySet
return Subscription.objects.filter(
user_profile__in=user_profiles,
recipient__type=Recipient.STREAM,
)
def num_subscribers_for_stream_id(stream_id):
# type: (int) -> int
return get_active_subscriptions_for_stream_id(stream_id).filter(