diff --git a/zerver/lib/rest.py b/zerver/lib/rest.py index 536580b95d..6950eb88ef 100644 --- a/zerver/lib/rest.py +++ b/zerver/lib/rest.py @@ -16,11 +16,14 @@ from django.conf import settings METHODS = ('GET', 'HEAD', 'POST', 'PUT', 'DELETE', 'PATCH') FLAGS = ('override_api_url_scheme') -def never_cache_responses(view_func: Callable[..., ReturnT]) -> Callable[..., ReturnT]: - """Patched version of the standard Django decorator that adds headers - to a response so that it will never be cached. +def default_never_cache_responses( + view_func: Callable[..., HttpResponse]) -> Callable[..., HttpResponse]: + """Patched version of the standard Django never_cache_responses + decorator that adds headers to a response so that it will never be + cached, unless the view code has already set a Cache-Control + header. - We need to patch this because our Django+Tornado + We also need to patch this because our Django+Tornado RespondAsynchronously hack involves returning a value that isn't a Django response object, on which add_never_cache_headers would crash. This only occurs in a case where client-side caching @@ -30,12 +33,14 @@ def never_cache_responses(view_func: Callable[..., ReturnT]) -> Callable[..., Re @wraps(view_func) def _wrapped_view_func(request: HttpRequest, *args: Any, **kwargs: Any) -> ReturnT: response = view_func(request, *args, **kwargs) - if response is not RespondAsynchronously: - add_never_cache_headers(response) + if response is RespondAsynchronously or response.has_header("Cache-Control"): + return response + + add_never_cache_headers(response) return response return _wrapped_view_func -@never_cache_responses +@default_never_cache_responses @csrf_exempt def rest_dispatch(request: HttpRequest, **kwargs: Any) -> HttpResponse: """Dispatch to a REST API endpoint. diff --git a/zerver/tests/test_home.py b/zerver/tests/test_home.py index bd68f4d688..f66a61fc05 100644 --- a/zerver/tests/test_home.py +++ b/zerver/tests/test_home.py @@ -240,6 +240,8 @@ class HomeTest(ZulipTestCase): with queries_captured() as queries: with patch('zerver.lib.cache.cache_set') as cache_mock: result = self._get_home_page(stream='Denmark') + self.assertEqual(set(result["Cache-Control"].split(", ")), + {"must-revalidate", "no-store", "no-cache"}) self.assert_length(queries, 45) self.assert_length(cache_mock.call_args_list, 7) @@ -442,6 +444,8 @@ class HomeTest(ZulipTestCase): self._sanity_check(result) html = result.content.decode('utf-8') self.assertIn('lunch', html) + self.assertEqual(set(result["Cache-Control"].split(", ")), + {"must-revalidate", "no-store", "no-cache"}) def test_notifications_stream(self) -> None: email = self.example_email("hamlet") diff --git a/zerver/tests/test_narrow.py b/zerver/tests/test_narrow.py index 4b5583787a..c85a4c512d 100644 --- a/zerver/tests/test_narrow.py +++ b/zerver/tests/test_narrow.py @@ -1029,6 +1029,9 @@ class GetOldMessagesTest(ZulipTestCase): payload = self.client_get("/json/messages", dict(post_params), **kwargs) self.assert_json_success(payload) + self.assertEqual(set(payload["Cache-Control"].split(", ")), + {"must-revalidate", "no-store", "no-cache", "max-age=0"}) + result = ujson.loads(payload.content) self.assertIn("messages", result) diff --git a/zerver/tests/test_upload.py b/zerver/tests/test_upload.py index 39131da15b..a544de817e 100644 --- a/zerver/tests/test_upload.py +++ b/zerver/tests/test_upload.py @@ -707,6 +707,8 @@ class FileUploadTest(UploadSerializeMixin, ZulipTestCase): self.assertIn(content_disposition, response['Content-disposition']) else: self.assertEqual(response.get('Content-disposition'), None) + self.assertEqual(set(response["Cache-Control"].split(", ")), + {"private", "immutable"}) check_xsend_links('zulip.txt', 'zulip.txt', 'filename="zulip.txt"') check_xsend_links('áéБД.txt', '%C3%A1%C3%A9%D0%91%D0%94.txt', diff --git a/zerver/views/upload.py b/zerver/views/upload.py index bdf86fc534..ab73fca3df 100644 --- a/zerver/views/upload.py +++ b/zerver/views/upload.py @@ -1,6 +1,7 @@ from django.http import HttpRequest, HttpResponse, HttpResponseForbidden, \ HttpResponseNotFound from django.shortcuts import redirect +from django.utils.cache import patch_cache_control from django.utils.translation import ugettext as _ from zerver.lib.response import json_success, json_error @@ -40,8 +41,10 @@ def serve_local(request: HttpRequest, path_id: str) -> HttpResponse: mimetype, encoding = guess_type(local_path) attachment = mimetype not in INLINE_MIME_TYPES - return sendfile(request, local_path, attachment=attachment, - mimetype=mimetype, encoding=encoding) + response = sendfile(request, local_path, attachment=attachment, + mimetype=mimetype, encoding=encoding) + patch_cache_control(response, private=True, immutable=True) + return response def serve_file_backend(request: HttpRequest, user_profile: UserProfile, realm_id_str: str, filename: str) -> HttpResponse: