Skip to content

Commit

Permalink
Adds context without progress bars for upload/download
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobmarks committed Apr 5, 2024
1 parent 1179d9a commit a5291a6
Showing 1 changed file with 37 additions and 15 deletions.
52 changes: 37 additions & 15 deletions fiftyone/utils/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""

from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
import logging
import os
from packaging.requirements import Requirement
Expand All @@ -31,6 +32,11 @@
callback=lambda: fou.ensure_package("huggingface_hub>=0.20.0"),
)

hfu = fou.lazy_import(
"huggingface_hub.utils",
callback=lambda: fou.ensure_package("huggingface_hub>=0.20.0"),
)


DATASETS_SERVER_URL = "https://datasets-server.huggingface.co"
DEFAULT_MEDIA_TYPE = "image"
Expand Down Expand Up @@ -154,11 +160,12 @@ def push_to_hub(

## Upload the dataset to the repo
api = hfh.HfApi(token=token)
api.upload_folder(
folder_path=tmp_dir,
repo_id=repo_id,
repo_type="dataset",
)
with _no_progress_bars():
api.upload_folder(
folder_path=tmp_dir,
repo_id=repo_id,
repo_type="dataset",
)

# Upload preview image or video if provided
if preview_path is not None:
Expand Down Expand Up @@ -457,6 +464,18 @@ def _parse_subset_kwargs(**kwargs):
return subsets


@contextmanager
def _no_progress_bars():
pbs_disabled = hfu.are_progress_bars_disabled()
hfu.disable_progress_bars()
try:
yield
finally:
# Restore the original state
if not pbs_disabled:
hfu.enable_progress_bars()


class HFHubParquetFilesDatasetConfig(HFHubDatasetConfig):
"""Config for a Hugging Face Hub dataset that is stored as parquet files.
Expand Down Expand Up @@ -1094,14 +1113,16 @@ def _load_fiftyone_dataset_from_config(config, **kwargs):

if dataset_type_name == "FiftyOneDataset" and max_samples is not None:
# If the dataset is a FiftyOneDataset, download only the necessary files
hfh.snapshot_download(
**init_download_kwargs,
ignore_patterns="data/*",
)
with _no_progress_bars():
hfh.snapshot_download(
**init_download_kwargs,
ignore_patterns="data/*",
)
else:
hfh.snapshot_download(
**init_download_kwargs,
)
with _no_progress_bars():
hfh.snapshot_download(
**init_download_kwargs,
)

dataset_type = getattr(
__import__("fiftyone.types", fromlist=[dataset_type_name]),
Expand Down Expand Up @@ -1130,9 +1151,10 @@ def _load_fiftyone_dataset_from_config(config, **kwargs):
logger.info(f"Downloading {len(filepaths)} media files...")
filenames = [os.path.basename(fp) for fp in filepaths]
allowed_globs = ["data/" + fn for fn in filenames]
hfh.snapshot_download(
**init_download_kwargs, allow_patterns=allowed_globs
)
with _no_progress_bars():
hfh.snapshot_download(
**init_download_kwargs, allow_patterns=allowed_globs
)
return dataset


Expand Down

0 comments on commit a5291a6

Please sign in to comment.