fix_unreads: Use cursor.execute correctly.

Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
Anders Kaseorg 2020-06-09 02:57:01 -07:00 committed by Tim Abbott
parent 14bbfe6ffb
commit 3aab9c03a9
2 changed files with 38 additions and 40 deletions

View File

@ -3,6 +3,8 @@ import logging
from typing import Callable, List, TypeVar from typing import Callable, List, TypeVar
from psycopg2.extensions import cursor from psycopg2.extensions import cursor
from psycopg2.sql import SQL
CursorObj = TypeVar('CursorObj', bound=cursor) CursorObj = TypeVar('CursorObj', bound=cursor)
from django.db import connection from django.db import connection
@ -25,7 +27,7 @@ def build_topic_mute_checker(cursor: CursorObj, user_profile: UserProfile) -> Ca
in zerver/lib/topic_mutes.py, but it works without the ORM, in zerver/lib/topic_mutes.py, but it works without the ORM,
so that we can use it in migrations. so that we can use it in migrations.
''' '''
query = ''' query = SQL('''
SELECT SELECT
recipient_id, recipient_id,
topic_name topic_name
@ -33,7 +35,7 @@ def build_topic_mute_checker(cursor: CursorObj, user_profile: UserProfile) -> Ca
zerver_mutedtopic zerver_mutedtopic
WHERE WHERE
user_profile_id = %s user_profile_id = %s
''' ''')
cursor.execute(query, [user_profile.id]) cursor.execute(query, [user_profile.id])
rows = cursor.fetchall() rows = cursor.fetchall()
@ -48,14 +50,13 @@ def build_topic_mute_checker(cursor: CursorObj, user_profile: UserProfile) -> Ca
return is_muted return is_muted
def update_unread_flags(cursor: CursorObj, user_message_ids: List[int]) -> None: def update_unread_flags(cursor: CursorObj, user_message_ids: List[int]) -> None:
um_id_list = ', '.join(str(id) for id in user_message_ids) query = SQL('''
query = '''
UPDATE zerver_usermessage UPDATE zerver_usermessage
SET flags = flags | 1 SET flags = flags | 1
WHERE id IN (%s) WHERE id IN %(user_message_ids)s
''' % (um_id_list,) ''')
cursor.execute(query) cursor.execute(query, {"user_message_ids": tuple(user_message_ids)})
def get_timing(message: str, f: Callable[[], None]) -> None: def get_timing(message: str, f: Callable[[], None]) -> None:
@ -71,7 +72,7 @@ def fix_unsubscribed(cursor: CursorObj, user_profile: UserProfile) -> None:
recipient_ids = [] recipient_ids = []
def find_recipients() -> None: def find_recipients() -> None:
query = ''' query = SQL('''
SELECT SELECT
zerver_subscription.recipient_id zerver_subscription.recipient_id
FROM FROM
@ -80,12 +81,12 @@ def fix_unsubscribed(cursor: CursorObj, user_profile: UserProfile) -> None:
zerver_recipient.id = zerver_subscription.recipient_id zerver_recipient.id = zerver_subscription.recipient_id
) )
WHERE ( WHERE (
zerver_subscription.user_profile_id = '%s' AND zerver_subscription.user_profile_id = %(user_profile_id)s AND
zerver_recipient.type = 2 AND zerver_recipient.type = 2 AND
(NOT zerver_subscription.active) (NOT zerver_subscription.active)
) )
''' ''')
cursor.execute(query, [user_profile.id]) cursor.execute(query, {"user_profile_id": user_profile.id})
rows = cursor.fetchall() rows = cursor.fetchall()
for row in rows: for row in rows:
recipient_ids.append(row[0]) recipient_ids.append(row[0])
@ -102,9 +103,7 @@ def fix_unsubscribed(cursor: CursorObj, user_profile: UserProfile) -> None:
user_message_ids = [] user_message_ids = []
def find() -> None: def find() -> None:
recips = ', '.join(str(id) for id in recipient_ids) query = SQL('''
query = '''
SELECT SELECT
zerver_usermessage.id zerver_usermessage.id
FROM FROM
@ -113,16 +112,16 @@ def fix_unsubscribed(cursor: CursorObj, user_profile: UserProfile) -> None:
zerver_message.id = zerver_usermessage.message_id zerver_message.id = zerver_usermessage.message_id
) )
WHERE ( WHERE (
zerver_usermessage.user_profile_id = %s AND zerver_usermessage.user_profile_id = %(user_profile_id)s AND
(zerver_usermessage.flags & 1) = 0 AND (zerver_usermessage.flags & 1) = 0 AND
zerver_message.recipient_id in (%s) zerver_message.recipient_id in %(recipient_ids)s
) )
''' % (user_profile.id, recips) ''')
logger.info(''' cursor.execute(query, {
EXPLAIN analyze''' + query.rstrip() + ';') "user_profile_id": user_profile.id,
"recipient_ids": tuple(recipient_ids),
cursor.execute(query) })
rows = cursor.fetchall() rows = cursor.fetchall()
for row in rows: for row in rows:
user_message_ids.append(row[0]) user_message_ids.append(row[0])
@ -154,7 +153,7 @@ def fix_pre_pointer(cursor: CursorObj, user_profile: UserProfile) -> None:
recipient_ids = [] recipient_ids = []
def find_non_muted_recipients() -> None: def find_non_muted_recipients() -> None:
query = ''' query = SQL('''
SELECT SELECT
zerver_subscription.recipient_id zerver_subscription.recipient_id
FROM FROM
@ -163,13 +162,13 @@ def fix_pre_pointer(cursor: CursorObj, user_profile: UserProfile) -> None:
zerver_recipient.id = zerver_subscription.recipient_id zerver_recipient.id = zerver_subscription.recipient_id
) )
WHERE ( WHERE (
zerver_subscription.user_profile_id = '%s' AND zerver_subscription.user_profile_id = %(user_profile_id)s AND
zerver_recipient.type = 2 AND zerver_recipient.type = 2 AND
(NOT zerver_subscription.is_muted) AND (NOT zerver_subscription.is_muted) AND
zerver_subscription.active zerver_subscription.active
) )
''' ''')
cursor.execute(query, [user_profile.id]) cursor.execute(query, {"user_profile_id": user_profile.id})
rows = cursor.fetchall() rows = cursor.fetchall()
for row in rows: for row in rows:
recipient_ids.append(row[0]) recipient_ids.append(row[0])
@ -186,11 +185,9 @@ def fix_pre_pointer(cursor: CursorObj, user_profile: UserProfile) -> None:
user_message_ids = [] user_message_ids = []
def find_old_ids() -> None: def find_old_ids() -> None:
recips = ', '.join(str(id) for id in recipient_ids)
is_topic_muted = build_topic_mute_checker(cursor, user_profile) is_topic_muted = build_topic_mute_checker(cursor, user_profile)
query = ''' query = SQL('''
SELECT SELECT
zerver_usermessage.id, zerver_usermessage.id,
zerver_message.recipient_id, zerver_message.recipient_id,
@ -201,17 +198,18 @@ def fix_pre_pointer(cursor: CursorObj, user_profile: UserProfile) -> None:
zerver_message.id = zerver_usermessage.message_id zerver_message.id = zerver_usermessage.message_id
) )
WHERE ( WHERE (
zerver_usermessage.user_profile_id = %s AND zerver_usermessage.user_profile_id = %(user_profile_id)s AND
zerver_usermessage.message_id <= %s AND zerver_usermessage.message_id <= %(pointer)s AND
(zerver_usermessage.flags & 1) = 0 AND (zerver_usermessage.flags & 1) = 0 AND
zerver_message.recipient_id in (%s) zerver_message.recipient_id in %(recipient_ids)s
) )
''' % (user_profile.id, pointer, recips) ''')
logger.info(''' cursor.execute(query, {
EXPLAIN analyze''' + query.rstrip() + ';') "user_profile_id": user_profile.id,
"pointer": pointer,
cursor.execute(query) "recipient_ids": tuple(recipient_ids),
})
rows = cursor.fetchall() rows = cursor.fetchall()
for (um_id, recipient_id, topic) in rows: for (um_id, recipient_id, topic) in rows:
if not is_topic_muted(recipient_id, topic): if not is_topic_muted(recipient_id, topic):

View File

@ -1292,11 +1292,11 @@ def import_attachments(data: TableData) -> None:
# better way to do this in Django 1.9 particularly. # better way to do this in Django 1.9 particularly.
with connection.cursor() as cursor: with connection.cursor() as cursor:
sql_template = SQL(''' sql_template = SQL('''
insert into {} ({}, {}) values %s INSERT INTO {m2m_table_name} ({parent_id}, {child_id}) VALUES %s
''').format( ''').format(
Identifier(m2m_table_name), m2m_table_name=Identifier(m2m_table_name),
Identifier(parent_id), parent_id=Identifier(parent_id),
Identifier(child_id), child_id=Identifier(child_id),
) )
tups = [(row[parent_id], row[child_id]) for row in m2m_rows] tups = [(row[parent_id], row[child_id]) for row in m2m_rows]
execute_values(cursor.cursor, sql_template, tups) execute_values(cursor.cursor, sql_template, tups)