messages: Use overloads to only return a user_message if needed.

This commit is contained in:
Alex Vandiver 2024-03-22 05:45:17 +00:00 committed by Tim Abbott
parent 6ace34c374
commit f92d43c690
9 changed files with 59 additions and 15 deletions

View File

@ -1217,7 +1217,7 @@ def check_update_message(
and raises a JsonableError if otherwise.
It returns the number changed.
"""
message, ignored_user_message = access_message(user_profile, message_id, lock_message=True)
message = access_message(user_profile, message_id, lock_message=True)
# If there is a change to the content, check that it hasn't been too long
# Allow an extra 20 seconds since we potentially allow editing 15 seconds

View File

@ -126,7 +126,9 @@ def check_add_reaction(
emoji_code: Optional[str],
reaction_type: Optional[str],
) -> None:
message, user_message = access_message(user_profile, message_id, lock_message=True)
message, has_user_message = access_message(
user_profile, message_id, lock_message=True, get_user_message="exists"
)
if emoji_code is None or reaction_type is None:
emoji_data = get_emoji_data(message.realm_id, emoji_name)
@ -179,7 +181,7 @@ def check_add_reaction(
# realm emoji).
check_emoji_request(user_profile.realm, emoji_name, emoji_code, reaction_type)
if user_message is None:
if not has_user_message:
# See called function for more context.
create_historical_user_messages(user_id=user_profile.id, message_ids=[message.id])

View File

@ -5,6 +5,7 @@ from typing import (
Collection,
Dict,
List,
Literal,
Mapping,
Optional,
Sequence,
@ -12,6 +13,7 @@ from typing import (
Tuple,
TypedDict,
Union,
overload,
)
from django.conf import settings
@ -260,11 +262,35 @@ def messages_for_ids(
return message_list
@overload
def access_message(
user_profile: UserProfile,
message_id: int,
get_user_message: None = ...,
lock_message: bool = ...,
) -> Message: ...
@overload
def access_message(
user_profile: UserProfile,
message_id: int,
get_user_message: Literal["exists"],
lock_message: bool = ...,
) -> Tuple[Message, bool]: ...
@overload
def access_message(
user_profile: UserProfile,
message_id: int,
get_user_message: Literal["object"],
lock_message: bool = ...,
) -> Tuple[Message, Optional[UserMessage]]: ...
def access_message(
user_profile: UserProfile,
message_id: int,
get_user_message: Optional[Literal["exists", "object"]] = None,
lock_message: bool = False,
) -> Tuple[Message, Optional[UserMessage]]:
) -> Union[Message, Tuple[Message, bool], Tuple[Message, Optional[UserMessage]]]:
"""You can access a message by ID in our APIs that either:
(1) You received or have previously accessed via starring
(aka have a UserMessage row for).
@ -290,10 +316,21 @@ def access_message(
except Message.DoesNotExist:
raise JsonableError(_("Invalid message(s)"))
user_message = get_usermessage_by_message_id(user_profile, message_id)
if get_user_message == "object":
user_message = get_usermessage_by_message_id(user_profile, message_id)
has_user_message = user_message is not None
else:
has_user_message = UserMessage.objects.filter(
user_profile=user_profile, message_id=message_id
).exists()
if has_message_access(user_profile, message, has_user_message=user_message is not None):
return (message, user_message)
if has_message_access(user_profile, message, has_user_message=has_user_message):
if get_user_message is None:
return message
if get_user_message == "exists":
return (message, has_user_message)
if get_user_message == "object":
return (message, user_message)
raise JsonableError(_("Invalid message(s)"))

View File

@ -1299,7 +1299,10 @@ def handle_push_notification(user_profile_id: int, missed_message: Dict[str, Any
with transaction.atomic(savepoint=False):
try:
(message, user_message) = access_message(
user_profile, missed_message["message_id"], lock_message=True
user_profile,
missed_message["message_id"],
lock_message=True,
get_user_message="object",
)
except JsonableError:
if ArchivedMessage.objects.filter(id=missed_message["message_id"]).exists():

View File

@ -1472,7 +1472,7 @@ Output:
"""
Mark all messages within the topic associated with message `target_message_id` as resolved.
"""
message, _ = access_message(acting_user, target_message_id)
message = access_message(acting_user, target_message_id)
return self.api_patch(
acting_user,
f"/api/v1/messages/{target_message_id}",

View File

@ -97,7 +97,7 @@ def get_message_edit_history(
) -> HttpResponse:
if not user_profile.realm.allow_edit_history:
raise JsonableError(_("Message edit history is disabled in this organization"))
message, ignored_user_message = access_message(user_profile, message_id)
message = access_message(user_profile, message_id)
# Extract the message edit history from the message
if message.edit_history is not None:
@ -175,7 +175,7 @@ def delete_message_backend(
# concurrently are serialized properly with deleting the message; this prevents a deadlock
# that would otherwise happen because of the other transaction holding a lock on the `Message`
# row.
message, ignored_user_message = access_message(user_profile, message_id, lock_message=True)
message = access_message(user_profile, message_id, lock_message=True)
validate_can_delete_message(user_profile, message)
try:
do_delete_messages(user_profile.realm, [message])
@ -197,7 +197,9 @@ def json_fetch_raw_message(
message = access_web_public_message(realm, message_id)
user_profile = None
else:
(message, user_message) = access_message(maybe_user_profile, message_id)
(message, user_message) = access_message(
maybe_user_profile, message_id, get_user_message="object"
)
user_profile = maybe_user_profile
flags = ["read"]

View File

@ -40,7 +40,7 @@ def remove_reaction(
emoji_code: Optional[str] = REQ(default=None),
reaction_type: str = REQ(default="unicode_emoji"),
) -> HttpResponse:
message, user_message = access_message(user_profile, message_id, lock_message=True)
message = access_message(user_profile, message_id, lock_message=True)
if emoji_code is None:
if emoji_name is None:

View File

@ -16,7 +16,7 @@ def read_receipts(
user_profile: UserProfile,
message_id: int = REQ(converter=to_non_negative_int, path_only=True),
) -> HttpResponse:
message = access_message(user_profile, message_id)[0]
message = access_message(user_profile, message_id)
if not user_profile.realm.enable_read_receipts:
raise JsonableError(_("Read receipts are disabled in this organization."))

View File

@ -24,7 +24,7 @@ def process_submessage(
msg_type: str = REQ(),
content: str = REQ(),
) -> HttpResponse:
message, user_message = access_message(user_profile, message_id, lock_message=True)
message = access_message(user_profile, message_id, lock_message=True)
verify_submessage_sender(
message_id=message.id,