Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix,Feature] Allow non-tensor data in envs #1944

Merged
merged 11 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
amend
  • Loading branch information
vmoens committed Jun 11, 2024
commit 0529917c80c427810a6e52cd425c0963930f10e1
3 changes: 3 additions & 0 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1833,6 +1833,9 @@ def __init__(self):
tensor=UnboundedContinuousTensorSpec(3),
non_tensor=NonTensorSpec(shape=()),
)
self.state_spec = CompositeSpec(
non_tensor=NonTensorSpec(shape=()),
)
self.reward_spec = UnboundedContinuousTensorSpec(1)
self.action_spec = UnboundedContinuousTensorSpec(1)

Expand Down
22 changes: 13 additions & 9 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@
DiscreteActionConvMockEnvNumpy,
DiscreteActionVecMockEnv,
DummyModelBasedEnvBase,
EnvWithMetadata,
EnvWithDynamicSpec,
EnvWithMetadata,
HeterogeneousCountingEnv,
HeterogeneousCountingEnvPolicy,
MockBatchedLockedEnv,
Expand Down Expand Up @@ -3173,16 +3173,20 @@ def test_single(self, bwad):
assert r.get("non_tensor").tolist() == list(range(10))

@pytest.mark.parametrize("bwad", [True, False])
def test_serial(self, bwad):
env = SerialEnv(2, EnvWithMetadata)
r = env.rollout(10, break_when_any_done=bwad)
assert r.get("non_tensor").tolist() == [list(range(10))] * 2
@pytest.mark.parametrize("use_buffers", [False, True])
def test_serial(self, bwad, use_buffers):
N = 50
env = SerialEnv(2, EnvWithMetadata, use_buffers=use_buffers)
r = env.rollout(N, break_when_any_done=bwad)
assert r.get("non_tensor").tolist() == [list(range(N))] * 2

@pytest.mark.parametrize("bwad", [True, False])
def test_parallel(self, bwad):
env = ParallelEnv(2, EnvWithMetadata)
r = env.rollout(10, break_when_any_done=bwad)
assert r.get("non_tensor").tolist() == [list(range(10))] * 2
@pytest.mark.parametrize("use_buffers", [False, True])
def test_parallel(self, bwad, use_buffers):
N = 50
env = ParallelEnv(2, EnvWithMetadata, use_buffers=use_buffers)
r = env.rollout(N, break_when_any_done=bwad)
assert r.get("non_tensor").tolist() == [list(range(N))] * 2


if __name__ == "__main__":
Expand Down
7 changes: 0 additions & 7 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,6 @@

import tensordict
import torch
from tensordict import (
LazyStackedTensorDict,
NonTensorData,
TensorDict,
TensorDictBase,
unravel_key,
)
from tensordict import (
is_tensor_collection,
LazyStackedTensorDict,
Expand Down
127 changes: 84 additions & 43 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
import torch

from tensordict import (
is_tensor_collection,LazyStackedTensorDict, NonTensorData, TensorDict,
is_tensor_collection,
LazyStackedTensorDict,
NonTensorData,
TensorDict,
TensorDictBase,
)
from tensordict._tensordict import unravel_key
Expand Down Expand Up @@ -659,6 +662,20 @@ def _create_td(self) -> None:
"batched environment base tensordict has the wrong shape"
)

# Non-tensor keys
non_tensor_keys = []
for spec in (
self.full_action_spec,
self.full_state_spec,
self.full_observation_spec,
self.full_reward_spec,
self.full_done_spec,
):
for key, spec in spec.items(True, True):
if isinstance(spec, NonTensorSpec):
non_tensor_keys.append(key)
self._non_tensor_keys = non_tensor_keys

if self._single_task:
self._env_input_keys = sorted(
list(self.input_spec["full_action_spec"].keys(True, True))
Expand Down Expand Up @@ -699,6 +716,15 @@ def _create_td(self) -> None:
)
)
env_output_keys = env_output_keys.union(self.reward_keys + self.done_keys)
env_obs_keys = [
key for key in env_obs_keys if key not in self._non_tensor_keys
]
env_input_keys = [
key for key in env_input_keys if key not in self._non_tensor_keys
]
env_output_keys = [
key for key in env_output_keys if key not in self._non_tensor_keys
]
self._env_obs_keys = sorted(env_obs_keys, key=_sort_keys)
self._env_input_keys = sorted(env_input_keys, key=_sort_keys)
self._env_output_keys = sorted(env_output_keys, key=_sort_keys)
Expand All @@ -725,20 +751,8 @@ def _create_td(self) -> None:
# output keys after step
self._selected_step_keys = {unravel_key(key) for key in self._env_output_keys}

