data_import: Fix IdMapper typing.

Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
Anders Kaseorg 2024-07-17 13:45:14 -07:00 committed by Tim Abbott
parent 1fd3f983a5
commit 27b0618704
5 changed files with 58 additions and 56 deletions

View File

@ -62,7 +62,7 @@ def make_realm(realm_id: int, team: dict[str, Any]) -> ZerverFieldsT:
def process_user(
user_dict: dict[str, Any], realm_id: int, team_name: str, user_id_mapper: IdMapper
user_dict: dict[str, Any], realm_id: int, team_name: str, user_id_mapper: IdMapper[str]
) -> ZerverFieldsT:
def is_team_admin(user_dict: dict[str, Any]) -> bool:
if user_dict["teams"] is None:
@ -127,7 +127,7 @@ def process_user(
def convert_user_data(
user_handler: UserHandler,
user_id_mapper: IdMapper,
user_id_mapper: IdMapper[str],
user_data_map: dict[str, dict[str, Any]],
realm_id: int,
team_name: str,
@ -147,8 +147,8 @@ def convert_channel_data(
channel_data: list[ZerverFieldsT],
user_data_map: dict[str, dict[str, Any]],
subscriber_handler: SubscriberHandler,
stream_id_mapper: IdMapper,
user_id_mapper: IdMapper,
stream_id_mapper: IdMapper[str],
user_id_mapper: IdMapper[str],
realm_id: int,
team_name: str,
) -> list[ZerverFieldsT]:
@ -245,8 +245,8 @@ def convert_direct_message_group_data(
direct_message_group_data: list[ZerverFieldsT],
user_data_map: dict[str, dict[str, Any]],
subscriber_handler: SubscriberHandler,
huddle_id_mapper: IdMapper,
user_id_mapper: IdMapper,
huddle_id_mapper: IdMapper[str],
user_id_mapper: IdMapper[str],
realm_id: int,
team_name: str,
) -> list[ZerverFieldsT]:
@ -274,7 +274,7 @@ def build_reactions(
total_reactions: list[ZerverFieldsT],
reactions: list[ZerverFieldsT],
message_id: int,
user_id_mapper: IdMapper,
user_id_mapper: IdMapper[str],
zerver_realmemoji: list[ZerverFieldsT],
) -> None:
realmemoji = {}
@ -314,7 +314,7 @@ def build_reactions(
total_reactions.append(reaction_dict)
def get_mentioned_user_ids(raw_message: dict[str, Any], user_id_mapper: IdMapper) -> set[int]:
def get_mentioned_user_ids(raw_message: dict[str, Any], user_id_mapper: IdMapper[str]) -> set[int]:
user_ids = set()
content = raw_message["content"]
@ -406,7 +406,7 @@ def process_raw_message_batch(
realm_id: int,
raw_messages: list[dict[str, Any]],
subscriber_map: dict[int, set[int]],
user_id_mapper: IdMapper,
user_id_mapper: IdMapper[str],
user_handler: UserHandler,
get_recipient_id_from_receiver_name: Callable[[str, int], int],
is_pm_data: bool,
@ -549,7 +549,7 @@ def process_posts(
output_dir: str,
is_pm_data: bool,
masking_content: bool,
user_id_mapper: IdMapper,
user_id_mapper: IdMapper[str],
user_handler: UserHandler,
zerver_realmemoji: list[dict[str, Any]],
total_reactions: list[dict[str, Any]],
@ -658,9 +658,9 @@ def write_message_data(
subscriber_map: dict[int, set[int]],
output_dir: str,
masking_content: bool,
stream_id_mapper: IdMapper,
huddle_id_mapper: IdMapper,
user_id_mapper: IdMapper,
stream_id_mapper: IdMapper[str],
huddle_id_mapper: IdMapper[str],
user_id_mapper: IdMapper[str],
user_handler: UserHandler,
zerver_realmemoji: list[dict[str, Any]],
total_reactions: list[dict[str, Any]],
@ -894,9 +894,9 @@ def do_convert_data(mattermost_data_dir: str, output_dir: str, masking_content:
user_handler = UserHandler()
subscriber_handler = SubscriberHandler()
user_id_mapper = IdMapper()
stream_id_mapper = IdMapper()
huddle_id_mapper = IdMapper()
user_id_mapper = IdMapper[str]()
stream_id_mapper = IdMapper[str]()
huddle_id_mapper = IdMapper[str]()
print("Generating data for", team_name)
realm = make_realm(realm_id, team)

View File

@ -57,7 +57,7 @@ def process_users(
realm_id: int,
domain_name: str,
user_handler: UserHandler,
user_id_mapper: IdMapper,
user_id_mapper: IdMapper[str],
) -> None:
realm_owners: list[int] = []
bots: list[int] = []
@ -158,7 +158,7 @@ def get_stream_name(rc_channel: dict[str, Any]) -> str:
def convert_channel_data(
room_id_to_room_map: dict[str, dict[str, Any]],
team_id_to_team_map: dict[str, dict[str, Any]],
stream_id_mapper: IdMapper,
stream_id_mapper: IdMapper[str],
realm_id: int,
) -> list[ZerverFieldsT]:
streams = []
@ -205,8 +205,8 @@ def convert_stream_subscription_data(
user_id_to_user_map: dict[str, dict[str, Any]],
dsc_id_to_dsc_map: dict[str, dict[str, Any]],
zerver_stream: list[ZerverFieldsT],
stream_id_mapper: IdMapper,
user_id_mapper: IdMapper,
stream_id_mapper: IdMapper[str],
user_id_mapper: IdMapper[str],
subscriber_handler: SubscriberHandler,
) -> None:
stream_members_map: dict[int, set[int]] = {}
@ -240,8 +240,8 @@ def convert_stream_subscription_data(
def convert_direct_message_group_data(
huddle_id_to_huddle_map: dict[str, dict[str, Any]],
huddle_id_mapper: IdMapper,
user_id_mapper: IdMapper,
huddle_id_mapper: IdMapper[str],
user_id_mapper: IdMapper[str],
subscriber_handler: SubscriberHandler,
) -> list[ZerverFieldsT]:
zerver_direct_message_group: list[ZerverFieldsT] = []
@ -582,7 +582,7 @@ def process_raw_message_batch(
def get_topic_name(
message: dict[str, Any],
dsc_id_to_dsc_map: dict[str, dict[str, Any]],
thread_id_mapper: IdMapper,
thread_id_mapper: IdMapper[str],
is_pm_data: bool = False,
) -> str:
if is_pm_data:
@ -609,14 +609,14 @@ def process_messages(
subscriber_map: dict[int, set[int]],
is_pm_data: bool,
username_to_user_id_map: dict[str, str],
user_id_mapper: IdMapper,
user_id_mapper: IdMapper[str],
user_handler: UserHandler,
user_id_to_recipient_id: dict[int, int],
stream_id_mapper: IdMapper,
stream_id_mapper: IdMapper[str],
stream_id_to_recipient_id: dict[int, int],
huddle_id_mapper: IdMapper,
huddle_id_mapper: IdMapper[str],
huddle_id_to_recipient_id: dict[int, int],
thread_id_mapper: IdMapper,
thread_id_mapper: IdMapper[str],
room_id_to_room_map: dict[str, dict[str, Any]],
dsc_id_to_dsc_map: dict[str, dict[str, Any]],
direct_id_to_direct_map: dict[str, dict[str, Any]],
@ -1074,10 +1074,10 @@ def do_convert_data(rocketchat_data_dir: str, output_dir: str) -> None:
user_handler = UserHandler()
subscriber_handler = SubscriberHandler()
user_id_mapper = IdMapper()
stream_id_mapper = IdMapper()
huddle_id_mapper = IdMapper()
thread_id_mapper = IdMapper()
user_id_mapper = IdMapper[str]()
stream_id_mapper = IdMapper[str]()
huddle_id_mapper = IdMapper[str]()
thread_id_mapper = IdMapper[str]()
process_users(
user_id_to_user_map=user_id_to_user_map,

View File

@ -1,5 +1,5 @@
from collections.abc import Callable
from typing import Any
from typing import Generic, TypeVar
"""
This module helps you set up a bunch
@ -12,6 +12,8 @@ for data imports that's usually easy to
manage.
"""
T = TypeVar("T")
def _seq() -> Callable[[], int]:
i = 0
@ -54,15 +56,15 @@ import of the file.
NEXT_ID = sequencer()
class IdMapper:
class IdMapper(Generic[T]):
def __init__(self) -> None:
self.map: dict[Any, int] = {}
self.map: dict[T, int] = {}
self.cnt = 0
def has(self, their_id: Any) -> bool:
def has(self, their_id: T) -> bool:
return their_id in self.map
def get(self, their_id: Any) -> int:
def get(self, their_id: T) -> int:
if their_id in self.map:
return self.map[their_id]

View File

@ -84,7 +84,7 @@ class MatterMostImporter(ZulipTestCase):
)
def test_process_user(self) -> None:
user_id_mapper = IdMapper()
user_id_mapper = IdMapper[str]()
fixture_file_name = self.fixture_file_name("export.json", "mattermost_fixtures")
mattermost_data = mattermost_data_file_to_dict(fixture_file_name)
username_to_user = create_username_to_user_mapping(mattermost_data["user"])
@ -131,7 +131,7 @@ class MatterMostImporter(ZulipTestCase):
self.assertEqual(user["timezone"], "UTC")
def test_process_guest_user(self) -> None:
user_id_mapper = IdMapper()
user_id_mapper = IdMapper[str]()
fixture_file_name = self.fixture_file_name("guestExport.json", "mattermost_fixtures")
mattermost_data = mattermost_data_file_to_dict(fixture_file_name)
username_to_user = create_username_to_user_mapping(mattermost_data["user"])
@ -162,7 +162,7 @@ class MatterMostImporter(ZulipTestCase):
self.assertEqual(user["role"], UserProfile.ROLE_MEMBER)
def test_convert_user_data(self) -> None:
user_id_mapper = IdMapper()
user_id_mapper = IdMapper[str]()
realm_id = 3
fixture_file_name = self.fixture_file_name("export.json", "mattermost_fixtures")
mattermost_data = mattermost_data_file_to_dict(fixture_file_name)
@ -211,8 +211,8 @@ class MatterMostImporter(ZulipTestCase):
user_handler = UserHandler()
subscriber_handler = SubscriberHandler()
stream_id_mapper = IdMapper()
user_id_mapper = IdMapper()
stream_id_mapper = IdMapper[str]()
user_id_mapper = IdMapper[str]()
team_name = "gryffindor"
convert_user_data(
@ -346,8 +346,8 @@ class MatterMostImporter(ZulipTestCase):
user_handler = UserHandler()
subscriber_handler = SubscriberHandler()
huddle_id_mapper = IdMapper()
user_id_mapper = IdMapper()
huddle_id_mapper = IdMapper[str]()
user_id_mapper = IdMapper[str]()
team_name = "gryffindor"
convert_user_data(
@ -432,7 +432,7 @@ class MatterMostImporter(ZulipTestCase):
reset_mirror_dummy_users(username_to_user)
user_handler = UserHandler()
user_id_mapper = IdMapper()
user_id_mapper = IdMapper[str]()
team_name = "gryffindor"
convert_user_data(
@ -479,7 +479,7 @@ class MatterMostImporter(ZulipTestCase):
self.assertTrue(filecmp.cmp(attachment_path, attachment_out_path))
def test_get_mentioned_user_ids(self) -> None:
user_id_mapper = IdMapper()
user_id_mapper = IdMapper[str]()
harry_id = user_id_mapper.get("harry")
raw_message = {
@ -615,7 +615,7 @@ class MatterMostImporter(ZulipTestCase):
self.assertEqual(zerver_realmemoji[1]["name"], "tick")
tick_emoji_code = zerver_realmemoji[1]["id"]
user_id_mapper = IdMapper()
user_id_mapper = IdMapper[str]()
harry_id = user_id_mapper.get("harry")
ron_id = user_id_mapper.get("ron")

View File

@ -96,7 +96,7 @@ class RocketChatImporter(ZulipTestCase):
domain_name = "zulip.com"
user_handler = UserHandler()
user_id_mapper = IdMapper()
user_id_mapper = IdMapper[str]()
process_users(
user_id_to_user_map=user_id_to_user_map,
@ -242,7 +242,7 @@ class RocketChatImporter(ZulipTestCase):
rocketchat_data = rocketchat_data_to_dict(fixture_dir_name)
realm_id = 3
stream_id_mapper = IdMapper()
stream_id_mapper = IdMapper[str]()
room_id_to_room_map: dict[str, dict[str, Any]] = {}
team_id_to_team_map: dict[str, dict[str, Any]] = {}
@ -316,8 +316,8 @@ class RocketChatImporter(ZulipTestCase):
user_handler = UserHandler()
subscriber_handler = SubscriberHandler()
user_id_mapper = IdMapper()
stream_id_mapper = IdMapper()
user_id_mapper = IdMapper[str]()
stream_id_mapper = IdMapper[str]()
user_id_to_user_map = map_user_id_to_user(rocketchat_data["user"])
@ -423,8 +423,8 @@ class RocketChatImporter(ZulipTestCase):
user_handler = UserHandler()
subscriber_handler = SubscriberHandler()
user_id_mapper = IdMapper()
huddle_id_mapper = IdMapper()
user_id_mapper = IdMapper[str]()
huddle_id_mapper = IdMapper[str]()
user_id_to_user_map = map_user_id_to_user(rocketchat_data["user"])
@ -525,9 +525,9 @@ class RocketChatImporter(ZulipTestCase):
user_handler = UserHandler()
subscriber_handler = SubscriberHandler()
user_id_mapper = IdMapper()
stream_id_mapper = IdMapper()
huddle_id_mapper = IdMapper()
user_id_mapper = IdMapper[str]()
stream_id_mapper = IdMapper[str]()
huddle_id_mapper = IdMapper[str]()
user_id_to_user_map = map_user_id_to_user(rocketchat_data["user"])
@ -823,7 +823,7 @@ class RocketChatImporter(ZulipTestCase):
domain_name = "zulip.com"
user_handler = UserHandler()
user_id_mapper = IdMapper()
user_id_mapper = IdMapper[str]()
process_users(
user_id_to_user_map=user_id_to_user_map,