user_groups: Make locks required for updating user group memberships.

**Background**

User groups are expected to comply with the DAG constraint for the
many-to-many inter-group membership. The check for this constraint has
to be performed recursively so that we can find all direct and indirect
subgroups of the user group to be added.

This kind of check is vulnerable to phantom reads which is possible at
the default read committed isolation level because we cannot guarantee
that the check is still valid when we are adding the subgroups to the
user group.

**Solution**

To avoid having another transaction concurrently update one of the
to-be-subgroup after the recursive check is done, and before the subgroup
is added, we use SELECT FOR UPDATE to lock the user group rows.

The lock needs to be acquired before a group membership change is about
to occur before any check has been conducted.

Suppose that we are adding subgroup B to supergroup A, the locking protocol
is specified as follows:

1. Acquire a lock for B and all its direct and indirect subgroups.
2. Acquire a lock for A.

For the removal of user groups, we acquire a lock for the user group to
be removed with all its direct and indirect subgroups. This is the special
case A=B, which is still complaint with the protocol.

**Error handling**

We currently rely on Postgres' deadlock detection to abort transactions
and show an error for the users. In the future, we might need some
recovery mechanism or at least better error handling.

**Notes**

An important note is that we need to reuse the recursive CTE query that
finds the direct and indirect subgroups when applying the lock on the
rows. And the lock needs to be acquired the same way for the addition and
removal of direct subgroups.

User membership change (as opposed to user group membership) is not
affected. Read-only queries aren't either. The locks only protect
critical regions where the user group dependency graph might violate
the DAG constraint, where users are not participating.

**Testing**

We implement a transaction test case targeting some typical scenarios
when an internal server error is expected to happen (this means that the
user group view makes the correct decision to abort the transaction when
something goes wrong with locks).

To achieve this, we add a development view intended only for unit tests.
It has a global BARRIER that can be shared across threads, so that we
can synchronize them to consistently reproduce certain potential race
conditions prevented by the database locks.

The transaction test case lanuches pairs of threads initiating possibly
conflicting requests at the same time. The tests are set up such that exactly N
of them are expected to succeed with a certain error message (while we don't
know each one).

**Security notes**

get_recursive_subgroups_for_groups will no longer fetch user groups from
other realms. As a result, trying to add/remove a subgroup from another
realm results in a UserGroup not found error response.

We also implement subgroup-specific checks in has_user_group_access to
keep permission managing in a single place. Do note that the API
currently don't have a way to violate that check because we are only
checking the realm ID now.
This commit is contained in:
Zixuan James Li 2023-06-16 22:39:52 -04:00 committed by Tim Abbott
parent 9f7fab4213
commit a081428ad2
12 changed files with 446 additions and 66 deletions

View File

@ -173,7 +173,7 @@ jobs:
- name: Run backend tests - name: Run backend tests
run: | run: |
source tools/ci/activate-venv source tools/ci/activate-venv
./tools/test-backend --coverage --xml-report --no-html-report --include-webhooks --no-cov-cleanup --ban-console-output ./tools/test-backend --coverage --xml-report --no-html-report --include-webhooks --include-transaction-tests --no-cov-cleanup --ban-console-output
- name: Run mypy - name: Run mypy
run: | run: |

View File

