Skip to content

Commit

Permalink
CenterCrop transform (pytorch#177)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jun 1, 2022
1 parent fa4febf commit b3f4b30
Show file tree
Hide file tree
Showing 10 changed files with 128 additions and 7 deletions.
6 changes: 5 additions & 1 deletion examples/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from torchrl.envs import ParallelEnv, EnvCreator
from torchrl.envs.utils import set_exploration_mode
from torchrl.record import VideoRecorder

try:
import configargparse as argparse
Expand Down Expand Up @@ -178,7 +179,10 @@ def main(args):

# remove video recorder from recorder to have matching state_dict keys
if args.record_video:
recorder_rm = TransformedEnv(recorder.env, recorder.transform[1:])
recorder_rm = TransformedEnv(recorder.env)
for transform in recorder.transform:
if not isinstance(transform, VideoRecorder):
recorder_rm.append_transform(transform)
else:
recorder_rm = recorder

Expand Down
6 changes: 5 additions & 1 deletion examples/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from torchrl.envs import ParallelEnv, EnvCreator
from torchrl.envs.utils import set_exploration_mode
from torchrl.record import VideoRecorder

try:
import configargparse as argparse
Expand Down Expand Up @@ -153,7 +154,10 @@ def main(args):

# remove video recorder from recorder to have matching state_dict keys
if args.record_video:
recorder_rm = TransformedEnv(recorder.env, recorder.transform[1:])
recorder_rm = TransformedEnv(recorder.env)
for transform in recorder.transform:
if not isinstance(transform, VideoRecorder):
recorder_rm.append_transform(transform)
else:
recorder_rm = recorder

Expand Down
6 changes: 5 additions & 1 deletion examples/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from torchrl.envs import ParallelEnv, EnvCreator
from torchrl.envs.utils import set_exploration_mode
from torchrl.record import VideoRecorder

try:
import configargparse as argparse
Expand Down Expand Up @@ -148,7 +149,10 @@ def main(args):

# remove video recorder from recorder to have matching state_dict keys
if args.record_video:
recorder_rm = TransformedEnv(recorder.env, recorder.transform[1:])
recorder_rm = TransformedEnv(recorder.env)
for transform in recorder.transform:
if not isinstance(transform, VideoRecorder):
recorder_rm.append_transform(transform)
else:
recorder_rm = recorder

Expand Down
6 changes: 5 additions & 1 deletion examples/redq/redq.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from torchrl.envs import ParallelEnv, EnvCreator
from torchrl.envs.utils import set_exploration_mode
from torchrl.record import VideoRecorder

try:
import configargparse as argparse
Expand Down Expand Up @@ -177,7 +178,10 @@ def main(args):

# remove video recorder from recorder to have matching state_dict keys
if args.record_video:
recorder_rm = TransformedEnv(recorder.env, recorder.transform[1:])
recorder_rm = TransformedEnv(recorder.env)
for transform in recorder.transform:
if not isinstance(transform, VideoRecorder):
recorder_rm.append_transform(transform)
else:
recorder_rm = recorder

Expand Down
6 changes: 5 additions & 1 deletion examples/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from torchrl.envs import ParallelEnv, EnvCreator
from torchrl.envs.utils import set_exploration_mode
from torchrl.record import VideoRecorder

try:
import configargparse as argparse
Expand Down Expand Up @@ -172,7 +173,10 @@ def main(args):

# remove video recorder from recorder to have matching state_dict keys
if args.record_video:
recorder_rm = TransformedEnv(recorder.env, recorder.transform[1:])
recorder_rm = TransformedEnv(recorder.env)
for transform in recorder.transform:
if not isinstance(transform, VideoRecorder):
recorder_rm.append_transform(transform)
else:
recorder_rm = recorder

Expand Down
1 change: 1 addition & 0 deletions test/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ def test_recorder():
args.env_task = ""
args.env_library = "gym"
args.frame_skip = 1
args.center_crop = []
args.from_pixels = False
args.vecnorm = False
args.norm_rewards = False
Expand Down
69 changes: 69 additions & 0 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

try:
_has_tv = True
from torchvision.transforms.functional import center_crop
from torchvision.transforms.functional_tensor import (
resize,
) # as of now resize is imported from torchvision
Expand Down Expand Up @@ -41,6 +42,7 @@
"TransformedEnv",
"RewardClipping",
"Resize",
"CenterCrop",
"GrayScale",
"Compose",
"ToTensorImage",
Expand Down Expand Up @@ -512,6 +514,9 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> Compose:
t.to(dest)
return super().to(dest)

def __iter__(self):
return iter(self.transforms)

def __len__(self):
return len(self.transforms)

Expand Down Expand Up @@ -759,6 +764,70 @@ def __repr__(self) -> str:
)


class CenterCrop(ObservationTransform):
"""Crops the center of an image
Args:
w (int): resulting width
h (int, optional): resulting height. If None, then w is used (square crop).
"""

