From 3781512439ece966d3c32d274233c0f8a89f80fb Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 18 Jul 2022 19:48:03 -0400 Subject: [PATCH] [BugFix]: Fix examples (#290) --- examples/EXAMPLES.md | 22 +++++++++++++ examples/ddpg/config.yaml | 36 ++++++++++++++++++++++ examples/ddpg/configs/cheetah.txt | 10 ------ examples/ddpg/configs/halfcheetah.txt | 10 ------ examples/ddpg/configs/humanoid.txt | 18 ----------- examples/ddpg/configs/humanoid_pixels.txt | 22 ------------- examples/ddpg/ddpg.py | 20 +++--------- examples/dqn/config.yaml | 34 ++++++++++++++++++++ examples/dqn/configs/pong.txt | 20 ------------ examples/dqn/configs/pong_smoketest.txt | 24 --------------- examples/dqn/dqn.py | 20 +++--------- examples/ppo/config.yaml | 26 ++++++++++++++++ examples/ppo/configs/cheetah.txt | 15 --------- examples/ppo/configs/cheetah_pixels.txt | 12 -------- examples/ppo/configs/cheetah_smoketest.txt | 17 ---------- examples/ppo/configs/humanoid.txt | 13 -------- examples/ppo/ppo.py | 6 ++-- examples/redq/config.yaml | 36 ++++++++++++++++++++++ examples/redq/configs/cheetah.txt | 16 ---------- examples/redq/configs/humanoid.txt | 19 ------------ examples/redq/configs/humanoid_pixels.txt | 23 -------------- examples/redq/redq.py | 20 +++--------- examples/sac/config.yaml | 36 ++++++++++++++++++++++ examples/sac/configs/cheetah.txt | 16 ---------- examples/sac/configs/humanoid.txt | 19 ------------ examples/sac/sac.py | 20 +++--------- torchrl/trainers/helpers/envs.py | 9 ++++-- torchrl/trainers/trainers.py | 4 +-- 28 files changed, 219 insertions(+), 324 deletions(-) create mode 100644 examples/EXAMPLES.md create mode 100644 examples/ddpg/config.yaml delete mode 100644 examples/ddpg/configs/cheetah.txt delete mode 100644 examples/ddpg/configs/halfcheetah.txt delete mode 100644 examples/ddpg/configs/humanoid.txt delete mode 100644 examples/ddpg/configs/humanoid_pixels.txt create mode 100644 examples/dqn/config.yaml delete mode 100644 examples/dqn/configs/pong.txt delete mode 100644 examples/dqn/configs/pong_smoketest.txt create mode 100644 examples/ppo/config.yaml delete mode 100644 examples/ppo/configs/cheetah.txt delete mode 100644 examples/ppo/configs/cheetah_pixels.txt delete mode 100644 examples/ppo/configs/cheetah_smoketest.txt delete mode 100644 examples/ppo/configs/humanoid.txt create mode 100644 examples/redq/config.yaml delete mode 100644 examples/redq/configs/cheetah.txt delete mode 100644 examples/redq/configs/humanoid.txt delete mode 100644 examples/redq/configs/humanoid_pixels.txt create mode 100644 examples/sac/config.yaml delete mode 100644 examples/sac/configs/cheetah.txt delete mode 100644 examples/sac/configs/humanoid.txt diff --git a/examples/EXAMPLES.md b/examples/EXAMPLES.md new file mode 100644 index 00000000000..ec91746a5c3 --- /dev/null +++ b/examples/EXAMPLES.md @@ -0,0 +1,22 @@ +# Examples + +We provide examples to train the following algorithms: +- [DQN](dqn/dqn.py) +- [DDPG](ddpg/ddpg.py) +- [SAC](sac/sac.py) +- [REDQ](redq/redq.py) +- [PPO](ppo/ppo.py) + +To run these examples, make sure you have installed hydra: +``` +pip install hydra-code +``` + +Then, go to the directory that interests you and run +``` +python sac.py +``` +or similar. Hyperparameters can be easily changed by providing the arguments to hydra: +``` +python sac frames_per_batch=63 +``` diff --git a/examples/ddpg/config.yaml b/examples/ddpg/config.yaml new file mode 100644 index 00000000000..5ad3912c0ef --- /dev/null +++ b/examples/ddpg/config.yaml @@ -0,0 +1,36 @@ +env_name: HalfCheetah-v4 +env_task: "" +env_library: gym +async_collection: 1 +record_video: 0 +normalize_rewards_online: 1 +normalize_rewards_online_scale: 5 +frame_skip: 1 +frames_per_batch: 1024 +optim_steps_per_batch: 128 +batch_size: 256 +total_frames: 1000000 +prb: 1 +lr: 3e-4 +ou_exploration: 1 +multi_step: 1 +init_random_frames: 25000 +activation: elu +gSDE: 0 +from_pixels: 0 +#collector_devices: [cuda:1,cuda:1,cuda:1,cuda:1] +collector_devices: [cpu,cpu,cpu,cpu] +env_per_collector: 8 +num_workers: 32 +lr_scheduler: "" +value_network_update_interval: 200 +record_interval: 10 +max_frames_per_traj: -1 +weight_decay: 0.0 +annealing_frames: 1000000 +init_env_steps: 10000 +record_frames: 10000 +loss_function: smooth_l1 +batch_transform: 1 +buffer_prefetch: 64 +norm_stats: 1 diff --git a/examples/ddpg/configs/cheetah.txt b/examples/ddpg/configs/cheetah.txt deleted file mode 100644 index 010b65e3833..00000000000 --- a/examples/ddpg/configs/cheetah.txt +++ /dev/null @@ -1,10 +0,0 @@ -env_name=cheetah -env_task=run -env_library=dm_control -async_collection -record_video -normalize_rewards_online -frames_per_batch=256 -frame_skip=4 -optim_steps_per_batch=4 -batch_size=64 diff --git a/examples/ddpg/configs/halfcheetah.txt b/examples/ddpg/configs/halfcheetah.txt deleted file mode 100644 index 55f2b59e85e..00000000000 --- a/examples/ddpg/configs/halfcheetah.txt +++ /dev/null @@ -1,10 +0,0 @@ -env_name=HalfCheetah-v2 -env_task= -env_library=gym -async_collection -record_video -normalize_rewards_online -frames_per_batch=256 -frame_skip=4 -optim_steps_per_batch=4 -batch_size=64 diff --git a/examples/ddpg/configs/humanoid.txt b/examples/ddpg/configs/humanoid.txt deleted file mode 100644 index 3dc45568499..00000000000 --- a/examples/ddpg/configs/humanoid.txt +++ /dev/null @@ -1,18 +0,0 @@ -env_name=humanoid -env_task=walk -env_library=dm_control -async_collection -record_video -prb -normalize_rewards_online -normalize_rewards_online_scale=10 -exp_name=humanoid - -num_workers=8 -env_per_collector=2 - -frame_skip=2 -frames_per_batch=500 -optim_steps_per_batch=80 -batch_size=128 -total_frames=5000000 diff --git a/examples/ddpg/configs/humanoid_pixels.txt b/examples/ddpg/configs/humanoid_pixels.txt deleted file mode 100644 index 19ef5a91629..00000000000 --- a/examples/ddpg/configs/humanoid_pixels.txt +++ /dev/null @@ -1,22 +0,0 @@ -env_name=humanoid -env_task=walk -env_library=dm_control -record_video -prb -exp_name=humanoid - -num_workers=4 -env_per_collector=1 - -frame_skip=2 -frames_per_batch=500 -optim_steps_per_batch=80 -batch_size=128 -total_frames=5000000 - -from_pixels -activation=elu -lr=0.0002 -weight_decay=2e-5 - -catframes=4 diff --git a/examples/ddpg/ddpg.py b/examples/ddpg/ddpg.py index 196930471ae..52921524a88 100644 --- a/examples/ddpg/ddpg.py +++ b/examples/ddpg/ddpg.py @@ -66,7 +66,7 @@ } -@hydra.main(version_base=None, config_path=None, config_name="config") +@hydra.main(version_base=None, config_path=".", config_name="config") def main(cfg: "DictConfig"): from torchrl.trainers.loggers import TensorboardLogger @@ -96,7 +96,9 @@ def main(cfg: "DictConfig"): if not cfg.vecnorm and cfg.norm_stats: proof_env = transformed_env_constructor(cfg=cfg, use_env_creator=False)() stats = get_stats_random_rollout( - cfg, proof_env, key="next_pixels" if cfg.from_pixels else None + cfg, + proof_env, + key="next_pixels" if cfg.from_pixels else "next_observation_vector", ) # make sure proof_env is closed proof_env.close() @@ -199,20 +201,6 @@ def main(cfg: "DictConfig"): cfg, ) - def select_keys(batch): - return batch.select( - "reward", - "done", - "steps_to_next_obs", - "pixels", - "next_pixels", - "observation_vector", - "next_observation_vector", - "action", - ) - - trainer.register_op("batch_process", select_keys) - final_seed = collector.set_seed(cfg.seed) print(f"init seed: {cfg.seed}, final seed: {final_seed}") diff --git a/examples/dqn/config.yaml b/examples/dqn/config.yaml new file mode 100644 index 00000000000..bc475863db1 --- /dev/null +++ b/examples/dqn/config.yaml @@ -0,0 +1,34 @@ +env_name: ALE/Pong-v5 +env_task: "" +env_library: gym +async_collection: 1 +record_video: 1 +normalize_rewards_online: 1 +normalize_rewards_online_scale: 5 +frame_skip: 4 +noops: 30 +frames_per_batch: 1024 +optim_steps_per_batch: 128 +batch_size: 256 +total_frames: 1000000 +prb: 1 +lr: 3e-4 +multi_step: 1 +init_random_frames: 25000 +from_pixels: 1 +#collector_devices: [cuda:1,cuda:1,cuda:1,cuda:1] +collector_devices: [cpu,cpu,cpu,cpu] +env_per_collector: 8 +num_workers: 32 +lr_scheduler: "" +value_network_update_interval: 200 +record_interval: 10 +max_frames_per_traj: -1 +weight_decay: 0.0 +annealing_frames: 1000000 +init_env_steps: 10000 +record_frames: 50000 +loss_function: smooth_l1 +batch_transform: 1 +buffer_prefetch: 64 +catframes: 4 diff --git a/examples/dqn/configs/pong.txt b/examples/dqn/configs/pong.txt deleted file mode 100644 index 65502156f3b..00000000000 --- a/examples/dqn/configs/pong.txt +++ /dev/null @@ -1,20 +0,0 @@ -frames_per_batch=256 -frame_skip=4 -optim_steps_per_batch=4 -batch_size=64 -env_library=gym -env_name=ALE/Pong-v5 -noops=30 -max_frames_per_traj=-1 -exp_name=pong -record_interval=10000 -async_collection -distributional -prb -multi_step -annealing_frames=50000000 -record_frames=50000 -normalize_rewards_online -from_pixels -record_video -catframes=4 diff --git a/examples/dqn/configs/pong_smoketest.txt b/examples/dqn/configs/pong_smoketest.txt deleted file mode 100644 index fa43fbda879..00000000000 --- a/examples/dqn/configs/pong_smoketest.txt +++ /dev/null @@ -1,24 +0,0 @@ -frames_per_batch=32 -frame_skip=4 -optim_steps_per_batch=3 -env_library=gym -env_name=ALE/Pong-v5 -noops=30 -max_frames_per_traj=-1 -exp_name=pong -record_interval=23 -batch_size=32 -async_collection -distributional -prb -multi_step -annealing_frames=500 -total_frames=500 -record_frames=30 -normalize_rewards_online -from_pixels -record_video -num_workers=4 -env_per_collector=2 -init_random_frames=7 -catframes=4 diff --git a/examples/dqn/dqn.py b/examples/dqn/dqn.py index 445b7f8b77d..cb26f9884e2 100644 --- a/examples/dqn/dqn.py +++ b/examples/dqn/dqn.py @@ -56,7 +56,7 @@ cs.store(name="config", node=Config) -@hydra.main(version_base=None, config_path=None, config_name="config") +@hydra.main(version_base=None, config_path=".", config_name="config") def main(cfg: "DictConfig"): from torchrl.trainers.loggers import TensorboardLogger @@ -87,7 +87,9 @@ def main(cfg: "DictConfig"): if not cfg.vecnorm and cfg.norm_stats: proof_env = transformed_env_constructor(cfg=cfg, use_env_creator=False)() stats = get_stats_random_rollout( - cfg, proof_env, key="next_pixels" if cfg.from_pixels else None + cfg, + proof_env, + key="next_pixels" if cfg.from_pixels else "next_observation_vector", ) # make sure proof_env is closed proof_env.close() @@ -169,20 +171,6 @@ def main(cfg: "DictConfig"): cfg, ) - def select_keys(batch): - return batch.select( - "reward", - "done", - "steps_to_next_obs", - "pixels", - "next_pixels", - "observation_vector", - "next_observation_vector", - "action", - ) - - trainer.register_op("batch_process", select_keys) - final_seed = collector.set_seed(cfg.seed) print(f"init seed: {cfg.seed}, final seed: {final_seed}") diff --git a/examples/ppo/config.yaml b/examples/ppo/config.yaml new file mode 100644 index 00000000000..b8167c519a3 --- /dev/null +++ b/examples/ppo/config.yaml @@ -0,0 +1,26 @@ +env_name: HalfCheetah-v4 +env_task: "" +env_library: gym +async_collection: 0 +record_video: 1 +normalize_rewards_online: 1 +normalize_rewards_online_scale: 5 +frame_skip: 1 +frames_per_batch: 1000 +optim_steps_per_batch: 10 +batch_size: 256 +total_frames: 1000000 +lr: 3e-4 +from_pixels: 0 +#collector_devices: [cuda:1] +collector_devices: [cpu] +env_per_collector: 4 +num_workers: 4 +lr_scheduler: "" +record_interval: 100 +max_frames_per_traj: -1 +weight_decay: 0.0 +init_env_steps: 10000 +record_frames: 50000 +loss_function: smooth_l1 +batch_transform: 1 diff --git a/examples/ppo/configs/cheetah.txt b/examples/ppo/configs/cheetah.txt deleted file mode 100644 index 0928c576093..00000000000 --- a/examples/ppo/configs/cheetah.txt +++ /dev/null @@ -1,15 +0,0 @@ -env_name=cheetah -env_task=run -env_library=dm_control -optim_steps_per_batch=10 -lmbda=0.95 -normalize_rewards_online -record_video -max_frames_per_traj=1000 -record_interval=200 -lr=3e-4 -tanh_loc -entropy_coef=0.1 -clip_norm=1000.0 -frames_per_batch=3200 -frame_skip=4 diff --git a/examples/ppo/configs/cheetah_pixels.txt b/examples/ppo/configs/cheetah_pixels.txt deleted file mode 100644 index ab8eee0d187..00000000000 --- a/examples/ppo/configs/cheetah_pixels.txt +++ /dev/null @@ -1,12 +0,0 @@ -env_name=cheetah -env_task=run -env_library=dm_control -optim_steps_per_batch=10 -lmbda=0.95 -normalize_rewards_online -record_video -max_frames_per_traj=1000 -record_interval=200 -tanh_loc -init_with_lag -catframes=4 diff --git a/examples/ppo/configs/cheetah_smoketest.txt b/examples/ppo/configs/cheetah_smoketest.txt deleted file mode 100644 index 3c63618801b..00000000000 --- a/examples/ppo/configs/cheetah_smoketest.txt +++ /dev/null @@ -1,17 +0,0 @@ -env_name=cheetah -env_task=run -env_library=dm_control -optim_steps_per_batch=3 -frames_per_batch=500 -lamda=0.95 -normalize_rewards_online -record_video -max_frames_per_traj=100 -record_interval=4 -lr=3e-4 -tanh_loc -init_with_lag -frame_skip=4 -num_workers=4 -env_per_collector=2 -total_frames=5000 diff --git a/examples/ppo/configs/humanoid.txt b/examples/ppo/configs/humanoid.txt deleted file mode 100644 index 0a736bead8f..00000000000 --- a/examples/ppo/configs/humanoid.txt +++ /dev/null @@ -1,13 +0,0 @@ -env_name=humanoid -env_task=walk -env_library=dm_control -optim_steps_per_batch=10 -lmbda=0.95 -normalize_rewards_online -record_video -max_frames_per_traj=1000 -record_interval=200 -lr=3e-4 -entropy_coef=1e-4 -frame_skip=4 -tanh_loc diff --git a/examples/ppo/ppo.py b/examples/ppo/ppo.py index 0aa9b093d7e..eb524618a31 100644 --- a/examples/ppo/ppo.py +++ b/examples/ppo/ppo.py @@ -51,7 +51,7 @@ cs.store(name="config", node=Config) -@hydra.main(version_base=None, config_path=None, config_name="config") +@hydra.main(version_base=None, config_path=".", config_name="config") def main(cfg: "DictConfig"): from torchrl.trainers.loggers import TensorboardLogger @@ -81,7 +81,9 @@ def main(cfg: "DictConfig"): if not cfg.vecnorm and cfg.norm_stats: proof_env = transformed_env_constructor(cfg=cfg, use_env_creator=False)() stats = get_stats_random_rollout( - cfg, proof_env, key="next_pixels" if cfg.from_pixels else None + cfg, + proof_env, + key="next_pixels" if cfg.from_pixels else "next_observation_vector", ) # make sure proof_env is closed proof_env.close() diff --git a/examples/redq/config.yaml b/examples/redq/config.yaml new file mode 100644 index 00000000000..e595c3db4ff --- /dev/null +++ b/examples/redq/config.yaml @@ -0,0 +1,36 @@ +env_name: HalfCheetah-v4 +env_task: "" +env_library: gym +async_collection: 1 +record_video: 0 +normalize_rewards_online: 1 +normalize_rewards_online_scale: 5 +frame_skip: 1 +frames_per_batch: 1024 +optim_steps_per_batch: 1024 +batch_size: 256 +total_frames: 1000000 +prb: 1 +lr: 3e-4 +ou_exploration: 1 +multi_step: 1 +init_random_frames: 25000 +activation: elu +gSDE: 0 +from_pixels: 0 +#collector_devices: [cuda:1,cuda:1,cuda:1,cuda:1] +collector_devices: [cpu,cpu] +env_per_collector: 1 +num_workers: 2 +lr_scheduler: "" +value_network_update_interval: 200 +record_interval: 10 +max_frames_per_traj: -1 +weight_decay: 0.0 +annealing_frames: 1000000 +init_env_steps: 10000 +record_frames: 10000 +loss_function: smooth_l1 +batch_transform: 1 +buffer_prefetch: 64 +norm_stats: 1 diff --git a/examples/redq/configs/cheetah.txt b/examples/redq/configs/cheetah.txt deleted file mode 100644 index 5f6547b9320..00000000000 --- a/examples/redq/configs/cheetah.txt +++ /dev/null @@ -1,16 +0,0 @@ -env_name=cheetah -env_task=run -env_library=dm_control -async_collection -record_video -frame_skip=4 -frames_per_batch=256 -optim_steps_per_batch=10 -batch_size=128 -prb -exp_name=cheetah -tanh_loc -total_frames=5000000 -num_workers=8 -env_per_collector=8 -collector_devices=cuda:0 diff --git a/examples/redq/configs/humanoid.txt b/examples/redq/configs/humanoid.txt deleted file mode 100644 index a7cddfcdd22..00000000000 --- a/examples/redq/configs/humanoid.txt +++ /dev/null @@ -1,19 +0,0 @@ -env_name=humanoid -env_task=walk -env_library=dm_control -async_collection -record_video -prb -normalize_rewards_online -normalize_rewards_online_scale=10 -exp_name=humanoid -tanh_loc - -num_workers=8 -env_per_collector=2 - -frame_skip=2 -frames_per_batch=500 -optim_steps_per_batch=80 -batch_size=128 -total_frames=5000000 diff --git a/examples/redq/configs/humanoid_pixels.txt b/examples/redq/configs/humanoid_pixels.txt deleted file mode 100644 index 2dbfbf1a3b4..00000000000 --- a/examples/redq/configs/humanoid_pixels.txt +++ /dev/null @@ -1,23 +0,0 @@ -env_name=humanoid -env_task=walk -env_library=dm_control -record_video -prb -exp_name=humanoid -tanh_loc - -num_workers=4 -env_per_collector=1 - -frame_skip=2 -frames_per_batch=500 -optim_steps_per_batch=80 -batch_size=128 -total_frames=5000000 - -from_pixels -activation=elu -lr=0.0002 -weight_decay=2e-5 - -catframes=4 diff --git a/examples/redq/redq.py b/examples/redq/redq.py index 3d14615ca94..4cac0c07d62 100644 --- a/examples/redq/redq.py +++ b/examples/redq/redq.py @@ -67,7 +67,7 @@ } -@hydra.main(version_base=None, config_path=None, config_name="config") +@hydra.main(version_base=None, config_path=".", config_name="config") def main(cfg: "DictConfig"): from torchrl.trainers.loggers import ( TensorboardLogger, @@ -99,7 +99,9 @@ def main(cfg: "DictConfig"): if not cfg.vecnorm and cfg.norm_stats: proof_env = transformed_env_constructor(cfg=cfg, use_env_creator=False)() stats = get_stats_random_rollout( - cfg, proof_env, key="next_pixels" if cfg.from_pixels else None + cfg, + proof_env, + key="next_pixels" if cfg.from_pixels else "next_observation_vector", ) # make sure proof_env is closed proof_env.close() @@ -201,20 +203,6 @@ def main(cfg: "DictConfig"): cfg, ) - def select_keys(batch): - return batch.select( - "reward", - "done", - "steps_to_next_obs", - "pixels", - "next_pixels", - "observation_vector", - "next_observation_vector", - "action", - ) - - trainer.register_op("batch_process", select_keys) - final_seed = collector.set_seed(cfg.seed) print(f"init seed: {cfg.seed}, final seed: {final_seed}") diff --git a/examples/sac/config.yaml b/examples/sac/config.yaml new file mode 100644 index 00000000000..a9661541772 --- /dev/null +++ b/examples/sac/config.yaml @@ -0,0 +1,36 @@ +env_name: HalfCheetah-v4 +env_task: "" +env_library: gym +async_collection: 1 +record_video: 0 +normalize_rewards_online: 1 +normalize_rewards_online_scale: 5 +frame_skip: 1 +frames_per_batch: 1024 +optim_steps_per_batch: 128 +batch_size: 256 +total_frames: 1000000 +prb: 1 +lr: 3e-4 +ou_exploration: 1 +multi_step: 1 +init_random_frames: 25000 +activation: elu +gSDE: 0 +from_pixels: 0 +#collector_devices: [cuda:1,cuda:1,cuda:1,cuda:1] +collector_devices: [cpu,cpu] +env_per_collector: 1 +num_workers: 2 +lr_scheduler: "" +value_network_update_interval: 200 +record_interval: 10 +max_frames_per_traj: -1 +weight_decay: 0.0 +annealing_frames: 1000000 +init_env_steps: 10000 +record_frames: 10000 +loss_function: smooth_l1 +batch_transform: 1 +buffer_prefetch: 64 +norm_stats: 1 diff --git a/examples/sac/configs/cheetah.txt b/examples/sac/configs/cheetah.txt deleted file mode 100644 index 0190a35bb15..00000000000 --- a/examples/sac/configs/cheetah.txt +++ /dev/null @@ -1,16 +0,0 @@ -env_name=cheetah -env_task=run -env_library=dm_control -async_collection -record_video -frame_skip=4 -frames_per_batch=256 -optim_steps_per_batch=4 -batch_size=64 -prb -exp_name=cheetah -tanh_loc -total_frames=5000000 -num_workers=8 -env_per_collector=8 -collector_devices=cuda:1 diff --git a/examples/sac/configs/humanoid.txt b/examples/sac/configs/humanoid.txt deleted file mode 100644 index a7cddfcdd22..00000000000 --- a/examples/sac/configs/humanoid.txt +++ /dev/null @@ -1,19 +0,0 @@ -env_name=humanoid -env_task=walk -env_library=dm_control -async_collection -record_video -prb -normalize_rewards_online -normalize_rewards_online_scale=10 -exp_name=humanoid -tanh_loc - -num_workers=8 -env_per_collector=2 - -frame_skip=2 -frames_per_batch=500 -optim_steps_per_batch=80 -batch_size=128 -total_frames=5000000 diff --git a/examples/sac/sac.py b/examples/sac/sac.py index 70e39c3d2ce..f66aacbe7e6 100644 --- a/examples/sac/sac.py +++ b/examples/sac/sac.py @@ -67,7 +67,7 @@ } -@hydra.main(version_base=None, config_path=None, config_name="config") +@hydra.main(version_base=None, config_path=".", config_name="config") def main(cfg: "DictConfig"): from torchrl.trainers.loggers import TensorboardLogger @@ -97,7 +97,9 @@ def main(cfg: "DictConfig"): if not cfg.vecnorm and cfg.norm_stats: proof_env = transformed_env_constructor(cfg=cfg, use_env_creator=False)() stats = get_stats_random_rollout( - cfg, proof_env, key="next_pixels" if cfg.from_pixels else None + cfg, + proof_env, + key="next_pixels" if cfg.from_pixels else "next_observation_vector", ) # make sure proof_env is closed proof_env.close() @@ -195,20 +197,6 @@ def main(cfg: "DictConfig"): cfg, ) - def select_keys(batch): - return batch.select( - "reward", - "done", - "steps_to_next_obs", - "pixels", - "next_pixels", - "observation_vector", - "next_observation_vector", - "action", - ) - - trainer.register_op("batch_process", select_keys) - final_seed = collector.set_seed(cfg.seed) print(f"init seed: {cfg.seed}, final seed: {final_seed}") diff --git a/torchrl/trainers/helpers/envs.py b/torchrl/trainers/helpers/envs.py index 15f46b8e2da..9dd0ab4e876 100644 --- a/torchrl/trainers/helpers/envs.py +++ b/torchrl/trainers/helpers/envs.py @@ -348,8 +348,13 @@ def get_stats_random_rollout( while n < cfg.init_env_steps: _td_stats = proof_environment.rollout(max_steps=cfg.init_env_steps) n += _td_stats.numel() - td_stats.append(_td_stats.to_tensordict().select(key).cpu()) - del _td_stats + _td_stats_select = _td_stats.to_tensordict().select(key).cpu() + if not len(list(_td_stats_select.keys())): + raise RuntimeError( + f"key {key} not found in tensordict with keys {list(_td_stats.keys())}" + ) + td_stats.append(_td_stats_select) + del _td_stats, _td_stats_select td_stats = torch.cat(td_stats, 0) if key is None: diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index 4b9c9abb283..a8003adefa5 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -441,8 +441,8 @@ def _log(self, log_pbar=False, **kwargs) -> None: else: _log = False method = LOGGER_METHODS.get(key, "log_scalar") - if _log and self.log_scalar is not None: - getattr(self.log_scalar, method)(key, item, step=collected_frames) + if _log and self.logger is not None: + getattr(self.logger, method)(key, item, step=collected_frames) if method == "log_scalar" and self.progress_bar and log_pbar: if isinstance(item, torch.Tensor): item = item.item()