diff --git a/zerver/lib/upload/__init__.py b/zerver/lib/upload/__init__.py index 27268487f5..e1e5d83fa5 100644 --- a/zerver/lib/upload/__init__.py +++ b/zerver/lib/upload/__init__.py @@ -224,8 +224,10 @@ def delete_message_attachments(path_ids: list[str]) -> None: return upload_backend.delete_message_attachments(path_ids) -def all_message_attachments(include_thumbnails: bool = False) -> Iterator[tuple[str, datetime]]: - return upload_backend.all_message_attachments(include_thumbnails) +def all_message_attachments( + *, include_thumbnails: bool = False, prefix: str = "" +) -> Iterator[tuple[str, datetime]]: + return upload_backend.all_message_attachments(include_thumbnails, prefix) # Avatar image uploads diff --git a/zerver/lib/upload/base.py b/zerver/lib/upload/base.py index cfc8ec4318..c869c08ff4 100644 --- a/zerver/lib/upload/base.py +++ b/zerver/lib/upload/base.py @@ -68,7 +68,9 @@ class ZulipUploadBackend: self.delete_message_attachment(path_id) def all_message_attachments( - self, include_thumbnails: bool = False + self, + include_thumbnails: bool = False, + prefix: str = "", ) -> Iterator[tuple[str, datetime]]: raise NotImplementedError diff --git a/zerver/lib/upload/local.py b/zerver/lib/upload/local.py index 33e61f20c1..438fdce62e 100644 --- a/zerver/lib/upload/local.py +++ b/zerver/lib/upload/local.py @@ -115,17 +115,22 @@ class LocalUploadBackend(ZulipUploadBackend): @override def all_message_attachments( - self, include_thumbnails: bool = False + self, + include_thumbnails: bool = False, + prefix: str = "", ) -> Iterator[tuple[str, datetime]]: assert settings.LOCAL_UPLOADS_DIR is not None top = settings.LOCAL_UPLOADS_DIR + "/files" - for dirname, subdirnames, files in os.walk(top): + start = top + if prefix != "": + start += f"/{prefix}" + for dirname, subdirnames, files in os.walk(start): if not include_thumbnails and dirname == top and "thumbnail" in subdirnames: subdirnames.remove("thumbnail") for f in files: fullpath = os.path.join(dirname, f) yield ( - os.path.relpath(fullpath, settings.LOCAL_UPLOADS_DIR + "/files"), + os.path.relpath(fullpath, top), timestamp_to_datetime(os.path.getmtime(fullpath)), ) diff --git a/zerver/lib/upload/s3.py b/zerver/lib/upload/s3.py index ae3c7bbaef..b27af1ff9a 100644 --- a/zerver/lib/upload/s3.py +++ b/zerver/lib/upload/s3.py @@ -290,11 +290,13 @@ class S3UploadBackend(ZulipUploadBackend): @override def all_message_attachments( - self, include_thumbnails: bool = False + self, + include_thumbnails: bool = False, + prefix: str = "", ) -> Iterator[tuple[str, datetime]]: client = self.uploads_bucket.meta.client paginator = client.get_paginator("list_objects_v2") - page_iterator = paginator.paginate(Bucket=self.uploads_bucket.name) + page_iterator = paginator.paginate(Bucket=self.uploads_bucket.name, Prefix=prefix) for page in page_iterator: if page["KeyCount"] > 0: diff --git a/zerver/tests/test_upload_local.py b/zerver/tests/test_upload_local.py index 36bd825a5c..4e86b16410 100644 --- a/zerver/tests/test_upload_local.py +++ b/zerver/tests/test_upload_local.py @@ -130,6 +130,15 @@ class LocalStorageTest(UploadSerializeMixin, ZulipTestCase): found_files = [r[0] for r in all_message_attachments()] self.assertEqual(sorted(found_files), ["bar/baz", "bar/troz", "foo", "test/other/file"]) + found_paths = [r[0] for r in all_message_attachments(prefix="bar")] + self.assertEqual(sorted(found_paths), ["bar/baz", "bar/troz"]) + + found_paths = [r[0] for r in all_message_attachments(prefix="test")] + self.assertEqual(found_paths, ["test/other/file"]) + + found_paths = [r[0] for r in all_message_attachments(prefix="missing")] + self.assertEqual(found_paths, []) + write_local_file("files", "thumbnail/thing", b"content") found_files = [r[0] for r in all_message_attachments()] self.assertEqual(sorted(found_files), ["bar/baz", "bar/troz", "foo", "test/other/file"]) diff --git a/zerver/tests/test_upload_s3.py b/zerver/tests/test_upload_s3.py index 0866c525fa..ebc561ef93 100644 --- a/zerver/tests/test_upload_s3.py +++ b/zerver/tests/test_upload_s3.py @@ -183,6 +183,15 @@ class S3Test(ZulipTestCase): found_paths = [r[0] for r in all_message_attachments()] self.assertEqual(sorted(found_paths), sorted(path_ids)) + found_paths = [r[0] for r in all_message_attachments(prefix=str(user_profile.realm_id))] + self.assertEqual(sorted(found_paths), sorted(path_ids)) + + found_paths = [r[0] for r in all_message_attachments(prefix=os.path.dirname(path_ids[0]))] + self.assertEqual(found_paths, [path_ids[0]]) + + found_paths = [r[0] for r in all_message_attachments(prefix="missing")] + self.assertEqual(found_paths, []) + found_paths = [r[0] for r in all_message_attachments(include_thumbnails=True)] for thumbnail_format in THUMBNAIL_OUTPUT_FORMATS: if thumbnail_format.animated: