diff --git a/hub/core/dataset.py b/hub/core/dataset.py index 531a88afa8..d62db5d564 100644 --- a/hub/core/dataset.py +++ b/hub/core/dataset.py @@ -1,4 +1,5 @@ import hub +from tqdm import tqdm # type: ignore import pickle import warnings import posixpath @@ -476,13 +477,14 @@ def pytorch( transform: Optional[Callable] = None, tensors: Optional[Sequence[str]] = None, num_workers: int = 1, - batch_size: Optional[int] = 1, + batch_size: int = 1, drop_last: bool = False, collate_fn: Optional[Callable] = None, pin_memory: bool = False, shuffle: bool = False, buffer_size: int = 10 * 1000, use_local_cache: bool = False, + use_progress_bar: bool = False, ): """Converts the dataset into a pytorch Dataloader. @@ -494,7 +496,7 @@ def pytorch( transform (Callable, optional) : Transformation function to be applied to each sample. tensors (List, optional): Optionally provide a list of tensor names in the ordering that your training script expects. For example, if you have a dataset that has "image" and "label" tensors, if `tensors=["image", "label"]`, your training script should expect each batch will be provided as a tuple of (image, label). num_workers (int): The number of workers to use for fetching data in parallel. - batch_size (int, optional): Number of samples per batch to load. Default value is 1. + batch_size (int): Number of samples per batch to load. Default value is 1. drop_last (bool): Set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. Default value is False. Read torch.utils.data.DataLoader docs for more details. @@ -505,13 +507,14 @@ def pytorch( shuffle (bool): If True, the data loader will shuffle the data indices. Default value is False. buffer_size (int): The size of the buffer used to prefetch/shuffle in MB. The buffer uses shared memory under the hood. Default value is 10 GB. Increasing the buffer_size will increase the extent of shuffling. use_local_cache (bool): If True, the data loader will use a local cache to store data. This is useful when the dataset can fit on the machine and we don't want to fetch the data multiple times for each iteration. Default value is False. + use_progress_bar (bool): If True, tqdm will be wrapped around the returned dataloader. Default value is True. Returns: A torch.utils.data.DataLoader object. """ from hub.integrations import dataset_to_pytorch - return dataset_to_pytorch( + dataloader = dataset_to_pytorch( self, transform, tensors, @@ -525,6 +528,11 @@ def pytorch( use_local_cache=use_local_cache, ) + if use_progress_bar: + dataloader = tqdm(dataloader, desc=self.path, total=len(self) // batch_size) + + return dataloader + def _get_total_meta(self): """Returns tensor metas all together""" return {