Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Crop Transform #2336

Merged
merged 10 commits into from
Jul 30, 2024
Prev Previous commit
minor fixes
  • Loading branch information
albertbou92 committed Jul 30, 2024
commit 07e4e1a5cdc93e6ae0685f351c0b7ed14a4ecf49
2 changes: 1 addition & 1 deletion docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -790,10 +790,10 @@ to be able to create this other composition:
BurnInTransform
CatFrames
CatTensors
Crop
CenterCrop
ClipTransform
Compose
Crop
DTypeCastTransform
DeviceCastTransform
DiscreteActionProjection
Expand Down
42 changes: 6 additions & 36 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2178,18 +2178,8 @@ def test_transform_no_env(self, keys, h, nchannels, batch, device):
assert observation_spec[key].shape == torch.Size([nchannels, 20, h])

@pytest.mark.parametrize("nchannels", [3])
@pytest.mark.parametrize(
"batch",
[
[2]
]
)
@pytest.mark.parametrize(
"h",
[
None
]
)
@pytest.mark.parametrize("batch", [[2]])
@pytest.mark.parametrize("h", [None])
@pytest.mark.parametrize("keys", [["observation_pixels"]])
@pytest.mark.parametrize("device", get_default_devices())
def test_transform_model(self, keys, h, nchannels, batch, device):
Expand All @@ -2214,18 +2204,8 @@ def test_transform_model(self, keys, h, nchannels, batch, device):
assert (td.get("dont touch") == dont_touch).all()

@pytest.mark.parametrize("nchannels", [3])
@pytest.mark.parametrize(
"batch",
[
[2]
]
)
@pytest.mark.parametrize(
"h",
[
None
]
)
@pytest.mark.parametrize("batch", [[2]])
@pytest.mark.parametrize("h", [None])
@pytest.mark.parametrize("keys", [["observation_pixels"]])
@pytest.mark.parametrize("device", get_default_devices())
def test_transform_compose(self, keys, h, nchannels, batch, device):
Expand Down Expand Up @@ -2254,18 +2234,8 @@ def test_transform_compose(self, keys, h, nchannels, batch, device):
assert (tdc.get("dont touch") == dont_touch).all()

@pytest.mark.parametrize("nchannels", [3])
@pytest.mark.parametrize(
"batch",
[
[2]
]
)
@pytest.mark.parametrize(
"h",
[
None
]
)
@pytest.mark.parametrize("batch", [[2]])
@pytest.mark.parametrize("h", [None])
@pytest.mark.parametrize("keys", [["observation_pixels"]])
@pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer])
def test_transform_rb(
Expand Down
Loading