From 225f92fe8a5423a94c18852e321a6578d13a5229 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 2 Jul 2022 21:33:40 +0100 Subject: [PATCH] Bugfix: Double2Float default behaviour (#242) --- torchrl/envs/gym_like.py | 26 ++++++++++++++++++-------- torchrl/envs/libs/gym.py | 12 +++++++++++- torchrl/envs/transforms/transforms.py | 9 +++++++-- 3 files changed, 36 insertions(+), 11 deletions(-) diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 83c9941e12c..50860f56892 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -1,5 +1,6 @@ from __future__ import annotations +import warnings from typing import Optional, Union, Tuple import numpy as np @@ -38,6 +39,11 @@ def __init__(self, keys=None): self.keys = keys def __call__(self, info_dict: dict, tensordict: _TensorDict) -> _TensorDict: + if not isinstance(info_dict, dict) and len(self.keys): + warnings.warn( + f"Found an info_dict of type {type(info_dict)} " + f"but expected type or subtype `dict`." + ) for key in self.keys: if key in info_dict: tensordict[key] = info_dict[key] @@ -67,6 +73,11 @@ class GymLikeEnv(_EnvWrapper): It is also expected that env.reset() returns an observation similar to the one observed after a step is completed. """ + @classmethod + def __new__(cls, *args, **kwargs): + cls._info_dict_reader = None + return super().__new__(cls, *args, **kwargs) + def _step(self, tensordict: _TensorDict) -> _TensorDict: action = tensordict.get("action") action_np = self.action_spec.to_numpy(action, safe=False) @@ -98,7 +109,8 @@ def _step(self, tensordict: _TensorDict) -> _TensorDict: ) tensordict_out.set("reward", reward) tensordict_out.set("done", done) - self.info_dict_reader(info, tensordict_out) + if self.info_dict_reader is not None: + self.info_dict_reader(*info, tensordict_out) return tensordict_out @@ -156,17 +168,15 @@ def set_info_dict_reader(self, info_dict_reader: callable) -> GymLikeEnv: self.info_dict_reader = info_dict_reader return self + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(env={self._env}, batch_size={self.batch_size})" + ) + @property def info_dict_reader(self): - if "_info_dict_reader" not in self.__dir__(): - self._info_dict_reader = default_info_dict_reader() return self._info_dict_reader @info_dict_reader.setter def info_dict_reader(self, value: callable): self._info_dict_reader = value - - def __repr__(self) -> str: - return ( - f"{self.__class__.__name__}(env={self._env}, batch_size={self.batch_size})" - ) diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index 3f865aa5b0f..b77bc9da248 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -19,7 +19,7 @@ UnboundedContinuousTensorSpec, ) from ...data.utils import numpy_to_torch_dtype_dict -from ..gym_like import GymLikeEnv +from ..gym_like import GymLikeEnv, default_info_dict_reader from ..utils import classproperty try: @@ -226,6 +226,16 @@ def rebuild_with_kwargs(self, **new_kwargs): self._env = self._build_env(**self._constructor_kwargs) self._make_specs(self._env) + @property + def info_dict_reader(self): + if self._info_dict_reader is None: + self._info_dict_reader = default_info_dict_reader() + return self._info_dict_reader + + @info_dict_reader.setter + def info_dict_reader(self, value: callable): + self._info_dict_reader = value + class GymEnv(GymWrapper): """ diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index d11e63bd628..40223215d4e 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -1245,8 +1245,6 @@ def __init__( keys_in: Optional[Sequence[str]] = None, keys_inv_in: Optional[Sequence[str]] = None, ): - if keys_inv_in is None: - keys_inv_in = ["action"] super().__init__(keys_in=keys_in, keys_inv_in=keys_inv_in) def _apply_transform(self, obs: torch.Tensor) -> torch.Tensor: @@ -1286,6 +1284,13 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec self._transform_spec(observation_spec) return observation_spec + def __repr__(self) -> str: + s = ( + f"{self.__class__.__name__}(keys_in={self.keys_in}, keys_out={self.keys_out}," + f"keys_inv_in={self.keys_inv_in}, keys_inv_out={self.keys_inv_out})" + ) + return s + class CatTensors(Transform): """