From bf91ff6433ff170714fe898728f66f55195735b6 Mon Sep 17 00:00:00 2001 From: Albert Bou Date: Tue, 30 Jul 2024 20:27:15 +0200 Subject: [PATCH] [Feature] Crop Transform (#2336) Co-authored-by: Vincent Moens --- docs/source/reference/envs.rst | 1 + test/test_transforms.py | 208 ++++++++++++++++++++++++++ torchrl/envs/__init__.py | 1 + torchrl/envs/transforms/__init__.py | 1 + torchrl/envs/transforms/transforms.py | 67 +++++++++ 5 files changed, 278 insertions(+) diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index e9f6f21c644..11a5bb041a6 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -793,6 +793,7 @@ to be able to create this other composition: CenterCrop ClipTransform Compose + Crop DTypeCastTransform DeviceCastTransform DiscreteActionProjection diff --git a/test/test_transforms.py b/test/test_transforms.py index bdafba648eb..94ec8b2716c 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -70,6 +70,7 @@ CenterCrop, ClipTransform, Compose, + Crop, DeviceCastTransform, DiscreteActionProjection, DMControlEnv, @@ -2135,6 +2136,213 @@ def test_transform_inverse(self): raise pytest.skip("No inverse for CatTensors") +@pytest.mark.skipif(not _has_tv, reason="no torchvision") +class TestCrop(TransformBase): + @pytest.mark.parametrize("nchannels", [1, 3]) + @pytest.mark.parametrize("batch", [[], [2], [2, 4]]) + @pytest.mark.parametrize("h", [None, 21]) + @pytest.mark.parametrize( + "keys", [["observation", ("some_other", "nested_key")], ["observation_pixels"]] + ) + @pytest.mark.parametrize("device", get_default_devices()) + def test_transform_no_env(self, keys, h, nchannels, batch, device): + torch.manual_seed(0) + dont_touch = torch.randn(*batch, nchannels, 16, 16, device=device) + crop = Crop(w=20, h=h, in_keys=keys) + if h is None: + h = 20 + td = TensorDict( + { + key: torch.randn(*batch, nchannels, 16, 16, device=device) + for key in keys + }, + batch, + device=device, + ) + td.set("dont touch", dont_touch.clone()) + crop(td) + for key in keys: + assert td.get(key).shape[-2:] == torch.Size([20, h]) + assert (td.get("dont touch") == dont_touch).all() + + if len(keys) == 1: + observation_spec = BoundedTensorSpec(-1, 1, (nchannels, 16, 16)) + observation_spec = crop.transform_observation_spec(observation_spec) + assert observation_spec.shape == torch.Size([nchannels, 20, h]) + else: + observation_spec = CompositeSpec( + {key: BoundedTensorSpec(-1, 1, (nchannels, 16, 16)) for key in keys} + ) + observation_spec = crop.transform_observation_spec(observation_spec) + for key in keys: + assert observation_spec[key].shape == torch.Size([nchannels, 20, h]) + + @pytest.mark.parametrize("nchannels", [3]) + @pytest.mark.parametrize("batch", [[2]]) + @pytest.mark.parametrize("h", [None]) + @pytest.mark.parametrize("keys", [["observation_pixels"]]) + @pytest.mark.parametrize("device", get_default_devices()) + def test_transform_model(self, keys, h, nchannels, batch, device): + torch.manual_seed(0) + dont_touch = torch.randn(*batch, nchannels, 16, 16, device=device) + crop = Crop(w=20, h=h, in_keys=keys) + if h is None: + h = 20 + td = TensorDict( + { + key: torch.randn(*batch, nchannels, 16, 16, device=device) + for key in keys + }, + batch, + device=device, + ) + td.set("dont touch", dont_touch.clone()) + model = nn.Sequential(crop, nn.Identity()) + model(td) + for key in keys: + assert td.get(key).shape[-2:] == torch.Size([20, h]) + assert (td.get("dont touch") == dont_touch).all() + + @pytest.mark.parametrize("nchannels", [3]) + @pytest.mark.parametrize("batch", [[2]]) + @pytest.mark.parametrize("h", [None]) + @pytest.mark.parametrize("keys", [["observation_pixels"]]) + @pytest.mark.parametrize("device", get_default_devices()) + def test_transform_compose(self, keys, h, nchannels, batch, device): + torch.manual_seed(0) + dont_touch = torch.randn(*batch, nchannels, 16, 16, device=device) + crop = Crop(w=20, h=h, in_keys=keys) + if h is None: + h = 20 + td = TensorDict( + { + key: torch.randn(*batch, nchannels, 16, 16, device=device) + for key in keys + }, + batch, + device=device, + ) + td.set("dont touch", dont_touch.clone()) + model = Compose(crop) + tdc = model(td.clone()) + for key in keys: + assert tdc.get(key).shape[-2:] == torch.Size([20, h]) + assert (tdc.get("dont touch") == dont_touch).all() + tdc = model._call(td.clone()) + for key in keys: + assert tdc.get(key).shape[-2:] == torch.Size([20, h]) + assert (tdc.get("dont touch") == dont_touch).all() + + @pytest.mark.parametrize("nchannels", [3]) + @pytest.mark.parametrize("batch", [[2]]) + @pytest.mark.parametrize("h", [None]) + @pytest.mark.parametrize("keys", [["observation_pixels"]]) + @pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer]) + def test_transform_rb( + self, + rbclass, + keys, + h, + nchannels, + batch, + ): + torch.manual_seed(0) + dont_touch = torch.randn( + *batch, + nchannels, + 16, + 16, + ) + crop = Crop(w=20, h=h, in_keys=keys) + if h is None: + h = 20 + td = TensorDict( + { + key: torch.randn( + *batch, + nchannels, + 16, + 16, + ) + for key in keys + }, + batch, + ) + td.set("dont touch", dont_touch.clone()) + rb = rbclass(storage=LazyTensorStorage(10)) + rb.append_transform(crop) + rb.extend(td) + td = rb.sample(10) + for key in keys: + assert td.get(key).shape[-2:] == torch.Size([20, h]) + + def test_single_trans_env_check(self): + keys = ["pixels"] + ct = Compose(ToTensorImage(), Crop(w=20, h=20, in_keys=keys)) + env = TransformedEnv(DiscreteActionConvMockEnvNumpy(), ct) + check_env_specs(env) + + def test_serial_trans_env_check(self): + keys = ["pixels"] + + def make_env(): + ct = Compose(ToTensorImage(), Crop(w=20, h=20, in_keys=keys)) + return TransformedEnv(DiscreteActionConvMockEnvNumpy(), ct) + + env = SerialEnv(2, make_env) + check_env_specs(env) + + def test_parallel_trans_env_check(self): + keys = ["pixels"] + + def make_env(): + ct = Compose(ToTensorImage(), Crop(w=20, h=20, in_keys=keys)) + return TransformedEnv(DiscreteActionConvMockEnvNumpy(), ct) + + env = ParallelEnv(2, make_env) + try: + check_env_specs(env) + finally: + try: + env.close() + except RuntimeError: + pass + + def test_trans_serial_env_check(self): + keys = ["pixels"] + ct = Compose(ToTensorImage(), Crop(w=20, h=20, in_keys=keys)) + env = TransformedEnv(SerialEnv(2, DiscreteActionConvMockEnvNumpy), ct) + check_env_specs(env) + + def test_trans_parallel_env_check(self): + keys = ["pixels"] + ct = Compose(ToTensorImage(), Crop(w=20, h=20, in_keys=keys)) + env = TransformedEnv(ParallelEnv(2, DiscreteActionConvMockEnvNumpy), ct) + try: + check_env_specs(env) + finally: + try: + env.close() + except RuntimeError: + pass + + @pytest.mark.skipif(not _has_gym, reason="No Gym detected") + @pytest.mark.parametrize("out_key", [None, ["outkey"], [("out", "key")]]) + def test_transform_env(self, out_key): + keys = ["pixels"] + ct = Compose(ToTensorImage(), Crop(out_keys=out_key, w=20, h=20, in_keys=keys)) + env = TransformedEnv(GymEnv(PONG_VERSIONED()), ct) + td = env.reset() + if out_key is None: + assert td["pixels"].shape == torch.Size([3, 20, 20]) + else: + assert td[out_key[0]].shape == torch.Size([3, 20, 20]) + check_env_specs(env) + + def test_transform_inverse(self): + raise pytest.skip("Crop does not have an inverse method.") + + @pytest.mark.skipif(not _has_tv, reason="no torchvision") class TestCenterCrop(TransformBase): @pytest.mark.parametrize("nchannels", [1, 3]) diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index cb1c9813ba0..ced185d7e00 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -51,6 +51,7 @@ CenterCrop, ClipTransform, Compose, + Crop, DeviceCastTransform, DiscreteActionProjection, DoubleToFloat, diff --git a/torchrl/envs/transforms/__init__.py b/torchrl/envs/transforms/__init__.py index 5204b8a19d8..64a25b94e37 100644 --- a/torchrl/envs/transforms/__init__.py +++ b/torchrl/envs/transforms/__init__.py @@ -20,6 +20,7 @@ CenterCrop, ClipTransform, Compose, + Crop, DeviceCastTransform, DiscreteActionProjection, DoubleToFloat, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index afc2313abdd..1a66ee489a6 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -1913,6 +1913,73 @@ def _reset( return tensordict_reset +class Crop(ObservationTransform): + """Crops the input image at the specified location and output size. + + Args: + w (int): resulting width + h (int, optional): resulting height. If None, then w is used (square crop). + top (int, optional): top pixel coordinate to start cropping. Default is 0, i.e. top of the image. + left (int, optional): left pixel coordinate to start cropping. Default is 0, i.e. left of the image. + in_keys (sequence of NestedKey, optional): the entries to crop. If none is provided, + ``["pixels"]`` is assumed. + out_keys (sequence of NestedKey, optional): the cropped images keys. If none is + provided, ``in_keys`` is assumed. + + """ + + def __init__( + self, + w: int, + h: int = None, + top: int = 0, + left: int = 0, + in_keys: Sequence[NestedKey] | None = None, + out_keys: Sequence[NestedKey] | None = None, + ): + if in_keys is None: + in_keys = IMAGE_KEYS # default + if out_keys is None: + out_keys = copy(in_keys) + super().__init__(in_keys=in_keys, out_keys=out_keys) + self.w = w + self.h = h if h else w + self.top = top + self.left = left + + def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor: + from torchvision.transforms.functional import crop + + observation = crop(observation, self.top, self.left, self.w, self.h) + return observation + + def _reset( + self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase + ) -> TensorDictBase: + with _set_missing_tolerance(self, True): + tensordict_reset = self._call(tensordict_reset) + return tensordict_reset + + @_apply_to_composite + def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + space = observation_spec.space + if isinstance(space, ContinuousBox): + space.low = self._apply_transform(space.low) + space.high = self._apply_transform(space.high) + observation_spec.shape = space.low.shape + else: + observation_spec.shape = self._apply_transform( + torch.zeros(observation_spec.shape) + ).shape + return observation_spec + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"w={float(self.w):4.4f}, h={float(self.h):4.4f}, top={float(self.top):4.4f}, left={float(self.left):4.4f}, " + ) + + class CenterCrop(ObservationTransform): """Crops the center of an image.