# Non-tensor keys
non_tensor_keys = []
for key, spec in self.observation_spec.items(True, True):
if isinstance(spec, NonTensorSpec):
non_tensor_keys.append(key)
for key, spec in self.full_reward_spec.items(True, True):
if isinstance(spec, NonTensorSpec):
non_tensor_keys.append(key)
for key, spec in self.full_done_spec.items(True, True):
if isinstance(spec, NonTensorSpec):
non_tensor_keys.append(key)
self._non_tensor_keys = non_tensor_keys

if not self.share_individual_td:
shared_tensordict_parent = shared_tensordict_parent.filter_non_tensor_data()
shared_tensordict_parent = shared_tensordict_parent.select(
*self._selected_keys,
*(unravel_key(("next", key)) for key in self._env_output_keys),
Expand All @@ -757,7 +771,7 @@ def _create_td(self) -> None:
*self._selected_keys,
*(unravel_key(("next", key)) for key in self._env_output_keys),
strict=False,
)
).filter_non_tensor_data()
for tensordict in shared_tensordict_parent
]
shared_tensordict_parent = LazyStackedTensorDict.lazy_stack(
Expand Down Expand Up @@ -971,7 +985,8 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
(self.num_workers,), device=self.device, dtype=torch.bool
)

if not self._use_buffers:
out_tds = None
if not self._use_buffers or self._non_tensor_keys:
out_tds = [None] * self.num_workers

tds = []
Expand Down Expand Up @@ -1015,8 +1030,9 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
"share_individual_td argument to True."
)
raise
else:
if out_tds is not None:
out_tds[i] = _td

if not self._use_buffers:
result = LazyStackedTensorDict.maybe_dense_stack(out_tds)
return result
Expand All @@ -1034,6 +1050,11 @@ def select_and_clone(name, tensor):
nested_keys=True,
filter_empty=True,
)
if out_tds is not None:
out.update(
LazyStackedTensorDict(*out_tds), keys_to_update=self._non_tensor_keys
)

if out.device != device:
if device is None:
out = out.clear_device_()
Expand Down Expand Up @@ -1061,6 +1082,9 @@ def _step(
data_in.append(tensordict_in[i])

self._sync_m2w()
out_tds = None
if not self._use_buffers or self._non_tensor_keys:
out_tds = []

if self._use_buffers:
next_td = self.shared_tensordict_parent.get("next")
Expand All @@ -1071,12 +1095,13 @@ def _step(
keys_to_update=list(self._env_output_keys),
non_blocking=self.non_blocking,
)
if out_tds is not None:
out_tds.append(out_td)
else:
tds = []
for i, _data_in in enumerate(data_in):
out_td = self._envs[i]._step(_data_in)
tds.append(out_td)
return LazyStackedTensorDict.maybe_dense_stack(tds)
out_tds.append(out_td)
return LazyStackedTensorDict.maybe_dense_stack(out_tds)

# We must pass a clone of the tensordict, as the values of this tensordict
# will be modified in-place at further steps
Expand All @@ -1087,6 +1112,10 @@ def select_and_clone(name, tensor):
return tensor.clone()

out = next_td.named_apply(select_and_clone, nested_keys=True, filter_empty=True)
if out_tds is not None:
out.update(
LazyStackedTensorDict(*out_tds), keys_to_update=self._non_tensor_keys
)

if out.device != device:
if device is None:
Expand Down Expand Up @@ -1352,7 +1381,6 @@ def look_for_cuda(tensor, has_cuda=has_cuda):
"has_lazy_inputs": self.has_lazy_inputs,
"num_threads": num_sub_threads,
"non_blocking": self.non_blocking,
"_non_tensor_keys": self._non_tensor_keys,
}
)
if self._use_buffers:
Expand All @@ -1362,6 +1390,7 @@ def look_for_cuda(tensor, has_cuda=has_cuda):
"_selected_input_keys": self._selected_input_keys,
"_selected_reset_keys": self._selected_reset_keys,
"_selected_step_keys": self._selected_step_keys,
"_non_tensor_keys": self._non_tensor_keys,
}
)
process = proc_fun(target=func, kwargs=kwargs[idx])
Expand Down Expand Up @@ -1467,6 +1496,7 @@ def step_and_maybe_reset(
next_td_passthrough, non_blocking=self.non_blocking
)
else:
# next_td_keys = None
data = [{} for _ in range(self.num_workers)]

