From 545a28cc12efd251cdeec939062e17d52f5cbe3a Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 3 Sep 2023 13:28:45 +0100 Subject: [PATCH] [BugFix] Fix Gym Categorical/One-hot issues (#1482) --- test/mocking_classes.py | 3 ++- test/test_env.py | 2 +- test/test_libs.py | 19 +++++++++++++++++++ test/test_specs.py | 2 +- torchrl/data/tensor_specs.py | 23 +++++++++++------------ torchrl/envs/libs/gym.py | 10 ++++++++++ torchrl/envs/transforms/transforms.py | 2 +- 7 files changed, 45 insertions(+), 16 deletions(-) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 504634917bc..f3193ce2099 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -1696,7 +1696,8 @@ def _step( done = self.output_spec["full_done_spec"].zero() td = self.observation_spec.zero() - one_hot_action = tensordict["action"].argmax(-1).unsqueeze(-1) + one_hot_action = tensordict["action"] + one_hot_action = one_hot_action.long().argmax(-1).unsqueeze(-1) reward["reward"] += one_hot_action.to(torch.float) self.count += one_hot_action.to(torch.int) td["observation"] += expand_right(self.count, td["observation"].shape) diff --git a/test/test_env.py b/test/test_env.py index 6ad03208e3d..f5d03ba366c 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -2013,7 +2013,7 @@ def check_rollout_consistency(td: TensorDict, max_steps: int): == td["next", "observation"][index_batch_size][:-1][~next_is_done] ).all() # Check observation and reward update with count action for root - action_is_count = td["action"].argmax(-1).to(torch.bool) + action_is_count = td["action"].long().argmax(-1).to(torch.bool) assert ( td["next", "observation"][action_is_count] == td["observation"][action_is_count] + 1 diff --git a/test/test_libs.py b/test/test_libs.py index dc03c092c68..cce2d3f75ab 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -292,6 +292,25 @@ def info_reader(info, tensordict): env.rand_step() env.rollout(3) + @implement_for("gymnasium", "0.27.0", None) + def test_one_hot_and_categorical(self): + # tests that one-hot and categorical work ok when an integer is expected as action + cliff_walking = GymEnv("CliffWalking-v0", categorical_action_encoding=True) + cliff_walking.rollout(10) + check_env_specs(cliff_walking) + + cliff_walking = GymEnv("CliffWalking-v0", categorical_action_encoding=False) + cliff_walking.rollout(10) + check_env_specs(cliff_walking) + + @implement_for("gym", None, "0.27.0") + def test_one_hot_and_categorical(self): # noqa: F811 + # we do not skip (bc we may want to make sure nothing is skipped) + # but CliffWalking-v0 in earlier Gym versions uses np.bool, which + # was deprecated after np 1.20, and we don't want to install multiple np + # versions. + return + @implement_for("gym", None, "0.26") def _make_gym_environment(env_name): # noqa: F811 diff --git a/test/test_specs.py b/test/test_specs.py index 6e84e50c103..594f9e8fba7 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -1168,7 +1168,7 @@ def test_one_hot_discrete_action_spec_rand(self): sample = action_spec.rand((100000,)) - sample_list = sample.argmax(-1) + sample_list = sample.long().argmax(-1) sample_list = [sum(sample_list == i).item() for i in range(10)] assert chisquare(sample_list).pvalue > 0.1 diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 00b8418ab2f..99cbf8cd78b 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -1117,7 +1117,7 @@ def __init__( n: int, shape: Optional[torch.Size] = None, device: Optional[DEVICE_TYPING] = None, - dtype: Optional[Union[str, torch.dtype]] = torch.long, + dtype: Optional[Union[str, torch.dtype]] = torch.bool, use_register: bool = False, ): dtype, device = _default_dtype_and_device(dtype, device) @@ -1229,9 +1229,9 @@ def encode( ) -> torch.Tensor: if not isinstance(val, torch.Tensor): if ignore_device: - val = torch.tensor(val, dtype=self.dtype) + val = torch.tensor(val) else: - val = torch.tensor(val, dtype=self.dtype, device=self.device) + val = torch.tensor(val, device=self.device) if space is None: space = self.space @@ -1244,7 +1244,7 @@ def encode( if (val >= space.n).any(): raise AssertionError("Value must be less than action space.") - val = torch.nn.functional.one_hot(val.long(), space.n) + val = torch.nn.functional.one_hot(val.long(), space.n).to(self.dtype) return val def to_numpy(self, val: torch.Tensor, safe: bool = None) -> np.ndarray: @@ -1254,7 +1254,7 @@ def to_numpy(self, val: torch.Tensor, safe: bool = None) -> np.ndarray: if not isinstance(val, torch.Tensor): raise NotImplementedError self.assert_is_in(val) - val = val.argmax(-1).cpu().numpy() + val = val.long().argmax(-1).cpu().numpy() if self.use_register: inv_reg = self.space.register.inverse() vals = [] @@ -1323,14 +1323,13 @@ def to_categorical(self, val: torch.Tensor, safe: bool = None) -> torch.Tensor: safe = _CHECK_SPEC_ENCODE if safe: self.assert_is_in(val) - return val.argmax(-1) + return val.long().argmax(-1) def to_categorical_spec(self) -> DiscreteTensorSpec: """Converts the spec to the equivalent categorical spec.""" return DiscreteTensorSpec( self.space.n, device=self.device, - dtype=self.dtype, shape=self.shape[:-1], ) @@ -1801,7 +1800,7 @@ def __init__( nvec: Sequence[int], shape: Optional[torch.Size] = None, device=None, - dtype=torch.long, + dtype=torch.bool, use_register=False, ): self.nvec = nvec @@ -1943,14 +1942,13 @@ def to_categorical(self, val: torch.Tensor, safe: bool = None) -> torch.Tensor: if safe: self.assert_is_in(val) vals = self._split(val) - return torch.stack([val.argmax(-1) for val in vals], -1) + return torch.stack([val.long().argmax(-1) for val in vals], -1) def to_categorical_spec(self) -> MultiDiscreteTensorSpec: """Converts the spec to the equivalent categorical spec.""" return MultiDiscreteTensorSpec( [_space.n for _space in self.space], device=self.device, - dtype=self.dtype, shape=[*self.shape[:-1], len(self.space)], ) @@ -2122,7 +2120,9 @@ def to_one_hot_spec(self) -> OneHotDiscreteTensorSpec: """Converts the spec to the equivalent one-hot spec.""" shape = [*self.shape, self.space.n] return OneHotDiscreteTensorSpec( - n=self.space.n, shape=shape, device=self.device, dtype=self.dtype + n=self.space.n, + shape=shape, + device=self.device, ) def expand(self, *shape): @@ -2443,7 +2443,6 @@ def to_one_hot_spec(self) -> MultiOneHotDiscreteTensorSpec: return MultiOneHotDiscreteTensorSpec( nvec, device=self.device, - dtype=self.dtype, shape=[*self.shape[:-1], sum(nvec)], ) diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index 2a71cd8f7a1..8e58c0915a6 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -419,6 +419,16 @@ def _build_env( env = self._build_gym_env(env, pixels_only) return env + def read_action(self, action): + action = super().read_action(action) + if ( + isinstance(self.action_spec, (OneHotDiscreteTensorSpec, DiscreteTensorSpec)) + and action.size == 1 + ): + # some envs require an integer for indexing + action = int(action) + return action + @implement_for("gym", None, "0.19.0") def _build_gym_env(self, env, pixels_only): # noqa: F811 from .utils import GymPixelObservationWrapper as PixelObservationWrapper diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 6a726883971..c22aac89866 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -3050,7 +3050,7 @@ def _inv_apply_transform(self, action: torch.Tensor) -> torch.Tensor: raise RuntimeError( f"action.shape[-1]={action.shape[-1]} must match self.max_actions={self.max_actions}." ) - action = action.argmax(-1) # bool to int + action = action.long().argmax(-1) # bool to int idx = action >= self.num_actions_effective if idx.any(): action[idx] = torch.randint(self.num_actions_effective, (idx.sum(),))