inplace = False

def __init__(
self,
w: int,
h: int = None,
keys: Optional[Sequence[str]] = None,
):
if not _has_tv:
raise ImportError(
"Torchvision not found. The Resize transform relies on "
"torchvision implementation. "
"Consider installing this dependency."
)
if keys is None:
keys = IMAGE_KEYS # default
super().__init__(keys=keys)
self.w = w
self.h = h if h else w

def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor:
observation = center_crop(observation, [self.w, self.h])
return observation

def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
if isinstance(observation_spec, CompositeSpec):
return CompositeSpec(
**{
key: self.transform_observation_spec(_obs_spec)
if key in self.keys
else _obs_spec
for key, _obs_spec in observation_spec._specs.items()
}
)
else:
_observation_spec = observation_spec
space = _observation_spec.space
if isinstance(space, ContinuousBox):
space.minimum = self._apply_transform(space.minimum)
space.maximum = self._apply_transform(space.maximum)
_observation_spec.shape = space.minimum.shape
else:
_observation_spec.shape = self._apply_transform(
torch.zeros(_observation_spec.shape)
).shape

observation_spec = _observation_spec
return observation_spec

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}("
f"w={float(self.w):4.4f}, h={float(self.h):4.4f}, "
)


class GrayScale(ObservationTransform):
"""
Turns a pixel observation to grayscale.
Expand Down
16 changes: 16 additions & 0 deletions torchrl/record/recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@

import torch

try:
from torchvision.transforms.functional import center_crop as center_crop_fn
except ImportError:
center_crop_fn = None

from torchrl.data.tensordict.tensordict import _TensorDict
from torchrl.envs.transforms import ObservationTransform, Transform

Expand All @@ -27,6 +32,7 @@ class VideoRecorder(ObservationTransform):
Default is `"next_pixels"`.
skip (int): frame interval in the output video.
Default is 2.
center_crop (int, optional): value of square center crop.
"""

def __init__(
Expand All @@ -35,6 +41,7 @@ def __init__(
tag: str,
keys: Optional[Sequence[str]] = None,
skip: int = 2,
center_crop: Optional[int] = None,
**kwargs,
) -> None:
if keys is None:
Expand All @@ -49,6 +56,11 @@ def __init__(
self.writer = writer
self.tag = tag
self.count = 0
self.center_crop = center_crop
if center_crop and not center_crop_fn:
raise ImportError(
"Could not load center_crop from torchvision. Make sure torchvision is installed."
)
self.obs = []
try:
import moviepy # noqa
Expand All @@ -75,6 +87,10 @@ def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor:
f"got {observation_trsf.ndimension()} instead"
)
observation_trsf = observation_trsf.permute(2, 0, 1)
if self.center_crop:
observation_trsf = center_crop_fn(
observation_trsf, [self.center_crop, self.center_crop]
)
self.obs.append(observation_trsf.cpu().to(torch.uint8))
return observation

Expand Down
15 changes: 15 additions & 0 deletions torchrl/trainers/helpers/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
ToTensorImage,
TransformedEnv,
VecNorm,
CenterCrop,
)
from torchrl.envs.transforms.transforms import gSDENoise
from torchrl.record.recorder import VideoRecorder
Expand Down Expand Up @@ -99,10 +100,14 @@ def make_env_transforms(
reward_loc = args.reward_loc

if len(video_tag):
center_crop = args.center_crop
if center_crop:
center_crop = center_crop[0]
env.append_transform(
VideoRecorder(
writer=writer,
tag=f"{video_tag}_{env_name}_video",
center_crop=center_crop,
),
)

Expand All @@ -115,6 +120,8 @@ def make_env_transforms(
"when pixels are being used."
)
env.append_transform(ToTensorImage())
if args.center_crop:
env.append_transform(CenterCrop(*args.center_crop))
env.append_transform(Resize(84, 84))
env.append_transform(GrayScale())
env.append_transform(CatFrames(N=args.catframes, keys=["next_pixels"]))
Expand Down Expand Up @@ -444,6 +451,14 @@ def parser_env_args(parser: ArgumentParser) -> ArgumentParser:
default=0,
help="Number of frames to concatenate through time. Default is 0 (do not use CatFrames).",
)
parser.add_argument(
"--center_crop",
"--center-crop",
type=int,
nargs="+",
default=[],
help="center crop size.",
)
parser.add_argument(
"--max_frames_per_traj",
"--max-frames-per-traj",
Expand Down
4 changes: 2 additions & 2 deletions torchrl/trainers/helpers/recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ def parser_recorder_args(parser: ArgumentParser) -> ArgumentParser:
"--record_interval",
"--record-interval",
type=int,
default=50,
default=1000,
help="number of batch collections in between two collections of validation rollouts. "
"Default=10000.",
"Default=1000.",
)
parser.add_argument(
"--record_frames",
Expand Down

0 comments on commit b3f4b30

Please sign in to comment.