diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 3277158af57..16ce7d0d534 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -3039,6 +3039,7 @@ def __init__( self._constructor_kwargs = kwargs self._check_kwargs(kwargs) + self._convert_actions_to_numpy = kwargs.pop("convert_actions_to_numpy", True) self._env = self._build_env(**kwargs) # writes the self._env attribute self._make_specs(self._env) # writes the self._env attribute self.is_closed = False diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index d2b6e0f23fa..82f42180913 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -172,6 +172,7 @@ class GymLikeEnv(_EnvWrapper): def __new__(cls, *args, **kwargs): self = super().__new__(cls, *args, _batch_locked=True, **kwargs) self._info_dict_reader = [] + return self def read_action(self, action): @@ -289,7 +290,8 @@ def read_obs( def _step(self, tensordict: TensorDictBase) -> TensorDictBase: action = tensordict.get(self.action_key) - action_np = self.read_action(action) + if self._convert_actions_to_numpy: + action = self.read_action(action) reward = 0 for _ in range(self.wrapper_frame_skip): @@ -300,7 +302,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: truncated, done, info_dict, - ) = self._output_transform(self._env.step(action_np)) + ) = self._output_transform(self._env.step(action)) if _reward is not None: reward = reward + _reward diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index 8431d155ee2..34af87b75f9 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -645,6 +645,11 @@ class GymWrapper(GymLikeEnv, metaclass=_AsyncMeta): allow_done_after_reset (bool, optional): if ``True``, it is tolerated for envs to be ``done`` just after :meth:`~.reset` is called. Defaults to ``False``. + convert_actions_to_numpy (bool, optional): if ``True``, actions will be + converted from tensors to numpy arrays and moved to CPU before being passed to the + env step function. Set this to ``False`` if the environment is evaluated + on GPU, such as IsaacLab. + Defaults to ``True``. Attributes: available_envs (List[str]): a list of environments to build.