diff --git a/zerver/lib/export.py b/zerver/lib/export.py index 8a861ed53f..7e12baf987 100644 --- a/zerver/lib/export.py +++ b/zerver/lib/export.py @@ -38,6 +38,7 @@ from zerver.models import ( CustomProfileField, CustomProfileFieldValue, DefaultStream, + GroupGroupMembership, Huddle, Message, Reaction, @@ -127,6 +128,7 @@ ALL_ZULIP_TABLES = { "zerver_defaultstreamgroup_streams", "zerver_draft", "zerver_emailchangestatus", + "zerver_groupgroupmembership", "zerver_huddle", "zerver_message", "zerver_missedmessageemailaddress", @@ -705,6 +707,13 @@ def get_realm_config() -> Config: parent_key="user_group__in", ) + Config( + table="zerver_groupgroupmembership", + model=GroupGroupMembership, + normal_parent=user_groups_config, + parent_key="supergroup__in", + ) + Config( custom_tables=[ "zerver_userprofile_crossrealm", diff --git a/zerver/lib/import_realm.py b/zerver/lib/import_realm.py index d2c4934581..e7ad6b4c27 100644 --- a/zerver/lib/import_realm.py +++ b/zerver/lib/import_realm.py @@ -42,6 +42,7 @@ from zerver.models import ( CustomProfileField, CustomProfileFieldValue, DefaultStream, + GroupGroupMembership, Huddle, Message, MutedUser, @@ -121,6 +122,7 @@ ID_MAP: Dict[str, Dict[int, int]] = { "service": {}, "usergroup": {}, "usergroupmembership": {}, + "groupgroupmembership": {}, "botstoragedata": {}, "botconfigdata": {}, "analytics_realmcount": {}, @@ -1121,6 +1123,9 @@ def do_import_realm(import_dir: Path, subdomain: str, processes: int = 1) -> Rea re_map_foreign_keys_many_to_many( data, "zerver_usergroup", "direct_members", related_table="user_profile" ) + re_map_foreign_keys_many_to_many( + data, "zerver_usergroup", "direct_subgroups", related_table="usergroup" + ) update_model_ids(UserGroup, data, "usergroup") bulk_import_model(data, UserGroup) @@ -1133,6 +1138,15 @@ def do_import_realm(import_dir: Path, subdomain: str, processes: int = 1) -> Rea update_model_ids(UserGroupMembership, data, "usergroupmembership") bulk_import_model(data, UserGroupMembership) + re_map_foreign_keys( + data, "zerver_groupgroupmembership", "supergroup", related_table="usergroup" + ) + re_map_foreign_keys( + data, "zerver_groupgroupmembership", "subgroup", related_table="usergroup" + ) + update_model_ids(GroupGroupMembership, data, "groupgroupmembership") + bulk_import_model(data, GroupGroupMembership) + if "zerver_botstoragedata" in data: re_map_foreign_keys( data, "zerver_botstoragedata", "bot_profile", related_table="user_profile" diff --git a/zerver/lib/user_groups.py b/zerver/lib/user_groups.py index 26dec2b773..bf1f9e78e4 100644 --- a/zerver/lib/user_groups.py +++ b/zerver/lib/user_groups.py @@ -1,7 +1,9 @@ from typing import Any, Dict, List from django.db import transaction +from django.db.models import QuerySet from django.utils.translation import gettext as _ +from django_cte import With from zerver.lib.exceptions import JsonableError from zerver.models import Realm, UserGroup, UserGroupMembership, UserProfile @@ -95,3 +97,34 @@ def get_direct_memberships_of_users(user_group: UserGroup, members: List[UserPro user_group=user_group, user_profile__in=members ).values_list("user_profile_id", flat=True) ) + + +# These recursive lookups use standard PostgreSQL common table +# expression (CTE) queries. These queries use the django-cte library, +# because upstream Django does not yet support CTE. +# +# https://www.postgresql.org/docs/current/queries-with.html +# https://pypi.org/project/django-cte/ +# https://code.djangoproject.com/ticket/28919 + + +def get_recursive_subgroups(user_group: UserGroup) -> "QuerySet[UserGroup]": + cte = With.recursive( + lambda cte: UserGroup.objects.filter(id=user_group.id) + .values("id") + .union(cte.join(UserGroup, direct_supergroups=cte.col.id).values("id")) + ) + return cte.join(UserGroup, id=cte.col.id).with_cte(cte) + + +def get_recursive_group_members(user_group: UserGroup) -> "QuerySet[UserProfile]": + return UserProfile.objects.filter(direct_groups__in=get_recursive_subgroups(user_group)) + + +def get_recursive_membership_groups(user_profile: UserProfile) -> "QuerySet[UserGroup]": + cte = With.recursive( + lambda cte: user_profile.direct_groups.values("id").union( + cte.join(UserGroup, direct_subgroups=cte.col.id).values("id") + ) + ) + return cte.join(UserGroup, id=cte.col.id).with_cte(cte) diff --git a/zerver/migrations/0366_group_group_membership.py b/zerver/migrations/0366_group_group_membership.py new file mode 100644 index 0000000000..5f12ec3616 --- /dev/null +++ b/zerver/migrations/0366_group_group_membership.py @@ -0,0 +1,56 @@ +# Generated by Django 3.2.7 on 2021-09-29 23:34 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("zerver", "0365_alter_user_group_related_fields"), + ] + + operations = [ + migrations.CreateModel( + name="GroupGroupMembership", + fields=[ + ( + "id", + models.AutoField( + auto_created=True, primary_key=True, serialize=False, verbose_name="ID" + ), + ), + ( + "subgroup", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="+", + to="zerver.usergroup", + ), + ), + ( + "supergroup", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="+", + to="zerver.usergroup", + ), + ), + ], + ), + migrations.AddField( + model_name="usergroup", + name="direct_subgroups", + field=models.ManyToManyField( + related_name="direct_supergroups", + through="zerver.GroupGroupMembership", + to="zerver.UserGroup", + ), + ), + migrations.AddConstraint( + model_name="groupgroupmembership", + constraint=models.UniqueConstraint( + fields=("supergroup", "subgroup"), name="zerver_groupgroupmembership_uniq" + ), + ), + ] diff --git a/zerver/models.py b/zerver/models.py index a4fef62328..89824ffafe 100644 --- a/zerver/models.py +++ b/zerver/models.py @@ -36,6 +36,7 @@ from django.utils.functional import Promise from django.utils.timezone import now as timezone_now from django.utils.translation import gettext as _ from django.utils.translation import gettext_lazy +from django_cte import CTEManager from confirmation import settings as confirmation_settings from zerver.lib import cache @@ -1976,11 +1977,19 @@ class PasswordTooWeakError(Exception): class UserGroup(models.Model): + objects = CTEManager() id: int = models.AutoField(auto_created=True, primary_key=True, verbose_name="ID") name: str = models.CharField(max_length=100) direct_members: Manager = models.ManyToManyField( UserProfile, through="UserGroupMembership", related_name="direct_groups" ) + direct_subgroups: Manager = models.ManyToManyField( + "self", + symmetrical=False, + through="GroupGroupMembership", + through_fields=("supergroup", "subgroup"), + related_name="direct_supergroups", + ) realm: Realm = models.ForeignKey(Realm, on_delete=CASCADE) description: str = models.TextField(default="") is_system_group: bool = models.BooleanField(default=False) @@ -1998,6 +2007,19 @@ class UserGroupMembership(models.Model): unique_together = (("user_group", "user_profile"),) +class GroupGroupMembership(models.Model): + id: int = models.AutoField(auto_created=True, primary_key=True, verbose_name="ID") + supergroup: UserGroup = models.ForeignKey(UserGroup, on_delete=CASCADE, related_name="+") + subgroup: UserGroup = models.ForeignKey(UserGroup, on_delete=CASCADE, related_name="+") + + class Meta: + constraints = [ + models.UniqueConstraint( + fields=["supergroup", "subgroup"], name="zerver_groupgroupmembership_uniq" + ) + ] + + def remote_user_to_email(remote_user: str) -> str: if settings.SSO_APPEND_DOMAIN is not None: remote_user += "@" + settings.SSO_APPEND_DOMAIN diff --git a/zerver/tests/test_user_groups.py b/zerver/tests/test_user_groups.py index 82bf0b5f68..0e9d208b4e 100644 --- a/zerver/tests/test_user_groups.py +++ b/zerver/tests/test_user_groups.py @@ -11,9 +11,19 @@ from zerver.lib.user_groups import ( create_user_group, get_direct_memberships_of_users, get_direct_user_groups, + get_recursive_group_members, + get_recursive_membership_groups, + get_recursive_subgroups, user_groups_in_realm_serialized, ) -from zerver.models import Realm, UserGroup, UserGroupMembership, UserProfile, get_realm +from zerver.models import ( + GroupGroupMembership, + Realm, + UserGroup, + UserGroupMembership, + UserProfile, + get_realm, +) class UserGroupTestCase(ZulipTestCase): @@ -50,6 +60,47 @@ class UserGroupTestCase(ZulipTestCase): self.assert_length(user_groups, 1) self.assertEqual(user_groups[0].name, "support") + def test_recursive_queries_for_user_groups(self) -> None: + realm = get_realm("zulip") + iago = self.example_user("iago") + desdemona = self.example_user("desdemona") + shiva = self.example_user("shiva") + + leadership_group = UserGroup.objects.create(realm=realm, name="Leadership") + UserGroupMembership.objects.create(user_profile=desdemona, user_group=leadership_group) + + staff_group = UserGroup.objects.create(realm=realm, name="Staff") + UserGroupMembership.objects.create(user_profile=iago, user_group=staff_group) + GroupGroupMembership.objects.create(supergroup=staff_group, subgroup=leadership_group) + + everyone_group = UserGroup.objects.create(realm=realm, name="Everyone") + UserGroupMembership.objects.create(user_profile=shiva, user_group=everyone_group) + GroupGroupMembership.objects.create(supergroup=everyone_group, subgroup=staff_group) + + self.assertCountEqual(list(get_recursive_subgroups(leadership_group)), [leadership_group]) + self.assertCountEqual( + list(get_recursive_subgroups(staff_group)), [leadership_group, staff_group] + ) + self.assertCountEqual( + list(get_recursive_subgroups(everyone_group)), + [leadership_group, staff_group, everyone_group], + ) + + self.assertCountEqual(list(get_recursive_group_members(leadership_group)), [desdemona]) + self.assertCountEqual(list(get_recursive_group_members(staff_group)), [desdemona, iago]) + self.assertCountEqual( + list(get_recursive_group_members(everyone_group)), [desdemona, iago, shiva] + ) + + self.assertCountEqual( + list(get_recursive_membership_groups(desdemona)), + [leadership_group, staff_group, everyone_group], + ) + self.assertCountEqual( + list(get_recursive_membership_groups(iago)), [staff_group, everyone_group] + ) + self.assertCountEqual(list(get_recursive_membership_groups(shiva)), [everyone_group]) + class UserGroupAPITestCase(UserGroupTestCase): def test_user_group_create(self) -> None: