Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] ActionDiscretizer #2247

Merged
merged 4 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
amend
  • Loading branch information
vmoens committed Jun 25, 2024
commit 66f86cfb8742ad39b97b6370c6b69aeaf8d828e6
13 changes: 7 additions & 6 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,7 @@ to be able to create this other composition:

Transform
TransformedEnv
ActionDiscretizer
ActionMask
AutoResetEnv
AutoResetTransform
Expand All @@ -779,17 +780,16 @@ to be able to create this other composition:
CenterCrop
ClipTransform
Compose
DTypeCastTransform
DeviceCastTransform
DiscreteActionProjection
DoubleToFloat
DTypeCastTransform
EndOfLifeTransform
ExcludeTransform
FiniteTensorDictCheck
FlattenObservation
FrameSkipTransform
GrayScale
gSDENoise
InitTracker
KLRewardTransform
NoopResetEnv
Expand All @@ -799,13 +799,13 @@ to be able to create this other composition:
PinMemoryTransform
R3MTransform
RandomCropTensorDict
RemoveEmptySpecs
RenameTransform
Resize
Reward2GoTransform
RewardClipping
RewardScaling
RewardSum
Reward2GoTransform
RemoveEmptySpecs
SelectTransform
SignTransform
SqueezeTransform
Expand All @@ -815,11 +815,12 @@ to be able to create this other composition:
TimeMaxPool
ToTensorImage
UnsqueezeTransform
VecGymEnvTransform
VecNorm
VC1Transform
VIPRewardTransform
VIPTransform
VecGymEnvTransform
VecNorm
gSDENoise

