diff --git a/.github/workflows/zulip-ci.yml b/.github/workflows/zulip-ci.yml index 0b69c5d510..c2788c1c8f 100644 --- a/.github/workflows/zulip-ci.yml +++ b/.github/workflows/zulip-ci.yml @@ -173,7 +173,7 @@ jobs: - name: Run backend tests run: | source tools/ci/activate-venv - ./tools/test-backend --coverage --xml-report --no-html-report --include-webhooks --no-cov-cleanup --ban-console-output + ./tools/test-backend --coverage --xml-report --no-html-report --include-webhooks --include-transaction-tests --no-cov-cleanup --ban-console-output - name: Run mypy run: | diff --git a/tools/linter_lib/custom_check.py b/tools/linter_lib/custom_check.py index 7f8e5e2ad4..0c87c715a6 100644 --- a/tools/linter_lib/custom_check.py +++ b/tools/linter_lib/custom_check.py @@ -224,7 +224,7 @@ python_rules = RuleList( rules=[ { "pattern": "subject|SUBJECT", - "exclude_pattern": "subject to the|email|outbox|account deactivation", + "exclude_pattern": "subject to the|email|outbox|account deactivation|is subject to", "description": "avoid subject as a var", "good_lines": ["topic_name"], "bad_lines": ['subject="foo"', " MAX_SUBJECT_LEN"], diff --git a/zerver/actions/user_groups.py b/zerver/actions/user_groups.py index c55b0503a0..5b5ce51b04 100644 --- a/zerver/actions/user_groups.py +++ b/zerver/actions/user_groups.py @@ -8,7 +8,6 @@ from django.utils.translation import gettext as _ from zerver.lib.exceptions import JsonableError from zerver.lib.user_groups import ( - access_user_group_by_id, get_role_based_system_groups_dict, set_defaults_for_group_settings, ) @@ -402,8 +401,8 @@ def do_send_delete_user_group_event(realm: Realm, user_group_id: int, realm_id: send_event(realm, event, active_user_ids(realm_id)) -def check_delete_user_group(user_group_id: int, *, acting_user: UserProfile) -> None: - user_group = access_user_group_by_id(user_group_id, acting_user, for_read=False) +def check_delete_user_group(user_group: UserGroup, *, acting_user: UserProfile) -> None: + user_group_id = user_group.id user_group.delete() do_send_delete_user_group_event(acting_user.realm, user_group_id, acting_user.realm.id) diff --git a/zerver/lib/test_classes.py b/zerver/lib/test_classes.py index 131d2f078f..16164ecffd 100644 --- a/zerver/lib/test_classes.py +++ b/zerver/lib/test_classes.py @@ -1911,9 +1911,7 @@ class ZulipTestCase(ZulipTestCaseMixin, TestCase): self.assert_length(lst, expected_num_events) -def get_row_ids_in_all_tables() -> ( - Iterator[Tuple[str, Set[int]]] -): # nocoverage # Will be tested with the UserGroup transaction test case +def get_row_ids_in_all_tables() -> Iterator[Tuple[str, Set[int]]]: all_models = apps.get_models(include_auto_created=True) ignored_tables = {"django_session"} @@ -1947,13 +1945,11 @@ class ZulipTransactionTestCase(ZulipTestCaseMixin, TransactionTestCase): ZulipTransactionTestCase tests if they leak state. """ - def setUp(self) -> None: # nocoverage # Will be tested with the UserGroup transaction test case + def setUp(self) -> None: super().setUp() self.models_ids_set = dict(get_row_ids_in_all_tables()) - def tearDown( - self, - ) -> None: # nocoverage # Will be tested with the UserGroup transaction test case + def tearDown(self) -> None: """Verifies that the test did not adjust the set of rows in the test database. This is a sanity check to help ensure that tests using this class do not have unintended side effects on the @@ -1972,7 +1968,6 @@ class ZulipTransactionTestCase(ZulipTestCaseMixin, TransactionTestCase): TransactionTestCase, so that the test database does not get flushed/deleted after each test using this class. """ - # nocoverage # Will be tested with the UserGroup transaction test case class WebhookTestCase(ZulipTestCase): diff --git a/zerver/lib/test_helpers.py b/zerver/lib/test_helpers.py index 7bbf5e870b..4073ef1b3d 100644 --- a/zerver/lib/test_helpers.py +++ b/zerver/lib/test_helpers.py @@ -517,6 +517,8 @@ def write_instrumentation_reports(full_suite: bool, include_webhooks: bool) -> N "static/(?P.+)", "flush_caches", "external_content/(?P[^/]+)/(?P[^/]+)", + # Such endpoints are only used in certain test cases that can be skipped + "testing/(?P.+)", # These are SCIM2 urls overridden from django-scim2 to return Not Implemented. # We actually test them, but it's not being detected as a tested pattern, # possibly due to the use of re_path. TODO: Investigate and get them diff --git a/zerver/lib/user_groups.py b/zerver/lib/user_groups.py index 55b69fe50d..fefb0d09ee 100644 --- a/zerver/lib/user_groups.py +++ b/zerver/lib/user_groups.py @@ -1,4 +1,6 @@ -from typing import Dict, Iterable, List, Mapping, Sequence, TypedDict +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Collection, Dict, Iterable, Iterator, List, Mapping, TypedDict from django.db import transaction from django.db.models import F, QuerySet @@ -28,9 +30,31 @@ class UserGroupDict(TypedDict): can_mention_group: int +@dataclass +class LockedUserGroupContext: + """User groups in this dataclass are guaranteeed to be locked until the + end of the current transaction. + + supergroup is the user group to have subgroups added or removed; + direct_subgroups are user groups that are recursively queried for subgroups; + recursive_subgroups include direct_subgroups and their descendants. + """ + + supergroup: UserGroup + direct_subgroups: List[UserGroup] + recursive_subgroups: List[UserGroup] + + def has_user_group_access( - user_group: UserGroup, user_profile: UserProfile, *, for_read: bool + user_group: UserGroup, user_profile: UserProfile, *, for_read: bool, as_subgroup: bool ) -> bool: + if user_group.realm_id != user_profile.realm_id: + return False + + if as_subgroup: + # At this time, we only check for realm ID of a potential subgroup. + return True + if for_read and not user_profile.is_guest: # Everyone is allowed to read a user group and check who # are its members. Guests should be unable to reach this @@ -57,29 +81,87 @@ def access_user_group_by_id( user_group_id: int, user_profile: UserProfile, *, for_read: bool ) -> UserGroup: try: - user_group = UserGroup.objects.get(id=user_group_id, realm=user_profile.realm) + if for_read: + user_group = UserGroup.objects.get(id=user_group_id, realm=user_profile.realm) + else: + user_group = UserGroup.objects.select_for_update().get( + id=user_group_id, realm=user_profile.realm + ) except UserGroup.DoesNotExist: raise JsonableError(_("Invalid user group")) - if not has_user_group_access(user_group, user_profile, for_read=for_read): + if not has_user_group_access(user_group, user_profile, for_read=for_read, as_subgroup=False): raise JsonableError(_("Insufficient permission")) return user_group -def access_user_groups_as_potential_subgroups( - user_group_ids: Sequence[int], acting_user: UserProfile -) -> List[UserGroup]: - user_groups = UserGroup.objects.filter(id__in=user_group_ids, realm=acting_user.realm) +@contextmanager +def lock_subgroups_with_respect_to_supergroup( + potential_subgroup_ids: Collection[int], potential_supergroup_id: int, acting_user: UserProfile +) -> Iterator[LockedUserGroupContext]: + """This locks the user groups with the given potential_subgroup_ids, as well + as their indirect subgroups, followed by the potential supergroup. It + ensures that we lock the user groups in a consistent order topologically to + avoid unnecessary deadlocks on non-conflicting queries. - valid_group_ids = [group.id for group in user_groups] - invalid_group_ids = [group_id for group_id in user_group_ids if group_id not in valid_group_ids] - if invalid_group_ids: - raise JsonableError( - _("Invalid user group ID: {group_id}").format(group_id=invalid_group_ids[0]) + Regardless of whether the user groups returned are used, always call this + helper before making changes to subgroup memberships. This avoids + introducing cycles among user groups when there is a race condition in + which one of these subgroups become an ancestor of the parent user group in + another transaction. + + Note that it only does a permission check on the potential supergroup, + not the potential subgroups or their recursive subgroups. + """ + with transaction.atomic(savepoint=False): + # Calling list with the QuerySet forces its evaluation putting a lock on + # the queried rows. + recursive_subgroups = list( + get_recursive_subgroups_for_groups( + potential_subgroup_ids, acting_user.realm + ).select_for_update(nowait=True) ) + # TODO: This select_for_update query is subject to deadlocking, and + # better error handling is needed. We may use + # select_for_update(nowait=True) and release the locks held by ending + # the transaction with a JsonableError by handling the DatabaseError. + # But at the current scale of concurrent requests, we rely on + # Postgres's deadlock detection when it occurs. + potential_supergroup = access_user_group_by_id( + potential_supergroup_id, acting_user, for_read=False + ) + # We avoid making a separate query for user_group_ids because the + # recursive query already returns those user groups. + potential_subgroups = [ + user_group + for user_group in recursive_subgroups + if user_group.id in potential_subgroup_ids + ] - return list(user_groups) + # We expect that the passed user_group_ids each corresponds to an + # existing user group. + group_ids_found = [group.id for group in potential_subgroups] + group_ids_not_found = [ + group_id for group_id in potential_subgroup_ids if group_id not in group_ids_found + ] + if group_ids_not_found: + raise JsonableError( + _("Invalid user group ID: {group_id}").format(group_id=group_ids_not_found[0]) + ) + + for subgroup in potential_subgroups: + # At this time, we only do a check on the realm ID of the fetched + # subgroup. This would be caught by the check earlier, so there is + # no coverage here. + if not has_user_group_access(subgroup, acting_user, for_read=False, as_subgroup=True): + raise JsonableError(_("Insufficient permission")) # nocoverage + + yield LockedUserGroupContext( + direct_subgroups=potential_subgroups, + recursive_subgroups=recursive_subgroups, + supergroup=potential_supergroup, + ) def access_user_group_for_setting( @@ -266,9 +348,11 @@ def get_subgroup_ids(user_group: UserGroup, *, direct_subgroup_only: bool = Fals return list(subgroup_ids) -def get_recursive_subgroups_for_groups(user_group_ids: List[int]) -> QuerySet[UserGroup]: +def get_recursive_subgroups_for_groups( + user_group_ids: Iterable[int], realm: Realm +) -> QuerySet[UserGroup]: cte = With.recursive( - lambda cte: UserGroup.objects.filter(id__in=user_group_ids) + lambda cte: UserGroup.objects.filter(id__in=user_group_ids, realm=realm) .values(group_id=F("id")) .union(cte.join(UserGroup, direct_supergroups=cte.col.group_id).values(group_id=F("id"))) ) diff --git a/zerver/tests/test_events.py b/zerver/tests/test_events.py index 0f87fe14eb..6ed768152e 100644 --- a/zerver/tests/test_events.py +++ b/zerver/tests/test_events.py @@ -1463,9 +1463,7 @@ class NormalActionsTest(BaseAction): check_user_group_remove_subgroups("events[0]", events[0]) # Test remove event - events = self.verify_action( - lambda: check_delete_user_group(backend.id, acting_user=othello) - ) + events = self.verify_action(lambda: check_delete_user_group(backend, acting_user=othello)) check_user_group_remove("events[0]", events[0]) def test_default_stream_groups_events(self) -> None: diff --git a/zerver/tests/test_user_groups.py b/zerver/tests/test_user_groups.py index e61f57d816..a78d64cd78 100644 --- a/zerver/tests/test_user_groups.py +++ b/zerver/tests/test_user_groups.py @@ -3,10 +3,13 @@ from typing import Iterable, Optional from unittest import mock import orjson +from django.db import transaction from django.utils.timezone import now as timezone_now +from zerver.actions.create_realm import do_create_realm from zerver.actions.realm_settings import do_set_realm_property from zerver.actions.user_groups import ( + add_subgroups_to_user_group, check_add_user_group, create_user_group_in_database, promote_new_full_members, @@ -24,6 +27,7 @@ from zerver.lib.user_groups import ( get_recursive_subgroups, get_subgroup_ids, get_user_group_member_ids, + has_user_group_access, is_user_in_group, user_groups_in_realm_serialized, ) @@ -229,6 +233,23 @@ class UserGroupTestCase(ZulipTestCase): self.assertFalse(is_user_in_group(moderators_group, hamlet)) self.assertFalse(is_user_in_group(moderators_group, hamlet, direct_member_only=True)) + def test_has_user_group_access_to_subgroup(self) -> None: + iago = self.example_user("iago") + zulip_realm = get_realm("zulip") + zulip_group = check_add_user_group(zulip_realm, "zulip", [], acting_user=None) + moderators_group = UserGroup.objects.get( + name=UserGroup.MODERATORS_GROUP_NAME, realm=zulip_realm, is_system_group=True + ) + + lear_realm = get_realm("lear") + lear_group = check_add_user_group(lear_realm, "test", [], acting_user=None) + + self.assertFalse(has_user_group_access(lear_group, iago, for_read=False, as_subgroup=True)) + self.assertTrue(has_user_group_access(zulip_group, iago, for_read=False, as_subgroup=True)) + self.assertTrue( + has_user_group_access(moderators_group, iago, for_read=False, as_subgroup=True) + ) + class UserGroupAPITestCase(UserGroupTestCase): def test_user_group_create(self) -> None: @@ -580,8 +601,10 @@ class UserGroupAPITestCase(UserGroupTestCase): self.assertEqual(UserGroup.objects.filter(realm=hamlet.realm).count(), 9) self.assertEqual(UserGroupMembership.objects.count(), 44) self.assertFalse(UserGroup.objects.filter(id=user_group.id).exists()) - # Test when invalid user group is supplied - result = self.client_delete("/json/user_groups/1111") + # Test when invalid user group is supplied; transaction needed for + # error handling + with transaction.atomic(): + result = self.client_delete("/json/user_groups/1111") self.assert_json_error(result, "Invalid user group") lear_realm = get_realm("lear") @@ -804,7 +827,8 @@ class UserGroupAPITestCase(UserGroupTestCase): def check_delete_user_group(acting_user: str, error_msg: Optional[str] = None) -> None: self.login(acting_user) user_group = UserGroup.objects.get(name="support") - result = self.client_delete(f"/json/user_groups/{user_group.id}") + with transaction.atomic(): + result = self.client_delete(f"/json/user_groups/{user_group.id}") if error_msg is None: self.assert_json_success(result) self.assert_length(UserGroup.objects.filter(realm=realm), 9) @@ -1460,3 +1484,27 @@ class UserGroupAPITestCase(UserGroupTestCase): ).content ) self.assertCountEqual(result_dict["subgroups"], [admins_group.id]) + + def test_add_subgroup_from_wrong_realm(self) -> None: + other_realm = do_create_realm("other", "Other Realm") + other_user_group = check_add_user_group(other_realm, "user_group", [], acting_user=None) + + realm = get_realm("zulip") + zulip_group = check_add_user_group(realm, "zulip_test", [], acting_user=None) + + self.login("iago") + result = self.client_post( + f"/json/user_groups/{zulip_group.id}/subgroups", + {"add": orjson.dumps([other_user_group.id]).decode()}, + ) + self.assert_json_error(result, f"Invalid user group ID: {other_user_group.id}") + + # Having a subgroup from another realm is very unlikely because we do + # not allow cross-realm subgroups being added in the first place. But we + # test the handling in this scenario for completeness. + add_subgroups_to_user_group(zulip_group, [other_user_group], acting_user=None) + result = self.client_post( + f"/json/user_groups/{zulip_group.id}/subgroups", + {"delete": orjson.dumps([other_user_group.id]).decode()}, + ) + self.assert_json_error(result, f"Invalid user group ID: {other_user_group.id}") diff --git a/zerver/transaction_tests/test_user_groups.py b/zerver/transaction_tests/test_user_groups.py new file mode 100644 index 0000000000..adb82b58b6 --- /dev/null +++ b/zerver/transaction_tests/test_user_groups.py @@ -0,0 +1,158 @@ +import threading +from typing import TYPE_CHECKING, List, Optional + +import orjson +from django.db import connections, transaction + +from zerver.actions.user_groups import add_subgroups_to_user_group, check_add_user_group +from zerver.lib.test_classes import ZulipTransactionTestCase +from zerver.models import Realm, UserGroup, get_realm +from zerver.views.development import user_groups as user_group_view + +if TYPE_CHECKING: + from django.test.client import _MonkeyPatchedWSGIResponse as TestHttpResponse + + +class UserGroupRaceConditionTestCase(ZulipTransactionTestCase): + created_user_groups: List[UserGroup] = [] + counter = 0 + CHAIN_LENGTH = 3 + + def tearDown(self) -> None: + # Clean up the user groups created to minimize leakage + with transaction.atomic(): + for group in self.created_user_groups: + group.delete() + transaction.on_commit(lambda: self.created_user_groups.clear()) + + super().tearDown() + + def create_user_group_chain(self, realm: Realm) -> List[UserGroup]: + """Build a user groups forming a chain through group-group memberships + returning a list where each group is the supergroup of its subsequent group. + """ + groups = [ + check_add_user_group(realm, f"chain #{self.counter + i}", [], acting_user=None) + for i in range(self.CHAIN_LENGTH) + ] + self.counter += self.CHAIN_LENGTH + self.created_user_groups.extend(groups) + prev_group = groups[0] + for group in groups[1:]: + add_subgroups_to_user_group(prev_group, [group], acting_user=None) + prev_group = group + return groups + + def test_lock_subgroups_with_respect_to_supergroup(self) -> None: + realm = get_realm("zulip") + self.login("iago") + test_case = self + + class RacingThread(threading.Thread): + def __init__( + self, + subgroup_ids: List[int], + supergroup_id: int, + ) -> None: + threading.Thread.__init__(self) + self.response: Optional["TestHttpResponse"] = None + self.subgroup_ids = subgroup_ids + self.supergroup_id = supergroup_id + + def run(self) -> None: + try: + self.response = test_case.client_post( + url=f"/testing/user_groups/{self.supergroup_id}/subgroups", + info={"add": orjson.dumps(self.subgroup_ids).decode()}, + ) + finally: + # Close all thread-local database connections + connections.close_all() + + def assert_thread_success_count( + t1: RacingThread, + t2: RacingThread, + *, + success_count: int, + error_messsage: str = "", + ) -> None: + help_msg = """We access the test endpoint that wraps around the +real subgroup update endpoint by synchronizing them after the acquisition of the +first lock in the critical region. Though unlikely, this test might fail as we +have no control over the scheduler when the barrier timeouts. +""".strip() + barrier = threading.Barrier(parties=2, timeout=3) + + user_group_view.set_sync_after_recursive_query(barrier) + t1.start() + t2.start() + + succeeded = 0 + for t in [t1, t2]: + t.join() + response = t.response + if response is not None and response.status_code == 200: + succeeded += 1 + continue + + assert response is not None + self.assert_json_error(response, error_messsage) + # Race condition resolution should only allow one thread to succeed + self.assertEqual( + succeeded, + success_count, + f"Exactly {success_count} thread(s) should succeed.\n{help_msg}", + ) + + foo_chain = self.create_user_group_chain(realm) + bar_chain = self.create_user_group_chain(realm) + # These two threads are conflicting because a cycle would be formed if + # both of them succeed. There is a deadlock in such circular dependency. + assert_thread_success_count( + RacingThread( + subgroup_ids=[foo_chain[0].id], + supergroup_id=bar_chain[-1].id, + ), + RacingThread( + subgroup_ids=[bar_chain[-1].id], + supergroup_id=foo_chain[0].id, + ), + success_count=1, + error_messsage="Deadlock detected", + ) + + foo_chain = self.create_user_group_chain(realm) + bar_chain = self.create_user_group_chain(realm) + # These two requests would succeed if they didn't race with each other. + # However, both threads will attempt to grab a lock on overlapping rows + # when they first do the recursive query for subgroups. In this case, we + # expect that one of the threads fails due to nowait=True for the + # .select_for_update() call. + assert_thread_success_count( + RacingThread( + subgroup_ids=[foo_chain[0].id], + supergroup_id=bar_chain[-1].id, + ), + RacingThread( + subgroup_ids=[foo_chain[1].id], + supergroup_id=bar_chain[-1].id, + ), + success_count=1, + error_messsage="Busy lock detected", + ) + + foo_chain = self.create_user_group_chain(realm) + bar_chain = self.create_user_group_chain(realm) + baz_chain = self.create_user_group_chain(realm) + # Adding non-conflicting subgroups should succeed. + assert_thread_success_count( + RacingThread( + subgroup_ids=[foo_chain[1].id, foo_chain[2].id, baz_chain[2].id], + supergroup_id=baz_chain[0].id, + ), + RacingThread( + subgroup_ids=[bar_chain[1].id, bar_chain[2].id], + supergroup_id=baz_chain[0].id, + ), + success_count=2, + ) diff --git a/zerver/views/development/user_groups.py b/zerver/views/development/user_groups.py new file mode 100644 index 0000000000..37cdec827c --- /dev/null +++ b/zerver/views/development/user_groups.py @@ -0,0 +1,65 @@ +import threading +from typing import Any, Optional +from unittest import mock + +from django.db import OperationalError, transaction +from django.http import HttpRequest, HttpResponse + +from zerver.lib.exceptions import JsonableError +from zerver.lib.request import REQ, has_request_variables +from zerver.lib.response import json_success +from zerver.lib.user_groups import access_user_group_by_id +from zerver.lib.validator import check_int +from zerver.models import UserGroup, UserProfile +from zerver.views.user_groups import update_subgroups_of_user_group + +BARRIER: Optional[threading.Barrier] = None + + +def set_sync_after_recursive_query(barrier: Optional[threading.Barrier]) -> None: + global BARRIER + BARRIER = barrier + + +@has_request_variables +def dev_update_subgroups( + request: HttpRequest, + user_profile: UserProfile, + user_group_id: int = REQ(json_validator=check_int, path_only=True), +) -> HttpResponse: + # The test is expected to set up the barrier before accessing this endpoint. + assert BARRIER is not None + try: + with transaction.atomic(), mock.patch( + "zerver.lib.user_groups.access_user_group_by_id" + ) as m: + + def wait_after_recursive_query(*args: Any, **kwargs: Any) -> UserGroup: + # When updating the subgroups, we access the supergroup group + # only after finishing the recursive query. + BARRIER.wait() + return access_user_group_by_id(*args, **kwargs) + + m.side_effect = wait_after_recursive_query + + update_subgroups_of_user_group(request, user_profile, user_group_id=user_group_id) + except OperationalError as err: + msg = str(err) + if "deadlock detected" in msg: + raise JsonableError("Deadlock detected") + else: + assert "could not obtain lock" in msg + # This error is possible when nowait is set the True, which only + # applies to the recursive query on the subgroups. Because the + # recursive query fails, this thread must have not waited on the + # barrier yet. + BARRIER.wait() + raise JsonableError("Busy lock detected") + except ( + threading.BrokenBarrierError + ): # nocoverage # This is only possible when timeout happens or there is a programming error + raise JsonableError( + "Broken barrier. The tester should make sure that the exact number of parties have waited on the barrier set by the previous immediate set_sync_after_first_lock call" + ) + + return json_success(request) diff --git a/zerver/views/user_groups.py b/zerver/views/user_groups.py index 27d6af1ec1..32e1ce9a6d 100644 --- a/zerver/views/user_groups.py +++ b/zerver/views/user_groups.py @@ -1,6 +1,7 @@ from typing import List, Optional, Sequence from django.conf import settings +from django.db import transaction from django.http import HttpRequest, HttpResponse from django.utils.translation import gettext as _ from django.utils.translation import override as override_language @@ -25,14 +26,13 @@ from zerver.lib.response import json_success from zerver.lib.user_groups import ( access_user_group_by_id, access_user_group_for_setting, - access_user_groups_as_potential_subgroups, check_user_group_name, get_direct_memberships_of_users, - get_recursive_subgroups_for_groups, get_subgroup_ids, get_user_group_direct_member_ids, get_user_group_member_ids, is_user_in_group, + lock_subgroups_with_respect_to_supergroup, user_groups_in_realm_serialized, ) from zerver.lib.users import access_user_by_id, user_ids_to_users @@ -95,6 +95,7 @@ def get_user_group(request: HttpRequest, user_profile: UserProfile) -> HttpRespo return json_success(request, data={"user_groups": user_groups}) +@transaction.atomic @require_user_group_edit_permission @has_request_variables def edit_user_group( @@ -153,7 +154,11 @@ def delete_user_group( user_profile: UserProfile, user_group_id: int = REQ(json_validator=check_int, path_only=True), ) -> HttpResponse: - check_delete_user_group(user_group_id, acting_user=user_profile) + # For deletion, the user group's recursive subgroups and the user group itself are locked. + with lock_subgroups_with_respect_to_supergroup( + [user_group_id], user_group_id, acting_user=user_profile + ) as context: + check_delete_user_group(context.supergroup, acting_user=user_profile) return json_success(request) @@ -231,6 +236,7 @@ def notify_for_user_group_subscription_changes( do_send_messages(notifications) +@transaction.atomic def add_members_to_group_backend( request: HttpRequest, user_profile: UserProfile, user_group_id: int, members: Sequence[int] ) -> HttpResponse: @@ -260,6 +266,7 @@ def add_members_to_group_backend( return json_success(request) +@transaction.atomic def remove_members_from_group_backend( request: HttpRequest, user_profile: UserProfile, user_group_id: int, members: Sequence[int] ) -> HttpResponse: @@ -292,28 +299,33 @@ def add_subgroups_to_group_backend( if not subgroup_ids: return json_success(request) - subgroups = access_user_groups_as_potential_subgroups(subgroup_ids, user_profile) - user_group = access_user_group_by_id(user_group_id, user_profile, for_read=False) - existing_direct_subgroup_ids = user_group.direct_subgroups.all().values_list("id", flat=True) - for group in subgroups: - if group.id in existing_direct_subgroup_ids: - raise JsonableError( - _("User group {group_id} is already a subgroup of this group.").format( - group_id=group.id + with lock_subgroups_with_respect_to_supergroup( + subgroup_ids, user_group_id, user_profile + ) as context: + existing_direct_subgroup_ids = context.supergroup.direct_subgroups.all().values_list( + "id", flat=True + ) + for group in context.direct_subgroups: + if group.id in existing_direct_subgroup_ids: + raise JsonableError( + _("User group {group_id} is already a subgroup of this group.").format( + group_id=group.id + ) ) + + recursive_subgroup_ids = { + recursive_subgroup.id for recursive_subgroup in context.recursive_subgroups + } + if user_group_id in recursive_subgroup_ids: + raise JsonableError( + _( + "User group {user_group_id} is already a subgroup of one of the passed subgroups." + ).format(user_group_id=user_group_id) ) - subgroup_ids = [group.id for group in subgroups] - if user_group_id in get_recursive_subgroups_for_groups(subgroup_ids).values_list( - "id", flat=True - ): - raise JsonableError( - _( - "User group {user_group_id} is already a subgroup of one of the passed subgroups." - ).format(user_group_id=user_group_id) + add_subgroups_to_user_group( + context.supergroup, context.direct_subgroups, acting_user=user_profile ) - - add_subgroups_to_user_group(user_group, subgroups, acting_user=user_profile) return json_success(request) @@ -323,18 +335,27 @@ def remove_subgroups_from_group_backend( if not subgroup_ids: return json_success(request) - subgroups = access_user_groups_as_potential_subgroups(subgroup_ids, user_profile) - user_group = access_user_group_by_id(user_group_id, user_profile, for_read=False) - existing_direct_subgroup_ids = user_group.direct_subgroups.all().values_list("id", flat=True) - for group in subgroups: - if group.id not in existing_direct_subgroup_ids: - raise JsonableError( - _("User group {group_id} is not a subgroup of this group.").format( - group_id=group.id + with lock_subgroups_with_respect_to_supergroup( + subgroup_ids, user_group_id, user_profile + ) as context: + # While the recursive subgroups in the context are not used, it is important that + # we acquire a lock for these rows while updating the subgroups to acquire the locks + # in a consistent order for subgroup membership changes. + existing_direct_subgroup_ids = context.supergroup.direct_subgroups.all().values_list( + "id", flat=True + ) + for group in context.direct_subgroups: + if group.id not in existing_direct_subgroup_ids: + raise JsonableError( + _("User group {group_id} is not a subgroup of this group.").format( + group_id=group.id + ) ) - ) - remove_subgroups_from_user_group(user_group, subgroups, acting_user=user_profile) + remove_subgroups_from_user_group( + context.supergroup, context.direct_subgroups, acting_user=user_profile + ) + return json_success(request) diff --git a/zproject/dev_urls.py b/zproject/dev_urls.py index cb0d8fbb97..7255245219 100644 --- a/zproject/dev_urls.py +++ b/zproject/dev_urls.py @@ -10,6 +10,7 @@ from django.urls import path from django.views.generic import TemplateView from django.views.static import serve +from zerver.lib.rest import rest_path from zerver.views.auth import config_error, login_page from zerver.views.development.cache import remove_caches from zerver.views.development.camo import handle_camo_url @@ -31,6 +32,7 @@ from zerver.views.development.registration import ( register_development_realm, register_development_user, ) +from zerver.views.development.user_groups import dev_update_subgroups # These URLs are available only in the development environment @@ -98,6 +100,14 @@ urls = [ path("external_content//", handle_camo_url), ] +testing_urls = [ + rest_path( + "testing/user_groups//subgroups", + POST=(dev_update_subgroups, {"intentionally_undocumented"}), + ), +] +urls += testing_urls + v1_api_mobile_patterns = [ # This is for the signing in through the devAuthBackEnd on mobile apps. path("dev_fetch_api_key", api_dev_fetch_api_key),