Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Dynamic specs #2143

Merged
merged 27 commits into from
May 31, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
amend
  • Loading branch information
vmoens committed May 26, 2024
commit 6c7a6d02ac73711e259acf6ee2aff7a0c3e5b474
74 changes: 39 additions & 35 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def __init__(
serial_for_single: bool = False,
non_blocking: bool = False,
mp_start_method: str = None,
use_buffers: bool=None,
use_buffers: bool = None,
):
super().__init__(device=device)
self.serial_for_single = serial_for_single
Expand Down Expand Up @@ -477,8 +477,10 @@ def _get_metadata(
if self._use_buffers is not False:
_use_buffers = not self.meta_data.has_dynamic_specs
if self._use_buffers and not _use_buffers:
warn("A value of use_buffers=True was passed but this is incompatible "
"with the list of environments provided. Turning use_buffers to False.")
warn(
"A value of use_buffers=True was passed but this is incompatible "
"with the list of environments provided. Turning use_buffers to False."
)
self._use_buffers = _use_buffers
if self.share_individual_td is None:
self.share_individual_td = False
Expand All @@ -503,8 +505,10 @@ def _get_metadata(
not metadata.has_dynamic_specs for metadata in self.meta_data
)
if self._use_buffers and not _use_buffers:
warn("A value of use_buffers=True was passed but this is incompatible "
"with the list of environments provided. Turning use_buffers to False.")
warn(
"A value of use_buffers=True was passed but this is incompatible "
"with the list of environments provided. Turning use_buffers to False."
)
self._use_buffers = _use_buffers

self._set_properties()
Expand Down Expand Up @@ -1383,26 +1387,27 @@ def load_state_dict(self, state_dict: OrderedDict) -> None:
def _step_and_maybe_reset_no_buffers(
self, tensordict: TensorDictBase
) -> Tuple[TensorDictBase, TensorDictBase]:

for i, _data in enumerate(tensordict.unbind(0)):
self.parent_channels[i].send(("step_and_maybe_reset", _data))

results = [None] * self.num_workers

consumed_indices = []
events = set(range(self.num_workers))
while len(consumed_indices) < self.num_workers:
for i in list(events):
if self._events[i].is_set():
results[i] = self.parent_channels[i].recv()
self._events[i].clear()
consumed_indices.append(i)
events.discard(i)

out_next, out_root = zip(*(future for future in results))
return TensorDict.maybe_dense_stack(
out_next
), TensorDict.maybe_dense_stack(out_root)
return super().step_and_maybe_reset(tensordict)

# for i, _data in enumerate(tensordict.unbind(0)):
# self.parent_channels[i].send(("step_and_maybe_reset", _data))
#
# results = [None] * self.num_workers
#
# consumed_indices = []
# events = set(range(self.num_workers))
# while len(consumed_indices) < self.num_workers:
# for i in list(events):
# if self._events[i].is_set():
# results[i] = self.parent_channels[i].recv()
# self._events[i].clear()
# consumed_indices.append(i)
# events.discard(i)
#
# out_next, out_root = zip(*(future for future in results))
# return TensorDict.maybe_dense_stack(
# out_next
# ), TensorDict.maybe_dense_stack(out_root)

@torch.no_grad()
@_check_start
Expand Down Expand Up @@ -1578,6 +1583,7 @@ def _reset_no_buffers(
for i, channel in enumerate(self.parent_channels):
if not needs_resetting[i]:
out_tds.append(None)
continue
self._events[i].wait()
td = channel.recv()
out_tds.append(td)
Expand Down Expand Up @@ -2115,6 +2121,10 @@ def _run_worker_pipe_direct(
# we use 'data' to pass the keys that we need to pass to reset,
# because passing the entire buffer may have unwanted consequences
data, reset_kwargs = data
if data is not None:
data._fast_apply(
lambda x: x.clone() if x.device.type == "cuda" else x, out=data
)
cur_td = env.reset(
tensordict=data,
**reset_kwargs,
Expand All @@ -2130,7 +2140,6 @@ def _run_worker_pipe_direct(
if not initialized:
raise RuntimeError("called 'init' before step")
i += 1
# No need to copy here since we don't write in-place
next_td = env._step(data)
if event is not None:
event.record()
Expand All @@ -2143,15 +2152,10 @@ def _run_worker_pipe_direct(
if not initialized:
raise RuntimeError("called 'init' before step")
i += 1
# We must copy the root shared td here, or at least get rid of done:
# if we don't `td is root_shared_tensordict`
# which means that root_shared_tensordict will carry the content of next
# in the next iteration. When using StepCounter, it will look for an
# existing done state, find it and consider the env as done by input (not
# by output) of the step!
# Caveat: for RNN we may need some keys of the "next" TD so we pass the list
# through data
td, root_next_td = env.step_and_maybe_reset(data.clone())
data._fast_apply(
lambda x: x.clone() if x.device.type == "cuda" else x, out=data
)
td, root_next_td = env.step_and_maybe_reset(data)

if event is not None:
event.record()
Expand Down
Loading