Environments with masked actions
--------------------------------
Expand Down
118 changes: 118 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@
from torchrl.envs.transforms.rlhf import KLRewardTransform
from torchrl.envs.transforms.transforms import (
_has_tv,
ActionDiscretizer,
BatchSizeTransform,
FORWARD_NOT_IMPLEMENTED,
Transform,
Expand Down Expand Up @@ -10985,6 +10986,123 @@ def test_transform_inverse(self):
return


class TestActionDiscretizer(TransformBase):
@pytest.mark.parametrize("categorical", [True, False])
def test_single_trans_env_check(self, categorical):
base_env = ContinuousActionVecMockEnv()
env = base_env.append_transform(
ActionDiscretizer(num_intervals=5, categorical=categorical)
)
check_env_specs(env)

@pytest.mark.parametrize("categorical", [True, False])
def test_serial_trans_env_check(self, categorical):
def make_env():
base_env = ContinuousActionVecMockEnv()
return base_env.append_transform(
ActionDiscretizer(num_intervals=5, categorical=categorical)
)

env = SerialEnv(2, make_env)
check_env_specs(env)

@pytest.mark.parametrize("categorical", [True, False])
def test_parallel_trans_env_check(self, categorical):
def make_env():
base_env = ContinuousActionVecMockEnv()
env = base_env.append_transform(
ActionDiscretizer(num_intervals=5, categorical=categorical)
)
return env

env = ParallelEnv(2, make_env, mp_start_method="fork")
check_env_specs(env)

@pytest.mark.parametrize("categorical", [True, False])
def test_trans_serial_env_check(self, categorical):
env = SerialEnv(2, ContinuousActionVecMockEnv).append_transform(
ActionDiscretizer(num_intervals=5, categorical=categorical)
)
check_env_specs(env)

@pytest.mark.parametrize("categorical", [True, False])
def test_trans_parallel_env_check(self, categorical):
env = ParallelEnv(
2, ContinuousActionVecMockEnv, mp_start_method="fork"
).append_transform(ActionDiscretizer(num_intervals=5, categorical=categorical))
check_env_specs(env)

def test_transform_no_env(self):
categorical = True
with pytest.raises(RuntimeError, match="Cannot execute transform"):
ActionDiscretizer(num_intervals=5, categorical=categorical)._init()

def test_transform_compose(self):
categorical = True
env = SerialEnv(2, ContinuousActionVecMockEnv).append_transform(
Compose(ActionDiscretizer(num_intervals=5, categorical=categorical))
)
check_env_specs(env)

@pytest.mark.skipif(not _has_gym, reason="gym required for this test")
@pytest.mark.parametrize("envname", ["cheetah", "pendulum"])
@pytest.mark.parametrize("interval_as_tensor", [False, True])
@pytest.mark.parametrize("categorical", [True, False])
@pytest.mark.parametrize(
"sampling",
[
None,
ActionDiscretizer.SamplingStrategy.MEDIAN,
ActionDiscretizer.SamplingStrategy.LOW,
ActionDiscretizer.SamplingStrategy.HIGH,
ActionDiscretizer.SamplingStrategy.RANDOM,
],
)
def test_transform_env(self, envname, interval_as_tensor, categorical, sampling):
base_env = GymEnv(
HALFCHEETAH_VERSIONED() if envname == "cheetah" else PENDULUM_VERSIONED()
)
if interval_as_tensor:
num_intervals = torch.arange(5, 11 if envname == "cheetah" else 6)
else:
num_intervals = 5
t = ActionDiscretizer(
num_intervals=num_intervals,
categorical=categorical,
sampling=sampling,
out_action_key="action_disc",
)
env = base_env.append_transform(t)
check_env_specs(env)
r = env.rollout(4)
assert r["action"].dtype == torch.float
if categorical:
assert r["action_disc"].dtype == torch.int64
else:
assert r["action_disc"].dtype == torch.bool
if t.sampling in (
t.SamplingStrategy.LOW,
t.SamplingStrategy.MEDIAN,
t.SamplingStrategy.RANDOM,
):
assert (r["action"] < base_env.action_spec.high).all()
if t.sampling in (
t.SamplingStrategy.HIGH,
t.SamplingStrategy.MEDIAN,
t.SamplingStrategy.RANDOM,
):
assert (r["action"] > base_env.action_spec.low).all()

def test_transform_model(self):
pytest.skip("Tested elsewhere")

def test_transform_rb(self):
pytest.skip("Tested elsewhere")

def test_transform_inverse(self):
pytest.skip("Tested elsewhere")


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
6 changes: 5 additions & 1 deletion torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3312,6 +3312,8 @@ class MultiDiscreteTensorSpec(DiscreteTensorSpec):
device (str, int or torch.device, optional): device of
the tensors.
dtype (str or torch.dtype, optional): dtype of the tensors.
remove_singleton (bool, optional): if ``True``, singleton samples (of size [1])
will be squeezed. Defaults to ``True``.

Examples:
>>> ts = MultiDiscreteTensorSpec((3, 2, 3))
Expand All @@ -3330,6 +3332,7 @@ def __init__(
device: Optional[DEVICE_TYPING] = None,
dtype: Optional[Union[str, torch.dtype]] = torch.int64,
mask: torch.Tensor | None = None,
remove_singleton: bool = True,
):
if not isinstance(nvec, torch.Tensor):
nvec = torch.tensor(nvec)
Expand All @@ -3354,6 +3357,7 @@ def __init__(
shape, space, device, dtype, domain="discrete"
)
self.update_mask(mask)
self.remove_singleton = remove_singleton

def update_mask(self, mask):
if mask is not None:
Expand Down Expand Up @@ -3442,7 +3446,7 @@ def rand(self, shape: Optional[torch.Size] = None) -> torch.Tensor:
*self.shape[:-1],
)
x = self._rand(space=self.space, shape=shape, i=self.nvec.ndim)
if self.shape == torch.Size([1]):
if self.remove_singleton and self.shape == torch.Size([1]):
x = x.squeeze(-1)
return x

Expand Down
1 change: 1 addition & 0 deletions torchrl/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
)
from .model_based import DreamerDecoder, DreamerEnv, ModelBasedEnvBase
from .transforms import (
ActionDiscretizer,
ActionMask,
AutoResetEnv,
AutoResetTransform,
Expand Down
1 change: 1 addition & 0 deletions torchrl/envs/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .rb_transforms import MultiStepTransform
from .rlhf import KLRewardTransform
from .transforms import (
ActionDiscretizer,
ActionMask,
AutoResetEnv,
AutoResetTransform,
Expand Down
Loading