Skip to content

Commit

Permalink
test fix
Browse files Browse the repository at this point in the history
  • Loading branch information
levongh committed May 10, 2023
1 parent 7e02e05 commit 244d99d
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions deeplake/enterprise/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@

from deeplake.util.remove_cache import get_base_storage
from deeplake.core.index.index import IndexEntry
from deeplake.tests.common import requires_torch, requires_libdeeplake
from deeplake.tests.common import (
requires_torch,
requires_libdeeplake,
convert_data_according_to_torch_version,
)
from deeplake.core.dataset import Dataset
from deeplake.constants import KB

Expand Down Expand Up @@ -63,7 +67,6 @@ def index_transform(sample):
@pytest.mark.parametrize(
"ds",
["hub_cloud_ds", "hub_cloud_gcs_ds"],
indirect=True,
)
def test_pytorch_small(ds):
with ds:
Expand Down Expand Up @@ -458,7 +461,7 @@ def test_pytorch_decode(hub_cloud_ds, compressed_image_paths, compression):
ptds = hub_cloud_ds.dataloader().pytorch(decode_method={"image": "tobytes"})

for i, batch in enumerate(ptds):
image = batch["image"]
image = convert_data_according_to_torch_version(batch["image"])
assert isinstance(image, bytes)
if i < 5 and not compression:
np.testing.assert_array_equal(
Expand Down

0 comments on commit 244d99d

Please sign in to comment.