From 230802ee222a6850618517674b991e93a9ed9e6b Mon Sep 17 00:00:00 2001 From: Zev Benjamin Date: Sat, 1 Mar 2014 11:20:04 -0500 Subject: [PATCH] migrate.py: Add multiple columns simultaneously (imported from commit 6cd01fcce6a6e18ce57be6f4da1fd394120b1f99) --- zerver/lib/migrate.py | 32 +++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/zerver/lib/migrate.py b/zerver/lib/migrate.py index 0816f82aae..b0a1f9474a 100644 --- a/zerver/lib/migrate.py +++ b/zerver/lib/migrate.py @@ -15,43 +15,53 @@ def validate(sql_thingy): if not re.match('^[a-z][a-z\d_]+$', sql_thingy): raise Exception('Invalid SQL object: %s' % (sql_thingy,)) -def do_batch_update(db, table, col, val, batch_size=10000, sleep=0.1): +def do_batch_update(db, table, cols, vals, batch_size=10000, sleep=0.1): validate(table) - validate(col) + for col in cols: + validate(col) stmt = ''' UPDATE %s - SET %s = %%s + SET (%s) = (%s) WHERE id >= %%s AND id < %%s - ''' % (table, col) + ''' % (table, ', '.join(cols), ', '.join(['%s'] * len(cols))) + print stmt (min_id, max_id) = db.execute("SELECT MIN(id), MAX(id) FROM %s" % (table,))[0] if min_id is None: return + + print "%s rows need updating" % (max_id - min_id,) while min_id <= max_id: lower = min_id upper = min_id + batch_size print '%s about to update range [%s,%s)' % (time.asctime(), lower, upper) db.start_transaction() - db.execute(stmt, params=[val, lower, upper]) + params = list(vals) + [lower, upper] + db.execute(stmt, params=params) db.commit_transaction() min_id = upper time.sleep(sleep) -def add_bool_column(db, table, col): +def add_bool_columns(db, table, cols): validate(table) - validate(col) + for col in cols: + validate(col) coltype = 'boolean' val = 'false' - stmt = 'ALTER TABLE %s ADD %s %s' % (table, col, coltype) + stmt = ('ALTER TABLE %s ' % (table,)) \ + + ', '.join(['ADD %s %s' % (col, coltype) for col in cols]) timed_ddl(db, stmt) - stmt = 'ALTER TABLE %s ALTER %s SET DEFAULT %s' % (table, col, val) + stmt = ('ALTER TABLE %s ' % (table,)) \ + + ', '.join(['ALTER %s SET DEFAULT %s' % (col, val) for col in cols]) timed_ddl(db, stmt) - do_batch_update(db, table, col, val) + vals = [val] * len(cols) + do_batch_update(db, table, cols, vals) stmt = 'ANALYZE %s' % (table,) timed_ddl(db, stmt) - stmt = 'ALTER TABLE %s ALTER %s SET NOT NULL' % (table, col) + stmt = ('ALTER TABLE %s ' % (table,)) \ + + ', '.join(['ALTER %s SET NOT NULL' % (col,) for col in cols]) timed_ddl(db, stmt)