diff --git a/pgroonga/migrations/0002_html_escape_subject.py b/pgroonga/migrations/0002_html_escape_subject.py index b8c8d9b82b..776d4c99f2 100644 --- a/pgroonga/migrations/0002_html_escape_subject.py +++ b/pgroonga/migrations/0002_html_escape_subject.py @@ -1,15 +1,19 @@ from django.db import connection, migrations from django.db.backends.postgresql.schema import DatabaseSchemaEditor from django.db.migrations.state import StateApps +from psycopg2.sql import SQL from zerver.lib.migrate import do_batch_update def rebuild_pgroonga_index(apps: StateApps, schema_editor: DatabaseSchemaEditor) -> None: with connection.cursor() as cursor: - do_batch_update(cursor, 'zerver_message', ['search_pgroonga'], - ["escape_html(subject) || ' ' || rendered_content"], - escape=False, batch_size=10000) + do_batch_update( + cursor, + "zerver_message", + [SQL("search_pgroonga = escape_html(subject) || ' ' || rendered_content")], + batch_size=10000, + ) class Migration(migrations.Migration): atomic = False diff --git a/zerver/lib/migrate.py b/zerver/lib/migrate.py index 8fd6f700c6..be0830a888 100644 --- a/zerver/lib/migrate.py +++ b/zerver/lib/migrate.py @@ -1,4 +1,5 @@ from psycopg2.extensions import cursor +from psycopg2.sql import Composable, Identifier, SQL from typing import List, TypeVar import time @@ -8,21 +9,22 @@ CursorObj = TypeVar('CursorObj', bound=cursor) def do_batch_update(cursor: CursorObj, table: str, - cols: List[str], - vals: List[str], + assignments: List[Composable], batch_size: int=10000, - sleep: float=0.1, - escape: bool=True) -> None: # nocoverage + sleep: float=0.1) -> None: # nocoverage # The string substitution below is complicated by our need to # support multiple postgres versions. - stmt = ''' - UPDATE %s - SET %s - WHERE id >= %%s AND id < %%s - ''' % (table, ', '.join(['%s = %%s' % (col) for col in cols])) + stmt = SQL(''' + UPDATE {} + SET {} + WHERE id >= %s AND id < %s + ''').format( + Identifier(table), + SQL(', ').join(assignments), + ) - cursor.execute("SELECT MIN(id), MAX(id) FROM %s" % (table,)) - (min_id, max_id) = cursor.fetchall()[0] + cursor.execute(SQL("SELECT MIN(id), MAX(id) FROM {}").format(Identifier(table))) + (min_id, max_id) = cursor.fetchone() if min_id is None: return @@ -31,18 +33,14 @@ def do_batch_update(cursor: CursorObj, lower = min_id upper = min_id + batch_size print(' Updating range [%s,%s)' % (lower, upper)) - params = list(vals) + [lower, upper] - if escape: - cursor.execute(stmt, params=params) - else: - cursor.execute(stmt % tuple(params)) + cursor.execute(stmt, [lower, upper]) min_id = upper time.sleep(sleep) # Once we've finished, check if any new rows were inserted to the table if min_id > max_id: - cursor.execute("SELECT MAX(id) FROM %s" % (table,)) - max_id = cursor.fetchall()[0][0] + cursor.execute(SQL("SELECT MAX(id) FROM {}").format(Identifier(table))) + (max_id,) = cursor.fetchone() print(" Finishing...", end='')