Skip to content

Commit

Permalink
[Feature] Fine grained DeviceCastTransform (#2041)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Mar 28, 2024
1 parent c98754f commit 2c485dd
Show file tree
Hide file tree
Showing 3 changed files with 365 additions and 24 deletions.
234 changes: 233 additions & 1 deletion test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9331,7 +9331,239 @@ def test_transform_inverse(self):
return


class TestDeviceCastTransform(TransformBase):
class TestDeviceCastTransformPart(TransformBase):
@pytest.mark.parametrize("in_keys", ["observation"])
@pytest.mark.parametrize("out_keys", [None, ["obs_device"]])
@pytest.mark.parametrize("in_keys_inv", ["action"])
@pytest.mark.parametrize("out_keys_inv", [None, ["action_device"]])
def test_single_trans_env_check(self, in_keys, out_keys, in_keys_inv, out_keys_inv):
env = ContinuousActionVecMockEnv(device="cpu:0")
env = TransformedEnv(
env,
DeviceCastTransform(
"cpu:1",
in_keys=in_keys,
out_keys=out_keys,
in_keys_inv=in_keys_inv,
out_keys_inv=out_keys_inv,
),
)
assert env.device is None
check_env_specs(env)

@pytest.mark.parametrize("in_keys", ["observation"])
@pytest.mark.parametrize("out_keys", [None, ["obs_device"]])
@pytest.mark.parametrize("in_keys_inv", ["action"])
@pytest.mark.parametrize("out_keys_inv", [None, ["action_device"]])
def test_serial_trans_env_check(self, in_keys, out_keys, in_keys_inv, out_keys_inv):
def make_env():
return TransformedEnv(
ContinuousActionVecMockEnv(device="cpu:0"),
DeviceCastTransform(
"cpu:1",
in_keys=in_keys,
out_keys=out_keys,
in_keys_inv=in_keys_inv,
out_keys_inv=out_keys_inv,
),
)

env = SerialEnv(2, make_env)
assert env.device is None
check_env_specs(env)

@pytest.mark.parametrize("in_keys", ["observation"])
@pytest.mark.parametrize("out_keys", [None, ["obs_device"]])
@pytest.mark.parametrize("in_keys_inv", ["action"])
@pytest.mark.parametrize("out_keys_inv", [None, ["action_device"]])
def test_parallel_trans_env_check(
self, in_keys, out_keys, in_keys_inv, out_keys_inv
):
def make_env():
return TransformedEnv(
ContinuousActionVecMockEnv(device="cpu:0"),
DeviceCastTransform(
"cpu:1",
in_keys=in_keys,
out_keys=out_keys,
in_keys_inv=in_keys_inv,
out_keys_inv=out_keys_inv,
),
)

env = ParallelEnv(
2,
make_env,
mp_start_method="fork" if not torch.cuda.is_available() else "spawn",
)
assert env.device is None
try:
check_env_specs(env)
finally:
env.close()

@pytest.mark.parametrize("in_keys", ["observation"])
@pytest.mark.parametrize("out_keys", [None, ["obs_device"]])
@pytest.mark.parametrize("in_keys_inv", ["action"])
@pytest.mark.parametrize("out_keys_inv", [None, ["action_device"]])
def test_trans_serial_env_check(self, in_keys, out_keys, in_keys_inv, out_keys_inv):
def make_env():
return ContinuousActionVecMockEnv(device="cpu:0")

env = TransformedEnv(
SerialEnv(2, make_env),
DeviceCastTransform(
"cpu:1",
in_keys=in_keys,
out_keys=out_keys,
in_keys_inv=in_keys_inv,
out_keys_inv=out_keys_inv,
),
)
assert env.device is None
check_env_specs(env)

@pytest.mark.parametrize("in_keys", ["observation"])
@pytest.mark.parametrize("out_keys", [None, ["obs_device"]])
@pytest.mark.parametrize("in_keys_inv", ["action"])
@pytest.mark.parametrize("out_keys_inv", [None, ["action_device"]])
def test_trans_parallel_env_check(
self, in_keys, out_keys, in_keys_inv, out_keys_inv
):
def make_env():
return ContinuousActionVecMockEnv(device="cpu:0")

env = TransformedEnv(
ParallelEnv(
2,
make_env,
mp_start_method="fork" if not torch.cuda.is_available() else "spawn",
),
DeviceCastTransform(
"cpu:1",
in_keys=in_keys,
out_keys=out_keys,
in_keys_inv=in_keys_inv,
out_keys_inv=out_keys_inv,
),
)
assert env.device is None
try:
check_env_specs(env)
finally:
env.close()

def test_transform_no_env(self):
t = DeviceCastTransform("cpu:1", "cpu:0", in_keys=["a"], out_keys=["b"])
td = TensorDict({"a": torch.randn((), device="cpu:0")}, [], device="cpu:0")
tdt = t._call(td)
assert tdt.device is None

@pytest.mark.parametrize("in_keys", ["observation"])
@pytest.mark.parametrize("out_keys", [None, ["obs_device"]])
@pytest.mark.parametrize("in_keys_inv", ["action"])
@pytest.mark.parametrize("out_keys_inv", [None, ["action_device"]])
def test_transform_env(self, in_keys, out_keys, in_keys_inv, out_keys_inv):
env = ContinuousActionVecMockEnv(device="cpu:0")
env = TransformedEnv(
env,
DeviceCastTransform(
"cpu:1",
in_keys=in_keys,
out_keys=out_keys,
in_keys_inv=in_keys_inv,
out_keys_inv=out_keys_inv,
),
)
assert env.device is None
assert env.transform.device == torch.device("cpu:1")
assert env.transform.orig_device == torch.device("cpu:0")

def test_transform_compose(self):
t = Compose(
DeviceCastTransform(
"cpu:1",
"cpu:0",
in_keys=["a"],
out_keys=["b"],
in_keys_inv=["c"],
out_keys_inv=["d"],
)
)

td = TensorDict(
{
"a": torch.randn((), device="cpu:0"),
"c": torch.randn((), device="cpu:1"),
},
[],
device="cpu:0",
)
tdt = t._call(td)
tdit = t._inv_call(td)

assert tdt.device is None
assert tdit.device is None

def test_transform_model(self):
t = nn.Sequential(
Compose(
DeviceCastTransform(
"cpu:1",
"cpu:0",
in_keys=["a"],
out_keys=["b"],
in_keys_inv=["c"],
out_keys_inv=["d"],
)
)
)
td = TensorDict(
{
"a": torch.randn((), device="cpu:0"),
"c": torch.randn((), device="cpu:1"),
},
[],
device="cpu:0",
)
tdt = t(td)

assert tdt.device is None

@pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer])
@pytest.mark.parametrize("storage", [LazyTensorStorage])
def test_transform_rb(self, rbclass, storage):
# we don't test casting to cuda on Memmap tensor storage since it's discouraged
t = Compose(
DeviceCastTransform(
"cpu:1",
"cpu:0",
in_keys=["a"],
out_keys=["b"],
in_keys_inv=["c"],
out_keys_inv=["d"],
)
)
rb = rbclass(storage=storage(max_size=20, device="auto"))
rb.append_transform(t)
td = TensorDict(
{
"a": torch.randn((), device="cpu:0"),
"c": torch.randn((), device="cpu:1"),
},
[],
device="cpu:0",
)
rb.add(td)
assert rb._storage._storage.device is None
assert rb.sample(4).device is None

