Skip to content

Commit

Permalink
only download necessary media files for FiftyOneDatasets
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobmarks committed Apr 4, 2024
1 parent d5ea06e commit b71e3fb
Showing 1 changed file with 36 additions and 3 deletions.
39 changes: 36 additions & 3 deletions fiftyone/utils/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,6 +1077,12 @@ def _resolve_dataset_name(config, **kwargs):
return name


def _get_files_to_download(dataset):
filepaths = dataset.values("filepath")
filepaths = [fp for fp in filepaths if not os.path.exists(fp)]
return filepaths


def _load_fiftyone_dataset_from_config(config, **kwargs):
logger.info("Loading dataset")

Expand All @@ -1086,12 +1092,27 @@ def _load_fiftyone_dataset_from_config(config, **kwargs):
splits = _parse_split_kwargs(**kwargs)

download_dir = _get_download_dir(config._repo_id, **kwargs)
hfh.snapshot_download(
repo_id=config._repo_id, repo_type="dataset", local_dir=download_dir
)

init_download_kwargs = {
"repo_id": config._repo_id,
"repo_type": "dataset",
"local_dir": download_dir,
}

dataset_type_name = config._format.strip()

if dataset_type_name == "FiftyOneDataset":
# If the dataset is a FiftyOneDataset, we can smart only download the
# necessary files
hfh.snapshot_download(
**init_download_kwargs,
ignore_patterns="data/*",
)
else:
hfh.snapshot_download(
**init_download_kwargs,
)

dataset_type = getattr(
__import__("fiftyone.types", fromlist=[dataset_type_name]),
dataset_type_name,
Expand All @@ -1110,6 +1131,18 @@ def _load_fiftyone_dataset_from_config(config, **kwargs):
dataset_kwargs["name"] = name

dataset = fod.Dataset.from_dir(download_dir, **dataset_kwargs)

if dataset_type_name != "FiftyOneDataset":
return dataset

filepaths = _get_files_to_download(dataset)
if filepaths:
logger.info(f"Downloading {len(filepaths)} media files...")
filenames = [os.path.basename(fp) for fp in filepaths]
allowed_globs = ["data/" + fn for fn in filenames]
hfh.snapshot_download(
**init_download_kwargs, allow_patterns=allowed_globs
)
return dataset


Expand Down

0 comments on commit b71e3fb

Please sign in to comment.