From 486a11bb839f1635863574c6eeef4af4edbba7e5 Mon Sep 17 00:00:00 2001 From: nicolas-dufour <33259879+nicolas-dufour@users.noreply.github.com> Date: Mon, 12 Sep 2022 18:36:19 +0100 Subject: [PATCH] [Features] Make image_size a cfg param (#430) --- test/test_trainer.py | 1 + torchrl/trainers/helpers/envs.py | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/test/test_trainer.py b/test/test_trainer.py index 7d710d66424..9196880a92b 100644 --- a/test/test_trainer.py +++ b/test/test_trainer.py @@ -243,6 +243,7 @@ def test_recorder(): args.record_frames = 24 // args.frame_skip args.record_interval = 2 args.catframes = 4 + args.image_size = 84 args.collector_devices = ["cpu"] N = 8 diff --git a/torchrl/trainers/helpers/envs.py b/torchrl/trainers/helpers/envs.py index 9b31a80ce0d..fb05b91ad06 100644 --- a/torchrl/trainers/helpers/envs.py +++ b/torchrl/trainers/helpers/envs.py @@ -127,7 +127,7 @@ def make_env_transforms( env.append_transform(ToTensorImage()) if cfg.center_crop: env.append_transform(CenterCrop(*cfg.center_crop)) - env.append_transform(Resize(84, 84)) + env.append_transform(Resize(cfg.image_size, cfg.image_size)) if cfg.grayscale: env.append_transform(GrayScale()) env.append_transform(FlattenObservation(first_dim=batch_dims)) @@ -441,4 +441,5 @@ class EnvConfig: max_frames_per_traj: int = 1000 # Number of steps before a reset of the environment is called (if it has not been flagged as done before). batch_transform: bool = False - # if True, the transforms will be applied to the parallel env, and not to each individual env. + # if True, the transforms will be applied to the parallel env, and not to each individual env.\ + image_size: int = 84