mirror of https://github.com/zulip/zulip.git
migrate: Improve do_batch_update escaping correctness with psycopg2.sql.
Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
parent
cebac3f35a
commit
ffe5402c49
|
@ -1,15 +1,19 @@
|
|||
from django.db import connection, migrations
|
||||
from django.db.backends.postgresql.schema import DatabaseSchemaEditor
|
||||
from django.db.migrations.state import StateApps
|
||||
from psycopg2.sql import SQL
|
||||
|
||||
from zerver.lib.migrate import do_batch_update
|
||||
|
||||
|
||||
def rebuild_pgroonga_index(apps: StateApps, schema_editor: DatabaseSchemaEditor) -> None:
|
||||
with connection.cursor() as cursor:
|
||||
do_batch_update(cursor, 'zerver_message', ['search_pgroonga'],
|
||||
["escape_html(subject) || ' ' || rendered_content"],
|
||||
escape=False, batch_size=10000)
|
||||
do_batch_update(
|
||||
cursor,
|
||||
"zerver_message",
|
||||
[SQL("search_pgroonga = escape_html(subject) || ' ' || rendered_content")],
|
||||
batch_size=10000,
|
||||
)
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
atomic = False
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from psycopg2.extensions import cursor
|
||||
from psycopg2.sql import Composable, Identifier, SQL
|
||||
from typing import List, TypeVar
|
||||
|
||||
import time
|
||||
|
@ -8,21 +9,22 @@ CursorObj = TypeVar('CursorObj', bound=cursor)
|
|||
|
||||
def do_batch_update(cursor: CursorObj,
|
||||
table: str,
|
||||
cols: List[str],
|
||||
vals: List[str],
|
||||
assignments: List[Composable],
|
||||
batch_size: int=10000,
|
||||
sleep: float=0.1,
|
||||
escape: bool=True) -> None: # nocoverage
|
||||
sleep: float=0.1) -> None: # nocoverage
|
||||
# The string substitution below is complicated by our need to
|
||||
# support multiple postgres versions.
|
||||
stmt = '''
|
||||
UPDATE %s
|
||||
SET %s
|
||||
WHERE id >= %%s AND id < %%s
|
||||
''' % (table, ', '.join(['%s = %%s' % (col) for col in cols]))
|
||||
stmt = SQL('''
|
||||
UPDATE {}
|
||||
SET {}
|
||||
WHERE id >= %s AND id < %s
|
||||
''').format(
|
||||
Identifier(table),
|
||||
SQL(', ').join(assignments),
|
||||
)
|
||||
|
||||
cursor.execute("SELECT MIN(id), MAX(id) FROM %s" % (table,))
|
||||
(min_id, max_id) = cursor.fetchall()[0]
|
||||
cursor.execute(SQL("SELECT MIN(id), MAX(id) FROM {}").format(Identifier(table)))
|
||||
(min_id, max_id) = cursor.fetchone()
|
||||
if min_id is None:
|
||||
return
|
||||
|
||||
|
@ -31,18 +33,14 @@ def do_batch_update(cursor: CursorObj,
|
|||
lower = min_id
|
||||
upper = min_id + batch_size
|
||||
print(' Updating range [%s,%s)' % (lower, upper))
|
||||
params = list(vals) + [lower, upper]
|
||||
if escape:
|
||||
cursor.execute(stmt, params=params)
|
||||
else:
|
||||
cursor.execute(stmt % tuple(params))
|
||||
cursor.execute(stmt, [lower, upper])
|
||||
|
||||
min_id = upper
|
||||
time.sleep(sleep)
|
||||
|
||||
# Once we've finished, check if any new rows were inserted to the table
|
||||
if min_id > max_id:
|
||||
cursor.execute("SELECT MAX(id) FROM %s" % (table,))
|
||||
max_id = cursor.fetchall()[0][0]
|
||||
cursor.execute(SQL("SELECT MAX(id) FROM {}").format(Identifier(table)))
|
||||
(max_id,) = cursor.fetchone()
|
||||
|
||||
print(" Finishing...", end='')
|
||||
|
|
Loading…
Reference in New Issue