From 436dab0e017f43382519ff5ad0b6bd88e049ed9b Mon Sep 17 00:00:00 2001 From: Alex Vandiver Date: Thu, 11 Apr 2024 20:48:10 +0000 Subject: [PATCH] messages: Remove use of @overload in access_message. f92d43c6908e added uses of `@overload` to probide multiple type signatures for `access_message`, based on the `get_user_message` parameter. Unfortunately, mypy does not check the function body against overload signatures, so it allows type errors to go undetected. Replace the overloads with two functions, for one of which also returns the usermessage. The third form, of only returning if the usermessage exists, is not in a high-enough performance endpoint that a third form is worth maintaining; it uses the usermessage form. --- zerver/actions/reactions.py | 8 ++-- zerver/lib/message.py | 74 +++++++++++++------------------- zerver/lib/push_notifications.py | 9 ++-- zerver/views/message_edit.py | 11 +++-- 4 files changed, 43 insertions(+), 59 deletions(-) diff --git a/zerver/actions/reactions.py b/zerver/actions/reactions.py index 2462b19ea1..c6c0c06e6b 100644 --- a/zerver/actions/reactions.py +++ b/zerver/actions/reactions.py @@ -4,7 +4,7 @@ from zerver.actions.user_topics import do_set_user_topic_visibility_policy from zerver.lib.emoji import check_emoji_request, get_emoji_data from zerver.lib.exceptions import ReactionExistsError from zerver.lib.message import ( - access_message, + access_message_and_usermessage, set_visibility_policy_possible, should_change_visibility_policy, visibility_policy_for_participation, @@ -126,8 +126,8 @@ def check_add_reaction( emoji_code: Optional[str], reaction_type: Optional[str], ) -> None: - message, has_user_message = access_message( - user_profile, message_id, lock_message=True, get_user_message="exists" + message, user_message = access_message_and_usermessage( + user_profile, message_id, lock_message=True ) if emoji_code is None or reaction_type is None: @@ -181,7 +181,7 @@ def check_add_reaction( # realm emoji). check_emoji_request(user_profile.realm, emoji_name, emoji_code, reaction_type) - if not has_user_message: + if user_message is None: # See called function for more context. create_historical_user_messages(user_id=user_profile.id, message_ids=[message.id]) diff --git a/zerver/lib/message.py b/zerver/lib/message.py index 557de205f0..6d18db147f 100644 --- a/zerver/lib/message.py +++ b/zerver/lib/message.py @@ -7,7 +7,6 @@ from typing import ( Collection, Dict, List, - Literal, Mapping, Optional, Sequence, @@ -15,7 +14,6 @@ from typing import ( Tuple, TypedDict, Union, - overload, ) from django.conf import settings @@ -265,35 +263,11 @@ def messages_for_ids( return message_list -@overload def access_message( user_profile: UserProfile, message_id: int, - get_user_message: None = ..., - lock_message: bool = ..., -) -> Message: ... -@overload -def access_message( - user_profile: UserProfile, - message_id: int, - get_user_message: Literal["exists"], - lock_message: bool = ..., -) -> Tuple[Message, bool]: ... -@overload -def access_message( - user_profile: UserProfile, - message_id: int, - get_user_message: Literal["object"], - lock_message: bool = ..., -) -> Tuple[Message, Optional[UserMessage]]: ... - - -def access_message( - user_profile: UserProfile, - message_id: int, - get_user_message: Optional[Literal["exists", "object"]] = None, lock_message: bool = False, -) -> Union[Message, Tuple[Message, bool], Tuple[Message, Optional[UserMessage]]]: +) -> Message: """You can access a message by ID in our APIs that either: (1) You received or have previously accessed via starring (aka have a UserMessage row for). @@ -319,26 +293,36 @@ def access_message( except Message.DoesNotExist: raise JsonableError(_("Invalid message(s)")) - if get_user_message == "object": - user_message = get_usermessage_by_message_id(user_profile, message_id) - has_user_message = lambda: user_message is not None - elif get_user_message == "exists": - local_exists = UserMessage.objects.filter( - user_profile=user_profile, message_id=message_id - ).exists() - has_user_message = lambda: local_exists - else: - has_user_message = lambda: UserMessage.objects.filter( - user_profile=user_profile, message_id=message_id - ).exists() + has_user_message = lambda: UserMessage.objects.filter( + user_profile=user_profile, message_id=message_id + ).exists() if has_message_access(user_profile, message, has_user_message=has_user_message): - if get_user_message is None: - return message - if get_user_message == "exists": - return (message, local_exists) - if get_user_message == "object": - return (message, user_message) + return message + raise JsonableError(_("Invalid message(s)")) + + +def access_message_and_usermessage( + user_profile: UserProfile, + message_id: int, + lock_message: bool = False, +) -> Tuple[Message, Optional[UserMessage]]: + """As access_message, but also returns the usermessage, if any.""" + try: + base_query = Message.objects.select_related(*Message.DEFAULT_SELECT_RELATED) + if lock_message: + # We want to lock only the `Message` row, and not the related fields + # because the `Message` row only has a possibility of races. + base_query = base_query.select_for_update(of=("self",)) + message = base_query.get(id=message_id) + except Message.DoesNotExist: + raise JsonableError(_("Invalid message(s)")) + + user_message = get_usermessage_by_message_id(user_profile, message_id) + has_user_message = lambda: user_message is not None + + if has_message_access(user_profile, message, has_user_message=has_user_message): + return (message, user_message) raise JsonableError(_("Invalid message(s)")) diff --git a/zerver/lib/push_notifications.py b/zerver/lib/push_notifications.py index e32cd510b7..24aea636e2 100644 --- a/zerver/lib/push_notifications.py +++ b/zerver/lib/push_notifications.py @@ -42,7 +42,7 @@ from zerver.lib.avatar import absolute_avatar_url, get_avatar_for_inaccessible_u from zerver.lib.display_recipient import get_display_recipient from zerver.lib.emoji_utils import hex_codepoint_to_emoji from zerver.lib.exceptions import ErrorCode, JsonableError -from zerver.lib.message import access_message, huddle_users +from zerver.lib.message import access_message_and_usermessage, huddle_users from zerver.lib.outgoing_http import OutgoingSession from zerver.lib.remote_server import ( send_json_to_push_bouncer, @@ -1289,11 +1289,8 @@ def handle_push_notification(user_profile_id: int, missed_message: Dict[str, Any with transaction.atomic(savepoint=False): try: - (message, user_message) = access_message( - user_profile, - missed_message["message_id"], - lock_message=True, - get_user_message="object", + (message, user_message) = access_message_and_usermessage( + user_profile, missed_message["message_id"], lock_message=True ) except JsonableError: if ArchivedMessage.objects.filter(id=missed_message["message_id"]).exists(): diff --git a/zerver/views/message_edit.py b/zerver/views/message_edit.py index 613e0689fa..254eeeb455 100644 --- a/zerver/views/message_edit.py +++ b/zerver/views/message_edit.py @@ -14,7 +14,12 @@ from zerver.actions.message_edit import check_update_message from zerver.context_processors import get_valid_realm_from_request from zerver.lib.exceptions import JsonableError from zerver.lib.html_diff import highlight_html_differences -from zerver.lib.message import access_message, access_web_public_message, messages_for_ids +from zerver.lib.message import ( + access_message, + access_message_and_usermessage, + access_web_public_message, + messages_for_ids, +) from zerver.lib.request import RequestNotes from zerver.lib.response import json_success from zerver.lib.timestamp import datetime_to_timestamp @@ -197,9 +202,7 @@ def json_fetch_raw_message( message = access_web_public_message(realm, message_id) user_profile = None else: - (message, user_message) = access_message( - maybe_user_profile, message_id, get_user_message="object" - ) + (message, user_message) = access_message_and_usermessage(maybe_user_profile, message_id) user_profile = maybe_user_profile flags = ["read"]