Skip to content

Commit

Permalink
[Features] Make image_size a cfg param (pytorch#430)
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolas-dufour authored Sep 12, 2022
1 parent 59007c3 commit 486a11b
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
1 change: 1 addition & 0 deletions test/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ def test_recorder():
args.record_frames = 24 // args.frame_skip
args.record_interval = 2
args.catframes = 4
args.image_size = 84
args.collector_devices = ["cpu"]

N = 8
Expand Down
5 changes: 3 additions & 2 deletions torchrl/trainers/helpers/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def make_env_transforms(
env.append_transform(ToTensorImage())
if cfg.center_crop:
env.append_transform(CenterCrop(*cfg.center_crop))
env.append_transform(Resize(84, 84))
env.append_transform(Resize(cfg.image_size, cfg.image_size))
if cfg.grayscale:
env.append_transform(GrayScale())
env.append_transform(FlattenObservation(first_dim=batch_dims))
Expand Down Expand Up @@ -441,4 +441,5 @@ class EnvConfig:
max_frames_per_traj: int = 1000
# Number of steps before a reset of the environment is called (if it has not been flagged as done before).
batch_transform: bool = False
# if True, the transforms will be applied to the parallel env, and not to each individual env.
# if True, the transforms will be applied to the parallel env, and not to each individual env.\
image_size: int = 84

0 comments on commit 486a11b

Please sign in to comment.