diff --git a/zerver/lib/actions.py b/zerver/lib/actions.py index 6c8ddad37d..cb8c4c2beb 100644 --- a/zerver/lib/actions.py +++ b/zerver/lib/actions.py @@ -138,6 +138,7 @@ from zerver.lib.stream_subscription import ( get_stream_subscriptions_for_users, get_subscribed_stream_ids_for_user, get_subscriptions_for_send_message, + get_used_colors_for_user_ids, get_user_ids_for_streams, num_subscribers_for_stream_id, subscriber_ids_with_stream_history_access, @@ -3846,6 +3847,7 @@ def bulk_add_subscriptions( acting_user: Optional[UserProfile], ) -> SubT: users = list(users) + user_ids = [user.id for user in users] # Sanity check out callers for stream in streams: @@ -3856,6 +3858,8 @@ def bulk_add_subscriptions( recipient_id_to_stream = {stream.recipient_id: stream for stream in streams} + used_colors_for_user_ids: Dict[int, Set[str]] = get_used_colors_for_user_ids(user_ids) + subs_by_user: Dict[int, List[Subscription]] = defaultdict(list) all_subs_query = get_stream_subscriptions_for_users(users) for sub in all_subs_query: @@ -3866,7 +3870,7 @@ def bulk_add_subscriptions( subs_to_add: List[SubInfo] = [] for user_profile in users: my_subs = subs_by_user[user_profile.id] - used_colors = {sub.color for sub in my_subs} + used_colors = used_colors_for_user_ids.get(user_profile.id, set()) # Make a fresh set of all new recipient ids, and then we will # remove any for which our user already has a subscription diff --git a/zerver/lib/stream_subscription.py b/zerver/lib/stream_subscription.py index 002933365e..bc2d57584b 100644 --- a/zerver/lib/stream_subscription.py +++ b/zerver/lib/stream_subscription.py @@ -82,6 +82,30 @@ def get_stream_subscriptions_for_users(user_profiles: List[UserProfile]) -> Quer ) +def get_used_colors_for_user_ids(user_ids: List[int]) -> Dict[int, Set[str]]: + """Fetch which stream colors have already been used for each user in + user_ids. Uses an optimized query designed to support picking + colors when bulk-adding users to streams, which requires + inspecting all Subscription objects for the users, which can often + end up being all Subscription objects in the realm. + """ + query = ( + Subscription.objects.filter( + user_profile_id__in=user_ids, + recipient__type=Recipient.STREAM, + ) + .values("user_profile_id", "color") + .distinct() + ) + + result: Dict[int, Set[str]] = defaultdict(set) + + for row in list(query): + result[row["user_profile_id"]].add(row["color"]) + + return result + + def get_bulk_stream_subscriber_info( users: List[UserProfile], streams: List[Stream], diff --git a/zerver/tests/test_signup.py b/zerver/tests/test_signup.py index 1b794fafd3..1bbf862b3d 100644 --- a/zerver/tests/test_signup.py +++ b/zerver/tests/test_signup.py @@ -868,7 +868,7 @@ class LoginTest(ZulipTestCase): with queries_captured() as queries, cache_tries_captured() as cache_tries: self.register(self.nonreg_email("test"), "test") # Ensure the number of queries we make is not O(streams) - self.assert_length(queries, 89) + self.assert_length(queries, 90) # We can probably avoid a couple cache hits here, but there doesn't # seem to be any O(N) behavior. Some of the cache hits are related diff --git a/zerver/tests/test_subs.py b/zerver/tests/test_subs.py index 3db8ec1cc0..fbdfd289c1 100644 --- a/zerver/tests/test_subs.py +++ b/zerver/tests/test_subs.py @@ -3655,7 +3655,7 @@ class SubscriptionAPITest(ZulipTestCase): streams_to_sub, dict(principals=orjson.dumps([user1.id, user2.id]).decode()), ) - self.assert_length(queries, 35) + self.assert_length(queries, 36) for ev in [x for x in events if x["event"]["type"] not in ("message", "stream")]: if ev["event"]["op"] == "add": @@ -3680,7 +3680,7 @@ class SubscriptionAPITest(ZulipTestCase): streams_to_sub, dict(principals=orjson.dumps([self.test_user.id]).decode()), ) - self.assert_length(queries, 11) + self.assert_length(queries, 12) add_event, add_peer_event = events self.assertEqual(add_event["event"]["type"], "subscription") @@ -4062,7 +4062,7 @@ class SubscriptionAPITest(ZulipTestCase): # The only known O(N) behavior here is that we call # principal_to_user_profile for each of our users. - self.assert_length(queries, 18) + self.assert_length(queries, 19) self.assert_length(cache_tries, 4) def test_subscriptions_add_for_principal(self) -> None: @@ -4523,7 +4523,7 @@ class SubscriptionAPITest(ZulipTestCase): [new_streams[0]], dict(principals=orjson.dumps([user1.id, user2.id]).decode()), ) - self.assert_length(queries, 35) + self.assert_length(queries, 36) # Test creating private stream. with queries_captured() as queries: @@ -4533,7 +4533,7 @@ class SubscriptionAPITest(ZulipTestCase): dict(principals=orjson.dumps([user1.id, user2.id]).decode()), invite_only=True, ) - self.assert_length(queries, 34) + self.assert_length(queries, 35) # Test creating a public stream with announce when realm has a notification stream. notifications_stream = get_stream(self.streams[0], self.test_realm) @@ -4548,7 +4548,7 @@ class SubscriptionAPITest(ZulipTestCase): principals=orjson.dumps([user1.id, user2.id]).decode(), ), ) - self.assert_length(queries, 43) + self.assert_length(queries, 44) class GetStreamsTest(ZulipTestCase): diff --git a/zerver/tests/test_users.py b/zerver/tests/test_users.py index 23198d2e77..bcda1c3e81 100644 --- a/zerver/tests/test_users.py +++ b/zerver/tests/test_users.py @@ -799,7 +799,7 @@ class QueryCountTest(ZulipTestCase): acting_user=None, ) - self.assert_length(queries, 84) + self.assert_length(queries, 85) self.assert_length(cache_tries, 27) peer_add_events = [event for event in events if event["event"].get("op") == "peer_add"]