Skip to content

Commit

Permalink
[Feature] Extend TensorDictPrimer default_value options (#2071)
Browse files Browse the repository at this point in the history
Co-authored-by: Vincent Moens <vmoens@meta.com>
  • Loading branch information
albertbou92 and vmoens authored Apr 18, 2024
1 parent 36c89dc commit 6b87184
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 40 deletions.
89 changes: 72 additions & 17 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6423,17 +6423,11 @@ def test_trans_parallel_env_check(self):
finally:
env.close()

def test_trans_serial_env_check(self):
with pytest.raises(RuntimeError, match="The leading shape of the primer specs"):
env = TransformedEnv(
SerialEnv(2, ContinuousActionVecMockEnv),
TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([4])),
)
_ = env.observation_spec

@pytest.mark.parametrize("spec_shape", [[4], [2, 4]])
def test_trans_serial_env_check(self, spec_shape):
env = TransformedEnv(
SerialEnv(2, ContinuousActionVecMockEnv),
TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([2, 4])),
TensorDictPrimer(mykey=UnboundedContinuousTensorSpec(spec_shape)),
)
check_env_specs(env)
assert "mykey" in env.reset().keys()
Expand Down Expand Up @@ -6533,6 +6527,72 @@ def test_tensordictprimer_batching(self, batched_class, break_when_any_done):
r1 = env.rollout(100, break_when_any_done=break_when_any_done)
tensordict.tensordict.assert_allclose_td(r0, r1)

def test_callable_default_value(self):
def create_tensor():
return torch.ones(3)

env = TransformedEnv(
ContinuousActionVecMockEnv(),
TensorDictPrimer(
mykey=UnboundedContinuousTensorSpec([3]), default_value=create_tensor
),
)
check_env_specs(env)
assert "mykey" in env.reset().keys()
assert ("next", "mykey") in env.rollout(3).keys(True)

def test_dict_default_value(self):

# Test with a dict of float default values
key1_spec = UnboundedContinuousTensorSpec([3])
key2_spec = UnboundedContinuousTensorSpec([3])
env = TransformedEnv(
ContinuousActionVecMockEnv(),
TensorDictPrimer(
mykey1=key1_spec,
mykey2=key2_spec,
default_value={
"mykey1": 1.0,
"mykey2": 2.0,
},
),
)
check_env_specs(env)
reset_td = env.reset()
assert "mykey1" in reset_td.keys()
assert "mykey2" in reset_td.keys()
rollout_td = env.rollout(3)
assert ("next", "mykey1") in rollout_td.keys(True)
assert ("next", "mykey2") in rollout_td.keys(True)
assert (rollout_td.get(("next", "mykey1")) == 1.0).all()
assert (rollout_td.get(("next", "mykey2")) == 2.0).all()

# Test with a dict of callable default values
key1_spec = UnboundedContinuousTensorSpec([3])
key2_spec = DiscreteTensorSpec(3, dtype=torch.int64)
env = TransformedEnv(
ContinuousActionVecMockEnv(),
TensorDictPrimer(
mykey1=key1_spec,
mykey2=key2_spec,
default_value={
"mykey1": lambda: torch.ones(3),
"mykey2": lambda: torch.tensor(1, dtype=torch.int64),
},
),
)
check_env_specs(env)
reset_td = env.reset()
assert "mykey1" in reset_td.keys()
assert "mykey2" in reset_td.keys()
rollout_td = env.rollout(3)
assert ("next", "mykey1") in rollout_td.keys(True)
assert ("next", "mykey2") in rollout_td.keys(True)
assert (rollout_td.get(("next", "mykey1")) == torch.ones(3)).all
assert (
rollout_td.get(("next", "mykey2")) == torch.tensor(1, dtype=torch.int64)
).all


class TestTimeMaxPool(TransformBase):
@pytest.mark.parametrize("T", [2, 4])
Expand Down Expand Up @@ -6813,18 +6873,13 @@ def make_env():
finally:
env.close()

