diff --git a/deeplake/tests/common.py b/deeplake/tests/common.py index eaaa5f6397..fac807d1d2 100644 --- a/deeplake/tests/common.py +++ b/deeplake/tests/common.py @@ -142,9 +142,7 @@ def __exit__(self, *args, **kwargs): def convert_data_according_to_torch_version(batch): - import torch - - if torch.__version__ < "2.0.0": + if isinstance(batch, List): return batch[0] else: return batch