Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Aug 9, 2024
2 parents 766cfbc + a3532a1 commit 732d13b
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 79 deletions.
6 changes: 1 addition & 5 deletions knowledge_base/VIDEO_CUSTOMISATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,5 @@ as advised by the documentation.
We can improve the video quality by appending all our desired settings
(as keyword arguments) to `recorder` like so:
```python
# The arguments' types don't appear to matter too much, as long as they are
# appropriate for Python.
# For example, this would work as well:
# logger = CSVLogger(exp_name="my_exp", crf=17, preset="slow")
logger = CSVLogger(exp_name="my_exp", crf="17", preset="slow")
recorder = VideoRecorder(logger, tag = "my_video", options = {"crf": "17", "preset": "slow"})
```
148 changes: 76 additions & 72 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3344,88 +3344,92 @@ def test_pendulum_env(self):
@pytest.mark.parametrize("env_device", [None, *get_default_devices()])
class TestPartialSteps:
@pytest.mark.parametrize("use_buffers", [False, True])
def test_parallel_partial_steps(self, use_buffers, device, env_device):
penv = ParallelEnv(
4,
lambda: CountingEnv(max_steps=10, start_val=2, device=env_device),
mp_start_method=mp_ctx,
use_buffers=use_buffers,
device=device,
)
td = penv.reset()
psteps = torch.zeros(4, dtype=torch.bool)
psteps[[1, 3]] = True
td.set("_partial_steps", psteps)

td.set("action", penv.action_spec.one())
td = penv.step(td)
assert (td[0].get("next") == 0).all()
assert (td[1].get("next") != 0).any()
assert (td[2].get("next") == 0).all()
assert (td[3].get("next") != 0).any()
def test_parallel_partial_steps(
self, use_buffers, device, env_device, maybe_fork_ParallelEnv
):
with torch.device(device):
penv = maybe_fork_ParallelEnv(
4,
lambda: CountingEnv(max_steps=10, start_val=2, device=env_device),
use_buffers=use_buffers,
device=device,
)
td = penv.reset()
psteps = torch.zeros(4, dtype=torch.bool)
psteps[[1, 3]] = True
td.set("_partial_steps", psteps)

td.set("action", penv.action_spec.one())
td = penv.step(td)
assert (td[0].get("next") == 0).all()
assert (td[1].get("next") != 0).any()
assert (td[2].get("next") == 0).all()
assert (td[3].get("next") != 0).any()

@pytest.mark.parametrize("use_buffers", [False, True])
def test_parallel_partial_step_and_maybe_reset(
self, use_buffers, device, env_device
self, use_buffers, device, env_device, maybe_fork_ParallelEnv
):
penv = ParallelEnv(
4,
lambda: CountingEnv(max_steps=10, start_val=2, device=env_device),
mp_start_method=mp_ctx,
use_buffers=use_buffers,
device=device,
)
td = penv.reset()
psteps = torch.zeros(4, dtype=torch.bool)
psteps[[1, 3]] = True
td.set("_partial_steps", psteps)

td.set("action", penv.action_spec.one())
td, tdreset = penv.step_and_maybe_reset(td)
assert (td[0].get("next") == 0).all()
assert (td[1].get("next") != 0).any()
assert (td[2].get("next") == 0).all()
assert (td[3].get("next") != 0).any()
with torch.device(device):
penv = maybe_fork_ParallelEnv(
4,
lambda: CountingEnv(max_steps=10, start_val=2, device=env_device),
use_buffers=use_buffers,
device=device,
)
td = penv.reset()
psteps = torch.zeros(4, dtype=torch.bool)
psteps[[1, 3]] = True
td.set("_partial_steps", psteps)

td.set("action", penv.action_spec.one())
td, tdreset = penv.step_and_maybe_reset(td)
assert (td[0].get("next") == 0).all()
assert (td[1].get("next") != 0).any()
assert (td[2].get("next") == 0).all()
assert (td[3].get("next") != 0).any()

@pytest.mark.parametrize("use_buffers", [False, True])
def test_serial_partial_steps(self, use_buffers, device, env_device):
penv = SerialEnv(
4,
lambda: CountingEnv(max_steps=10, start_val=2, device=env_device),
use_buffers=use_buffers,
device=device,
)
td = penv.reset()
psteps = torch.zeros(4, dtype=torch.bool)
psteps[[1, 3]] = True
td.set("_partial_steps", psteps)

