user_groups: Add a recursive group membership model.

Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
Anders Kaseorg 2021-09-28 17:46:57 -07:00 committed by Tim Abbott
parent 6ac9386a29
commit 1e5157b66c
6 changed files with 186 additions and 1 deletions

View File

@ -38,6 +38,7 @@ from zerver.models import (
CustomProfileField, CustomProfileField,
CustomProfileFieldValue, CustomProfileFieldValue,
DefaultStream, DefaultStream,
GroupGroupMembership,
Huddle, Huddle,
Message, Message,
Reaction, Reaction,
@ -127,6 +128,7 @@ ALL_ZULIP_TABLES = {
"zerver_defaultstreamgroup_streams", "zerver_defaultstreamgroup_streams",
"zerver_draft", "zerver_draft",
"zerver_emailchangestatus", "zerver_emailchangestatus",
"zerver_groupgroupmembership",
"zerver_huddle", "zerver_huddle",
"zerver_message", "zerver_message",
"zerver_missedmessageemailaddress", "zerver_missedmessageemailaddress",
@ -705,6 +707,13 @@ def get_realm_config() -> Config:
parent_key="user_group__in", parent_key="user_group__in",
) )
Config(
table="zerver_groupgroupmembership",
model=GroupGroupMembership,
normal_parent=user_groups_config,
parent_key="supergroup__in",
)
Config( Config(
custom_tables=[ custom_tables=[
"zerver_userprofile_crossrealm", "zerver_userprofile_crossrealm",

View File

@ -42,6 +42,7 @@ from zerver.models import (
CustomProfileField, CustomProfileField,
CustomProfileFieldValue, CustomProfileFieldValue,
DefaultStream, DefaultStream,
GroupGroupMembership,
Huddle, Huddle,
Message, Message,
MutedUser, MutedUser,
@ -121,6 +122,7 @@ ID_MAP: Dict[str, Dict[int, int]] = {
"service": {}, "service": {},
"usergroup": {}, "usergroup": {},
"usergroupmembership": {}, "usergroupmembership": {},
"groupgroupmembership": {},
"botstoragedata": {}, "botstoragedata": {},
"botconfigdata": {}, "botconfigdata": {},
"analytics_realmcount": {}, "analytics_realmcount": {},
@ -1121,6 +1123,9 @@ def do_import_realm(import_dir: Path, subdomain: str, processes: int = 1) -> Rea
re_map_foreign_keys_many_to_many( re_map_foreign_keys_many_to_many(
data, "zerver_usergroup", "direct_members", related_table="user_profile" data, "zerver_usergroup", "direct_members", related_table="user_profile"
) )
re_map_foreign_keys_many_to_many(
data, "zerver_usergroup", "direct_subgroups", related_table="usergroup"
)
update_model_ids(UserGroup, data, "usergroup") update_model_ids(UserGroup, data, "usergroup")
bulk_import_model(data, UserGroup) bulk_import_model(data, UserGroup)
@ -1133,6 +1138,15 @@ def do_import_realm(import_dir: Path, subdomain: str, processes: int = 1) -> Rea
update_model_ids(UserGroupMembership, data, "usergroupmembership") update_model_ids(UserGroupMembership, data, "usergroupmembership")
bulk_import_model(data, UserGroupMembership) bulk_import_model(data, UserGroupMembership)
re_map_foreign_keys(
data, "zerver_groupgroupmembership", "supergroup", related_table="usergroup"
)
re_map_foreign_keys(
data, "zerver_groupgroupmembership", "subgroup", related_table="usergroup"
)
update_model_ids(GroupGroupMembership, data, "groupgroupmembership")
bulk_import_model(data, GroupGroupMembership)
if "zerver_botstoragedata" in data: if "zerver_botstoragedata" in data:
re_map_foreign_keys( re_map_foreign_keys(
data, "zerver_botstoragedata", "bot_profile", related_table="user_profile" data, "zerver_botstoragedata", "bot_profile", related_table="user_profile"

View File

@ -1,7 +1,9 @@
from typing import Any, Dict, List from typing import Any, Dict, List
from django.db import transaction from django.db import transaction
from django.db.models import QuerySet
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from django_cte import With
from zerver.lib.exceptions import JsonableError from zerver.lib.exceptions import JsonableError
from zerver.models import Realm, UserGroup, UserGroupMembership, UserProfile from zerver.models import Realm, UserGroup, UserGroupMembership, UserProfile
@ -95,3 +97,34 @@ def get_direct_memberships_of_users(user_group: UserGroup, members: List[UserPro
user_group=user_group, user_profile__in=members user_group=user_group, user_profile__in=members
).values_list("user_profile_id", flat=True) ).values_list("user_profile_id", flat=True)
) )
# These recursive lookups use standard PostgreSQL common table
# expression (CTE) queries. These queries use the django-cte library,
# because upstream Django does not yet support CTE.
#
# https://www.postgresql.org/docs/current/queries-with.html
# https://pypi.org/project/django-cte/
# https://code.djangoproject.com/ticket/28919
def get_recursive_subgroups(user_group: UserGroup) -> "QuerySet[UserGroup]":
cte = With.recursive(
lambda cte: UserGroup.objects.filter(id=user_group.id)
.values("id")
.union(cte.join(UserGroup, direct_supergroups=cte.col.id).values("id"))
)
return cte.join(UserGroup, id=cte.col.id).with_cte(cte)
def get_recursive_group_members(user_group: UserGroup) -> "QuerySet[UserProfile]":
return UserProfile.objects.filter(direct_groups__in=get_recursive_subgroups(user_group))
def get_recursive_membership_groups(user_profile: UserProfile) -> "QuerySet[UserGroup]":
cte = With.recursive(
lambda cte: user_profile.direct_groups.values("id").union(
cte.join(UserGroup, direct_subgroups=cte.col.id).values("id")
)
)
return cte.join(UserGroup, id=cte.col.id).with_cte(cte)

View File

@ -0,0 +1,56 @@
# Generated by Django 3.2.7 on 2021-09-29 23:34
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("zerver", "0365_alter_user_group_related_fields"),
]
operations = [
migrations.CreateModel(
name="GroupGroupMembership",
fields=[
(
"id",
models.AutoField(
auto_created=True, primary_key=True, serialize=False, verbose_name="ID"
),
),
(
"subgroup",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
related_name="+",
to="zerver.usergroup",
),
),
(
"supergroup",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
related_name="+",
to="zerver.usergroup",
),
),
],
),
migrations.AddField(
model_name="usergroup",
name="direct_subgroups",
field=models.ManyToManyField(
related_name="direct_supergroups",
through="zerver.GroupGroupMembership",
to="zerver.UserGroup",
),
),
migrations.AddConstraint(
model_name="groupgroupmembership",
constraint=models.UniqueConstraint(
fields=("supergroup", "subgroup"), name="zerver_groupgroupmembership_uniq"
),
),
]

View File

@ -36,6 +36,7 @@ from django.utils.functional import Promise
from django.utils.timezone import now as timezone_now from django.utils.timezone import now as timezone_now
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from django.utils.translation import gettext_lazy from django.utils.translation import gettext_lazy
from django_cte import CTEManager
from confirmation import settings as confirmation_settings from confirmation import settings as confirmation_settings
from zerver.lib import cache from zerver.lib import cache
@ -1976,11 +1977,19 @@ class PasswordTooWeakError(Exception):
class UserGroup(models.Model): class UserGroup(models.Model):
objects = CTEManager()
id: int = models.AutoField(auto_created=True, primary_key=True, verbose_name="ID") id: int = models.AutoField(auto_created=True, primary_key=True, verbose_name="ID")
name: str = models.CharField(max_length=100) name: str = models.CharField(max_length=100)
direct_members: Manager = models.ManyToManyField( direct_members: Manager = models.ManyToManyField(
UserProfile, through="UserGroupMembership", related_name="direct_groups" UserProfile, through="UserGroupMembership", related_name="direct_groups"
) )
direct_subgroups: Manager = models.ManyToManyField(
"self",
symmetrical=False,
through="GroupGroupMembership",
through_fields=("supergroup", "subgroup"),
related_name="direct_supergroups",
)
realm: Realm = models.ForeignKey(Realm, on_delete=CASCADE) realm: Realm = models.ForeignKey(Realm, on_delete=CASCADE)
description: str = models.TextField(default="") description: str = models.TextField(default="")
is_system_group: bool = models.BooleanField(default=False) is_system_group: bool = models.BooleanField(default=False)
@ -1998,6 +2007,19 @@ class UserGroupMembership(models.Model):
unique_together = (("user_group", "user_profile"),) unique_together = (("user_group", "user_profile"),)
class GroupGroupMembership(models.Model):
id: int = models.AutoField(auto_created=True, primary_key=True, verbose_name="ID")
supergroup: UserGroup = models.ForeignKey(UserGroup, on_delete=CASCADE, related_name="+")
subgroup: UserGroup = models.ForeignKey(UserGroup, on_delete=CASCADE, related_name="+")
class Meta:
constraints = [
models.UniqueConstraint(
fields=["supergroup", "subgroup"], name="zerver_groupgroupmembership_uniq"
)
]
def remote_user_to_email(remote_user: str) -> str: def remote_user_to_email(remote_user: str) -> str:
if settings.SSO_APPEND_DOMAIN is not None: if settings.SSO_APPEND_DOMAIN is not None:
remote_user += "@" + settings.SSO_APPEND_DOMAIN remote_user += "@" + settings.SSO_APPEND_DOMAIN

View File

@ -11,9 +11,19 @@ from zerver.lib.user_groups import (
create_user_group, create_user_group,
get_direct_memberships_of_users, get_direct_memberships_of_users,
get_direct_user_groups, get_direct_user_groups,
get_recursive_group_members,
get_recursive_membership_groups,
get_recursive_subgroups,
user_groups_in_realm_serialized, user_groups_in_realm_serialized,
) )
from zerver.models import Realm, UserGroup, UserGroupMembership, UserProfile, get_realm from zerver.models import (
GroupGroupMembership,
Realm,
UserGroup,
UserGroupMembership,
UserProfile,
get_realm,
)
class UserGroupTestCase(ZulipTestCase): class UserGroupTestCase(ZulipTestCase):
@ -50,6 +60,47 @@ class UserGroupTestCase(ZulipTestCase):
self.assert_length(user_groups, 1) self.assert_length(user_groups, 1)
self.assertEqual(user_groups[0].name, "support") self.assertEqual(user_groups[0].name, "support")
def test_recursive_queries_for_user_groups(self) -> None:
realm = get_realm("zulip")
iago = self.example_user("iago")
desdemona = self.example_user("desdemona")
shiva = self.example_user("shiva")
leadership_group = UserGroup.objects.create(realm=realm, name="Leadership")
UserGroupMembership.objects.create(user_profile=desdemona, user_group=leadership_group)
staff_group = UserGroup.objects.create(realm=realm, name="Staff")
UserGroupMembership.objects.create(user_profile=iago, user_group=staff_group)
GroupGroupMembership.objects.create(supergroup=staff_group, subgroup=leadership_group)
everyone_group = UserGroup.objects.create(realm=realm, name="Everyone")
UserGroupMembership.objects.create(user_profile=shiva, user_group=everyone_group)
GroupGroupMembership.objects.create(supergroup=everyone_group, subgroup=staff_group)
self.assertCountEqual(list(get_recursive_subgroups(leadership_group)), [leadership_group])
self.assertCountEqual(
list(get_recursive_subgroups(staff_group)), [leadership_group, staff_group]
)
self.assertCountEqual(
list(get_recursive_subgroups(everyone_group)),
[leadership_group, staff_group, everyone_group],
)
self.assertCountEqual(list(get_recursive_group_members(leadership_group)), [desdemona])
self.assertCountEqual(list(get_recursive_group_members(staff_group)), [desdemona, iago])
self.assertCountEqual(
list(get_recursive_group_members(everyone_group)), [desdemona, iago, shiva]
)
self.assertCountEqual(
list(get_recursive_membership_groups(desdemona)),
[leadership_group, staff_group, everyone_group],
)
self.assertCountEqual(
list(get_recursive_membership_groups(iago)), [staff_group, everyone_group]
)
self.assertCountEqual(list(get_recursive_membership_groups(shiva)), [everyone_group])
class UserGroupAPITestCase(UserGroupTestCase): class UserGroupAPITestCase(UserGroupTestCase):
def test_user_group_create(self) -> None: def test_user_group_create(self) -> None: