diff --git a/pgroonga/migrations/0002_html_escape_subject.py b/pgroonga/migrations/0002_html_escape_subject.py
index eecf8969fb..f5cb23bd55 100644
--- a/pgroonga/migrations/0002_html_escape_subject.py
+++ b/pgroonga/migrations/0002_html_escape_subject.py
@@ -3,19 +3,13 @@ from django.db import models, migrations, connection
from django.contrib.postgres import operations
from django.db.backends.postgresql_psycopg2.schema import DatabaseSchemaEditor
from django.db.migrations.state import StateApps
+from zerver.lib.migrate import do_batch_update
def rebuild_pgroonga_index(apps: StateApps, schema_editor: DatabaseSchemaEditor) -> None:
- BATCH_SIZE = 10000
-
- Message = apps.get_model("zerver", "Message")
- message_ids = Message.objects.values_list('id', flat=True)
with connection.cursor() as cursor:
- for i in range(0, len(message_ids), BATCH_SIZE):
- batch_ids = ', '.join(str(id) for id in message_ids[i:i+BATCH_SIZE])
- cursor.execute("UPDATE zerver_message SET "
- "search_pgroonga = "
- "escape_html(subject) || ' ' || rendered_content "
- "WHERE id IN (%s)" % (batch_ids,))
+ do_batch_update(cursor, 'zerver_message', ['search_pgroonga'],
+ ["escape_html(subject) || ' ' || rendered_content"],
+ escape=False, batch_size=10000)
class Migration(migrations.Migration):
atomic = False
diff --git a/tools/linter_lib/custom_check.py b/tools/linter_lib/custom_check.py
index 4b9ff315b1..67e5178c75 100644
--- a/tools/linter_lib/custom_check.py
+++ b/tools/linter_lib/custom_check.py
@@ -449,6 +449,7 @@ def build_custom_checkers(by_lang):
'zerver/migrations/0041_create_attachments_for_old_messages.py',
'zerver/migrations/0060_move_avatars_to_be_uid_based.py',
'zerver/migrations/0104_fix_unreads.py',
+ 'pgroonga/migrations/0002_html_escape_subject.py',
]),
'description': "Don't import models or other code in migrations; see docs/subsystems/schema-migrations.md",
},
diff --git a/zerver/lib/migrate.py b/zerver/lib/migrate.py
index 46c4bdc5f4..d812abb264 100644
--- a/zerver/lib/migrate.py
+++ b/zerver/lib/migrate.py
@@ -1,8 +1,12 @@
-from typing import Any, Callable, Dict, List, Tuple
from django.db.models.query import QuerySet
+from psycopg2.extensions import cursor
+from typing import Any, Callable, Dict, List, Tuple, TypeVar
+
import re
import time
+CursorObj = TypeVar('CursorObj', bound=cursor)
+
def create_index_if_not_exist(index_name: str, table_name: str, column_string: str,
where_clause: str) -> str:
#
@@ -25,3 +29,43 @@ def create_index_if_not_exist(index_name: str, table_name: str, column_string: s
END$$;
''' % (index_name, index_name, table_name, column_string, where_clause)
return stmt
+
+
+def do_batch_update(cursor: CursorObj,
+ table: str,
+ cols: List[str],
+ vals: List[str],
+ batch_size: int=10000,
+ sleep: float=0.1,
+ escape: bool=True) -> None: # nocoverage
+ stmt = '''
+ UPDATE %s
+ SET (%s) = (%s)
+ WHERE id >= %%s AND id < %%s
+ ''' % (table, ', '.join(cols), ', '.join(['%s'] * len(cols)))
+
+ cursor.execute("SELECT MIN(id), MAX(id) FROM %s" % (table,))
+ (min_id, max_id) = cursor.fetchall()[0]
+ if min_id is None:
+ return
+
+ print("\n Range of rows to update: [%s, %s]" % (min_id, max_id))
+ while min_id <= max_id:
+ lower = min_id
+ upper = min_id + batch_size
+ print(' Updating range [%s,%s)' % (lower, upper))
+ params = list(vals) + [lower, upper]
+ if escape:
+ cursor.execute(stmt, params=params)
+ else:
+ cursor.execute(stmt % tuple(params))
+
+ min_id = upper
+ time.sleep(sleep)
+
+ # Once we've finished, check if any new rows were inserted to the table
+ if min_id > max_id:
+ cursor.execute("SELECT MAX(id) FROM %s" % (table,))
+ max_id = cursor.fetchall()[0][0]
+
+ print(" Finishing...", end='')