messages: Remove use of @overload in access_message.

f92d43c690 added uses of `@overload` to probide multiple type
signatures for `access_message`, based on the `get_user_message`
parameter.  Unfortunately, mypy does not check the function body
against overload signatures, so it allows type errors to go
undetected.

Replace the overloads with two functions, for one of which also
returns the usermessage.  The third form, of only returning if the
usermessage exists, is not in a high-enough performance endpoint that
a third form is worth maintaining; it uses the usermessage form.
This commit is contained in:
Alex Vandiver 2024-04-11 20:48:10 +00:00 committed by Tim Abbott
parent 30f71639f0
commit 436dab0e01
4 changed files with 43 additions and 59 deletions

View File

@ -4,7 +4,7 @@ from zerver.actions.user_topics import do_set_user_topic_visibility_policy
from zerver.lib.emoji import check_emoji_request, get_emoji_data
from zerver.lib.exceptions import ReactionExistsError
from zerver.lib.message import (
access_message,
access_message_and_usermessage,
set_visibility_policy_possible,
should_change_visibility_policy,
visibility_policy_for_participation,
@ -126,8 +126,8 @@ def check_add_reaction(
emoji_code: Optional[str],
reaction_type: Optional[str],
) -> None:
message, has_user_message = access_message(
user_profile, message_id, lock_message=True, get_user_message="exists"
message, user_message = access_message_and_usermessage(
user_profile, message_id, lock_message=True
)
if emoji_code is None or reaction_type is None:
@ -181,7 +181,7 @@ def check_add_reaction(
# realm emoji).
check_emoji_request(user_profile.realm, emoji_name, emoji_code, reaction_type)
if not has_user_message:
if user_message is None:
# See called function for more context.
create_historical_user_messages(user_id=user_profile.id, message_ids=[message.id])

View File

@ -7,7 +7,6 @@ from typing import (
Collection,
Dict,
List,
Literal,
Mapping,
Optional,
Sequence,
@ -15,7 +14,6 @@ from typing import (
Tuple,
TypedDict,
Union,
overload,
)
from django.conf import settings
@ -265,35 +263,11 @@ def messages_for_ids(
return message_list
@overload
def access_message(
user_profile: UserProfile,
message_id: int,
get_user_message: None = ...,
lock_message: bool = ...,
) -> Message: ...
@overload
def access_message(
user_profile: UserProfile,
message_id: int,
get_user_message: Literal["exists"],
lock_message: bool = ...,
) -> Tuple[Message, bool]: ...
@overload
def access_message(
user_profile: UserProfile,
message_id: int,
get_user_message: Literal["object"],
lock_message: bool = ...,
) -> Tuple[Message, Optional[UserMessage]]: ...
def access_message(
user_profile: UserProfile,
message_id: int,
get_user_message: Optional[Literal["exists", "object"]] = None,
lock_message: bool = False,
) -> Union[Message, Tuple[Message, bool], Tuple[Message, Optional[UserMessage]]]:
) -> Message:
"""You can access a message by ID in our APIs that either:
(1) You received or have previously accessed via starring
(aka have a UserMessage row for).
@ -319,26 +293,36 @@ def access_message(
except Message.DoesNotExist:
raise JsonableError(_("Invalid message(s)"))
if get_user_message == "object":
user_message = get_usermessage_by_message_id(user_profile, message_id)
has_user_message = lambda: user_message is not None
elif get_user_message == "exists":
local_exists = UserMessage.objects.filter(
user_profile=user_profile, message_id=message_id
).exists()
has_user_message = lambda: local_exists
else:
has_user_message = lambda: UserMessage.objects.filter(
user_profile=user_profile, message_id=message_id
).exists()
has_user_message = lambda: UserMessage.objects.filter(
user_profile=user_profile, message_id=message_id
).exists()
if has_message_access(user_profile, message, has_user_message=has_user_message):
if get_user_message is None:
return message
if get_user_message == "exists":
return (message, local_exists)
if get_user_message == "object":
return (message, user_message)
return message
raise JsonableError(_("Invalid message(s)"))
def access_message_and_usermessage(
user_profile: UserProfile,
message_id: int,
lock_message: bool = False,
) -> Tuple[Message, Optional[UserMessage]]:
"""As access_message, but also returns the usermessage, if any."""
try:
base_query = Message.objects.select_related(*Message.DEFAULT_SELECT_RELATED)
if lock_message:
# We want to lock only the `Message` row, and not the related fields
# because the `Message` row only has a possibility of races.
base_query = base_query.select_for_update(of=("self",))
message = base_query.get(id=message_id)
except Message.DoesNotExist:
raise JsonableError(_("Invalid message(s)"))
user_message = get_usermessage_by_message_id(user_profile, message_id)
has_user_message = lambda: user_message is not None
if has_message_access(user_profile, message, has_user_message=has_user_message):
return (message, user_message)
raise JsonableError(_("Invalid message(s)"))

View File

@ -42,7 +42,7 @@ from zerver.lib.avatar import absolute_avatar_url, get_avatar_for_inaccessible_u
from zerver.lib.display_recipient import get_display_recipient
from zerver.lib.emoji_utils import hex_codepoint_to_emoji
from zerver.lib.exceptions import ErrorCode, JsonableError
from zerver.lib.message import access_message, huddle_users
from zerver.lib.message import access_message_and_usermessage, huddle_users
from zerver.lib.outgoing_http import OutgoingSession
from zerver.lib.remote_server import (
send_json_to_push_bouncer,
@ -1289,11 +1289,8 @@ def handle_push_notification(user_profile_id: int, missed_message: Dict[str, Any
with transaction.atomic(savepoint=False):
try:
(message, user_message) = access_message(
user_profile,
missed_message["message_id"],
lock_message=True,
get_user_message="object",
(message, user_message) = access_message_and_usermessage(
user_profile, missed_message["message_id"], lock_message=True
)
except JsonableError:
if ArchivedMessage.objects.filter(id=missed_message["message_id"]).exists():

View File

@ -14,7 +14,12 @@ from zerver.actions.message_edit import check_update_message
from zerver.context_processors import get_valid_realm_from_request
from zerver.lib.exceptions import JsonableError
from zerver.lib.html_diff import highlight_html_differences
from zerver.lib.message import access_message, access_web_public_message, messages_for_ids
from zerver.lib.message import (
access_message,
access_message_and_usermessage,
access_web_public_message,
messages_for_ids,
)
from zerver.lib.request import RequestNotes
from zerver.lib.response import json_success
from zerver.lib.timestamp import datetime_to_timestamp
@ -197,9 +202,7 @@ def json_fetch_raw_message(
message = access_web_public_message(realm, message_id)
user_profile = None
else:
(message, user_message) = access_message(
maybe_user_profile, message_id, get_user_message="object"
)
(message, user_message) = access_message_and_usermessage(maybe_user_profile, message_id)
user_profile = maybe_user_profile
flags = ["read"]