def test_transform_inverse(self):
# Tested before
return


class TestDeviceCastTransformWhole(TransformBase):
def test_single_trans_env_check(self):
env = ContinuousActionVecMockEnv(device="cpu:0")
env = TransformedEnv(env, DeviceCastTransform("cpu:1"))
Expand Down
7 changes: 5 additions & 2 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ def decorator(func):

def clear_device_(self):
"""A no-op for all leaf specs (which must have a device)."""
pass
return self

def encode(
self, val: Union[np.ndarray, torch.Tensor], *, ignore_device=False
Expand Down Expand Up @@ -866,6 +866,7 @@ def clear_device_(self):
"""Clears the device of the CompositeSpec."""
for spec in self._specs:
spec.clear_device_()
return self

def __getitem__(self, item):
is_key = isinstance(item, str) or (
Expand Down Expand Up @@ -3594,8 +3595,10 @@ def device(self, device: DEVICE_TYPING):

def clear_device_(self):
"""Clears the device of the CompositeSpec."""
for spec in self._specs:
self._device = None
for spec in self._specs.values():
spec.clear_device_()
return self

def __getitem__(self, idx):
"""Indexes the current CompositeSpec based on the provided index."""
Expand Down
Loading

0 comments on commit 2c485dd

Please sign in to comment.