""" License: This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at https://mozilla.org/MPL/2.0/. """ import torch from hub import Dataset from hub.utils import Timer DATASET_NAMES = ["activeloop/mnist", "activeloop/cifar10_train"] BATCH_SIZES = [1, 16, 128] PREFETCH_SIZES = [1, 4, 16, 128] def time_iter_pytorch( dataset_name="activeloop/mnist", batch_size=1, prefetch_factor=0, process=None ): dset = Dataset(dataset_name, cache=False, storage_cache=False, mode="r") loader = torch.utils.data.DataLoader( dset.to_pytorch(), batch_size=batch_size, prefetch_factor=prefetch_factor, num_workers=1, ) with Timer( f"{dataset_name} PyTorch prefetch {prefetch_factor:03} in batches of {batch_size:03}" ): for idx, (image, label) in enumerate(loader): if process is not None: process(idx, image, label) def time_iter_tensorflow( dataset_name="activeloop/mnist", batch_size=1, prefetch_factor=0, process=None ): dset = Dataset(dataset_name, cache=False, storage_cache=False, mode="r") loader = dset.to_tensorflow().batch(batch_size).prefetch(prefetch_factor) with Timer( f"{dataset_name} TF prefetch {prefetch_factor:03} in batches of {batch_size:03}" ): for idx, batch in enumerate(loader): image = batch["image"] label = batch["label"] if process is not None: process(idx, image, label) if __name__ == "__main__": for name in DATASET_NAMES: for size in BATCH_SIZES: for prefetch in PREFETCH_SIZES: time_iter_pytorch(name, size, prefetch, None) time_iter_tensorflow(name, size, prefetch, None)