diff --git a/examples/td3/config.yaml b/examples/td3/config.yaml index 71a36dcb69f..642ec65037c 100644 --- a/examples/td3/config.yaml +++ b/examples/td3/config.yaml @@ -1,21 +1,15 @@ # Environment env_name: HalfCheetah-v4 -env_task: "" exp_name: "debugging" -env_library: gym -record_video: 0 -normalize_rewards_online: 0 -normalize_rewards_online_scale: 5 -normalize_rewards_online_decay: 0.99 total_frames: 1000000 -frames_per_batch: 1000 +frames_per_batch: 200 max_frames_per_traj: 1000 frame_skip: 1 from_pixels: 0 seed: 0 # Collection -init_random_frames: 25000 +init_random_frames: 10000 init_env_steps: 10000 record_interval: 10 record_frames: 10000 @@ -35,7 +29,6 @@ loss_function: smooth_l1 lr: 3e-4 weight_decay: 0.0 lr_scheduler: "" -optim_steps_per_batch: 128 batch_size: 256 target_update_polyak: 0.995 @@ -55,4 +48,4 @@ mode: online batch_transform: 1 buffer_prefetch: 64 norm_stats: 1 -device: "cpu" +device: cuda:0 diff --git a/examples/td3/td3.py b/examples/td3/td3.py index ede0360f65a..9eaa3d7759e 100644 --- a/examples/td3/td3.py +++ b/examples/td3/td3.py @@ -13,7 +13,7 @@ from tensordict.nn import InteractionType from torch import nn, optim -from torchrl.collectors import MultiSyncDataCollector +from torchrl.collectors import MultiaSyncDataCollector from torchrl.data import TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer from torchrl.data.replay_buffers.storages import LazyMemmapStorage @@ -110,7 +110,8 @@ def main(cfg: "DictConfig"): # noqa: F821 np.random.seed(cfg.seed) parallel_env = ParallelEnv( - cfg.env_per_collector, EnvCreator(lambda: env_maker(task=cfg.env_name)) + cfg.env_per_collector, + EnvCreator(lambda: env_maker(task=cfg.env_name, device=device)), ) parallel_env.set_seed(cfg.seed) @@ -124,7 +125,8 @@ def main(cfg: "DictConfig"): # noqa: F821 eval_env = TransformedEnv( ParallelEnv( - cfg.env_per_collector, EnvCreator(lambda: env_maker(task=cfg.env_name)) + cfg.env_per_collector, + EnvCreator(lambda: env_maker(task=cfg.env_name, device=device)), ), train_env.transform.clone(), ) @@ -205,15 +207,21 @@ def main(cfg: "DictConfig"): # noqa: F821 sigma_init=1, sigma_end=1, mean=0, - std=0.01, + std=0.1, ).to(device) # Create TD3 loss + if cfg.loss == "double": + double = True + elif cfg.loss == "single": + double = False + else: + raise NotImplementedError loss_module = TD3Loss( actor_network=model[0], qvalue_network=model[1], - num_qvalue_nets=2, - loss_function="smooth_l1", + num_qvalue_nets=2 if double else 1, + loss_function=cfg.loss_function, ) loss_module.make_value_estimator(gamma=cfg.gamma) @@ -221,11 +229,8 @@ def main(cfg: "DictConfig"): # noqa: F821 target_net_updater = SoftUpdate(loss_module, cfg.target_update_polyak) # Make Off-Policy Collector - collector = MultiSyncDataCollector( - # we'll just run one ParallelEnvironment. Adding elements to the list would increase the number of envs run in parallel - [ - train_env, - ], + collector = MultiaSyncDataCollector( + [train_env], actor_model_explore, frames_per_batch=cfg.frames_per_batch, max_frames_per_traj=cfg.max_frames_per_traj, @@ -270,13 +275,8 @@ def main(cfg: "DictConfig"): # noqa: F821 pbar.update(tensordict.numel()) # extend the replay buffer with the new data - if ("collector", "mask") in tensordict.keys(True): - # if multi-step, a mask is present to help filter padded values - current_frames = tensordict["collector", "mask"].sum() - tensordict = tensordict[tensordict.get(("collector", "mask")).squeeze(-1)] - else: - tensordict = tensordict.view(-1) - current_frames = tensordict.numel() + tensordict = tensordict.view(-1) + current_frames = tensordict.numel() replay_buffer.extend(tensordict.cpu()) collected_frames += current_frames @@ -298,11 +298,12 @@ def main(cfg: "DictConfig"): # noqa: F821 q_loss = loss_td["loss_qvalue"] optimizer_critic.zero_grad() - q_loss.backward(retain_graph=True) + update_actor = i % cfg.policy_update_delay == 0 + q_loss.backward(retain_graph=update_actor) optimizer_critic.step() q_losses.append(q_loss.item()) - if i % cfg.policy_update_delay == 0: + if update_actor: optimizer_actor.zero_grad() actor_loss.backward() optimizer_actor.step() diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index 605bc0d7af7..e05bf698f6e 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -53,10 +53,10 @@ class TD3Loss(LossModule): ``"l1"``, Default is ``"smooth_l1"``. delay_actor (bool, optional): whether to separate the target actor networks from the actor networks used for - data collection. Default is ``False``. + data collection. Default is ``True``. delay_qvalue (bool, optional): Whether to separate the target Q value networks from the Q value networks used - for data collection. Default is ``False``. + for data collection. Default is ``True``. """ default_value_estimator = ValueEstimators.TD0 @@ -71,8 +71,8 @@ def __init__( noise_clip: float = 0.5, priority_key: str = "td_error", loss_function: str = "smooth_l1", - delay_actor: bool = False, - delay_qvalue: bool = False, + delay_actor: bool = True, + delay_qvalue: bool = True, gamma: float = None, ) -> None: if not _has_functorch: @@ -132,10 +132,11 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: actor_params, ) # add noise to target policy + action = actor_output_td[1].get("action") noise = torch.normal( - mean=torch.zeros(actor_output_td[1]["action"].shape), - std=torch.ones(actor_output_td[1]["action"].shape) * self.policy_noise, - ).to(actor_output_td[1].device) + mean=torch.zeros(action.shape), + std=torch.full(action.shape, self.policy_noise), + ).to(action.device) noise = noise.clamp(-self.noise_clip, self.noise_clip) next_action = (actor_output_td[1]["action"] + noise).clamp(