do_mark_all_as_read: Split up the work into batches.

Fixes #15403.
This commit is contained in:
Mateusz Mandera 2022-10-02 21:32:36 +02:00 committed by Tim Abbott
parent ef468322f1
commit a410f6b241
10 changed files with 179 additions and 32 deletions

View File

@ -22,6 +22,13 @@ export function mark_all_as_read() {
success: () => { success: () => {
// After marking all messages as read, we reload the browser. // After marking all messages as read, we reload the browser.
// This is useful to avoid leaving ourselves deep in the past. // This is useful to avoid leaving ourselves deep in the past.
// This is also the currently intended behavior in case of partial success,
// (response code 200 with result "partially_completed")
// where the request times out after marking some messages as read,
// so we don't need to distinguish that scenario here.
// TODO: The frontend handling of partial success can be improved
// by re-running the request in a loop, while showing some status indicator
// to the user.
reload.initiate({ reload.initiate({
immediate: true, immediate: true,
save_pointer: false, save_pointer: false,

View File

@ -20,6 +20,14 @@ format used by the Zulip server that they are interacting with.
## Changes in Zulip 6.0 ## Changes in Zulip 6.0
**Feature level 153**
* [`POST /mark_all_as_read`](/api/mark-all-as-read): Messages are now
marked as read in batches, so that progress will be made even if the
request times out because of an extremely large number of unread
messages to process. Upon timeout, a success response with a
"partially_completed" result will be returned by the server.
**Feature level 152** **Feature level 152**
* [`PATCH /messages/{message_id}`](/api/update-message): The * [`PATCH /messages/{message_id}`](/api/update-message): The

View File

@ -33,7 +33,7 @@ DESKTOP_WARNING_VERSION = "5.4.3"
# Changes should be accompanied by documentation explaining what the # Changes should be accompanied by documentation explaining what the
# new level means in templates/zerver/api/changelog.md, as well as # new level means in templates/zerver/api/changelog.md, as well as
# "**Changes**" entries in the endpoint's documentation in `zulip.yaml`. # "**Changes**" entries in the endpoint's documentation in `zulip.yaml`.
API_FEATURE_LEVEL = 152 API_FEATURE_LEVEL = 153
# Bump the minor PROVISION_VERSION to indicate that folks should provision # Bump the minor PROVISION_VERSION to indicate that folks should provision
# only when going from an old version of the code to a newer version. Bump # only when going from an old version of the code to a newer version. Bump

View File

@ -44,37 +44,52 @@ def do_mark_all_as_read(user_profile: UserProfile) -> int:
) )
do_clear_mobile_push_notifications_for_ids([user_profile.id], all_push_message_ids) do_clear_mobile_push_notifications_for_ids([user_profile.id], all_push_message_ids)
batch_size = 2000
count = 0
while True:
with transaction.atomic(savepoint=False): with transaction.atomic(savepoint=False):
query = ( query = (
UserMessage.select_for_update_query() UserMessage.select_for_update_query()
.filter(user_profile=user_profile) .filter(user_profile=user_profile)
.extra(where=[UserMessage.where_unread()]) .extra(where=[UserMessage.where_unread()])[:batch_size]
) )
count = query.update( # This updated_count is the same as the number of UserMessage
# rows selected, because due to the FOR UPDATE lock, we're guaranteed
# that all the selected rows will indeed be updated.
# UPDATE queries don't support LIMIT, so we have to use a subquery
# to do batching.
updated_count = UserMessage.objects.filter(id__in=query).update(
flags=F("flags").bitor(UserMessage.flags.read), flags=F("flags").bitor(UserMessage.flags.read),
) )
event_time = timezone_now()
do_increment_logging_stat(
user_profile,
COUNT_STATS["messages_read::hour"],
None,
event_time,
increment=updated_count,
)
do_increment_logging_stat(
user_profile,
COUNT_STATS["messages_read_interactions::hour"],
None,
event_time,
increment=min(1, updated_count),
)
count += updated_count
if updated_count < batch_size:
break
event = asdict( event = asdict(
ReadMessagesEvent( ReadMessagesEvent(
messages=[], # we don't send messages, since the client reloads anyway messages=[], # we don't send messages, since the client reloads anyway
all=True, all=True,
) )
) )
event_time = timezone_now()
send_event(user_profile.realm, event, [user_profile.id]) send_event(user_profile.realm, event, [user_profile.id])
do_increment_logging_stat(
user_profile, COUNT_STATS["messages_read::hour"], None, event_time, increment=count
)
do_increment_logging_stat(
user_profile,
COUNT_STATS["messages_read_interactions::hour"],
None,
event_time,
increment=min(1, count),
)
return count return count

View File

@ -37,6 +37,7 @@ class ErrorCode(Enum):
PASSWORD_RESET_REQUIRED = auto() PASSWORD_RESET_REQUIRED = auto()
AUTHENTICATION_FAILED = auto() AUTHENTICATION_FAILED = auto()
UNAUTHORIZED = auto() UNAUTHORIZED = auto()
REQUEST_TIMEOUT = auto()
class JsonableError(Exception): class JsonableError(Exception):

View File

@ -45,6 +45,10 @@ def json_success(request: HttpRequest, data: Mapping[str, Any] = {}) -> HttpResp
return json_response(data=data) return json_response(data=data)
def json_partial_success(request: HttpRequest, data: Mapping[str, Any] = {}) -> HttpResponse:
return json_response(res_type="partially_completed", data=data, status=200)
def json_response_from_error(exception: JsonableError) -> HttpResponse: def json_response_from_error(exception: JsonableError) -> HttpResponse:
""" """
This should only be needed in middleware; in app code, just raise. This should only be needed in middleware; in app code, just raise.

View File

@ -412,7 +412,9 @@ def validate_against_openapi_schema(
if (endpoint, method) in EXCLUDE_DOCUMENTED_ENDPOINTS: if (endpoint, method) in EXCLUDE_DOCUMENTED_ENDPOINTS:
return True return True
# Check if the response matches its code # Check if the response matches its code
if status_code.startswith("2") and (content.get("result", "success").lower() != "success"): if status_code.startswith("2") and (
content.get("result", "success").lower() not in ["success", "partially_completed"]
):
raise SchemaError("Response is not 200 but is validating against 200 schema") raise SchemaError("Response is not 200 but is validating against 200 schema")
# Code is not declared but appears in various 400 responses. If # Code is not declared but appears in various 400 responses. If
# common, it can be added to 400 response schema # common, it can be added to 400 response schema

View File

@ -4533,9 +4533,41 @@ paths:
tags: ["messages"] tags: ["messages"]
description: | description: |
Marks all of the current user's unread messages as read. Marks all of the current user's unread messages as read.
**Changes**: Before Zulip 6.0 (feature level 153), this
request did a single atomic operation, which could time out
with 10,000s of unread messages to mark as read.
It now marks messages as read in batches, starting with the
newest messages, so that progress will be made even if the
request times out.
If the server's processing is interrupted by a timeout, it
will return an HTTP 200 success response with result
"partially_completed". A correct client should repeat the
request when handling such a response.
responses: responses:
"200": "200":
$ref: "#/components/responses/SimpleSuccess" description: Success or partial success.
content:
application/json:
schema:
oneOf:
- allOf:
- $ref: "#/components/schemas/JsonSuccess"
- $ref: "#/components/schemas/SuccessDescription"
- allOf:
- $ref: "#/components/schemas/PartiallyCompleted"
- example:
{
"code": "REQUEST_TIMEOUT",
"msg": "",
"result": "partially_completed",
}
description: |
If the request exceeds its processing time limit after having
successfully marked some messages as read, response code 200
with result "partially_completed" and code "REQUEST_TIMEOUT" will be returned like this:
/mark_stream_as_read: /mark_stream_as_read:
post: post:
operationId: mark-stream-as-read operationId: mark-stream-as-read
@ -16695,6 +16727,23 @@ components:
- error - error
msg: msg:
type: string type: string
PartiallyCompleted:
allOf:
- $ref: "#/components/schemas/JsonResponseBase"
- required:
- result
- code
additionalProperties: false
properties:
result:
enum:
- partially_completed
code:
type: string
description: |
A string that identifies the cause of the partial completion of the request.
msg:
type: string
ApiKeyResponse: ApiKeyResponse:
allOf: allOf:
- $ref: "#/components/schemas/JsonSuccessBase" - $ref: "#/components/schemas/JsonSuccessBase"

View File

@ -1,4 +1,5 @@
from typing import TYPE_CHECKING, Any, List, Mapping, Set from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Callable, Iterator, List, Mapping, Set
from unittest import mock from unittest import mock
import orjson import orjson
@ -22,6 +23,7 @@ from zerver.lib.message import (
) )
from zerver.lib.test_classes import ZulipTestCase from zerver.lib.test_classes import ZulipTestCase
from zerver.lib.test_helpers import get_subscription from zerver.lib.test_helpers import get_subscription
from zerver.lib.timeout import TimeoutExpired
from zerver.lib.user_topics import add_topic_mute from zerver.lib.user_topics import add_topic_mute
from zerver.models import ( from zerver.models import (
Message, Message,
@ -49,6 +51,18 @@ def check_flags(flags: List[str], expected: Set[str]) -> None:
raise AssertionError(f"expected flags (ignoring has_alert_word) to be {expected}") raise AssertionError(f"expected flags (ignoring has_alert_word) to be {expected}")
@contextmanager
def timeout_mock() -> Iterator[None]:
# timeout() doesn't work in test environment with database operations
# and they don't get committed - so we need to replace it with a mock
# that just calls the function.
def mock_timeout(seconds: int, func: Callable[[], object]) -> object:
return func()
with mock.patch("zerver.views.message_flags.timeout", new=mock_timeout):
yield
class FirstUnreadAnchorTests(ZulipTestCase): class FirstUnreadAnchorTests(ZulipTestCase):
""" """
HISTORICAL NOTE: HISTORICAL NOTE:
@ -62,6 +76,7 @@ class FirstUnreadAnchorTests(ZulipTestCase):
self.login("hamlet") self.login("hamlet")
# Mark all existing messages as read # Mark all existing messages as read
with timeout_mock():
result = self.client_post("/json/mark_all_as_read") result = self.client_post("/json/mark_all_as_read")
self.assert_json_success(result) self.assert_json_success(result)
@ -121,6 +136,7 @@ class FirstUnreadAnchorTests(ZulipTestCase):
def test_visible_messages_use_first_unread_anchor(self) -> None: def test_visible_messages_use_first_unread_anchor(self) -> None:
self.login("hamlet") self.login("hamlet")
with timeout_mock():
result = self.client_post("/json/mark_all_as_read") result = self.client_post("/json/mark_all_as_read")
self.assert_json_success(result) self.assert_json_success(result)
@ -563,11 +579,52 @@ class PushNotificationMarkReadFlowsTest(ZulipTestCase):
[third_message_id, fourth_message_id], [third_message_id, fourth_message_id],
) )
with timeout_mock():
result = self.client_post("/json/mark_all_as_read", {}) result = self.client_post("/json/mark_all_as_read", {})
self.assertEqual(self.get_mobile_push_notification_ids(user_profile), []) self.assertEqual(self.get_mobile_push_notification_ids(user_profile), [])
mock_push_notifications.assert_called() mock_push_notifications.assert_called()
class MarkAllAsReadEndpointTest(ZulipTestCase):
def test_mark_all_as_read_endpoint(self) -> None:
self.login("hamlet")
hamlet = self.example_user("hamlet")
othello = self.example_user("othello")
self.subscribe(hamlet, "Denmark")
for i in range(0, 4):
self.send_stream_message(othello, "Verona", "test")
self.send_personal_message(othello, hamlet, "test")
unread_count = (
UserMessage.objects.filter(user_profile=hamlet)
.extra(where=[UserMessage.where_unread()])
.count()
)
self.assertNotEqual(unread_count, 0)
with timeout_mock():
result = self.client_post("/json/mark_all_as_read", {})
self.assert_json_success(result)
new_unread_count = (
UserMessage.objects.filter(user_profile=hamlet)
.extra(where=[UserMessage.where_unread()])
.count()
)
self.assertEqual(new_unread_count, 0)
def test_mark_all_as_read_timeout_response(self) -> None:
self.login("hamlet")
with mock.patch("zerver.views.message_flags.timeout", side_effect=TimeoutExpired):
result = self.client_post("/json/mark_all_as_read", {})
self.assertEqual(result.status_code, 200)
result_dict = orjson.loads(result.content)
self.assertEqual(
result_dict, {"result": "partially_completed", "msg": "", "code": "REQUEST_TIMEOUT"}
)
class GetUnreadMsgsTest(ZulipTestCase): class GetUnreadMsgsTest(ZulipTestCase):
def mute_stream(self, user_profile: UserProfile, stream: Stream) -> None: def mute_stream(self, user_profile: UserProfile, stream: Stream) -> None:
recipient = Recipient.objects.get(type_id=stream.id, type=Recipient.STREAM) recipient = Recipient.objects.get(type_id=stream.id, type=Recipient.STREAM)

View File

@ -8,10 +8,11 @@ from zerver.actions.message_flags import (
do_mark_stream_messages_as_read, do_mark_stream_messages_as_read,
do_update_message_flags, do_update_message_flags,
) )
from zerver.lib.exceptions import JsonableError from zerver.lib.exceptions import ErrorCode, JsonableError
from zerver.lib.request import REQ, RequestNotes, has_request_variables from zerver.lib.request import REQ, RequestNotes, has_request_variables
from zerver.lib.response import json_success from zerver.lib.response import json_partial_success, json_success
from zerver.lib.streams import access_stream_by_id from zerver.lib.streams import access_stream_by_id
from zerver.lib.timeout import TimeoutExpired, timeout
from zerver.lib.topic import user_message_exists_for_topic from zerver.lib.topic import user_message_exists_for_topic
from zerver.lib.validator import check_int, check_list from zerver.lib.validator import check_int, check_list
from zerver.models import UserActivity, UserProfile from zerver.models import UserActivity, UserProfile
@ -50,7 +51,10 @@ def update_message_flags(
@has_request_variables @has_request_variables
def mark_all_as_read(request: HttpRequest, user_profile: UserProfile) -> HttpResponse: def mark_all_as_read(request: HttpRequest, user_profile: UserProfile) -> HttpResponse:
request_notes = RequestNotes.get_notes(request) request_notes = RequestNotes.get_notes(request)
count = do_mark_all_as_read(user_profile) try:
count = timeout(50, lambda: do_mark_all_as_read(user_profile))
except TimeoutExpired:
return json_partial_success(request, data={"code": ErrorCode.REQUEST_TIMEOUT.name})
log_data_str = f"[{count} updated]" log_data_str = f"[{count} updated]"
assert request_notes.log_data is not None assert request_notes.log_data is not None