migrate: Improve do_batch_update escaping correctness with psycopg2.sql.

Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
Anders Kaseorg 2020-05-03 16:15:36 -07:00 committed by Tim Abbott
parent cebac3f35a
commit ffe5402c49
2 changed files with 23 additions and 21 deletions

View File

@ -1,15 +1,19 @@
from django.db import connection, migrations from django.db import connection, migrations
from django.db.backends.postgresql.schema import DatabaseSchemaEditor from django.db.backends.postgresql.schema import DatabaseSchemaEditor
from django.db.migrations.state import StateApps from django.db.migrations.state import StateApps
from psycopg2.sql import SQL
from zerver.lib.migrate import do_batch_update from zerver.lib.migrate import do_batch_update
def rebuild_pgroonga_index(apps: StateApps, schema_editor: DatabaseSchemaEditor) -> None: def rebuild_pgroonga_index(apps: StateApps, schema_editor: DatabaseSchemaEditor) -> None:
with connection.cursor() as cursor: with connection.cursor() as cursor:
do_batch_update(cursor, 'zerver_message', ['search_pgroonga'], do_batch_update(
["escape_html(subject) || ' ' || rendered_content"], cursor,
escape=False, batch_size=10000) "zerver_message",
[SQL("search_pgroonga = escape_html(subject) || ' ' || rendered_content")],
batch_size=10000,
)
class Migration(migrations.Migration): class Migration(migrations.Migration):
atomic = False atomic = False

View File

@ -1,4 +1,5 @@
from psycopg2.extensions import cursor from psycopg2.extensions import cursor
from psycopg2.sql import Composable, Identifier, SQL
from typing import List, TypeVar from typing import List, TypeVar
import time import time
@ -8,21 +9,22 @@ CursorObj = TypeVar('CursorObj', bound=cursor)
def do_batch_update(cursor: CursorObj, def do_batch_update(cursor: CursorObj,
table: str, table: str,
cols: List[str], assignments: List[Composable],
vals: List[str],
batch_size: int=10000, batch_size: int=10000,
sleep: float=0.1, sleep: float=0.1) -> None: # nocoverage
escape: bool=True) -> None: # nocoverage
# The string substitution below is complicated by our need to # The string substitution below is complicated by our need to
# support multiple postgres versions. # support multiple postgres versions.
stmt = ''' stmt = SQL('''
UPDATE %s UPDATE {}
SET %s SET {}
WHERE id >= %%s AND id < %%s WHERE id >= %s AND id < %s
''' % (table, ', '.join(['%s = %%s' % (col) for col in cols])) ''').format(
Identifier(table),
SQL(', ').join(assignments),
)
cursor.execute("SELECT MIN(id), MAX(id) FROM %s" % (table,)) cursor.execute(SQL("SELECT MIN(id), MAX(id) FROM {}").format(Identifier(table)))
(min_id, max_id) = cursor.fetchall()[0] (min_id, max_id) = cursor.fetchone()
if min_id is None: if min_id is None:
return return
@ -31,18 +33,14 @@ def do_batch_update(cursor: CursorObj,
lower = min_id lower = min_id
upper = min_id + batch_size upper = min_id + batch_size
print(' Updating range [%s,%s)' % (lower, upper)) print(' Updating range [%s,%s)' % (lower, upper))
params = list(vals) + [lower, upper] cursor.execute(stmt, [lower, upper])
if escape:
cursor.execute(stmt, params=params)
else:
cursor.execute(stmt % tuple(params))
min_id = upper min_id = upper
time.sleep(sleep) time.sleep(sleep)
# Once we've finished, check if any new rows were inserted to the table # Once we've finished, check if any new rows were inserted to the table
if min_id > max_id: if min_id > max_id:
cursor.execute("SELECT MAX(id) FROM %s" % (table,)) cursor.execute(SQL("SELECT MAX(id) FROM {}").format(Identifier(table)))
max_id = cursor.fetchall()[0][0] (max_id,) = cursor.fetchone()
print(" Finishing...", end='') print(" Finishing...", end='')