From b15941610d0c030a13e0d6d25f3560a1e7e33bae Mon Sep 17 00:00:00 2001 From: Tim Abbott Date: Tue, 11 May 2021 15:31:03 -0700 Subject: [PATCH] message: Support avoiding database queries in has_message_access. If the caller has already fetched the Stream or subscription details for the user, those can be passed to has_message_access to avoid extra database queries. --- zerver/lib/message.py | 24 +++++++++++++++++++++-- zerver/tests/test_message_edit.py | 32 +++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 2 deletions(-) diff --git a/zerver/lib/message.py b/zerver/lib/message.py index 0f0908bb46..f827c14463 100644 --- a/zerver/lib/message.py +++ b/zerver/lib/message.py @@ -677,8 +677,21 @@ def access_message( def has_message_access( - user_profile: UserProfile, message: Message, user_message: Optional[UserMessage] + user_profile: UserProfile, + message: Message, + user_message: Optional[UserMessage], + *, + stream: Optional[Stream] = None, + is_subscribed: Optional[bool] = None, ) -> bool: + """ + Returns whether a user has access to a given message. + + * The user_message parameter must be provded if the user has a UserMessage + row for the target message. + * The optional stream parameter is validated; is_subscribed is not. + """ + # If you have a user_message object, you have access. if user_message is not None: return True @@ -687,7 +700,11 @@ def has_message_access( # You can't access private messages you didn't receive return False - stream = Stream.objects.get(id=message.recipient.type_id) + if stream is None: + stream = Stream.objects.get(id=message.recipient.type_id) + else: + assert stream.recipient_id == message.recipient_id + if stream.realm != user_profile.realm: # You can't access public stream messages in other realms return False @@ -701,6 +718,9 @@ def has_message_access( return True # is_history_public_to_subscribers, so check if you're subscribed + if is_subscribed is not None: + return is_subscribed + return Subscription.objects.filter( user_profile=user_profile, active=True, recipient=message.recipient ).exists() diff --git a/zerver/tests/test_message_edit.py b/zerver/tests/test_message_edit.py index 4c5eb9e7b1..383dcc39cb 100644 --- a/zerver/tests/test_message_edit.py +++ b/zerver/tests/test_message_edit.py @@ -1363,6 +1363,12 @@ class EditMessageTest(ZulipTestCase): has_message_access(guest_user, Message.objects.get(id=msg_id_to_test_acesss), None), True, ) + self.assertEqual( + has_message_access( + guest_user, Message.objects.get(id=msg_id_to_test_acesss), None, stream=old_stream + ), + True, + ) self.assertEqual( has_message_access(non_guest_user, Message.objects.get(id=msg_id_to_test_acesss), None), True, @@ -1387,6 +1393,32 @@ class EditMessageTest(ZulipTestCase): has_message_access(non_guest_user, Message.objects.get(id=msg_id_to_test_acesss), None), True, ) + self.assertEqual( + # If the guest user were subscribed to the new stream, + # they'd have access; has_message_access does not validate + # the is_subscribed parameter. + has_message_access( + guest_user, + Message.objects.get(id=msg_id_to_test_acesss), + None, + stream=new_stream, + is_subscribed=True, + ), + True, + ) + + self.assertEqual( + has_message_access( + guest_user, Message.objects.get(id=msg_id_to_test_acesss), None, stream=new_stream + ), + False, + ) + with self.assertRaises(AssertionError): + # Raises assertion if you pass an invalid stream. + has_message_access( + guest_user, Message.objects.get(id=msg_id_to_test_acesss), None, stream=old_stream + ) + self.assertEqual( UserMessage.objects.filter( user_profile_id=non_guest_user.id,