Skip to content

Commit

Permalink
[Feature] Device transform (pytorch#1472)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Oct 10, 2023
1 parent 3366f93 commit 3c63a58
Show file tree
Hide file tree
Showing 8 changed files with 287 additions and 38 deletions.
1 change: 1 addition & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ to be able to create this other composition:
CatTensors
CenterCrop
Compose
DeviceCastTransform
DiscreteActionProjection
DoubleToFloat
DTypeCastTransform
Expand Down
2 changes: 1 addition & 1 deletion test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(
**kwargs,
):
super().__init__(
device="cpu",
device=kwargs.pop("device", "cpu"),
dtype=torch.get_default_dtype(),
)
self.set_seed(seed)
Expand Down
102 changes: 102 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@
from torchrl.data import (
BoundedTensorSpec,
CompositeSpec,
LazyMemmapStorage,
LazyTensorStorage,
ReplayBuffer,
TensorDictReplayBuffer,
TensorStorage,
UnboundedContinuousTensorSpec,
)
from torchrl.envs import (
Expand All @@ -49,6 +51,7 @@
CatTensors,
CenterCrop,
Compose,
DeviceCastTransform,
DiscreteActionProjection,
DoubleToFloat,
EnvBase,
Expand Down Expand Up @@ -8133,6 +8136,105 @@ def test_kl_lstm(self):
klt(env.rollout(3, policy))


class TestDeviceCastTransform(TransformBase):
def test_single_trans_env_check(self):
env = ContinuousActionVecMockEnv(device="cpu:0")
env = TransformedEnv(env, DeviceCastTransform("cpu:1"))
assert env.device == torch.device("cpu:1")
check_env_specs(env)

def test_serial_trans_env_check(self):
def make_env():
return TransformedEnv(
ContinuousActionVecMockEnv(device="cpu:0"), DeviceCastTransform("cpu:1")
)

env = SerialEnv(2, make_env)
assert env.device == torch.device("cpu:1")
check_env_specs(env)

def test_parallel_trans_env_check(self):
def make_env():
return TransformedEnv(
ContinuousActionVecMockEnv(device="cpu:0"), DeviceCastTransform("cpu:1")
)

env = ParallelEnv(2, make_env)
assert env.device == torch.device("cpu:1")
check_env_specs(env)

def test_trans_serial_env_check(self):
def make_env():
return ContinuousActionVecMockEnv(device="cpu:0")

env = TransformedEnv(SerialEnv(2, make_env), DeviceCastTransform("cpu:1"))
assert env.device == torch.device("cpu:1")
check_env_specs(env)

def test_trans_parallel_env_check(self):
def make_env():
return ContinuousActionVecMockEnv(device="cpu:0")

env = TransformedEnv(ParallelEnv(2, make_env), DeviceCastTransform("cpu:1"))
assert env.device == torch.device("cpu:1")
check_env_specs(env)

def test_transform_no_env(self):
t = DeviceCastTransform("cpu:1", "cpu:0")
assert t._call(TensorDict({}, [], device="cpu:0")).device == torch.device(
"cpu:1"
)

def test_transform_compose(self):
t = Compose(DeviceCastTransform("cpu:1", "cpu:0"))
assert t._call(TensorDict({}, [], device="cpu:0")).device == torch.device(
"cpu:1"
)
assert t._inv_call(TensorDict({}, [], device="cpu:1")).device == torch.device(
"cpu:0"
)

def test_transform_env(self):
env = ContinuousActionVecMockEnv(device="cpu:0")
assert env.device == torch.device("cpu:0")
env = TransformedEnv(env, DeviceCastTransform("cpu:1"))
assert env.device == torch.device("cpu:1")
assert env.transform.device == torch.device("cpu:1")
assert env.transform.orig_device == torch.device("cpu:0")

def test_transform_model(self):
t = Compose(DeviceCastTransform("cpu:1", "cpu:0"))
m = nn.Sequential(t)
assert t(TensorDict({}, [], device="cpu:0")).device == torch.device("cpu:1")

@pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer])
@pytest.mark.parametrize(
"storage", [TensorStorage, LazyTensorStorage, LazyMemmapStorage]
)
def test_transform_rb(self, rbclass, storage):
t = Compose(DeviceCastTransform("cpu:1", "cpu:0"))
storage_kwargs = (
{
"storage": TensorDict(
{"a": torch.zeros(20, 1, device="cpu:0")}, [20], device="cpu:0"
)
}
if storage is TensorStorage
else {}
)
rb = rbclass(storage=storage(max_size=20, device="auto", **storage_kwargs))
rb.append_transform(t)
rb.add(TensorDict({"a": [1]}, [], device="cpu:1"))
assert rb._storage._storage.device == torch.device("cpu:0")
assert rb.sample(4).device == torch.device("cpu:1")

