Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix discrete SAC log-prob #1750

Merged
merged 4 commits into from
Dec 17, 2023
Merged

[BugFix] Fix discrete SAC log-prob #1750

merged 4 commits into from
Dec 17, 2023

Conversation

vmoens
Copy link
Contributor

@vmoens vmoens commented Dec 16, 2023

Copy link

pytorch-bot bot commented Dec 16, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/1750

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (9 Unrelated Failures)

As of commit b1d5efd with merge base 08f0bed (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 16, 2023
@vmoens vmoens added the bug Something isn't working label Dec 16, 2023
@vmoens vmoens marked this pull request as ready for review December 16, 2023 07:47
@matteobettini
Copy link
Contributor

matteobettini commented Dec 16, 2023

I don't get the fix.

We need both prob and log_prob

We also need to rerun the example training run for discrete SAC and confirm the returns curve match (if we change the logic)

Copy link

github-actions bot commented Dec 16, 2023

$\color{#D29922}\textsf{\Large⚠\kern{0.2cm}\normalsize Warning}$ Result of CPU Benchmark Tests

Total Benchmarks: 89. Improved: $\large\color{#35bf28}3$. Worsened: $\large\color{#d91a1a}9$.

Expand to view detailed results
Name Max Mean Ops Ops on Repo HEAD Change
test_single 64.2635ms 62.4888ms 16.0029 Ops/s 14.9632 Ops/s $\textbf{\color{#35bf28}+6.95\%}$
test_sync 49.0949ms 39.6367ms 25.2292 Ops/s 28.4795 Ops/s $\textbf{\color{#d91a1a}-11.41\%}$
test_async 73.7618ms 33.0609ms 30.2472 Ops/s 29.8029 Ops/s $\color{#35bf28}+1.49\%$
test_simple 0.4801s 0.4288s 2.3322 Ops/s 2.2666 Ops/s $\color{#35bf28}+2.90\%$
test_transformed 0.6397s 0.5915s 1.6905 Ops/s 1.6396 Ops/s $\color{#35bf28}+3.10\%$
test_serial 1.3595s 1.3152s 0.7604 Ops/s 0.7314 Ops/s $\color{#35bf28}+3.96\%$
test_parallel 1.3687s 1.3176s 0.7589 Ops/s 0.7578 Ops/s $\color{#35bf28}+0.15\%$
test_step_mdp_speed[True-True-True-True-True] 0.1418ms 21.6600μs 46.1681 KOps/s 44.4770 KOps/s $\color{#35bf28}+3.80\%$
test_step_mdp_speed[True-True-True-True-False] 37.7300μs 13.1427μs 76.0876 KOps/s 73.5524 KOps/s $\color{#35bf28}+3.45\%$
test_step_mdp_speed[True-True-True-False-True] 64.8610μs 12.7680μs 78.3210 KOps/s 76.9708 KOps/s $\color{#35bf28}+1.75\%$
test_step_mdp_speed[True-True-True-False-False] 30.3060μs 7.7905μs 128.3616 KOps/s 125.9262 KOps/s $\color{#35bf28}+1.93\%$
test_step_mdp_speed[True-True-False-True-True] 49.2110μs 23.3128μs 42.8949 KOps/s 42.2837 KOps/s $\color{#35bf28}+1.45\%$
test_step_mdp_speed[True-True-False-True-False] 42.8290μs 14.5190μs 68.8751 KOps/s 67.7866 KOps/s $\color{#35bf28}+1.61\%$
test_step_mdp_speed[True-True-False-False-True] 40.6250μs 14.0778μs 71.0338 KOps/s 69.4431 KOps/s $\color{#35bf28}+2.29\%$
test_step_mdp_speed[True-True-False-False-False] 50.8240μs 8.9608μs 111.5978 KOps/s 108.1746 KOps/s $\color{#35bf28}+3.16\%$
test_step_mdp_speed[True-False-True-True-True] 54.1700μs 24.3511μs 41.0659 KOps/s 39.6187 KOps/s $\color{#35bf28}+3.65\%$
test_step_mdp_speed[True-False-True-True-False] 46.2250μs 15.7628μs 63.4405 KOps/s 60.3272 KOps/s $\textbf{\color{#35bf28}+5.16\%}$
test_step_mdp_speed[True-False-True-False-True] 46.7270μs 13.9694μs 71.5852 KOps/s 69.4363 KOps/s $\color{#35bf28}+3.09\%$
test_step_mdp_speed[True-False-True-False-False] 30.4770μs 8.9031μs 112.3198 KOps/s 108.4366 KOps/s $\color{#35bf28}+3.58\%$
test_step_mdp_speed[True-False-False-True-True] 65.7620μs 25.4836μs 39.2409 KOps/s 37.8585 KOps/s $\color{#35bf28}+3.65\%$
test_step_mdp_speed[True-False-False-True-False] 48.8810μs 17.3200μs 57.7368 KOps/s 57.9699 KOps/s $\color{#d91a1a}-0.40\%$
test_step_mdp_speed[True-False-False-False-True] 48.3000μs 15.3369μs 65.2022 KOps/s 65.4564 KOps/s $\color{#d91a1a}-0.39\%$
test_step_mdp_speed[True-False-False-False-False] 28.0920μs 10.1838μs 98.1951 KOps/s 96.4682 KOps/s $\color{#35bf28}+1.79\%$
test_step_mdp_speed[False-True-True-True-True] 58.0980μs 24.4455μs 40.9074 KOps/s 39.8057 KOps/s $\color{#35bf28}+2.77\%$
test_step_mdp_speed[False-True-True-True-False] 41.7970μs 15.7826μs 63.3611 KOps/s 61.4548 KOps/s $\color{#35bf28}+3.10\%$
test_step_mdp_speed[False-True-True-False-True] 41.2460μs 16.1979μs 61.7364 KOps/s 60.0874 KOps/s $\color{#35bf28}+2.74\%$
test_step_mdp_speed[False-True-True-False-False] 39.7640μs 10.1415μs 98.6052 KOps/s 95.1077 KOps/s $\color{#35bf28}+3.68\%$
test_step_mdp_speed[False-True-False-True-True] 60.8020μs 25.5739μs 39.1023 KOps/s 38.4368 KOps/s $\color{#35bf28}+1.73\%$
test_step_mdp_speed[False-True-False-True-False] 83.6040μs 16.8794μs 59.2437 KOps/s 56.7805 KOps/s $\color{#35bf28}+4.34\%$
test_step_mdp_speed[False-True-False-False-True] 45.7450μs 17.2706μs 57.9019 KOps/s 56.3098 KOps/s $\color{#35bf28}+2.83\%$
test_step_mdp_speed[False-True-False-False-False] 46.8270μs 11.4346μs 87.4539 KOps/s 85.6275 KOps/s $\color{#35bf28}+2.13\%$
test_step_mdp_speed[False-False-True-True-True] 62.9460μs 26.6371μs 37.5417 KOps/s 36.3695 KOps/s $\color{#35bf28}+3.22\%$
test_step_mdp_speed[False-False-True-True-False] 51.9160μs 18.1658μs 55.0484 KOps/s 53.2586 KOps/s $\color{#35bf28}+3.36\%$
test_step_mdp_speed[False-False-True-False-True] 46.3550μs 17.2181μs 58.0783 KOps/s 56.2409 KOps/s $\color{#35bf28}+3.27\%$
test_step_mdp_speed[False-False-True-False-False] 42.2980μs 11.4152μs 87.6025 KOps/s 86.0459 KOps/s $\color{#35bf28}+1.81\%$
test_step_mdp_speed[False-False-False-True-True] 61.6440μs 27.6927μs 36.1106 KOps/s 34.7724 KOps/s $\color{#35bf28}+3.85\%$
test_step_mdp_speed[False-False-False-True-False] 46.4460μs 19.1266μs 52.2832 KOps/s 49.9040 KOps/s $\color{#35bf28}+4.77\%$
test_step_mdp_speed[False-False-False-False-True] 41.5060μs 18.2623μs 54.7576 KOps/s 53.5322 KOps/s $\color{#35bf28}+2.29\%$
test_step_mdp_speed[False-False-False-False-False] 38.7420μs 12.3290μs 81.1094 KOps/s 78.4739 KOps/s $\color{#35bf28}+3.36\%$
test_values[generalized_advantage_estimate-True-True] 13.4371ms 11.7565ms 85.0593 Ops/s 83.8501 Ops/s $\color{#35bf28}+1.44\%$
test_values[vec_generalized_advantage_estimate-True-True] 35.8042ms 27.6514ms 36.1645 Ops/s 35.5819 Ops/s $\color{#35bf28}+1.64\%$
test_values[td0_return_estimate-False-False] 0.2446ms 0.1787ms 5.5952 KOps/s 5.6973 KOps/s $\color{#d91a1a}-1.79\%$
test_values[td1_return_estimate-False-False] 28.0039ms 25.1326ms 39.7889 Ops/s 40.3305 Ops/s $\color{#d91a1a}-1.34\%$
test_values[vec_td1_return_estimate-False-False] 35.4438ms 27.6367ms 36.1837 Ops/s 35.9245 Ops/s $\color{#35bf28}+0.72\%$
test_values[td_lambda_return_estimate-True-False] 37.8463ms 35.2174ms 28.3951 Ops/s 28.8676 Ops/s $\color{#d91a1a}-1.64\%$
test_values[vec_td_lambda_return_estimate-True-False] 35.3175ms 27.8499ms 35.9067 Ops/s 35.8794 Ops/s $\color{#35bf28}+0.08\%$
test_gae_speed[generalized_advantage_estimate-False-1-512] 10.7644ms 7.9688ms 125.4892 Ops/s 125.7323 Ops/s $\color{#d91a1a}-0.19\%$
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] 2.1310ms 1.8514ms 540.1202 Ops/s 529.3634 Ops/s $\color{#35bf28}+2.03\%$
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] 8.4367ms 0.4285ms 2.3339 KOps/s 2.3656 KOps/s $\color{#d91a1a}-1.34\%$
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] 42.6618ms 40.2227ms 24.8616 Ops/s 24.5251 Ops/s $\color{#35bf28}+1.37\%$
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] 10.6486ms 2.6293ms 380.3298 Ops/s 380.5214 Ops/s $\color{#d91a1a}-0.05\%$
test_dqn_speed 10.1256ms 1.6173ms 618.3072 Ops/s 613.4327 Ops/s $\color{#35bf28}+0.79\%$
test_ddpg_speed 11.7582ms 3.6440ms 274.4261 Ops/s 270.8108 Ops/s $\color{#35bf28}+1.33\%$
test_sac_speed 18.3266ms 10.1054ms 98.9572 Ops/s 96.2836 Ops/s $\color{#35bf28}+2.78\%$
test_redq_speed 26.2030ms 18.9654ms 52.7275 Ops/s 52.5584 Ops/s $\color{#35bf28}+0.32\%$
test_redq_deprec_speed 55.0603ms 16.4628ms 60.7430 Ops/s 65.6829 Ops/s $\textbf{\color{#d91a1a}-7.52\%}$
test_td3_speed 17.6285ms 10.3270ms 96.8334 Ops/s 94.2841 Ops/s $\color{#35bf28}+2.70\%$
test_cql_speed 49.7929ms 41.2475ms 24.2439 Ops/s 24.1736 Ops/s $\color{#35bf28}+0.29\%$
test_a2c_speed 16.3872ms 8.6653ms 115.4025 Ops/s 106.8153 Ops/s $\textbf{\color{#35bf28}+8.04\%}$
test_ppo_speed 17.6925ms 8.9485ms 111.7511 Ops/s 111.6168 Ops/s $\color{#35bf28}+0.12\%$
test_reinforce_speed 15.9228ms 7.7474ms 129.0753 Ops/s 128.1444 Ops/s $\color{#35bf28}+0.73\%$
test_iql_speed 44.0745ms 35.0420ms 28.5371 Ops/s 28.4136 Ops/s $\color{#35bf28}+0.43\%$
test_sample_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 2.3605ms 1.8556ms 538.9080 Ops/s 542.2653 Ops/s $\color{#d91a1a}-0.62\%$
test_sample_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 99.5453ms 2.1639ms 462.1194 Ops/s 507.5180 Ops/s $\textbf{\color{#d91a1a}-8.95\%}$
test_sample_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 2.4933ms 1.9427ms 514.7603 Ops/s 510.0217 Ops/s $\color{#35bf28}+0.93\%$
test_sample_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 2.1568ms 1.8751ms 533.2977 Ops/s 544.3832 Ops/s $\color{#d91a1a}-2.04\%$
test_sample_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 96.7864ms 2.1823ms 458.2387 Ops/s 515.8620 Ops/s $\textbf{\color{#d91a1a}-11.17\%}$
test_sample_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 2.7504ms 1.9952ms 501.2091 Ops/s 503.6977 Ops/s $\color{#d91a1a}-0.49\%$
test_sample_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 2.2222ms 1.8723ms 534.0956 Ops/s 536.4281 Ops/s $\color{#d91a1a}-0.43\%$
test_sample_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 0.1085s 2.1749ms 459.7978 Ops/s 508.9358 Ops/s $\textbf{\color{#d91a1a}-9.66\%}$
test_sample_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 4.0761ms 1.9883ms 502.9426 Ops/s 507.3924 Ops/s $\color{#d91a1a}-0.88\%$
test_iterate_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 2.4102ms 1.8516ms 540.0875 Ops/s 537.8427 Ops/s $\color{#35bf28}+0.42\%$
test_iterate_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 99.3143ms 2.1250ms 470.5928 Ops/s 519.1069 Ops/s $\textbf{\color{#d91a1a}-9.35\%}$
test_iterate_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 3.0809ms 2.0016ms 499.6071 Ops/s 506.5899 Ops/s $\color{#d91a1a}-1.38\%$
test_iterate_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 3.2243ms 1.8804ms 531.8042 Ops/s 537.9029 Ops/s $\color{#d91a1a}-1.13\%$
test_iterate_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 0.1198s 2.2331ms 447.8047 Ops/s 504.0018 Ops/s $\textbf{\color{#d91a1a}-11.15\%}$
test_iterate_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 2.9871ms 1.9741ms 506.5717 Ops/s 502.0516 Ops/s $\color{#35bf28}+0.90\%$
test_iterate_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 2.3913ms 1.8626ms 536.8811 Ops/s 536.5029 Ops/s $\color{#35bf28}+0.07\%$
test_iterate_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 99.8292ms 2.1697ms 460.9009 Ops/s 504.8599 Ops/s $\textbf{\color{#d91a1a}-8.71\%}$
test_iterate_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 3.2680ms 1.9684ms 508.0394 Ops/s 504.0233 Ops/s $\color{#35bf28}+0.80\%$
test_populate_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-400] 0.1396s 16.3396ms 61.2010 Ops/s 58.3545 Ops/s $\color{#35bf28}+4.88\%$
test_populate_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] 95.7047ms 15.4853ms 64.5775 Ops/s 64.5180 Ops/s $\color{#35bf28}+0.09\%$
test_populate_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] 99.3906ms 15.5963ms 64.1179 Ops/s 64.3261 Ops/s $\color{#d91a1a}-0.32\%$
test_populate_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] 95.9221ms 15.4698ms 64.6420 Ops/s 64.9871 Ops/s $\color{#d91a1a}-0.53\%$
test_populate_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] 96.7391ms 15.4273ms 64.8200 Ops/s 64.2664 Ops/s $\color{#35bf28}+0.86\%$
test_populate_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] 96.9632ms 15.5165ms 64.4474 Ops/s 63.1760 Ops/s $\color{#35bf28}+2.01\%$
test_populate_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] 99.2836ms 15.6622ms 63.8482 Ops/s 63.3465 Ops/s $\color{#35bf28}+0.79\%$
test_populate_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] 0.1019s 17.5400ms 57.0125 Ops/s 62.5557 Ops/s $\textbf{\color{#d91a1a}-8.86\%}$
test_populate_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] 0.1025s 15.6841ms 63.7590 Ops/s 63.6002 Ops/s $\color{#35bf28}+0.25\%$

Co-authored-by: Matteo Bettini <55539777+matteobettini@users.noreply.github.com>
Copy link
Contributor

@matteobettini matteobettini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

If we want to be extra careful we should run the discrete_sac example to check

Copy link

github-actions bot commented Dec 16, 2023

$\color{#D29922}\textsf{\Large&amp;#x26A0;\kern{0.2cm}\normalsize Warning}$ Result of GPU Benchmark Tests

Total Benchmarks: 92. Improved: $\large\color{#35bf28}4$. Worsened: $\large\color{#d91a1a}3$.

Expand to view detailed results
Name Max Mean Ops Ops on Repo HEAD Change
test_single 0.1284s 0.1257s 7.9544 Ops/s 8.1627 Ops/s $\color{#d91a1a}-2.55\%$
test_sync 0.1020s 0.1017s 9.8316 Ops/s 9.7076 Ops/s $\color{#35bf28}+1.28\%$
test_async 0.2736s 99.1462ms 10.0861 Ops/s 9.9670 Ops/s $\color{#35bf28}+1.19\%$
test_single_pixels 0.1475s 0.1473s 6.7885 Ops/s 7.5393 Ops/s $\textbf{\color{#d91a1a}-9.96\%}$
test_sync_pixels 96.7164ms 95.1649ms 10.5081 Ops/s 10.4573 Ops/s $\color{#35bf28}+0.49\%$
test_async_pixels 0.2472s 91.6094ms 10.9159 Ops/s 10.9835 Ops/s $\color{#d91a1a}-0.62\%$
test_simple 1.0020s 0.9232s 1.0831 Ops/s 1.1248 Ops/s $\color{#d91a1a}-3.71\%$
test_transformed 1.2231s 1.1597s 0.8623 Ops/s 0.8798 Ops/s $\color{#d91a1a}-1.99\%$
test_serial 2.7291s 2.6498s 0.3774 Ops/s 0.3974 Ops/s $\textbf{\color{#d91a1a}-5.04\%}$
test_parallel 2.6072s 2.5094s 0.3985 Ops/s 0.4004 Ops/s $\color{#d91a1a}-0.48\%$
test_step_mdp_speed[True-True-True-True-True] 80.9510μs 32.7507μs 30.5337 KOps/s 30.6682 KOps/s $\color{#d91a1a}-0.44\%$
test_step_mdp_speed[True-True-True-True-False] 44.0420μs 19.4647μs 51.3750 KOps/s 51.1065 KOps/s $\color{#35bf28}+0.53\%$
test_step_mdp_speed[True-True-True-False-True] 41.2610μs 18.7897μs 53.2207 KOps/s 52.7940 KOps/s $\color{#35bf28}+0.81\%$
test_step_mdp_speed[True-True-True-False-False] 41.5510μs 11.2542μs 88.8555 KOps/s 91.2921 KOps/s $\color{#d91a1a}-2.67\%$
test_step_mdp_speed[True-True-False-True-True] 55.4100μs 34.4376μs 29.0380 KOps/s 29.1441 KOps/s $\color{#d91a1a}-0.36\%$
test_step_mdp_speed[True-True-False-True-False] 44.3700μs 21.1683μs 47.2405 KOps/s 47.2263 KOps/s $\color{#35bf28}+0.03\%$
test_step_mdp_speed[True-True-False-False-True] 43.9400μs 20.7194μs 48.2639 KOps/s 47.7585 KOps/s $\color{#35bf28}+1.06\%$
test_step_mdp_speed[True-True-False-False-False] 51.2720μs 13.3628μs 74.8344 KOps/s 77.9621 KOps/s $\color{#d91a1a}-4.01\%$
test_step_mdp_speed[True-False-True-True-True] 63.0100μs 36.3855μs 27.4835 KOps/s 27.8756 KOps/s $\color{#d91a1a}-1.41\%$
test_step_mdp_speed[True-False-True-True-False] 45.9310μs 23.1355μs 43.2236 KOps/s 43.2344 KOps/s $\color{#d91a1a}-0.03\%$
test_step_mdp_speed[True-False-True-False-True] 50.7710μs 20.8876μs 47.8753 KOps/s 49.0465 KOps/s $\color{#d91a1a}-2.39\%$
test_step_mdp_speed[True-False-True-False-False] 37.5200μs 13.1746μs 75.9034 KOps/s 77.8939 KOps/s $\color{#d91a1a}-2.56\%$
test_step_mdp_speed[True-False-False-True-True] 59.3500μs 38.3940μs 26.0458 KOps/s 26.3723 KOps/s $\color{#d91a1a}-1.24\%$
test_step_mdp_speed[True-False-False-True-False] 56.8110μs 25.2996μs 39.5263 KOps/s 40.3223 KOps/s $\color{#d91a1a}-1.97\%$
test_step_mdp_speed[True-False-False-False-True] 45.5610μs 22.7204μs 44.0133 KOps/s 44.4055 KOps/s $\color{#d91a1a}-0.88\%$
test_step_mdp_speed[True-False-False-False-False] 34.5420μs 14.5444μs 68.7551 KOps/s 68.1274 KOps/s $\color{#35bf28}+0.92\%$
test_step_mdp_speed[False-True-True-True-True] 60.6700μs 35.9460μs 27.8195 KOps/s 27.7170 KOps/s $\color{#35bf28}+0.37\%$
test_step_mdp_speed[False-True-True-True-False] 44.6510μs 23.3505μs 42.8256 KOps/s 43.7148 KOps/s $\color{#d91a1a}-2.03\%$
test_step_mdp_speed[False-True-True-False-True] 45.5400μs 24.6725μs 40.5310 KOps/s 41.4481 KOps/s $\color{#d91a1a}-2.21\%$
test_step_mdp_speed[False-True-True-False-False] 33.4410μs 14.6433μs 68.2907 KOps/s 67.5841 KOps/s $\color{#35bf28}+1.05\%$
test_step_mdp_speed[False-True-False-True-True] 89.2110μs 38.5272μs 25.9557 KOps/s 26.2951 KOps/s $\color{#d91a1a}-1.29\%$
test_step_mdp_speed[False-True-False-True-False] 43.7900μs 25.2314μs 39.6331 KOps/s 40.2571 KOps/s $\color{#d91a1a}-1.55\%$
test_step_mdp_speed[False-True-False-False-True] 51.4610μs 25.8508μs 38.6835 KOps/s 38.1406 KOps/s $\color{#35bf28}+1.42\%$
test_step_mdp_speed[False-True-False-False-False] 35.3910μs 16.7829μs 59.5844 KOps/s 61.1219 KOps/s $\color{#d91a1a}-2.52\%$
test_step_mdp_speed[False-False-True-True-True] 60.7800μs 40.1102μs 24.9313 KOps/s 25.4296 KOps/s $\color{#d91a1a}-1.96\%$
test_step_mdp_speed[False-False-True-True-False] 49.6800μs 27.1431μs 36.8417 KOps/s 37.4096 KOps/s $\color{#d91a1a}-1.52\%$
test_step_mdp_speed[False-False-True-False-True] 0.1021ms 26.2919μs 38.0346 KOps/s 38.4245 KOps/s $\color{#d91a1a}-1.01\%$
test_step_mdp_speed[False-False-True-False-False] 37.6800μs 16.8699μs 59.2771 KOps/s 60.9477 KOps/s $\color{#d91a1a}-2.74\%$
test_step_mdp_speed[False-False-False-True-True] 68.1610μs 40.9772μs 24.4038 KOps/s 24.2955 KOps/s $\color{#35bf28}+0.45\%$
test_step_mdp_speed[False-False-False-True-False] 60.9620μs 28.4801μs 35.1122 KOps/s 35.1922 KOps/s $\color{#d91a1a}-0.23\%$
test_step_mdp_speed[False-False-False-False-True] 58.2820μs 27.7324μs 36.0590 KOps/s 36.4793 KOps/s $\color{#d91a1a}-1.15\%$
test_step_mdp_speed[False-False-False-False-False] 39.1720μs 18.5209μs 53.9930 KOps/s 55.6885 KOps/s $\color{#d91a1a}-3.04\%$
test_values[generalized_advantage_estimate-True-True] 28.2647ms 27.1010ms 36.8990 Ops/s 38.1818 Ops/s $\color{#d91a1a}-3.36\%$
test_values[vec_generalized_advantage_estimate-True-True] 96.4090ms 3.5209ms 284.0195 Ops/s 297.3413 Ops/s $\color{#d91a1a}-4.48\%$
test_values[td0_return_estimate-False-False] 0.1023ms 67.9680μs 14.7128 KOps/s 15.0931 KOps/s $\color{#d91a1a}-2.52\%$
test_values[td1_return_estimate-False-False] 60.8722ms 58.5337ms 17.0842 Ops/s 16.5268 Ops/s $\color{#35bf28}+3.37\%$
test_values[vec_td1_return_estimate-False-False] 2.0874ms 1.7992ms 555.8107 Ops/s 556.8632 Ops/s $\color{#d91a1a}-0.19\%$
test_values[td_lambda_return_estimate-True-False] 96.1933ms 92.5187ms 10.8086 Ops/s 10.4449 Ops/s $\color{#35bf28}+3.48\%$
test_values[vec_td_lambda_return_estimate-True-False] 2.1080ms 1.8196ms 549.5798 Ops/s 551.9136 Ops/s $\color{#d91a1a}-0.42\%$
test_gae_speed[generalized_advantage_estimate-False-1-512] 25.6983ms 25.4829ms 39.2421 Ops/s 36.6769 Ops/s $\textbf{\color{#35bf28}+6.99\%}$
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] 0.9186ms 0.7375ms 1.3560 KOps/s 1.2925 KOps/s $\color{#35bf28}+4.91\%$
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] 0.7467ms 0.6909ms 1.4473 KOps/s 1.3909 KOps/s $\color{#35bf28}+4.06\%$
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] 1.5652ms 1.4889ms 671.6171 Ops/s 674.4053 Ops/s $\color{#d91a1a}-0.41\%$
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] 1.0680ms 0.7215ms 1.3859 KOps/s 1.4215 KOps/s $\color{#d91a1a}-2.50\%$
test_dqn_speed 7.8889ms 1.4967ms 668.1287 Ops/s 684.1421 Ops/s $\color{#d91a1a}-2.34\%$
test_ddpg_speed 4.3545ms 3.3713ms 296.6248 Ops/s 281.4170 Ops/s $\textbf{\color{#35bf28}+5.40\%}$
test_sac_speed 95.3695ms 10.3386ms 96.7246 Ops/s 108.4807 Ops/s $\textbf{\color{#d91a1a}-10.84\%}$
test_redq_speed 17.6201ms 16.8256ms 59.4334 Ops/s 60.9994 Ops/s $\color{#d91a1a}-2.57\%$
test_redq_deprec_speed 13.8919ms 13.1404ms 76.1015 Ops/s 77.1527 Ops/s $\color{#d91a1a}-1.36\%$
test_td3_speed 19.3772ms 9.7962ms 102.0808 Ops/s 105.2477 Ops/s $\color{#d91a1a}-3.01\%$
test_cql_speed 34.9763ms 33.9319ms 29.4708 Ops/s 29.8122 Ops/s $\color{#d91a1a}-1.15\%$
test_a2c_speed 8.0086ms 7.0726ms 141.3908 Ops/s 140.1266 Ops/s $\color{#35bf28}+0.90\%$
test_ppo_speed 8.3746ms 7.4332ms 134.5308 Ops/s 134.0231 Ops/s $\color{#35bf28}+0.38\%$
test_reinforce_speed 7.3177ms 6.1169ms 163.4813 Ops/s 162.4654 Ops/s $\color{#35bf28}+0.63\%$
test_iql_speed 28.5754ms 27.1849ms 36.7851 Ops/s 37.0608 Ops/s $\color{#d91a1a}-0.74\%$
test_sample_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 3.1689ms 2.5806ms 387.5018 Ops/s 399.1688 Ops/s $\color{#d91a1a}-2.92\%$
test_sample_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 4.0238ms 2.7402ms 364.9383 Ops/s 333.9103 Ops/s $\textbf{\color{#35bf28}+9.29\%}$
test_sample_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 4.4665ms 2.7529ms 363.2590 Ops/s 376.8130 Ops/s $\color{#d91a1a}-3.60\%$
test_sample_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 3.1598ms 2.5686ms 389.3098 Ops/s 403.0050 Ops/s $\color{#d91a1a}-3.40\%$
test_sample_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 3.6987ms 2.7354ms 365.5723 Ops/s 375.9474 Ops/s $\color{#d91a1a}-2.76\%$
test_sample_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 3.8128ms 2.7430ms 364.5627 Ops/s 373.7324 Ops/s $\color{#d91a1a}-2.45\%$
test_sample_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 3.2285ms 2.5634ms 390.1008 Ops/s 399.5665 Ops/s $\color{#d91a1a}-2.37\%$
test_sample_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 3.9081ms 2.7414ms 364.7833 Ops/s 375.9346 Ops/s $\color{#d91a1a}-2.97\%$
test_sample_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 3.8855ms 2.7408ms 364.8583 Ops/s 373.5261 Ops/s $\color{#d91a1a}-2.32\%$
test_iterate_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 3.2239ms 2.5446ms 392.9889 Ops/s 398.8909 Ops/s $\color{#d91a1a}-1.48\%$
test_iterate_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 3.8034ms 2.7363ms 365.4569 Ops/s 375.1835 Ops/s $\color{#d91a1a}-2.59\%$
test_iterate_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 4.0658ms 2.7605ms 362.2478 Ops/s 374.0149 Ops/s $\color{#d91a1a}-3.15\%$
test_iterate_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 3.0107ms 2.5696ms 389.1695 Ops/s 400.7725 Ops/s $\color{#d91a1a}-2.90\%$
test_iterate_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 3.7686ms 2.7331ms 365.8853 Ops/s 372.5148 Ops/s $\color{#d91a1a}-1.78\%$
test_iterate_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 3.7353ms 2.7619ms 362.0761 Ops/s 373.1918 Ops/s $\color{#d91a1a}-2.98\%$
test_iterate_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 2.9621ms 2.5571ms 391.0650 Ops/s 401.8452 Ops/s $\color{#d91a1a}-2.68\%$
test_iterate_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 4.0797ms 2.7512ms 363.4748 Ops/s 375.7625 Ops/s $\color{#d91a1a}-3.27\%$
test_iterate_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 3.9770ms 2.7366ms 365.4169 Ops/s 374.6625 Ops/s $\color{#d91a1a}-2.47\%$
test_populate_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-400] 0.1865s 18.7759ms 53.2598 Ops/s 53.9493 Ops/s $\color{#d91a1a}-1.28\%$
test_populate_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] 0.1194s 17.3967ms 57.4822 Ops/s 59.0125 Ops/s $\color{#d91a1a}-2.59\%$
test_populate_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] 0.1195s 17.3491ms 57.6399 Ops/s 58.8827 Ops/s $\color{#d91a1a}-2.11\%$
test_populate_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] 0.1207s 17.2779ms 57.8774 Ops/s 58.6140 Ops/s $\color{#d91a1a}-1.26\%$
test_populate_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] 0.1200s 15.1831ms 65.8628 Ops/s 67.3820 Ops/s $\color{#d91a1a}-2.25\%$
test_populate_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] 0.1200s 17.3900ms 57.5042 Ops/s 58.7152 Ops/s $\color{#d91a1a}-2.06\%$
test_populate_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] 0.1198s 17.3946ms 57.4891 Ops/s 58.6636 Ops/s $\color{#d91a1a}-2.00\%$
test_populate_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] 0.1210s 17.3742ms 57.5568 Ops/s 58.5735 Ops/s $\color{#d91a1a}-1.74\%$
test_populate_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] 0.1229s 15.2061ms 65.7629 Ops/s 58.4777 Ops/s $\textbf{\color{#35bf28}+12.46\%}$

@vmoens vmoens merged commit 0e02132 into main Dec 17, 2023
40 of 49 checks passed
@vmoens vmoens deleted the fix-logprob branch December 17, 2023 20:55
@matteobettini
Copy link
Contributor

This fix actually changed the behavior and is breaking the CI of BenchMARL for discrete SAC with MaskedCategorical.

What happens is that if dist.probs returns [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]

dist.logits will give [ 0.0000, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf]

while torch.log(torch.where(dist.probs == 0, 1e-8, dist.probs)) gives [[ 0.0000, -18.4207, -18.4207, -18.4207, -18.4207, -18.4207, -18.4207, -18.4207, -18.4207, -18.4207, -18.4207]

Can we return to the previous behavior?

@vmoens
Copy link
Contributor Author

vmoens commented Dec 21, 2023

Not sure returning to the previous behaviour is appropriate. Why is this a problem?
-18 log prob is roughly equivalent to -inf and is more numerically stable IMO

Edit: I see , you don't want the -inf but the -18. We can clamp logits if that helps

@matteobettini
Copy link
Contributor

matteobettini commented Dec 21, 2023

The -inf causes a bug in MaskedCategorical https://github.com/facebookresearch/BenchMARL/actions/runs/7278697272/job/19833445677

If you prefer to adapt how the logits are computed in MaskedCategorical, that should be a viable solution too

@matteobettini
Copy link
Contributor

I wonder what was the problem of the previous behavior tho.
We were just clamping the probs and computing the log, without calling dist.logits

@vmoens
Copy link
Contributor Author

vmoens commented Dec 21, 2023

Efficiency: we were recomputing something already computed by the dist
Stability: the log-probs should be normalised. The dist does that, we didn't.

And the code previously was seriously ugly

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants