From 71427239d00b96ed8cc4ed061e81e5faa681c435 Mon Sep 17 00:00:00 2001 From: PIG208 <359101898@qq.com> Date: Thu, 19 Aug 2021 01:40:01 +0800 Subject: [PATCH] typing: Replace CursorObj by CursorWrapper. --- zerver/lib/fix_unreads.py | 15 ++++++--------- zerver/lib/migrate.py | 8 +++----- 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/zerver/lib/fix_unreads.py b/zerver/lib/fix_unreads.py index 247a766670..a29c462bb8 100644 --- a/zerver/lib/fix_unreads.py +++ b/zerver/lib/fix_unreads.py @@ -1,13 +1,10 @@ import logging import time -from typing import Callable, List, TypeVar - -from psycopg2.extensions import cursor -from psycopg2.sql import SQL - -CursorObj = TypeVar("CursorObj", bound=cursor) +from typing import Callable, List from django.db import connection +from django.db.backends.utils import CursorWrapper +from psycopg2.sql import SQL from zerver.models import UserProfile @@ -23,7 +20,7 @@ logger.setLevel(logging.WARNING) def build_topic_mute_checker( - cursor: CursorObj, user_profile: UserProfile + cursor: CursorWrapper, user_profile: UserProfile ) -> Callable[[int, str], bool]: """ This function is similar to the function of the same name @@ -52,7 +49,7 @@ def build_topic_mute_checker( return is_muted -def update_unread_flags(cursor: CursorObj, user_message_ids: List[int]) -> None: +def update_unread_flags(cursor: CursorWrapper, user_message_ids: List[int]) -> None: query = SQL( """ UPDATE zerver_usermessage @@ -72,7 +69,7 @@ def get_timing(message: str, f: Callable[[], None]) -> None: logger.info("elapsed time: %.03f\n", elapsed) -def fix_unsubscribed(cursor: CursorObj, user_profile: UserProfile) -> None: +def fix_unsubscribed(cursor: CursorWrapper, user_profile: UserProfile) -> None: recipient_ids = [] diff --git a/zerver/lib/migrate.py b/zerver/lib/migrate.py index 1189c16bd5..ab1b7b3d24 100644 --- a/zerver/lib/migrate.py +++ b/zerver/lib/migrate.py @@ -1,14 +1,12 @@ import time -from typing import List, TypeVar +from typing import List -from psycopg2.extensions import cursor +from django.db.backends.utils import CursorWrapper from psycopg2.sql import SQL, Composable, Identifier -CursorObj = TypeVar("CursorObj", bound=cursor) - def do_batch_update( - cursor: CursorObj, + cursor: CursorWrapper, table: str, assignments: List[Composable], batch_size: int = 10000,