Skip to content

Commit

Permalink
[BugFix] Fix Gym Categorical/One-hot issues (pytorch#1482)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Sep 3, 2023
1 parent 2982515 commit 545a28c
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 16 deletions.
3 changes: 2 additions & 1 deletion test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1696,7 +1696,8 @@ def _step(
done = self.output_spec["full_done_spec"].zero()
td = self.observation_spec.zero()

one_hot_action = tensordict["action"].argmax(-1).unsqueeze(-1)
one_hot_action = tensordict["action"]
one_hot_action = one_hot_action.long().argmax(-1).unsqueeze(-1)
reward["reward"] += one_hot_action.to(torch.float)
self.count += one_hot_action.to(torch.int)
td["observation"] += expand_right(self.count, td["observation"].shape)
Expand Down
2 changes: 1 addition & 1 deletion test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2013,7 +2013,7 @@ def check_rollout_consistency(td: TensorDict, max_steps: int):
== td["next", "observation"][index_batch_size][:-1][~next_is_done]
).all()
# Check observation and reward update with count action for root
action_is_count = td["action"].argmax(-1).to(torch.bool)
action_is_count = td["action"].long().argmax(-1).to(torch.bool)
assert (
td["next", "observation"][action_is_count]
== td["observation"][action_is_count] + 1
Expand Down
19 changes: 19 additions & 0 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,25 @@ def info_reader(info, tensordict):
env.rand_step()
env.rollout(3)

@implement_for("gymnasium", "0.27.0", None)
def test_one_hot_and_categorical(self):
# tests that one-hot and categorical work ok when an integer is expected as action
cliff_walking = GymEnv("CliffWalking-v0", categorical_action_encoding=True)
cliff_walking.rollout(10)
check_env_specs(cliff_walking)

cliff_walking = GymEnv("CliffWalking-v0", categorical_action_encoding=False)
cliff_walking.rollout(10)
check_env_specs(cliff_walking)

@implement_for("gym", None, "0.27.0")
def test_one_hot_and_categorical(self): # noqa: F811
# we do not skip (bc we may want to make sure nothing is skipped)
# but CliffWalking-v0 in earlier Gym versions uses np.bool, which
# was deprecated after np 1.20, and we don't want to install multiple np
# versions.
return


@implement_for("gym", None, "0.26")
def _make_gym_environment(env_name): # noqa: F811
Expand Down
2 changes: 1 addition & 1 deletion test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1168,7 +1168,7 @@ def test_one_hot_discrete_action_spec_rand(self):

sample = action_spec.rand((100000,))

sample_list = sample.argmax(-1)
sample_list = sample.long().argmax(-1)
sample_list = [sum(sample_list == i).item() for i in range(10)]
assert chisquare(sample_list).pvalue > 0.1

Expand Down
23 changes: 11 additions & 12 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1117,7 +1117,7 @@ def __init__(
n: int,
shape: Optional[torch.Size] = None,
device: Optional[DEVICE_TYPING] = None,
dtype: Optional[Union[str, torch.dtype]] = torch.long,
dtype: Optional[Union[str, torch.dtype]] = torch.bool,
use_register: bool = False,
):
dtype, device = _default_dtype_and_device(dtype, device)
Expand Down Expand Up @@ -1229,9 +1229,9 @@ def encode(
) -> torch.Tensor:
if not isinstance(val, torch.Tensor):
if ignore_device:
val = torch.tensor(val, dtype=self.dtype)
val = torch.tensor(val)
else:
val = torch.tensor(val, dtype=self.dtype, device=self.device)
val = torch.tensor(val, device=self.device)

if space is None:
space = self.space
Expand All @@ -1244,7 +1244,7 @@ def encode(
if (val >= space.n).any():
raise AssertionError("Value must be less than action space.")

val = torch.nn.functional.one_hot(val.long(), space.n)
val = torch.nn.functional.one_hot(val.long(), space.n).to(self.dtype)
return val

def to_numpy(self, val: torch.Tensor, safe: bool = None) -> np.ndarray:
Expand All @@ -1254,7 +1254,7 @@ def to_numpy(self, val: torch.Tensor, safe: bool = None) -> np.ndarray:
if not isinstance(val, torch.Tensor):
raise NotImplementedError
self.assert_is_in(val)
val = val.argmax(-1).cpu().numpy()
val = val.long().argmax(-1).cpu().numpy()
if self.use_register:
inv_reg = self.space.register.inverse()
vals = []
Expand Down Expand Up @@ -1323,14 +1323,13 @@ def to_categorical(self, val: torch.Tensor, safe: bool = None) -> torch.Tensor:
safe = _CHECK_SPEC_ENCODE
if safe:
self.assert_is_in(val)
return val.argmax(-1)
return val.long().argmax(-1)

def to_categorical_spec(self) -> DiscreteTensorSpec:
"""Converts the spec to the equivalent categorical spec."""
return DiscreteTensorSpec(
self.space.n,
device=self.device,
dtype=self.dtype,
shape=self.shape[:-1],
)

Expand Down Expand Up @@ -1801,7 +1800,7 @@ def __init__(
nvec: Sequence[int],
shape: Optional[torch.Size] = None,
device=None,
dtype=torch.long,
dtype=torch.bool,
use_register=False,
):
self.nvec = nvec
Expand Down Expand Up @@ -1943,14 +1942,13 @@ def to_categorical(self, val: torch.Tensor, safe: bool = None) -> torch.Tensor:
if safe:
self.assert_is_in(val)
vals = self._split(val)
return torch.stack([val.argmax(-1) for val in vals], -1)
return torch.stack([val.long().argmax(-1) for val in vals], -1)

def to_categorical_spec(self) -> MultiDiscreteTensorSpec:
"""Converts the spec to the equivalent categorical spec."""
return MultiDiscreteTensorSpec(
[_space.n for _space in self.space],
device=self.device,
dtype=self.dtype,
shape=[*self.shape[:-1], len(self.space)],
)

Expand Down Expand Up @@ -2122,7 +2120,9 @@ def to_one_hot_spec(self) -> OneHotDiscreteTensorSpec:
"""Converts the spec to the equivalent one-hot spec."""
shape = [*self.shape, self.space.n]
return OneHotDiscreteTensorSpec(
n=self.space.n, shape=shape, device=self.device, dtype=self.dtype
n=self.space.n,
shape=shape,
device=self.device,
)

def expand(self, *shape):
Expand Down Expand Up @@ -2443,7 +2443,6 @@ def to_one_hot_spec(self) -> MultiOneHotDiscreteTensorSpec:
return MultiOneHotDiscreteTensorSpec(
nvec,
device=self.device,
dtype=self.dtype,
shape=[*self.shape[:-1], sum(nvec)],
)

Expand Down
10 changes: 10 additions & 0 deletions torchrl/envs/libs/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,16 @@ def _build_env(
env = self._build_gym_env(env, pixels_only)
return env

def read_action(self, action):
action = super().read_action(action)
if (
isinstance(self.action_spec, (OneHotDiscreteTensorSpec, DiscreteTensorSpec))
and action.size == 1
):
# some envs require an integer for indexing
action = int(action)
return action

@implement_for("gym", None, "0.19.0")
def _build_gym_env(self, env, pixels_only): # noqa: F811
from .utils import GymPixelObservationWrapper as PixelObservationWrapper
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3050,7 +3050,7 @@ def _inv_apply_transform(self, action: torch.Tensor) -> torch.Tensor:
raise RuntimeError(
f"action.shape[-1]={action.shape[-1]} must match self.max_actions={self.max_actions}."
)
action = action.argmax(-1) # bool to int
action = action.long().argmax(-1) # bool to int
idx = action >= self.num_actions_effective
if idx.any():
action[idx] = torch.randint(self.num_actions_effective, (idx.sum(),))
Expand Down

0 comments on commit 545a28c

Please sign in to comment.