Skip to content

Commit

Permalink
[BugFix] Fix TD3 target net (#1186)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored May 23, 2023
1 parent 079dee7 commit 4ece06c
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 37 deletions.
13 changes: 3 additions & 10 deletions examples/td3/config.yaml
Original file line number Diff line number Diff line change
@@ -1,21 +1,15 @@
# Environment
env_name: HalfCheetah-v4
env_task: ""
exp_name: "debugging"
env_library: gym
record_video: 0
normalize_rewards_online: 0
normalize_rewards_online_scale: 5
normalize_rewards_online_decay: 0.99
total_frames: 1000000
frames_per_batch: 1000
frames_per_batch: 200
max_frames_per_traj: 1000
frame_skip: 1
from_pixels: 0
seed: 0

# Collection
init_random_frames: 25000
init_random_frames: 10000
init_env_steps: 10000
record_interval: 10
record_frames: 10000
Expand All @@ -35,7 +29,6 @@ loss_function: smooth_l1
lr: 3e-4
weight_decay: 0.0
lr_scheduler: ""
optim_steps_per_batch: 128
batch_size: 256
target_update_polyak: 0.995

Expand All @@ -55,4 +48,4 @@ mode: online
batch_transform: 1
buffer_prefetch: 64
norm_stats: 1
device: "cpu"
device: cuda:0
41 changes: 21 additions & 20 deletions examples/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from tensordict.nn import InteractionType

from torch import nn, optim
from torchrl.collectors import MultiSyncDataCollector
from torchrl.collectors import MultiaSyncDataCollector
from torchrl.data import TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer

from torchrl.data.replay_buffers.storages import LazyMemmapStorage
Expand Down Expand Up @@ -110,7 +110,8 @@ def main(cfg: "DictConfig"): # noqa: F821
np.random.seed(cfg.seed)

parallel_env = ParallelEnv(
cfg.env_per_collector, EnvCreator(lambda: env_maker(task=cfg.env_name))
cfg.env_per_collector,
EnvCreator(lambda: env_maker(task=cfg.env_name, device=device)),
)
parallel_env.set_seed(cfg.seed)

Expand All @@ -124,7 +125,8 @@ def main(cfg: "DictConfig"): # noqa: F821

eval_env = TransformedEnv(
ParallelEnv(
cfg.env_per_collector, EnvCreator(lambda: env_maker(task=cfg.env_name))
cfg.env_per_collector,
EnvCreator(lambda: env_maker(task=cfg.env_name, device=device)),
),
train_env.transform.clone(),
)
Expand Down Expand Up @@ -205,27 +207,30 @@ def main(cfg: "DictConfig"): # noqa: F821
sigma_init=1,
sigma_end=1,
mean=0,
std=0.01,
std=0.1,
).to(device)

# Create TD3 loss
if cfg.loss == "double":
double = True
elif cfg.loss == "single":
double = False
else:
raise NotImplementedError
loss_module = TD3Loss(
actor_network=model[0],
qvalue_network=model[1],
num_qvalue_nets=2,
loss_function="smooth_l1",
num_qvalue_nets=2 if double else 1,
loss_function=cfg.loss_function,
)
loss_module.make_value_estimator(gamma=cfg.gamma)

# Define Target Network Updater
target_net_updater = SoftUpdate(loss_module, cfg.target_update_polyak)

# Make Off-Policy Collector
collector = MultiSyncDataCollector(
# we'll just run one ParallelEnvironment. Adding elements to the list would increase the number of envs run in parallel
[
train_env,
],
collector = MultiaSyncDataCollector(
[train_env],
actor_model_explore,
frames_per_batch=cfg.frames_per_batch,
max_frames_per_traj=cfg.max_frames_per_traj,
Expand Down Expand Up @@ -270,13 +275,8 @@ def main(cfg: "DictConfig"): # noqa: F821
pbar.update(tensordict.numel())

# extend the replay buffer with the new data
if ("collector", "mask") in tensordict.keys(True):
# if multi-step, a mask is present to help filter padded values
current_frames = tensordict["collector", "mask"].sum()
tensordict = tensordict[tensordict.get(("collector", "mask")).squeeze(-1)]
else:
tensordict = tensordict.view(-1)
current_frames = tensordict.numel()
tensordict = tensordict.view(-1)
current_frames = tensordict.numel()
replay_buffer.extend(tensordict.cpu())
collected_frames += current_frames

Expand All @@ -298,11 +298,12 @@ def main(cfg: "DictConfig"): # noqa: F821
q_loss = loss_td["loss_qvalue"]

optimizer_critic.zero_grad()
q_loss.backward(retain_graph=True)
update_actor = i % cfg.policy_update_delay == 0
q_loss.backward(retain_graph=update_actor)
optimizer_critic.step()
q_losses.append(q_loss.item())

if i % cfg.policy_update_delay == 0:
if update_actor:
optimizer_actor.zero_grad()
actor_loss.backward()
optimizer_actor.step()
Expand Down
15 changes: 8 additions & 7 deletions torchrl/objectives/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ class TD3Loss(LossModule):
``"l1"``, Default is ``"smooth_l1"``.
delay_actor (bool, optional): whether to separate the target actor
networks from the actor networks used for
data collection. Default is ``False``.
data collection. Default is ``True``.
delay_qvalue (bool, optional): Whether to separate the target Q value
networks from the Q value networks used
for data collection. Default is ``False``.
for data collection. Default is ``True``.
"""

default_value_estimator = ValueEstimators.TD0
Expand All @@ -71,8 +71,8 @@ def __init__(
noise_clip: float = 0.5,
priority_key: str = "td_error",
loss_function: str = "smooth_l1",
delay_actor: bool = False,
delay_qvalue: bool = False,
delay_actor: bool = True,
delay_qvalue: bool = True,
gamma: float = None,
) -> None:
if not _has_functorch:
Expand Down Expand Up @@ -132,10 +132,11 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
actor_params,
)
# add noise to target policy
action = actor_output_td[1].get("action")
noise = torch.normal(
mean=torch.zeros(actor_output_td[1]["action"].shape),
std=torch.ones(actor_output_td[1]["action"].shape) * self.policy_noise,
).to(actor_output_td[1].device)
mean=torch.zeros(action.shape),
std=torch.full(action.shape, self.policy_noise),
).to(action.device)
noise = noise.clamp(-self.noise_clip, self.noise_clip)

next_action = (actor_output_td[1]["action"] + noise).clamp(
Expand Down

2 comments on commit 4ece06c

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 4ece06c Previous: ae10bb8 Ratio
benchmarks/test_objectives_benchmarks.py::test_values[vec_td1_return_estimate-False-False] 5.112405251834303 iter/sec (stddev: 0.007928172008496546) 11.599514688182666 iter/sec (stddev: 0.0028920320909638386) 2.27

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'GPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 4ece06c Previous: b4cf4d7 Ratio
benchmarks/test_objectives_benchmarks.py::test_values[vec_td1_return_estimate-False-False] 5.683931101821377 iter/sec (stddev: 0.006199355694531112) 12.103495997917573 iter/sec (stddev: 0.002614638756833848) 2.13

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.