@ -224,7 +224,7 @@ python_rules = RuleList(
rules=[ rules=[
{ {
"pattern": "subject|SUBJECT", "pattern": "subject|SUBJECT",
"exclude_pattern": "subject to the|email|outbox|account deactivation", "exclude_pattern": "subject to the|email|outbox|account deactivation|is subject to",
"description": "avoid subject as a var", "description": "avoid subject as a var",
"good_lines": ["topic_name"], "good_lines": ["topic_name"],
"bad_lines": ['subject="foo"', " MAX_SUBJECT_LEN"], "bad_lines": ['subject="foo"', " MAX_SUBJECT_LEN"],

View File

@ -8,7 +8,6 @@ from django.utils.translation import gettext as _
from zerver.lib.exceptions import JsonableError from zerver.lib.exceptions import JsonableError
from zerver.lib.user_groups import ( from zerver.lib.user_groups import (
access_user_group_by_id,
get_role_based_system_groups_dict, get_role_based_system_groups_dict,
set_defaults_for_group_settings, set_defaults_for_group_settings,
) )
@ -402,8 +401,8 @@ def do_send_delete_user_group_event(realm: Realm, user_group_id: int, realm_id:
send_event(realm, event, active_user_ids(realm_id)) send_event(realm, event, active_user_ids(realm_id))
def check_delete_user_group(user_group_id: int, *, acting_user: UserProfile) -> None: def check_delete_user_group(user_group: UserGroup, *, acting_user: UserProfile) -> None:
user_group = access_user_group_by_id(user_group_id, acting_user, for_read=False) user_group_id = user_group.id
user_group.delete() user_group.delete()
do_send_delete_user_group_event(acting_user.realm, user_group_id, acting_user.realm.id) do_send_delete_user_group_event(acting_user.realm, user_group_id, acting_user.realm.id)

View File

@ -1911,9 +1911,7 @@ class ZulipTestCase(ZulipTestCaseMixin, TestCase):
self.assert_length(lst, expected_num_events) self.assert_length(lst, expected_num_events)
def get_row_ids_in_all_tables() -> ( def get_row_ids_in_all_tables() -> Iterator[Tuple[str, Set[int]]]:
Iterator[Tuple[str, Set[int]]]
): # nocoverage # Will be tested with the UserGroup transaction test case
all_models = apps.get_models(include_auto_created=True) all_models = apps.get_models(include_auto_created=True)
ignored_tables = {"django_session"} ignored_tables = {"django_session"}
@ -1947,13 +1945,11 @@ class ZulipTransactionTestCase(ZulipTestCaseMixin, TransactionTestCase):
ZulipTransactionTestCase tests if they leak state. ZulipTransactionTestCase tests if they leak state.
""" """
def setUp(self) -> None: # nocoverage # Will be tested with the UserGroup transaction test case def setUp(self) -> None:
super().setUp() super().setUp()
self.models_ids_set = dict(get_row_ids_in_all_tables()) self.models_ids_set = dict(get_row_ids_in_all_tables())
def tearDown( def tearDown(self) -> None:
self,
) -> None: # nocoverage # Will be tested with the UserGroup transaction test case
"""Verifies that the test did not adjust the set of rows in the test """Verifies that the test did not adjust the set of rows in the test
database. This is a sanity check to help ensure that tests database. This is a sanity check to help ensure that tests
using this class do not have unintended side effects on the using this class do not have unintended side effects on the
@ -1972,7 +1968,6 @@ class ZulipTransactionTestCase(ZulipTestCaseMixin, TransactionTestCase):
TransactionTestCase, so that the test database does not get TransactionTestCase, so that the test database does not get
flushed/deleted after each test using this class. flushed/deleted after each test using this class.
""" """
# nocoverage # Will be tested with the UserGroup transaction test case
class WebhookTestCase(ZulipTestCase): class WebhookTestCase(ZulipTestCase):

View File

@ -517,6 +517,8 @@ def write_instrumentation_reports(full_suite: bool, include_webhooks: bool) -> N
"static/(?P<path>.+)", "static/(?P<path>.+)",
"flush_caches", "flush_caches",
"external_content/(?P<digest>[^/]+)/(?P<received_url>[^/]+)", "external_content/(?P<digest>[^/]+)/(?P<received_url>[^/]+)",
# Such endpoints are only used in certain test cases that can be skipped
"testing/(?P<path>.+)",
# These are SCIM2 urls overridden from django-scim2 to return Not Implemented. # These are SCIM2 urls overridden from django-scim2 to return Not Implemented.
# We actually test them, but it's not being detected as a tested pattern, # We actually test them, but it's not being detected as a tested pattern,
# possibly due to the use of re_path. TODO: Investigate and get them # possibly due to the use of re_path. TODO: Investigate and get them

View File

@ -1,4 +1,6 @@
from typing import Dict, Iterable, List, Mapping, Sequence, TypedDict from contextlib import contextmanager
from dataclasses import dataclass
from typing import Collection, Dict, Iterable, Iterator, List, Mapping, TypedDict
from django.db import transaction from django.db import transaction
from django.db.models import F, QuerySet from django.db.models import F, QuerySet
@ -28,9 +30,31 @@ class UserGroupDict(TypedDict):
can_mention_group: int can_mention_group: int
@dataclass
class LockedUserGroupContext:
"""User groups in this dataclass are guaranteeed to be locked until the
end of the current transaction.
supergroup is the user group to have subgroups added or removed;
direct_subgroups are user groups that are recursively queried for subgroups;
recursive_subgroups include direct_subgroups and their descendants.
"""
supergroup: UserGroup
direct_subgroups: List[UserGroup]
recursive_subgroups: List[UserGroup]
def has_user_group_access( def has_user_group_access(
user_group: UserGroup, user_profile: UserProfile, *, for_read: bool user_group: UserGroup, user_profile: UserProfile, *, for_read: bool, as_subgroup: bool
) -> bool: ) -> bool:
if user_group.realm_id != user_profile.realm_id:
return False
if as_subgroup:
# At this time, we only check for realm ID of a potential subgroup.
return True
if for_read and not user_profile.is_guest: if for_read and not user_profile.is_guest:
# Everyone is allowed to read a user group and check who # Everyone is allowed to read a user group and check who
# are its members. Guests should be unable to reach this # are its members. Guests should be unable to reach this
@ -57,29 +81,87 @@ def access_user_group_by_id(
user_group_id: int, user_profile: UserProfile, *, for_read: bool user_group_id: int, user_profile: UserProfile, *, for_read: bool
) -> UserGroup: ) -> UserGroup:
try: try:
if for_read:
user_group = UserGroup.objects.get(id=user_group_id, realm=user_profile.realm) user_group = UserGroup.objects.get(id=user_group_id, realm=user_profile.realm)
else:
user_group = UserGroup.objects.select_for_update().get(
id=user_group_id, realm=user_profile.realm
)
except UserGroup.DoesNotExist: except UserGroup.DoesNotExist:
raise JsonableError(_("Invalid user group")) raise JsonableError(_("Invalid user group"))
if not has_user_group_access(user_group, user_profile, for_read=for_read): if not has_user_group_access(user_group, user_profile, for_read=for_read, as_subgroup=False):
raise JsonableError(_("Insufficient permission")) raise JsonableError(_("Insufficient permission"))
return user_group return user_group
def access_user_groups_as_potential_subgroups( @contextmanager
user_group_ids: Sequence[int], acting_user: UserProfile def lock_subgroups_with_respect_to_supergroup(
) -> List[UserGroup]: potential_subgroup_ids: Collection[int], potential_supergroup_id: int, acting_user: UserProfile
user_groups = UserGroup.objects.filter(id__in=user_group_ids, realm=acting_user.realm) ) -> Iterator[LockedUserGroupContext]:
"""This locks the user groups with the given potential_subgroup_ids, as well
as their indirect subgroups, followed by the potential supergroup. It
ensures that we lock the user groups in a consistent order topologically to
avoid unnecessary deadlocks on non-conflicting queries.
valid_group_ids = [group.id for group in user_groups] Regardless of whether the user groups returned are used, always call this
invalid_group_ids = [group_id for group_id in user_group_ids if group_id not in valid_group_ids] helper before making changes to subgroup memberships. This avoids
if invalid_group_ids: introducing cycles among user groups when there is a race condition in
which one of these subgroups become an ancestor of the parent user group in
another transaction.
Note that it only does a permission check on the potential supergroup,
not the potential subgroups or their recursive subgroups.
"""
with transaction.atomic(savepoint=False):
# Calling list with the QuerySet forces its evaluation putting a lock on
# the queried rows.
recursive_subgroups = list(
get_recursive_subgroups_for_groups(
potential_subgroup_ids, acting_user.realm
).select_for_update(nowait=True)
)
# TODO: This select_for_update query is subject to deadlocking, and
# better error handling is needed. We may use
# select_for_update(nowait=True) and release the locks held by ending
# the transaction with a JsonableError by handling the DatabaseError.
# But at the current scale of concurrent requests, we rely on
# Postgres's deadlock detection when it occurs.
potential_supergroup = access_user_group_by_id(
potential_supergroup_id, acting_user, for_read=False
)
# We avoid making a separate query for user_group_ids because the
# recursive query already returns those user groups.
potential_subgroups = [
user_group
for user_group in recursive_subgroups
if user_group.id in potential_subgroup_ids
]
# We expect that the passed user_group_ids each corresponds to an
# existing user group.
group_ids_found = [group.id for group in potential_subgroups]
group_ids_not_found = [
group_id for group_id in potential_subgroup_ids if group_id not in group_ids_found
]
if group_ids_not_found:
raise JsonableError( raise JsonableError(
_("Invalid user group ID: {group_id}").format(group_id=invalid_group_ids[0]) _("Invalid user group ID: {group_id}").format(group_id=group_ids_not_found[0])
) )
return list(user_groups) for subgroup in potential_subgroups:
# At this time, we only do a check on the realm ID of the fetched
# subgroup. This would be caught by the check earlier, so there is
# no coverage here.
if not has_user_group_access(subgroup, acting_user, for_read=False, as_subgroup=True):
raise JsonableError(_("Insufficient permission")) # nocoverage
yield LockedUserGroupContext(
direct_subgroups=potential_subgroups,
recursive_subgroups=recursive_subgroups,
supergroup=potential_supergroup,
)
def access_user_group_for_setting( def access_user_group_for_setting(
@ -266,9 +348,11 @@ def get_subgroup_ids(user_group: UserGroup, *, direct_subgroup_only: bool = Fals
return list(subgroup_ids) return list(subgroup_ids)
def get_recursive_subgroups_for_groups(user_group_ids: List[int]) -> QuerySet[UserGroup]: def get_recursive_subgroups_for_groups(
user_group_ids: Iterable[int], realm: Realm
) -> QuerySet[UserGroup]:
cte = With.recursive( cte = With.recursive(
lambda cte: UserGroup.objects.filter(id__in=user_group_ids) lambda cte: UserGroup.objects.filter(id__in=user_group_ids, realm=realm)
.values(group_id=F("id")) .values(group_id=F("id"))
.union(cte.join(UserGroup, direct_supergroups=cte.col.group_id).values(group_id=F("id"))) .union(cte.join(UserGroup, direct_supergroups=cte.col.group_id).values(group_id=F("id")))
) )

View File

@ -1463,9 +1463,7 @@ class NormalActionsTest(BaseAction):
check_user_group_remove_subgroups("events[0]", events[0]) check_user_group_remove_subgroups("events[0]", events[0])
# Test remove event # Test remove event
events = self.verify_action( events = self.verify_action(lambda: check_delete_user_group(backend, acting_user=othello))
lambda: check_delete_user_group(backend.id, acting_user=othello)
)
check_user_group_remove("events[0]", events[0]) check_user_group_remove("events[0]", events[0])
def test_default_stream_groups_events(self) -> None: def test_default_stream_groups_events(self) -> None:

View File

@ -3,10 +3,13 @@ from typing import Iterable, Optional
from unittest import mock from unittest import mock
import orjson import orjson
from django.db import transaction
from django.utils.timezone import now as timezone_now from django.utils.timezone import now as timezone_now
from zerver.actions.create_realm import do_create_realm
from zerver.actions.realm_settings import do_set_realm_property from zerver.actions.realm_settings import do_set_realm_property
from zerver.actions.user_groups import ( from zerver.actions.user_groups import (
add_subgroups_to_user_group,
check_add_user_group, check_add_user_group,
create_user_group_in_database, create_user_group_in_database,
promote_new_full_members, promote_new_full_members,
@ -24,6 +27,7 @@ from zerver.lib.user_groups import (
get_recursive_subgroups, get_recursive_subgroups,
get_subgroup_ids, get_subgroup_ids,
get_user_group_member_ids, get_user_group_member_ids,
has_user_group_access,
is_user_in_group, is_user_in_group,
user_groups_in_realm_serialized, user_groups_in_realm_serialized,
) )
@ -229,6 +233,23 @@ class UserGroupTestCase(ZulipTestCase):
self.assertFalse(is_user_in_group(moderators_group, hamlet)) self.assertFalse(is_user_in_group(moderators_group, hamlet))
self.assertFalse(is_user_in_group(moderators_group, hamlet, direct_member_only=True)) self.assertFalse(is_user_in_group(moderators_group, hamlet, direct_member_only=True))
def test_has_user_group_access_to_subgroup(self) -> None:
iago = self.example_user("iago")
zulip_realm = get_realm("zulip")
zulip_group = check_add_user_group(zulip_realm, "zulip", [], acting_user=None)
moderators_group = UserGroup.objects.get(
name=UserGroup.MODERATORS_GROUP_NAME, realm=zulip_realm, is_system_group=True
)
lear_realm = get_realm("lear")
lear_group = check_add_user_group(lear_realm, "test", [], acting_user=None)
self.assertFalse(has_user_group_access(lear_group, iago, for_read=False, as_subgroup=True))
self.assertTrue(has_user_group_access(zulip_group, iago, for_read=False, as_subgroup=True))
self.assertTrue(
has_user_group_access(moderators_group, iago, for_read=False, as_subgroup=True)
)
class UserGroupAPITestCase(UserGroupTestCase): class UserGroupAPITestCase(UserGroupTestCase):
def test_user_group_create(self) -> None: def test_user_group_create(self) -> None:
@ -580,7 +601,9 @@ class UserGroupAPITestCase(UserGroupTestCase):
self.assertEqual(UserGroup.objects.filter(realm=hamlet.realm).count(), 9) self.assertEqual(UserGroup.objects.filter(realm=hamlet.realm).count(), 9)
self.assertEqual(UserGroupMembership.objects.count(), 44) self.assertEqual(UserGroupMembership.objects.count(), 44)
self.assertFalse(UserGroup.objects.filter(id=user_group.id).exists()) self.assertFalse(UserGroup.objects.filter(id=user_group.id).exists())
# Test when invalid user group is supplied # Test when invalid user group is supplied; transaction needed for
# error handling
with transaction.atomic():
result = self.client_delete("/json/user_groups/1111") result = self.client_delete("/json/user_groups/1111")
self.assert_json_error(result, "Invalid user group") self.assert_json_error(result, "Invalid user group")
@ -804,6 +827,7 @@ class UserGroupAPITestCase(UserGroupTestCase):
def check_delete_user_group(acting_user: str, error_msg: Optional[str] = None) -> None: def check_delete_user_group(acting_user: str, error_msg: Optional[str] = None) -> None:
self.login(acting_user) self.login(acting_user)
user_group = UserGroup.objects.get(name="support") user_group = UserGroup.objects.get(name="support")
with transaction.atomic():
result = self.client_delete(f"/json/user_groups/{user_group.id}") result = self.client_delete(f"/json/user_groups/{user_group.id}")
if error_msg is None: if error_msg is None:
self.assert_json_success(result) self.assert_json_success(result)
@ -1460,3 +1484,27 @@ class UserGroupAPITestCase(UserGroupTestCase):
).content ).content
) )
self.assertCountEqual(result_dict["subgroups"], [admins_group.id]) self.assertCountEqual(result_dict["subgroups"], [admins_group.id])
def test_add_subgroup_from_wrong_realm(self) -> None:
other_realm = do_create_realm("other", "Other Realm")
other_user_group = check_add_user_group(other_realm, "user_group", [], acting_user=None)
realm = get_realm("zulip")
zulip_group = check_add_user_group(realm, "zulip_test", [], acting_user=None)
self.login("iago")
result = self.client_post(
f"/json/user_groups/{zulip_group.id}/subgroups",
{"add": orjson.dumps([other_user_group.id]).decode()},
)
self.assert_json_error(result, f"Invalid user group ID: {other_user_group.id}")
# Having a subgroup from another realm is very unlikely because we do
# not allow cross-realm subgroups being added in the first place. But we
# test the handling in this scenario for completeness.
add_subgroups_to_user_group(zulip_group, [other_user_group], acting_user=None)
result = self.client_post(
f"/json/user_groups/{zulip_group.id}/subgroups",
{"delete": orjson.dumps([other_user_group.id]).decode()},
)
self.assert_json_error(result, f"Invalid user group ID: {other_user_group.id}")

View File

@ -0,0 +1,158 @@
import threading
from typing import TYPE_CHECKING, List, Optional
import orjson
from django.db import connections, transaction
from zerver.actions.user_groups import add_subgroups_to_user_group, check_add_user_group
from zerver.lib.test_classes import ZulipTransactionTestCase
from zerver.models import Realm, UserGroup, get_realm
from zerver.views.development import user_groups as user_group_view
if TYPE_CHECKING:
from django.test.client import _MonkeyPatchedWSGIResponse as TestHttpResponse
class UserGroupRaceConditionTestCase(ZulipTransactionTestCase):
created_user_groups: List[UserGroup] = []
counter = 0
CHAIN_LENGTH = 3
def tearDown(self) -> None:
# Clean up the user groups created to minimize leakage
with transaction.atomic():
for group in self.created_user_groups:
group.delete()
transaction.on_commit(lambda: self.created_user_groups.clear())
super().tearDown()
def create_user_group_chain(self, realm: Realm) -> List[UserGroup]:
"""Build a user groups forming a chain through group-group memberships
returning a list where each group is the supergroup of its subsequent group.
"""
groups = [
check_add_user_group(realm, f"chain #{self.counter + i}", [], acting_user=None)
for i in range(self.CHAIN_LENGTH)
]
self.counter += self.CHAIN_LENGTH
self.created_user_groups.extend(groups)
prev_group = groups[0]
for group in groups[1:]:
add_subgroups_to_user_group(prev_group, [group], acting_user=None)
prev_group = group
return groups
def test_lock_subgroups_with_respect_to_supergroup(self) -> None:
realm = get_realm("zulip")
self.login("iago")
test_case = self
class RacingThread(threading.Thread):
def __init__(
self,
subgroup_ids: List[int],
supergroup_id: int,
) -> None:
threading.Thread.__init__(self)
self.response: Optional["TestHttpResponse"] = None
self.subgroup_ids = subgroup_ids
self.supergroup_id = supergroup_id
def run(self) -> None:
try:
self.response = test_case.client_post(
url=f"/testing/user_groups/{self.supergroup_id}/subgroups",
info={"add": orjson.dumps(self.subgroup_ids).decode()},
)
finally:
# Close all thread-local database connections
connections.close_all()
def assert_thread_success_count(
t1: RacingThread,
t2: RacingThread,
*,
success_count: int,
error_messsage: str = "",
) -> None:
help_msg = """We access the test endpoint that wraps around the
real subgroup update endpoint by synchronizing them after the acquisition of the
first lock in the critical region. Though unlikely, this test might fail as we
have no control over the scheduler when the barrier timeouts.
""".strip()
barrier = threading.Barrier(parties=2, timeout=3)
user_group_view.set_sync_after_recursive_query(barrier)
t1.start()
t2.start()
succeeded = 0
for t in [t1, t2]:
t.join()
response = t.response
if response is not None and response.status_code == 200:
succeeded += 1
continue
assert response is not None
self.assert_json_error(response, error_messsage)
# Race condition resolution should only allow one thread to succeed
self.assertEqual(
succeeded,
success_count,
f"Exactly {success_count} thread(s) should succeed.\n{help_msg}",
)
foo_chain = self.create_user_group_chain(realm)
bar_chain = self.create_user_group_chain(realm)
# These two threads are conflicting because a cycle would be formed if
# both of them succeed. There is a deadlock in such circular dependency.
assert_thread_success_count(
RacingThread(
subgroup_ids=[foo_chain[0].id],
supergroup_id=bar_chain[-1].id,
),
RacingThread(
subgroup_ids=[bar_chain[-1].id],
supergroup_id=foo_chain[0].id,
),
success_count=1,
error_messsage="Deadlock detected",
)
foo_chain = self.create_user_group_chain(realm)
bar_chain = self.create_user_group_chain(realm)
# These two requests would succeed if they didn't race with each other.
# However, both threads will attempt to grab a lock on overlapping rows
# when they first do the recursive query for subgroups. In this case, we
# expect that one of the threads fails due to nowait=True for the
# .select_for_update() call.
assert_thread_success_count(
RacingThread(
subgroup_ids=[foo_chain[0].id],
supergroup_id=bar_chain[-1].id,
),
RacingThread(
subgroup_ids=[foo_chain[1].id],
supergroup_id=bar_chain[-1].id,
),
success_count=1,
error_messsage="Busy lock detected",
)
foo_chain = self.create_user_group_chain(realm)
bar_chain = self.create_user_group_chain(realm)
baz_chain = self.create_user_group_chain(realm)
# Adding non-conflicting subgroups should succeed.
assert_thread_success_count(
RacingThread(
subgroup_ids=[foo_chain[1].id, foo_chain[2].id, baz_chain[2].id],
supergroup_id=baz_chain[0].id,
),
RacingThread(
subgroup_ids=[bar_chain[1].id, bar_chain[2].id],
supergroup_id=baz_chain[0].id,
),
success_count=2,
)

View File

@ -0,0 +1,65 @@
import threading
from typing import Any, Optional
from unittest import mock
from django.db import OperationalError, transaction
from django.http import HttpRequest, HttpResponse
from zerver.lib.exceptions import JsonableError
from zerver.lib.request import REQ, has_request_variables
from zerver.lib.response import json_success
from zerver.lib.user_groups import access_user_group_by_id
from zerver.lib.validator import check_int
from zerver.models import UserGroup, UserProfile
from zerver.views.user_groups import update_subgroups_of_user_group
BARRIER: Optional[threading.Barrier] = None
def set_sync_after_recursive_query(barrier: Optional[threading.Barrier]) -> None:
global BARRIER
BARRIER = barrier
@has_request_variables
def dev_update_subgroups(
request: HttpRequest,
user_profile: UserProfile,
user_group_id: int = REQ(json_validator=check_int, path_only=True),
) -> HttpResponse:
# The test is expected to set up the barrier before accessing this endpoint.
assert BARRIER is not None
try:
with transaction.atomic(), mock.patch(
"zerver.lib.user_groups.access_user_group_by_id"
) as m:
def wait_after_recursive_query(*args: Any, **kwargs: Any) -> UserGroup:
# When updating the subgroups, we access the supergroup group
# only after finishing the recursive query.
BARRIER.wait()
return access_user_group_by_id(*args, **kwargs)
m.side_effect = wait_after_recursive_query
update_subgroups_of_user_group(request, user_profile, user_group_id=user_group_id)
except OperationalError as err:
msg = str(err)
if "deadlock detected" in msg:
raise JsonableError("Deadlock detected")
else:
assert "could not obtain lock" in msg
# This error is possible when nowait is set the True, which only
# applies to the recursive query on the subgroups. Because the
# recursive query fails, this thread must have not waited on the
# barrier yet.
BARRIER.wait()
raise JsonableError("Busy lock detected")
except (
threading.BrokenBarrierError
): # nocoverage # This is only possible when timeout happens or there is a programming error
raise JsonableError(
"Broken barrier. The tester should make sure that the exact number of parties have waited on the barrier set by the previous immediate set_sync_after_first_lock call"
)
return json_success(request)

View File

@ -1,6 +1,7 @@
from typing import List, Optional, Sequence from typing import List, Optional, Sequence
from django.conf import settings from django.conf import settings
from django.db import transaction
from django.http import HttpRequest, HttpResponse from django.http import HttpRequest, HttpResponse
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from django.utils.translation import override as override_language from django.utils.translation import override as override_language
@ -25,14 +26,13 @@ from zerver.lib.response import json_success
from zerver.lib.user_groups import ( from zerver.lib.user_groups import (
access_user_group_by_id, access_user_group_by_id,
access_user_group_for_setting, access_user_group_for_setting,
access_user_groups_as_potential_subgroups,
check_user_group_name, check_user_group_name,
get_direct_memberships_of_users, get_direct_memberships_of_users,
get_recursive_subgroups_for_groups,
get_subgroup_ids, get_subgroup_ids,
get_user_group_direct_member_ids, get_user_group_direct_member_ids,
get_user_group_member_ids, get_user_group_member_ids,
is_user_in_group, is_user_in_group,
lock_subgroups_with_respect_to_supergroup,
user_groups_in_realm_serialized, user_groups_in_realm_serialized,
) )
from zerver.lib.users import access_user_by_id, user_ids_to_users from zerver.lib.users import access_user_by_id, user_ids_to_users
@ -95,6 +95,7 @@ def get_user_group(request: HttpRequest, user_profile: UserProfile) -> HttpRespo
return json_success(request, data={"user_groups": user_groups}) return json_success(request, data={"user_groups": user_groups})
@transaction.atomic
@require_user_group_edit_permission @require_user_group_edit_permission
@has_request_variables @has_request_variables
def edit_user_group( def edit_user_group(
@ -153,7 +154,11 @@ def delete_user_group(
user_profile: UserProfile, user_profile: UserProfile,
user_group_id: int = REQ(json_validator=check_int, path_only=True), user_group_id: int = REQ(json_validator=check_int, path_only=True),
) -> HttpResponse: ) -> HttpResponse:
check_delete_user_group(user_group_id, acting_user=user_profile) # For deletion, the user group's recursive subgroups and the user group itself are locked.
with lock_subgroups_with_respect_to_supergroup(
[user_group_id], user_group_id, acting_user=user_profile
) as context:
check_delete_user_group(context.supergroup, acting_user=user_profile)
return json_success(request) return json_success(request)
@ -231,6 +236,7 @@ def notify_for_user_group_subscription_changes(
do_send_messages(notifications) do_send_messages(notifications)
@transaction.atomic
def add_members_to_group_backend( def add_members_to_group_backend(
request: HttpRequest, user_profile: UserProfile, user_group_id: int, members: Sequence[int] request: HttpRequest, user_profile: UserProfile, user_group_id: int, members: Sequence[int]
) -> HttpResponse: ) -> HttpResponse:
@ -260,6 +266,7 @@ def add_members_to_group_backend(
return json_success(request) return json_success(request)
@transaction.atomic
def remove_members_from_group_backend( def remove_members_from_group_backend(
request: HttpRequest, user_profile: UserProfile, user_group_id: int, members: Sequence[int] request: HttpRequest, user_profile: UserProfile, user_group_id: int, members: Sequence[int]
) -> HttpResponse: ) -> HttpResponse:
@ -292,10 +299,13 @@ def add_subgroups_to_group_backend(
if not subgroup_ids: if not subgroup_ids:
return json_success(request) return json_success(request)
subgroups = access_user_groups_as_potential_subgroups(subgroup_ids, user_profile) with lock_subgroups_with_respect_to_supergroup(
user_group = access_user_group_by_id(user_group_id, user_profile, for_read=False) subgroup_ids, user_group_id, user_profile
existing_direct_subgroup_ids = user_group.direct_subgroups.all().values_list("id", flat=True) ) as context:
for group in subgroups: existing_direct_subgroup_ids = context.supergroup.direct_subgroups.all().values_list(
"id", flat=True
)
for group in context.direct_subgroups:
if group.id in existing_direct_subgroup_ids: if group.id in existing_direct_subgroup_ids:
raise JsonableError( raise JsonableError(
_("User group {group_id} is already a subgroup of this group.").format( _("User group {group_id} is already a subgroup of this group.").format(
@ -303,17 +313,19 @@ def add_subgroups_to_group_backend(
) )
) )
subgroup_ids = [group.id for group in subgroups] recursive_subgroup_ids = {
if user_group_id in get_recursive_subgroups_for_groups(subgroup_ids).values_list( recursive_subgroup.id for recursive_subgroup in context.recursive_subgroups
"id", flat=True }
): if user_group_id in recursive_subgroup_ids:
raise JsonableError( raise JsonableError(
_( _(
"User group {user_group_id} is already a subgroup of one of the passed subgroups." "User group {user_group_id} is already a subgroup of one of the passed subgroups."
).format(user_group_id=user_group_id) ).format(user_group_id=user_group_id)
) )
add_subgroups_to_user_group(user_group, subgroups, acting_user=user_profile) add_subgroups_to_user_group(
context.supergroup, context.direct_subgroups, acting_user=user_profile
)
return json_success(request) return json_success(request)
@ -323,10 +335,16 @@ def remove_subgroups_from_group_backend(
if not subgroup_ids: if not subgroup_ids:
return json_success(request) return json_success(request)
subgroups = access_user_groups_as_potential_subgroups(subgroup_ids, user_profile) with lock_subgroups_with_respect_to_supergroup(
user_group = access_user_group_by_id(user_group_id, user_profile, for_read=False) subgroup_ids, user_group_id, user_profile
existing_direct_subgroup_ids = user_group.direct_subgroups.all().values_list("id", flat=True) ) as context:
for group in subgroups: # While the recursive subgroups in the context are not used, it is important that
# we acquire a lock for these rows while updating the subgroups to acquire the locks
# in a consistent order for subgroup membership changes.
existing_direct_subgroup_ids = context.supergroup.direct_subgroups.all().values_list(
"id", flat=True
)
for group in context.direct_subgroups:
if group.id not in existing_direct_subgroup_ids: if group.id not in existing_direct_subgroup_ids:
raise JsonableError( raise JsonableError(
_("User group {group_id} is not a subgroup of this group.").format( _("User group {group_id} is not a subgroup of this group.").format(
@ -334,7 +352,10 @@ def remove_subgroups_from_group_backend(
) )
) )
remove_subgroups_from_user_group(user_group, subgroups, acting_user=user_profile) remove_subgroups_from_user_group(
context.supergroup, context.direct_subgroups, acting_user=user_profile
)
return json_success(request) return json_success(request)

View File

@ -10,6 +10,7 @@ from django.urls import path
from django.views.generic import TemplateView from django.views.generic import TemplateView
from django.views.static import serve from django.views.static import serve
from zerver.lib.rest import rest_path
from zerver.views.auth import config_error, login_page from zerver.views.auth import config_error, login_page
from zerver.views.development.cache import remove_caches from zerver.views.development.cache import remove_caches
from zerver.views.development.camo import handle_camo_url from zerver.views.development.camo import handle_camo_url
@ -31,6 +32,7 @@ from zerver.views.development.registration import (
register_development_realm, register_development_realm,
register_development_user, register_development_user,
) )
from zerver.views.development.user_groups import dev_update_subgroups
# These URLs are available only in the development environment # These URLs are available only in the development environment
@ -98,6 +100,14 @@ urls = [
path("external_content/<digest>/<received_url>", handle_camo_url), path("external_content/<digest>/<received_url>", handle_camo_url),
] ]
testing_urls = [
rest_path(
"testing/user_groups/<int:user_group_id>/subgroups",
POST=(dev_update_subgroups, {"intentionally_undocumented"}),
),
]
urls += testing_urls
v1_api_mobile_patterns = [ v1_api_mobile_patterns = [
# This is for the signing in through the devAuthBackEnd on mobile apps. # This is for the signing in through the devAuthBackEnd on mobile apps.
path("dev_fetch_api_key", api_dev_fetch_api_key), path("dev_fetch_api_key", api_dev_fetch_api_key),