# -*- coding: utf-8 -*- from __future__ import absolute_import from __future__ import print_function from django.conf import settings from django.test import TestCase import os import shutil import ujson from mock import patch, MagicMock from six.moves import range from typing import Any, Dict, List, Set from zerver.lib.actions import ( do_claim_attachments, ) from zerver.lib.export import ( do_export_realm, export_usermessages_batch, ) from zerver.lib.upload import ( claim_attachment, upload_message_image, ) from zerver.lib.utils import ( mkdir_p, query_chunker, ) from zerver.lib.test_classes import ( ZulipTestCase, ) from zerver.lib.test_runner import slow from zerver.models import ( Message, Realm, Recipient, UserMessage, ) def rm_tree(path): # type: (str) -> None if os.path.exists(path): shutil.rmtree(path) class QueryUtilTest(ZulipTestCase): def _create_messages(self): # type: () -> None for email in [self.example_email('cordelia'), self.example_email('hamlet'), self.example_email('iago')]: for _ in range(5): self.send_message(email, self.example_email('othello'), Recipient.PERSONAL) @slow('creates lots of data') def test_query_chunker(self): # type: () -> None self._create_messages() cordelia = self.example_user('cordelia') hamlet = self.example_user('hamlet') def get_queries(): # type: () -> List[Any] queries = [ Message.objects.filter(sender_id=cordelia.id), Message.objects.filter(sender_id=hamlet.id), Message.objects.exclude(sender_id__in=[cordelia.id, hamlet.id]) ] return queries for query in get_queries(): # For our test to be meaningful, we want non-empty queries # at first assert len(list(query)) > 0 queries = get_queries() all_msg_ids = set() # type: Set[int] chunker = query_chunker( queries=queries, id_collector=all_msg_ids, chunk_size=20, ) all_row_ids = [] for chunk in chunker: for row in chunk: all_row_ids.append(row.id) self.assertEqual(all_row_ids, sorted(all_row_ids)) self.assertEqual(len(all_msg_ids), len(Message.objects.all())) # Now just search for cordelia/hamlet. Note that we don't really # need the order_by here, but it should be harmless. queries = [ Message.objects.filter(sender_id=cordelia.id).order_by('id'), Message.objects.filter(sender_id=hamlet.id), ] all_msg_ids = set() chunker = query_chunker( queries=queries, id_collector=all_msg_ids, chunk_size=7, # use a different size ) list(chunker) # exhaust the iterator self.assertEqual( len(all_msg_ids), len(Message.objects.filter(sender_id__in=[cordelia.id, hamlet.id])) ) # Try just a single query to validate chunking. queries = [ Message.objects.exclude(sender_id=cordelia.id), ] all_msg_ids = set() chunker = query_chunker( queries=queries, id_collector=all_msg_ids, chunk_size=11, # use a different size each time ) list(chunker) # exhaust the iterator self.assertEqual( len(all_msg_ids), len(Message.objects.exclude(sender_id=cordelia.id)) ) self.assertTrue(len(all_msg_ids) > 15) # Verify assertions about disjoint-ness. queries = [ Message.objects.exclude(sender_id=cordelia.id), Message.objects.filter(sender_id=hamlet.id), ] all_msg_ids = set() chunker = query_chunker( queries=queries, id_collector=all_msg_ids, chunk_size=13, # use a different size each time ) with self.assertRaises(AssertionError): list(chunker) # exercise the iterator # Try to confuse things with ids part of the query... queries = [ Message.objects.filter(id__lte=10), Message.objects.filter(id__gt=10), ] all_msg_ids = set() chunker = query_chunker( queries=queries, id_collector=all_msg_ids, chunk_size=11, # use a different size each time ) self.assertEqual(len(all_msg_ids), 0) # until we actually use the iterator list(chunker) # exhaust the iterator self.assertEqual(len(all_msg_ids), len(Message.objects.all())) # Verify that we can just get the first chunk with a next() call. queries = [ Message.objects.all(), ] all_msg_ids = set() chunker = query_chunker( queries=queries, id_collector=all_msg_ids, chunk_size=10, # use a different size each time ) first_chunk = next(chunker) # type: ignore self.assertEqual(len(first_chunk), 10) self.assertEqual(len(all_msg_ids), 10) expected_msg = Message.objects.all()[0:10][5] actual_msg = first_chunk[5] self.assertEqual(actual_msg.content, expected_msg.content) self.assertEqual(actual_msg.sender_id, expected_msg.sender_id) class ExportTest(ZulipTestCase): def setUp(self): # type: () -> None rm_tree(settings.LOCAL_UPLOADS_DIR) def _make_output_dir(self): # type: () -> str output_dir = 'var/test-export' rm_tree(output_dir) mkdir_p(output_dir) return output_dir def _export_realm(self, realm, exportable_user_ids=None): # type: (Realm, Set[int]) -> Dict[str, Any] output_dir = self._make_output_dir() with patch('logging.info'), patch('zerver.lib.export.create_soft_link'): do_export_realm( realm=realm, output_dir=output_dir, threads=0, exportable_user_ids=exportable_user_ids, ) # TODO: Process the second partial file, which can be created # for certain edge cases. export_usermessages_batch( input_path=os.path.join(output_dir, 'messages-000001.json.partial'), output_path=os.path.join(output_dir, 'message.json') ) def read_file(fn): # type: (str) -> Any full_fn = os.path.join(output_dir, fn) with open(full_fn) as f: return ujson.load(f) result = {} result['realm'] = read_file('realm.json') result['attachment'] = read_file('attachment.json') result['message'] = read_file('message.json') result['uploads_dir'] = os.path.join(output_dir, 'uploads') return result def test_attachment(self): # type: () -> None message = Message.objects.all()[0] user_profile = message.sender url = upload_message_image(u'dummy.txt', len(b'zulip!'), u'text/plain', b'zulip!', user_profile) path_id = url.replace('/user_uploads/', '') claim_attachment( user_profile=user_profile, path_id=path_id, message=message, is_message_realm_public=True ) realm = Realm.objects.get(string_id='zulip') full_data = self._export_realm(realm) data = full_data['attachment'] self.assertEqual(len(data['zerver_attachment']), 1) record = data['zerver_attachment'][0] self.assertEqual(record['path_id'], path_id) fn = os.path.join(full_data['uploads_dir'], path_id) with open(fn) as f: self.assertEqual(f.read(), 'zulip!') def test_zulip_realm(self): # type: () -> None realm = Realm.objects.get(string_id='zulip') full_data = self._export_realm(realm) data = full_data['realm'] self.assertEqual(len(data['zerver_userprofile_crossrealm']), 0) self.assertEqual(len(data['zerver_userprofile_mirrordummy']), 0) def get_set(table, field): # type: (str, str) -> Set[str] values = set(r[field] for r in data[table]) # print('set(%s)' % sorted(values)) return values def find_by_id(table, db_id): # type: (str, int) -> Dict[str, Any] return [ r for r in data[table] if r['id'] == db_id][0] exported_user_emails = get_set('zerver_userprofile', 'email') self.assertIn(self.example_email('cordelia'), exported_user_emails) self.assertIn('default-bot@zulip.com', exported_user_emails) self.assertIn('emailgateway@zulip.com', exported_user_emails) exported_streams = get_set('zerver_stream', 'name') self.assertEqual( exported_streams, set([u'Denmark', u'Rome', u'Scotland', u'Venice', u'Verona']) ) data = full_data['message'] um = UserMessage.objects.all()[0] exported_um = find_by_id('zerver_usermessage', um.id) self.assertEqual(exported_um['message'], um.message_id) self.assertEqual(exported_um['user_profile'], um.user_profile_id) exported_message = find_by_id('zerver_message', um.message_id) self.assertEqual(exported_message['content'], um.message.content) # TODO, extract get_set/find_by_id, so we can split this test up # Now, restrict users cordelia = self.example_user('cordelia') hamlet = self.example_user('hamlet') user_ids = set([cordelia.id, hamlet.id]) full_data = self._export_realm(realm, exportable_user_ids=user_ids) data = full_data['realm'] exported_user_emails = get_set('zerver_userprofile', 'email') self.assertIn(self.example_email('cordelia'), exported_user_emails) self.assertIn(self.example_email('hamlet'), exported_user_emails) self.assertNotIn('default-bot@zulip.com', exported_user_emails) self.assertNotIn(self.example_email('iago'), exported_user_emails) dummy_user_emails = get_set('zerver_userprofile_mirrordummy', 'email') self.assertIn(self.example_email('iago'), dummy_user_emails) self.assertNotIn(self.example_email('cordelia'), dummy_user_emails)