From 387dd0e34edd5638ac41ef436af5a427180f83fb Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 26 Jun 2024 12:40:33 +0100 Subject: [PATCH] [Feature] ActionDiscretizer (#2247) --- docs/source/reference/envs.rst | 13 +- test/test_transforms.py | 120 +++++++++++ torchrl/data/tensor_specs.py | 6 +- torchrl/envs/__init__.py | 1 + torchrl/envs/transforms/__init__.py | 1 + torchrl/envs/transforms/transforms.py | 281 +++++++++++++++++++++++++- 6 files changed, 414 insertions(+), 8 deletions(-) diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index 4c4a1e2b8e5..c7b0eba35c0 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -768,6 +768,7 @@ to be able to create this other composition: Transform TransformedEnv + ActionDiscretizer ActionMask AutoResetEnv AutoResetTransform @@ -779,17 +780,16 @@ to be able to create this other composition: CenterCrop ClipTransform Compose + DTypeCastTransform DeviceCastTransform DiscreteActionProjection DoubleToFloat - DTypeCastTransform EndOfLifeTransform ExcludeTransform FiniteTensorDictCheck FlattenObservation FrameSkipTransform GrayScale - gSDENoise InitTracker KLRewardTransform NoopResetEnv @@ -799,13 +799,13 @@ to be able to create this other composition: PinMemoryTransform R3MTransform RandomCropTensorDict + RemoveEmptySpecs RenameTransform Resize + Reward2GoTransform RewardClipping RewardScaling RewardSum - Reward2GoTransform - RemoveEmptySpecs SelectTransform SignTransform SqueezeTransform @@ -815,11 +815,12 @@ to be able to create this other composition: TimeMaxPool ToTensorImage UnsqueezeTransform - VecGymEnvTransform - VecNorm VC1Transform VIPRewardTransform VIPTransform + VecGymEnvTransform + VecNorm + gSDENoise Environments with masked actions -------------------------------- diff --git a/test/test_transforms.py b/test/test_transforms.py index 9c44b50f98e..78cebdbebab 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -119,6 +119,7 @@ from torchrl.envs.transforms.rlhf import KLRewardTransform from torchrl.envs.transforms.transforms import ( _has_tv, + ActionDiscretizer, BatchSizeTransform, FORWARD_NOT_IMPLEMENTED, Transform, @@ -10985,6 +10986,125 @@ def test_transform_inverse(self): return +class TestActionDiscretizer(TransformBase): + @pytest.mark.parametrize("categorical", [True, False]) + def test_single_trans_env_check(self, categorical): + base_env = ContinuousActionVecMockEnv() + env = base_env.append_transform( + ActionDiscretizer(num_intervals=5, categorical=categorical) + ) + check_env_specs(env) + + @pytest.mark.parametrize("categorical", [True, False]) + def test_serial_trans_env_check(self, categorical): + def make_env(): + base_env = ContinuousActionVecMockEnv() + return base_env.append_transform( + ActionDiscretizer(num_intervals=5, categorical=categorical) + ) + + env = SerialEnv(2, make_env) + check_env_specs(env) + + @pytest.mark.parametrize("categorical", [True, False]) + def test_parallel_trans_env_check(self, categorical): + def make_env(): + base_env = ContinuousActionVecMockEnv() + env = base_env.append_transform( + ActionDiscretizer(num_intervals=5, categorical=categorical) + ) + return env + + env = ParallelEnv(2, make_env, mp_start_method="fork") + check_env_specs(env) + + @pytest.mark.parametrize("categorical", [True, False]) + def test_trans_serial_env_check(self, categorical): + env = SerialEnv(2, ContinuousActionVecMockEnv).append_transform( + ActionDiscretizer(num_intervals=5, categorical=categorical) + ) + check_env_specs(env) + + @pytest.mark.parametrize("categorical", [True, False]) + def test_trans_parallel_env_check(self, categorical): + env = ParallelEnv( + 2, ContinuousActionVecMockEnv, mp_start_method="fork" + ).append_transform(ActionDiscretizer(num_intervals=5, categorical=categorical)) + check_env_specs(env) + + def test_transform_no_env(self): + categorical = True + with pytest.raises(RuntimeError, match="Cannot execute transform"): + ActionDiscretizer(num_intervals=5, categorical=categorical)._init() + + def test_transform_compose(self): + categorical = True + env = SerialEnv(2, ContinuousActionVecMockEnv).append_transform( + Compose(ActionDiscretizer(num_intervals=5, categorical=categorical)) + ) + check_env_specs(env) + + @pytest.mark.skipif(not _has_gym, reason="gym required for this test") + @pytest.mark.parametrize("envname", ["cheetah", "pendulum"]) + @pytest.mark.parametrize("interval_as_tensor", [False, True]) + @pytest.mark.parametrize("categorical", [True, False]) + @pytest.mark.parametrize( + "sampling", + [ + None, + ActionDiscretizer.SamplingStrategy.MEDIAN, + ActionDiscretizer.SamplingStrategy.LOW, + ActionDiscretizer.SamplingStrategy.HIGH, + ActionDiscretizer.SamplingStrategy.RANDOM, + ], + ) + def test_transform_env(self, envname, interval_as_tensor, categorical, sampling): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + base_env = GymEnv( + HALFCHEETAH_VERSIONED() if envname == "cheetah" else PENDULUM_VERSIONED(), + device=device, + ) + if interval_as_tensor: + num_intervals = torch.arange(5, 11 if envname == "cheetah" else 6) + else: + num_intervals = 5 + t = ActionDiscretizer( + num_intervals=num_intervals, + categorical=categorical, + sampling=sampling, + out_action_key="action_disc", + ) + env = base_env.append_transform(t) + check_env_specs(env) + r = env.rollout(4) + assert r["action"].dtype == torch.float + if categorical: + assert r["action_disc"].dtype == torch.int64 + else: + assert r["action_disc"].dtype == torch.bool + if t.sampling in ( + t.SamplingStrategy.LOW, + t.SamplingStrategy.MEDIAN, + t.SamplingStrategy.RANDOM, + ): + assert (r["action"] < base_env.action_spec.high).all() + if t.sampling in ( + t.SamplingStrategy.HIGH, + t.SamplingStrategy.MEDIAN, + t.SamplingStrategy.RANDOM, + ): + assert (r["action"] > base_env.action_spec.low).all() + + def test_transform_model(self): + pytest.skip("Tested elsewhere") + + def test_transform_rb(self): + pytest.skip("Tested elsewhere") + + def test_transform_inverse(self): + pytest.skip("Tested elsewhere") + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 002ca9c5fde..6f75f207293 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -3312,6 +3312,8 @@ class MultiDiscreteTensorSpec(DiscreteTensorSpec): device (str, int or torch.device, optional): device of the tensors. dtype (str or torch.dtype, optional): dtype of the tensors. + remove_singleton (bool, optional): if ``True``, singleton samples (of size [1]) + will be squeezed. Defaults to ``True``. Examples: >>> ts = MultiDiscreteTensorSpec((3, 2, 3)) @@ -3330,6 +3332,7 @@ def __init__( device: Optional[DEVICE_TYPING] = None, dtype: Optional[Union[str, torch.dtype]] = torch.int64, mask: torch.Tensor | None = None, + remove_singleton: bool = True, ): if not isinstance(nvec, torch.Tensor): nvec = torch.tensor(nvec) @@ -3354,6 +3357,7 @@ def __init__( shape, space, device, dtype, domain="discrete" ) self.update_mask(mask) + self.remove_singleton = remove_singleton def update_mask(self, mask): if mask is not None: @@ -3442,7 +3446,7 @@ def rand(self, shape: Optional[torch.Size] = None) -> torch.Tensor: *self.shape[:-1], ) x = self._rand(space=self.space, shape=shape, i=self.nvec.ndim) - if self.shape == torch.Size([1]): + if self.remove_singleton and self.shape == torch.Size([1]): x = x.squeeze(-1) return x diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index 84676728aa5..8475979a3ba 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -38,6 +38,7 @@ ) from .model_based import DreamerDecoder, DreamerEnv, ModelBasedEnvBase from .transforms import ( + ActionDiscretizer, ActionMask, AutoResetEnv, AutoResetTransform, diff --git a/torchrl/envs/transforms/__init__.py b/torchrl/envs/transforms/__init__.py index 11c338a0f7b..5204b8a19d8 100644 --- a/torchrl/envs/transforms/__init__.py +++ b/torchrl/envs/transforms/__init__.py @@ -8,6 +8,7 @@ from .rb_transforms import MultiStepTransform from .rlhf import KLRewardTransform from .transforms import ( + ActionDiscretizer, ActionMask, AutoResetEnv, AutoResetTransform, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 7802fba368d..5185c3c1ac6 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -10,6 +10,7 @@ import multiprocessing as mp import warnings from copy import copy +from enum import IntEnum from functools import wraps from textwrap import indent from typing import ( @@ -341,7 +342,7 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: def inv(self, tensordict: TensorDictBase) -> TensorDictBase: def clone(data): try: - # we priviledge speed for tensordicts + # we privilege speed for tensordicts return data.clone(recurse=False) except AttributeError: return tree_map(lambda x: x, data) @@ -8174,3 +8175,281 @@ def _replace_auto_reset_vals(self, *, tensordict_reset): _dest.set(key, val_set_reg) delattr(self, "_saved_td_autorest") return tensordict_reset + + +class ActionDiscretizer(Transform): + """A transform to discretize a continuous action space. + + This transform makes it possible to use an algorithm designed for discrete + action spaces such as DQN over environments with a continuous action space. + + Args: + num_intervals (int or torch.Tensor): the number of discrete values + for each element of the action space. If a single integer is provided, + all action items are sliced with the same number of elements. + If a tensor is provided, it must have the same number of elements + as the action space (ie, the length of the ``num_intervals`` tensor + must match the last dimension of the action space). + action_key (NestedKey, optional): the action key to use. Points to + the action of the parent env (the floating point action). + Defaults to ``"action"``. + out_action_key (NestedKey, optional): the key where the discrete + action should be written. If ``None`` is provided, it defaults to + the value of ``action_key``. If both keys do not match, the + continuous action_spec is moved from the ``full_action_spec`` + environment attribute to the ``full_state_spec`` container, + as only the discrete action should be sampled for an action to + be taken. Providing ``out_action_key`` can ensure that the + floating point action is available to be recorded. + sampling (ActionDiscretizer.SamplingStrategy, optinoal): an element + of the ``ActionDiscretizer.SamplingStrategy`` ``IntEnum`` object + (``MEDIAN``, ``LOW``, ``HIGH`` or ``RANDOM``). Indicates how the + continuous action should be sampled in the provided interval. + categorical (bool, optional): if ``False``, one-hot encoding is used. + Defaults to ``True``. + + Examples: + >>> from torchrl.envs import GymEnv, check_env_specs + >>> import torch + >>> base_env = GymEnv("HalfCheetah-v4") + >>> num_intervals = torch.arange(5, 11) + >>> categorical = True + >>> sampling = ActionDiscretizer.SamplingStrategy.MEDIAN + >>> t = ActionDiscretizer( + ... num_intervals=num_intervals, + ... categorical=categorical, + ... sampling=sampling, + ... out_action_key="action_disc", + ... ) + >>> env = base_env.append_transform(t) + TransformedEnv( + env=GymEnv(env=HalfCheetah-v4, batch_size=torch.Size([]), device=cpu), + transform=ActionDiscretizer( + num_intervals=tensor([ 5, 6, 7, 8, 9, 10]), + action_key=action, + out_action_key=action_disc,, + sampling=0, + categorical=True)) + >>> check_env_specs(env) + >>> # Produce a rollout + >>> r = env.rollout(4) + >>> print(r) + TensorDict( + fields={ + action: Tensor(shape=torch.Size([4, 6]), device=cpu, dtype=torch.float32, is_shared=False), + action_disc: Tensor(shape=torch.Size([4, 6]), device=cpu, dtype=torch.int64, is_shared=False), + done: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([4, 17]), device=cpu, dtype=torch.float64, is_shared=False), + reward: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([4]), + device=cpu, + is_shared=False), + observation: Tensor(shape=torch.Size([4, 17]), device=cpu, dtype=torch.float64, is_shared=False), + terminated: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([4]), + device=cpu, + is_shared=False) + >>> assert r["action"].dtype == torch.float + >>> assert r["action_disc"].dtype == torch.int64 + >>> assert (r["action"] < base_env.action_spec.high).all() + >>> assert (r["action"] > base_env.action_spec.low).all() + + """ + + class SamplingStrategy(IntEnum): + """The sampling strategies for ActionDiscretizer.""" + + MEDIAN = 0 + LOW = 1 + HIGH = 2 + RANDOM = 3 + + def __init__( + self, + num_intervals: int | torch.Tensor, + action_key: NestedKey = "action", + out_action_key: NestedKey = None, + sampling=None, + categorical: bool = True, + ): + if out_action_key is None: + out_action_key = action_key + super().__init__(in_keys_inv=[action_key], out_keys_inv=[out_action_key]) + self.action_key = action_key + self.out_action_key = out_action_key + if not isinstance(num_intervals, torch.Tensor): + self.num_intervals = num_intervals + else: + self.register_buffer("num_intervals", num_intervals) + if sampling is None: + sampling = self.SamplingStrategy.MEDIAN + self.sampling = sampling + self.categorical = categorical + + def __repr__(self): + def _indent(s): + return indent(s, 4 * " ") + + num_intervals = f"num_intervals={self.num_intervals}" + action_key = f"action_key={self.action_key}" + out_action_key = f"out_action_key={self.out_action_key}" + sampling = f"sampling={self.sampling}" + categorical = f"categorical={self.categorical}" + return ( + f"{type(self).__name__}(\n{_indent(num_intervals)},\n{_indent(action_key)}," + f"\n{_indent(out_action_key)},\n{_indent(sampling)},\n{_indent(categorical)})" + ) + + def transform_input_spec(self, input_spec): + try: + action_spec = input_spec["full_action_spec", self.in_keys_inv[0]] + if not isinstance(action_spec, BoundedTensorSpec): + raise TypeError( + f"action spec type {type(action_spec)} is not supported." + ) + + n_act = action_spec.shape + if not n_act: + n_act = 1 + else: + n_act = n_act[-1] + self.n_act = n_act + + self.dtype = action_spec.dtype + interval = (action_spec.high - action_spec.low).unsqueeze(-1) + + num_intervals = self.num_intervals + + def custom_arange(nint): + result = torch.arange( + start=0.0, + end=1.0, + step=1 / nint, + dtype=self.dtype, + device=action_spec.device, + ) + result_ = result + if self.sampling in ( + self.SamplingStrategy.HIGH, + self.SamplingStrategy.MEDIAN, + ): + result_ = (1 - result).flip(0) + if self.sampling == self.SamplingStrategy.MEDIAN: + result = (result + result_) / 2 + else: + result = result_ + return result + + if isinstance(num_intervals, int): + arange = ( + custom_arange(num_intervals).expand(n_act, num_intervals) * interval + ) + self.register_buffer( + "intervals", action_spec.low.unsqueeze(-1) + arange + ) + else: + arange = [ + custom_arange(_num_intervals) * interval + for _num_intervals, interval in zip( + num_intervals.tolist(), interval.unbind(-2) + ) + ] + self.intervals = [ + low + arange + for low, arange in zip( + action_spec.low.unsqueeze(-1).unbind(-2), arange + ) + ] + + cls = ( + functools.partial(MultiDiscreteTensorSpec, remove_singleton=False) + if self.categorical + else MultiOneHotDiscreteTensorSpec + ) + + if not isinstance(num_intervals, torch.Tensor): + nvec = torch.as_tensor(num_intervals, device=action_spec.device) + else: + nvec = num_intervals + if nvec.ndim > 1: + raise RuntimeError(f"Cannot use num_intervals with shape {nvec.shape}") + if nvec.ndim == 0 or nvec.numel() == 1: + nvec = nvec.expand(action_spec.shape[-1]) + self.register_buffer("nvec", nvec) + if self.sampling == self.SamplingStrategy.RANDOM: + # compute jitters + self.jitters = interval.squeeze(-1) / nvec + shape = ( + action_spec.shape + if self.categorical + else (*action_spec.shape[:-1], nvec.sum()) + ) + action_spec = cls(nvec=nvec, shape=shape, device=action_spec.device) + input_spec["full_action_spec", self.out_keys_inv[0]] = action_spec + + if self.out_keys_inv[0] != self.in_keys_inv[0]: + input_spec["full_state_spec", self.in_keys_inv[0]] = input_spec[ + "full_action_spec", self.in_keys_inv[0] + ].clone() + del input_spec["full_action_spec", self.in_keys_inv[0]] + return input_spec + except AttributeError as err: + # To avoid silent AttributeErrors + raise RuntimeError(str(err)) + + def _init(self): + # We just need to access the action spec for everything to be initialized + try: + _ = self.container.full_action_spec + except AttributeError: + raise RuntimeError( + f"Cannot execute transform {type(self).__name__} without a parent env." + ) + + def inv(self, tensordict): + if self.out_keys_inv[0] == self.in_keys_inv[0]: + return super().inv(tensordict) + # We re-write this because we don't want to clone the TD here + return self._inv_call(tensordict) + + def _inv_call(self, tensordict): + # action is categorical, map it to desired dtype + intervals = getattr(self, "intervals", None) + if intervals is None: + self._init() + return self._inv_call(tensordict) + action = tensordict.get(self.out_keys_inv[0]) + if self.categorical: + action = action.unsqueeze(-1) + if isinstance(intervals, torch.Tensor): + action = intervals.gather(index=action, dim=-1).squeeze(-1) + else: + action = torch.stack( + [ + interval.gather(index=action, dim=-1).squeeze(-1) + for interval, action in zip(intervals, action.unbind(-2)) + ], + -1, + ) + else: + nvec = self.nvec.tolist() + action = action.split(nvec, dim=-1) + if isinstance(intervals, torch.Tensor): + intervals = intervals.unbind(-2) + action = torch.stack( + [ + intervals[action].view(action.shape[:-1]) + for (intervals, action) in zip(intervals, action) + ], + -1, + ) + + if self.sampling == self.SamplingStrategy.RANDOM: + action = action + self.jitters * torch.rand_like(self.jitters) + return tensordict.set(self.in_keys_inv[0], action)