Skip to content

Commit

Permalink
[BugFix] Fix TD3 inplace updates (#1219)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jun 1, 2023
1 parent 235a1fa commit 73a4408
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 83 deletions.
163 changes: 82 additions & 81 deletions .circleci/unittest/linux_examples/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,87 +29,87 @@ python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_
python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test_deps.py -v --durations 20

# With batched environments
python .circleci/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo.py \
env.num_envs=1 \
collector.total_frames=48 \
collector.frames_per_batch=16 \
collector.collector_device=cuda:0 \
optim.device=cuda:0 \
loss.mini_batch_size=10 \
loss.ppo_epochs=1 \
logger.backend= \
logger.log_interval=4 \
optim.lr_scheduler=False
python .circleci/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \
total_frames=48 \
init_random_frames=10 \
batch_size=10 \
frames_per_batch=16 \
num_workers=4 \
env_per_collector=2 \
collector_device=cuda:0 \
optim_steps_per_batch=1 \
record_video=True \
record_frames=4 \
buffer_size=120
python .circleci/unittest/helpers/coverage_run_parallel.py examples/a2c/a2c.py \
env.num_envs=1 \
collector.total_frames=48 \
collector.frames_per_batch=16 \
collector.collector_device=cuda:0 \
logger.backend= \
logger.log_interval=4 \
optim.lr_scheduler=False \
optim.device=cuda:0
python .circleci/unittest/helpers/coverage_run_parallel.py examples/dqn/dqn.py \
total_frames=48 \
init_random_frames=10 \
batch_size=10 \
frames_per_batch=16 \
num_workers=4 \
env_per_collector=2 \
collector_device=cuda:0 \
optim_steps_per_batch=1 \
record_video=True \
record_frames=4 \
buffer_size=120
python .circleci/unittest/helpers/coverage_run_parallel.py examples/redq/redq.py \
total_frames=48 \
init_random_frames=10 \
batch_size=10 \
frames_per_batch=16 \
num_workers=4 \
env_per_collector=2 \
collector_device=cuda:0 \
optim_steps_per_batch=1 \
record_video=True \
record_frames=4 \
buffer_size=120
python .circleci/unittest/helpers/coverage_run_parallel.py examples/sac/sac.py \
total_frames=48 \
init_random_frames=10 \
batch_size=10 \
frames_per_batch=16 \
num_workers=4 \
env_per_collector=2 \
collector_device=cuda:0 \
optim_steps_per_batch=1 \
record_video=True \
record_frames=4 \
buffer_size=120
python .circleci/unittest/helpers/coverage_run_parallel.py examples/dreamer/dreamer.py \
total_frames=200 \
init_random_frames=10 \
batch_size=10 \
frames_per_batch=200 \
num_workers=4 \
env_per_collector=2 \
collector_device=cuda:0 \
optim_steps_per_batch=1 \
record_video=True \
record_frames=4 \
buffer_size=120 \
rssm_hidden_dim=17
#python .circleci/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo.py \
# env.num_envs=1 \
# collector.total_frames=48 \
# collector.frames_per_batch=16 \
# collector.collector_device=cuda:0 \
# optim.device=cuda:0 \
# loss.mini_batch_size=10 \
# loss.ppo_epochs=1 \
# logger.backend= \
# logger.log_interval=4 \
# optim.lr_scheduler=False
#python .circleci/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \
# total_frames=48 \
# init_random_frames=10 \
# batch_size=10 \
# frames_per_batch=16 \
# num_workers=4 \
# env_per_collector=2 \
# collector_device=cuda:0 \
# optim_steps_per_batch=1 \
# record_video=True \
# record_frames=4 \
# buffer_size=120
#python .circleci/unittest/helpers/coverage_run_parallel.py examples/a2c/a2c.py \
# env.num_envs=1 \
# collector.total_frames=48 \
# collector.frames_per_batch=16 \
# collector.collector_device=cuda:0 \
# logger.backend= \
# logger.log_interval=4 \
# optim.lr_scheduler=False \
# optim.device=cuda:0
#python .circleci/unittest/helpers/coverage_run_parallel.py examples/dqn/dqn.py \
# total_frames=48 \
# init_random_frames=10 \
# batch_size=10 \
# frames_per_batch=16 \
# num_workers=4 \
# env_per_collector=2 \
# collector_device=cuda:0 \
# optim_steps_per_batch=1 \
# record_video=True \
# record_frames=4 \
# buffer_size=120
#python .circleci/unittest/helpers/coverage_run_parallel.py examples/redq/redq.py \
# total_frames=48 \
# init_random_frames=10 \
# batch_size=10 \
# frames_per_batch=16 \
# num_workers=4 \
# env_per_collector=2 \
# collector_device=cuda:0 \
# optim_steps_per_batch=1 \
# record_video=True \
# record_frames=4 \
# buffer_size=120
#python .circleci/unittest/helpers/coverage_run_parallel.py examples/sac/sac.py \
# total_frames=48 \
# init_random_frames=10 \
# batch_size=10 \
# frames_per_batch=16 \
# num_workers=4 \
# env_per_collector=2 \
# collector_device=cuda:0 \
# optim_steps_per_batch=1 \
# record_video=True \
# record_frames=4 \
# buffer_size=120
#python .circleci/unittest/helpers/coverage_run_parallel.py examples/dreamer/dreamer.py \
# total_frames=200 \
# init_random_frames=10 \
# batch_size=10 \
# frames_per_batch=200 \
# num_workers=4 \
# env_per_collector=2 \
# collector_device=cuda:0 \
# optim_steps_per_batch=1 \
# record_video=True \
# record_frames=4 \
# buffer_size=120 \
# rssm_hidden_dim=17
python .circleci/unittest/helpers/coverage_run_parallel.py examples/td3/td3.py \
total_frames=48 \
init_random_frames=10 \
Expand All @@ -118,6 +118,7 @@ python .circleci/unittest/helpers/coverage_run_parallel.py examples/td3/td3.py \
num_workers=4 \
env_per_collector=2 \
collector_device=cuda:0 \
device=cuda:0 \
mode=offline
python .circleci/unittest/helpers/coverage_run_parallel.py examples/iql/iql_online.py \
total_frames=48 \
Expand Down
6 changes: 6 additions & 0 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,6 +952,9 @@ def test_ddpg_tensordict_run(self, td_est):
_ = loss_fn(td)


@pytest.mark.skipif(
not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}"
)
class TestTD3(LossModuleTestBase):
seed = 0

