From b71e3fb58d265af784cb4a7020b1adedbffd5b6d Mon Sep 17 00:00:00 2001 From: Jacob Marks Date: Thu, 4 Apr 2024 18:10:44 -0400 Subject: [PATCH] only download necessary media files for FiftyOneDatasets --- fiftyone/utils/huggingface.py | 39 ++++++++++++++++++++++++++++++++--- 1 file changed, 36 insertions(+), 3 deletions(-) diff --git a/fiftyone/utils/huggingface.py b/fiftyone/utils/huggingface.py index d81e21eedf..55d880a4bb 100644 --- a/fiftyone/utils/huggingface.py +++ b/fiftyone/utils/huggingface.py @@ -1077,6 +1077,12 @@ def _resolve_dataset_name(config, **kwargs): return name +def _get_files_to_download(dataset): + filepaths = dataset.values("filepath") + filepaths = [fp for fp in filepaths if not os.path.exists(fp)] + return filepaths + + def _load_fiftyone_dataset_from_config(config, **kwargs): logger.info("Loading dataset") @@ -1086,12 +1092,27 @@ def _load_fiftyone_dataset_from_config(config, **kwargs): splits = _parse_split_kwargs(**kwargs) download_dir = _get_download_dir(config._repo_id, **kwargs) - hfh.snapshot_download( - repo_id=config._repo_id, repo_type="dataset", local_dir=download_dir - ) + + init_download_kwargs = { + "repo_id": config._repo_id, + "repo_type": "dataset", + "local_dir": download_dir, + } dataset_type_name = config._format.strip() + if dataset_type_name == "FiftyOneDataset": + # If the dataset is a FiftyOneDataset, we can smart only download the + # necessary files + hfh.snapshot_download( + **init_download_kwargs, + ignore_patterns="data/*", + ) + else: + hfh.snapshot_download( + **init_download_kwargs, + ) + dataset_type = getattr( __import__("fiftyone.types", fromlist=[dataset_type_name]), dataset_type_name, @@ -1110,6 +1131,18 @@ def _load_fiftyone_dataset_from_config(config, **kwargs): dataset_kwargs["name"] = name dataset = fod.Dataset.from_dir(download_dir, **dataset_kwargs) + + if dataset_type_name != "FiftyOneDataset": + return dataset + + filepaths = _get_files_to_download(dataset) + if filepaths: + logger.info(f"Downloading {len(filepaths)} media files...") + filenames = [os.path.basename(fp) for fp in filepaths] + allowed_globs = ["data/" + fn for fn in filenames] + hfh.snapshot_download( + **init_download_kwargs, allow_patterns=allowed_globs + ) return dataset