fix_unreads: Use cursor.execute correctly.

Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
Anders Kaseorg 2020-06-09 02:57:01 -07:00 committed by Tim Abbott
parent 14bbfe6ffb
commit 3aab9c03a9
2 changed files with 38 additions and 40 deletions

View File

@ -3,6 +3,8 @@ import logging
from typing import Callable, List, TypeVar
from psycopg2.extensions import cursor
from psycopg2.sql import SQL
CursorObj = TypeVar('CursorObj', bound=cursor)
from django.db import connection
@ -25,7 +27,7 @@ def build_topic_mute_checker(cursor: CursorObj, user_profile: UserProfile) -> Ca
in zerver/lib/topic_mutes.py, but it works without the ORM,
so that we can use it in migrations.
'''
query = '''
query = SQL('''
SELECT
recipient_id,
topic_name
@ -33,7 +35,7 @@ def build_topic_mute_checker(cursor: CursorObj, user_profile: UserProfile) -> Ca
zerver_mutedtopic
WHERE
user_profile_id = %s
'''
''')
cursor.execute(query, [user_profile.id])
rows = cursor.fetchall()
@ -48,14 +50,13 @@ def build_topic_mute_checker(cursor: CursorObj, user_profile: UserProfile) -> Ca
return is_muted
def update_unread_flags(cursor: CursorObj, user_message_ids: List[int]) -> None:
um_id_list = ', '.join(str(id) for id in user_message_ids)
query = '''
query = SQL('''
UPDATE zerver_usermessage
SET flags = flags | 1
WHERE id IN (%s)
''' % (um_id_list,)
WHERE id IN %(user_message_ids)s
''')
cursor.execute(query)
cursor.execute(query, {"user_message_ids": tuple(user_message_ids)})
def get_timing(message: str, f: Callable[[], None]) -> None:
@ -71,7 +72,7 @@ def fix_unsubscribed(cursor: CursorObj, user_profile: UserProfile) -> None:
recipient_ids = []
def find_recipients() -> None:
query = '''
query = SQL('''
SELECT
zerver_subscription.recipient_id
FROM
@ -80,12 +81,12 @@ def fix_unsubscribed(cursor: CursorObj, user_profile: UserProfile) -> None:
zerver_recipient.id = zerver_subscription.recipient_id
)
WHERE (
zerver_subscription.user_profile_id = '%s' AND
zerver_subscription.user_profile_id = %(user_profile_id)s AND
zerver_recipient.type = 2 AND
(NOT zerver_subscription.active)
)
'''
cursor.execute(query, [user_profile.id])
''')
cursor.execute(query, {"user_profile_id": user_profile.id})
rows = cursor.fetchall()
for row in rows:
recipient_ids.append(row[0])
@ -102,9 +103,7 @@ def fix_unsubscribed(cursor: CursorObj, user_profile: UserProfile) -> None:
user_message_ids = []
def find() -> None:
recips = ', '.join(str(id) for id in recipient_ids)
query = '''
query = SQL('''
SELECT
zerver_usermessage.id
FROM
@ -113,16 +112,16 @@ def fix_unsubscribed(cursor: CursorObj, user_profile: UserProfile) -> None:
zerver_message.id = zerver_usermessage.message_id
)
WHERE (
zerver_usermessage.user_profile_id = %s AND
zerver_usermessage.user_profile_id = %(user_profile_id)s AND
(zerver_usermessage.flags & 1) = 0 AND
zerver_message.recipient_id in (%s)
zerver_message.recipient_id in %(recipient_ids)s
)
''' % (user_profile.id, recips)
''')
logger.info('''
EXPLAIN analyze''' + query.rstrip() + ';')
cursor.execute(query)
cursor.execute(query, {
"user_profile_id": user_profile.id,
"recipient_ids": tuple(recipient_ids),
})
rows = cursor.fetchall()
for row in rows:
user_message_ids.append(row[0])
@ -154,7 +153,7 @@ def fix_pre_pointer(cursor: CursorObj, user_profile: UserProfile) -> None:
recipient_ids = []
def find_non_muted_recipients() -> None:
query = '''
query = SQL('''
SELECT
zerver_subscription.recipient_id
FROM
@ -163,13 +162,13 @@ def fix_pre_pointer(cursor: CursorObj, user_profile: UserProfile) -> None:
zerver_recipient.id = zerver_subscription.recipient_id
)
WHERE (
zerver_subscription.user_profile_id = '%s' AND
zerver_subscription.user_profile_id = %(user_profile_id)s AND
zerver_recipient.type = 2 AND
(NOT zerver_subscription.is_muted) AND
zerver_subscription.active
)
'''
cursor.execute(query, [user_profile.id])
''')
cursor.execute(query, {"user_profile_id": user_profile.id})
rows = cursor.fetchall()
for row in rows:
recipient_ids.append(row[0])
@ -186,11 +185,9 @@ def fix_pre_pointer(cursor: CursorObj, user_profile: UserProfile) -> None:
user_message_ids = []
def find_old_ids() -> None:
recips = ', '.join(str(id) for id in recipient_ids)
is_topic_muted = build_topic_mute_checker(cursor, user_profile)
query = '''
query = SQL('''
SELECT
zerver_usermessage.id,
zerver_message.recipient_id,
@ -201,17 +198,18 @@ def fix_pre_pointer(cursor: CursorObj, user_profile: UserProfile) -> None:
zerver_message.id = zerver_usermessage.message_id
)
WHERE (
zerver_usermessage.user_profile_id = %s AND
zerver_usermessage.message_id <= %s AND
zerver_usermessage.user_profile_id = %(user_profile_id)s AND
zerver_usermessage.message_id <= %(pointer)s AND
(zerver_usermessage.flags & 1) = 0 AND
zerver_message.recipient_id in (%s)
zerver_message.recipient_id in %(recipient_ids)s
)
''' % (user_profile.id, pointer, recips)
''')
logger.info('''
EXPLAIN analyze''' + query.rstrip() + ';')
cursor.execute(query)
cursor.execute(query, {
"user_profile_id": user_profile.id,
"pointer": pointer,
"recipient_ids": tuple(recipient_ids),
})
rows = cursor.fetchall()
for (um_id, recipient_id, topic) in rows:
if not is_topic_muted(recipient_id, topic):

View File

@ -1292,11 +1292,11 @@ def import_attachments(data: TableData) -> None:
# better way to do this in Django 1.9 particularly.
with connection.cursor() as cursor:
sql_template = SQL('''
insert into {} ({}, {}) values %s
INSERT INTO {m2m_table_name} ({parent_id}, {child_id}) VALUES %s
''').format(
Identifier(m2m_table_name),
Identifier(parent_id),
Identifier(child_id),
m2m_table_name=Identifier(m2m_table_name),
parent_id=Identifier(parent_id),
child_id=Identifier(child_id),
)
tups = [(row[parent_id], row[child_id]) for row in m2m_rows]
execute_values(cursor.cursor, sql_template, tups)