def test_trans_serial_env_check(self):
@pytest.mark.parametrize("shape", [(), (2,)])
def test_trans_serial_env_check(self, shape):
state_dim = 7
action_dim = 7
with pytest.raises(RuntimeError, match="The leading shape of the primer"):
env = TransformedEnv(
SerialEnv(2, ContinuousActionVecMockEnv),
gSDENoise(state_dim=state_dim, action_dim=action_dim, shape=()),
)
check_env_specs(env)
env = TransformedEnv(
SerialEnv(2, ContinuousActionVecMockEnv),
gSDENoise(state_dim=state_dim, action_dim=action_dim, shape=(2,)),
gSDENoise(state_dim=state_dim, action_dim=action_dim, shape=shape),
)
try:
check_env_specs(env)
Expand Down
104 changes: 82 additions & 22 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4435,8 +4435,12 @@ class TensorDictPrimer(Transform):
random (bool, optional): if ``True``, the values will be drawn randomly from
the TensorSpec domain (or a unit Gaussian if unbounded). Otherwise a fixed value will be assumed.
Defaults to `False`.
default_value (float, optional): if non-random filling is chosen, this
value will be used to populate the tensors. Defaults to `0.0`.
default_value (float, Callable, Dict[NestedKey, float], Dict[NestedKey, Callable], optional): If non-random
filling is chosen, `default_value` will be used to populate the tensors. If `default_value` is a float,
all elements of the tensors will be set to that value. If it is a callable, this callable is expected to
return a tensor fitting the specs, and it will be used to generate the tensors. Finally, if `default_value`
is a dictionary of tensors or a dictionary of callables with keys matching those of the specs, these will
be used to generate the corresponding tensors. Defaults to `0.0`.
reset_key (NestedKey, optional): the reset key to be used as partial
reset indicator. Must be unique. If not provided, defaults to the
only reset key of the parent environment (if it has only one)
Expand Down Expand Up @@ -4493,8 +4497,11 @@ class TensorDictPrimer(Transform):
def __init__(
self,
primers: dict | CompositeSpec = None,
random: bool = False,
default_value: float = 0.0,
random: bool | None = None,
default_value: float
| Callable
| Dict[NestedKey, float]
| Dict[NestedKey, Callable] = None,
reset_key: NestedKey | None = None,
**kwargs,
):
Expand All @@ -4509,8 +4516,31 @@ def __init__(
if not isinstance(kwargs, CompositeSpec):
kwargs = CompositeSpec(kwargs)
self.primers = kwargs
if random and default_value:
raise ValueError(
"Setting random to True and providing a default_value are incompatible."
)
default_value = (
default_value or 0.0
) # if not random and no default value, use 0.0
self.random = random
if isinstance(default_value, dict):
default_value = TensorDict(default_value, [])
default_value_keys = default_value.keys(
True,
True,
is_leaf=lambda x: issubclass(x, (NonTensorData, torch.Tensor)),
)
if set(default_value_keys) != set(self.primers.keys(True, True)):
raise ValueError(
"If a default_value dictionary is provided, it must match the primers keys."
)
else:
default_value = {
key: default_value for key in self.primers.keys(True, True)
}
self.default_value = default_value
self._validated = False
self.reset_key = reset_key

# sanity check
Expand Down Expand Up @@ -4563,6 +4593,9 @@ def to(self, *args, **kwargs):
self.primers = self.primers.to(device)
return super().to(*args, **kwargs)

def _expand_shape(self, spec):
return spec.expand((*self.parent.batch_size, *spec.shape))

def transform_observation_spec(
self, observation_spec: CompositeSpec
) -> CompositeSpec:
Expand All @@ -4572,15 +4605,13 @@ def transform_observation_spec(
)
for key, spec in self.primers.items():
if spec.shape[: len(observation_spec.shape)] != observation_spec.shape:
raise RuntimeError(
f"The leading shape of the primer specs ({self.__class__}) should match the one of the parent env. "
f"Got observation_spec.shape={observation_spec.shape} but the '{key}' entry's shape is {spec.shape}."
)
expanded_spec = self._expand_shape(spec)
spec = expanded_spec
try:
device = observation_spec.device
except RuntimeError:
device = self.device
observation_spec[key] = spec.to(device)
observation_spec[key] = self.primers[key] = spec.to(device)
return observation_spec

def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec:
Expand All @@ -4593,8 +4624,13 @@ def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec:
def _batch_size(self):
return self.parent.batch_size

def _validate_value_tensor(self, value, spec):
if not spec.is_in(value):
raise RuntimeError(f"Value ({value}) is not in the spec domain ({spec}).")
return True

def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
for key, spec in self.primers.items():
for key, spec in self.primers.items(True, True):
if spec.shape[: len(tensordict.shape)] != tensordict.shape:
raise RuntimeError(
"The leading shape of the spec must match the tensordict's, "
Expand All @@ -4605,11 +4641,21 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
if self.random:
value = spec.rand()
else:
value = torch.full_like(
spec.zero(),
self.default_value,
)
value = self.default_value[key]
if callable(value):
value = value()
if not self._validated:
self._validate_value_tensor(value, spec)
else:
value = torch.full(
spec.shape,
value,
device=spec.device,
)

tensordict.set(key, value)
if not self._validated:
self._validated = True
return tensordict

def _step(
Expand Down Expand Up @@ -4638,22 +4684,36 @@ def _reset(
)
_reset = _get_reset(self.reset_key, tensordict)
if _reset.any():
for key, spec in self.primers.items():
for key, spec in self.primers.items(True, True):
if self.random:
value = spec.rand(shape)
else:
value = torch.full_like(
spec.zero(shape),
self.default_value,
)
prev_val = tensordict.get(key, 0.0)
value = torch.where(expand_as_right(_reset, value), value, prev_val)
value = self.default_value[key]
if callable(value):
value = value()
if not self._validated:
self._validate_value_tensor(value, spec)
else:
value = torch.full(
spec.shape,
value,
device=spec.device,
)
prev_val = tensordict.get(key, 0.0)
value = torch.where(
expand_as_right(_reset, value), value, prev_val
)
tensordict_reset.set(key, value)
self._validated = True
return tensordict_reset

def __repr__(self) -> str:
class_name = self.__class__.__name__
return f"{class_name}(primers={self.primers}, default_value={self.default_value}, random={self.random})"
default_value = {
key: value if isinstance(value, float) else "Callable"
for key, value in self.default_value.items()
}
return f"{class_name}(primers={self.primers}, default_value={default_value}, random={self.random})"


class PinMemoryTransform(Transform):
Expand Down
1 change: 0 additions & 1 deletion torchrl/objectives/value/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import functools
import math

import warnings
Expand Down

0 comments on commit 6b87184

Please sign in to comment.