import: Use inspection to determine sequence names.

This commit is contained in:
Alex Vandiver 2024-08-14 02:56:06 +00:00 committed by Tim Abbott
parent 31623911d1
commit 73e5364838
1 changed files with 12 additions and 17 deletions

View File

@ -13,6 +13,7 @@ from django.conf import settings
from django.core.cache import cache
from django.core.validators import validate_email
from django.db import connection, transaction
from django.db.backends.utils import CursorWrapper
from django.utils.timezone import now as timezone_now
from psycopg2.extras import execute_values
from psycopg2.sql import SQL, Identifier
@ -458,18 +459,12 @@ def current_table_ids(data: TableData, table: TableName) -> list[int]:
return [item["id"] for item in data[table]]
def idseq(model_class: Any) -> str:
if model_class == RealmDomain:
return "zerver_realmalias_id_seq"
elif model_class == BotStorageData:
return "zerver_botuserstatedata_id_seq"
elif model_class == BotConfigData:
return "zerver_botuserconfigdata_id_seq"
elif model_class == UserTopic:
# The database table for this model was renamed from `mutedtopic` to
# `usertopic`, but the name of the sequence object remained the same.
return "zerver_mutedtopic_id_seq"
return f"{model_class._meta.db_table}_id_seq"
def idseq(model_class: Any, cursor: CursorWrapper) -> str:
sequences = connection.introspection.get_sequences(cursor, model_class._meta.db_table)
for sequence in sequences:
if sequence["column"] == "id":
return sequence["name"]
raise Exception(f"No sequence found for 'id' of {model_class}")
def allocate_ids(model_class: Any, count: int) -> list[int]:
@ -478,11 +473,11 @@ def allocate_ids(model_class: Any, count: int) -> list[int]:
imported into that table. Hence, this gives a reserved range of IDs to import the
converted Slack objects into the tables.
"""
conn = connection.cursor()
sequence = idseq(model_class)
conn.execute("select nextval(%s) from generate_series(1, %s)", [sequence, count])
query = conn.fetchall() # Each element in the result is a tuple like (5,)
conn.close()
with connection.cursor() as cursor:
sequence = idseq(model_class, cursor)
cursor.execute("select nextval(%s) from generate_series(1, %s)", [sequence, count])
query = cursor.fetchall() # Each element in the result is a tuple like (5,)
# convert List[Tuple[int]] to List[int]
return [item[0] for item in query]