def test_transform_inverse(self):
t = DeviceCastTransform("cpu:1", "cpu:0")
assert t._inv_call(TensorDict({}, [], device="cpu:1")).device == torch.device(
"cpu:0"
)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
33 changes: 22 additions & 11 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,13 @@ def add(self, data: Any) -> int:
Returns:
index where the data lives in the replay buffer.
"""
if self._transform is not None and (
is_tensor_collection(data) or len(self._transform)
):
data = self._transform.inv(data)
return self._add(data)

def _add(self, data):
with self._replay_lock:
index = self._writer.add(data)
self._sampler.add(index)
Expand All @@ -271,9 +278,9 @@ def extend(self, data: Sequence) -> torch.Tensor:
Returns:
Indices of the data added to the replay buffer.
"""
if self._transform is not None and is_tensor_collection(data):
data = self._transform.inv(data)
elif self._transform is not None and len(self._transform):
if self._transform is not None and (
is_tensor_collection(data) or len(self._transform)
):
data = self._transform.inv(data)
return self._extend(data)

Expand Down Expand Up @@ -675,19 +682,24 @@ def _get_priority(self, tensordict: TensorDictBase) -> Optional[torch.Tensor]:
return priority

def add(self, data: TensorDictBase) -> int:
if self._transform is not None:
data = self._transform.inv(data)

if is_tensor_collection(data):
data_add = TensorDict(
{
"_data": data,
},
batch_size=[],
device=data.device,
)
if data.batch_size:
data_add["_rb_batch_size"] = torch.tensor(data.batch_size)

else:
data_add = data
index = super().add(data_add)

index = super()._add(data_add)
if is_tensor_collection(data_add):
data_add.set("index", index)

Expand All @@ -699,7 +711,8 @@ def add(self, data: TensorDictBase) -> int:
def extend(self, tensordicts: Union[List, TensorDictBase]) -> torch.Tensor:
if is_tensor_collection(tensordicts):
tensordicts = TensorDict(
{"_data": tensordicts}, batch_size=tensordicts.batch_size[:1]
{"_data": tensordicts},
batch_size=tensordicts.batch_size[:1],
)
if tensordicts.batch_dims > 1:
# we want the tensordict to have one dimension only. The batch size
Expand Down Expand Up @@ -730,14 +743,12 @@ def extend(self, tensordicts: Union[List, TensorDictBase]) -> torch.Tensor:
stacked_td = tensordicts

if self._transform is not None:
stacked_td.set("_data", self._transform.inv(stacked_td.get("_data")))
tensordicts = self._transform.inv(stacked_td.get("_data"))
stacked_td.set("_data", tensordicts)
if tensordicts.device is not None:
stacked_td = stacked_td.to(tensordicts.device)

index = super()._extend(stacked_td)
# stacked_td.set(
# "index",
# torch.tensor(index, dtype=torch.int, device=stacked_td.device),
# inplace=True,
# )
self.update_tensordict_priority(stacked_td)
return index

