2017-09-25 09:47:15 +02:00
|
|
|
from __future__ import absolute_import
|
|
|
|
|
2017-11-07 07:56:26 +01:00
|
|
|
from collections import defaultdict
|
2017-09-25 09:47:15 +02:00
|
|
|
from django.db import transaction
|
2017-11-01 10:04:16 +01:00
|
|
|
from django.utils.translation import ugettext as _
|
|
|
|
from zerver.lib.exceptions import JsonableError
|
2017-09-25 09:47:15 +02:00
|
|
|
from zerver.models import UserProfile, Realm, UserGroupMembership, UserGroup
|
2017-11-07 07:56:26 +01:00
|
|
|
from typing import Dict, Iterable, List, Text, Tuple, Any
|
2017-09-25 09:47:15 +02:00
|
|
|
|
2017-11-01 10:04:16 +01:00
|
|
|
def access_user_group_by_id(user_group_id: int, realm: Realm) -> UserGroup:
|
|
|
|
try:
|
|
|
|
user_group = UserGroup.objects.get(id=user_group_id, realm=realm)
|
|
|
|
except UserGroup.DoesNotExist:
|
|
|
|
raise JsonableError(_("Invalid user group"))
|
|
|
|
return user_group
|
|
|
|
|
2017-11-05 11:15:10 +01:00
|
|
|
def user_groups_in_realm(realm: Realm) -> List[UserGroup]:
|
2017-09-25 09:47:15 +02:00
|
|
|
user_groups = UserGroup.objects.filter(realm=realm)
|
|
|
|
return list(user_groups)
|
|
|
|
|
2017-11-05 11:15:10 +01:00
|
|
|
def user_groups_in_realm_serialized(realm: Realm) -> List[Dict[Text, Any]]:
|
2017-11-30 01:09:23 +01:00
|
|
|
"""This function is used in do_events_register code path so this code
|
|
|
|
should be performant. We need to do 2 database queries because
|
|
|
|
Django's ORM doesn't properly support the left join between
|
|
|
|
UserGroup and UserGroupMembership that we need.
|
2017-11-07 07:56:26 +01:00
|
|
|
"""
|
2017-11-30 01:09:23 +01:00
|
|
|
realm_groups = UserGroup.objects.filter(realm=realm)
|
|
|
|
group_dicts = {} # type: Dict[str, Any]
|
|
|
|
for user_group in realm_groups:
|
|
|
|
group_dicts[user_group.id] = dict(
|
|
|
|
id=user_group.id,
|
|
|
|
name=user_group.name,
|
|
|
|
description=user_group.description,
|
|
|
|
members=[],
|
|
|
|
)
|
|
|
|
|
|
|
|
membership = UserGroupMembership.objects.filter(user_group__realm=realm).values_list(
|
|
|
|
'user_group_id', 'user_profile_id')
|
|
|
|
for (user_group_id, user_profile_id) in membership:
|
|
|
|
group_dicts[user_group_id]['members'].append(user_profile_id)
|
|
|
|
for group_dict in group_dicts.values():
|
|
|
|
group_dict['members'] = sorted(group_dict['members'])
|
|
|
|
|
|
|
|
return sorted(group_dicts.values(), key=lambda group_dict: group_dict['id'])
|
2017-11-07 07:56:26 +01:00
|
|
|
|
2017-11-05 11:15:10 +01:00
|
|
|
def get_user_groups(user_profile: UserProfile) -> List[UserGroup]:
|
2017-09-25 09:47:15 +02:00
|
|
|
return list(user_profile.usergroup_set.all())
|
|
|
|
|
2017-11-05 11:15:10 +01:00
|
|
|
def check_add_user_to_user_group(user_profile: UserProfile, user_group: UserGroup) -> bool:
|
2017-09-25 09:47:15 +02:00
|
|
|
member_obj, created = UserGroupMembership.objects.get_or_create(
|
|
|
|
user_group=user_group, user_profile=user_profile)
|
|
|
|
return created
|
|
|
|
|
2017-11-05 11:15:10 +01:00
|
|
|
def remove_user_from_user_group(user_profile: UserProfile, user_group: UserGroup) -> int:
|
2017-09-25 09:47:15 +02:00
|
|
|
num_deleted, _ = UserGroupMembership.objects.filter(
|
|
|
|
user_profile=user_profile, user_group=user_group).delete()
|
|
|
|
return num_deleted
|
|
|
|
|
2017-11-05 11:15:10 +01:00
|
|
|
def check_remove_user_from_user_group(user_profile: UserProfile, user_group: UserGroup) -> bool:
|
2017-09-25 09:47:15 +02:00
|
|
|
try:
|
|
|
|
num_deleted = remove_user_from_user_group(user_profile, user_group)
|
|
|
|
return bool(num_deleted)
|
|
|
|
except Exception:
|
|
|
|
return False
|
|
|
|
|
2017-11-05 11:15:10 +01:00
|
|
|
def create_user_group(name: Text, members: List[UserProfile], realm: Realm,
|
|
|
|
description: Text='') -> UserGroup:
|
2017-09-25 09:47:15 +02:00
|
|
|
with transaction.atomic():
|
2017-11-01 09:01:38 +01:00
|
|
|
user_group = UserGroup.objects.create(name=name, realm=realm,
|
|
|
|
description=description)
|
2017-09-25 09:47:15 +02:00
|
|
|
UserGroupMembership.objects.bulk_create([
|
|
|
|
UserGroupMembership(user_profile=member, user_group=user_group)
|
|
|
|
for member in members
|
|
|
|
])
|
|
|
|
return user_group
|
2017-11-02 08:53:30 +01:00
|
|
|
|
2017-11-27 05:27:04 +01:00
|
|
|
def get_memberships_of_users(user_group: UserGroup, members: List[UserProfile]) -> List[int]:
|
2017-11-02 08:53:30 +01:00
|
|
|
return list(UserGroupMembership.objects.filter(
|
|
|
|
user_group=user_group,
|
|
|
|
user_profile__in=members).values_list('user_profile_id', flat=True))
|