diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh index 4587be88ddc..bcc688b0a6d 100755 --- a/.github/unittest/linux_examples/scripts/run_test.sh +++ b/.github/unittest/linux_examples/scripts/run_test.sh @@ -36,7 +36,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/de optim.pretrain_gradient_steps=55 \ optim.updates_per_episode=3 \ optim.warmup_steps=10 \ - optim.device=cuda:0 \ logger.backend= \ env.backend=gymnasium \ env.name=HalfCheetah-v4 @@ -44,16 +43,13 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/de optim.pretrain_gradient_steps=55 \ optim.updates_per_episode=3 \ optim.warmup_steps=10 \ - optim.device=cuda:0 \ env.backend=gymnasium \ logger.backend= python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iql/iql_offline.py \ optim.gradient_steps=55 \ - optim.device=cuda:0 \ logger.backend= python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/cql/cql_offline.py \ optim.gradient_steps=55 \ - optim.device=cuda:0 \ logger.backend= # ==================================================================================== # @@ -86,8 +82,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dd optim.batch_size=10 \ collector.frames_per_batch=16 \ collector.env_per_collector=2 \ - collector.device=cuda:0 \ - network.device=cuda:0 \ optim.utd_ratio=1 \ replay_buffer.size=120 \ env.name=Pendulum-v1 \ @@ -112,7 +106,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dq collector.init_random_frames=10 \ collector.frames_per_batch=16 \ buffer.batch_size=10 \ - device=cuda:0 \ loss.num_updates=1 \ logger.backend= \ buffer.buffer_size=120 @@ -122,7 +115,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/cq optim.batch_size=10 \ collector.frames_per_batch=16 \ collector.env_per_collector=2 \ - collector.device=cuda:0 \ replay_buffer.size=120 \ logger.backend= python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/redq/redq.py \ @@ -131,7 +123,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/re collector.init_random_frames=10 \ collector.frames_per_batch=16 \ collector.env_per_collector=2 \ - collector.device=cuda:0 \ buffer.batch_size=10 \ optim.steps_per_batch=1 \ logger.record_video=True \ @@ -143,22 +134,18 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/sa collector.init_random_frames=10 \ collector.frames_per_batch=16 \ collector.env_per_collector=2 \ - collector.device=cuda:0 \ optim.batch_size=10 \ optim.utd_ratio=1 \ replay_buffer.size=120 \ env.name=Pendulum-v1 \ - network.device=cuda:0 \ logger.backend= python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/discrete_sac/discrete_sac.py \ collector.total_frames=48 \ collector.init_random_frames=10 \ collector.frames_per_batch=16 \ collector.env_per_collector=1 \ - collector.device=cuda:0 \ optim.batch_size=10 \ optim.utd_ratio=1 \ - network.device=cuda:0 \ optim.batch_size=10 \ optim.utd_ratio=1 \ replay_buffer.size=120 \ @@ -185,9 +172,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/td collector.frames_per_batch=16 \ collector.num_workers=4 \ collector.env_per_collector=2 \ - collector.device=cuda:0 \ - collector.device=cuda:0 \ - network.device=cuda:0 \ logger.mode=offline \ env.name=Pendulum-v1 \ logger.backend= @@ -196,8 +180,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iq optim.batch_size=10 \ collector.frames_per_batch=16 \ env.train_num_envs=2 \ - optim.device=cuda:0 \ - collector.device=cuda:0 \ logger.mode=offline \ logger.backend= python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iql/discrete_iql.py \ @@ -205,8 +187,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iq optim.batch_size=10 \ collector.frames_per_batch=16 \ env.train_num_envs=2 \ - optim.device=cuda:0 \ - collector.device=cuda:0 \ logger.mode=offline \ logger.backend= python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/cql/cql_online.py \ @@ -214,8 +194,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iq optim.batch_size=10 \ collector.frames_per_batch=16 \ env.train_num_envs=2 \ - collector.device=cuda:0 \ - optim.device=cuda:0 \ logger.mode=offline \ logger.backend= @@ -238,8 +216,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dd optim.batch_size=10 \ collector.frames_per_batch=16 \ collector.env_per_collector=1 \ - collector.device=cuda:0 \ - network.device=cuda:0 \ optim.utd_ratio=1 \ replay_buffer.size=120 \ env.name=Pendulum-v1 \ @@ -251,7 +227,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dq collector.init_random_frames=10 \ collector.frames_per_batch=16 \ buffer.batch_size=10 \ - device=cuda:0 \ loss.num_updates=1 \ logger.backend= \ buffer.buffer_size=120 @@ -262,7 +237,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/re collector.frames_per_batch=16 \ collector.env_per_collector=1 \ buffer.batch_size=10 \ - collector.device=cuda:0 \ optim.steps_per_batch=1 \ logger.record_video=True \ logger.record_frames=4 \ @@ -274,8 +248,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iq collector.frames_per_batch=16 \ env.train_num_envs=1 \ logger.mode=offline \ - optim.device=cuda:0 \ - collector.device=cuda:0 \ logger.backend= python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/cql/cql_online.py \ collector.total_frames=48 \ @@ -283,8 +255,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/cq collector.frames_per_batch=16 \ collector.env_per_collector=1 \ logger.mode=offline \ - optim.device=cuda:0 \ - collector.device=cuda:0 \ logger.backend= python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/td3/td3.py \ collector.total_frames=48 \ @@ -292,11 +262,9 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/td collector.frames_per_batch=16 \ collector.num_workers=2 \ collector.env_per_collector=1 \ - collector.device=cuda:0 \ logger.mode=offline \ optim.batch_size=10 \ env.name=Pendulum-v1 \ - network.device=cuda:0 \ logger.backend= python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/multiagent/mappo_ippo.py \ collector.n_iters=2 \ diff --git a/sota-implementations/a2c/a2c_atari.py b/sota-implementations/a2c/a2c_atari.py index 7ad39ed43e5..775dcfe206d 100644 --- a/sota-implementations/a2c/a2c_atari.py +++ b/sota-implementations/a2c/a2c_atari.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import hydra from torchrl._utils import logger as torchrl_logger +from torchrl.record import VideoRecorder @hydra.main(config_path="", config_name="config_atari", version_base="1.1") @@ -104,6 +105,14 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create test environment test_env = make_parallel_env(cfg.env.env_name, 1, device, is_test=True) + test_env.set_seed(0) + if cfg.logger.video: + test_env = test_env.insert_transform( + 0, + VideoRecorder( + logger, tag=f"rendered/{cfg.env.env_name}", in_keys=["pixels"] + ), + ) test_env.eval() # Main loop diff --git a/sota-implementations/a2c/a2c_mujoco.py b/sota-implementations/a2c/a2c_mujoco.py index 7b4a153e150..0276039058f 100644 --- a/sota-implementations/a2c/a2c_mujoco.py +++ b/sota-implementations/a2c/a2c_mujoco.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import hydra from torchrl._utils import logger as torchrl_logger +from torchrl.record import VideoRecorder @hydra.main(config_path="", config_name="config_mujoco", version_base="1.1") @@ -89,7 +90,15 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Create test environment - test_env = make_env(cfg.env.env_name, device) + test_env = make_env(cfg.env.env_name, device, from_pixels=cfg.logger.video) + test_env.set_seed(0) + if cfg.logger.video: + test_env = test_env.insert_transform( + 0, + VideoRecorder( + logger, tag=f"rendered/{cfg.env.env_name}", in_keys=["pixels"] + ), + ) test_env.eval() # Main loop @@ -178,9 +187,10 @@ def main(cfg: "DictConfig"): # noqa: F821 # Get test rewards with torch.no_grad(), set_exploration_type(ExplorationType.MODE): - if ((i - 1) * frames_in_batch) // cfg.logger.test_interval < ( - i * frames_in_batch - ) // cfg.logger.test_interval: + prev_test_frame = ((i - 1) * frames_in_batch) // cfg.logger.test_interval + cur_test_frame = (i * frames_in_batch) // cfg.logger.test_interval + final = collected_frames >= collector.total_frames + if prev_test_frame < cur_test_frame or final: actor.eval() eval_start = time.time() test_rewards = eval_model( diff --git a/sota-implementations/a2c/config_atari.yaml b/sota-implementations/a2c/config_atari.yaml index 8c94f62fb93..dd0f43b52cb 100644 --- a/sota-implementations/a2c/config_atari.yaml +++ b/sota-implementations/a2c/config_atari.yaml @@ -16,6 +16,7 @@ logger: exp_name: Atari_Schulman17 test_interval: 40_000_000 num_test_episodes: 3 + video: False # Optim optim: diff --git a/sota-implementations/a2c/config_mujoco.yaml b/sota-implementations/a2c/config_mujoco.yaml index b30b7304f61..03a0bde32c5 100644 --- a/sota-implementations/a2c/config_mujoco.yaml +++ b/sota-implementations/a2c/config_mujoco.yaml @@ -15,6 +15,7 @@ logger: exp_name: Mujoco_Schulman17 test_interval: 1_000_000 num_test_episodes: 5 + video: False # Optim optim: diff --git a/sota-implementations/a2c/utils_atari.py b/sota-implementations/a2c/utils_atari.py index 0ddcd79123e..240ebac96d2 100644 --- a/sota-implementations/a2c/utils_atari.py +++ b/sota-implementations/a2c/utils_atari.py @@ -36,6 +36,8 @@ TanhNormal, ValueOperator, ) +from torchrl.record import VideoRecorder + # ==================================================================== # Environment utils @@ -201,6 +203,11 @@ def make_ppo_models(env_name): # -------------------------------------------------------------------- +def dump_video(module): + if isinstance(module, VideoRecorder): + module.dump() + + def eval_model(actor, test_env, num_episodes=3): test_rewards = [] for _ in range(num_episodes): @@ -213,5 +220,6 @@ def eval_model(actor, test_env, num_episodes=3): ) reward = td_test["next", "episode_reward"][td_test["next", "done"]] test_rewards = np.append(test_rewards, reward.cpu().numpy()) + test_env.apply(dump_video) del td_test return test_rewards.mean() diff --git a/sota-implementations/a2c/utils_mujoco.py b/sota-implementations/a2c/utils_mujoco.py index 50780a9d086..178678e4457 100644 --- a/sota-implementations/a2c/utils_mujoco.py +++ b/sota-implementations/a2c/utils_mujoco.py @@ -20,14 +20,20 @@ ) from torchrl.envs.libs.gym import GymEnv from torchrl.modules import MLP, ProbabilisticActor, TanhNormal, ValueOperator +from torchrl.record import VideoRecorder + # ==================================================================== # Environment utils # -------------------------------------------------------------------- -def make_env(env_name="HalfCheetah-v4", device="cpu"): - env = GymEnv(env_name, device=device) +def make_env( + env_name="HalfCheetah-v4", device="cpu", from_pixels=False, pixels_only=False +): + env = GymEnv( + env_name, device=device, from_pixels=from_pixels, pixels_only=pixels_only + ) env = TransformedEnv(env) env.append_transform(RewardSum()) env.append_transform(StepCounter()) @@ -125,6 +131,11 @@ def make_ppo_models(env_name): # -------------------------------------------------------------------- +def dump_video(module): + if isinstance(module, VideoRecorder): + module.dump() + + def eval_model(actor, test_env, num_episodes=3): test_rewards = [] for _ in range(num_episodes): @@ -137,5 +148,6 @@ def eval_model(actor, test_env, num_episodes=3): ) reward = td_test["next", "episode_reward"][td_test["next", "done"]] test_rewards = np.append(test_rewards, reward.cpu().numpy()) + test_env.apply(dump_video) del td_test return test_rewards.mean() diff --git a/sota-implementations/cql/cql_offline.py b/sota-implementations/cql/cql_offline.py index 441cb3555e2..59b574090f9 100644 --- a/sota-implementations/cql/cql_offline.py +++ b/sota-implementations/cql/cql_offline.py @@ -20,6 +20,7 @@ from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( + dump_video, log_metrics, make_continuous_cql_optimizer, make_continuous_loss, @@ -49,16 +50,25 @@ def main(cfg: "DictConfig"): # noqa: F821 # Set seeds torch.manual_seed(cfg.env.seed) np.random.seed(cfg.env.seed) - device = torch.device(cfg.optim.device) + device = cfg.optim.device + if device in ("", None): + if torch.cuda.is_available(): + device = "cuda:0" + else: + device = "cpu" + device = torch.device(device) # Create env - train_env, eval_env = make_environment(cfg, cfg.logger.eval_envs) + train_env, eval_env = make_environment( + cfg, train_num_envs=1, eval_num_envs=cfg.logger.eval_envs, logger=logger + ) # Create replay buffer replay_buffer = make_offline_replay_buffer(cfg.replay_buffer) # Create agent model = make_cql_model(cfg, train_env, eval_env, device) + del train_env # Create loss loss_module, target_net_updater = make_continuous_loss(cfg.loss, model) @@ -144,6 +154,7 @@ def main(cfg: "DictConfig"): # noqa: F821 eval_td = eval_env.rollout( max_steps=eval_steps, policy=model[0], auto_cast_to_device=True ) + eval_env.apply(dump_video) eval_reward = eval_td["next", "reward"].sum(1).mean().item() to_log["evaluation_reward"] = eval_reward diff --git a/sota-implementations/cql/cql_online.py b/sota-implementations/cql/cql_online.py index a70f9091cb6..dc9bd512285 100644 --- a/sota-implementations/cql/cql_online.py +++ b/sota-implementations/cql/cql_online.py @@ -23,6 +23,7 @@ from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( + dump_video, log_metrics, make_collector, make_continuous_cql_optimizer, @@ -54,13 +55,20 @@ def main(cfg: "DictConfig"): # noqa: F821 # Set seeds torch.manual_seed(cfg.env.seed) np.random.seed(cfg.env.seed) - device = torch.device(cfg.optim.device) + device = cfg.optim.device + if device in ("", None): + if torch.cuda.is_available(): + device = "cuda:0" + else: + device = "cpu" + device = torch.device(device) # Create env train_env, eval_env = make_environment( cfg, cfg.env.train_num_envs, cfg.env.eval_num_envs, + logger=logger, ) # Create replay buffer @@ -99,12 +107,12 @@ def main(cfg: "DictConfig"): # noqa: F821 * cfg.optim.utd_ratio ) prb = cfg.replay_buffer.prb - eval_iter = cfg.logger.eval_iter frames_per_batch = cfg.collector.frames_per_batch - eval_rollout_steps = cfg.collector.max_frames_per_traj + evaluation_interval = cfg.logger.log_interval + eval_rollout_steps = cfg.logger.eval_steps sampling_start = time.time() - for tensordict in collector: + for i, tensordict in enumerate(collector): sampling_time = time.time() - sampling_start pbar.update(tensordict.numel()) # update weights of the inference policy @@ -191,7 +199,11 @@ def main(cfg: "DictConfig"): # noqa: F821 metrics_to_log["train/training_time"] = training_time # Evaluation - if abs(collected_frames % eval_iter) < frames_per_batch: + + prev_test_frame = ((i - 1) * frames_per_batch) // evaluation_interval + cur_test_frame = (i * frames_per_batch) // evaluation_interval + final = current_frames >= collector.total_frames + if (i >= 1 and (prev_test_frame < cur_test_frame)) or final: with set_exploration_type(ExplorationType.MODE), torch.no_grad(): eval_start = time.time() eval_rollout = eval_env.rollout( @@ -202,6 +214,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) eval_time = time.time() - eval_start eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() + eval_env.apply(dump_video) metrics_to_log["eval/reward"] = eval_reward metrics_to_log["eval/time"] = eval_time diff --git a/sota-implementations/cql/discrete_cql_config.yaml b/sota-implementations/cql/discrete_cql_config.yaml index 807479d45bd..644b8ec624e 100644 --- a/sota-implementations/cql/discrete_cql_config.yaml +++ b/sota-implementations/cql/discrete_cql_config.yaml @@ -30,6 +30,7 @@ logger: eval_steps: 200 mode: online eval_iter: 1000 + video: False # Buffer replay_buffer: @@ -41,7 +42,7 @@ replay_buffer: # Optimization optim: utd_ratio: 1 - device: cuda:0 + device: null lr: 1e-3 weight_decay: 0.0 batch_size: 256 diff --git a/sota-implementations/cql/discrete_cql_online.py b/sota-implementations/cql/discrete_cql_online.py index fd07684774d..4b6f14cd058 100644 --- a/sota-implementations/cql/discrete_cql_online.py +++ b/sota-implementations/cql/discrete_cql_online.py @@ -35,7 +35,13 @@ @hydra.main(version_base="1.1", config_path="", config_name="discrete_cql_config") def main(cfg: "DictConfig"): # noqa: F821 - device = torch.device(cfg.optim.device) + device = cfg.optim.device + if device in ("", None): + if torch.cuda.is_available(): + device = "cuda:0" + else: + device = "cpu" + device = torch.device(device) # Create logger exp_name = generate_exp_name("DiscreteCQL", cfg.logger.exp_name) diff --git a/sota-implementations/cql/offline_config.yaml b/sota-implementations/cql/offline_config.yaml index 0047b74d14c..bf213d4e3c5 100644 --- a/sota-implementations/cql/offline_config.yaml +++ b/sota-implementations/cql/offline_config.yaml @@ -13,10 +13,12 @@ logger: project_name: torchrl_example_cql group_name: null exp_name: cql_${replay_buffer.dataset} + # eval iter in gradient steps eval_iter: 5000 eval_steps: 1000 mode: online eval_envs: 5 + video: False # replay buffer replay_buffer: @@ -25,7 +27,7 @@ replay_buffer: # optimization optim: - device: cuda:0 + device: null actor_lr: 3e-4 critic_lr: 3e-4 weight_decay: 0.0 diff --git a/sota-implementations/cql/online_config.yaml b/sota-implementations/cql/online_config.yaml index 9b3e5b5bf24..00db1d6bb62 100644 --- a/sota-implementations/cql/online_config.yaml +++ b/sota-implementations/cql/online_config.yaml @@ -26,9 +26,9 @@ logger: group_name: null exp_name: cql_${env.name} log_interval: 5000 # record interval in frames - eval_steps: 1000 mode: online - eval_iter: 1000 + eval_steps: 1000 + video: False # Buffer replay_buffer: @@ -39,7 +39,7 @@ replay_buffer: # Optimization optim: utd_ratio: 1 - device: cuda:0 + device: null actor_lr: 3e-4 critic_lr: 3e-4 weight_decay: 0.0 diff --git a/sota-implementations/cql/utils.py b/sota-implementations/cql/utils.py index 350b105b441..46b84ee434b 100644 --- a/sota-implementations/cql/utils.py +++ b/sota-implementations/cql/utils.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import functools + import torch.nn import torch.optim from tensordict.nn import TensorDictModule, TensorDictSequential @@ -37,6 +39,7 @@ ValueOperator, ) from torchrl.objectives import CQLLoss, DiscreteCQLLoss, SoftUpdate +from torchrl.record import VideoRecorder from torchrl.trainers.helpers.models import ACTIVATIONS @@ -45,16 +48,17 @@ # ----------------- -def env_maker(cfg, device="cpu"): +def env_maker(cfg, device="cpu", from_pixels=False): lib = cfg.env.backend if lib in ("gym", "gymnasium"): with set_gym_backend(lib): return GymEnv( - cfg.env.name, - device=device, + cfg.env.name, device=device, from_pixels=from_pixels, pixels_only=False ) elif lib == "dm_control": - env = DMControlEnv(cfg.env.name, cfg.env.task) + env = DMControlEnv( + cfg.env.name, cfg.env.task, from_pixels=from_pixels, pixels_only=False + ) return TransformedEnv( env, CatTensors(in_keys=env.observation_spec.keys(), out_key="observation") ) @@ -75,25 +79,32 @@ def apply_env_transforms( return transformed_env -def make_environment(cfg, train_num_envs=1, eval_num_envs=1): +def make_environment(cfg, train_num_envs=1, eval_num_envs=1, logger=None): """Make environments for training and evaluation.""" + maker = functools.partial(env_maker, cfg) parallel_env = ParallelEnv( train_num_envs, - EnvCreator(lambda cfg=cfg: env_maker(cfg)), + EnvCreator(maker), serial_for_single=True, ) parallel_env.set_seed(cfg.env.seed) train_env = apply_env_transforms(parallel_env) + maker = functools.partial(env_maker, cfg, from_pixels=cfg.logger.video) eval_env = TransformedEnv( ParallelEnv( eval_num_envs, - EnvCreator(lambda cfg=cfg: env_maker(cfg)), + EnvCreator(maker), serial_for_single=True, ), train_env.transform.clone(), ) + eval_env.set_seed(0) + if cfg.logger.video: + eval_env = eval_env.insert_transform( + 0, VideoRecorder(logger=logger, tag="rendered", in_keys=["pixels"]) + ) return train_env, eval_env @@ -373,3 +384,8 @@ def log_metrics(logger, metrics, step): if logger is not None: for metric_name, metric_value in metrics.items(): logger.log_scalar(metric_name, metric_value, step) + + +def dump_video(module): + if isinstance(module, VideoRecorder): + module.dump() diff --git a/sota-implementations/ddpg/config.yaml b/sota-implementations/ddpg/config.yaml index 7d17038330b..43cb5093c09 100644 --- a/sota-implementations/ddpg/config.yaml +++ b/sota-implementations/ddpg/config.yaml @@ -32,12 +32,12 @@ optim: weight_decay: 1e-4 batch_size: 256 target_update_polyak: 0.995 + device: null # network network: hidden_sizes: [256, 256] activation: relu - device: "cuda:0" noise_type: "ou" # ou or gaussian # logging @@ -48,3 +48,5 @@ logger: exp_name: ${env.name}_DDPG mode: online eval_iter: 25000 + video: False + num_eval_envs: 1 diff --git a/sota-implementations/ddpg/ddpg.py b/sota-implementations/ddpg/ddpg.py index e8313e6c342..eb0b88c26f7 100644 --- a/sota-implementations/ddpg/ddpg.py +++ b/sota-implementations/ddpg/ddpg.py @@ -23,6 +23,7 @@ from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( + dump_video, log_metrics, make_collector, make_ddpg_agent, @@ -35,7 +36,13 @@ @hydra.main(version_base="1.1", config_path="", config_name="config") def main(cfg: "DictConfig"): # noqa: F821 - device = torch.device(cfg.network.device) + device = cfg.optim.device + if device in ("", None): + if torch.cuda.is_available(): + device = "cuda:0" + else: + device = "cpu" + device = torch.device(device) # Create logger exp_name = generate_exp_name("DDPG", cfg.logger.exp_name) @@ -58,7 +65,7 @@ def main(cfg: "DictConfig"): # noqa: F821 np.random.seed(cfg.env.seed) # Create environments - train_env, eval_env = make_environment(cfg) + train_env, eval_env = make_environment(cfg, logger=logger) # Create agent model, exploration_policy = make_ddpg_agent(cfg, train_env, eval_env, device) @@ -186,6 +193,7 @@ def main(cfg: "DictConfig"): # noqa: F821 auto_cast_to_device=True, break_when_any_done=True, ) + eval_env.apply(dump_video) eval_time = time.time() - eval_start eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() metrics_to_log["eval/reward"] = eval_reward diff --git a/sota-implementations/ddpg/utils.py b/sota-implementations/ddpg/utils.py index 4006fc27b38..45c6da7a342 100644 --- a/sota-implementations/ddpg/utils.py +++ b/sota-implementations/ddpg/utils.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import functools + import torch from torch import nn, optim @@ -34,6 +36,7 @@ from torchrl.objectives import SoftUpdate from torchrl.objectives.ddpg import DDPGLoss +from torchrl.record import VideoRecorder # ==================================================================== @@ -41,16 +44,17 @@ # ----------------- -def env_maker(cfg, device="cpu"): +def env_maker(cfg, device="cpu", from_pixels=False): lib = cfg.env.library if lib in ("gym", "gymnasium"): with set_gym_backend(lib): return GymEnv( - cfg.env.name, - device=device, + cfg.env.name, device=device, from_pixels=from_pixels, pixels_only=False ) elif lib == "dm_control": - env = DMControlEnv(cfg.env.name, cfg.env.task) + env = DMControlEnv( + cfg.env.name, cfg.env.task, from_pixels=from_pixels, pixels_only=False + ) return TransformedEnv( env, CatTensors(in_keys=env.observation_spec.keys(), out_key="observation") ) @@ -71,11 +75,12 @@ def apply_env_transforms(env, max_episode_steps=1000): return transformed_env -def make_environment(cfg): +def make_environment(cfg, logger): """Make environments for training and evaluation.""" + maker = functools.partial(env_maker, cfg, from_pixels=False) parallel_env = ParallelEnv( cfg.collector.env_per_collector, - EnvCreator(lambda cfg=cfg: env_maker(cfg)), + EnvCreator(maker), serial_for_single=True, ) parallel_env.set_seed(cfg.env.seed) @@ -84,14 +89,20 @@ def make_environment(cfg): parallel_env, max_episode_steps=cfg.env.max_episode_steps ) + maker = functools.partial(env_maker, cfg, from_pixels=cfg.logger.video) eval_env = TransformedEnv( ParallelEnv( - cfg.collector.env_per_collector, - EnvCreator(lambda cfg=cfg: env_maker(cfg)), + cfg.logger.num_eval_envs, + EnvCreator(maker), serial_for_single=True, ), train_env.transform.clone(), ) + eval_env.set_seed(0) + if cfg.logger.video: + eval_env = eval_env.append_transform( + VideoRecorder(logger, tag="rendered", in_keys=["pixels"]) + ) return train_env, eval_env @@ -290,3 +301,8 @@ def get_activation(cfg): return nn.LeakyReLU else: raise NotImplementedError + + +def dump_video(module): + if isinstance(module, VideoRecorder): + module.dump() diff --git a/sota-implementations/decision_transformer/dt.py b/sota-implementations/decision_transformer/dt.py index a79c0037205..59dbcafd8c9 100644 --- a/sota-implementations/decision_transformer/dt.py +++ b/sota-implementations/decision_transformer/dt.py @@ -17,8 +17,10 @@ from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules.tensordict_module import DecisionTransformerInferenceWrapper +from torchrl.record import VideoRecorder from utils import ( + dump_video, log_metrics, make_dt_loss, make_dt_model, @@ -34,6 +36,12 @@ def main(cfg: "DictConfig"): # noqa: F821 set_gym_backend(cfg.env.backend).set() model_device = cfg.optim.device + if model_device in ("", None): + if torch.cuda.is_available(): + model_device = "cuda:0" + else: + model_device = "cpu" + model_device = torch.device(model_device) # Set seeds torch.manual_seed(cfg.env.seed) @@ -48,7 +56,11 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Create test environment - test_env = make_env(cfg.env, obs_loc, obs_std) + test_env = make_env(cfg.env, obs_loc, obs_std, from_pixels=cfg.logger.video) + if cfg.logger.video: + test_env = test_env.append_transform( + VideoRecorder(logger, tag="rendered", in_keys=["pixels"]) + ) # Create policy model actor = make_dt_model(cfg) @@ -109,6 +121,7 @@ def main(cfg: "DictConfig"): # noqa: F821 policy=inference_policy, auto_cast_to_device=True, ) + test_env.apply(dump_video) to_log["eval/reward"] = ( eval_td["next", "reward"].sum(1).mean().item() / reward_scaling ) diff --git a/sota-implementations/decision_transformer/dt_config.yaml b/sota-implementations/decision_transformer/dt_config.yaml index b42d8b58d35..4805785a62c 100644 --- a/sota-implementations/decision_transformer/dt_config.yaml +++ b/sota-implementations/decision_transformer/dt_config.yaml @@ -27,6 +27,7 @@ logger: pretrain_log_interval: 500 # record interval in frames fintune_log_interval: 1 eval_steps: 1000 + video: False # replay buffer replay_buffer: @@ -42,7 +43,7 @@ replay_buffer: # optimization optim: - device: cuda:0 + device: null lr: 1.0e-4 weight_decay: 5.0e-4 batch_size: 64 diff --git a/sota-implementations/decision_transformer/odt_config.yaml b/sota-implementations/decision_transformer/odt_config.yaml index f06972fd46b..eec2b455fb3 100644 --- a/sota-implementations/decision_transformer/odt_config.yaml +++ b/sota-implementations/decision_transformer/odt_config.yaml @@ -25,8 +25,9 @@ logger: exp_name: oDT-HalfCheetah-medium-v2 model_name: oDT pretrain_log_interval: 500 # record interval in frames - fintune_log_interval: 1 + finetune_log_interval: 1 eval_steps: 1000 + video: False # replay buffer replay_buffer: @@ -37,12 +38,11 @@ replay_buffer: buffer_prefetch: 64 capacity: 1_000_000 scratch_dir: - device: cuda:0 prefetch: 3 # optimizer optim: - device: cuda:0 + device: null lr: 1.0e-4 weight_decay: 5.0e-4 batch_size: 256 diff --git a/sota-implementations/decision_transformer/online_dt.py b/sota-implementations/decision_transformer/online_dt.py index 427b5d8eaa3..5cb297e5c0b 100644 --- a/sota-implementations/decision_transformer/online_dt.py +++ b/sota-implementations/decision_transformer/online_dt.py @@ -17,8 +17,10 @@ from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules.tensordict_module import DecisionTransformerInferenceWrapper +from torchrl.record import VideoRecorder from utils import ( + dump_video, log_metrics, make_env, make_logger, @@ -34,6 +36,12 @@ def main(cfg: "DictConfig"): # noqa: F821 set_gym_backend(cfg.env.backend).set() model_device = cfg.optim.device + if model_device in ("", None): + if torch.cuda.is_available(): + model_device = "cuda:0" + else: + model_device = "cpu" + model_device = torch.device(model_device) # Set seeds torch.manual_seed(cfg.env.seed) @@ -48,7 +56,11 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Create test environment - test_env = make_env(cfg.env, obs_loc, obs_std) + test_env = make_env(cfg.env, obs_loc, obs_std, from_pixels=cfg.logger.video) + if cfg.logger.video: + test_env = test_env.append_transform( + VideoRecorder(logger, tag="rendered", in_keys=["pixels"]) + ) # Create policy model actor = make_odt_model(cfg) @@ -123,6 +135,7 @@ def main(cfg: "DictConfig"): # noqa: F821 auto_cast_to_device=True, break_when_any_done=False, ) + test_env.apply(dump_video) inference_policy.train() to_log["eval/reward"] = ( eval_td["next", "reward"].sum(1).mean().item() / reward_scaling diff --git a/sota-implementations/decision_transformer/utils.py b/sota-implementations/decision_transformer/utils.py index 9d479a8118d..a87b3cd8d9f 100644 --- a/sota-implementations/decision_transformer/utils.py +++ b/sota-implementations/decision_transformer/utils.py @@ -48,6 +48,7 @@ ) from torchrl.objectives import DTLoss, OnlineDTLoss +from torchrl.record import VideoRecorder from torchrl.record.loggers import generate_exp_name, get_logger from torchrl.trainers.helpers.envs import LIBS @@ -56,7 +57,7 @@ # ----------------- -def make_base_env(env_cfg): +def make_base_env(env_cfg, from_pixels=False): set_gym_backend(env_cfg.backend).set() env_library = LIBS[env_cfg.library] @@ -66,6 +67,8 @@ def make_base_env(env_cfg): env_kwargs = { "env_name": env_name, "frame_skip": frame_skip, + "from_pixels": from_pixels, + "pixels_only": False, } if env_library is DMControlEnv: env_task = env_cfg.task @@ -131,7 +134,7 @@ def make_transformed_env(base_env, env_cfg, obs_loc, obs_std, train=False): return transformed_env -def make_parallel_env(env_cfg, obs_loc, obs_std, train=False): +def make_parallel_env(env_cfg, obs_loc, obs_std, train=False, from_pixels=False): if train: num_envs = env_cfg.num_train_envs else: @@ -139,7 +142,7 @@ def make_parallel_env(env_cfg, obs_loc, obs_std, train=False): def make_env(): with set_gym_backend(env_cfg.backend): - return make_base_env(env_cfg) + return make_base_env(env_cfg, from_pixels=from_pixels) env = make_transformed_env( ParallelEnv(num_envs, EnvCreator(make_env), serial_for_single=True), @@ -151,8 +154,10 @@ def make_env(): return env -def make_env(env_cfg, obs_loc, obs_std, train=False): - env = make_parallel_env(env_cfg, obs_loc, obs_std, train=train) +def make_env(env_cfg, obs_loc, obs_std, train=False, from_pixels=False): + env = make_parallel_env( + env_cfg, obs_loc, obs_std, train=train, from_pixels=from_pixels + ) return env @@ -517,3 +522,8 @@ def make_logger(cfg): def log_metrics(logger, metrics, step): for metric_name, metric_value in metrics.items(): logger.log_scalar(metric_name, metric_value, step) + + +def dump_video(module): + if isinstance(module, VideoRecorder): + module.dump() diff --git a/sota-implementations/discrete_sac/config.yaml b/sota-implementations/discrete_sac/config.yaml index df26c835ef0..aa852ca1fc3 100644 --- a/sota-implementations/discrete_sac/config.yaml +++ b/sota-implementations/discrete_sac/config.yaml @@ -14,7 +14,7 @@ collector: init_env_steps: 1000 frames_per_batch: 500 reset_at_each_iter: False - device: cuda:0 + device: null env_per_collector: 1 num_workers: 1 @@ -42,7 +42,7 @@ optim: network: hidden_sizes: [256, 256] activation: relu - device: "cuda:0" + device: null # logging logger: @@ -52,3 +52,4 @@ logger: exp_name: ${env.name}_DiscreteSAC mode: online eval_iter: 5000 + video: False diff --git a/sota-implementations/discrete_sac/discrete_sac.py b/sota-implementations/discrete_sac/discrete_sac.py index 40d9a1743c2..6e100f92dc3 100644 --- a/sota-implementations/discrete_sac/discrete_sac.py +++ b/sota-implementations/discrete_sac/discrete_sac.py @@ -23,6 +23,7 @@ from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( + dump_video, log_metrics, make_collector, make_environment, @@ -35,7 +36,13 @@ @hydra.main(version_base="1.1", config_path="", config_name="config") def main(cfg: "DictConfig"): # noqa: F821 - device = torch.device(cfg.network.device) + device = cfg.network.device + if device in ("", None): + if torch.cuda.is_available(): + device = "cuda:0" + else: + device = "cpu" + device = torch.device(device) # Create logger exp_name = generate_exp_name("DiscreteSAC", cfg.logger.exp_name) @@ -58,7 +65,7 @@ def main(cfg: "DictConfig"): # noqa: F821 np.random.seed(cfg.env.seed) # Create environments - train_env, eval_env = make_environment(cfg) + train_env, eval_env = make_environment(cfg, logger=logger) # Create agent model = make_sac_agent(cfg, train_env, eval_env, device) @@ -100,7 +107,7 @@ def main(cfg: "DictConfig"): # noqa: F821 frames_per_batch = cfg.collector.frames_per_batch sampling_start = time.time() - for tensordict in collector: + for i, tensordict in enumerate(collector): sampling_time = time.time() - sampling_start # Update weights of the inference policy @@ -193,7 +200,10 @@ def main(cfg: "DictConfig"): # noqa: F821 metrics_to_log["train/training_time"] = training_time # Evaluation - if abs(collected_frames % eval_iter) < frames_per_batch: + prev_test_frame = ((i - 1) * frames_per_batch) // eval_iter + cur_test_frame = (i * frames_per_batch) // eval_iter + final = current_frames >= collector.total_frames + if (i >= 1 and (prev_test_frame < cur_test_frame)) or final: with set_exploration_type(ExplorationType.MODE), torch.no_grad(): eval_start = time.time() eval_rollout = eval_env.rollout( @@ -202,6 +212,7 @@ def main(cfg: "DictConfig"): # noqa: F821 auto_cast_to_device=True, break_when_any_done=True, ) + eval_env.apply(dump_video) eval_time = time.time() - eval_start eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() metrics_to_log["eval/reward"] = eval_reward diff --git a/sota-implementations/discrete_sac/utils.py b/sota-implementations/discrete_sac/utils.py index 5821ed53465..ddffffc2a8e 100644 --- a/sota-implementations/discrete_sac/utils.py +++ b/sota-implementations/discrete_sac/utils.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import functools import tempfile from contextlib import nullcontext @@ -36,22 +37,25 @@ from torchrl.modules.tensordict_module.actors import ProbabilisticActor from torchrl.objectives import SoftUpdate from torchrl.objectives.sac import DiscreteSACLoss +from torchrl.record import VideoRecorder + # ==================================================================== # Environment utils # ----------------- -def env_maker(cfg, device="cpu"): +def env_maker(cfg, device="cpu", from_pixels=False): lib = cfg.env.library if lib in ("gym", "gymnasium"): with set_gym_backend(lib): return GymEnv( - cfg.env.name, - device=device, + cfg.env.name, device=device, from_pixels=from_pixels, pixels_only=False ) elif lib == "dm_control": - env = DMControlEnv(cfg.env.name, cfg.env.task) + env = DMControlEnv( + cfg.env.name, cfg.env.task, from_pixels=from_pixels, pixels_only=False + ) return TransformedEnv( env, CatTensors(in_keys=env.observation_spec.keys(), out_key="observation") ) @@ -72,11 +76,12 @@ def apply_env_transforms(env, max_episode_steps): return transformed_env -def make_environment(cfg): +def make_environment(cfg, logger=None): """Make environments for training and evaluation.""" + maker = functools.partial(env_maker, cfg) parallel_env = ParallelEnv( cfg.collector.env_per_collector, - EnvCreator(lambda cfg=cfg: env_maker(cfg)), + EnvCreator(maker), serial_for_single=True, ) parallel_env.set_seed(cfg.env.seed) @@ -85,14 +90,19 @@ def make_environment(cfg): parallel_env, max_episode_steps=cfg.env.max_episode_steps ) + maker = functools.partial(env_maker, cfg, from_pixels=cfg.logger.video) eval_env = TransformedEnv( ParallelEnv( cfg.collector.env_per_collector, - EnvCreator(lambda cfg=cfg: env_maker(cfg)), + EnvCreator(maker), serial_for_single=True, ), train_env.transform.clone(), ) + if cfg.logger.video: + eval_env = eval_env.insert_transform( + 0, VideoRecorder(logger, tag="rendered", in_keys=["pixels"]) + ) return train_env, eval_env @@ -103,6 +113,13 @@ def make_environment(cfg): def make_collector(cfg, train_env, actor_model_explore): """Make collector.""" + device = cfg.collector.device + if device in ("", None): + if torch.cuda.is_available(): + device = "cuda:0" + else: + device = "cpu" + device = torch.device(device) collector = SyncDataCollector( train_env, actor_model_explore, @@ -110,7 +127,8 @@ def make_collector(cfg, train_env, actor_model_explore): frames_per_batch=cfg.collector.frames_per_batch, total_frames=cfg.collector.total_frames, reset_at_each_iter=cfg.collector.reset_at_each_iter, - device=cfg.collector.device, + device=device, + storing_device="cpu", ) collector.set_seed(cfg.env.seed) return collector @@ -288,3 +306,8 @@ def get_activation(cfg): return nn.LeakyReLU else: raise NotImplementedError + + +def dump_video(module): + if isinstance(module, VideoRecorder): + module.dump() diff --git a/sota-implementations/dqn/config_atari.yaml b/sota-implementations/dqn/config_atari.yaml index 691fb4ff626..50e374cef14 100644 --- a/sota-implementations/dqn/config_atari.yaml +++ b/sota-implementations/dqn/config_atari.yaml @@ -1,4 +1,4 @@ -device: cuda:0 +device: null # Environment env: @@ -27,6 +27,7 @@ logger: exp_name: DQN test_interval: 1_000_000 num_test_episodes: 3 + video: False # Optim optim: diff --git a/sota-implementations/dqn/config_cartpole.yaml b/sota-implementations/dqn/config_cartpole.yaml index 1ebeba42f8c..9a69762d6bd 100644 --- a/sota-implementations/dqn/config_cartpole.yaml +++ b/sota-implementations/dqn/config_cartpole.yaml @@ -1,4 +1,4 @@ -device: cuda:0 +device: null # Environment env: @@ -26,6 +26,7 @@ logger: exp_name: DQN test_interval: 50_000 num_test_episodes: 5 + video: False # Optim optim: diff --git a/sota-implementations/dqn/dqn_atari.py b/sota-implementations/dqn/dqn_atari.py index ba5f7cbf761..90f93551d4d 100644 --- a/sota-implementations/dqn/dqn_atari.py +++ b/sota-implementations/dqn/dqn_atari.py @@ -22,6 +22,7 @@ from torchrl.envs import ExplorationType, set_exploration_type from torchrl.modules import EGreedyModule from torchrl.objectives import DQNLoss, HardUpdate +from torchrl.record import VideoRecorder from torchrl.record.loggers import generate_exp_name, get_logger from utils_atari import eval_model, make_dqn_model, make_env @@ -29,7 +30,13 @@ @hydra.main(config_path="", config_name="config_atari", version_base="1.1") def main(cfg: "DictConfig"): # noqa: F821 - device = torch.device(cfg.device) + device = cfg.device + if device in ("", None): + if torch.cuda.is_available(): + device = "cuda:0" + else: + device = "cpu" + device = torch.device(device) # Correct for frame_skip frame_skip = 4 @@ -111,6 +118,13 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create the test environment test_env = make_env(cfg.env.env_name, frame_skip, device, is_test=True) + if cfg.logger.video: + test_env.insert_transform( + 0, + VideoRecorder( + logger, tag=f"rendered/{cfg.env.env_name}", in_keys=["pixels"] + ), + ) test_env.eval() # Main loop @@ -122,7 +136,7 @@ def main(cfg: "DictConfig"): # noqa: F821 num_test_episodes = cfg.logger.num_test_episodes q_losses = torch.zeros(num_updates, device=device) pbar = tqdm.tqdm(total=total_frames) - for data in collector: + for i, data in enumerate(collector): log_info = {} sampling_time = time.time() - sampling_start @@ -186,9 +200,10 @@ def main(cfg: "DictConfig"): # noqa: F821 # Get and log evaluation rewards and eval time with torch.no_grad(), set_exploration_type(ExplorationType.MODE): - if (collected_frames - frames_per_batch) // test_interval < ( - collected_frames // test_interval - ): + prev_test_frame = ((i - 1) * frames_per_batch) // test_interval + cur_test_frame = (i * frames_per_batch) // test_interval + final = current_frames >= collector.total_frames + if (i >= 1 and (prev_test_frame < cur_test_frame)) or final: model.eval() eval_start = time.time() test_rewards = eval_model( diff --git a/sota-implementations/dqn/dqn_cartpole.py b/sota-implementations/dqn/dqn_cartpole.py index cfe734173f5..ac3f17a9203 100644 --- a/sota-implementations/dqn/dqn_cartpole.py +++ b/sota-implementations/dqn/dqn_cartpole.py @@ -16,6 +16,7 @@ from torchrl.envs import ExplorationType, set_exploration_type from torchrl.modules import EGreedyModule from torchrl.objectives import DQNLoss, HardUpdate +from torchrl.record import VideoRecorder from torchrl.record.loggers import generate_exp_name, get_logger from utils_cartpole import eval_model, make_dqn_model, make_env @@ -23,7 +24,13 @@ @hydra.main(config_path="", config_name="config_cartpole", version_base="1.1") def main(cfg: "DictConfig"): # noqa: F821 - device = torch.device(cfg.device) + device = cfg.device + if device in ("", None): + if torch.cuda.is_available(): + device = "cuda:0" + else: + device = "cpu" + device = torch.device(device) # Make the components model = make_dqn_model(cfg.env.env_name) @@ -93,7 +100,14 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Create the test environment - test_env = make_env(cfg.env.env_name, "cpu") + test_env = make_env(cfg.env.env_name, "cpu", from_pixels=cfg.logger.video) + if cfg.logger.video: + test_env.insert_transform( + 0, + VideoRecorder( + logger, tag=f"rendered/{cfg.env.env_name}", in_keys=["pixels"] + ), + ) # Main loop collected_frames = 0 @@ -108,7 +122,7 @@ def main(cfg: "DictConfig"): # noqa: F821 sampling_start = time.time() q_losses = torch.zeros(num_updates, device=device) - for data in collector: + for i, data in enumerate(collector): log_info = {} sampling_time = time.time() - sampling_start @@ -167,9 +181,10 @@ def main(cfg: "DictConfig"): # noqa: F821 # Get and log evaluation rewards and eval time with torch.no_grad(), set_exploration_type(ExplorationType.MODE): - if (collected_frames - frames_per_batch) // test_interval < ( - collected_frames // test_interval - ): + prev_test_frame = ((i - 1) * frames_per_batch) // test_interval + cur_test_frame = (i * frames_per_batch) // test_interval + final = current_frames >= collector.total_frames + if (i >= 1 and (prev_test_frame < cur_test_frame)) or final: model.eval() eval_start = time.time() test_rewards = eval_model(model, test_env, num_test_episodes) diff --git a/sota-implementations/dqn/utils_atari.py b/sota-implementations/dqn/utils_atari.py index b9805659e63..3dbbfe87af4 100644 --- a/sota-implementations/dqn/utils_atari.py +++ b/sota-implementations/dqn/utils_atari.py @@ -23,6 +23,7 @@ ) from torchrl.modules import ConvNet, MLP, QValueActor +from torchrl.record import VideoRecorder # ==================================================================== @@ -111,7 +112,13 @@ def eval_model(actor, test_env, num_episodes=3): break_when_any_done=True, max_steps=10_000_000, ) + test_env.apply(dump_video) reward = td_test["next", "episode_reward"][td_test["next", "done"]] test_rewards[i] = reward.sum() del td_test return test_rewards.mean() + + +def dump_video(module): + if isinstance(module, VideoRecorder): + module.dump() diff --git a/sota-implementations/dqn/utils_cartpole.py b/sota-implementations/dqn/utils_cartpole.py index 8d2ec5fab06..2df280a04b4 100644 --- a/sota-implementations/dqn/utils_cartpole.py +++ b/sota-implementations/dqn/utils_cartpole.py @@ -9,14 +9,16 @@ from torchrl.envs import RewardSum, StepCounter, TransformedEnv from torchrl.envs.libs.gym import GymEnv from torchrl.modules import MLP, QValueActor +from torchrl.record import VideoRecorder + # ==================================================================== # Environment utils # -------------------------------------------------------------------- -def make_env(env_name="CartPole-v1", device="cpu"): - env = GymEnv(env_name, device=device) +def make_env(env_name="CartPole-v1", device="cpu", from_pixels=False): + env = GymEnv(env_name, device=device, from_pixels=from_pixels, pixels_only=False) env = TransformedEnv(env) env.append_transform(RewardSum()) env.append_transform(StepCounter()) @@ -74,7 +76,13 @@ def eval_model(actor, test_env, num_episodes=3): break_when_any_done=True, max_steps=10_000_000, ) + test_env.apply(dump_video) reward = td_test["next", "episode_reward"][td_test["next", "done"]] test_rewards[i] = reward.sum() del td_test return test_rewards.mean() + + +def dump_video(module): + if isinstance(module, VideoRecorder): + module.dump() diff --git a/sota-implementations/iql/discrete_iql.py b/sota-implementations/iql/discrete_iql.py index c0101f1c941..33513dd3973 100644 --- a/sota-implementations/iql/discrete_iql.py +++ b/sota-implementations/iql/discrete_iql.py @@ -24,6 +24,7 @@ from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( + dump_video, log_metrics, make_collector, make_discrete_iql_model, @@ -57,13 +58,20 @@ def main(cfg: "DictConfig"): # noqa: F821 # Set seeds torch.manual_seed(cfg.env.seed) np.random.seed(cfg.env.seed) - device = torch.device(cfg.optim.device) + device = cfg.optim.device + if device in ("", None): + if torch.cuda.is_available(): + device = "cuda:0" + else: + device = "cpu" + device = torch.device(device) # Create environments train_env, eval_env = make_environment( cfg, cfg.env.train_num_envs, cfg.env.eval_num_envs, + logger=logger, ) # Create replay buffer @@ -186,6 +194,7 @@ def main(cfg: "DictConfig"): # noqa: F821 auto_cast_to_device=True, break_when_any_done=True, ) + eval_env.apply(dump_video) eval_time = time.time() - eval_start eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() metrics_to_log["eval/reward"] = eval_reward diff --git a/sota-implementations/iql/discrete_iql.yaml b/sota-implementations/iql/discrete_iql.yaml index c21a320e375..9245d4c4832 100644 --- a/sota-implementations/iql/discrete_iql.yaml +++ b/sota-implementations/iql/discrete_iql.yaml @@ -28,6 +28,7 @@ logger: eval_steps: 200 mode: online eval_iter: 1000 + video: False # replay buffer replay_buffer: @@ -38,7 +39,7 @@ replay_buffer: # optimization optim: utd_ratio: 1 - device: cuda:0 + device: null lr: 3e-4 weight_decay: 0.0 batch_size: 256 diff --git a/sota-implementations/iql/iql_offline.py b/sota-implementations/iql/iql_offline.py index 66c6d206c3d..d98724e1371 100644 --- a/sota-implementations/iql/iql_offline.py +++ b/sota-implementations/iql/iql_offline.py @@ -22,6 +22,7 @@ from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( + dump_video, log_metrics, make_environment, make_iql_model, @@ -54,10 +55,20 @@ def main(cfg: "DictConfig"): # noqa: F821 # Set seeds torch.manual_seed(cfg.env.seed) np.random.seed(cfg.env.seed) - device = torch.device(cfg.optim.device) + device = cfg.optim.device + if device in ("", None): + if torch.cuda.is_available(): + device = "cuda:0" + else: + device = "cpu" + device = torch.device(device) # Creante env - train_env, eval_env = make_environment(cfg, cfg.logger.eval_envs) + train_env, eval_env = make_environment( + cfg, + cfg.logger.eval_envs, + logger=logger, + ) # Create replay buffer replay_buffer = make_offline_replay_buffer(cfg.replay_buffer) @@ -123,6 +134,7 @@ def main(cfg: "DictConfig"): # noqa: F821 eval_td = eval_env.rollout( max_steps=eval_steps, policy=model[0], auto_cast_to_device=True ) + eval_env.apply(dump_video) eval_reward = eval_td["next", "reward"].sum(1).mean().item() to_log["evaluation_reward"] = eval_reward if logger is not None: diff --git a/sota-implementations/iql/iql_online.py b/sota-implementations/iql/iql_online.py index 307f6df5e2b..b66c6f9dcf2 100644 --- a/sota-implementations/iql/iql_online.py +++ b/sota-implementations/iql/iql_online.py @@ -24,6 +24,7 @@ from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( + dump_video, log_metrics, make_collector, make_environment, @@ -57,13 +58,20 @@ def main(cfg: "DictConfig"): # noqa: F821 # Set seeds torch.manual_seed(cfg.env.seed) np.random.seed(cfg.env.seed) - device = torch.device(cfg.optim.device) + device = cfg.optim.device + if device in ("", None): + if torch.cuda.is_available(): + device = "cuda:0" + else: + device = "cpu" + device = torch.device(device) # Create environments train_env, eval_env = make_environment( cfg, cfg.env.train_num_envs, cfg.env.eval_num_envs, + logger=logger, ) # Create replay buffer @@ -184,6 +192,7 @@ def main(cfg: "DictConfig"): # noqa: F821 auto_cast_to_device=True, break_when_any_done=True, ) + eval_env.apply(dump_video) eval_time = time.time() - eval_start eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() metrics_to_log["eval/reward"] = eval_reward diff --git a/sota-implementations/iql/offline_config.yaml b/sota-implementations/iql/offline_config.yaml index f7486708c5a..5f34fa5651a 100644 --- a/sota-implementations/iql/offline_config.yaml +++ b/sota-implementations/iql/offline_config.yaml @@ -17,6 +17,7 @@ logger: eval_steps: 1000 mode: online eval_envs: 5 + video: False # replay buffer replay_buffer: @@ -25,7 +26,7 @@ replay_buffer: # optimization optim: - device: cuda:0 + device: null lr: 3e-4 weight_decay: 0.0 gradient_steps: 50000 diff --git a/sota-implementations/iql/online_config.yaml b/sota-implementations/iql/online_config.yaml index 511d77ec365..1f7bb361e6c 100644 --- a/sota-implementations/iql/online_config.yaml +++ b/sota-implementations/iql/online_config.yaml @@ -28,6 +28,7 @@ logger: eval_steps: 200 mode: online eval_iter: 1000 + video: False # replay buffer replay_buffer: @@ -38,7 +39,7 @@ replay_buffer: # optimization optim: utd_ratio: 1 - device: cuda:0 + device: null lr: 3e-4 weight_decay: 0.0 batch_size: 256 diff --git a/sota-implementations/iql/utils.py b/sota-implementations/iql/utils.py index 8b594d3a60c..2d5aee80ce2 100644 --- a/sota-implementations/iql/utils.py +++ b/sota-implementations/iql/utils.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import functools + import torch.nn import torch.optim from tensordict.nn import InteractionType, TensorDictModule @@ -39,6 +41,7 @@ ValueOperator, ) from torchrl.objectives import DiscreteIQLLoss, HardUpdate, IQLLoss, SoftUpdate +from torchrl.record import VideoRecorder from torchrl.trainers.helpers.models import ACTIVATIONS @@ -48,16 +51,17 @@ # ----------------- -def env_maker(cfg, device="cpu"): +def env_maker(cfg, device="cpu", from_pixels=False): lib = cfg.env.backend if lib in ("gym", "gymnasium"): with set_gym_backend(lib): return GymEnv( - cfg.env.name, - device=device, + cfg.env.name, device=device, from_pixels=from_pixels, pixels_only=False ) elif lib == "dm_control": - env = DMControlEnv(cfg.env.name, cfg.env.task) + env = DMControlEnv( + cfg.env.name, cfg.env.task, from_pixels=from_pixels, pixels_only=False + ) return TransformedEnv( env, CatTensors(in_keys=env.observation_spec.keys(), out_key="observation") ) @@ -79,25 +83,31 @@ def apply_env_transforms( return transformed_env -def make_environment(cfg, train_num_envs=1, eval_num_envs=1): +def make_environment(cfg, train_num_envs=1, eval_num_envs=1, logger=None): """Make environments for training and evaluation.""" + maker = functools.partial(env_maker, cfg) parallel_env = ParallelEnv( train_num_envs, - EnvCreator(lambda: env_maker(cfg)), + EnvCreator(maker), serial_for_single=True, ) parallel_env.set_seed(cfg.env.seed) train_env = apply_env_transforms(parallel_env) + maker = functools.partial(env_maker, cfg, from_pixels=cfg.logger.video) eval_env = TransformedEnv( ParallelEnv( eval_num_envs, - EnvCreator(lambda: env_maker(cfg)), + EnvCreator(maker), serial_for_single=True, ), train_env.transform.clone(), ) + if cfg.logger.video: + eval_env.insert_transform( + 0, VideoRecorder(logger, tag="rendered", in_keys=["pixels"]) + ) return train_env, eval_env @@ -417,3 +427,8 @@ def log_metrics(logger, metrics, step): if logger is not None: for metric_name, metric_value in metrics.items(): logger.log_scalar(metric_name, metric_value, step) + + +def dump_video(module): + if isinstance(module, VideoRecorder): + module.dump() diff --git a/sota-implementations/ppo/config_atari.yaml b/sota-implementations/ppo/config_atari.yaml index d6ec35ab5f2..31e6f13a58c 100644 --- a/sota-implementations/ppo/config_atari.yaml +++ b/sota-implementations/ppo/config_atari.yaml @@ -16,6 +16,7 @@ logger: exp_name: Atari_Schulman17 test_interval: 40_000_000 num_test_episodes: 3 + video: False # Optim optim: diff --git a/sota-implementations/ppo/config_mujoco.yaml b/sota-implementations/ppo/config_mujoco.yaml index 3320837ae3d..2dd3c6cc229 100644 --- a/sota-implementations/ppo/config_mujoco.yaml +++ b/sota-implementations/ppo/config_mujoco.yaml @@ -15,6 +15,7 @@ logger: exp_name: Mujoco_Schulman17 test_interval: 1_000_000 num_test_episodes: 5 + video: False # Optim optim: diff --git a/sota-implementations/ppo/ppo_atari.py b/sota-implementations/ppo/ppo_atari.py index 69468e133a8..908cb7924a3 100644 --- a/sota-implementations/ppo/ppo_atari.py +++ b/sota-implementations/ppo/ppo_atari.py @@ -9,6 +9,7 @@ """ import hydra from torchrl._utils import logger as torchrl_logger +from torchrl.record import VideoRecorder @hydra.main(config_path="", config_name="config_atari", version_base="1.1") @@ -104,9 +105,16 @@ def main(cfg: "DictConfig"): # noqa: F821 "group": cfg.logger.group_name, }, ) + logger_video = cfg.logger.video + else: + logger_video = False # Create test environment test_env = make_parallel_env(cfg.env.env_name, 1, device, is_test=True) + if logger_video: + test_env = test_env.append_transform( + VideoRecorder(logger, tag="rendering/test", in_keys=["pixels_int"]) + ) test_env.eval() # Main loop diff --git a/sota-implementations/ppo/ppo_mujoco.py b/sota-implementations/ppo/ppo_mujoco.py index ae4ba9ea9e5..e3e74971a49 100644 --- a/sota-implementations/ppo/ppo_mujoco.py +++ b/sota-implementations/ppo/ppo_mujoco.py @@ -9,6 +9,7 @@ """ import hydra from torchrl._utils import logger as torchrl_logger +from torchrl.record import VideoRecorder @hydra.main(config_path="", config_name="config_mujoco", version_base="1.1") @@ -96,9 +97,16 @@ def main(cfg: "DictConfig"): # noqa: F821 "group": cfg.logger.group_name, }, ) + logger_video = cfg.logger.video + else: + logger_video = False # Create test environment - test_env = make_env(cfg.env.env_name, device) + test_env = make_env(cfg.env.env_name, device, from_pixels=logger_video) + if logger_video: + test_env = test_env.append_transform( + VideoRecorder(logger, tag="rendering/test", in_keys=["pixels"]) + ) test_env.eval() # Main loop diff --git a/sota-implementations/ppo/utils_atari.py b/sota-implementations/ppo/utils_atari.py index 5cb838cac47..f2e4ae8cebf 100644 --- a/sota-implementations/ppo/utils_atari.py +++ b/sota-implementations/ppo/utils_atari.py @@ -18,6 +18,7 @@ GymEnv, NoopResetEnv, ParallelEnv, + RenameTransform, Resize, RewardSum, SignTransform, @@ -35,6 +36,8 @@ TanhNormal, ValueOperator, ) +from torchrl.record import VideoRecorder + # ==================================================================== # Environment utils @@ -64,7 +67,8 @@ def make_parallel_env(env_name, num_envs, device, is_test=False): device=device, ) env = TransformedEnv(env) - env.append_transform(ToTensorImage()) + env.append_transform(RenameTransform(in_keys=["pixels"], out_keys=["pixels_int"])) + env.append_transform(ToTensorImage(in_keys=["pixels_int"], out_keys=["pixels"])) env.append_transform(GrayScale()) env.append_transform(Resize(84, 84)) env.append_transform(CatFrames(N=4, dim=-3)) @@ -198,6 +202,11 @@ def make_ppo_models(env_name): # -------------------------------------------------------------------- +def dump_video(module): + if isinstance(module, VideoRecorder): + module.dump() + + def eval_model(actor, test_env, num_episodes=3): test_rewards = [] for _ in range(num_episodes): @@ -208,6 +217,7 @@ def eval_model(actor, test_env, num_episodes=3): break_when_any_done=True, max_steps=10_000_000, ) + test_env.apply(dump_video) reward = td_test["next", "episode_reward"][td_test["next", "done"]] test_rewards.append(reward.cpu()) del td_test diff --git a/sota-implementations/ppo/utils_mujoco.py b/sota-implementations/ppo/utils_mujoco.py index 7be234b322d..eefd8bebb6b 100644 --- a/sota-implementations/ppo/utils_mujoco.py +++ b/sota-implementations/ppo/utils_mujoco.py @@ -19,14 +19,16 @@ ) from torchrl.envs.libs.gym import GymEnv from torchrl.modules import MLP, ProbabilisticActor, TanhNormal, ValueOperator +from torchrl.record import VideoRecorder + # ==================================================================== # Environment utils # -------------------------------------------------------------------- -def make_env(env_name="HalfCheetah-v4", device="cpu"): - env = GymEnv(env_name, device=device) +def make_env(env_name="HalfCheetah-v4", device="cpu", from_pixels: bool = False): + env = GymEnv(env_name, device=device, from_pixels=from_pixels, pixels_only=False) env = TransformedEnv(env) env.append_transform(VecNorm(in_keys=["observation"], decay=0.99999, eps=1e-2)) env.append_transform(ClipTransform(in_keys=["observation"], low=-10, high=10)) @@ -126,6 +128,11 @@ def make_ppo_models(env_name): # -------------------------------------------------------------------- +def dump_video(module): + if isinstance(module, VideoRecorder): + module.dump() + + def eval_model(actor, test_env, num_episodes=3): test_rewards = [] for _ in range(num_episodes): @@ -138,5 +145,6 @@ def eval_model(actor, test_env, num_episodes=3): ) reward = td_test["next", "episode_reward"][td_test["next", "done"]] test_rewards.append(reward.cpu()) + test_env.apply(dump_video) del td_test return torch.cat(test_rewards, 0).mean() diff --git a/sota-implementations/redq/config.yaml b/sota-implementations/redq/config.yaml index c67543716dc..e60191c0f93 100644 --- a/sota-implementations/redq/config.yaml +++ b/sota-implementations/redq/config.yaml @@ -43,11 +43,11 @@ logger: project_name: torchrl_example_redq group_name: null exp_name: cheetah - record_video: 0 record_interval: 10 record_frames: 10000 mode: online recorder_log_keys: + video: False optim: optimizer: adam diff --git a/sota-implementations/redq/redq.py b/sota-implementations/redq/redq.py index d9aef64b525..d6b1668aadf 100644 --- a/sota-implementations/redq/redq.py +++ b/sota-implementations/redq/redq.py @@ -63,18 +63,20 @@ def main(cfg: "DictConfig"): # noqa: F821 ] ) - logger = get_logger( - logger_type=cfg.logger.backend, - logger_name="redq_logging", - experiment_name=exp_name, - wandb_kwargs={ - "mode": cfg.logger.mode, - "config": dict(cfg), - "project": cfg.logger.project_name, - "group": cfg.logger.group_name, - }, - ) - video_tag = exp_name if cfg.logger.record_video else "" + if cfg.logger.backend: + logger = get_logger( + logger_type=cfg.logger.backend, + logger_name="redq_logging", + experiment_name=exp_name, + wandb_kwargs={ + "mode": cfg.logger.mode, + "config": dict(cfg), + "project": cfg.logger.project_name, + "group": cfg.logger.group_name, + }, + ) + else: + logger = "" key, init_env_steps, stats = None, None, None if not cfg.env.vecnorm and cfg.env.norm_stats: @@ -146,7 +148,7 @@ def main(cfg: "DictConfig"): # noqa: F821 recorder = transformed_env_constructor( cfg, - video_tag=video_tag, + video_tag="rendering/test", norm_obs_only=True, obs_norm_state_dict=obs_norm_state_dict, logger=logger, @@ -162,8 +164,8 @@ def main(cfg: "DictConfig"): # noqa: F821 recorder.transform = create_env_fn.transform.clone() else: raise NotImplementedError(f"Unsupported env type {type(create_env_fn)}") - if logger is not None and video_tag: - recorder.insert_transform(0, VideoRecorder(logger=logger, tag=video_tag)) + if logger is not None and cfg.logger.video: + recorder.insert_transform(0, VideoRecorder(logger=logger, tag="rendering/test")) # reset reward scaling for t in recorder.transform: diff --git a/sota-implementations/redq/utils.py b/sota-implementations/redq/utils.py index 37e7da91b4a..0d2e53b9cb1 100644 --- a/sota-implementations/redq/utils.py +++ b/sota-implementations/redq/utils.py @@ -282,7 +282,7 @@ def make_trainer( rb_trainer = ReplayBufferTrainer( replay_buffer, cfg.buffer.batch_size, - flatten_tensordicts=False, + flatten_tensordicts=True, memmap=False, device=device, ) @@ -1044,7 +1044,6 @@ def make_replay_buffer( storage=LazyMemmapStorage( cfg.buffer.size, scratch_dir=cfg.buffer.scratch_dir, - # device=device, # when using prefetch, this can overload the GPU memory ), sampler=sampler, pin_memory=device != torch.device("cpu"), diff --git a/sota-implementations/sac/config.yaml b/sota-implementations/sac/config.yaml index 6546f1e30b7..29586f2e9a7 100644 --- a/sota-implementations/sac/config.yaml +++ b/sota-implementations/sac/config.yaml @@ -40,7 +40,7 @@ network: activation: relu default_policy_scale: 1.0 scale_lb: 0.1 - device: "cuda:0" + device: # logging logger: @@ -50,3 +50,4 @@ logger: exp_name: ${env.name}_SAC mode: online eval_iter: 25000 + video: False diff --git a/sota-implementations/sac/sac.py b/sota-implementations/sac/sac.py index 576de96394d..f7a399cda72 100644 --- a/sota-implementations/sac/sac.py +++ b/sota-implementations/sac/sac.py @@ -24,6 +24,7 @@ from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( + dump_video, log_metrics, make_collector, make_environment, @@ -36,7 +37,13 @@ @hydra.main(version_base="1.1", config_path="", config_name="config") def main(cfg: "DictConfig"): # noqa: F821 - device = torch.device(cfg.network.device) + device = cfg.network.device + if device in ("", None): + if torch.cuda.is_available(): + device = torch.device("cuda:0") + else: + device = torch.device("cpu") + device = torch.device(device) # Create logger exp_name = generate_exp_name("SAC", cfg.logger.exp_name) @@ -58,7 +65,7 @@ def main(cfg: "DictConfig"): # noqa: F821 np.random.seed(cfg.env.seed) # Create environments - train_env, eval_env = make_environment(cfg) + train_env, eval_env = make_environment(cfg, logger=logger) # Create agent model, exploration_policy = make_sac_agent(cfg, train_env, eval_env, device) @@ -198,6 +205,7 @@ def main(cfg: "DictConfig"): # noqa: F821 auto_cast_to_device=True, break_when_any_done=True, ) + eval_env.apply(dump_video) eval_time = time.time() - eval_start eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() metrics_to_log["eval/reward"] = eval_reward diff --git a/sota-implementations/sac/utils.py b/sota-implementations/sac/utils.py index afb731dcc95..d190769772c 100644 --- a/sota-implementations/sac/utils.py +++ b/sota-implementations/sac/utils.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import functools import torch from tensordict.nn import InteractionType, TensorDictModule @@ -26,6 +27,7 @@ from torchrl.modules.distributions import TanhNormal from torchrl.objectives import SoftUpdate from torchrl.objectives.sac import SACLoss +from torchrl.record import VideoRecorder # ==================================================================== @@ -33,16 +35,20 @@ # ----------------- -def env_maker(cfg, device="cpu"): +def env_maker(cfg, device="cpu", from_pixels=False): lib = cfg.env.library if lib in ("gym", "gymnasium"): with set_gym_backend(lib): return GymEnv( cfg.env.name, device=device, + from_pixels=from_pixels, + pixels_only=False, ) elif lib == "dm_control": - env = DMControlEnv(cfg.env.name, cfg.env.task) + env = DMControlEnv( + cfg.env.name, cfg.env.task, from_pixels=from_pixels, pixels_only=False + ) return TransformedEnv( env, CatTensors(in_keys=env.observation_spec.keys(), out_key="observation") ) @@ -63,24 +69,31 @@ def apply_env_transforms(env, max_episode_steps=1000): return transformed_env -def make_environment(cfg): +def make_environment(cfg, logger=None): """Make environments for training and evaluation.""" + partial = functools.partial(env_maker, cfg=cfg) parallel_env = ParallelEnv( cfg.collector.env_per_collector, - EnvCreator(lambda cfg=cfg: env_maker(cfg)), + EnvCreator(partial), serial_for_single=True, ) parallel_env.set_seed(cfg.env.seed) train_env = apply_env_transforms(parallel_env, cfg.env.max_episode_steps) + partial = functools.partial(env_maker, cfg=cfg, from_pixels=cfg.logger.video) + trsf_clone = train_env.transform.clone() + if cfg.logger.video: + trsf_clone.insert( + 0, VideoRecorder(logger, tag="rendering/test", in_keys=["pixels"]) + ) eval_env = TransformedEnv( ParallelEnv( cfg.collector.env_per_collector, - EnvCreator(lambda cfg=cfg: env_maker(cfg)), + EnvCreator(partial), serial_for_single=True, ), - train_env.transform.clone(), + trsf_clone, ) return train_env, eval_env @@ -211,13 +224,10 @@ def make_sac_agent(cfg, train_env, eval_env, device): # init nets with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): - td = eval_env.reset() + td = eval_env.fake_tensordict() td = td.to(device) for net in model: net(td) - del td - eval_env.close() - return model, model[0] @@ -298,3 +308,8 @@ def get_activation(cfg): return nn.LeakyReLU else: raise NotImplementedError + + +def dump_video(module): + if isinstance(module, VideoRecorder): + module.dump() diff --git a/sota-implementations/td3/config.yaml b/sota-implementations/td3/config.yaml index e94a5b6b774..7f7854b68b3 100644 --- a/sota-implementations/td3/config.yaml +++ b/sota-implementations/td3/config.yaml @@ -41,7 +41,7 @@ optim: network: hidden_sizes: [256, 256] activation: relu - device: "cuda:0" + device: null # logging logger: @@ -51,3 +51,4 @@ logger: exp_name: ${env.name}_TD3 mode: online eval_iter: 25000 + video: False diff --git a/sota-implementations/td3/td3.py b/sota-implementations/td3/td3.py index 6b1ee046d55..97fd039c238 100644 --- a/sota-implementations/td3/td3.py +++ b/sota-implementations/td3/td3.py @@ -23,6 +23,7 @@ from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( + dump_video, log_metrics, make_collector, make_environment, @@ -35,7 +36,13 @@ @hydra.main(version_base="1.1", config_path="", config_name="config") def main(cfg: "DictConfig"): # noqa: F821 - device = torch.device(cfg.network.device) + device = cfg.network.device + if device in ("", None): + if torch.cuda.is_available(): + device = torch.device("cuda:0") + else: + device = torch.device("cpu") + device = torch.device(device) # Create logger exp_name = generate_exp_name("TD3", cfg.logger.exp_name) @@ -58,7 +65,7 @@ def main(cfg: "DictConfig"): # noqa: F821 np.random.seed(cfg.env.seed) # Create environments - train_env, eval_env = make_environment(cfg) + train_env, eval_env = make_environment(cfg, logger=logger) # Create agent model, exploration_policy = make_td3_agent(cfg, train_env, eval_env, device) @@ -196,6 +203,7 @@ def main(cfg: "DictConfig"): # noqa: F821 auto_cast_to_device=True, break_when_any_done=True, ) + eval_env.apply(dump_video) eval_time = time.time() - eval_start eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() metrics_to_log["eval/reward"] = eval_reward diff --git a/sota-implementations/td3/utils.py b/sota-implementations/td3/utils.py index fed055f98bf..c597ae205a2 100644 --- a/sota-implementations/td3/utils.py +++ b/sota-implementations/td3/utils.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import functools import tempfile from contextlib import nullcontext @@ -36,6 +37,7 @@ from torchrl.objectives import SoftUpdate from torchrl.objectives.td3 import TD3Loss +from torchrl.record import VideoRecorder # ==================================================================== @@ -43,16 +45,20 @@ # ----------------- -def env_maker(cfg, device="cpu"): +def env_maker(cfg, device="cpu", from_pixels=False): lib = cfg.env.library if lib in ("gym", "gymnasium"): with set_gym_backend(lib): return GymEnv( cfg.env.name, device=device, + from_pixels=from_pixels, + pixels_only=False, ) elif lib == "dm_control": - env = DMControlEnv(cfg.env.name, cfg.env.task) + env = DMControlEnv( + cfg.env.name, cfg.env.task, from_pixels=from_pixels, pixels_only=False + ) return TransformedEnv( env, CatTensors(in_keys=env.observation_spec.keys(), out_key="observation") ) @@ -73,26 +79,31 @@ def apply_env_transforms(env, max_episode_steps): return transformed_env -def make_environment(cfg): +def make_environment(cfg, logger=None): """Make environments for training and evaluation.""" + partial = functools.partial(env_maker, cfg=cfg) parallel_env = ParallelEnv( cfg.collector.env_per_collector, - EnvCreator(lambda cfg=cfg: env_maker(cfg)), + EnvCreator(partial), serial_for_single=True, ) parallel_env.set_seed(cfg.env.seed) - train_env = apply_env_transforms( - parallel_env, max_episode_steps=cfg.env.max_episode_steps - ) + train_env = apply_env_transforms(parallel_env, cfg.env.max_episode_steps) + partial = functools.partial(env_maker, cfg=cfg, from_pixels=cfg.logger.video) + trsf_clone = train_env.transform.clone() + if cfg.logger.video: + trsf_clone.insert( + 0, VideoRecorder(logger, tag="rendering/test", in_keys=["pixels"]) + ) eval_env = TransformedEnv( ParallelEnv( cfg.collector.env_per_collector, - EnvCreator(lambda cfg=cfg: env_maker(cfg)), + EnvCreator(partial), serial_for_single=True, ), - train_env.transform.clone(), + trsf_clone, ) return train_env, eval_env @@ -297,3 +308,8 @@ def get_activation(cfg): return nn.LeakyReLU else: raise NotImplementedError + + +def dump_video(module): + if isinstance(module, VideoRecorder): + module.dump() diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 24cf2819eab..ccab829d480 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -929,7 +929,20 @@ def __getattr__(self, attr: str) -> Any: attr ) # make sure that appropriate exceptions are raised except AttributeError as err: - if attr.endswith("_spec"): + if attr in ( + "action_spec", + "done_spec", + "full_action_spec", + "full_done_spec", + "full_observation_spec", + "full_reward_spec", + "full_state_spec", + "input_spec", + "observation_spec", + "output_spec", + "reward_spec", + "state_spec", + ): raise AttributeError( f"Could not get {attr} because an internal error was raised. To find what this error " f"is, call env.transform.transform__spec(env.base_env.spec)." @@ -3511,9 +3524,10 @@ def func(name, item): item = self._apply_transform(item) tensordict.set(name, item) - return tensordict._fast_apply( + tensordict._fast_apply( func, named=True, nested_keys=True, filter_empty=True ) + return tensordict else: # we made sure that if in_keys is not None, out_keys is not None either for in_key, out_key in zip(in_keys, out_keys): diff --git a/torchrl/record/loggers/csv.py b/torchrl/record/loggers/csv.py index 3f188a02a61..dc3aff2ad6b 100644 --- a/torchrl/record/loggers/csv.py +++ b/torchrl/record/loggers/csv.py @@ -43,7 +43,8 @@ def add_scalar(self, name: str, value: float, global_step: Optional[int] = None) if not os.path.isfile(filepath): os.makedirs(Path(filepath).parent, exist_ok=True) if filepath not in self.files: - self.files[filepath] = open(filepath, "a") + os.makedirs(Path(filepath).parent, exist_ok=True) + self.files[filepath] = open(filepath, "a+") fd = self.files[filepath] fd.write(",".join([str(global_step), str(value)]) + "\n") fd.flush() diff --git a/torchrl/record/loggers/utils.py b/torchrl/record/loggers/utils.py index ec7321f5bbd..226135f333f 100644 --- a/torchrl/record/loggers/utils.py +++ b/torchrl/record/loggers/utils.py @@ -44,7 +44,9 @@ def get_logger( elif logger_type == "csv": from torchrl.record.loggers.csv import CSVLogger - logger = CSVLogger(log_dir=logger_name, exp_name=experiment_name) + logger = CSVLogger( + log_dir=logger_name, exp_name=experiment_name, video_format="mp4" + ) elif logger_type == "wandb": from torchrl.record.loggers.wandb import WandbLogger