user_groups: Add get_recursive_strict_subgroups function.

This commit adds get_recursive_strict_subgroups function
which returns all the subgroups but not includes the user
group passed to the function.

We also update the test to check subgroups of named user
groups using the get_recursive_strict_subgroups function.
This is fine as we already test the get_recursive_subgroups
function.
This commit is contained in:
Sahil Batra 2024-04-24 14:56:50 +05:30 committed by Tim Abbott
parent ba196cfd6b
commit 75b1f32a19
2 changed files with 29 additions and 17 deletions

View File

@ -328,6 +328,18 @@ def get_recursive_subgroups(user_group: UserGroup) -> QuerySet[UserGroup]:
return cte.join(UserGroup, id=cte.col.group_id).with_cte(cte) return cte.join(UserGroup, id=cte.col.group_id).with_cte(cte)
def get_recursive_strict_subgroups(user_group: UserGroup) -> QuerySet[UserGroup]:
# Same as get_recursive_subgroups but does not include the
# user_group passed.
direct_subgroup_ids = user_group.direct_subgroups.all().values("id")
cte = With.recursive(
lambda cte: UserGroup.objects.filter(id__in=direct_subgroup_ids)
.values(group_id=F("id"))
.union(cte.join(UserGroup, direct_supergroups=cte.col.group_id).values(group_id=F("id")))
)
return cte.join(UserGroup, id=cte.col.group_id).with_cte(cte)
def get_recursive_group_members(user_group: UserGroup) -> QuerySet[UserProfile]: def get_recursive_group_members(user_group: UserGroup) -> QuerySet[UserProfile]:
return UserProfile.objects.filter(direct_groups__in=get_recursive_subgroups(user_group)) return UserProfile.objects.filter(direct_groups__in=get_recursive_subgroups(user_group))
@ -365,11 +377,7 @@ def get_subgroup_ids(user_group: UserGroup, *, direct_subgroup_only: bool = Fals
if direct_subgroup_only: if direct_subgroup_only:
subgroup_ids = user_group.direct_subgroups.all().values_list("id", flat=True) subgroup_ids = user_group.direct_subgroups.all().values_list("id", flat=True)
else: else:
subgroup_ids = ( subgroup_ids = get_recursive_strict_subgroups(user_group).values_list("id", flat=True)
get_recursive_subgroups(user_group)
.exclude(id=user_group.id)
.values_list("id", flat=True)
)
return list(subgroup_ids) return list(subgroup_ids)

View File

@ -25,6 +25,7 @@ from zerver.lib.user_groups import (
get_direct_user_groups, get_direct_user_groups,
get_recursive_group_members, get_recursive_group_members,
get_recursive_membership_groups, get_recursive_membership_groups,
get_recursive_strict_subgroups,
get_recursive_subgroups, get_recursive_subgroups,
get_subgroup_ids, get_subgroup_ids,
get_user_group_member_ids, get_user_group_member_ids,
@ -121,6 +122,13 @@ class UserGroupTestCase(ZulipTestCase):
[leadership_group, staff_group, everyone_group], [leadership_group, staff_group, everyone_group],
) )
self.assertCountEqual(list(get_recursive_strict_subgroups(leadership_group)), [])
self.assertCountEqual(list(get_recursive_strict_subgroups(staff_group)), [leadership_group])
self.assertCountEqual(
list(get_recursive_strict_subgroups(everyone_group)),
[leadership_group, staff_group],
)
self.assertCountEqual(list(get_recursive_group_members(leadership_group)), [desdemona]) self.assertCountEqual(list(get_recursive_group_members(leadership_group)), [desdemona])
self.assertCountEqual(list(get_recursive_group_members(staff_group)), [desdemona, iago]) self.assertCountEqual(list(get_recursive_group_members(staff_group)), [desdemona, iago])
self.assertCountEqual( self.assertCountEqual(
@ -162,35 +170,32 @@ class UserGroupTestCase(ZulipTestCase):
is_system_group=True, is_system_group=True,
) )
self.assertCountEqual(list(get_recursive_subgroups(owners_group)), [owners_group]) self.assertCountEqual(list(get_recursive_strict_subgroups(owners_group)), [])
self.assertCountEqual(list(get_recursive_strict_subgroups(admins_group)), [owners_group])
self.assertCountEqual( self.assertCountEqual(
list(get_recursive_subgroups(admins_group)), [owners_group, admins_group] list(get_recursive_strict_subgroups(moderators_group)),
[owners_group, admins_group],
) )
self.assertCountEqual( self.assertCountEqual(
list(get_recursive_subgroups(moderators_group)), list(get_recursive_strict_subgroups(full_members_group)),
[owners_group, admins_group, moderators_group], [owners_group, admins_group, moderators_group],
) )
self.assertCountEqual( self.assertCountEqual(
list(get_recursive_subgroups(full_members_group)), list(get_recursive_strict_subgroups(members_group)),
[owners_group, admins_group, moderators_group, full_members_group], [owners_group, admins_group, moderators_group, full_members_group],
) )
self.assertCountEqual( self.assertCountEqual(
list(get_recursive_subgroups(members_group)), list(get_recursive_strict_subgroups(everyone_group)),
[owners_group, admins_group, moderators_group, full_members_group, members_group],
)
self.assertCountEqual(
list(get_recursive_subgroups(everyone_group)),
[ [
owners_group, owners_group,
admins_group, admins_group,
moderators_group, moderators_group,
full_members_group, full_members_group,
members_group, members_group,
everyone_group,
], ],
) )
self.assertCountEqual( self.assertCountEqual(
list(get_recursive_subgroups(everyone_on_internet_group)), list(get_recursive_strict_subgroups(everyone_on_internet_group)),
[ [
owners_group, owners_group,
admins_group, admins_group,
@ -198,7 +203,6 @@ class UserGroupTestCase(ZulipTestCase):
full_members_group, full_members_group,
members_group, members_group,
everyone_group, everyone_group,
everyone_on_internet_group,
], ],
) )