mirror of https://github.com/zulip/zulip.git
migrate: Improve do_batch_update escaping correctness with psycopg2.sql.
Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
parent
cebac3f35a
commit
ffe5402c49
|
@ -1,15 +1,19 @@
|
||||||
from django.db import connection, migrations
|
from django.db import connection, migrations
|
||||||
from django.db.backends.postgresql.schema import DatabaseSchemaEditor
|
from django.db.backends.postgresql.schema import DatabaseSchemaEditor
|
||||||
from django.db.migrations.state import StateApps
|
from django.db.migrations.state import StateApps
|
||||||
|
from psycopg2.sql import SQL
|
||||||
|
|
||||||
from zerver.lib.migrate import do_batch_update
|
from zerver.lib.migrate import do_batch_update
|
||||||
|
|
||||||
|
|
||||||
def rebuild_pgroonga_index(apps: StateApps, schema_editor: DatabaseSchemaEditor) -> None:
|
def rebuild_pgroonga_index(apps: StateApps, schema_editor: DatabaseSchemaEditor) -> None:
|
||||||
with connection.cursor() as cursor:
|
with connection.cursor() as cursor:
|
||||||
do_batch_update(cursor, 'zerver_message', ['search_pgroonga'],
|
do_batch_update(
|
||||||
["escape_html(subject) || ' ' || rendered_content"],
|
cursor,
|
||||||
escape=False, batch_size=10000)
|
"zerver_message",
|
||||||
|
[SQL("search_pgroonga = escape_html(subject) || ' ' || rendered_content")],
|
||||||
|
batch_size=10000,
|
||||||
|
)
|
||||||
|
|
||||||
class Migration(migrations.Migration):
|
class Migration(migrations.Migration):
|
||||||
atomic = False
|
atomic = False
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from psycopg2.extensions import cursor
|
from psycopg2.extensions import cursor
|
||||||
|
from psycopg2.sql import Composable, Identifier, SQL
|
||||||
from typing import List, TypeVar
|
from typing import List, TypeVar
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
@ -8,21 +9,22 @@ CursorObj = TypeVar('CursorObj', bound=cursor)
|
||||||
|
|
||||||
def do_batch_update(cursor: CursorObj,
|
def do_batch_update(cursor: CursorObj,
|
||||||
table: str,
|
table: str,
|
||||||
cols: List[str],
|
assignments: List[Composable],
|
||||||
vals: List[str],
|
|
||||||
batch_size: int=10000,
|
batch_size: int=10000,
|
||||||
sleep: float=0.1,
|
sleep: float=0.1) -> None: # nocoverage
|
||||||
escape: bool=True) -> None: # nocoverage
|
|
||||||
# The string substitution below is complicated by our need to
|
# The string substitution below is complicated by our need to
|
||||||
# support multiple postgres versions.
|
# support multiple postgres versions.
|
||||||
stmt = '''
|
stmt = SQL('''
|
||||||
UPDATE %s
|
UPDATE {}
|
||||||
SET %s
|
SET {}
|
||||||
WHERE id >= %%s AND id < %%s
|
WHERE id >= %s AND id < %s
|
||||||
''' % (table, ', '.join(['%s = %%s' % (col) for col in cols]))
|
''').format(
|
||||||
|
Identifier(table),
|
||||||
|
SQL(', ').join(assignments),
|
||||||
|
)
|
||||||
|
|
||||||
cursor.execute("SELECT MIN(id), MAX(id) FROM %s" % (table,))
|
cursor.execute(SQL("SELECT MIN(id), MAX(id) FROM {}").format(Identifier(table)))
|
||||||
(min_id, max_id) = cursor.fetchall()[0]
|
(min_id, max_id) = cursor.fetchone()
|
||||||
if min_id is None:
|
if min_id is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -31,18 +33,14 @@ def do_batch_update(cursor: CursorObj,
|
||||||
lower = min_id
|
lower = min_id
|
||||||
upper = min_id + batch_size
|
upper = min_id + batch_size
|
||||||
print(' Updating range [%s,%s)' % (lower, upper))
|
print(' Updating range [%s,%s)' % (lower, upper))
|
||||||
params = list(vals) + [lower, upper]
|
cursor.execute(stmt, [lower, upper])
|
||||||
if escape:
|
|
||||||
cursor.execute(stmt, params=params)
|
|
||||||
else:
|
|
||||||
cursor.execute(stmt % tuple(params))
|
|
||||||
|
|
||||||
min_id = upper
|
min_id = upper
|
||||||
time.sleep(sleep)
|
time.sleep(sleep)
|
||||||
|
|
||||||
# Once we've finished, check if any new rows were inserted to the table
|
# Once we've finished, check if any new rows were inserted to the table
|
||||||
if min_id > max_id:
|
if min_id > max_id:
|
||||||
cursor.execute("SELECT MAX(id) FROM %s" % (table,))
|
cursor.execute(SQL("SELECT MAX(id) FROM {}").format(Identifier(table)))
|
||||||
max_id = cursor.fetchall()[0][0]
|
(max_id,) = cursor.fetchone()
|
||||||
|
|
||||||
print(" Finishing...", end='')
|
print(" Finishing...", end='')
|
||||||
|
|
Loading…
Reference in New Issue