Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Crop Transform #2336

Merged
merged 10 commits into from
Jul 30, 2024
1 change: 1 addition & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,7 @@ to be able to create this other composition:
CenterCrop
ClipTransform
Compose
Crop
DTypeCastTransform
DeviceCastTransform
DiscreteActionProjection
Expand Down
208 changes: 208 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
CenterCrop,
ClipTransform,
Compose,
Crop,
DeviceCastTransform,
DiscreteActionProjection,
DMControlEnv,
Expand Down Expand Up @@ -2135,6 +2136,213 @@ def test_transform_inverse(self):
raise pytest.skip("No inverse for CatTensors")


@pytest.mark.skipif(not _has_tv, reason="no torchvision")
class TestCrop(TransformBase):
@pytest.mark.parametrize("nchannels", [1, 3])
@pytest.mark.parametrize("batch", [[], [2], [2, 4]])
@pytest.mark.parametrize("h", [None, 21])
@pytest.mark.parametrize(
"keys", [["observation", ("some_other", "nested_key")], ["observation_pixels"]]
)
@pytest.mark.parametrize("device", get_default_devices())
def test_transform_no_env(self, keys, h, nchannels, batch, device):
torch.manual_seed(0)
dont_touch = torch.randn(*batch, nchannels, 16, 16, device=device)
crop = Crop(w=20, h=h, in_keys=keys)
if h is None:
h = 20
td = TensorDict(
{
key: torch.randn(*batch, nchannels, 16, 16, device=device)
for key in keys
},
batch,
device=device,
)
td.set("dont touch", dont_touch.clone())
crop(td)
for key in keys:
assert td.get(key).shape[-2:] == torch.Size([20, h])
assert (td.get("dont touch") == dont_touch).all()

if len(keys) == 1:
observation_spec = BoundedTensorSpec(-1, 1, (nchannels, 16, 16))
observation_spec = crop.transform_observation_spec(observation_spec)
assert observation_spec.shape == torch.Size([nchannels, 20, h])
else:
observation_spec = CompositeSpec(
{key: BoundedTensorSpec(-1, 1, (nchannels, 16, 16)) for key in keys}
)
observation_spec = crop.transform_observation_spec(observation_spec)
for key in keys:
assert observation_spec[key].shape == torch.Size([nchannels, 20, h])

@pytest.mark.parametrize("nchannels", [3])
@pytest.mark.parametrize("batch", [[2]])
@pytest.mark.parametrize("h", [None])
@pytest.mark.parametrize("keys", [["observation_pixels"]])
@pytest.mark.parametrize("device", get_default_devices())
def test_transform_model(self, keys, h, nchannels, batch, device):
torch.manual_seed(0)
dont_touch = torch.randn(*batch, nchannels, 16, 16, device=device)
crop = Crop(w=20, h=h, in_keys=keys)
if h is None:
h = 20
td = TensorDict(
{
key: torch.randn(*batch, nchannels, 16, 16, device=device)
for key in keys
},
batch,
device=device,
)
td.set("dont touch", dont_touch.clone())
model = nn.Sequential(crop, nn.Identity())
model(td)
for key in keys:
assert td.get(key).shape[-2:] == torch.Size([20, h])
assert (td.get("dont touch") == dont_touch).all()

@pytest.mark.parametrize("nchannels", [3])
@pytest.mark.parametrize("batch", [[2]])
@pytest.mark.parametrize("h", [None])
@pytest.mark.parametrize("keys", [["observation_pixels"]])
@pytest.mark.parametrize("device", get_default_devices())
def test_transform_compose(self, keys, h, nchannels, batch, device):
torch.manual_seed(0)
dont_touch = torch.randn(*batch, nchannels, 16, 16, device=device)
crop = Crop(w=20, h=h, in_keys=keys)
if h is None:
h = 20
td = TensorDict(
{
key: torch.randn(*batch, nchannels, 16, 16, device=device)
for key in keys
},
batch,
device=device,
)
td.set("dont touch", dont_touch.clone())
model = Compose(crop)
tdc = model(td.clone())
for key in keys:
assert tdc.get(key).shape[-2:] == torch.Size([20, h])
assert (tdc.get("dont touch") == dont_touch).all()
tdc = model._call(td.clone())
for key in keys:
assert tdc.get(key).shape[-2:] == torch.Size([20, h])
assert (tdc.get("dont touch") == dont_touch).all()

@pytest.mark.parametrize("nchannels", [3])
@pytest.mark.parametrize("batch", [[2]])
@pytest.mark.parametrize("h", [None])
@pytest.mark.parametrize("keys", [["observation_pixels"]])
@pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer])
def test_transform_rb(
self,
rbclass,
keys,
h,
nchannels,
batch,
):
torch.manual_seed(0)
dont_touch = torch.randn(
*batch,
nchannels,
16,
16,
)
crop = Crop(w=20, h=h, in_keys=keys)
if h is None:
h = 20
td = TensorDict(
{
key: torch.randn(
*batch,
nchannels,
16,
16,
)
for key in keys
},
batch,
)
td.set("dont touch", dont_touch.clone())
rb = rbclass(storage=LazyTensorStorage(10))
rb.append_transform(crop)
rb.extend(td)
td = rb.sample(10)
for key in keys:
assert td.get(key).shape[-2:] == torch.Size([20, h])

