Skip to content

Commit

Permalink
[AL-2017] Add decode method to Pytorch API (activeloopai#1991)
Browse files Browse the repository at this point in the history
* added decode method to new dataloader

* add warning

* update error

* added pil decompression to sample compressed chunk

* add decode method to python dataloader

* experimental -> enterprise

* added new dataloader api

* updated tests

* fixes after merge

* cleanup

* lint fix

* convert pil to np

* remove

* remove unused import

* fixes and tests

* [DL-895] Enterprise API reference (activeloopai#2006)

* init

* update

* docstrings

* update docs

* fix

* fixes

* fix pytorch docstring

* fix error

* update eval docs

* eval doc

* add note

* darg

* fix legacy dataloaders

* add linting

* lint fixes

* mypy fix

* darg fix

* fixes

* test fix

* docs update

Co-authored-by: Fayaz Rahman <fayazrahman4u@gmail.com>
  • Loading branch information
AbhinavTuli and FayazRahman authored Nov 15, 2022
1 parent ca708f1 commit 9d2c9a3
Show file tree
Hide file tree
Showing 30 changed files with 530 additions and 253 deletions.
4 changes: 4 additions & 0 deletions deeplake/core/chunk/sample_compressed_chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def read_sample( # type: ignore
stream: bool = False,
decompress: bool = True,
is_tile: bool = False,
to_pil: bool = False,
):
if self.is_empty_tensor:
raise EmptyTensorError(
Expand Down Expand Up @@ -129,7 +130,10 @@ def read_sample( # type: ignore
end_idx=stop,
step=step,
reverse=reverse,
to_pil=to_pil,
)
if to_pil:
return sample

if squeeze:
sample = sample.squeeze(0)
Expand Down
14 changes: 13 additions & 1 deletion deeplake/core/chunk_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
from deeplake.core.sample import Sample
from itertools import chain, repeat
from collections.abc import Iterable
from PIL import Image # type: ignore


class ChunkEngine:
Expand Down Expand Up @@ -1481,7 +1482,8 @@ def read_sample_from_chunk(
cast: bool = True,
copy: bool = False,
decompress: bool = True,
) -> np.ndarray:
to_pil: bool = False,
) -> Union[np.ndarray, Image.Image]:
enc = self.chunk_id_encoder
if self.is_fixed_shape and self.sample_compression is None:
num_samples_per_chunk = self.num_samples_per_chunk
Expand All @@ -1490,6 +1492,16 @@ def read_sample_from_chunk(
local_sample_index = enc.translate_index_relative_to_chunks(
global_sample_index
)
if to_pil:
assert isinstance(chunk, SampleCompressedChunk)
return chunk.read_sample(
local_sample_index,
cast=cast,
copy=copy,
decompress=decompress,
to_pil=True,
)

return chunk.read_sample(
local_sample_index, cast=cast, copy=copy, decompress=decompress
)
Expand Down
11 changes: 7 additions & 4 deletions deeplake/core/compression.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import io
from logging import warning
import deeplake
from deeplake.util.exceptions import (
Expand All @@ -18,7 +17,7 @@
from typing import Union, Tuple, Sequence, List, Optional, BinaryIO
import numpy as np
from pathlib import Path
from PIL import Image, UnidentifiedImageError # type: ignore
from PIL import Image # type: ignore
from io import BytesIO

import mmap
Expand Down Expand Up @@ -248,7 +247,8 @@ def decompress_array(
end_idx: Optional[int] = None,
step: Optional[int] = None,
reverse: bool = False,
) -> np.ndarray:
to_pil: bool = False,
) -> Union[np.ndarray, Image.Image]:
"""Decompress some buffer into a numpy array. It is expected that all meta information is
stored inside `buffer`.
Expand All @@ -265,13 +265,14 @@ def decompress_array(
end_idx: (int, Optional): Applicable only for video compressions. Index of last frame (exclusive).
step: (int, Optional): Applicable only for video compressions. Step size for seeking.
reverse (bool): Applicable only for video compressions. Reverses output numpy array if set to True.
to_pil (bool): If True, will return a PIL image instead of a numpy array.
Raises:
SampleDecompressionError: If decompression fails.
ValueError: If dtype and shape are not specified for byte compression.
Returns:
np.ndarray: Array from the decompressed buffer.
Union[np.ndarray, Image.Image]: Decompressed array or PIL image.
"""
compr_type = get_compression_type(compression)
if compr_type == BYTE_COMPRESSION:
Expand Down Expand Up @@ -304,6 +305,8 @@ def decompress_array(
if not isinstance(buffer, str):
buffer = BytesIO(buffer) # type: ignore
img = Image.open(buffer) # type: ignore
if to_pil:
return img
arr = np.array(img)
if shape is not None:
arr = arr.reshape(shape)
Expand Down
134 changes: 111 additions & 23 deletions deeplake/core/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1344,7 +1344,7 @@ def diff(
The dictionary will always have 2 keys, "dataset" and "tensors". The values corresponding to these keys are detailed below:
- If ``id_1`` and ``id_2`` are None, both the keys will have a single list as their value. This list will contain a dictionary describing changes compared to the previous commit.
- If only ``id_1`` is provided, both keys will have a tuple of 2 lists as their value. The lists will contain dictionaries describing commitwise differences between commits. The 2 lists will range from current state and ``id_1` to most recent common ancestor the commits respectively.
- If only ``id_1`` is provided, both keys will have a tuple of 2 lists as their value. The lists will contain dictionaries describing commitwise differences between commits. The 2 lists will range from current state and ``id_1`` to most recent common ancestor the commits respectively.
- If only ``id_2`` is provided, a ValueError will be raised.
- If both ``id_1`` and ``id_2`` are provided, both keys will have a tuple of 2 lists as their value. The lists will contain dictionaries describing commitwise differences between commits. The 2 lists will range from ``id_1`` and ``id_2`` to most recent common ancestor the commits respectively.
Expand Down Expand Up @@ -1464,7 +1464,6 @@ def pytorch(
self,
transform: Optional[Callable] = None,
tensors: Optional[Sequence[str]] = None,
tobytes: Union[bool, Sequence[str]] = False,
num_workers: int = 1,
batch_size: int = 1,
drop_last: bool = False,
Expand All @@ -1477,29 +1476,37 @@ def pytorch(
return_index: bool = True,
pad_tensors: bool = False,
transform_kwargs: Optional[Dict[str, Any]] = None,
decode_method: Optional[Dict[str, str]] = None,
):
"""Converts the dataset into a pytorch Dataloader.
Args:
transform (Callable, Optional): Transformation function to be applied to each sample.
tensors (List, Optional): Optionally provide a list of tensor names in the ordering that your training script expects. For example, if you have a dataset that has "image" and "label" tensors, if `tensors=["image", "label"]`, your training script should expect each batch will be provided as a tuple of (image, label).
tobytes (bool): If ``True``, samples will not be decompressed and their raw bytes will be returned instead of numpy arrays. Can also be a list of tensors, in which case those tensors alone will not be decompressed.
tensors (List, Optional): Optionally provide a list of tensor names in the ordering that your training script expects. For example, if you have a dataset that has "image" and "label" tensors, if ``tensors=["image", "label"]``, your training script should expect each batch will be provided as a tuple of (image, label).
num_workers (int): The number of workers to use for fetching data in parallel.
batch_size (int): Number of samples per batch to load. Default value is 1.
drop_last (bool): Set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size.
if ``False`` and the size of dataset is not divisible by the batch size, then the last batch will be smaller. Default value is False.
if ``False`` and the size of dataset is not divisible by the batch size, then the last batch will be smaller. Default value is ``False``.
Read torch.utils.data.DataLoader docs for more details.
collate_fn (Callable, Optional): merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.
Read torch.utils.data.DataLoader docs for more details.
pin_memory (bool): If ``True``, the data loader will copy Tensors into CUDA pinned memory before returning them. Default value is False.
pin_memory (bool): If ``True``, the data loader will copy Tensors into CUDA pinned memory before returning them. Default value is ``False``.
Read torch.utils.data.DataLoader docs for more details.
shuffle (bool): If ``True``, the data loader will shuffle the data indices. Default value is False. Details about how Deep Lake shuffles data can be found at https://docs.activeloop.ai/how-hub-works/shuffling-in-ds.pytorch
shuffle (bool): If ``True``, the data loader will shuffle the data indices. Default value is False. Details about how Deep Lake shuffles data can be found at `Shuffling in ds.pytorch() <https://docs.activeloop.ai/how-it-works/shuffling-in-ds.pytorch>`_
buffer_size (int): The size of the buffer used to shuffle the data in MBs. Defaults to 2048 MB. Increasing the buffer_size will increase the extent of shuffling.
use_local_cache (bool): If ``True``, the data loader will use a local cache to store data. The default cache location is ~/.activeloop/cache, but it can be changed by setting the LOCAL_CACHE_PREFIX environment variable. This is useful when the dataset can fit on the machine and we don't want to fetch the data multiple times for each iteration. Default value is False
use_local_cache (bool): If ``True``, the data loader will use a local cache to store data. The default cache location is ~/.activeloop/cache, but it can be changed by setting the ``LOCAL_CACHE_PREFIX`` environment variable. This is useful when the dataset can fit on the machine and we don't want to fetch the data multiple times for each iteration. Default value is ``False``
use_progress_bar (bool): If ``True``, tqdm will be wrapped around the returned dataloader. Default value is True.
return_index (bool): If ``True``, the returned dataloader will have a key "index" that contains the index of the sample(s) in the original dataset. Default value is True.
pad_tensors (bool): If ``True``, shorter tensors will be padded to the length of the longest tensor. Default value is False.
transform_kwargs (optional, Dict[str, Any]): Additional kwargs to be passed to `transform`.
transform_kwargs (optional, Dict[str, Any]): Additional kwargs to be passed to ``transform``.
decode_method (Dict[str, str], Optional): A dictionary of decode methods for each tensor. Defaults to ``None``.
- Supported decode methods are:
:'numpy': Default behaviour. Returns samples as numpy arrays.
:'tobytes': Returns raw bytes of the samples.
:'pil': Returns samples as PIL images. Especially useful when transformation use torchvision transforms, that
require PIL images as input. Only supported for tensors with ``sample_compression='jpeg'`` or ``'png'``.
Returns:
A torch.utils.data.DataLoader object.
Expand All @@ -1520,7 +1527,6 @@ def pytorch(
self,
transform=transform,
tensors=tensors,
tobytes=tobytes,
num_workers=num_workers,
batch_size=batch_size,
drop_last=drop_last,
Expand All @@ -1531,13 +1537,82 @@ def pytorch(
use_local_cache=use_local_cache,
return_index=return_index,
pad_tensors=pad_tensors,
decode_method=decode_method,
)

if use_progress_bar:
dataloader = tqdm(dataloader, desc=self.path, total=len(self) // batch_size)
dataset_read(self)
return dataloader

def dataloader(self):
"""Returns a :class:`~deeplake.enterprise.DeepLakeDataLoader` object. To use this, install deeplake with ``pip install deeplake[enterprise]``.
Returns:
~deeplake.enterprise.DeepLakeDataLoader: A :class:`deeplake.enterprise.DeepLakeDataLoader` object.
Examples:
Creating a simple dataloader object which returns a batch of numpy arrays
>>> import deeplake
>>> ds_train = deeplake.load('hub://activeloop/fashion-mnist-train')
>>> train_loader = ds_train.dataloader().numpy()
>>> for i, data in enumerate(train_loader):
... # custom logic on data
... pass
Creating dataloader with custom transformation and batch size
>>> import deeplake
>>> import torch
>>> from torchvision import datasets, transforms, models
>>>
>>> ds_train = deeplake.load('hub://activeloop/fashion-mnist-train')
>>> tform = transforms.Compose([
... transforms.ToPILImage(), # Must convert to PIL image for subsequent operations to run
... transforms.RandomRotation(20), # Image augmentation
... transforms.ToTensor(), # Must convert to pytorch tensor for subsequent operations to run
... transforms.Normalize([0.5], [0.5]),
... ])
...
>>> batch_size = 32
>>> # create dataloader by chaining with transform function and batch size and returns batch of pytorch tensors
>>> train_loader = ds_train.dataloader()\\
... .transform({'images': tform, 'labels': None})\\
... .batch(batch_size)\\
... .shuffle()\\
... .pytorch()
...
>>> # loop over the elements
>>> for i, data in enumerate(train_loader):
... # custom logic on data
... pass
Creating dataloader and chaining with query
>>> ds = deeplake.load('hub://activeloop/coco-train')
>>> train_loader = ds_train.dataloader()\\
... .query("(select * where contains(categories, 'car') limit 1000) union (select * where contains(categories, 'motorcycle') limit 1000)")\\
... .pytorch()
...
>>> # loop over the elements
>>> for i, data in enumerate(train_loader):
... # custom logic on data
... pass
**Restrictions**
The new high performance C++ dataloader is part of our Growth and Enterprise Plan .
- Users of our Community plan can create dataloaders on Activeloop datasets ("hub://activeloop/..." datasets).
- To run queries on your own datasets, `upgrade your organization's plan <https://www.activeloop.ai/pricing/>`_.
"""
from deeplake.enterprise import dataloader

return dataloader(self)

@deeplake_reporter.record_call
def filter(
self,
Expand Down Expand Up @@ -1591,7 +1666,7 @@ def filter(
return ret

def query(self, query_string: str):
"""Returns a sliced :class:`~deeplake.core.dataset.Dataset` with given query results.
"""Returns a sliced :class:`~deeplake.core.dataset.Dataset` with given query results. To use this, install deeplake with ``pip install deeplake[enterprise]``.
It allows to run SQL like queries on dataset and extract results. See supported keywords and the Tensor Query Language documentation
:ref:`here <tql>`.
Expand All @@ -1616,8 +1691,15 @@ def query(self, query_string: str):
>>> ds_train = deeplake.load('hub://activeloop/coco-train')
>>> query_ds_train = ds_train.query("(select * where contains(categories, 'car') limit 1000) union (select * where contains(categories, 'motorcycle') limit 1000)")
**Restrictions**
Querying datasets is part of our Growth and Enterprise Plan .
- Users of our Community plan can only perform queries on Activeloop datasets ("hub://activeloop/..." datasets).
- To run queries on your own datasets, `upgrade your organization's plan <https://www.activeloop.ai/pricing/>`_.
"""
from deeplake.experimental import query
from deeplake.enterprise import query

return query(self, query_string)

Expand All @@ -1627,14 +1709,13 @@ def sample_by(
replace: Optional[bool] = True,
size: Optional[int] = None,
):
"""Returns a sliced :class:`~deeplake.core.dataset.Dataset` with given weighted sampler applied
"""Returns a sliced :class:`~deeplake.core.dataset.Dataset` with given weighted sampler applied.
To use this, install deeplake with ``pip install deeplake[enterprise]``.
Args:
weights: (Union[str, list, tuple]): If it's string then tql will be run to calculate the weights based on the expression. list and tuple will be treated as the list of the weights per sample
replace: Optional[bool] If true the samples can be repeated in the result view.
(default: ``True``).
size: Optional[int] The length of the result view.
(default: ``len(dataset)``)
weights: (Union[str, list, tuple]): If it's string then tql will be run to calculate the weights based on the expression. list and tuple will be treated as the list of the weights per sample.
replace: Optional[bool] If true the samples can be repeated in the result view. Defaults to ``True``
size: Optional[int] The length of the result view. Defaults to length of the dataset.
Returns:
Expand All @@ -1651,19 +1732,26 @@ def sample_by(
Sample the dataset treating `labels` tensor as weights.
>>> import deeplake
>>> from deeplake.experimental import query
>>> ds = deeplake.load('hub://activeloop/fashion-mnist-train')
>>> sampled_ds = ds.sample_by("labels")
>>> sampled_ds = ds.sample_by("max_weight(labels == 5: 10, labels == 6: 5"))
Sample the dataset with the given weights;
>>> ds = deeplake.load('hub://activeloop/coco-train')
>>> weights = list()
>>> for i in range(0, len(ds)):
>>> weights.append(i % 5)
>>> for i in range(len(ds)):
... weights.append(i % 5)
...
>>> sampled_ds = ds.sample_by(weights, replace=False)
**Restrictions**
Querying datasets is part of our Growth and Enterprise Plan .
- Users of our Community plan can only use ``sample_by`` on Activeloop datasets ("hub://activeloop/..." datasets).
- To use sampling functionality on your own datasets, `upgrade your organization's plan <https://www.activeloop.ai/pricing/>`_.
"""
from deeplake.experimental import sample_by
from deeplake.enterprise import sample_by

return sample_by(self, weights, replace, size)

Expand Down
Loading

0 comments on commit 9d2c9a3

Please sign in to comment.