Skip to content

Commit

Permalink
[Feature] Partial steps in batched envs
Browse files Browse the repository at this point in the history
ghstack-source-id: a1a69e55cddf10290cb59dc1a3c6136bd257368a
Pull Request resolved: pytorch#2377
  • Loading branch information
vmoens committed Aug 12, 2024
1 parent a6310ae commit 430f1bd
Show file tree
Hide file tree
Showing 3 changed files with 290 additions and 62 deletions.
5 changes: 4 additions & 1 deletion test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,7 +1038,10 @@ def _step(
tensordict: TensorDictBase,
) -> TensorDictBase:
action = tensordict.get(self.action_key)
self.count += action.to(dtype=torch.int, device=self.device)
self.count += action.to(
dtype=torch.int,
device=self.action_spec.device if self.device is None else self.device,
)
tensordict = TensorDict(
source={
"observation": self.count.clone(),
Expand Down
93 changes: 93 additions & 0 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

import argparse
import contextlib
import functools
import gc
import os.path
Expand Down Expand Up @@ -3340,6 +3341,98 @@ def test_pendulum_env(self):
assert r.shape == torch.Size((5, 10))


@pytest.mark.parametrize("device", [None, *get_default_devices()])
@pytest.mark.parametrize("env_device", [None, *get_default_devices()])
class TestPartialSteps:
@pytest.mark.parametrize("use_buffers", [False, True])
def test_parallel_partial_steps(
self, use_buffers, device, env_device, maybe_fork_ParallelEnv
):
with torch.device(device) if device is not None else contextlib.nullcontext():
penv = maybe_fork_ParallelEnv(
4,
lambda: CountingEnv(max_steps=10, start_val=2, device=env_device),
use_buffers=use_buffers,
device=device,
)
td = penv.reset()
psteps = torch.zeros(4, dtype=torch.bool)
psteps[[1, 3]] = True
td.set("_step", psteps)

td.set("action", penv.action_spec.one())
td = penv.step(td)
assert (td[0].get("next") == 0).all()
assert (td[1].get("next") != 0).any()
assert (td[2].get("next") == 0).all()
assert (td[3].get("next") != 0).any()

@pytest.mark.parametrize("use_buffers", [False, True])
def test_parallel_partial_step_and_maybe_reset(
self, use_buffers, device, env_device, maybe_fork_ParallelEnv
):
with torch.device(device) if device is not None else contextlib.nullcontext():
penv = maybe_fork_ParallelEnv(
4,
lambda: CountingEnv(max_steps=10, start_val=2, device=env_device),
use_buffers=use_buffers,
device=device,
)
td = penv.reset()
psteps = torch.zeros(4, dtype=torch.bool)
psteps[[1, 3]] = True
td.set("_step", psteps)

td.set("action", penv.action_spec.one())
td, tdreset = penv.step_and_maybe_reset(td)
assert (td[0].get("next") == 0).all()
assert (td[1].get("next") != 0).any()
assert (td[2].get("next") == 0).all()
assert (td[3].get("next") != 0).any()

@pytest.mark.parametrize("use_buffers", [False, True])
def test_serial_partial_steps(self, use_buffers, device, env_device):
with torch.device(device) if device is not None else contextlib.nullcontext():
penv = SerialEnv(
4,
lambda: CountingEnv(max_steps=10, start_val=2, device=env_device),
use_buffers=use_buffers,
device=device,
)
td = penv.reset()
psteps = torch.zeros(4, dtype=torch.bool)
psteps[[1, 3]] = True
td.set("_step", psteps)

td.set("action", penv.action_spec.one())
td = penv.step(td)
assert (td[0].get("next") == 0).all()
assert (td[1].get("next") != 0).any()
assert (td[2].get("next") == 0).all()
assert (td[3].get("next") != 0).any()

@pytest.mark.parametrize("use_buffers", [False, True])
def test_serial_partial_step_and_maybe_reset(self, use_buffers, device, env_device):
with torch.device(device) if device is not None else contextlib.nullcontext():
penv = SerialEnv(
4,
lambda: CountingEnv(max_steps=10, start_val=2, device=env_device),
use_buffers=use_buffers,
device=device,
)
td = penv.reset()
psteps = torch.zeros(4, dtype=torch.bool)
psteps[[1, 3]] = True
td.set("_step", psteps)

td.set("action", penv.action_spec.one())
td = penv.step(td)
assert (td[0].get("next") == 0).all()
assert (td[1].get("next") != 0).any()
assert (td[2].get("next") == 0).all()
assert (td[3].get("next") != 0).any()


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
Loading

0 comments on commit 430f1bd

Please sign in to comment.