zulip/zerver/lib/migrate.py

125 lines
4.0 KiB
Python

from __future__ import print_function
from typing import Any, Callable, Dict, List, Tuple
from django.db.models.query import QuerySet
import re
import time
def timed_ddl(db, stmt):
# type: (Any, str) -> None
print()
print(time.asctime())
print(stmt)
t = time.time()
db.execute(stmt)
delay = time.time() - t
print('Took %.2fs' % (delay,))
def validate(sql_thingy):
# type: (str) -> None
# Do basic validation that table/col name is safe.
if not re.match('^[a-z][a-z\d_]+$', sql_thingy):
raise Exception('Invalid SQL object: %s' % (sql_thingy,))
def do_batch_update(db, table, cols, vals, batch_size=10000, sleep=0.1):
# type: (Any, str, List[str], List[str], int, float) -> None
validate(table)
for col in cols:
validate(col)
stmt = '''
UPDATE %s
SET (%s) = (%s)
WHERE id >= %%s AND id < %%s
''' % (table, ', '.join(cols), ', '.join(['%s'] * len(cols)))
print(stmt)
(min_id, max_id) = db.execute("SELECT MIN(id), MAX(id) FROM %s" % (table,))[0]
if min_id is None:
return
print("%s rows need updating" % (max_id - min_id,))
while min_id <= max_id:
lower = min_id
upper = min_id + batch_size
print('%s about to update range [%s,%s)' % (time.asctime(), lower, upper))
db.start_transaction()
params = list(vals) + [lower, upper]
db.execute(stmt, params=params)
db.commit_transaction()
min_id = upper
time.sleep(sleep)
def add_bool_columns(db, table, cols):
# type: (Any, str, List[str]) -> None
validate(table)
for col in cols:
validate(col)
coltype = 'boolean'
val = 'false'
stmt = (('ALTER TABLE %s ' % (table,)) +
', '.join(['ADD %s %s' % (col, coltype) for col in cols]))
timed_ddl(db, stmt)
stmt = (('ALTER TABLE %s ' % (table,)) +
', '.join(['ALTER %s SET DEFAULT %s' % (col, val) for col in cols]))
timed_ddl(db, stmt)
vals = [val] * len(cols)
do_batch_update(db, table, cols, vals)
stmt = 'ANALYZE %s' % (table,)
timed_ddl(db, stmt)
stmt = (('ALTER TABLE %s ' % (table,)) +
', '.join(['ALTER %s SET NOT NULL' % (col,) for col in cols]))
timed_ddl(db, stmt)
def create_index_if_nonexistant(db, table, col, index):
# type: (Any, str, str, str) -> None
validate(table)
validate(col)
validate(index)
test = """SELECT relname FROM pg_class
WHERE relname = %s"""
if len(db.execute(test, params=[index])) != 0:
print("Not creating index '%s' because it already exists" % (index,))
else:
stmt = "CREATE INDEX %s ON %s (%s)" % (index, table, col)
timed_ddl(db, stmt)
def act_on_message_ranges(db, orm, tasks, batch_size=5000, sleep=0.5):
# type: (Any, Dict[str, Any], List[Tuple[Callable[[QuerySet], QuerySet], Callable[[QuerySet], None]]], int , float) -> None
# tasks should be an array of (filterer, action) tuples
# where filterer is a function that returns a filtered QuerySet
# and action is a function that acts on a QuerySet
all_objects = orm['zerver.Message'].objects
try:
min_id = all_objects.all().order_by('id')[0].id
except IndexError:
print('There is no work to do')
return
max_id = all_objects.all().order_by('-id')[0].id
print("max_id = %d" % (max_id,))
overhead = int((max_id + 1 - min_id) / batch_size * sleep / 60)
print("Expect this to take at least %d minutes, just due to sleeps alone." % (overhead,))
while min_id <= max_id:
lower = min_id
upper = min_id + batch_size - 1
if upper > max_id:
upper = max_id
print('%s about to update range %s to %s' % (time.asctime(), lower, upper))
db.start_transaction()
for filterer, action in tasks:
objects = all_objects.filter(id__range=(lower, upper))
targets = filterer(objects)
action(targets)
db.commit_transaction()
min_id = upper + 1
time.sleep(sleep)