-
Notifications
You must be signed in to change notification settings - Fork 327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BugFix] Fix discrete SAC log-prob #1750
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/1750
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (9 Unrelated Failures)As of commit b1d5efd with merge base 08f0bed (): FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
I don't get the fix. We need both prob and log_prob We also need to rerun the example training run for discrete SAC and confirm the returns curve match (if we change the logic) |
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_single | 64.2635ms | 62.4888ms | 16.0029 Ops/s | 14.9632 Ops/s | |
test_sync | 49.0949ms | 39.6367ms | 25.2292 Ops/s | 28.4795 Ops/s | |
test_async | 73.7618ms | 33.0609ms | 30.2472 Ops/s | 29.8029 Ops/s | |
test_simple | 0.4801s | 0.4288s | 2.3322 Ops/s | 2.2666 Ops/s | |
test_transformed | 0.6397s | 0.5915s | 1.6905 Ops/s | 1.6396 Ops/s | |
test_serial | 1.3595s | 1.3152s | 0.7604 Ops/s | 0.7314 Ops/s | |
test_parallel | 1.3687s | 1.3176s | 0.7589 Ops/s | 0.7578 Ops/s | |
test_step_mdp_speed[True-True-True-True-True] | 0.1418ms | 21.6600μs | 46.1681 KOps/s | 44.4770 KOps/s | |
test_step_mdp_speed[True-True-True-True-False] | 37.7300μs | 13.1427μs | 76.0876 KOps/s | 73.5524 KOps/s | |
test_step_mdp_speed[True-True-True-False-True] | 64.8610μs | 12.7680μs | 78.3210 KOps/s | 76.9708 KOps/s | |
test_step_mdp_speed[True-True-True-False-False] | 30.3060μs | 7.7905μs | 128.3616 KOps/s | 125.9262 KOps/s | |
test_step_mdp_speed[True-True-False-True-True] | 49.2110μs | 23.3128μs | 42.8949 KOps/s | 42.2837 KOps/s | |
test_step_mdp_speed[True-True-False-True-False] | 42.8290μs | 14.5190μs | 68.8751 KOps/s | 67.7866 KOps/s | |
test_step_mdp_speed[True-True-False-False-True] | 40.6250μs | 14.0778μs | 71.0338 KOps/s | 69.4431 KOps/s | |
test_step_mdp_speed[True-True-False-False-False] | 50.8240μs | 8.9608μs | 111.5978 KOps/s | 108.1746 KOps/s | |
test_step_mdp_speed[True-False-True-True-True] | 54.1700μs | 24.3511μs | 41.0659 KOps/s | 39.6187 KOps/s | |
test_step_mdp_speed[True-False-True-True-False] | 46.2250μs | 15.7628μs | 63.4405 KOps/s | 60.3272 KOps/s | |
test_step_mdp_speed[True-False-True-False-True] | 46.7270μs | 13.9694μs | 71.5852 KOps/s | 69.4363 KOps/s | |
test_step_mdp_speed[True-False-True-False-False] | 30.4770μs | 8.9031μs | 112.3198 KOps/s | 108.4366 KOps/s | |
test_step_mdp_speed[True-False-False-True-True] | 65.7620μs | 25.4836μs | 39.2409 KOps/s | 37.8585 KOps/s | |
test_step_mdp_speed[True-False-False-True-False] | 48.8810μs | 17.3200μs | 57.7368 KOps/s | 57.9699 KOps/s | |
test_step_mdp_speed[True-False-False-False-True] | 48.3000μs | 15.3369μs | 65.2022 KOps/s | 65.4564 KOps/s | |
test_step_mdp_speed[True-False-False-False-False] | 28.0920μs | 10.1838μs | 98.1951 KOps/s | 96.4682 KOps/s | |
test_step_mdp_speed[False-True-True-True-True] | 58.0980μs | 24.4455μs | 40.9074 KOps/s | 39.8057 KOps/s | |
test_step_mdp_speed[False-True-True-True-False] | 41.7970μs | 15.7826μs | 63.3611 KOps/s | 61.4548 KOps/s | |
test_step_mdp_speed[False-True-True-False-True] | 41.2460μs | 16.1979μs | 61.7364 KOps/s | 60.0874 KOps/s | |
test_step_mdp_speed[False-True-True-False-False] | 39.7640μs | 10.1415μs | 98.6052 KOps/s | 95.1077 KOps/s | |
test_step_mdp_speed[False-True-False-True-True] | 60.8020μs | 25.5739μs | 39.1023 KOps/s | 38.4368 KOps/s | |
test_step_mdp_speed[False-True-False-True-False] | 83.6040μs | 16.8794μs | 59.2437 KOps/s | 56.7805 KOps/s | |
test_step_mdp_speed[False-True-False-False-True] | 45.7450μs | 17.2706μs | 57.9019 KOps/s | 56.3098 KOps/s | |
test_step_mdp_speed[False-True-False-False-False] | 46.8270μs | 11.4346μs | 87.4539 KOps/s | 85.6275 KOps/s | |
test_step_mdp_speed[False-False-True-True-True] | 62.9460μs | 26.6371μs | 37.5417 KOps/s | 36.3695 KOps/s | |
test_step_mdp_speed[False-False-True-True-False] | 51.9160μs | 18.1658μs | 55.0484 KOps/s | 53.2586 KOps/s | |
test_step_mdp_speed[False-False-True-False-True] | 46.3550μs | 17.2181μs | 58.0783 KOps/s | 56.2409 KOps/s | |
test_step_mdp_speed[False-False-True-False-False] | 42.2980μs | 11.4152μs | 87.6025 KOps/s | 86.0459 KOps/s | |
test_step_mdp_speed[False-False-False-True-True] | 61.6440μs | 27.6927μs | 36.1106 KOps/s | 34.7724 KOps/s | |
test_step_mdp_speed[False-False-False-True-False] | 46.4460μs | 19.1266μs | 52.2832 KOps/s | 49.9040 KOps/s | |
test_step_mdp_speed[False-False-False-False-True] | 41.5060μs | 18.2623μs | 54.7576 KOps/s | 53.5322 KOps/s | |
test_step_mdp_speed[False-False-False-False-False] | 38.7420μs | 12.3290μs | 81.1094 KOps/s | 78.4739 KOps/s | |
test_values[generalized_advantage_estimate-True-True] | 13.4371ms | 11.7565ms | 85.0593 Ops/s | 83.8501 Ops/s | |
test_values[vec_generalized_advantage_estimate-True-True] | 35.8042ms | 27.6514ms | 36.1645 Ops/s | 35.5819 Ops/s | |
test_values[td0_return_estimate-False-False] | 0.2446ms | 0.1787ms | 5.5952 KOps/s | 5.6973 KOps/s | |
test_values[td1_return_estimate-False-False] | 28.0039ms | 25.1326ms | 39.7889 Ops/s | 40.3305 Ops/s | |
test_values[vec_td1_return_estimate-False-False] | 35.4438ms | 27.6367ms | 36.1837 Ops/s | 35.9245 Ops/s | |
test_values[td_lambda_return_estimate-True-False] | 37.8463ms | 35.2174ms | 28.3951 Ops/s | 28.8676 Ops/s | |
test_values[vec_td_lambda_return_estimate-True-False] | 35.3175ms | 27.8499ms | 35.9067 Ops/s | 35.8794 Ops/s | |
test_gae_speed[generalized_advantage_estimate-False-1-512] | 10.7644ms | 7.9688ms | 125.4892 Ops/s | 125.7323 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] | 2.1310ms | 1.8514ms | 540.1202 Ops/s | 529.3634 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] | 8.4367ms | 0.4285ms | 2.3339 KOps/s | 2.3656 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] | 42.6618ms | 40.2227ms | 24.8616 Ops/s | 24.5251 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] | 10.6486ms | 2.6293ms | 380.3298 Ops/s | 380.5214 Ops/s | |
test_dqn_speed | 10.1256ms | 1.6173ms | 618.3072 Ops/s | 613.4327 Ops/s | |
test_ddpg_speed | 11.7582ms | 3.6440ms | 274.4261 Ops/s | 270.8108 Ops/s | |
test_sac_speed | 18.3266ms | 10.1054ms | 98.9572 Ops/s | 96.2836 Ops/s | |
test_redq_speed | 26.2030ms | 18.9654ms | 52.7275 Ops/s | 52.5584 Ops/s | |
test_redq_deprec_speed | 55.0603ms | 16.4628ms | 60.7430 Ops/s | 65.6829 Ops/s | |
test_td3_speed | 17.6285ms | 10.3270ms | 96.8334 Ops/s | 94.2841 Ops/s | |
test_cql_speed | 49.7929ms | 41.2475ms | 24.2439 Ops/s | 24.1736 Ops/s | |
test_a2c_speed | 16.3872ms | 8.6653ms | 115.4025 Ops/s | 106.8153 Ops/s | |
test_ppo_speed | 17.6925ms | 8.9485ms | 111.7511 Ops/s | 111.6168 Ops/s | |
test_reinforce_speed | 15.9228ms | 7.7474ms | 129.0753 Ops/s | 128.1444 Ops/s | |
test_iql_speed | 44.0745ms | 35.0420ms | 28.5371 Ops/s | 28.4136 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 2.3605ms | 1.8556ms | 538.9080 Ops/s | 542.2653 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 99.5453ms | 2.1639ms | 462.1194 Ops/s | 507.5180 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 2.4933ms | 1.9427ms | 514.7603 Ops/s | 510.0217 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 2.1568ms | 1.8751ms | 533.2977 Ops/s | 544.3832 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 96.7864ms | 2.1823ms | 458.2387 Ops/s | 515.8620 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 2.7504ms | 1.9952ms | 501.2091 Ops/s | 503.6977 Ops/s | |
test_sample_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 2.2222ms | 1.8723ms | 534.0956 Ops/s | 536.4281 Ops/s | |
test_sample_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 0.1085s | 2.1749ms | 459.7978 Ops/s | 508.9358 Ops/s | |
test_sample_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 4.0761ms | 1.9883ms | 502.9426 Ops/s | 507.3924 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 2.4102ms | 1.8516ms | 540.0875 Ops/s | 537.8427 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 99.3143ms | 2.1250ms | 470.5928 Ops/s | 519.1069 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 3.0809ms | 2.0016ms | 499.6071 Ops/s | 506.5899 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 3.2243ms | 1.8804ms | 531.8042 Ops/s | 537.9029 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 0.1198s | 2.2331ms | 447.8047 Ops/s | 504.0018 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 2.9871ms | 1.9741ms | 506.5717 Ops/s | 502.0516 Ops/s | |
test_iterate_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 2.3913ms | 1.8626ms | 536.8811 Ops/s | 536.5029 Ops/s | |
test_iterate_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 99.8292ms | 2.1697ms | 460.9009 Ops/s | 504.8599 Ops/s | |
test_iterate_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 3.2680ms | 1.9684ms | 508.0394 Ops/s | 504.0233 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-400] | 0.1396s | 16.3396ms | 61.2010 Ops/s | 58.3545 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] | 95.7047ms | 15.4853ms | 64.5775 Ops/s | 64.5180 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] | 99.3906ms | 15.5963ms | 64.1179 Ops/s | 64.3261 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] | 95.9221ms | 15.4698ms | 64.6420 Ops/s | 64.9871 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] | 96.7391ms | 15.4273ms | 64.8200 Ops/s | 64.2664 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] | 96.9632ms | 15.5165ms | 64.4474 Ops/s | 63.1760 Ops/s | |
test_populate_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] | 99.2836ms | 15.6622ms | 63.8482 Ops/s | 63.3465 Ops/s | |
test_populate_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] | 0.1019s | 17.5400ms | 57.0125 Ops/s | 62.5557 Ops/s | |
test_populate_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] | 0.1025s | 15.6841ms | 63.7590 Ops/s | 63.6002 Ops/s |
Co-authored-by: Matteo Bettini <55539777+matteobettini@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
If we want to be extra careful we should run the discrete_sac example to check
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_single | 0.1284s | 0.1257s | 7.9544 Ops/s | 8.1627 Ops/s | |
test_sync | 0.1020s | 0.1017s | 9.8316 Ops/s | 9.7076 Ops/s | |
test_async | 0.2736s | 99.1462ms | 10.0861 Ops/s | 9.9670 Ops/s | |
test_single_pixels | 0.1475s | 0.1473s | 6.7885 Ops/s | 7.5393 Ops/s | |
test_sync_pixels | 96.7164ms | 95.1649ms | 10.5081 Ops/s | 10.4573 Ops/s | |
test_async_pixels | 0.2472s | 91.6094ms | 10.9159 Ops/s | 10.9835 Ops/s | |
test_simple | 1.0020s | 0.9232s | 1.0831 Ops/s | 1.1248 Ops/s | |
test_transformed | 1.2231s | 1.1597s | 0.8623 Ops/s | 0.8798 Ops/s | |
test_serial | 2.7291s | 2.6498s | 0.3774 Ops/s | 0.3974 Ops/s | |
test_parallel | 2.6072s | 2.5094s | 0.3985 Ops/s | 0.4004 Ops/s | |
test_step_mdp_speed[True-True-True-True-True] | 80.9510μs | 32.7507μs | 30.5337 KOps/s | 30.6682 KOps/s | |
test_step_mdp_speed[True-True-True-True-False] | 44.0420μs | 19.4647μs | 51.3750 KOps/s | 51.1065 KOps/s | |
test_step_mdp_speed[True-True-True-False-True] | 41.2610μs | 18.7897μs | 53.2207 KOps/s | 52.7940 KOps/s | |
test_step_mdp_speed[True-True-True-False-False] | 41.5510μs | 11.2542μs | 88.8555 KOps/s | 91.2921 KOps/s | |
test_step_mdp_speed[True-True-False-True-True] | 55.4100μs | 34.4376μs | 29.0380 KOps/s | 29.1441 KOps/s | |
test_step_mdp_speed[True-True-False-True-False] | 44.3700μs | 21.1683μs | 47.2405 KOps/s | 47.2263 KOps/s | |
test_step_mdp_speed[True-True-False-False-True] | 43.9400μs | 20.7194μs | 48.2639 KOps/s | 47.7585 KOps/s | |
test_step_mdp_speed[True-True-False-False-False] | 51.2720μs | 13.3628μs | 74.8344 KOps/s | 77.9621 KOps/s | |
test_step_mdp_speed[True-False-True-True-True] | 63.0100μs | 36.3855μs | 27.4835 KOps/s | 27.8756 KOps/s | |
test_step_mdp_speed[True-False-True-True-False] | 45.9310μs | 23.1355μs | 43.2236 KOps/s | 43.2344 KOps/s | |
test_step_mdp_speed[True-False-True-False-True] | 50.7710μs | 20.8876μs | 47.8753 KOps/s | 49.0465 KOps/s | |
test_step_mdp_speed[True-False-True-False-False] | 37.5200μs | 13.1746μs | 75.9034 KOps/s | 77.8939 KOps/s | |
test_step_mdp_speed[True-False-False-True-True] | 59.3500μs | 38.3940μs | 26.0458 KOps/s | 26.3723 KOps/s | |
test_step_mdp_speed[True-False-False-True-False] | 56.8110μs | 25.2996μs | 39.5263 KOps/s | 40.3223 KOps/s | |
test_step_mdp_speed[True-False-False-False-True] | 45.5610μs | 22.7204μs | 44.0133 KOps/s | 44.4055 KOps/s | |
test_step_mdp_speed[True-False-False-False-False] | 34.5420μs | 14.5444μs | 68.7551 KOps/s | 68.1274 KOps/s | |
test_step_mdp_speed[False-True-True-True-True] | 60.6700μs | 35.9460μs | 27.8195 KOps/s | 27.7170 KOps/s | |
test_step_mdp_speed[False-True-True-True-False] | 44.6510μs | 23.3505μs | 42.8256 KOps/s | 43.7148 KOps/s | |
test_step_mdp_speed[False-True-True-False-True] | 45.5400μs | 24.6725μs | 40.5310 KOps/s | 41.4481 KOps/s | |
test_step_mdp_speed[False-True-True-False-False] | 33.4410μs | 14.6433μs | 68.2907 KOps/s | 67.5841 KOps/s | |
test_step_mdp_speed[False-True-False-True-True] | 89.2110μs | 38.5272μs | 25.9557 KOps/s | 26.2951 KOps/s | |
test_step_mdp_speed[False-True-False-True-False] | 43.7900μs | 25.2314μs | 39.6331 KOps/s | 40.2571 KOps/s | |
test_step_mdp_speed[False-True-False-False-True] | 51.4610μs | 25.8508μs | 38.6835 KOps/s | 38.1406 KOps/s | |
test_step_mdp_speed[False-True-False-False-False] | 35.3910μs | 16.7829μs | 59.5844 KOps/s | 61.1219 KOps/s | |
test_step_mdp_speed[False-False-True-True-True] | 60.7800μs | 40.1102μs | 24.9313 KOps/s | 25.4296 KOps/s | |
test_step_mdp_speed[False-False-True-True-False] | 49.6800μs | 27.1431μs | 36.8417 KOps/s | 37.4096 KOps/s | |
test_step_mdp_speed[False-False-True-False-True] | 0.1021ms | 26.2919μs | 38.0346 KOps/s | 38.4245 KOps/s | |
test_step_mdp_speed[False-False-True-False-False] | 37.6800μs | 16.8699μs | 59.2771 KOps/s | 60.9477 KOps/s | |
test_step_mdp_speed[False-False-False-True-True] | 68.1610μs | 40.9772μs | 24.4038 KOps/s | 24.2955 KOps/s | |
test_step_mdp_speed[False-False-False-True-False] | 60.9620μs | 28.4801μs | 35.1122 KOps/s | 35.1922 KOps/s | |
test_step_mdp_speed[False-False-False-False-True] | 58.2820μs | 27.7324μs | 36.0590 KOps/s | 36.4793 KOps/s | |
test_step_mdp_speed[False-False-False-False-False] | 39.1720μs | 18.5209μs | 53.9930 KOps/s | 55.6885 KOps/s | |
test_values[generalized_advantage_estimate-True-True] | 28.2647ms | 27.1010ms | 36.8990 Ops/s | 38.1818 Ops/s | |
test_values[vec_generalized_advantage_estimate-True-True] | 96.4090ms | 3.5209ms | 284.0195 Ops/s | 297.3413 Ops/s | |
test_values[td0_return_estimate-False-False] | 0.1023ms | 67.9680μs | 14.7128 KOps/s | 15.0931 KOps/s | |
test_values[td1_return_estimate-False-False] | 60.8722ms | 58.5337ms | 17.0842 Ops/s | 16.5268 Ops/s | |
test_values[vec_td1_return_estimate-False-False] | 2.0874ms | 1.7992ms | 555.8107 Ops/s | 556.8632 Ops/s | |
test_values[td_lambda_return_estimate-True-False] | 96.1933ms | 92.5187ms | 10.8086 Ops/s | 10.4449 Ops/s | |
test_values[vec_td_lambda_return_estimate-True-False] | 2.1080ms | 1.8196ms | 549.5798 Ops/s | 551.9136 Ops/s | |
test_gae_speed[generalized_advantage_estimate-False-1-512] | 25.6983ms | 25.4829ms | 39.2421 Ops/s | 36.6769 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] | 0.9186ms | 0.7375ms | 1.3560 KOps/s | 1.2925 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] | 0.7467ms | 0.6909ms | 1.4473 KOps/s | 1.3909 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] | 1.5652ms | 1.4889ms | 671.6171 Ops/s | 674.4053 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] | 1.0680ms | 0.7215ms | 1.3859 KOps/s | 1.4215 KOps/s | |
test_dqn_speed | 7.8889ms | 1.4967ms | 668.1287 Ops/s | 684.1421 Ops/s | |
test_ddpg_speed | 4.3545ms | 3.3713ms | 296.6248 Ops/s | 281.4170 Ops/s | |
test_sac_speed | 95.3695ms | 10.3386ms | 96.7246 Ops/s | 108.4807 Ops/s | |
test_redq_speed | 17.6201ms | 16.8256ms | 59.4334 Ops/s | 60.9994 Ops/s | |
test_redq_deprec_speed | 13.8919ms | 13.1404ms | 76.1015 Ops/s | 77.1527 Ops/s | |
test_td3_speed | 19.3772ms | 9.7962ms | 102.0808 Ops/s | 105.2477 Ops/s | |
test_cql_speed | 34.9763ms | 33.9319ms | 29.4708 Ops/s | 29.8122 Ops/s | |
test_a2c_speed | 8.0086ms | 7.0726ms | 141.3908 Ops/s | 140.1266 Ops/s | |
test_ppo_speed | 8.3746ms | 7.4332ms | 134.5308 Ops/s | 134.0231 Ops/s | |
test_reinforce_speed | 7.3177ms | 6.1169ms | 163.4813 Ops/s | 162.4654 Ops/s | |
test_iql_speed | 28.5754ms | 27.1849ms | 36.7851 Ops/s | 37.0608 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 3.1689ms | 2.5806ms | 387.5018 Ops/s | 399.1688 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 4.0238ms | 2.7402ms | 364.9383 Ops/s | 333.9103 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 4.4665ms | 2.7529ms | 363.2590 Ops/s | 376.8130 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 3.1598ms | 2.5686ms | 389.3098 Ops/s | 403.0050 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 3.6987ms | 2.7354ms | 365.5723 Ops/s | 375.9474 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 3.8128ms | 2.7430ms | 364.5627 Ops/s | 373.7324 Ops/s | |
test_sample_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 3.2285ms | 2.5634ms | 390.1008 Ops/s | 399.5665 Ops/s | |
test_sample_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 3.9081ms | 2.7414ms | 364.7833 Ops/s | 375.9346 Ops/s | |
test_sample_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 3.8855ms | 2.7408ms | 364.8583 Ops/s | 373.5261 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 3.2239ms | 2.5446ms | 392.9889 Ops/s | 398.8909 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 3.8034ms | 2.7363ms | 365.4569 Ops/s | 375.1835 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 4.0658ms | 2.7605ms | 362.2478 Ops/s | 374.0149 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 3.0107ms | 2.5696ms | 389.1695 Ops/s | 400.7725 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 3.7686ms | 2.7331ms | 365.8853 Ops/s | 372.5148 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 3.7353ms | 2.7619ms | 362.0761 Ops/s | 373.1918 Ops/s | |
test_iterate_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 2.9621ms | 2.5571ms | 391.0650 Ops/s | 401.8452 Ops/s | |
test_iterate_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 4.0797ms | 2.7512ms | 363.4748 Ops/s | 375.7625 Ops/s | |
test_iterate_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 3.9770ms | 2.7366ms | 365.4169 Ops/s | 374.6625 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-400] | 0.1865s | 18.7759ms | 53.2598 Ops/s | 53.9493 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] | 0.1194s | 17.3967ms | 57.4822 Ops/s | 59.0125 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] | 0.1195s | 17.3491ms | 57.6399 Ops/s | 58.8827 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] | 0.1207s | 17.2779ms | 57.8774 Ops/s | 58.6140 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] | 0.1200s | 15.1831ms | 65.8628 Ops/s | 67.3820 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] | 0.1200s | 17.3900ms | 57.5042 Ops/s | 58.7152 Ops/s | |
test_populate_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] | 0.1198s | 17.3946ms | 57.4891 Ops/s | 58.6636 Ops/s | |
test_populate_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] | 0.1210s | 17.3742ms | 57.5568 Ops/s | 58.5735 Ops/s | |
test_populate_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] | 0.1229s | 15.2061ms | 65.7629 Ops/s | 58.4777 Ops/s |
This fix actually changed the behavior and is breaking the CI of BenchMARL for discrete SAC with MaskedCategorical. What happens is that if
while Can we return to the previous behavior? |
Not sure returning to the previous behaviour is appropriate. Why is this a problem? Edit: I see , you don't want the -inf but the -18. We can clamp logits if that helps |
The -inf causes a bug in MaskedCategorical https://github.com/facebookresearch/BenchMARL/actions/runs/7278697272/job/19833445677 If you prefer to adapt how the logits are computed in |
I wonder what was the problem of the previous behavior tho. |
Efficiency: we were recomputing something already computed by the dist And the code previously was seriously ugly |
cc @matteobettini