Skip to content

Commit

Permalink
[Feature] Fix DType casting lazy init (pytorch#1589)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Oct 2, 2023
1 parent db1a7d4 commit 59d29b8
Show file tree
Hide file tree
Showing 5 changed files with 393 additions and 176 deletions.
11 changes: 4 additions & 7 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2259,12 +2259,6 @@ def test_double2float(self, keys, keys_inv, device):
)
action_spec = double2float.transform_input_spec(input_spec)
assert action_spec.dtype == torch.float

elif len(keys) == 1:
observation_spec = BoundedTensorSpec(0, 1, (1, 3, 3), dtype=torch.double)
observation_spec = double2float.transform_observation_spec(observation_spec)
assert observation_spec.dtype == torch.float

else:
observation_spec = CompositeSpec(
{
Expand All @@ -2274,7 +2268,7 @@ def test_double2float(self, keys, keys_inv, device):
)
observation_spec = double2float.transform_observation_spec(observation_spec)
for key in keys:
assert observation_spec[key].dtype == torch.float
assert observation_spec[key].dtype == torch.float, key

@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize(
Expand Down Expand Up @@ -2326,6 +2320,7 @@ def test_single_env_no_inkeys(self):
base_env.state_spec[key] = spec.to(torch.float64)
if base_env.action_spec.dtype == torch.float32:
base_env.action_spec = base_env.action_spec.to(torch.float64)
check_env_specs(base_env)
env = TransformedEnv(
base_env,
DoubleToFloat(),
Expand All @@ -2335,6 +2330,8 @@ def test_single_env_no_inkeys(self):
for spec in env.state_spec.values(True, True):
assert spec.dtype == torch.float32
assert env.action_spec.dtype != torch.float64
assert env.transform.in_keys == env.transform.out_keys
assert env.transform.in_keys_inv == env.transform.out_keys_inv
check_env_specs(env)

def test_single_trans_env_check(self, dtype_fixture): # noqa: F811
Expand Down
24 changes: 11 additions & 13 deletions torchrl/envs/transforms/rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from copy import deepcopy
from copy import copy, deepcopy

import torch
from tensordict import TensorDictBase, unravel_key
Expand Down Expand Up @@ -93,24 +93,22 @@ def __init__(
if in_keys is None:
in_keys = self.DEFAULT_IN_KEYS
if out_keys is None:
out_keys = in_keys
if not isinstance(in_keys, list):
in_keys = [in_keys]
if not isinstance(out_keys, list):
out_keys = [out_keys]
if not is_seq_of_nested_key(in_keys) or not is_seq_of_nested_key(out_keys):
out_keys = copy(in_keys)
super().__init__(in_keys=in_keys, out_keys=out_keys)
if not is_seq_of_nested_key(self.in_keys) or not is_seq_of_nested_key(
self.out_keys
):
raise ValueError(
f"invalid in_keys / out_keys:\nin_keys={in_keys} \nout_keys={out_keys}"
f"invalid in_keys / out_keys:\nin_keys={self.in_keys} \nout_keys={self.out_keys}"
)
if len(in_keys) != 1 or len(out_keys) != 1:
if len(self.in_keys) != 1 or len(self.out_keys) != 1:
raise ValueError(
f"Only one in_key/out_key is allowed, got in_keys={in_keys}, out_keys={out_keys}."
f"Only one in_key/out_key is allowed, got in_keys={self.in_keys}, out_keys={self.out_keys}."
)
super().__init__(in_keys=in_keys, out_keys=out_keys)
# for convenience, convert out_keys to tuples
self.out_keys = [
self._out_keys = [
out_key if isinstance(out_key, tuple) else (out_key,)
for out_key in self.out_keys
for out_key in self._out_keys
]

# update the in_keys for dispatch etc
Expand Down
Loading

0 comments on commit 59d29b8

Please sign in to comment.