Expand Down
35 changes: 27 additions & 8 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,14 @@ class TensorStorage(Storage):
"""A storage for tensors and tensordicts.
Args:
data (tensor or TensorDict): the data buffer to be used.
storage (tensor or TensorDict): the data buffer to be used.
max_size (int): size of the storage, i.e. maximum number of elements stored
in the buffer.
device (torch.device, optional): device where the sampled tensors will be
stored and sent. Default is :obj:`torch.device("cpu")`.
If "auto" is passed, the device is automatically gathered from the
first batch of data passed. This is not enabled by default to avoid
data placed on GPU by mistake, causing OOM issues.
Examples:
>>> data = TensorDict({
Expand Down Expand Up @@ -230,7 +233,7 @@ def __new__(cls, *args, **kwargs):
cls._storage = None
return super().__new__(cls)

def __init__(self, storage, max_size=None, device=None):
def __init__(self, storage, max_size=None, device="cpu"):
if not ((storage is None) ^ (max_size is None)):
if storage is None:
raise ValueError("Expected storage to be non-null.")
Expand All @@ -247,7 +250,13 @@ def __init__(self, storage, max_size=None, device=None):
self._len = max_size
else:
self._len = 0
self.device = device if device else torch.device("cpu")
self.device = (
torch.device(device)
if device != "auto"
else storage.device
if storage is not None
else "auto"
)
self._storage = storage

def state_dict(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -345,6 +354,9 @@ class LazyTensorStorage(TensorStorage):
in the buffer.
device (torch.device, optional): device where the sampled tensors will be
stored and sent. Default is :obj:`torch.device("cpu")`.
If "auto" is passed, the device is automatically gathered from the
first batch of data passed. This is not enabled by default to avoid
data placed on GPU by mistake, causing OOM issues.
Examples:
>>> data = TensorDict({
Expand Down Expand Up @@ -396,12 +408,14 @@ class LazyTensorStorage(TensorStorage):
"""

def __init__(self, max_size, device=None):
super().__init__(None, max_size, device=device)
def __init__(self, max_size, device="cpu"):
super().__init__(storage=None, max_size=max_size, device=device)

def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None:
if VERBOSE:
print("Creating a TensorStorage...")
if self.device == "auto":
self.device = data.device
if isinstance(data, torch.Tensor):
# if Tensor, we just create a MemmapTensor of the desired shape, device and dtype
out = torch.empty(
Expand Down Expand Up @@ -436,6 +450,9 @@ class LazyMemmapStorage(LazyTensorStorage):
scratch_dir (str or path): directory where memmap-tensors will be written.
device (torch.device, optional): device where the sampled tensors will be
stored and sent. Default is :obj:`torch.device("cpu")`.
If ``None`` is provided, the device is automatically gathered from the
first batch of data passed. This is not enabled by default to avoid
data placed on GPU by mistake, causing OOM issues.
Examples:
>>> data = TensorDict({
Expand Down Expand Up @@ -486,15 +503,15 @@ class LazyMemmapStorage(LazyTensorStorage):
"""

def __init__(self, max_size, scratch_dir=None, device=None):
def __init__(self, max_size, scratch_dir=None, device="cpu"):
super().__init__(max_size)
self.initialized = False
self.scratch_dir = None
if scratch_dir is not None:
self.scratch_dir = str(scratch_dir)
if self.scratch_dir[-1] != "/":
self.scratch_dir += "/"
self.device = device if device else torch.device("cpu")
self.device = torch.device(device) if device != "auto" else device
self._len = 0

def state_dict(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -552,6 +569,8 @@ def load_state_dict(self, state_dict):
def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None:
if VERBOSE:
print("Creating a MemmapStorage...")
if self.device == "auto":
self.device = data.device
if isinstance(data, torch.Tensor):
# if Tensor, we just create a MemmapTensor of the desired shape, device and dtype
out = MemmapTensor(
Expand Down Expand Up @@ -682,7 +701,7 @@ def _get_default_collate(storage, _is_tensordict=False):
return torch.utils.data._utils.collate.default_collate
elif isinstance(storage, LazyMemmapStorage):
return _collate_as_tensor
elif isinstance(storage, (LazyTensorStorage,)):
elif isinstance(storage, (TensorStorage,)):
return _collate_contiguous
else:
raise NotImplementedError(
Expand Down
1 change: 1 addition & 0 deletions torchrl/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
CatTensors,
CenterCrop,
Compose,
DeviceCastTransform,
DiscreteActionProjection,
DoubleToFloat,
DTypeCastTransform,
Expand Down
1 change: 1 addition & 0 deletions torchrl/envs/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
CatTensors,
CenterCrop,
Compose,
DeviceCastTransform,
DiscreteActionProjection,
DoubleToFloat,
DTypeCastTransform,
Expand Down
Loading

0 comments on commit 3c63a58

Please sign in to comment.