Skip to content

Commit

Permalink
[Feature] ActionDiscretizer (#2247)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jun 26, 2024
1 parent 559b729 commit 387dd0e
Show file tree
Hide file tree
Showing 6 changed files with 414 additions and 8 deletions.
13 changes: 7 additions & 6 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,7 @@ to be able to create this other composition:

Transform
TransformedEnv
ActionDiscretizer
ActionMask
AutoResetEnv
AutoResetTransform
Expand All @@ -779,17 +780,16 @@ to be able to create this other composition:
CenterCrop
ClipTransform
Compose
DTypeCastTransform
DeviceCastTransform
DiscreteActionProjection
DoubleToFloat
DTypeCastTransform
EndOfLifeTransform
ExcludeTransform
FiniteTensorDictCheck
FlattenObservation
FrameSkipTransform
GrayScale
gSDENoise
InitTracker
KLRewardTransform
NoopResetEnv
Expand All @@ -799,13 +799,13 @@ to be able to create this other composition:
PinMemoryTransform
R3MTransform
RandomCropTensorDict
RemoveEmptySpecs
RenameTransform
Resize
Reward2GoTransform
RewardClipping
RewardScaling
RewardSum
Reward2GoTransform
RemoveEmptySpecs
SelectTransform
SignTransform
SqueezeTransform
Expand All @@ -815,11 +815,12 @@ to be able to create this other composition:
TimeMaxPool
ToTensorImage
UnsqueezeTransform
VecGymEnvTransform
VecNorm
VC1Transform
VIPRewardTransform
VIPTransform
VecGymEnvTransform
VecNorm
gSDENoise

Environments with masked actions
--------------------------------
Expand Down
120 changes: 120 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@
from torchrl.envs.transforms.rlhf import KLRewardTransform
from torchrl.envs.transforms.transforms import (
_has_tv,
ActionDiscretizer,
BatchSizeTransform,
FORWARD_NOT_IMPLEMENTED,
Transform,
Expand Down Expand Up @@ -10985,6 +10986,125 @@ def test_transform_inverse(self):
return


class TestActionDiscretizer(TransformBase):
@pytest.mark.parametrize("categorical", [True, False])
def test_single_trans_env_check(self, categorical):
base_env = ContinuousActionVecMockEnv()
env = base_env.append_transform(
ActionDiscretizer(num_intervals=5, categorical=categorical)
)
check_env_specs(env)

@pytest.mark.parametrize("categorical", [True, False])
def test_serial_trans_env_check(self, categorical):
def make_env():
base_env = ContinuousActionVecMockEnv()
return base_env.append_transform(
ActionDiscretizer(num_intervals=5, categorical=categorical)
)

env = SerialEnv(2, make_env)
check_env_specs(env)

@pytest.mark.parametrize("categorical", [True, False])
def test_parallel_trans_env_check(self, categorical):
def make_env():
base_env = ContinuousActionVecMockEnv()
env = base_env.append_transform(
ActionDiscretizer(num_intervals=5, categorical=categorical)
)
return env

env = ParallelEnv(2, make_env, mp_start_method="fork")
check_env_specs(env)

@pytest.mark.parametrize("categorical", [True, False])
def test_trans_serial_env_check(self, categorical):
env = SerialEnv(2, ContinuousActionVecMockEnv).append_transform(
ActionDiscretizer(num_intervals=5, categorical=categorical)
)
check_env_specs(env)

@pytest.mark.parametrize("categorical", [True, False])
def test_trans_parallel_env_check(self, categorical):
env = ParallelEnv(
2, ContinuousActionVecMockEnv, mp_start_method="fork"
).append_transform(ActionDiscretizer(num_intervals=5, categorical=categorical))
check_env_specs(env)

def test_transform_no_env(self):
categorical = True
with pytest.raises(RuntimeError, match="Cannot execute transform"):
ActionDiscretizer(num_intervals=5, categorical=categorical)._init()

def test_transform_compose(self):
categorical = True
env = SerialEnv(2, ContinuousActionVecMockEnv).append_transform(
Compose(ActionDiscretizer(num_intervals=5, categorical=categorical))
)
check_env_specs(env)

@pytest.mark.skipif(not _has_gym, reason="gym required for this test")
@pytest.mark.parametrize("envname", ["cheetah", "pendulum"])
@pytest.mark.parametrize("interval_as_tensor", [False, True])
@pytest.mark.parametrize("categorical", [True, False])
@pytest.mark.parametrize(
"sampling",
[
None,
ActionDiscretizer.SamplingStrategy.MEDIAN,
ActionDiscretizer.SamplingStrategy.LOW,
ActionDiscretizer.SamplingStrategy.HIGH,
ActionDiscretizer.SamplingStrategy.RANDOM,
],
)
def test_transform_env(self, envname, interval_as_tensor, categorical, sampling):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
base_env = GymEnv(
HALFCHEETAH_VERSIONED() if envname == "cheetah" else PENDULUM_VERSIONED(),
device=device,
)
if interval_as_tensor:
num_intervals = torch.arange(5, 11 if envname == "cheetah" else 6)
else:
num_intervals = 5
t = ActionDiscretizer(
num_intervals=num_intervals,
categorical=categorical,
sampling=sampling,
out_action_key="action_disc",
)
env = base_env.append_transform(t)
check_env_specs(env)
r = env.rollout(4)
assert r["action"].dtype == torch.float
if categorical:
assert r["action_disc"].dtype == torch.int64
else:
assert r["action_disc"].dtype == torch.bool
if t.sampling in (
t.SamplingStrategy.LOW,
t.SamplingStrategy.MEDIAN,
t.SamplingStrategy.RANDOM,
):
assert (r["action"] < base_env.action_spec.high).all()
if t.sampling in (
t.SamplingStrategy.HIGH,
t.SamplingStrategy.MEDIAN,
t.SamplingStrategy.RANDOM,
):
assert (r["action"] > base_env.action_spec.low).all()

def test_transform_model(self):
pytest.skip("Tested elsewhere")

def test_transform_rb(self):
pytest.skip("Tested elsewhere")

def test_transform_inverse(self):
pytest.skip("Tested elsewhere")


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
6 changes: 5 additions & 1 deletion torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3312,6 +3312,8 @@ class MultiDiscreteTensorSpec(DiscreteTensorSpec):
device (str, int or torch.device, optional): device of
the tensors.
dtype (str or torch.dtype, optional): dtype of the tensors.
remove_singleton (bool, optional): if ``True``, singleton samples (of size [1])
will be squeezed. Defaults to ``True``.
Examples:
>>> ts = MultiDiscreteTensorSpec((3, 2, 3))
Expand All @@ -3330,6 +3332,7 @@ def __init__(
device: Optional[DEVICE_TYPING] = None,
dtype: Optional[Union[str, torch.dtype]] = torch.int64,
mask: torch.Tensor | None = None,
remove_singleton: bool = True,
):
if not isinstance(nvec, torch.Tensor):
nvec = torch.tensor(nvec)
Expand All @@ -3354,6 +3357,7 @@ def __init__(
shape, space, device, dtype, domain="discrete"
)
self.update_mask(mask)
self.remove_singleton = remove_singleton

def update_mask(self, mask):
if mask is not None:
Expand Down Expand Up @@ -3442,7 +3446,7 @@ def rand(self, shape: Optional[torch.Size] = None) -> torch.Tensor:
*self.shape[:-1],
)
x = self._rand(space=self.space, shape=shape, i=self.nvec.ndim)
if self.shape == torch.Size([1]):
if self.remove_singleton and self.shape == torch.Size([1]):
x = x.squeeze(-1)
return x

Expand Down
1 change: 1 addition & 0 deletions torchrl/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
)
from .model_based import DreamerDecoder, DreamerEnv, ModelBasedEnvBase
from .transforms import (
ActionDiscretizer,
ActionMask,
AutoResetEnv,
AutoResetTransform,
Expand Down
1 change: 1 addition & 0 deletions torchrl/envs/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .rb_transforms import MultiStepTransform
from .rlhf import KLRewardTransform
from .transforms import (
ActionDiscretizer,
ActionMask,
AutoResetEnv,
AutoResetTransform,
Expand Down
Loading

0 comments on commit 387dd0e

Please sign in to comment.