From 85b7dbddbc6740f801e8483365be07c9b89a874f Mon Sep 17 00:00:00 2001 From: Sahil Batra Date: Sat, 20 Apr 2024 20:33:33 +0530 Subject: [PATCH] groups: Update subgroup to be NamedUserGroup. --- zerver/lib/user_groups.py | 14 +++++---- ...sergroup_foreign_keys_to_namedusergroup.py | 18 ++++++++++++ zerver/models/groups.py | 4 +-- zerver/tests/test_user_groups.py | 29 +++++++++++-------- 4 files changed, 46 insertions(+), 19 deletions(-) diff --git a/zerver/lib/user_groups.py b/zerver/lib/user_groups.py index ca2be48be1..f3c0d534ee 100644 --- a/zerver/lib/user_groups.py +++ b/zerver/lib/user_groups.py @@ -325,21 +325,25 @@ def get_recursive_subgroups(user_group: UserGroup) -> QuerySet[UserGroup]: cte = With.recursive( lambda cte: UserGroup.objects.filter(id=user_group.id) .values(group_id=F("id")) - .union(cte.join(UserGroup, direct_supergroups=cte.col.group_id).values(group_id=F("id"))) + .union( + cte.join(NamedUserGroup, direct_supergroups=cte.col.group_id).values(group_id=F("id")) + ) ) return cte.join(UserGroup, id=cte.col.group_id).with_cte(cte) -def get_recursive_strict_subgroups(user_group: UserGroup) -> QuerySet[UserGroup]: +def get_recursive_strict_subgroups(user_group: UserGroup) -> QuerySet[NamedUserGroup]: # Same as get_recursive_subgroups but does not include the # user_group passed. direct_subgroup_ids = user_group.direct_subgroups.all().values("id") cte = With.recursive( - lambda cte: UserGroup.objects.filter(id__in=direct_subgroup_ids) + lambda cte: NamedUserGroup.objects.filter(id__in=direct_subgroup_ids) .values(group_id=F("id")) - .union(cte.join(UserGroup, direct_supergroups=cte.col.group_id).values(group_id=F("id"))) + .union( + cte.join(NamedUserGroup, direct_supergroups=cte.col.group_id).values(group_id=F("id")) + ) ) - return cte.join(UserGroup, id=cte.col.group_id).with_cte(cte) + return cte.join(NamedUserGroup, id=cte.col.group_id).with_cte(cte) def get_recursive_group_members(user_group: UserGroup) -> QuerySet[UserProfile]: diff --git a/zerver/migrations/0514_update_usergroup_foreign_keys_to_namedusergroup.py b/zerver/migrations/0514_update_usergroup_foreign_keys_to_namedusergroup.py index 086999a07a..6014850658 100644 --- a/zerver/migrations/0514_update_usergroup_foreign_keys_to_namedusergroup.py +++ b/zerver/migrations/0514_update_usergroup_foreign_keys_to_namedusergroup.py @@ -24,4 +24,22 @@ class Migration(migrations.Migration): null=True, on_delete=django.db.models.deletion.CASCADE, to="zerver.namedusergroup" ), ), + migrations.AlterField( + model_name="groupgroupmembership", + name="subgroup", + field=models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="+", + to="zerver.namedusergroup", + ), + ), + migrations.AlterField( + model_name="usergroup", + name="direct_subgroups", + field=models.ManyToManyField( + related_name="direct_supergroups", + through="zerver.GroupGroupMembership", + to="zerver.namedusergroup", + ), + ), ] diff --git a/zerver/models/groups.py b/zerver/models/groups.py index 1ca99eba9a..b18be2b316 100644 --- a/zerver/models/groups.py +++ b/zerver/models/groups.py @@ -27,7 +27,7 @@ class UserGroup(models.Model): # type: ignore[django-manager-missing] # django- UserProfile, through="zerver.UserGroupMembership", related_name="direct_groups" ) direct_subgroups = models.ManyToManyField( - "self", + "zerver.NamedUserGroup", symmetrical=False, through="zerver.GroupGroupMembership", through_fields=("supergroup", "subgroup"), @@ -123,7 +123,7 @@ class UserGroupMembership(models.Model): class GroupGroupMembership(models.Model): supergroup = models.ForeignKey(UserGroup, on_delete=CASCADE, related_name="+") - subgroup = models.ForeignKey(UserGroup, on_delete=CASCADE, related_name="+") + subgroup = models.ForeignKey(NamedUserGroup, on_delete=CASCADE, related_name="+") class Meta: constraints = [ diff --git a/zerver/tests/test_user_groups.py b/zerver/tests/test_user_groups.py index 0c7506aa06..4348731e5a 100644 --- a/zerver/tests/test_user_groups.py +++ b/zerver/tests/test_user_groups.py @@ -33,7 +33,14 @@ from zerver.lib.user_groups import ( is_user_in_group, user_groups_in_realm_serialized, ) -from zerver.models import GroupGroupMembership, Realm, UserGroup, UserGroupMembership, UserProfile +from zerver.models import ( + GroupGroupMembership, + NamedUserGroup, + Realm, + UserGroup, + UserGroupMembership, + UserProfile, +) from zerver.models.groups import SystemGroups from zerver.models.realms import get_realm @@ -130,12 +137,10 @@ class UserGroupTestCase(ZulipTestCase): ) self.assertCountEqual(list(get_recursive_strict_subgroups(leadership_group)), []) - self.assertCountEqual( - list(get_recursive_strict_subgroups(staff_group)), [leadership_group.usergroup_ptr] - ) + self.assertCountEqual(list(get_recursive_strict_subgroups(staff_group)), [leadership_group]) self.assertCountEqual( list(get_recursive_strict_subgroups(everyone_group)), - [leadership_group.usergroup_ptr, staff_group.usergroup_ptr], + [leadership_group, staff_group], ) self.assertCountEqual(list(get_recursive_group_members(leadership_group)), [desdemona]) @@ -155,25 +160,25 @@ class UserGroupTestCase(ZulipTestCase): def test_subgroups_of_role_based_system_groups(self) -> None: realm = get_realm("zulip") - owners_group = UserGroup.objects.get( + owners_group = NamedUserGroup.objects.get( realm=realm, name=SystemGroups.OWNERS, is_system_group=True ) - admins_group = UserGroup.objects.get( + admins_group = NamedUserGroup.objects.get( realm=realm, name=SystemGroups.ADMINISTRATORS, is_system_group=True ) - moderators_group = UserGroup.objects.get( + moderators_group = NamedUserGroup.objects.get( realm=realm, name=SystemGroups.MODERATORS, is_system_group=True ) - full_members_group = UserGroup.objects.get( + full_members_group = NamedUserGroup.objects.get( realm=realm, name=SystemGroups.FULL_MEMBERS, is_system_group=True ) - members_group = UserGroup.objects.get( + members_group = NamedUserGroup.objects.get( realm=realm, name=SystemGroups.MEMBERS, is_system_group=True ) - everyone_group = UserGroup.objects.get( + everyone_group = NamedUserGroup.objects.get( realm=realm, name=SystemGroups.EVERYONE, is_system_group=True ) - everyone_on_internet_group = UserGroup.objects.get( + everyone_on_internet_group = NamedUserGroup.objects.get( realm=realm, name=SystemGroups.EVERYONE_ON_INTERNET, is_system_group=True,