Skip to content

Commit

Permalink
Bugfix: Double2Float default behaviour (pytorch#242)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jul 2, 2022
1 parent 5ed32a9 commit 225f92f
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 11 deletions.
26 changes: 18 additions & 8 deletions torchrl/envs/gym_like.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import warnings
from typing import Optional, Union, Tuple

import numpy as np
Expand Down Expand Up @@ -38,6 +39,11 @@ def __init__(self, keys=None):
self.keys = keys

def __call__(self, info_dict: dict, tensordict: _TensorDict) -> _TensorDict:
if not isinstance(info_dict, dict) and len(self.keys):
warnings.warn(
f"Found an info_dict of type {type(info_dict)} "
f"but expected type or subtype `dict`."
)
for key in self.keys:
if key in info_dict:
tensordict[key] = info_dict[key]
Expand Down Expand Up @@ -67,6 +73,11 @@ class GymLikeEnv(_EnvWrapper):
It is also expected that env.reset() returns an observation similar to the one observed after a step is completed.
"""

@classmethod
def __new__(cls, *args, **kwargs):
cls._info_dict_reader = None
return super().__new__(cls, *args, **kwargs)

def _step(self, tensordict: _TensorDict) -> _TensorDict:
action = tensordict.get("action")
action_np = self.action_spec.to_numpy(action, safe=False)
Expand Down Expand Up @@ -98,7 +109,8 @@ def _step(self, tensordict: _TensorDict) -> _TensorDict:
)
tensordict_out.set("reward", reward)
tensordict_out.set("done", done)
self.info_dict_reader(info, tensordict_out)
if self.info_dict_reader is not None:
self.info_dict_reader(*info, tensordict_out)

return tensordict_out

Expand Down Expand Up @@ -156,17 +168,15 @@ def set_info_dict_reader(self, info_dict_reader: callable) -> GymLikeEnv:
self.info_dict_reader = info_dict_reader
return self

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(env={self._env}, batch_size={self.batch_size})"
)

@property
def info_dict_reader(self):
if "_info_dict_reader" not in self.__dir__():
self._info_dict_reader = default_info_dict_reader()
return self._info_dict_reader

@info_dict_reader.setter
def info_dict_reader(self, value: callable):
self._info_dict_reader = value

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(env={self._env}, batch_size={self.batch_size})"
)
12 changes: 11 additions & 1 deletion torchrl/envs/libs/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
UnboundedContinuousTensorSpec,
)
from ...data.utils import numpy_to_torch_dtype_dict
from ..gym_like import GymLikeEnv
from ..gym_like import GymLikeEnv, default_info_dict_reader
from ..utils import classproperty

try:
Expand Down Expand Up @@ -226,6 +226,16 @@ def rebuild_with_kwargs(self, **new_kwargs):
self._env = self._build_env(**self._constructor_kwargs)
self._make_specs(self._env)

@property
def info_dict_reader(self):
if self._info_dict_reader is None:
self._info_dict_reader = default_info_dict_reader()
return self._info_dict_reader

@info_dict_reader.setter
def info_dict_reader(self, value: callable):
self._info_dict_reader = value


class GymEnv(GymWrapper):
"""
Expand Down
9 changes: 7 additions & 2 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1245,8 +1245,6 @@ def __init__(
keys_in: Optional[Sequence[str]] = None,
keys_inv_in: Optional[Sequence[str]] = None,
):
if keys_inv_in is None:
keys_inv_in = ["action"]
super().__init__(keys_in=keys_in, keys_inv_in=keys_inv_in)

def _apply_transform(self, obs: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -1286,6 +1284,13 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
self._transform_spec(observation_spec)
return observation_spec

def __repr__(self) -> str:
s = (
f"{self.__class__.__name__}(keys_in={self.keys_in}, keys_out={self.keys_out},"
f"keys_inv_in={self.keys_inv_in}, keys_inv_out={self.keys_inv_out})"
)
return s


class CatTensors(Transform):
"""
Expand Down

0 comments on commit 225f92f

Please sign in to comment.