migrate: Improve do_batch_update escaping correctness with psycopg2.sql.

Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
Anders Kaseorg 2020-05-03 16:15:36 -07:00 committed by Tim Abbott
parent cebac3f35a
commit ffe5402c49
2 changed files with 23 additions and 21 deletions

View File

@ -1,15 +1,19 @@
from django.db import connection, migrations
from django.db.backends.postgresql.schema import DatabaseSchemaEditor
from django.db.migrations.state import StateApps
from psycopg2.sql import SQL
from zerver.lib.migrate import do_batch_update
def rebuild_pgroonga_index(apps: StateApps, schema_editor: DatabaseSchemaEditor) -> None:
with connection.cursor() as cursor:
do_batch_update(cursor, 'zerver_message', ['search_pgroonga'],
["escape_html(subject) || ' ' || rendered_content"],
escape=False, batch_size=10000)
do_batch_update(
cursor,
"zerver_message",
[SQL("search_pgroonga = escape_html(subject) || ' ' || rendered_content")],
batch_size=10000,
)
class Migration(migrations.Migration):
atomic = False

View File

@ -1,4 +1,5 @@
from psycopg2.extensions import cursor
from psycopg2.sql import Composable, Identifier, SQL
from typing import List, TypeVar
import time
@ -8,21 +9,22 @@ CursorObj = TypeVar('CursorObj', bound=cursor)
def do_batch_update(cursor: CursorObj,
table: str,
cols: List[str],
vals: List[str],
assignments: List[Composable],
batch_size: int=10000,
sleep: float=0.1,
escape: bool=True) -> None: # nocoverage
sleep: float=0.1) -> None: # nocoverage
# The string substitution below is complicated by our need to
# support multiple postgres versions.
stmt = '''
UPDATE %s
SET %s
WHERE id >= %%s AND id < %%s
''' % (table, ', '.join(['%s = %%s' % (col) for col in cols]))
stmt = SQL('''
UPDATE {}
SET {}
WHERE id >= %s AND id < %s
''').format(
Identifier(table),
SQL(', ').join(assignments),
)
cursor.execute("SELECT MIN(id), MAX(id) FROM %s" % (table,))
(min_id, max_id) = cursor.fetchall()[0]
cursor.execute(SQL("SELECT MIN(id), MAX(id) FROM {}").format(Identifier(table)))
(min_id, max_id) = cursor.fetchone()
if min_id is None:
return
@ -31,18 +33,14 @@ def do_batch_update(cursor: CursorObj,
lower = min_id
upper = min_id + batch_size
print(' Updating range [%s,%s)' % (lower, upper))
params = list(vals) + [lower, upper]
if escape:
cursor.execute(stmt, params=params)
else:
cursor.execute(stmt % tuple(params))
cursor.execute(stmt, [lower, upper])
min_id = upper
time.sleep(sleep)
# Once we've finished, check if any new rows were inserted to the table
if min_id > max_id:
cursor.execute("SELECT MAX(id) FROM %s" % (table,))
max_id = cursor.fetchall()[0][0]
cursor.execute(SQL("SELECT MAX(id) FROM {}").format(Identifier(table)))
(max_id,) = cursor.fetchone()
print(" Finishing...", end='')