td.set("action", penv.action_spec.one())
td = penv.step(td)
assert (td[0].get("next") == 0).all()
assert (td[1].get("next") != 0).any()
assert (td[2].get("next") == 0).all()
assert (td[3].get("next") != 0).any()
with torch.device(device):
penv = SerialEnv(
4,
lambda: CountingEnv(max_steps=10, start_val=2, device=env_device),
use_buffers=use_buffers,
device=device,
)
td = penv.reset()
psteps = torch.zeros(4, dtype=torch.bool)
psteps[[1, 3]] = True
td.set("_partial_steps", psteps)

td.set("action", penv.action_spec.one())
td = penv.step(td)
assert (td[0].get("next") == 0).all()
assert (td[1].get("next") != 0).any()
assert (td[2].get("next") == 0).all()
assert (td[3].get("next") != 0).any()

@pytest.mark.parametrize("use_buffers", [False, True])
def test_serial_partial_step_and_maybe_reset(self, use_buffers, device, env_device):
penv = SerialEnv(
4,
lambda: CountingEnv(max_steps=10, start_val=2, device=env_device),
use_buffers=use_buffers,
device=device,
)
td = penv.reset()
psteps = torch.zeros(4, dtype=torch.bool)
psteps[[1, 3]] = True
td.set("_partial_steps", psteps)

td.set("action", penv.action_spec.one())
td = penv.step(td)
assert (td[0].get("next") == 0).all()
assert (td[1].get("next") != 0).any()
assert (td[2].get("next") == 0).all()
assert (td[3].get("next") != 0).any()
with torch.device(device):
penv = SerialEnv(
4,
lambda: CountingEnv(max_steps=10, start_val=2, device=env_device),
use_buffers=use_buffers,
device=device,
)
td = penv.reset()
psteps = torch.zeros(4, dtype=torch.bool)
psteps[[1, 3]] = True
td.set("_partial_steps", psteps)

td.set("action", penv.action_spec.one())
td = penv.step(td)
assert (td[0].get("next") == 0).all()
assert (td[1].get("next") != 0).any()
assert (td[2].get("next") == 0).all()
assert (td[3].get("next") != 0).any()


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3087,6 +3087,7 @@ def __init__(

self._constructor_kwargs = kwargs
self._check_kwargs(kwargs)
self._convert_actions_to_numpy = kwargs.pop("convert_actions_to_numpy", True)
self._env = self._build_env(**kwargs) # writes the self._env attribute
self._make_specs(self._env) # writes the self._env attribute
self.is_closed = False
Expand Down
6 changes: 4 additions & 2 deletions torchrl/envs/gym_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ class GymLikeEnv(_EnvWrapper):
def __new__(cls, *args, **kwargs):
self = super().__new__(cls, *args, _batch_locked=True, **kwargs)
self._info_dict_reader = []

return self

def read_action(self, action):
Expand Down Expand Up @@ -289,7 +290,8 @@ def read_obs(

def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
action = tensordict.get(self.action_key)
action_np = self.read_action(action)
if self._convert_actions_to_numpy:
action = self.read_action(action)

reward = 0
for _ in range(self.wrapper_frame_skip):
Expand All @@ -300,7 +302,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
truncated,
done,
info_dict,
) = self._output_transform(self._env.step(action_np))
) = self._output_transform(self._env.step(action))

if _reward is not None:
reward = reward + _reward
Expand Down
5 changes: 5 additions & 0 deletions torchrl/envs/libs/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,11 @@ class GymWrapper(GymLikeEnv, metaclass=_AsyncMeta):
allow_done_after_reset (bool, optional): if ``True``, it is tolerated
for envs to be ``done`` just after :meth:`~.reset` is called.
Defaults to ``False``.
convert_actions_to_numpy (bool, optional): if ``True``, actions will be
converted from tensors to numpy arrays and moved to CPU before being passed to the
env step function. Set this to ``False`` if the environment is evaluated
on GPU, such as IsaacLab.
Defaults to ``True``.
Attributes:
available_envs (List[str]): a list of environments to build.
Expand Down

0 comments on commit 732d13b

Please sign in to comment.