Skip to content

Commit

Permalink
[Refactor,Performance] Faster collectors (bis) (pytorch#1331)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jul 7, 2023
1 parent 17faca8 commit fcb04e4
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 24 deletions.
15 changes: 15 additions & 0 deletions benchmarks/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import os
import time
import warnings
from collections import defaultdict

import pytest
Expand Down Expand Up @@ -54,3 +55,17 @@ def fin():

def pytest_addoption(parser):
parser.addoption("--rank", action="store")


@pytest.fixture(autouse=True)
def set_warnings() -> None:
warnings.filterwarnings(
"ignore",
category=UserWarning,
message=r"Lazy modules are a new feature under heavy development",
)
warnings.filterwarnings(
"ignore",
category=UserWarning,
message=r"Couldn't cast the policy onto the desired device on remote process",
)
5 changes: 5 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,8 @@ def set_warnings() -> None:
category=UserWarning,
message=r"Lazy modules are a new feature under heavy development",
)
warnings.filterwarnings(
"ignore",
category=UserWarning,
message=r"Couldn't cast the policy onto the desired device on remote process",
)
65 changes: 41 additions & 24 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,9 +771,8 @@ def iterator(self) -> Iterator[TensorDictBase]:
break

def _step_and_maybe_reset(self) -> None:
done = self._tensordict.get(("next", *self.env.done_key))
done = self._tensordict.get(("next", self.env.done_key))
truncated = self._tensordict.get(("next", "truncated"), None)
traj_ids = self._tensordict.get(("collector", "traj_ids"))

self._tensordict = step_mdp(
self._tensordict,
Expand All @@ -788,23 +787,28 @@ def _step_and_maybe_reset(self) -> None:
done_or_terminated = (
(done | truncated) if truncated is not None else done.clone()
)

if done_or_terminated.any():
traj_ids = self._tensordict.get(("collector", "traj_ids"))
traj_ids = traj_ids.clone()
# collectors do not support passing other tensors than `"_reset"`
# to `reset()`.
_reset = done_or_terminated
td_reset = self._tensordict.select().set("_reset", _reset)
td_reset = self.env.reset(td_reset)
td_reset.del_("_reset")
traj_done_or_terminated = done_or_terminated.sum(
tuple(range(self._tensordict.batch_dims, done_or_terminated.ndim)),
dtype=torch.bool,
)
if td_reset.batch_dims:
# better cloning here than when passing the td for stacking
# cloning is necessary to avoid modifying dones in-place
self._tensordict = self._tensordict.clone()
self._tensordict.get_sub_tensordict(traj_done_or_terminated).update(
td_reset[traj_done_or_terminated], inplace=True
td_reset[traj_done_or_terminated]
)
else:
self._tensordict.update(td_reset, inplace=True)
self._tensordict.update(td_reset)

done = self._tensordict.get(self.env.done_key)
if done.any():
Expand All @@ -814,9 +818,7 @@ def _step_and_maybe_reset(self) -> None:
traj_ids[traj_done_or_terminated] = traj_ids.max() + torch.arange(
1, traj_done_or_terminated.sum() + 1, device=traj_ids.device
)
self._tensordict.set_(
("collector", "traj_ids"), traj_ids
) # no ops if they already match
self._tensordict.set(("collector", "traj_ids"), traj_ids)

@torch.no_grad()
def rollout(self) -> TensorDictBase:
Expand All @@ -827,40 +829,55 @@ def rollout(self) -> TensorDictBase:
"""
if self.reset_at_each_iter:
self._tensordict.update(self.env.reset(), inplace=True)
self._tensordict.update(self.env.reset())

# self._tensordict.fill_(("collector", "step_count"), 0)
self._tensordict_out.fill_(("collector", "traj_ids"), -1)

tensordicts = []
with set_exploration_type(self.exploration_type):
for j in range(self.frames_per_batch):
for t in range(self.frames_per_batch):
if self._frames < self.init_random_frames:
self.env.rand_step(self._tensordict)
else:
self.policy(self._tensordict)
self.env.step(self._tensordict)

# we must clone all the values, since the step / traj_id updates are done in-place
try:
self._tensordict_out[..., j] = self._tensordict
except RuntimeError:
# unlock the output tensordict to allow for new keys to be written
# these will be missed during the sync but at least we won't get an error during the update
is_shared = self._tensordict_out.is_shared()
self._tensordict_out.unlock_()
self._tensordict_out[..., j] = self._tensordict
if is_shared:
self._tensordict_out.share_memory_()
else:
self._tensordict_out.lock()
tensordicts.append(self._tensordict.to(self.storing_device))

self._step_and_maybe_reset()
if (
self.interruptor is not None
and self.interruptor.collection_stopped()
):
try:
torch.stack(
tensordicts,
self._tensordict_out.ndim - 1,
out=self._tensordict_out[: t + 1],
)
except RuntimeError:
with self._tensordict_out.unlock_():
torch.stack(
tensordicts,
self._tensordict_out.ndim - 1,
out=self._tensordict_out[: t + 1],
)
break

else:
try:
self._tensordict_out = torch.stack(
tensordicts,
self._tensordict_out.ndim - 1,
out=self._tensordict_out,
)
except RuntimeError:
with self._tensordict_out.unlock_():
self._tensordict_out = torch.stack(
tensordicts,
self._tensordict_out.ndim - 1,
out=self._tensordict_out,
)
return self._tensordict_out

def reset(self, index=None, **kwargs) -> None:
Expand Down

0 comments on commit fcb04e4

Please sign in to comment.