diff --git a/zephyr/tests.py b/zephyr/tests.py index 8f54bd9d43..3d3e4429b1 100644 --- a/zephyr/tests.py +++ b/zephyr/tests.py @@ -2563,30 +2563,17 @@ class UserPresenceTests(AuthedTestCase): self.assertEqual(email_to_domain(email), 'humbughq.com') class UnreadCountTests(AuthedTestCase): - - def test_initial_counts(self): - # All test users have a pointer at -1, so all messages are read - for user in UserProfile.objects.all(): - for message in UserMessage.objects.filter(user_profile=user): - self.assertTrue(message.flags.read) - - self.login('hamlet@humbughq.com') - for msg in self.get_old_messages(): - self.assertEqual(msg['flags'], ['read']) + def setUp(self): + self.unread_msgs = [self.send_message("iago@humbughq.com", "hamlet@humbughq.com", Recipient.PERSONAL, "hello"), + self.send_message("iago@humbughq.com", "hamlet@humbughq.com", Recipient.PERSONAL, "hello2")] def test_new_message(self): # Sending a new message results in unread UserMessages being created self.login("hamlet@humbughq.com") content = "Test message for unset read bit" - self.client.post("/json/send_message", {"type": "stream", - "to": "Verona", - "client": "test suite", - "content": content, - "subject": "Test subject"}) - last_msg = Message.objects.all().order_by("-id")[0] - self.assertEqual(last_msg.content, "Test message for unset read bit") + last_msg = self.send_message("hamlet@humbughq.com", "Verona", Recipient.STREAM, content) user_messages = list(UserMessage.objects.filter(message=last_msg)) - self.assertEqual(4, len(user_messages)) + self.assertEqual(len(user_messages) > 0, True) for um in user_messages: self.assertEqual(um.message.content, content) if um.user_profile.email != "hamlet@humbughq.com": @@ -2595,28 +2582,30 @@ class UnreadCountTests(AuthedTestCase): def test_update_flags(self): self.login("hamlet@humbughq.com") - result = self.client.post("/json/update_message_flags", {"messages": ujson.dumps([1, 2]), - "op": "add", - "flag": "read"}) + result = self.client.post("/json/update_message_flags", + {"messages": ujson.dumps([msg.id for msg in self.unread_msgs]), + "op": "add", + "flag": "read"}) self.assert_json_success(result) # Ensure we properly set the flags + found = 0 for msg in self.get_old_messages(): - if msg['id'] == 1: - self.assertEqual(msg['flags'], ['read']) - elif msg['id'] == 2: + if msg['id'] in [message.id for message in self.unread_msgs]: self.assertEqual(msg['flags'], ['read']) + found += 1 + self.assertEqual(found, 2) - result = self.client.post("/json/update_message_flags", {"messages": ujson.dumps([2]), + result = self.client.post("/json/update_message_flags", {"messages": ujson.dumps([self.unread_msgs[1].id]), "op": "remove", "flag": "read"}) self.assert_json_success(result) # Ensure we properly remove just one flag for msg in self.get_old_messages(): - if msg['id'] == 1: + if msg['id'] == self.unread_msgs[0].id: self.assertEqual(msg['flags'], ['read']) - elif msg['id'] == 2: + elif msg['id'] == self.unread_msgs[1].id: self.assertEqual(msg['flags'], []) def test_update_all_flags(self):