# -*- coding: utf-8 -*- from django.conf import settings from django.test import TestCase import os import shutil import ujson from mock import patch, MagicMock from typing import Any, Dict, List, Set from zerver.lib.actions import ( do_claim_attachments, ) from zerver.lib.export import ( do_export_realm, export_usermessages_batch, ) from zerver.lib.upload import ( claim_attachment, upload_message_image, ) from zerver.lib.utils import ( query_chunker, ) from zerver.lib.test_classes import ( ZulipTestCase, ) from zerver.lib.test_runner import slow from zerver.models import ( Message, Realm, Recipient, UserMessage, ) def rm_tree(path): # type: (str) -> None if os.path.exists(path): shutil.rmtree(path) class QueryUtilTest(ZulipTestCase): def _create_messages(self): # type: () -> None for email in [self.example_email('cordelia'), self.example_email('hamlet'), self.example_email('iago')]: for _ in range(5): self.send_personal_message(email, self.example_email('othello')) @slow('creates lots of data') def test_query_chunker(self): # type: () -> None self._create_messages() cordelia = self.example_user('cordelia') hamlet = self.example_user('hamlet') def get_queries(): # type: () -> List[Any] queries = [ Message.objects.filter(sender_id=cordelia.id), Message.objects.filter(sender_id=hamlet.id), Message.objects.exclude(sender_id__in=[cordelia.id, hamlet.id]) ] return queries for query in get_queries(): # For our test to be meaningful, we want non-empty queries # at first assert len(list(query)) > 0 queries = get_queries() all_msg_ids = set() # type: Set[int] chunker = query_chunker( queries=queries, id_collector=all_msg_ids, chunk_size=20, ) all_row_ids = [] for chunk in chunker: for row in chunk: all_row_ids.append(row.id) self.assertEqual(all_row_ids, sorted(all_row_ids)) self.assertEqual(len(all_msg_ids), len(Message.objects.all())) # Now just search for cordelia/hamlet. Note that we don't really # need the order_by here, but it should be harmless. queries = [ Message.objects.filter(sender_id=cordelia.id).order_by('id'), Message.objects.filter(sender_id=hamlet.id), ] all_msg_ids = set() chunker = query_chunker( queries=queries, id_collector=all_msg_ids, chunk_size=7, # use a different size ) list(chunker) # exhaust the iterator self.assertEqual( len(all_msg_ids), len(Message.objects.filter(sender_id__in=[cordelia.id, hamlet.id])) ) # Try just a single query to validate chunking. queries = [ Message.objects.exclude(sender_id=cordelia.id), ] all_msg_ids = set() chunker = query_chunker( queries=queries, id_collector=all_msg_ids, chunk_size=11, # use a different size each time ) list(chunker) # exhaust the iterator self.assertEqual( len(all_msg_ids), len(Message.objects.exclude(sender_id=cordelia.id)) ) self.assertTrue(len(all_msg_ids) > 15) # Verify assertions about disjoint-ness. queries = [ Message.objects.exclude(sender_id=cordelia.id), Message.objects.filter(sender_id=hamlet.id), ] all_msg_ids = set() chunker = query_chunker( queries=queries, id_collector=all_msg_ids, chunk_size=13, # use a different size each time ) with self.assertRaises(AssertionError): list(chunker) # exercise the iterator # Try to confuse things with ids part of the query... queries = [ Message.objects.filter(id__lte=10), Message.objects.filter(id__gt=10), ] all_msg_ids = set() chunker = query_chunker( queries=queries, id_collector=all_msg_ids, chunk_size=11, # use a different size each time ) self.assertEqual(len(all_msg_ids), 0) # until we actually use the iterator list(chunker) # exhaust the iterator self.assertEqual(len(all_msg_ids), len(Message.objects.all())) # Verify that we can just get the first chunk with a next() call. queries = [ Message.objects.all(), ] all_msg_ids = set() chunker = query_chunker( queries=queries, id_collector=all_msg_ids, chunk_size=10, # use a different size each time ) first_chunk = next(chunker) # type: ignore self.assertEqual(len(first_chunk), 10) self.assertEqual(len(all_msg_ids), 10) expected_msg = Message.objects.all()[0:10][5] actual_msg = first_chunk[5] self.assertEqual(actual_msg.content, expected_msg.content) self.assertEqual(actual_msg.sender_id, expected_msg.sender_id) class ExportTest(ZulipTestCase): def setUp(self): # type: () -> None rm_tree(settings.LOCAL_UPLOADS_DIR) def _make_output_dir(self): # type: () -> str output_dir = 'var/test-export' rm_tree(output_dir) os.makedirs(output_dir, exist_ok=True) return output_dir def _export_realm(self, realm, exportable_user_ids=None): # type: (Realm, Set[int]) -> Dict[str, Any] output_dir = self._make_output_dir() with patch('logging.info'), patch('zerver.lib.export.create_soft_link'): do_export_realm( realm=realm, output_dir=output_dir, threads=0, exportable_user_ids=exportable_user_ids, ) # TODO: Process the second partial file, which can be created # for certain edge cases. export_usermessages_batch( input_path=os.path.join(output_dir, 'messages-000001.json.partial'), output_path=os.path.join(output_dir, 'message.json') ) def read_file(fn): # type: (str) -> Any full_fn = os.path.join(output_dir, fn) with open(full_fn) as f: return ujson.load(f) result = {} result['realm'] = read_file('realm.json') result['attachment'] = read_file('attachment.json') result['message'] = read_file('message.json') result['uploads_dir'] = os.path.join(output_dir, 'uploads') return result def test_attachment(self): # type: () -> None message = Message.objects.all()[0] user_profile = message.sender url = upload_message_image(u'dummy.txt', len(b'zulip!'), u'text/plain', b'zulip!', user_profile) path_id = url.replace('/user_uploads/', '') claim_attachment( user_profile=user_profile, path_id=path_id, message=message, is_message_realm_public=True ) realm = Realm.objects.get(string_id='zulip') full_data = self._export_realm(realm) data = full_data['attachment'] self.assertEqual(len(data['zerver_attachment']), 1) record = data['zerver_attachment'][0] self.assertEqual(record['path_id'], path_id) fn = os.path.join(full_data['uploads_dir'], path_id) with open(fn) as f: self.assertEqual(f.read(), 'zulip!') def test_zulip_realm(self): # type: () -> None realm = Realm.objects.get(string_id='zulip') full_data = self._export_realm(realm) data = full_data['realm'] self.assertEqual(len(data['zerver_userprofile_crossrealm']), 0) self.assertEqual(len(data['zerver_userprofile_mirrordummy']), 0) def get_set(table, field): # type: (str, str) -> Set[str] values = set(r[field] for r in data[table]) # print('set(%s)' % sorted(values)) return values def find_by_id(table, db_id): # type: (str, int) -> Dict[str, Any] return [ r for r in data[table] if r['id'] == db_id][0] exported_user_emails = get_set('zerver_userprofile', 'email') self.assertIn(self.example_email('cordelia'), exported_user_emails) self.assertIn('default-bot@zulip.com', exported_user_emails) self.assertIn('emailgateway@zulip.com', exported_user_emails) exported_streams = get_set('zerver_stream', 'name') self.assertEqual( exported_streams, set([u'Denmark', u'Rome', u'Scotland', u'Venice', u'Verona']) ) data = full_data['message'] um = UserMessage.objects.all()[0] exported_um = find_by_id('zerver_usermessage', um.id) self.assertEqual(exported_um['message'], um.message_id) self.assertEqual(exported_um['user_profile'], um.user_profile_id) exported_message = find_by_id('zerver_message', um.message_id) self.assertEqual(exported_message['content'], um.message.content) # TODO, extract get_set/find_by_id, so we can split this test up # Now, restrict users cordelia = self.example_user('cordelia') hamlet = self.example_user('hamlet') user_ids = set([cordelia.id, hamlet.id]) full_data = self._export_realm(realm, exportable_user_ids=user_ids) data = full_data['realm'] exported_user_emails = get_set('zerver_userprofile', 'email') self.assertIn(self.example_email('cordelia'), exported_user_emails) self.assertIn(self.example_email('hamlet'), exported_user_emails) self.assertNotIn('default-bot@zulip.com', exported_user_emails) self.assertNotIn(self.example_email('iago'), exported_user_emails) dummy_user_emails = get_set('zerver_userprofile_mirrordummy', 'email') self.assertIn(self.example_email('iago'), dummy_user_emails) self.assertNotIn(self.example_email('cordelia'), dummy_user_emails)