if self._non_tensor_keys:
Expand All @@ -1485,15 +1515,10 @@ def step_and_maybe_reset(
event.clear()

if self._non_tensor_keys:
non_tensor_tds = []
for i in range(self.num_workers):
msg, non_tensor_td = self.parent_channels[i].recv()
self.shared_tensordicts[i].update_(
non_tensor_td,
keys_to_update=[
*self._non_tensor_keys,
*[("next", key) for key in self._non_tensor_keys],
],
)
non_tensor_tds.append(non_tensor_td)

# We must pass a clone of the tensordict, as the values of this tensordict
# will be modified in-place at further steps
Expand Down Expand Up @@ -1523,6 +1548,13 @@ def step_and_maybe_reset(
next_td = next_td.clone().clear_device_()
tensordict_ = tensordict_.clone().clear_device_()
tensordict.set("next", next_td)
if self._non_tensor_keys:
non_tensor_tds = LazyStackedTensorDict(*non_tensor_tds)
tensordict.update(
non_tensor_tds,
keys_to_update=[("next", key) for key in self._non_tensor_keys],
)
tensordict_.update(non_tensor_tds, keys_to_update=self._non_tensor_keys)
return tensordict, tensordict_

def _step_no_buffers(
Expand Down Expand Up @@ -1593,11 +1625,10 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
event.clear()

if self._non_tensor_keys:
non_tensor_tds = []
for i in range(self.num_workers):
msg, non_tensor_td = self.parent_channels[i].recv()
self.shared_tensordicts[i].get("next").update_(
non_tensor_td, keys_to_update=self._non_tensor_keys
)
non_tensor_tds.append(non_tensor_td)

# We must pass a clone of the tensordict, as the values of this tensordict
# will be modified in-place at further steps
Expand All @@ -1622,6 +1653,11 @@ def select_and_clone(name, tensor):
filter_empty=True,
device=device,
)
if self._non_tensor_keys:
out.update(
LazyStackedTensorDict(*non_tensor_tds),
keys_to_update=self._non_tensor_keys,
)
self._sync_w2m()
return out

Expand Down Expand Up @@ -1739,12 +1775,11 @@ def tentative_update(val, other):
event.wait(self._timeout)
event.clear()

workers_nontensor = []
if self._non_tensor_keys:
for i in workers:
for i, _ in outs:
msg, non_tensor_td = self.parent_channels[i].recv()
self.shared_tensordicts[i].update_(
non_tensor_td, keys_to_update=self._non_tensor_keys
)
workers_nontensor.append((i, non_tensor_td))

selected_output_keys = self._selected_reset_keys_filt
device = self.device
Expand All @@ -1767,6 +1802,11 @@ def select_and_clone(name, tensor):
filter_empty=True,
device=device,
)
if self._non_tensor_keys:
workers, nontensor = zip(*workers_nontensor)
out[torch.tensor(workers)] = LazyStackedTensorDict(*nontensor).select(
*self._non_tensor_keys
)
self._sync_w2m()
return out

Expand Down Expand Up @@ -2007,14 +2047,14 @@ def look_for_cuda(tensor, has_cuda=has_cuda):
# No need to copy here since we don't write in-place
input = root_shared_tensordict
if data:
next_td_passthrough_keys = data.get("next_td_passthrough_keys", None)
next_td_passthrough_keys = data.get("next_td_passthrough_keys")
if next_td_passthrough_keys is not None:
input = input.set(
"next", next_shared_tensordict.select(*next_td_passthrough_keys)
)
non_tensor_data = data.get("non_tensor_data", None)
non_tensor_data = data.get("non_tensor_data")
if non_tensor_data is not None:
input.update_(non_tensor_data)
input.update(non_tensor_data)

next_td = env._step(input)
next_shared_tensordict.update_(next_td, non_blocking=non_blocking)
Expand Down Expand Up @@ -2051,9 +2091,10 @@ def look_for_cuda(tensor, has_cuda=has_cuda):
)
non_tensor_data = data.get("non_tensor_data", None)
if non_tensor_data is not None:
input.update_(non_tensor_data)
input.update(non_tensor_data)
td, root_next_td = env.step_and_maybe_reset(input)
next_shared_tensordict.update_(td.pop("next"), non_blocking=non_blocking)
td_next = td.pop("next")
next_shared_tensordict.update_(td_next, non_blocking=non_blocking)
root_shared_tensordict.update_(root_next_td, non_blocking=non_blocking)

if event is not None:
Expand All @@ -2063,7 +2104,7 @@ def look_for_cuda(tensor, has_cuda=has_cuda):

if _non_tensor_keys:
ntd = root_next_td.select(*_non_tensor_keys)
ntd.set("next", next_shared_tensordict.select(*_non_tensor_keys))
ntd.set("next", td_next.select(*_non_tensor_keys))
child_pipe.send(("non_tensor", ntd))

del td, root_next_td
Expand Down Expand Up @@ -2148,12 +2189,12 @@ def _run_worker_pipe_direct(
env = env_fun
del env_fun
for spec in env.output_spec.values(True, True):
if spec.device.type == "cuda":
if spec.device is not None and spec.device.type == "cuda":
has_cuda = True
break
else:
for spec in env.input_spec.values(True, True):
if spec.device.type == "cuda":
if spec.device is not None and spec.device.type == "cuda":
has_cuda = True
break
else:
Expand Down
Loading