import logging import time from typing import Callable, List, TypeVar from psycopg2.extensions import cursor from psycopg2.sql import SQL CursorObj = TypeVar('CursorObj', bound=cursor) from django.db import connection from zerver.models import UserProfile ''' NOTE! Be careful modifying this library, as it is used in a migration, and it needs to be valid for the state of the database that is in place when the 0104_fix_unreads migration runs. ''' logger = logging.getLogger('zulip.fix_unreads') logger.setLevel(logging.WARNING) def build_topic_mute_checker(cursor: CursorObj, user_profile: UserProfile) -> Callable[[int, str], bool]: ''' This function is similar to the function of the same name in zerver/lib/topic_mutes.py, but it works without the ORM, so that we can use it in migrations. ''' query = SQL(''' SELECT recipient_id, topic_name FROM zerver_mutedtopic WHERE user_profile_id = %s ''') cursor.execute(query, [user_profile.id]) rows = cursor.fetchall() tups = { (recipient_id, topic_name.lower()) for (recipient_id, topic_name) in rows } def is_muted(recipient_id: int, topic: str) -> bool: return (recipient_id, topic.lower()) in tups return is_muted def update_unread_flags(cursor: CursorObj, user_message_ids: List[int]) -> None: query = SQL(''' UPDATE zerver_usermessage SET flags = flags | 1 WHERE id IN %(user_message_ids)s ''') cursor.execute(query, {"user_message_ids": tuple(user_message_ids)}) def get_timing(message: str, f: Callable[[], None]) -> None: start = time.time() logger.info(message) f() elapsed = time.time() - start logger.info('elapsed time: %.03f\n', elapsed) def fix_unsubscribed(cursor: CursorObj, user_profile: UserProfile) -> None: recipient_ids = [] def find_recipients() -> None: query = SQL(''' SELECT zerver_subscription.recipient_id FROM zerver_subscription INNER JOIN zerver_recipient ON ( zerver_recipient.id = zerver_subscription.recipient_id ) WHERE ( zerver_subscription.user_profile_id = %(user_profile_id)s AND zerver_recipient.type = 2 AND (NOT zerver_subscription.active) ) ''') cursor.execute(query, {"user_profile_id": user_profile.id}) rows = cursor.fetchall() for row in rows: recipient_ids.append(row[0]) logger.info(str(recipient_ids)) get_timing( 'get recipients', find_recipients, ) if not recipient_ids: return user_message_ids = [] def find() -> None: query = SQL(''' SELECT zerver_usermessage.id FROM zerver_usermessage INNER JOIN zerver_message ON ( zerver_message.id = zerver_usermessage.message_id ) WHERE ( zerver_usermessage.user_profile_id = %(user_profile_id)s AND (zerver_usermessage.flags & 1) = 0 AND zerver_message.recipient_id in %(recipient_ids)s ) ''') cursor.execute(query, { "user_profile_id": user_profile.id, "recipient_ids": tuple(recipient_ids), }) rows = cursor.fetchall() for row in rows: user_message_ids.append(row[0]) logger.info('rows found: %d', len(user_message_ids)) get_timing( 'finding unread messages for non-active streams', find, ) if not user_message_ids: return def fix() -> None: update_unread_flags(cursor, user_message_ids) get_timing( 'fixing unread messages for non-active streams', fix, ) def fix_pre_pointer(cursor: CursorObj, user_profile: UserProfile) -> None: pointer = user_profile.pointer if not pointer: return recipient_ids = [] def find_non_muted_recipients() -> None: query = SQL(''' SELECT zerver_subscription.recipient_id FROM zerver_subscription INNER JOIN zerver_recipient ON ( zerver_recipient.id = zerver_subscription.recipient_id ) WHERE ( zerver_subscription.user_profile_id = %(user_profile_id)s AND zerver_recipient.type = 2 AND (NOT zerver_subscription.is_muted) AND zerver_subscription.active ) ''') cursor.execute(query, {"user_profile_id": user_profile.id}) rows = cursor.fetchall() for row in rows: recipient_ids.append(row[0]) logger.info(str(recipient_ids)) get_timing( 'find_non_muted_recipients', find_non_muted_recipients, ) if not recipient_ids: return user_message_ids = [] def find_old_ids() -> None: is_topic_muted = build_topic_mute_checker(cursor, user_profile) query = SQL(''' SELECT zerver_usermessage.id, zerver_message.recipient_id, zerver_message.subject FROM zerver_usermessage INNER JOIN zerver_message ON ( zerver_message.id = zerver_usermessage.message_id ) WHERE ( zerver_usermessage.user_profile_id = %(user_profile_id)s AND zerver_usermessage.message_id <= %(pointer)s AND (zerver_usermessage.flags & 1) = 0 AND zerver_message.recipient_id in %(recipient_ids)s ) ''') cursor.execute(query, { "user_profile_id": user_profile.id, "pointer": pointer, "recipient_ids": tuple(recipient_ids), }) rows = cursor.fetchall() for (um_id, recipient_id, topic) in rows: if not is_topic_muted(recipient_id, topic): user_message_ids.append(um_id) logger.info('rows found: %d', len(user_message_ids)) get_timing( 'finding pre-pointer messages that are not muted', find_old_ids, ) if not user_message_ids: return def fix() -> None: update_unread_flags(cursor, user_message_ids) get_timing( 'fixing unread messages for pre-pointer non-muted messages', fix, ) def fix(user_profile: UserProfile) -> None: logger.info('\n---\nFixing %s:', user_profile.id) with connection.cursor() as cursor: fix_unsubscribed(cursor, user_profile) fix_pre_pointer(cursor, user_profile)