Skip to content

Commit

Permalink
[BugFix,Feature] Allow non-tensor data in envs (#1944)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jun 19, 2024
1 parent 47a2627 commit 038a615
Show file tree
Hide file tree
Showing 7 changed files with 240 additions and 37 deletions.
1 change: 1 addition & 0 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,7 @@ Check the :obj:`torchrl.envs.utils.check_env_specs` method for a sanity check.
UnboundedDiscreteTensorSpec
LazyStackedTensorSpec
LazyStackedCompositeSpec
NonTensorSpec

Reinforcement Learning From Human Feedback (RLHF)
-------------------------------------------------
Expand Down
3 changes: 3 additions & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ Each env will have the following attributes:
all the output keys (:obj:`"full_observation_spec"`, :obj:`"full_reward_spec"` and :obj:`"full_done_spec"`).
It is locked and should not be modified directly.

If the environment carries non-tensor data, a :class:`~torchrl.data.NonTensorSpec`
instance can be used.

Importantly, the environment spec shapes should contain the batch size, e.g.
an environment with :obj:`env.batch_size == torch.Size([4])` should have
an :obj:`env.action_spec` with shape :obj:`torch.Size([4, action_size])`.
Expand Down
34 changes: 34 additions & 0 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
CompositeSpec,
DiscreteTensorSpec,
MultiOneHotDiscreteTensorSpec,
NonTensorSpec,
OneHotDiscreteTensorSpec,
TensorSpec,
UnboundedContinuousTensorSpec,
Expand Down Expand Up @@ -1825,6 +1826,39 @@ def _set_seed(self, seed: Optional[int]):
torch.manual_seed(seed)


class EnvWithMetadata(EnvBase):
def __init__(self):
super().__init__()
self.observation_spec = CompositeSpec(
tensor=UnboundedContinuousTensorSpec(3),
non_tensor=NonTensorSpec(shape=()),
)
self.state_spec = CompositeSpec(
non_tensor=NonTensorSpec(shape=()),
)
self.reward_spec = UnboundedContinuousTensorSpec(1)
self.action_spec = UnboundedContinuousTensorSpec(1)

def _reset(self, tensordict):
data = self.observation_spec.zero()
data.set_non_tensor("non_tensor", 0)
data.update(self.full_done_spec.zero())
return data

def _step(
self,
tensordict: TensorDictBase,
) -> TensorDictBase:
data = self.observation_spec.zero()
data.set_non_tensor("non_tensor", tensordict["non_tensor"] + 1)
data.update(self.full_done_spec.zero())
data.update(self.full_reward_spec.zero())
return data

def _set_seed(self, seed: Optional[int]):
return seed


class AutoResettingCountingEnv(CountingEnv):
def _step(self, tensordict):
tensordict = super()._step(tensordict)
Expand Down
27 changes: 27 additions & 0 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
DiscreteActionVecMockEnv,
DummyModelBasedEnvBase,
EnvWithDynamicSpec,
EnvWithMetadata,
HeterogeneousCountingEnv,
HeterogeneousCountingEnvPolicy,
MockBatchedLockedEnv,
Expand Down Expand Up @@ -2395,6 +2396,7 @@ def test_parallel(
@pytest.mark.parametrize(
"envclass",
[
EnvWithMetadata,
ContinuousActionConvMockEnv,
ContinuousActionConvMockEnvNumpy,
ContinuousActionVecMockEnv,
Expand All @@ -2419,6 +2421,7 @@ def test_mocking_envs(envclass):
env.set_seed(100)
reset = env.reset()
_ = env.rand_step(reset)
r = env.rollout(3)
check_env_specs(env, seed=100, return_contiguous=False)


Expand Down Expand Up @@ -3162,6 +3165,30 @@ def test_batched_dynamic(self, break_when_any_done):
assert_allclose_td(rollout_no_buffers_serial, rollout_no_buffers_parallel)


class TestNonTensorEnv:
@pytest.mark.parametrize("bwad", [True, False])
def test_single(self, bwad):
env = EnvWithMetadata()
r = env.rollout(10, break_when_any_done=bwad)
assert r.get("non_tensor").tolist() == list(range(10))

@pytest.mark.parametrize("bwad", [True, False])
@pytest.mark.parametrize("use_buffers", [False, True])
def test_serial(self, bwad, use_buffers):
N = 50
env = SerialEnv(2, EnvWithMetadata, use_buffers=use_buffers)
r = env.rollout(N, break_when_any_done=bwad)
assert r.get("non_tensor").tolist() == [list(range(N))] * 2

@pytest.mark.parametrize("bwad", [True, False])
@pytest.mark.parametrize("use_buffers", [False, True])
def test_parallel(self, bwad, use_buffers):
N = 50
env = ParallelEnv(2, EnvWithMetadata, use_buffers=use_buffers)
r = env.rollout(N, break_when_any_done=bwad)
assert r.get("non_tensor").tolist() == [list(range(N))] * 2


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
10 changes: 6 additions & 4 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,7 @@ def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Ten
indexed tensor
"""
raise NotImplementedError
...

@abc.abstractmethod
def expand(self, *shape):
Expand All @@ -656,7 +656,7 @@ def expand(self, *shape):
from it if the current dimension is a singleton.
"""
raise NotImplementedError
...

def squeeze(self, dim: int | None = None):
"""Returns a new Spec with all the dimensions of size ``1`` removed.
Expand Down Expand Up @@ -740,7 +740,7 @@ def is_in(self, val: torch.Tensor) -> bool:
boolean indicating if values belongs to the TensorSpec box
"""
raise NotImplementedError
...

def contains(self, item):
"""Returns whether a sample is contained within the space defined by the TensorSpec.
Expand Down Expand Up @@ -2120,7 +2120,9 @@ def is_in(self, val: torch.Tensor) -> bool:
return (
isinstance(val, NonTensorData)
and val.shape == shape
and val.device == self.device
# We relax constrains on device as they're hard to enforce for non-tensor
# tensordicts and pointless
# and val.device == self.device
and val.dtype == self.dtype
)

Expand Down
Loading

0 comments on commit 038a615

Please sign in to comment.