diff --git a/zerver/actions/message_edit.py b/zerver/actions/message_edit.py index fc17a013fb..93a01a7832 100644 --- a/zerver/actions/message_edit.py +++ b/zerver/actions/message_edit.py @@ -1217,7 +1217,7 @@ def check_update_message( and raises a JsonableError if otherwise. It returns the number changed. """ - message, ignored_user_message = access_message(user_profile, message_id, lock_message=True) + message = access_message(user_profile, message_id, lock_message=True) # If there is a change to the content, check that it hasn't been too long # Allow an extra 20 seconds since we potentially allow editing 15 seconds diff --git a/zerver/actions/reactions.py b/zerver/actions/reactions.py index 0b668df8fe..24b2ced9fd 100644 --- a/zerver/actions/reactions.py +++ b/zerver/actions/reactions.py @@ -126,7 +126,9 @@ def check_add_reaction( emoji_code: Optional[str], reaction_type: Optional[str], ) -> None: - message, user_message = access_message(user_profile, message_id, lock_message=True) + message, has_user_message = access_message( + user_profile, message_id, lock_message=True, get_user_message="exists" + ) if emoji_code is None or reaction_type is None: emoji_data = get_emoji_data(message.realm_id, emoji_name) @@ -179,7 +181,7 @@ def check_add_reaction( # realm emoji). check_emoji_request(user_profile.realm, emoji_name, emoji_code, reaction_type) - if user_message is None: + if not has_user_message: # See called function for more context. create_historical_user_messages(user_id=user_profile.id, message_ids=[message.id]) diff --git a/zerver/lib/message.py b/zerver/lib/message.py index fa4fe002b0..6769cb20ba 100644 --- a/zerver/lib/message.py +++ b/zerver/lib/message.py @@ -5,6 +5,7 @@ from typing import ( Collection, Dict, List, + Literal, Mapping, Optional, Sequence, @@ -12,6 +13,7 @@ from typing import ( Tuple, TypedDict, Union, + overload, ) from django.conf import settings @@ -260,11 +262,35 @@ def messages_for_ids( return message_list +@overload def access_message( user_profile: UserProfile, message_id: int, + get_user_message: None = ..., + lock_message: bool = ..., +) -> Message: ... +@overload +def access_message( + user_profile: UserProfile, + message_id: int, + get_user_message: Literal["exists"], + lock_message: bool = ..., +) -> Tuple[Message, bool]: ... +@overload +def access_message( + user_profile: UserProfile, + message_id: int, + get_user_message: Literal["object"], + lock_message: bool = ..., +) -> Tuple[Message, Optional[UserMessage]]: ... + + +def access_message( + user_profile: UserProfile, + message_id: int, + get_user_message: Optional[Literal["exists", "object"]] = None, lock_message: bool = False, -) -> Tuple[Message, Optional[UserMessage]]: +) -> Union[Message, Tuple[Message, bool], Tuple[Message, Optional[UserMessage]]]: """You can access a message by ID in our APIs that either: (1) You received or have previously accessed via starring (aka have a UserMessage row for). @@ -290,10 +316,21 @@ def access_message( except Message.DoesNotExist: raise JsonableError(_("Invalid message(s)")) - user_message = get_usermessage_by_message_id(user_profile, message_id) + if get_user_message == "object": + user_message = get_usermessage_by_message_id(user_profile, message_id) + has_user_message = user_message is not None + else: + has_user_message = UserMessage.objects.filter( + user_profile=user_profile, message_id=message_id + ).exists() - if has_message_access(user_profile, message, has_user_message=user_message is not None): - return (message, user_message) + if has_message_access(user_profile, message, has_user_message=has_user_message): + if get_user_message is None: + return message + if get_user_message == "exists": + return (message, has_user_message) + if get_user_message == "object": + return (message, user_message) raise JsonableError(_("Invalid message(s)")) diff --git a/zerver/lib/push_notifications.py b/zerver/lib/push_notifications.py index b11bc9c59f..8fb3ed38fc 100644 --- a/zerver/lib/push_notifications.py +++ b/zerver/lib/push_notifications.py @@ -1299,7 +1299,10 @@ def handle_push_notification(user_profile_id: int, missed_message: Dict[str, Any with transaction.atomic(savepoint=False): try: (message, user_message) = access_message( - user_profile, missed_message["message_id"], lock_message=True + user_profile, + missed_message["message_id"], + lock_message=True, + get_user_message="object", ) except JsonableError: if ArchivedMessage.objects.filter(id=missed_message["message_id"]).exists(): diff --git a/zerver/lib/test_classes.py b/zerver/lib/test_classes.py index 8d85a1b3cd..461c4f0703 100644 --- a/zerver/lib/test_classes.py +++ b/zerver/lib/test_classes.py @@ -1472,7 +1472,7 @@ Output: """ Mark all messages within the topic associated with message `target_message_id` as resolved. """ - message, _ = access_message(acting_user, target_message_id) + message = access_message(acting_user, target_message_id) return self.api_patch( acting_user, f"/api/v1/messages/{target_message_id}", diff --git a/zerver/views/message_edit.py b/zerver/views/message_edit.py index aae308d0fa..613e0689fa 100644 --- a/zerver/views/message_edit.py +++ b/zerver/views/message_edit.py @@ -97,7 +97,7 @@ def get_message_edit_history( ) -> HttpResponse: if not user_profile.realm.allow_edit_history: raise JsonableError(_("Message edit history is disabled in this organization")) - message, ignored_user_message = access_message(user_profile, message_id) + message = access_message(user_profile, message_id) # Extract the message edit history from the message if message.edit_history is not None: @@ -175,7 +175,7 @@ def delete_message_backend( # concurrently are serialized properly with deleting the message; this prevents a deadlock # that would otherwise happen because of the other transaction holding a lock on the `Message` # row. - message, ignored_user_message = access_message(user_profile, message_id, lock_message=True) + message = access_message(user_profile, message_id, lock_message=True) validate_can_delete_message(user_profile, message) try: do_delete_messages(user_profile.realm, [message]) @@ -197,7 +197,9 @@ def json_fetch_raw_message( message = access_web_public_message(realm, message_id) user_profile = None else: - (message, user_message) = access_message(maybe_user_profile, message_id) + (message, user_message) = access_message( + maybe_user_profile, message_id, get_user_message="object" + ) user_profile = maybe_user_profile flags = ["read"] diff --git a/zerver/views/reactions.py b/zerver/views/reactions.py index a0e6a14451..740b977a7d 100644 --- a/zerver/views/reactions.py +++ b/zerver/views/reactions.py @@ -40,7 +40,7 @@ def remove_reaction( emoji_code: Optional[str] = REQ(default=None), reaction_type: str = REQ(default="unicode_emoji"), ) -> HttpResponse: - message, user_message = access_message(user_profile, message_id, lock_message=True) + message = access_message(user_profile, message_id, lock_message=True) if emoji_code is None: if emoji_name is None: diff --git a/zerver/views/read_receipts.py b/zerver/views/read_receipts.py index 01f0757354..f1b7b7a118 100644 --- a/zerver/views/read_receipts.py +++ b/zerver/views/read_receipts.py @@ -16,7 +16,7 @@ def read_receipts( user_profile: UserProfile, message_id: int = REQ(converter=to_non_negative_int, path_only=True), ) -> HttpResponse: - message = access_message(user_profile, message_id)[0] + message = access_message(user_profile, message_id) if not user_profile.realm.enable_read_receipts: raise JsonableError(_("Read receipts are disabled in this organization.")) diff --git a/zerver/views/submessage.py b/zerver/views/submessage.py index bbcd0939dc..b82ab81ab6 100644 --- a/zerver/views/submessage.py +++ b/zerver/views/submessage.py @@ -24,7 +24,7 @@ def process_submessage( msg_type: str = REQ(), content: str = REQ(), ) -> HttpResponse: - message, user_message = access_message(user_profile, message_id, lock_message=True) + message = access_message(user_profile, message_id, lock_message=True) verify_submessage_sender( message_id=message.id,