auth: Rewrite data model for tracking enabled auth backends.

So far, we've used the BitField .authentication_methods on Realm
for tracking which backends are enabled for an organization. This
however made it a pain to add new backends (requiring altering the
column and a migration - particularly troublesome if someone wanted to
create their own custom auth backend for their server).

Instead this will be tracked through the existence of the appropriate
rows in the RealmAuthenticationMethods table.
This commit is contained in:
Mateusz Mandera 2023-04-16 21:53:22 +02:00 committed by Tim Abbott
parent 41f17bf392
commit ffa3aa8487
19 changed files with 184 additions and 70 deletions

View File

@ -22,6 +22,7 @@ from zerver.models import (
PreregistrationRealm,
Realm,
RealmAuditLog,
RealmAuthenticationMethod,
RealmUserDefault,
Stream,
UserProfile,
@ -29,6 +30,7 @@ from zerver.models import (
get_realm,
get_system_bot,
)
from zproject.backends import all_implemented_backend_names
if settings.CORPORATE_ENABLED:
from corporate.lib.support import get_support_url
@ -232,6 +234,14 @@ def do_create_realm(
create_system_user_groups_for_realm(realm)
# We create realms with all authentications methods enabled by default.
RealmAuthenticationMethod.objects.bulk_create(
[
RealmAuthenticationMethod(name=backend_name, realm=realm)
for backend_name in all_implemented_backend_names()
]
)
# Create stream once Realm object has been saved
notifications_stream = ensure_stream(
realm,

View File

@ -23,6 +23,7 @@ from zerver.models import (
Attachment,
Realm,
RealmAuditLog,
RealmAuthenticationMethod,
RealmReactivationStatus,
RealmUserDefault,
ScheduledEmail,
@ -130,9 +131,13 @@ def do_set_realm_authentication_methods(
old_value = realm.authentication_methods_dict()
with transaction.atomic():
for key, value in list(authentication_methods.items()):
index = getattr(realm.authentication_methods, key).number
realm.authentication_methods.set_bit(index, int(value))
realm.save(update_fields=["authentication_methods"])
# This does queries in a loop, but this isn't a performance sensitive
# path and is only run rarely.
if value:
RealmAuthenticationMethod.objects.get_or_create(realm=realm, name=key)
else:
RealmAuthenticationMethod.objects.filter(realm=realm, name=key).delete()
updated_value = realm.authentication_methods_dict()
RealmAuditLog.objects.create(
realm=realm,

View File

@ -47,21 +47,21 @@ def gitter_workspace_to_realm(
NOW = float(timezone_now().timestamp())
zerver_realm: List[ZerverFieldsT] = build_zerver_realm(realm_id, realm_subdomain, NOW, "Gitter")
realm = build_realm(zerver_realm, realm_id, domain_name)
# Users will have GitHub's generated noreply email addresses so their only way to log in
# at first is via GitHub. So we set GitHub to be the only authentication method enabled
# default to avoid user confusion.
assert len(zerver_realm) == 1
authentication_methods = [
(auth_method[0], False)
if auth_method[0] != GitHubAuthBackend.auth_backend_name
else (auth_method[0], True)
for auth_method in zerver_realm[0]["authentication_methods"]
realm["zerver_realmauthenticationmethod"] = [
{
"name": GitHubAuthBackend.auth_backend_name,
"realm": realm_id,
# The id doesn't matter since it gets set by the import later properly, but we need to set
# it to something in the dict.
"id": 1,
}
]
zerver_realm[0]["authentication_methods"] = authentication_methods
realm = build_realm(zerver_realm, realm_id, domain_name)
zerver_userprofile, avatars, user_map = build_userprofile(int(NOW), domain_name, gitter_data)
zerver_stream, zerver_defaultstream, stream_map = build_stream_map(int(NOW), gitter_data)
zerver_recipient, zerver_subscription = build_recipient_and_subscription(

View File

@ -40,6 +40,7 @@ from zerver.models import (
Subscription,
UserProfile,
)
from zproject.backends import all_implemented_backend_names
# stubs
ZerverFieldsT = Dict[str, Any]
@ -83,10 +84,8 @@ def build_zerver_realm(
string_id=realm_subdomain,
description=f"Organization imported from {other_product}!",
)
auth_methods = [[flag[0], flag[1]] for flag in realm.authentication_methods]
realm_dict = model_to_dict(realm, exclude=["authentication_methods"])
realm_dict = model_to_dict(realm)
realm_dict["date_created"] = time
realm_dict["authentication_methods"] = auth_methods
return [realm_dict]
@ -373,6 +372,10 @@ def build_realm(
zerver_realmemoji=[],
zerver_realmfilter=[],
zerver_realmplayground=[],
zerver_realmauthenticationmethod=[
{"realm": realm_id, "name": name, "id": i}
for i, name in enumerate(all_implemented_backend_names(), start=1)
],
)
return realm

View File

@ -47,6 +47,7 @@ from zerver.models import (
Reaction,
Realm,
RealmAuditLog,
RealmAuthenticationMethod,
RealmDomain,
RealmEmoji,
RealmFilter,
@ -141,6 +142,7 @@ ALL_ZULIP_TABLES = {
"zerver_reaction",
"zerver_realm",
"zerver_realmauditlog",
"zerver_realmauthenticationmethod",
"zerver_realmdomain",
"zerver_realmemoji",
"zerver_realmfilter",
@ -297,10 +299,6 @@ DATE_FIELDS: Dict[TableName, List[Field]] = {
"zerver_usertopic": ["last_updated"],
}
BITHANDLER_FIELDS: Dict[TableName, List[Field]] = {
"zerver_realm": ["authentication_methods"],
}
def sanity_check_output(data: TableData) -> None:
# First, we verify that the export tool has a declared
@ -438,12 +436,6 @@ def floatify_datetime_fields(data: TableData, table: TableName) -> None:
item[field] = dt.timestamp()
def listify_bithandler_fields(data: TableData, table: TableName) -> None:
for item in data[table]:
for field in BITHANDLER_FIELDS[table]:
item[field] = list(item[field])
class Config:
"""A Config object configures a single table for exporting (and, maybe
some day importing as well. This configuration defines what
@ -668,8 +660,6 @@ def export_from_config(
for t in exported_tables:
if t in DATE_FIELDS:
floatify_datetime_fields(response, t)
if table in BITHANDLER_FIELDS:
listify_bithandler_fields(response, table)
# Now walk our children. It's extremely important to respect
# the order of children here.
@ -690,6 +680,13 @@ def get_realm_config() -> Config:
is_seeded=True,
)
Config(
table="zerver_realmauthenticationmethod",
model=RealmAuthenticationMethod,
normal_parent=realm_config,
include_rows="realm_id__in",
)
Config(
table="zerver_defaultstream",
model=DefaultStream,

View File

@ -51,6 +51,7 @@ from zerver.models import (
Reaction,
Realm,
RealmAuditLog,
RealmAuthenticationMethod,
RealmDomain,
RealmEmoji,
RealmFilter,
@ -77,6 +78,7 @@ from zerver.models import (
)
realm_tables = [
("zerver_realmauthenticationmethod", RealmAuthenticationMethod, "realmauthenticationmethod"),
("zerver_defaultstream", DefaultStream, "defaultstream"),
("zerver_realmemoji", RealmEmoji, "realmemoji"),
("zerver_realmdomain", RealmDomain, "realmdomain"),
@ -105,6 +107,7 @@ ID_MAP: Dict[str, Dict[int, int]] = {
"subscription": {},
"defaultstream": {},
"reaction": {},
"realmauthenticationmethod": {},
"realmemoji": {},
"realmdomain": {},
"realmfilter": {},
@ -598,19 +601,6 @@ def fix_bitfield_keys(data: TableData, table: TableName, field_name: Field) -> N
del item[field_name + "_mask"]
def fix_realm_authentication_bitfield(data: TableData, table: TableName, field_name: Field) -> None:
"""Used to fixup the authentication_methods bitfield to be an integer."""
for item in data[table]:
# The ordering of bits here is important for the imported value
# to end up as expected.
charlist = ["1" if field[1] else "0" for field in item[field_name]]
charlist.reverse()
values_as_bitstring = "".join(charlist)
values_as_int = int(values_as_bitstring, 2)
item[field_name] = values_as_int
def remove_denormalized_recipient_column_from_data(data: TableData) -> None:
"""
The recipient column shouldn't be imported, we'll set the correct values
@ -955,7 +945,6 @@ def do_import_realm(import_dir: Path, subdomain: str, processes: int = 1) -> Rea
# Fix realm subdomain information
data["zerver_realm"][0]["string_id"] = subdomain
data["zerver_realm"][0]["name"] = subdomain
fix_realm_authentication_bitfield(data, "zerver_realm", "authentication_methods")
update_model_ids(Realm, data, "realm")
# Create the realm, but mark it deactivated for now, while we

View File

@ -7,11 +7,13 @@ from zerver.lib.user_groups import create_system_user_groups_for_realm
from zerver.models import (
Realm,
RealmAuditLog,
RealmAuthenticationMethod,
RealmUserDefault,
UserProfile,
get_client,
get_system_bot,
)
from zproject.backends import all_implemented_backend_names
def server_initialized() -> bool:
@ -28,6 +30,14 @@ def create_internal_realm() -> None:
RealmUserDefault.objects.create(realm=realm)
create_system_user_groups_for_realm(realm)
# We create realms with all authentications methods enabled by default.
RealmAuthenticationMethod.objects.bulk_create(
[
RealmAuthenticationMethod(name=backend_name, realm=realm)
for backend_name in all_implemented_backend_names()
]
)
# Create some client objects for common requests. Not required;
# just ensures these get low IDs in production, and in development
# avoids an extra database write for the first HTTP request in

View File

@ -53,7 +53,8 @@ Usage examples:
realm_dict = vars(realm).copy()
# Remove a field that is confusingly useless
del realm_dict["_state"]
# Fix the one bitfield to display useful data
# This is not an attribute of realm strictly speaking, but valuable info to include.
realm_dict["authentication_methods"] = str(realm.authentication_methods_dict())
for key in identifier_attributes:

View File

@ -0,0 +1,52 @@
# Generated by Django 4.2 on 2023-04-13 23:45
import django.db.models.deletion
from django.db import migrations, models
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
from django.db.migrations.state import StateApps
def fill_RealmAuthenticationMethod_data(
apps: StateApps, schema_editor: BaseDatabaseSchemaEditor
) -> None:
Realm = apps.get_model("zerver", "Realm")
RealmAuthenticationMethod = apps.get_model("zerver", "RealmAuthenticationMethod")
for realm in Realm.objects.order_by("id"):
rows_to_create = []
for key, value in realm.authentication_methods.iteritems():
if value:
rows_to_create.append(RealmAuthenticationMethod(name=key, realm_id=realm.id))
RealmAuthenticationMethod.objects.bulk_create(rows_to_create)
class Migration(migrations.Migration):
atomic = False
dependencies = [
("zerver", "0435_scheduledmessage_rendered_content"),
]
operations = [
migrations.CreateModel(
name="RealmAuthenticationMethod",
fields=[
(
"id",
models.AutoField(
auto_created=True, primary_key=True, serialize=False, verbose_name="ID"
),
),
("name", models.CharField(max_length=80)),
(
"realm",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to="zerver.realm"
),
),
],
options={
"unique_together": {("realm", "name")},
},
),
migrations.RunPython(fill_RealmAuthenticationMethod_data),
]

View File

@ -0,0 +1,16 @@
# Generated by Django 4.2 on 2023-04-16 10:55
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("zerver", "0436_realmauthenticationmethods"),
]
operations = [
migrations.RemoveField(
model_name="realm",
name="authentication_methods",
),
]

View File

@ -271,6 +271,21 @@ def clear_supported_auth_backends_cache() -> None:
supported_backends = None
class RealmAuthenticationMethod(models.Model):
"""
Tracks which authentication backends are enabled for a realm.
An enabled backend is represented in this table a row with appropriate
.realm value and .name matching the name of the target backend in the
AUTH_BACKEND_NAME_MAP dict.
"""
realm = models.ForeignKey("Realm", on_delete=CASCADE, db_index=True)
name = models.CharField(max_length=80)
class Meta:
unique_together = ("realm", "name")
class Realm(models.Model): # type: ignore[django-manager-missing] # django-stubs cannot resolve the custom CTEManager yet https://github.com/typeddjango/django-stubs/issues/1023
MAX_REALM_NAME_LENGTH = 40
MAX_REALM_DESCRIPTION_LENGTH = 1000
@ -318,10 +333,6 @@ class Realm(models.Model): # type: ignore[django-manager-missing] # django-stub
_max_invites = models.IntegerField(null=True, db_column="max_invites")
disallow_disposable_email_addresses = models.BooleanField(default=True)
authentication_methods: BitHandler = BitField(
flags=AUTHENTICATION_FLAGS,
default=2**31 - 1,
)
# Allow users to access web-public streams without login. This
# setting also controls API access of web-public streams.
@ -825,17 +836,21 @@ class Realm(models.Model): # type: ignore[django-manager-missing] # django-stub
on the server, this will not return an entry for "Email")."""
# This mapping needs to be imported from here due to the cyclic
# dependency.
from zproject.backends import AUTH_BACKEND_NAME_MAP
from zproject.backends import AUTH_BACKEND_NAME_MAP, all_implemented_backend_names
ret: Dict[str, bool] = {}
supported_backends = [type(backend) for backend in supported_auth_backends()]
# `authentication_methods` is a bitfield.types.BitHandler, not
# a true dict; since it is still python2- and python3-compat,
# `iteritems` is its method to iterate over its contents.
for k, v in self.authentication_methods.iteritems():
backend = AUTH_BACKEND_NAME_MAP[k]
if backend in supported_backends:
ret[k] = v
for backend_name in all_implemented_backend_names():
backend_class = AUTH_BACKEND_NAME_MAP[backend_name]
if backend_class in supported_backends:
ret[backend_name] = False
for realm_authentication_method in RealmAuthenticationMethod.objects.filter(
realm_id=self.id
):
backend_class = AUTH_BACKEND_NAME_MAP[realm_authentication_method.name]
if backend_class in supported_backends:
ret[realm_authentication_method.name] = True
return ret
# `realm` instead of `self` here to make sure the parameters of the cache key

View File

@ -59,6 +59,7 @@ from zerver.actions.invites import do_invite_users
from zerver.actions.realm_settings import (
do_deactivate_realm,
do_reactivate_realm,
do_set_realm_authentication_methods,
do_set_realm_property,
)
from zerver.actions.user_settings import do_change_password, do_change_user_setting
@ -247,9 +248,13 @@ class AuthBackendTest(ZulipTestCase):
if isinstance(backend, AUTH_BACKEND_NAME_MAP[backend_name]):
break
index = getattr(user_profile.realm.authentication_methods, backend_name).number
user_profile.realm.authentication_methods.set_bit(index, False)
user_profile.realm.save()
authentication_methods = user_profile.realm.authentication_methods_dict()
authentication_methods[backend_name] = False
do_set_realm_authentication_methods(
user_profile.realm, authentication_methods, acting_user=None
)
if "realm" in good_kwargs:
# Because this test is a little unfaithful to the ordering
# (i.e. we fetched the realm object before this function
@ -264,8 +269,11 @@ class AuthBackendTest(ZulipTestCase):
self.assertEqual(result["Location"], user_profile.realm.uri + "/login/")
else:
self.assertIsNone(result)
user_profile.realm.authentication_methods.set_bit(index, True)
user_profile.realm.save()
authentication_methods[backend_name] = True
do_set_realm_authentication_methods(
user_profile.realm, authentication_methods, acting_user=None
)
def test_dummy_backend(self) -> None:
realm = get_realm("zulip")

View File

@ -1211,7 +1211,7 @@ class FetchQueriesTest(ZulipTestCase):
self.login_user(user)
flush_per_request_caches()
with self.assert_database_query_count(37):
with self.assert_database_query_count(41):
with mock.patch("zerver.lib.events.always_want") as want_mock:
fetch_initial_state_data(user)
@ -1226,7 +1226,7 @@ class FetchQueriesTest(ZulipTestCase):
muted_topics=1,
muted_users=1,
presence=1,
realm=0,
realm=4,
realm_bot=1,
realm_domains=1,
realm_embedded_bots=0,

View File

@ -248,7 +248,7 @@ class HomeTest(ZulipTestCase):
# Verify succeeds once logged-in
flush_per_request_caches()
with self.assert_database_query_count(47):
with self.assert_database_query_count(51):
with patch("zerver.lib.cache.cache_set") as cache_mock:
result = self._get_home_page(stream="Denmark")
self.check_rendered_logged_in_app(result)
@ -439,7 +439,7 @@ class HomeTest(ZulipTestCase):
# Verify number of queries for Realm admin isn't much higher than for normal users.
self.login("iago")
flush_per_request_caches()
with self.assert_database_query_count(44):
with self.assert_database_query_count(48):
with patch("zerver.lib.cache.cache_set") as cache_mock:
result = self._get_home_page()
self.check_rendered_logged_in_app(result)
@ -471,7 +471,7 @@ class HomeTest(ZulipTestCase):
# Then for the second page load, measure the number of queries.
flush_per_request_caches()
with self.assert_database_query_count(42):
with self.assert_database_query_count(46):
result = self._get_home_page()
# Do a sanity check that our new streams were in the payload.

View File

@ -939,7 +939,7 @@ class LoginTest(ZulipTestCase):
ContentType.objects.clear_cache()
# Ensure the number of queries we make is not O(streams)
with self.assert_database_query_count(96), cache_tries_captured() as cache_tries:
with self.assert_database_query_count(102), cache_tries_captured() as cache_tries:
with self.captureOnCommitCallbacks(execute=True):
self.register(self.nonreg_email("test"), "test")

View File

@ -734,7 +734,7 @@ class SlackImporter(ZulipTestCase):
passed_realm["zerver_realm"][0]["description"], "Organization imported from Slack!"
)
self.assertEqual(passed_realm["zerver_userpresence"], [])
self.assert_length(passed_realm.keys(), 15)
self.assert_length(passed_realm.keys(), 16)
self.assertEqual(realm["zerver_stream"], [])
self.assertEqual(realm["zerver_userprofile"], [])
@ -1148,6 +1148,10 @@ class SlackImporter(ZulipTestCase):
self.assertEqual(Message.objects.filter(realm=realm).count(), 82)
# All auth backends are enabled initially.
for name, enabled in realm.authentication_methods_dict().items():
self.assertTrue(enabled)
Realm.objects.filter(name=test_realm_subdomain).delete()
remove_folder(output_dir)

View File

@ -786,7 +786,7 @@ class QueryCountTest(ZulipTestCase):
prereg_user = PreregistrationUser.objects.get(email="fred@zulip.com")
with self.assert_database_query_count(90):
with self.assert_database_query_count(91):
with cache_tries_captured() as cache_tries:
with self.capture_send_event_calls(expected_num_events=11) as events:
fred = do_create_user(

View File

@ -285,7 +285,7 @@ def update_realm(
# The following realm properties do not fit the pattern above
# authentication_methods is not supported by the do_set_realm_property
# framework because of its bitfield.
# framework because it's tracked through the RealmAuthenticationMethod table.
if authentication_methods is not None and (
realm.authentication_methods_dict() != authentication_methods
):

View File

@ -113,12 +113,16 @@ from zproject.settings_types import OIDCIdPConfigDict
redis_client = get_redis_client()
def all_implemented_backend_names() -> List[str]:
return list(AUTH_BACKEND_NAME_MAP.keys())
# This first batch of methods is used by other code in Zulip to check
# whether a given authentication backend is enabled for a given realm.
# In each case, we both needs to check at the server level (via
# `settings.AUTHENTICATION_BACKENDS`, queried via
# `django.contrib.auth.get_backends`) and at the realm level (via the
# `Realm.authentication_methods` BitField).
# `RealmAuthenticationMethod` table).
def pad_method_dict(method_dict: Dict[str, bool]) -> Dict[str, bool]:
"""Pads an authentication methods dict to contain all auth backends
supported by the software, regardless of whether they are