Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Dynamic specs #2143

Merged
merged 27 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
amend
  • Loading branch information
vmoens committed May 26, 2024
commit 7b94c03f885255eb07eb8f78700a18c181b8f965
56 changes: 40 additions & 16 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import gc

import os
import tempfile
import weakref
from collections import OrderedDict
from copy import copy, deepcopy
Expand Down Expand Up @@ -184,6 +185,9 @@ class BatchedEnvBase(EnvBase):
Uses the default start method if not indicated ('spawn' by default in
TorchRL if not initiated differently before first import).
To be used only with :class:`~torchrl.envs.ParallelEnv` subclasses.
use_buffers (bool, optional): whether communication between workers should
occur via circular preallocated memory buffers. Defaults to ``True`` unless
one of the environment has dynamic specs.

.. note::
One can pass keyword arguments to each sub-environments using the following
Expand Down Expand Up @@ -297,13 +301,15 @@ def __init__(
serial_for_single: bool = False,
non_blocking: bool = False,
mp_start_method: str = None,
use_buffers: bool=None,
):
super().__init__(device=device)
self.serial_for_single = serial_for_single
self.is_closed = True
self.num_sub_threads = num_sub_threads
self.num_threads = num_threads
self._cache_in_keys = None
self._use_buffers = use_buffers

self._single_task = callable(create_env_fn) or (len(set(create_env_fn)) == 1)
if callable(create_env_fn):
Expand Down Expand Up @@ -353,7 +359,6 @@ def __init__(
f"Cannot use mp_start_method={mp_start_method} with envs of type {type(self)}."
)
self._mp_start_method = mp_start_method
self._use_buffers = None