Expand Down Expand Up @@ -1721,6 +1724,9 @@ def test_sac_tensordict_keys(self, td_est, version):
self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping)


@pytest.mark.skipif(
not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}"
)
class TestDiscreteSAC(LossModuleTestBase):
seed = 0

Expand Down
3 changes: 2 additions & 1 deletion torchrl/objectives/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,8 @@ def _target_param_getter(self, network_name):
return target_params
else:
params = getattr(self, param_name)
return params.detach()
# should we clone here?
return params.detach() # .clone()

else:
raise RuntimeError(
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
next_action = (actor_output_td[1][self.tensor_keys.action] + noise).clamp(
-self.max_action, self.max_action
)
actor_output_td[1].set(self.tensor_keys.action, next_action, inplace=True)
actor_output_td[1].set(self.tensor_keys.action, next_action)
tensordict_actor.set(
self.tensor_keys.action,
actor_output_td.get(self.tensor_keys.action),
Expand Down

1 comment on commit 73a4408

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'GPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 73a4408 Previous: 235a1fa Ratio
benchmarks/test_objectives_benchmarks.py::test_gae_speed[vec_generalized_advantage_estimate-False-32-512] 159.2509483635424 iter/sec (stddev: 0.00044593794032505115) 368.6238965124557 iter/sec (stddev: 0.0000672945303683467) 2.31

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.