diff --git a/pgroonga/migrations/0002_html_escape_subject.py b/pgroonga/migrations/0002_html_escape_subject.py
index b8c8d9b82b..776d4c99f2 100644
--- a/pgroonga/migrations/0002_html_escape_subject.py
+++ b/pgroonga/migrations/0002_html_escape_subject.py
@@ -1,15 +1,19 @@
from django.db import connection, migrations
from django.db.backends.postgresql.schema import DatabaseSchemaEditor
from django.db.migrations.state import StateApps
+from psycopg2.sql import SQL
from zerver.lib.migrate import do_batch_update
def rebuild_pgroonga_index(apps: StateApps, schema_editor: DatabaseSchemaEditor) -> None:
with connection.cursor() as cursor:
- do_batch_update(cursor, 'zerver_message', ['search_pgroonga'],
- ["escape_html(subject) || ' ' || rendered_content"],
- escape=False, batch_size=10000)
+ do_batch_update(
+ cursor,
+ "zerver_message",
+ [SQL("search_pgroonga = escape_html(subject) || ' ' || rendered_content")],
+ batch_size=10000,
+ )
class Migration(migrations.Migration):
atomic = False
diff --git a/zerver/lib/migrate.py b/zerver/lib/migrate.py
index 8fd6f700c6..be0830a888 100644
--- a/zerver/lib/migrate.py
+++ b/zerver/lib/migrate.py
@@ -1,4 +1,5 @@
from psycopg2.extensions import cursor
+from psycopg2.sql import Composable, Identifier, SQL
from typing import List, TypeVar
import time
@@ -8,21 +9,22 @@ CursorObj = TypeVar('CursorObj', bound=cursor)
def do_batch_update(cursor: CursorObj,
table: str,
- cols: List[str],
- vals: List[str],
+ assignments: List[Composable],
batch_size: int=10000,
- sleep: float=0.1,
- escape: bool=True) -> None: # nocoverage
+ sleep: float=0.1) -> None: # nocoverage
# The string substitution below is complicated by our need to
# support multiple postgres versions.
- stmt = '''
- UPDATE %s
- SET %s
- WHERE id >= %%s AND id < %%s
- ''' % (table, ', '.join(['%s = %%s' % (col) for col in cols]))
+ stmt = SQL('''
+ UPDATE {}
+ SET {}
+ WHERE id >= %s AND id < %s
+ ''').format(
+ Identifier(table),
+ SQL(', ').join(assignments),
+ )
- cursor.execute("SELECT MIN(id), MAX(id) FROM %s" % (table,))
- (min_id, max_id) = cursor.fetchall()[0]
+ cursor.execute(SQL("SELECT MIN(id), MAX(id) FROM {}").format(Identifier(table)))
+ (min_id, max_id) = cursor.fetchone()
if min_id is None:
return
@@ -31,18 +33,14 @@ def do_batch_update(cursor: CursorObj,
lower = min_id
upper = min_id + batch_size
print(' Updating range [%s,%s)' % (lower, upper))
- params = list(vals) + [lower, upper]
- if escape:
- cursor.execute(stmt, params=params)
- else:
- cursor.execute(stmt % tuple(params))
+ cursor.execute(stmt, [lower, upper])
min_id = upper
time.sleep(sleep)
# Once we've finished, check if any new rows were inserted to the table
if min_id > max_id:
- cursor.execute("SELECT MAX(id) FROM %s" % (table,))
- max_id = cursor.fetchall()[0][0]
+ cursor.execute(SQL("SELECT MAX(id) FROM {}").format(Identifier(table)))
+ (max_id,) = cursor.fetchone()
print(" Finishing...", end='')