Skip to content

Commit

Permalink
[Feature] Crop Transform (pytorch#2336)
Browse files Browse the repository at this point in the history
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
  • Loading branch information
albertbou92 and vmoens authored Jul 30, 2024
1 parent b3f99b3 commit bf91ff6
Show file tree
Hide file tree
Showing 5 changed files with 278 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,7 @@ to be able to create this other composition:
CenterCrop
ClipTransform
Compose
Crop
DTypeCastTransform
DeviceCastTransform
DiscreteActionProjection
Expand Down
208 changes: 208 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
CenterCrop,
ClipTransform,
Compose,
Crop,
DeviceCastTransform,
DiscreteActionProjection,
DMControlEnv,
Expand Down Expand Up @@ -2135,6 +2136,213 @@ def test_transform_inverse(self):
raise pytest.skip("No inverse for CatTensors")


@pytest.mark.skipif(not _has_tv, reason="no torchvision")
class TestCrop(TransformBase):
@pytest.mark.parametrize("nchannels", [1, 3])
@pytest.mark.parametrize("batch", [[], [2], [2, 4]])
@pytest.mark.parametrize("h", [None, 21])
@pytest.mark.parametrize(
"keys", [["observation", ("some_other", "nested_key")], ["observation_pixels"]]
)
@pytest.mark.parametrize("device", get_default_devices())
def test_transform_no_env(self, keys, h, nchannels, batch, device):
torch.manual_seed(0)
dont_touch = torch.randn(*batch, nchannels, 16, 16, device=device)
crop = Crop(w=20, h=h, in_keys=keys)
if h is None:
h = 20
td = TensorDict(
{
key: torch.randn(*batch, nchannels, 16, 16, device=device)
for key in keys
},
batch,
device=device,
)
td.set("dont touch", dont_touch.clone())
crop(td)
for key in keys:
assert td.get(key).shape[-2:] == torch.Size([20, h])
assert (td.get("dont touch") == dont_touch).all()

if len(keys) == 1:
observation_spec = BoundedTensorSpec(-1, 1, (nchannels, 16, 16))
observation_spec = crop.transform_observation_spec(observation_spec)
assert observation_spec.shape == torch.Size([nchannels, 20, h])
else:
observation_spec = CompositeSpec(
{key: BoundedTensorSpec(-1, 1, (nchannels, 16, 16)) for key in keys}
)
observation_spec = crop.transform_observation_spec(observation_spec)
for key in keys:
assert observation_spec[key].shape == torch.Size([nchannels, 20, h])

@pytest.mark.parametrize("nchannels", [3])
@pytest.mark.parametrize("batch", [[2]])
@pytest.mark.parametrize("h", [None])
@pytest.mark.parametrize("keys", [["observation_pixels"]])
@pytest.mark.parametrize("device", get_default_devices())
def test_transform_model(self, keys, h, nchannels, batch, device):
torch.manual_seed(0)
dont_touch = torch.randn(*batch, nchannels, 16, 16, device=device)
crop = Crop(w=20, h=h, in_keys=keys)
if h is None:
h = 20
td = TensorDict(
{
key: torch.randn(*batch, nchannels, 16, 16, device=device)
for key in keys
},
batch,
device=device,
)
td.set("dont touch", dont_touch.clone())
model = nn.Sequential(crop, nn.Identity())
model(td)
for key in keys:
assert td.get(key).shape[-2:] == torch.Size([20, h])
assert (td.get("dont touch") == dont_touch).all()

@pytest.mark.parametrize("nchannels", [3])
@pytest.mark.parametrize("batch", [[2]])
@pytest.mark.parametrize("h", [None])
@pytest.mark.parametrize("keys", [["observation_pixels"]])
@pytest.mark.parametrize("device", get_default_devices())
def test_transform_compose(self, keys, h, nchannels, batch, device):
torch.manual_seed(0)
dont_touch = torch.randn(*batch, nchannels, 16, 16, device=device)
crop = Crop(w=20, h=h, in_keys=keys)
if h is None:
h = 20
td = TensorDict(
{
key: torch.randn(*batch, nchannels, 16, 16, device=device)
for key in keys
},
batch,
device=device,
)
td.set("dont touch", dont_touch.clone())
model = Compose(crop)
tdc = model(td.clone())
for key in keys:
assert tdc.get(key).shape[-2:] == torch.Size([20, h])
assert (tdc.get("dont touch") == dont_touch).all()
tdc = model._call(td.clone())
for key in keys:
assert tdc.get(key).shape[-2:] == torch.Size([20, h])
assert (tdc.get("dont touch") == dont_touch).all()

@pytest.mark.parametrize("nchannels", [3])
@pytest.mark.parametrize("batch", [[2]])
@pytest.mark.parametrize("h", [None])
@pytest.mark.parametrize("keys", [["observation_pixels"]])
@pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer])
def test_transform_rb(
self,
rbclass,
keys,
h,
nchannels,
batch,
):
torch.manual_seed(0)
dont_touch = torch.randn(
*batch,
nchannels,
16,
16,
)
crop = Crop(w=20, h=h, in_keys=keys)
if h is None:
h = 20
td = TensorDict(
{
key: torch.randn(
*batch,
nchannels,
16,
16,
)
for key in keys
},
batch,
)
td.set("dont touch", dont_touch.clone())
rb = rbclass(storage=LazyTensorStorage(10))
rb.append_transform(crop)
rb.extend(td)
td = rb.sample(10)
for key in keys:
assert td.get(key).shape[-2:] == torch.Size([20, h])

