Skip to content

Commit

Permalink
Merge pull request #457 from activeloopai/fix/to_pytorch_flush
Browse files Browse the repository at this point in the history
Hotfix for to_pytorch
  • Loading branch information
AbhinavTuli authored Jan 15, 2021
2 parents a394e9b + 6a0486d commit 2a9a7e1
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
9 changes: 5 additions & 4 deletions hub/api/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,13 +519,14 @@ def to_pytorch(
num_samples: int, optional
The number of samples required of the dataset that needs to be converted
"""
if "torch" not in sys.modules:
try:
import torch
except ModuleNotFoundError:
raise ModuleNotInstalledException("torch")
import torch

global torch

self.flush() # FIXME Without this some tests in test_converters.py fails, not clear why
if "r" not in self.mode:
self.flush() # FIXME Without this some tests in test_converters.py fails, not clear why
return TorchDataset(
self,
transform,
Expand Down
6 changes: 6 additions & 0 deletions hub/api/tests/test_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,12 @@ def test_to_from_pytorch():
assert (res_ds["label", "d", "e", i].numpy() == i * np.ones((5, 3))).all()


@pytest.mark.skipif(not pytorch_loaded(), reason="requires pytorch to be loaded")
def test_to_pytorch_bug():
ds = hub.Dataset("activeloop/mnist", mode="r")
data = ds.to_pytorch()


if __name__ == "__main__":
with Timer("Test Converters"):
with Timer("from MNIST"):
Expand Down

0 comments on commit 2a9a7e1

Please sign in to comment.