bots: Remove private stream subscriptions on changing bot owner.

We remove bot's subscriptions for private streams to which the
new owner is not subscribed and keep the ones to which the new
owner is subscribed on changing owner.

This commit also changes the code for sending subscription
remove events to use transaction.on_commit since we call
the function inside a transactopn in do_change_bot_owner and
this also requires some changes in tests in test_events.
This commit is contained in:
Sahil Batra 2022-05-07 12:26:33 +05:30 committed by Tim Abbott
parent ba00907946
commit 35d5609996
8 changed files with 146 additions and 17 deletions

View File

@ -5,6 +5,8 @@ from django.db import transaction
from django.utils.timezone import now as timezone_now from django.utils.timezone import now as timezone_now
from zerver.actions.create_user import created_bot_event from zerver.actions.create_user import created_bot_event
from zerver.actions.streams import bulk_remove_subscriptions
from zerver.lib.streams import get_subscribed_private_streams_for_user
from zerver.models import RealmAuditLog, Stream, UserProfile, active_user_ids, bot_owner_user_ids from zerver.models import RealmAuditLog, Stream, UserProfile, active_user_ids, bot_owner_user_ids
from zerver.tornado.django_api import send_event_on_commit from zerver.tornado.django_api import send_event_on_commit
@ -70,6 +72,34 @@ def send_bot_owner_update_events(
send_event_on_commit(user_profile.realm, event, active_user_ids(user_profile.realm_id)) send_event_on_commit(user_profile.realm, event, active_user_ids(user_profile.realm_id))
def remove_bot_from_inaccessible_private_streams(
user_profile: UserProfile, *, acting_user: Optional[UserProfile]
) -> None:
assert user_profile.bot_owner is not None
new_owner_subscribed_private_streams = get_subscribed_private_streams_for_user(
user_profile.bot_owner
)
new_owner_subscribed_private_stream_ids = [
stream.id for stream in new_owner_subscribed_private_streams
]
bot_subscribed_private_streams = get_subscribed_private_streams_for_user(user_profile)
bot_subscribed_private_stream_ids = [stream.id for stream in bot_subscribed_private_streams]
stream_ids_to_unsubscribe = set(bot_subscribed_private_stream_ids) - set(
new_owner_subscribed_private_stream_ids
)
unsubscribed_streams = [
stream
for stream in bot_subscribed_private_streams
if stream.id in stream_ids_to_unsubscribe
]
bulk_remove_subscriptions(
user_profile.realm, [user_profile], unsubscribed_streams, acting_user=acting_user
)
@transaction.atomic(durable=True) @transaction.atomic(durable=True)
def do_change_bot_owner( def do_change_bot_owner(
user_profile: UserProfile, bot_owner: UserProfile, acting_user: Union[UserProfile, None] user_profile: UserProfile, bot_owner: UserProfile, acting_user: Union[UserProfile, None]
@ -88,6 +118,8 @@ def do_change_bot_owner(
send_bot_owner_update_events(user_profile, bot_owner, previous_owner) send_bot_owner_update_events(user_profile, bot_owner, previous_owner)
remove_bot_from_inaccessible_private_streams(user_profile, acting_user=acting_user)
@transaction.atomic(durable=True) @transaction.atomic(durable=True)
def do_change_default_sending_stream( def do_change_default_sending_stream(

View File

@ -28,7 +28,7 @@ from zerver.lib.email_mirror_helpers import encode_email_address
from zerver.lib.exceptions import JsonableError from zerver.lib.exceptions import JsonableError
from zerver.lib.mention import silent_mention_syntax_for_user from zerver.lib.mention import silent_mention_syntax_for_user
from zerver.lib.message import get_last_message_id from zerver.lib.message import get_last_message_id
from zerver.lib.queue import queue_json_publish from zerver.lib.queue import queue_event_on_commit, queue_json_publish
from zerver.lib.stream_color import pick_colors from zerver.lib.stream_color import pick_colors
from zerver.lib.stream_subscription import ( from zerver.lib.stream_subscription import (
SubInfo, SubInfo,
@ -718,7 +718,7 @@ def notify_subscriptions_removed(
) -> None: ) -> None:
payload = [dict(name=stream.name, stream_id=stream.id) for stream in streams] payload = [dict(name=stream.name, stream_id=stream.id) for stream in streams]
event = dict(type="subscription", op="remove", subscriptions=payload) event = dict(type="subscription", op="remove", subscriptions=payload)
send_event(realm, event, [user_profile.id]) send_event_on_commit(realm, event, [user_profile.id])
SubAndRemovedT: TypeAlias = Tuple[ SubAndRemovedT: TypeAlias = Tuple[
@ -750,7 +750,7 @@ def send_subscription_remove_events(
stream.recipient_id for stream in streams_by_user[user_profile.id] stream.recipient_id for stream in streams_by_user[user_profile.id]
], ],
} }
queue_json_publish("deferred_work", event) queue_event_on_commit("deferred_work", event)
send_peer_remove_events( send_peer_remove_events(
realm=realm, realm=realm,

View File

@ -13,6 +13,7 @@ import pika.adapters.tornado_connection
import pika.connection import pika.connection
import pika.exceptions import pika.exceptions
from django.conf import settings from django.conf import settings
from django.db import transaction
from pika.adapters.blocking_connection import BlockingChannel from pika.adapters.blocking_connection import BlockingChannel
from pika.channel import Channel from pika.channel import Channel
from pika.spec import Basic from pika.spec import Basic
@ -439,6 +440,10 @@ def queue_json_publish(
get_worker(queue_name, disable_timeout=True).consume_single_event(event) get_worker(queue_name, disable_timeout=True).consume_single_event(event)
def queue_event_on_commit(queue_name: str, event: Dict[str, Any]) -> None:
transaction.on_commit(lambda: queue_json_publish(queue_name, event))
def retry_event( def retry_event(
queue_name: str, event: Dict[str, Any], failure_processor: Callable[[Dict[str, Any]], None] queue_name: str, event: Dict[str, Any], failure_processor: Callable[[Dict[str, Any]], None]
) -> None: ) -> None:

View File

@ -956,3 +956,20 @@ def do_get_streams(
stream["is_default"] = stream["stream_id"] in default_stream_ids stream["is_default"] = stream["stream_id"] in default_stream_ids
return stream_dicts return stream_dicts
def get_subscribed_private_streams_for_user(user_profile: UserProfile) -> QuerySet[Stream]:
exists_expression = Exists(
Subscription.objects.filter(
user_profile=user_profile,
active=True,
is_user_active=True,
recipient_id=OuterRef("recipient_id"),
),
)
subscribed_private_streams = (
Stream.objects.filter(realm=user_profile.realm, invite_only=True, deactivated=False)
.annotate(subscribed=exists_expression)
.filter(subscribed=True)
)
return subscribed_private_streams

View File

@ -21,6 +21,7 @@ from zerver.models import (
Realm, Realm,
RealmUserDefault, RealmUserDefault,
Service, Service,
Subscription,
UserProfile, UserProfile,
get_bot_services, get_bot_services,
get_realm, get_realm,
@ -1102,6 +1103,54 @@ class BotTest(ZulipTestCase, UploadSerializeMixin):
bot_user.refresh_from_db() bot_user.refresh_from_db()
self.assertEqual(bot_user.bot_owner, cordelia) self.assertEqual(bot_user.bot_owner, cordelia)
def test_patch_bot_owner_with_private_streams(self) -> None:
self.login("iago")
hamlet = self.example_user("hamlet")
self.create_bot()
bot_realm = get_realm("zulip")
bot_email = "hambot-bot@zulip.testserver"
bot_user = get_user(bot_email, bot_realm)
private_stream = self.make_stream("private_stream", invite_only=True)
public_stream = self.make_stream("public_stream")
self.subscribe(bot_user, "private_stream")
self.subscribe(self.example_user("iago"), "private_stream")
self.subscribe(bot_user, "public_stream")
self.subscribe(self.example_user("iago"), "public_stream")
private_stream_test = self.make_stream("private_stream_test", invite_only=True)
self.subscribe(self.example_user("hamlet"), "private_stream_test")
self.subscribe(bot_user, "private_stream_test")
bot_info = {
"bot_owner_id": hamlet.id,
}
result = self.client_patch(f"/json/bots/{bot_user.id}", bot_info)
self.assert_json_success(result)
bot_user = get_user(bot_email, bot_realm)
assert bot_user.bot_owner is not None
self.assertEqual(bot_user.bot_owner.id, hamlet.id)
assert private_stream.recipient_id is not None
self.assertFalse(
Subscription.objects.filter(
user_profile=bot_user, recipient_id=private_stream.recipient_id, active=True
).exists()
)
assert private_stream_test.recipient_id is not None
self.assertTrue(
Subscription.objects.filter(
user_profile=bot_user, recipient_id=private_stream_test.recipient_id, active=True
).exists()
)
assert public_stream.recipient_id is not None
self.assertTrue(
Subscription.objects.filter(
user_profile=bot_user, recipient_id=public_stream.recipient_id, active=True
).exists()
)
def test_patch_bot_avatar(self) -> None: def test_patch_bot_avatar(self) -> None:
self.login("hamlet") self.login("hamlet")
bot_info = { bot_info = {

View File

@ -2398,6 +2398,29 @@ class NormalActionsTest(BaseAction):
check_realm_bot_add("events[0]", events[0]) check_realm_bot_add("events[0]", events[0])
check_realm_user_update("events[1]", events[1], "bot_owner_id") check_realm_user_update("events[1]", events[1], "bot_owner_id")
def test_peer_remove_events_on_changing_bot_owner(self) -> None:
previous_owner = self.example_user("aaron")
self.user_profile = self.example_user("iago")
bot = self.create_test_bot("test2", previous_owner, full_name="Test2 Testerson")
private_stream = self.make_stream("private_stream", invite_only=True)
self.make_stream("public_stream")
self.subscribe(bot, "private_stream")
self.subscribe(self.example_user("aaron"), "private_stream")
self.subscribe(bot, "public_stream")
self.subscribe(self.example_user("aaron"), "public_stream")
self.make_stream("private_stream_test", invite_only=True)
self.subscribe(self.example_user("iago"), "private_stream_test")
self.subscribe(bot, "private_stream_test")
action = lambda: do_change_bot_owner(bot, self.user_profile, previous_owner)
events = self.verify_action(action, num_events=3)
check_realm_bot_update("events[0]", events[0], "owner_id")
check_realm_user_update("events[1]", events[1], "bot_owner_id")
check_subscription_peer_remove("events[2]", events[2])
self.assertEqual(events[2]["stream_ids"], [private_stream.id])
def test_do_update_outgoing_webhook_service(self) -> None: def test_do_update_outgoing_webhook_service(self) -> None:
self.user_profile = self.example_user("iago") self.user_profile = self.example_user("iago")
bot = self.create_test_bot( bot = self.create_test_bot(

View File

@ -1899,6 +1899,7 @@ class MarkUnreadTest(ZulipTestCase):
] ]
# Unsubscribing generates an event in the deferred_work queue # Unsubscribing generates an event in the deferred_work queue
# that marks the above messages as read. # that marks the above messages as read.
with self.captureOnCommitCallbacks(execute=True):
self.unsubscribe(receiver, stream_name) self.unsubscribe(receiver, stream_name)
after_unsubscribe_stream_message_ids = [ after_unsubscribe_stream_message_ids = [
self.send_stream_message( self.send_stream_message(

View File

@ -2590,6 +2590,7 @@ class StreamAdminTest(ZulipTestCase):
with self.assert_database_query_count(query_count): with self.assert_database_query_count(query_count):
with cache_tries_captured() as cache_tries: with cache_tries_captured() as cache_tries:
with self.captureOnCommitCallbacks(execute=True):
result = self.client_delete( result = self.client_delete(
"/json/users/me/subscriptions", "/json/users/me/subscriptions",
{ {
@ -4940,9 +4941,9 @@ class SubscriptionAPITest(ZulipTestCase):
self.subscribe(user3, "private_stream") self.subscribe(user3, "private_stream")
# Sends 3 peer-remove events and 2 unsubscribe events. # Sends 3 peer-remove events and 2 unsubscribe events.
with self.capture_send_event_calls(expected_num_events=5) as events:
with self.assert_database_query_count(16): with self.assert_database_query_count(16):
with self.assert_memcached_count(3): with self.assert_memcached_count(3):
with self.capture_send_event_calls(expected_num_events=5) as events:
bulk_remove_subscriptions( bulk_remove_subscriptions(
realm, realm,
[user1, user2], [user1, user2],
@ -5455,6 +5456,7 @@ class SubscriptionAPITest(ZulipTestCase):
self.assertEqual(result[1]["stream_id"], stream2.id) self.assertEqual(result[1]["stream_id"], stream2.id)
self.assertEqual(result[2]["stream_id"], private.id) self.assertEqual(result[2]["stream_id"], private.id)
with self.captureOnCommitCallbacks(execute=True):
# Unsubscribing should mark all the messages in stream2 as read # Unsubscribing should mark all the messages in stream2 as read
self.unsubscribe(user, "stream2") self.unsubscribe(user, "stream2")
self.unsubscribe(user, "private_stream") self.unsubscribe(user, "private_stream")