Skip to content

Commit

Permalink
Merge pull request #2342 from activeloopai/tensor_calculation
Browse files Browse the repository at this point in the history
corrected tensor calculation logic
  • Loading branch information
levongh authored May 11, 2023
2 parents 494515a + baa69dc commit 54a1b8c
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 5 deletions.
8 changes: 6 additions & 2 deletions deeplake/enterprise/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@

from deeplake.util.remove_cache import get_base_storage
from deeplake.core.index.index import IndexEntry
from deeplake.tests.common import requires_torch, requires_libdeeplake
from deeplake.tests.common import (
requires_torch,
requires_libdeeplake,
convert_data_according_to_torch_version,
)
from deeplake.core.dataset import Dataset
from deeplake.constants import KB

Expand Down Expand Up @@ -458,7 +462,7 @@ def test_pytorch_decode(hub_cloud_ds, compressed_image_paths, compression):
ptds = hub_cloud_ds.dataloader().pytorch(decode_method={"image": "tobytes"})

for i, batch in enumerate(ptds):
image = batch["image"]
image = convert_data_according_to_torch_version(batch["image"])
assert isinstance(image, bytes)
if i < 5 and not compression:
np.testing.assert_array_equal(
Expand Down
4 changes: 4 additions & 0 deletions deeplake/integrations/pytorch/shuffle_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ def _num_torch_tensors(self, sample):
return 0
if isinstance(sample, TorchTensor):
return 1
elif isinstance(sample, bytes):
return 0
elif isinstance(sample, str):
return 0
elif isinstance(sample, dict):
return sum(self._num_torch_tensors(tensor) for tensor in sample.values())
elif isinstance(sample, Sequence):
Expand Down
2 changes: 2 additions & 0 deletions deeplake/requirements/tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,5 @@ boto3-stubs[essential]
lz4
rich
wandb

pandas; python_version >= '3.11' and sys_platform == 'win32'
4 changes: 1 addition & 3 deletions deeplake/tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,7 @@ def __exit__(self, *args, **kwargs):


def convert_data_according_to_torch_version(batch):
import torch

if torch.__version__ < "2.0.0":
if isinstance(batch, List):
return batch[0]
else:
return batch

0 comments on commit 54a1b8c

Please sign in to comment.