zulip/zerver/transaction_tests/test_user_groups.py

159 lines
6.0 KiB
Python

import threading
from typing import TYPE_CHECKING, List, Optional
import orjson
from django.db import connections, transaction
from zerver.actions.user_groups import add_subgroups_to_user_group, check_add_user_group
from zerver.lib.test_classes import ZulipTransactionTestCase
from zerver.models import Realm, UserGroup, get_realm
from zerver.views.development import user_groups as user_group_view
if TYPE_CHECKING:
from django.test.client import _MonkeyPatchedWSGIResponse as TestHttpResponse
class UserGroupRaceConditionTestCase(ZulipTransactionTestCase):
created_user_groups: List[UserGroup] = []
counter = 0
CHAIN_LENGTH = 3
def tearDown(self) -> None:
# Clean up the user groups created to minimize leakage
with transaction.atomic():
for group in self.created_user_groups:
group.delete()
transaction.on_commit(lambda: self.created_user_groups.clear())
super().tearDown()
def create_user_group_chain(self, realm: Realm) -> List[UserGroup]:
"""Build a user groups forming a chain through group-group memberships
returning a list where each group is the supergroup of its subsequent group.
"""
groups = [
check_add_user_group(realm, f"chain #{self.counter + i}", [], acting_user=None)
for i in range(self.CHAIN_LENGTH)
]
self.counter += self.CHAIN_LENGTH
self.created_user_groups.extend(groups)
prev_group = groups[0]
for group in groups[1:]:
add_subgroups_to_user_group(prev_group, [group], acting_user=None)
prev_group = group
return groups
def test_lock_subgroups_with_respect_to_supergroup(self) -> None:
realm = get_realm("zulip")
self.login("iago")
test_case = self
class RacingThread(threading.Thread):
def __init__(
self,
subgroup_ids: List[int],
supergroup_id: int,
) -> None:
threading.Thread.__init__(self)
self.response: Optional["TestHttpResponse"] = None
self.subgroup_ids = subgroup_ids
self.supergroup_id = supergroup_id
def run(self) -> None:
try:
self.response = test_case.client_post(
url=f"/testing/user_groups/{self.supergroup_id}/subgroups",
info={"add": orjson.dumps(self.subgroup_ids).decode()},
)
finally:
# Close all thread-local database connections
connections.close_all()
def assert_thread_success_count(
t1: RacingThread,
t2: RacingThread,
*,
success_count: int,
error_messsage: str = "",
) -> None:
help_msg = """We access the test endpoint that wraps around the
real subgroup update endpoint by synchronizing them after the acquisition of the
first lock in the critical region. Though unlikely, this test might fail as we
have no control over the scheduler when the barrier timeouts.
""".strip()
barrier = threading.Barrier(parties=2, timeout=3)
user_group_view.set_sync_after_recursive_query(barrier)
t1.start()
t2.start()
succeeded = 0
for t in [t1, t2]:
t.join()
response = t.response
if response is not None and response.status_code == 200:
succeeded += 1
continue
assert response is not None
self.assert_json_error(response, error_messsage)
# Race condition resolution should only allow one thread to succeed
self.assertEqual(
succeeded,
success_count,
f"Exactly {success_count} thread(s) should succeed.\n{help_msg}",
)
foo_chain = self.create_user_group_chain(realm)
bar_chain = self.create_user_group_chain(realm)
# These two threads are conflicting because a cycle would be formed if
# both of them succeed. There is a deadlock in such circular dependency.
assert_thread_success_count(
RacingThread(
subgroup_ids=[foo_chain[0].id],
supergroup_id=bar_chain[-1].id,
),
RacingThread(
subgroup_ids=[bar_chain[-1].id],
supergroup_id=foo_chain[0].id,
),
success_count=1,
error_messsage="Deadlock detected",
)
foo_chain = self.create_user_group_chain(realm)
bar_chain = self.create_user_group_chain(realm)
# These two requests would succeed if they didn't race with each other.
# However, both threads will attempt to grab a lock on overlapping rows
# when they first do the recursive query for subgroups. In this case, we
# expect that one of the threads fails due to nowait=True for the
# .select_for_update() call.
assert_thread_success_count(
RacingThread(
subgroup_ids=[foo_chain[0].id],
supergroup_id=bar_chain[-1].id,
),
RacingThread(
subgroup_ids=[foo_chain[1].id],
supergroup_id=bar_chain[-1].id,
),
success_count=1,
error_messsage="Busy lock detected",
)
foo_chain = self.create_user_group_chain(realm)
bar_chain = self.create_user_group_chain(realm)
baz_chain = self.create_user_group_chain(realm)
# Adding non-conflicting subgroups should succeed.
assert_thread_success_count(
RacingThread(
subgroup_ids=[foo_chain[1].id, foo_chain[2].id, baz_chain[2].id],
supergroup_id=baz_chain[0].id,
),
RacingThread(
subgroup_ids=[bar_chain[1].id, bar_chain[2].id],
supergroup_id=baz_chain[0].id,
),
success_count=2,
)