def test_single_trans_env_check(self):
keys = ["pixels"]
ct = Compose(ToTensorImage(), Crop(w=20, h=20, in_keys=keys))
env = TransformedEnv(DiscreteActionConvMockEnvNumpy(), ct)
check_env_specs(env)

def test_serial_trans_env_check(self):
keys = ["pixels"]

def make_env():
ct = Compose(ToTensorImage(), Crop(w=20, h=20, in_keys=keys))
return TransformedEnv(DiscreteActionConvMockEnvNumpy(), ct)

env = SerialEnv(2, make_env)
check_env_specs(env)

def test_parallel_trans_env_check(self):
keys = ["pixels"]

def make_env():
ct = Compose(ToTensorImage(), Crop(w=20, h=20, in_keys=keys))
return TransformedEnv(DiscreteActionConvMockEnvNumpy(), ct)

env = ParallelEnv(2, make_env)
try:
check_env_specs(env)
finally:
try:
env.close()
except RuntimeError:
pass

def test_trans_serial_env_check(self):
keys = ["pixels"]
ct = Compose(ToTensorImage(), Crop(w=20, h=20, in_keys=keys))
env = TransformedEnv(SerialEnv(2, DiscreteActionConvMockEnvNumpy), ct)
check_env_specs(env)

def test_trans_parallel_env_check(self):
keys = ["pixels"]
ct = Compose(ToTensorImage(), Crop(w=20, h=20, in_keys=keys))
env = TransformedEnv(ParallelEnv(2, DiscreteActionConvMockEnvNumpy), ct)
try:
check_env_specs(env)
finally:
try:
env.close()
except RuntimeError:
pass

@pytest.mark.skipif(not _has_gym, reason="No Gym detected")
@pytest.mark.parametrize("out_key", [None, ["outkey"], [("out", "key")]])
def test_transform_env(self, out_key):
keys = ["pixels"]
ct = Compose(ToTensorImage(), Crop(out_keys=out_key, w=20, h=20, in_keys=keys))
env = TransformedEnv(GymEnv(PONG_VERSIONED()), ct)
td = env.reset()
if out_key is None:
assert td["pixels"].shape == torch.Size([3, 20, 20])
else:
assert td[out_key[0]].shape == torch.Size([3, 20, 20])
check_env_specs(env)

def test_transform_inverse(self):
raise pytest.skip("Crop does not have an inverse method.")


@pytest.mark.skipif(not _has_tv, reason="no torchvision")
class TestCenterCrop(TransformBase):
@pytest.mark.parametrize("nchannels", [1, 3])
Expand Down
1 change: 1 addition & 0 deletions torchrl/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
CenterCrop,
ClipTransform,
Compose,
Crop,
DeviceCastTransform,
DiscreteActionProjection,
DoubleToFloat,
Expand Down
1 change: 1 addition & 0 deletions torchrl/envs/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
CenterCrop,
ClipTransform,
Compose,
Crop,
DeviceCastTransform,
DiscreteActionProjection,
DoubleToFloat,
Expand Down
67 changes: 67 additions & 0 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1913,6 +1913,73 @@ def _reset(
return tensordict_reset


class Crop(ObservationTransform):
"""Crops the input image at the specified location and output size.

Args:
w (int): resulting width
h (int, optional): resulting height. If None, then w is used (square crop).
top (int, optional): top pixel coordinate to start cropping. Default is 0, i.e. top of the image.
left (int, optional): left pixel coordinate to start cropping. Default is 0, i.e. left of the image.
in_keys (sequence of NestedKey, optional): the entries to crop. If none is provided,
``["pixels"]`` is assumed.
out_keys (sequence of NestedKey, optional): the cropped images keys. If none is
provided, ``in_keys`` is assumed.

"""

def __init__(
self,
w: int,
h: int = None,
top: int = 0,
left: int = 0,
in_keys: Sequence[NestedKey] | None = None,
out_keys: Sequence[NestedKey] | None = None,
):
if in_keys is None:
in_keys = IMAGE_KEYS # default
if out_keys is None:
out_keys = copy(in_keys)
super().__init__(in_keys=in_keys, out_keys=out_keys)
self.w = w
self.h = h if h else w
self.top = top
self.left = left

def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor:
from torchvision.transforms.functional import crop

observation = crop(observation, self.top, self.left, self.w, self.h)
return observation

def _reset(
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
) -> TensorDictBase:
with _set_missing_tolerance(self, True):
tensordict_reset = self._call(tensordict_reset)
return tensordict_reset

@_apply_to_composite
def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
space = observation_spec.space
if isinstance(space, ContinuousBox):
space.low = self._apply_transform(space.low)
space.high = self._apply_transform(space.high)
observation_spec.shape = space.low.shape
else:
observation_spec.shape = self._apply_transform(
torch.zeros(observation_spec.shape)
).shape
return observation_spec

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}("
f"w={float(self.w):4.4f}, h={float(self.h):4.4f}, top={float(self.top):4.4f}, left={float(self.left):4.4f}, "
)


class CenterCrop(ObservationTransform):
"""Crops the center of an image.

Expand Down
Loading