From 27b06187047d23ee13faf879ff732f8a3301c646 Mon Sep 17 00:00:00 2001 From: Anders Kaseorg Date: Wed, 17 Jul 2024 13:45:14 -0700 Subject: [PATCH] data_import: Fix IdMapper typing. Signed-off-by: Anders Kaseorg --- zerver/data_import/mattermost.py | 32 ++++++++++++------------ zerver/data_import/rocketchat.py | 30 +++++++++++----------- zerver/data_import/sequencer.py | 12 +++++---- zerver/tests/test_mattermost_importer.py | 20 +++++++-------- zerver/tests/test_rocketchat_importer.py | 20 +++++++-------- 5 files changed, 58 insertions(+), 56 deletions(-) diff --git a/zerver/data_import/mattermost.py b/zerver/data_import/mattermost.py index 7cd5d34c3b..9f3e6a5fbf 100644 --- a/zerver/data_import/mattermost.py +++ b/zerver/data_import/mattermost.py @@ -62,7 +62,7 @@ def make_realm(realm_id: int, team: dict[str, Any]) -> ZerverFieldsT: def process_user( - user_dict: dict[str, Any], realm_id: int, team_name: str, user_id_mapper: IdMapper + user_dict: dict[str, Any], realm_id: int, team_name: str, user_id_mapper: IdMapper[str] ) -> ZerverFieldsT: def is_team_admin(user_dict: dict[str, Any]) -> bool: if user_dict["teams"] is None: @@ -127,7 +127,7 @@ def process_user( def convert_user_data( user_handler: UserHandler, - user_id_mapper: IdMapper, + user_id_mapper: IdMapper[str], user_data_map: dict[str, dict[str, Any]], realm_id: int, team_name: str, @@ -147,8 +147,8 @@ def convert_channel_data( channel_data: list[ZerverFieldsT], user_data_map: dict[str, dict[str, Any]], subscriber_handler: SubscriberHandler, - stream_id_mapper: IdMapper, - user_id_mapper: IdMapper, + stream_id_mapper: IdMapper[str], + user_id_mapper: IdMapper[str], realm_id: int, team_name: str, ) -> list[ZerverFieldsT]: @@ -245,8 +245,8 @@ def convert_direct_message_group_data( direct_message_group_data: list[ZerverFieldsT], user_data_map: dict[str, dict[str, Any]], subscriber_handler: SubscriberHandler, - huddle_id_mapper: IdMapper, - user_id_mapper: IdMapper, + huddle_id_mapper: IdMapper[str], + user_id_mapper: IdMapper[str], realm_id: int, team_name: str, ) -> list[ZerverFieldsT]: @@ -274,7 +274,7 @@ def build_reactions( total_reactions: list[ZerverFieldsT], reactions: list[ZerverFieldsT], message_id: int, - user_id_mapper: IdMapper, + user_id_mapper: IdMapper[str], zerver_realmemoji: list[ZerverFieldsT], ) -> None: realmemoji = {} @@ -314,7 +314,7 @@ def build_reactions( total_reactions.append(reaction_dict) -def get_mentioned_user_ids(raw_message: dict[str, Any], user_id_mapper: IdMapper) -> set[int]: +def get_mentioned_user_ids(raw_message: dict[str, Any], user_id_mapper: IdMapper[str]) -> set[int]: user_ids = set() content = raw_message["content"] @@ -406,7 +406,7 @@ def process_raw_message_batch( realm_id: int, raw_messages: list[dict[str, Any]], subscriber_map: dict[int, set[int]], - user_id_mapper: IdMapper, + user_id_mapper: IdMapper[str], user_handler: UserHandler, get_recipient_id_from_receiver_name: Callable[[str, int], int], is_pm_data: bool, @@ -549,7 +549,7 @@ def process_posts( output_dir: str, is_pm_data: bool, masking_content: bool, - user_id_mapper: IdMapper, + user_id_mapper: IdMapper[str], user_handler: UserHandler, zerver_realmemoji: list[dict[str, Any]], total_reactions: list[dict[str, Any]], @@ -658,9 +658,9 @@ def write_message_data( subscriber_map: dict[int, set[int]], output_dir: str, masking_content: bool, - stream_id_mapper: IdMapper, - huddle_id_mapper: IdMapper, - user_id_mapper: IdMapper, + stream_id_mapper: IdMapper[str], + huddle_id_mapper: IdMapper[str], + user_id_mapper: IdMapper[str], user_handler: UserHandler, zerver_realmemoji: list[dict[str, Any]], total_reactions: list[dict[str, Any]], @@ -894,9 +894,9 @@ def do_convert_data(mattermost_data_dir: str, output_dir: str, masking_content: user_handler = UserHandler() subscriber_handler = SubscriberHandler() - user_id_mapper = IdMapper() - stream_id_mapper = IdMapper() - huddle_id_mapper = IdMapper() + user_id_mapper = IdMapper[str]() + stream_id_mapper = IdMapper[str]() + huddle_id_mapper = IdMapper[str]() print("Generating data for", team_name) realm = make_realm(realm_id, team) diff --git a/zerver/data_import/rocketchat.py b/zerver/data_import/rocketchat.py index 3c07fcc164..6e31611145 100644 --- a/zerver/data_import/rocketchat.py +++ b/zerver/data_import/rocketchat.py @@ -57,7 +57,7 @@ def process_users( realm_id: int, domain_name: str, user_handler: UserHandler, - user_id_mapper: IdMapper, + user_id_mapper: IdMapper[str], ) -> None: realm_owners: list[int] = [] bots: list[int] = [] @@ -158,7 +158,7 @@ def get_stream_name(rc_channel: dict[str, Any]) -> str: def convert_channel_data( room_id_to_room_map: dict[str, dict[str, Any]], team_id_to_team_map: dict[str, dict[str, Any]], - stream_id_mapper: IdMapper, + stream_id_mapper: IdMapper[str], realm_id: int, ) -> list[ZerverFieldsT]: streams = [] @@ -205,8 +205,8 @@ def convert_stream_subscription_data( user_id_to_user_map: dict[str, dict[str, Any]], dsc_id_to_dsc_map: dict[str, dict[str, Any]], zerver_stream: list[ZerverFieldsT], - stream_id_mapper: IdMapper, - user_id_mapper: IdMapper, + stream_id_mapper: IdMapper[str], + user_id_mapper: IdMapper[str], subscriber_handler: SubscriberHandler, ) -> None: stream_members_map: dict[int, set[int]] = {} @@ -240,8 +240,8 @@ def convert_stream_subscription_data( def convert_direct_message_group_data( huddle_id_to_huddle_map: dict[str, dict[str, Any]], - huddle_id_mapper: IdMapper, - user_id_mapper: IdMapper, + huddle_id_mapper: IdMapper[str], + user_id_mapper: IdMapper[str], subscriber_handler: SubscriberHandler, ) -> list[ZerverFieldsT]: zerver_direct_message_group: list[ZerverFieldsT] = [] @@ -582,7 +582,7 @@ def process_raw_message_batch( def get_topic_name( message: dict[str, Any], dsc_id_to_dsc_map: dict[str, dict[str, Any]], - thread_id_mapper: IdMapper, + thread_id_mapper: IdMapper[str], is_pm_data: bool = False, ) -> str: if is_pm_data: @@ -609,14 +609,14 @@ def process_messages( subscriber_map: dict[int, set[int]], is_pm_data: bool, username_to_user_id_map: dict[str, str], - user_id_mapper: IdMapper, + user_id_mapper: IdMapper[str], user_handler: UserHandler, user_id_to_recipient_id: dict[int, int], - stream_id_mapper: IdMapper, + stream_id_mapper: IdMapper[str], stream_id_to_recipient_id: dict[int, int], - huddle_id_mapper: IdMapper, + huddle_id_mapper: IdMapper[str], huddle_id_to_recipient_id: dict[int, int], - thread_id_mapper: IdMapper, + thread_id_mapper: IdMapper[str], room_id_to_room_map: dict[str, dict[str, Any]], dsc_id_to_dsc_map: dict[str, dict[str, Any]], direct_id_to_direct_map: dict[str, dict[str, Any]], @@ -1074,10 +1074,10 @@ def do_convert_data(rocketchat_data_dir: str, output_dir: str) -> None: user_handler = UserHandler() subscriber_handler = SubscriberHandler() - user_id_mapper = IdMapper() - stream_id_mapper = IdMapper() - huddle_id_mapper = IdMapper() - thread_id_mapper = IdMapper() + user_id_mapper = IdMapper[str]() + stream_id_mapper = IdMapper[str]() + huddle_id_mapper = IdMapper[str]() + thread_id_mapper = IdMapper[str]() process_users( user_id_to_user_map=user_id_to_user_map, diff --git a/zerver/data_import/sequencer.py b/zerver/data_import/sequencer.py index c009d1563e..d6ed379e7d 100644 --- a/zerver/data_import/sequencer.py +++ b/zerver/data_import/sequencer.py @@ -1,5 +1,5 @@ from collections.abc import Callable -from typing import Any +from typing import Generic, TypeVar """ This module helps you set up a bunch @@ -12,6 +12,8 @@ for data imports that's usually easy to manage. """ +T = TypeVar("T") + def _seq() -> Callable[[], int]: i = 0 @@ -54,15 +56,15 @@ import of the file. NEXT_ID = sequencer() -class IdMapper: +class IdMapper(Generic[T]): def __init__(self) -> None: - self.map: dict[Any, int] = {} + self.map: dict[T, int] = {} self.cnt = 0 - def has(self, their_id: Any) -> bool: + def has(self, their_id: T) -> bool: return their_id in self.map - def get(self, their_id: Any) -> int: + def get(self, their_id: T) -> int: if their_id in self.map: return self.map[their_id] diff --git a/zerver/tests/test_mattermost_importer.py b/zerver/tests/test_mattermost_importer.py index b6edbb0a15..2402717444 100644 --- a/zerver/tests/test_mattermost_importer.py +++ b/zerver/tests/test_mattermost_importer.py @@ -84,7 +84,7 @@ class MatterMostImporter(ZulipTestCase): ) def test_process_user(self) -> None: - user_id_mapper = IdMapper() + user_id_mapper = IdMapper[str]() fixture_file_name = self.fixture_file_name("export.json", "mattermost_fixtures") mattermost_data = mattermost_data_file_to_dict(fixture_file_name) username_to_user = create_username_to_user_mapping(mattermost_data["user"]) @@ -131,7 +131,7 @@ class MatterMostImporter(ZulipTestCase): self.assertEqual(user["timezone"], "UTC") def test_process_guest_user(self) -> None: - user_id_mapper = IdMapper() + user_id_mapper = IdMapper[str]() fixture_file_name = self.fixture_file_name("guestExport.json", "mattermost_fixtures") mattermost_data = mattermost_data_file_to_dict(fixture_file_name) username_to_user = create_username_to_user_mapping(mattermost_data["user"]) @@ -162,7 +162,7 @@ class MatterMostImporter(ZulipTestCase): self.assertEqual(user["role"], UserProfile.ROLE_MEMBER) def test_convert_user_data(self) -> None: - user_id_mapper = IdMapper() + user_id_mapper = IdMapper[str]() realm_id = 3 fixture_file_name = self.fixture_file_name("export.json", "mattermost_fixtures") mattermost_data = mattermost_data_file_to_dict(fixture_file_name) @@ -211,8 +211,8 @@ class MatterMostImporter(ZulipTestCase): user_handler = UserHandler() subscriber_handler = SubscriberHandler() - stream_id_mapper = IdMapper() - user_id_mapper = IdMapper() + stream_id_mapper = IdMapper[str]() + user_id_mapper = IdMapper[str]() team_name = "gryffindor" convert_user_data( @@ -346,8 +346,8 @@ class MatterMostImporter(ZulipTestCase): user_handler = UserHandler() subscriber_handler = SubscriberHandler() - huddle_id_mapper = IdMapper() - user_id_mapper = IdMapper() + huddle_id_mapper = IdMapper[str]() + user_id_mapper = IdMapper[str]() team_name = "gryffindor" convert_user_data( @@ -432,7 +432,7 @@ class MatterMostImporter(ZulipTestCase): reset_mirror_dummy_users(username_to_user) user_handler = UserHandler() - user_id_mapper = IdMapper() + user_id_mapper = IdMapper[str]() team_name = "gryffindor" convert_user_data( @@ -479,7 +479,7 @@ class MatterMostImporter(ZulipTestCase): self.assertTrue(filecmp.cmp(attachment_path, attachment_out_path)) def test_get_mentioned_user_ids(self) -> None: - user_id_mapper = IdMapper() + user_id_mapper = IdMapper[str]() harry_id = user_id_mapper.get("harry") raw_message = { @@ -615,7 +615,7 @@ class MatterMostImporter(ZulipTestCase): self.assertEqual(zerver_realmemoji[1]["name"], "tick") tick_emoji_code = zerver_realmemoji[1]["id"] - user_id_mapper = IdMapper() + user_id_mapper = IdMapper[str]() harry_id = user_id_mapper.get("harry") ron_id = user_id_mapper.get("ron") diff --git a/zerver/tests/test_rocketchat_importer.py b/zerver/tests/test_rocketchat_importer.py index f9bc59bb8b..1ab6ecf956 100644 --- a/zerver/tests/test_rocketchat_importer.py +++ b/zerver/tests/test_rocketchat_importer.py @@ -96,7 +96,7 @@ class RocketChatImporter(ZulipTestCase): domain_name = "zulip.com" user_handler = UserHandler() - user_id_mapper = IdMapper() + user_id_mapper = IdMapper[str]() process_users( user_id_to_user_map=user_id_to_user_map, @@ -242,7 +242,7 @@ class RocketChatImporter(ZulipTestCase): rocketchat_data = rocketchat_data_to_dict(fixture_dir_name) realm_id = 3 - stream_id_mapper = IdMapper() + stream_id_mapper = IdMapper[str]() room_id_to_room_map: dict[str, dict[str, Any]] = {} team_id_to_team_map: dict[str, dict[str, Any]] = {} @@ -316,8 +316,8 @@ class RocketChatImporter(ZulipTestCase): user_handler = UserHandler() subscriber_handler = SubscriberHandler() - user_id_mapper = IdMapper() - stream_id_mapper = IdMapper() + user_id_mapper = IdMapper[str]() + stream_id_mapper = IdMapper[str]() user_id_to_user_map = map_user_id_to_user(rocketchat_data["user"]) @@ -423,8 +423,8 @@ class RocketChatImporter(ZulipTestCase): user_handler = UserHandler() subscriber_handler = SubscriberHandler() - user_id_mapper = IdMapper() - huddle_id_mapper = IdMapper() + user_id_mapper = IdMapper[str]() + huddle_id_mapper = IdMapper[str]() user_id_to_user_map = map_user_id_to_user(rocketchat_data["user"]) @@ -525,9 +525,9 @@ class RocketChatImporter(ZulipTestCase): user_handler = UserHandler() subscriber_handler = SubscriberHandler() - user_id_mapper = IdMapper() - stream_id_mapper = IdMapper() - huddle_id_mapper = IdMapper() + user_id_mapper = IdMapper[str]() + stream_id_mapper = IdMapper[str]() + huddle_id_mapper = IdMapper[str]() user_id_to_user_map = map_user_id_to_user(rocketchat_data["user"]) @@ -823,7 +823,7 @@ class RocketChatImporter(ZulipTestCase): domain_name = "zulip.com" user_handler = UserHandler() - user_id_mapper = IdMapper() + user_id_mapper = IdMapper[str]() process_users( user_id_to_user_map=user_id_to_user_map,