@property
def non_blocking(self):
Expand Down Expand Up @@ -469,7 +474,12 @@ def _get_metadata(
self.meta_data = meta_data.expand(
*(self.num_workers, *meta_data.batch_size)
)
self._use_buffers = not self.meta_data.has_dynamic_specs
if self._use_buffers is not False:
_use_buffers = not self.meta_data.has_dynamic_specs
if self._use_buffers and not _use_buffers:
warn("A value of use_buffers=True was passed but this is incompatible "
"with the list of environments provided. Turning use_buffers to False.")
self._use_buffers = _use_buffers
if self.share_individual_td is None:
self.share_individual_td = False
else:
Expand All @@ -489,9 +499,14 @@ def _get_metadata(
"be True to accomodate non-stackable tensors."
)
self.share_individual_td = share_individual_td
self._use_buffers = all(
_use_buffers = all(
not metadata.has_dynamic_specs for metadata in self.meta_data
)
if self._use_buffers and not _use_buffers:
warn("A value of use_buffers=True was passed but this is incompatible "
"with the list of environments provided. Turning use_buffers to False.")
self._use_buffers = _use_buffers

self._set_properties()

def update_kwargs(self, kwargs: Union[dict, List[dict]]) -> None:
Expand Down Expand Up @@ -1368,17 +1383,26 @@ def load_state_dict(self, state_dict: OrderedDict) -> None:
def _step_and_maybe_reset_no_buffers(
self, tensordict: TensorDictBase
) -> Tuple[TensorDictBase, TensorDictBase]:
for i, data in enumerate(tensordict.unbind(0)):
self.parent_channels[i].send(("step_and_maybe_reset", data))
out_next, out_root = [], []
for i, channel in enumerate(self.parent_channels):
self._events[i].wait()
next_td, root_td = channel.recv()
out_root.append(root_td)
out_next.append(next_td)
return LazyStackedTensorDict.maybe_dense_stack(

for i, _data in enumerate(tensordict.unbind(0)):
self.parent_channels[i].send(("step_and_maybe_reset", _data))

results = [None] * self.num_workers

consumed_indices = []
events = set(range(self.num_workers))
while len(consumed_indices) < self.num_workers:
for i in list(events):
if self._events[i].is_set():
results[i] = self.parent_channels[i].recv()
self._events[i].clear()
consumed_indices.append(i)
events.discard(i)

out_next, out_root = zip(*(future for future in results))
return TensorDict.maybe_dense_stack(
out_next
), LazyStackedTensorDict.maybe_dense_stack(out_root)
), TensorDict.maybe_dense_stack(out_root)

@torch.no_grad()
@_check_start
Expand All @@ -1401,7 +1425,7 @@ def step_and_maybe_reset(
keys_to_update=self._env_input_keys,
non_blocking=self.non_blocking,
)
next_td_passthrough = tensordict.get("next")
next_td_passthrough = tensordict.get("next", default=None)
if next_td_passthrough is not None:
# if we have input "next" data (eg, RNNs which pass the next state)
# the sub-envs will need to process them through step_and_maybe_reset.
Expand Down Expand Up @@ -1483,7 +1507,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
keys_to_update=list(self._env_input_keys),
non_blocking=self.non_blocking,
)
next_td_passthrough = tensordict.get("next")
next_td_passthrough = tensordict.get("next", None)
if next_td_passthrough is not None:
# if we have input "next" data (eg, RNNs which pass the next state)
# the sub-envs will need to process them through step_and_maybe_reset.
Expand Down Expand Up @@ -2132,8 +2156,8 @@ def _run_worker_pipe_direct(
if event is not None:
event.record()
event.synchronize()
mp_event.set()
child_pipe.send((td, root_next_td))
mp_event.set()
del td, root_next_td

elif cmd == "close":
Expand Down
49 changes: 35 additions & 14 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,14 +348,14 @@ def __init__(
self.__dict__.setdefault("_batch_size", None)
if device is not None:
self.__dict__["_device"] = torch.device(device)
output_spec = self.__dict__.get("_output_spec", None)
output_spec = self.__dict__.get("_output_spec")
if output_spec is not None:
self.__dict__["_output_spec"] = (
output_spec.to(self.device)
if self.device is not None
else output_spec
)
input_spec = self.__dict__.get("_input_spec", None)
input_spec = self.__dict__.get("_input_spec")
if input_spec is not None:
self.__dict__["_input_spec"] = (
input_spec.to(self.device)
Expand Down Expand Up @@ -477,12 +477,12 @@ def shape(self):

@property
def device(self) -> torch.device:
device = self.__dict__.get("_device", None)
device = self.__dict__.get("_device")
return device

@device.setter
def device(self, value: torch.device) -> None:
device = self.__dict__.get("_device", None)
device = self.__dict__.get("_device")
if device is None:
self.__dict__["_device"] = value
return
Expand Down Expand Up @@ -557,7 +557,7 @@ def input_spec(self) -> TensorSpec:


"""
input_spec = self.__dict__.get("_input_spec", None)
input_spec = self.__dict__.get("_input_spec")
if input_spec is None:
input_spec = CompositeSpec(
full_state_spec=None,
Expand Down Expand Up @@ -617,7 +617,7 @@ def output_spec(self) -> TensorSpec:


"""
output_spec = self.__dict__.get("_output_spec", None)
output_spec = self.__dict__.get("_output_spec")
if output_spec is None:
output_spec = CompositeSpec(
shape=self.batch_size,
Expand All @@ -638,7 +638,7 @@ def action_keys(self) -> List[NestedKey]:

Keys are sorted by depth in the data tree.
"""
action_keys = self.__dict__.get("_action_keys", None)
action_keys = self.__dict__.get("_action_keys")
if action_keys is not None:
return action_keys
keys = self.input_spec["full_action_spec"].keys(True, True)
Expand All @@ -648,6 +648,22 @@ def action_keys(self) -> List[NestedKey]:
self.__dict__["_action_keys"] = keys
return keys

@property
def state_keys(self) -> List[NestedKey]:
"""The state keys of an environment.

By default, there will only be one key named "state".

Keys are sorted by depth in the data tree.
"""
state_keys = self.__dict__.get("_state_keys")
if state_keys is not None:
return state_keys
keys = self.input_spec["full_state_spec"].keys(True, True)
keys = sorted(keys, key=_repr_by_depth)
self.__dict__["_state_keys"] = keys
return keys

@property
def action_key(self) -> NestedKey:
"""The action key of an environment.
Expand Down Expand Up @@ -833,7 +849,7 @@ def reward_keys(self) -> List[NestedKey]:

Keys are sorted by depth in the data tree.
"""
reward_keys = self.__dict__.get("_reward_keys", None)
reward_keys = self.__dict__.get("_reward_keys")
if reward_keys is not None:
return reward_keys

Expand Down Expand Up @@ -1029,7 +1045,7 @@ def done_keys(self) -> List[NestedKey]:

Keys are sorted by depth in the data tree.
"""
done_keys = self.__dict__.get("_done_keys", None)
done_keys = self.__dict__.get("_done_keys")
if done_keys is not None:
return done_keys
done_keys = sorted(self.full_done_spec.keys(True, True), key=_repr_by_depth)
Expand Down Expand Up @@ -1384,6 +1400,10 @@ def state_spec(self) -> CompositeSpec:
def state_spec(self, value: CompositeSpec) -> None:
try:
self.input_spec.unlock_()
try:
delattr(self, "_state_keys")
except AttributeError:
pass
if value is None:
self.input_spec["full_state_spec"] = CompositeSpec(
device=self.device, shape=self.batch_size
Expand Down Expand Up @@ -2616,7 +2636,7 @@ def add_truncated_keys(self) -> EnvBase:

@property
def _step_mdp(self):
step_func = self.__dict__.get("_step_mdp_value", None)
step_func = self.__dict__.get("_step_mdp_value")
if step_func is None:
step_func = _StepMDP(self, exclude_action=False)
self.__dict__["_step_mdp_value"] = step_func
Expand Down Expand Up @@ -2769,7 +2789,7 @@ def step_and_maybe_reset(

@property
def _simple_done(self):
_simple_done = self.__dict__.get("_simple_done_value", None)
_simple_done = self.__dict__.get("_simple_done_value")
if _simple_done is None:
key_set = set(self.full_done_spec.keys())
_simple_done = key_set == {
Expand Down Expand Up @@ -2823,6 +2843,7 @@ def empty_cache(self):
self.__dict__["_reward_keys"] = None
self.__dict__["_done_keys"] = None
self.__dict__["_action_keys"] = None
self.__dict__["_state_keys"] = None
self.__dict__["_done_keys_group"] = None

@property
Expand All @@ -2836,7 +2857,7 @@ def reset_keys(self) -> List[NestedKey]:

Keys are sorted by depth in the data tree.
"""
reset_keys = self.__dict__.get("_reset_keys", None)
reset_keys = self.__dict__.get("_reset_keys")
if reset_keys is not None:
return reset_keys

Expand Down Expand Up @@ -2880,7 +2901,7 @@ def done_keys_groups(self):
inner lists contain the done keys (eg, done and truncated) that can
be read to determine a reset when it is absent.
"""
done_keys_group = self.__dict__.get("_done_keys_group", None)
done_keys_group = self.__dict__.get("_done_keys_group")
if done_keys_group is not None:
return done_keys_group

Expand Down Expand Up @@ -3038,7 +3059,7 @@ def __init__(
self._init_env() # runs all the steps to have a ready-to-use env

def _sync_device(self):
sync_func = self.__dict__.get("_sync_device_val", None)
sync_func = self.__dict__.get("_sync_device_val")
if sync_func is None:
device = self.device
if device.type != "cuda":
Expand Down
Loading