diff --git a/zerver/lib/retention.py b/zerver/lib/retention.py index 1c3b4f1f1c..5ea02d5848 100644 --- a/zerver/lib/retention.py +++ b/zerver/lib/retention.py @@ -329,3 +329,96 @@ def move_messages_to_archive(message_ids: List[int], chunk_size: int=MESSAGE_BAT # Clean up attachments: archived_attachments = ArchivedAttachment.objects.filter(messages__id__in=message_ids).distinct() Attachment.objects.filter(messages__isnull=True, id__in=archived_attachments).delete() + +def restore_messages_from_archive(archive_transaction_id: int) -> List[int]: + query = """ + INSERT INTO zerver_message ({dst_fields}) + SELECT {src_fields} + FROM zerver_archivedmessage + LEFT JOIN zerver_message ON zerver_archivedmessage.id = zerver_message.id + WHERE zerver_archivedmessage.archive_transaction_id = {archive_transaction_id} + AND zerver_message.id is NULL + RETURNING id + """ + return move_rows(Message, query, src_db_table='zerver_archivedmessage', returning_id=True, + archive_transaction_id=archive_transaction_id) + +def restore_models_with_message_key_from_archive(archive_transaction_id: int) -> None: + for model in models_with_message_key: + query = """ + INSERT INTO {table_name} ({dst_fields}) + SELECT {src_fields} + FROM {archive_table_name} + INNER JOIN zerver_archivedmessage ON {archive_table_name}.message_id = zerver_archivedmessage.id + LEFT JOIN {table_name} ON {archive_table_name}.id = {table_name}.id + WHERE zerver_archivedmessage.archive_transaction_id = {archive_transaction_id} + AND {table_name}.id IS NULL + """ + + move_rows(model['class'], query, src_db_table=model['archive_table_name'], + table_name=model['table_name'], + archive_transaction_id=archive_transaction_id, + archive_table_name=model['archive_table_name']) + +def restore_attachments_from_archive(archive_transaction_id: int) -> None: + query = """ + INSERT INTO zerver_attachment ({dst_fields}) + SELECT {src_fields} + FROM zerver_archivedattachment + INNER JOIN zerver_archivedattachment_messages + ON zerver_archivedattachment_messages.archivedattachment_id = zerver_archivedattachment.id + INNER JOIN zerver_archivedmessage + ON zerver_archivedattachment_messages.archivedmessage_id = zerver_archivedmessage.id + LEFT JOIN zerver_attachment ON zerver_archivedattachment.id = zerver_attachment.id + WHERE zerver_archivedmessage.archive_transaction_id = {archive_transaction_id} + AND zerver_attachment.id IS NULL + GROUP BY zerver_archivedattachment.id + """ + move_rows(Attachment, query, src_db_table='zerver_archivedattachment', + archive_transaction_id=archive_transaction_id) + +def restore_attachment_messages_from_archive(archive_transaction_id: int) -> None: + query = """ + INSERT INTO zerver_attachment_messages (id, attachment_id, message_id) + SELECT zerver_archivedattachment_messages.id, + zerver_archivedattachment_messages.archivedattachment_id, + zerver_archivedattachment_messages.archivedmessage_id + FROM zerver_archivedattachment_messages + INNER JOIN zerver_archivedmessage + ON zerver_archivedattachment_messages.archivedmessage_id = zerver_archivedmessage.id + LEFT JOIN zerver_attachment_messages + ON zerver_archivedattachment_messages.id = zerver_attachment_messages.id + WHERE zerver_archivedmessage.archive_transaction_id = {archive_transaction_id} + AND zerver_attachment_messages.id IS NULL + """ + with connection.cursor() as cursor: + cursor.execute(query.format(archive_transaction_id=archive_transaction_id)) + +@transaction.atomic +def restore_data_from_archive(archive_transaction: ArchiveTransaction) -> None: + restore_messages_from_archive(archive_transaction.id) + restore_models_with_message_key_from_archive(archive_transaction.id) + restore_attachments_from_archive(archive_transaction.id) + restore_attachment_messages_from_archive(archive_transaction.id) + + archive_transaction.restored = True + archive_transaction.save() + +def restore_data_from_archive_by_transactions(archive_transactions: List[ArchiveTransaction]) -> None: + # Looping over the list of ids means we're batching the restoration process by the size of the + # transactions: + for archive_transaction in archive_transactions: + restore_data_from_archive(archive_transaction) + +def restore_data_from_archive_by_realm(realm: Realm) -> None: + transactions = ArchiveTransaction.objects.exclude(restored=True).filter(realm=realm) + restore_data_from_archive_by_transactions(transactions) + +def restore_all_data_from_archive(restore_manual_transactions: bool=True) -> None: + for realm in Realm.objects.all(): + restore_data_from_archive_by_realm(realm) + + if restore_manual_transactions: + restore_data_from_archive_by_transactions( + ArchiveTransaction.objects.exclude(restored=True).filter(type=ArchiveTransaction.MANUAL) + ) diff --git a/zerver/tests/test_retention.py b/zerver/tests/test_retention.py index 0f24d6b556..67c401a01e 100644 --- a/zerver/tests/test_retention.py +++ b/zerver/tests/test_retention.py @@ -15,6 +15,7 @@ from zerver.models import (Message, Realm, UserProfile, Stream, ArchivedUserMess from zerver.lib.retention import ( archive_messages, move_messages_to_archive, + restore_all_data_from_archive ) # Class with helper functions useful for testing archiving of reactions: @@ -181,6 +182,9 @@ class TestArchiveMessagesGeneral(ArchiveMessagesTestingBase): archive_messages() self._verify_archive_data(expired_msg_ids, expired_usermsg_ids) + restore_all_data_from_archive() + self._verify_restored_data(expired_msg_ids, expired_usermsg_ids) + def test_expired_messages_in_one_realm(self) -> None: """Test with a retention policy set for only the MIT realm""" self._set_realm_message_retention_value(self.zulip_realm, None) @@ -209,6 +213,9 @@ class TestArchiveMessagesGeneral(ArchiveMessagesTestingBase): archive_messages() self._verify_archive_data(expired_msg_ids, expired_usermsg_ids) + restore_all_data_from_archive() + self._verify_restored_data(expired_msg_ids, expired_usermsg_ids) + self._set_realm_message_retention_value(self.zulip_realm, ZULIP_REALM_DAYS) def test_different_stream_realm_policies(self) -> None: @@ -302,6 +309,9 @@ class TestArchiveMessagesGeneral(ArchiveMessagesTestingBase): # Make sure we archived what neeeded: self._verify_archive_data(expired_msg_ids, expired_usermsg_ids) + restore_all_data_from_archive() + self._verify_restored_data(expired_msg_ids, expired_usermsg_ids) + def test_archiving_attachments(self) -> None: """End-to-end test for the logic for archiving attachments. This test is hard to read without first reading _send_messages_with_attachments""" @@ -342,6 +352,16 @@ class TestArchiveMessagesGeneral(ArchiveMessagesTestingBase): sorted(msgs_ids.values()) ) + restore_all_data_from_archive() + # Attachments should have been restored: + self.assertEqual(Attachment.objects.count(), 3) + self.assertEqual(ArchivedAttachment.objects.count(), 3) # Archived data doesn't get deleted by restoring. + self.assertEqual( + list(Attachment.objects.distinct('messages__id').order_by('messages__id').values_list( + 'messages__id', flat=True)), + sorted(msgs_ids.values()) + ) + class TestArchivingSubMessages(ArchiveMessagesTestingBase): def test_archiving_submessages(self) -> None: expired_msg_ids = self._make_expired_zulip_messages(2) @@ -385,6 +405,12 @@ class TestArchivingSubMessages(ArchiveMessagesTestingBase): set(submessage_ids) ) + restore_all_data_from_archive() + self.assertEqual( + set(SubMessage.objects.filter(id__in=submessage_ids).values_list('id', flat=True)), + set(submessage_ids) + ) + class TestArchivingReactions(ArchiveMessagesTestingBase, EmojiReactionBase): def test_archiving_reactions(self) -> None: expired_msg_ids = self._make_expired_zulip_messages(2) @@ -408,6 +434,12 @@ class TestArchivingReactions(ArchiveMessagesTestingBase, EmojiReactionBase): set(reaction_ids) ) + restore_all_data_from_archive() + self.assertEqual( + set(Reaction.objects.filter(id__in=reaction_ids).values_list('id', flat=True)), + set(reaction_ids) + ) + class MoveMessageToArchiveBase(RetentionTestingBase): def setUp(self) -> None: self.sender = 'hamlet@zulip.com' @@ -441,6 +473,9 @@ class MoveMessageToArchiveGeneral(MoveMessageToArchiveBase): move_messages_to_archive(message_ids=msg_ids) self._verify_archive_data(msg_ids, usermsg_ids) + restore_all_data_from_archive() + self._verify_restored_data(msg_ids, usermsg_ids) + def test_stream_messages_archiving(self) -> None: msg_ids = [self.send_stream_message(self.sender, "Verona") for i in range(0, 3)] @@ -450,6 +485,9 @@ class MoveMessageToArchiveGeneral(MoveMessageToArchiveBase): move_messages_to_archive(message_ids=msg_ids) self._verify_archive_data(msg_ids, usermsg_ids) + restore_all_data_from_archive() + self._verify_restored_data(msg_ids, usermsg_ids) + def test_archiving_messages_second_time(self) -> None: msg_ids = [self.send_stream_message(self.sender, "Verona") for i in range(0, 3)] @@ -512,6 +550,20 @@ class MoveMessageToArchiveGeneral(MoveMessageToArchiveBase): archivedattachment__id=attachment_id).values_list("id", flat=True)) ) + restore_all_data_from_archive() + self._verify_restored_data(msg_ids, usermsg_ids) + + restored_attachment_ids = list( + Attachment.objects.filter(messages__id__in=msg_ids).values_list("id", flat=True) + ) + + self.assertEqual(set(attachment_ids), set(restored_attachment_ids)) + for attachment_id in restored_attachment_ids: + self.assertEqual( + set(attachment_id_to_message_ids[attachment_id]), + set(Message.objects.filter(attachment__id=attachment_id).values_list("id", flat=True)) + ) + def test_archiving_message_with_shared_attachment(self) -> None: # Make sure that attachments still in use in other messages don't get deleted: self._create_attachments() @@ -551,6 +603,13 @@ class MoveMessageToArchiveGeneral(MoveMessageToArchiveBase): # Now the attachment should have been deleted: self.assertEqual(Attachment.objects.count(), 0) + # Restore everything: + restore_all_data_from_archive() + self.assertEqual( + set(Attachment.objects.filter(messages__id=msg_id).values_list("id", flat=True)), + set(attachment_ids) + ) + class MoveMessageToArchiveWithSubMessages(MoveMessageToArchiveBase): def test_archiving_message_with_submessages(self) -> None: msg_id = self.send_stream_message(self.sender, "Verona") @@ -585,6 +644,12 @@ class MoveMessageToArchiveWithSubMessages(MoveMessageToArchiveBase): ) self.assertEqual(SubMessage.objects.filter(id__in=submessage_ids).count(), 0) + restore_all_data_from_archive() + self.assertEqual( + set(SubMessage.objects.filter(id__in=submessage_ids).values_list('id', flat=True)), + set(submessage_ids) + ) + class MoveMessageToArchiveWithReactions(MoveMessageToArchiveBase, EmojiReactionBase): def test_archiving_message_with_reactions(self) -> None: msg_id = self.send_stream_message(self.sender, "Verona") @@ -604,3 +669,9 @@ class MoveMessageToArchiveWithReactions(MoveMessageToArchiveBase, EmojiReactionB set(reaction_ids) ) self.assertEqual(Reaction.objects.filter(id__in=reaction_ids).count(), 0) + + restore_all_data_from_archive() + self.assertEqual( + set(Reaction.objects.filter(id__in=reaction_ids).values_list('id', flat=True)), + set(reaction_ids) + )