def test_single_trans_env_check(self):
keys = ["pixels"]
ct = Compose(ToTensorImage(), Crop(w=20, h=20, in_keys=keys))
env = TransformedEnv(DiscreteActionConvMockEnvNumpy(), ct)
check_env_specs(env)

def test_serial_trans_env_check(self):
keys = ["pixels"]

def make_env():
ct = Compose(ToTensorImage(), Crop(w=20, h=20, in_keys=keys))
return TransformedEnv(DiscreteActionConvMockEnvNumpy(), ct)

env = SerialEnv(2, make_env)
check_env_specs(env)

def test_parallel_trans_env_check(self):
keys = ["pixels"]

def make_env():
ct = Compose(ToTensorImage(), Crop(w=20, h=20, in_keys=keys))
return TransformedEnv(DiscreteActionConvMockEnvNumpy(), ct)

env = ParallelEnv(2, make_env)
try:
check_env_specs(env)
finally:
try:
env.close()
except RuntimeError:
pass

def test_trans_serial_env_check(self):
keys = ["pixels"]
ct = Compose(ToTensorImage(), Crop(w=20, h=20, in_keys=keys))
env = TransformedEnv(SerialEnv(2, DiscreteActionConvMockEnvNumpy), ct)
check_env_specs(env)

def test_trans_parallel_env_check(self):
keys = ["pixels"]
ct = Compose(ToTensorImage(), Crop(w=20, h=20, in_keys=keys))
env = TransformedEnv(ParallelEnv(2, DiscreteActionConvMockEnvNumpy), ct)
try:
check_env_specs(env)
finally:
try:
env.close()
except RuntimeError:
pass

@pytest.mark.skipif(not _has_gym, reason="No Gym detected")
@pytest.mark.parametrize("out_key", [None, ["outkey"], [("out", "key")]])
def test_transform_env(self, out_key):
keys = ["pixels"]
ct = Compose(ToTensorImage(), Crop(out_keys=out_key, w=20, h=20, in_keys=keys))
env = TransformedEnv(GymEnv(PONG_VERSIONED()), ct)
td = env.reset()
if out_key is None:
assert td["pixels"].shape == torch.Size([3, 20, 20])
else:
assert td[out_key[0]].shape == torch.Size([3, 20, 20])
check_env_specs(env)

def test_transform_inverse(self):
raise pytest.skip("Crop does not have an inverse method.")


@pytest.mark.skipif(not _has_tv, reason="no torchvision")
class TestCenterCrop(TransformBase):
@pytest.mark.parametrize("nchannels", [1, 3])
Expand Down
1 change: 1 addition & 0 deletions torchrl/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
CenterCrop,
ClipTransform,
Compose,
Crop,
DeviceCastTransform,
DiscreteActionProjection,
DoubleToFloat,
Expand Down
1 change: 1 addition & 0 deletions torchrl/envs/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
CenterCrop,
ClipTransform,
Compose,
Crop,
DeviceCastTransform,
DiscreteActionProjection,
DoubleToFloat,
Expand Down
67 changes: 67 additions & 0 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1913,6 +1913,73 @@ def _reset(
return tensordict_reset


class Crop(ObservationTransform):
"""Crops the input image at the specified location and output size.
Args:
w (int): resulting width
h (int, optional): resulting height. If None, then w is used (square crop).
top (int, optional): top pixel coordinate to start cropping. Default is 0, i.e. top of the image.
left (int, optional): left pixel coordinate to start cropping. Default is 0, i.e. left of the image.
in_keys (sequence of NestedKey, optional): the entries to crop. If none is provided,
``["pixels"]`` is assumed.
out_keys (sequence of NestedKey, optional): the cropped images keys. If none is
provided, ``in_keys`` is assumed.
"""

def __init__(
self,
w: int,
h: int = None,
top: int = 0,
left: int = 0,
in_keys: Sequence[NestedKey] | None = None,
out_keys: Sequence[NestedKey] | None = None,
):
if in_keys is None:
in_keys = IMAGE_KEYS # default
if out_keys is None:
out_keys = copy(in_keys)
super().__init__(in_keys=in_keys, out_keys=out_keys)
self.w = w
self.h = h if h else w
self.top = top
self.left = left

def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor:
from torchvision.transforms.functional import crop

observation = crop(observation, self.top, self.left, self.w, self.h)
return observation

def _reset(
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
) -> TensorDictBase:
with _set_missing_tolerance(self, True):
tensordict_reset = self._call(tensordict_reset)
return tensordict_reset

@_apply_to_composite
def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
space = observation_spec.space
if isinstance(space, ContinuousBox):
space.low = self._apply_transform(space.low)
space.high = self._apply_transform(space.high)
observation_spec.shape = space.low.shape
else:
observation_spec.shape = self._apply_transform(
torch.zeros(observation_spec.shape)
).shape
return observation_spec

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}("
f"w={float(self.w):4.4f}, h={float(self.h):4.4f}, top={float(self.top):4.4f}, left={float(self.left):4.4f}, "
)


class CenterCrop(ObservationTransform):
"""Crops the center of an image.
Expand Down

0 comments on commit bf91ff6

Please sign in to comment.