Skip to content

Commit

Permalink
Disable Grayscale in arguments (pytorch#178)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jun 1, 2022
1 parent b3f4b30 commit a0888c0
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 5 deletions.
1 change: 1 addition & 0 deletions test/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ def test_recorder():
args = Namespace()
args.env_name = "ALE/Pong-v5"
args.env_task = ""
args.grayscale = True
args.env_library = "gym"
args.frame_skip = 1
args.center_crop = []
Expand Down
8 changes: 4 additions & 4 deletions torchrl/record/recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,10 @@ def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor:
f"got {observation_trsf.ndimension()} instead"
)
observation_trsf = observation_trsf.permute(2, 0, 1)
if self.center_crop:
observation_trsf = center_crop_fn(
observation_trsf, [self.center_crop, self.center_crop]
)
if self.center_crop:
observation_trsf = center_crop_fn(
observation_trsf, [self.center_crop, self.center_crop]
)
self.obs.append(observation_trsf.cpu().to(torch.uint8))
return observation

Expand Down
10 changes: 9 additions & 1 deletion torchrl/trainers/helpers/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ def make_env_transforms(
if args.center_crop:
env.append_transform(CenterCrop(*args.center_crop))
env.append_transform(Resize(84, 84))
env.append_transform(GrayScale())
if args.grayscale:
env.append_transform(GrayScale())
env.append_transform(CatFrames(N=args.catframes, keys=["next_pixels"]))
if stats is None:
obs_stats = {"loc": 0.0, "scale": 1.0}
Expand Down Expand Up @@ -459,6 +460,13 @@ def parser_env_args(parser: ArgumentParser) -> ArgumentParser:
default=[],
help="center crop size.",
)
parser.add_argument(
"--no_grayscale",
"--no-grayscale",
action="store_false",
dest="grayscale",
help="Disables grayscale transform.",
)
parser.add_argument(
"--max_frames_per_traj",
"--max-frames-per-traj",
Expand Down

0 comments on commit a0888c0

Please sign in to comment.