Skip to content

Commit

Permalink
[2 liner] TorchDataset.__len__ (#1268)
Browse files Browse the repository at this point in the history
  • Loading branch information
verbose-void authored Oct 18, 2021
1 parent a015ac2 commit 269833e
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 0 deletions.
5 changes: 5 additions & 0 deletions hub/integrations/pytorch/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def __init__(
if next_storage is not None:
next_storage.clear()

self.hub_dataset = dataset

try:
self.cache = cache(
cache_storage=cache_storage,
Expand All @@ -80,6 +82,9 @@ def __init__(
"Underlying storage of the dataset in MemoryProvider which is not supported."
)

def __len__(self):
return len(self.hub_dataset)

def __iter__(self):
for value in self.cache.iterate_samples():
if value is not None:
Expand Down
2 changes: 2 additions & 0 deletions hub/integrations/tests/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def test_pytorch_small(ds):

dl = ds.pytorch(num_workers=2, batch_size=1)

assert len(dl.dataset) == 16

for _ in range(2):
for i, batch in enumerate(dl):
np.testing.assert_array_equal(
Expand Down

0 comments on commit 269833e

Please sign in to comment.