-
Notifications
You must be signed in to change notification settings - Fork 327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Masking actions #1421
Conversation
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_single | 0.1548s | 0.1520s | 6.5772 Ops/s | 6.5386 Ops/s | |
test_sync | 0.1575s | 85.9564ms | 11.6338 Ops/s | 10.9943 Ops/s | |
test_async | 0.2196s | 78.9957ms | 12.6589 Ops/s | 11.8781 Ops/s | |
test_simple | 0.7483s | 0.6761s | 1.4790 Ops/s | 1.4680 Ops/s | |
test_transformed | 1.8586s | 1.8076s | 0.5532 Ops/s | 0.5518 Ops/s | |
test_serial | 1.9170s | 1.8669s | 0.5357 Ops/s | 0.5312 Ops/s | |
test_parallel | 1.6259s | 1.5795s | 0.6331 Ops/s | 0.6273 Ops/s | |
test_step_mdp_speed[True-True-True-True-True] | 0.1405ms | 50.2774μs | 19.8897 KOps/s | 19.8566 KOps/s | |
test_step_mdp_speed[True-True-True-True-False] | 0.1062ms | 28.4863μs | 35.1046 KOps/s | 35.0612 KOps/s | |
test_step_mdp_speed[True-True-True-False-True] | 83.5030μs | 35.0919μs | 28.4966 KOps/s | 28.1979 KOps/s | |
test_step_mdp_speed[True-True-True-False-False] | 70.8030μs | 19.3828μs | 51.5921 KOps/s | 50.8936 KOps/s | |
test_step_mdp_speed[True-True-False-True-True] | 88.4040μs | 51.9179μs | 19.2612 KOps/s | 18.9809 KOps/s | |
test_step_mdp_speed[True-True-False-True-False] | 59.4030μs | 30.8485μs | 32.4165 KOps/s | 32.7999 KOps/s | |
test_step_mdp_speed[True-True-False-False-True] | 0.1078ms | 37.2733μs | 26.8289 KOps/s | 26.6865 KOps/s | |
test_step_mdp_speed[True-True-False-False-False] | 73.8030μs | 21.5821μs | 46.3348 KOps/s | 45.9610 KOps/s | |
test_step_mdp_speed[True-False-True-True-True] | 0.1880ms | 54.8654μs | 18.2264 KOps/s | 18.1899 KOps/s | |
test_step_mdp_speed[True-False-True-True-False] | 84.8030μs | 32.7533μs | 30.5313 KOps/s | 30.4183 KOps/s | |
test_step_mdp_speed[True-False-True-False-True] | 75.5020μs | 37.6253μs | 26.5779 KOps/s | 26.4269 KOps/s | |
test_step_mdp_speed[True-False-True-False-False] | 52.0020μs | 21.8457μs | 45.7757 KOps/s | 46.4769 KOps/s | |
test_step_mdp_speed[True-False-False-True-True] | 0.1175ms | 56.0874μs | 17.8293 KOps/s | 17.9417 KOps/s | |
test_step_mdp_speed[True-False-False-True-False] | 77.6020μs | 34.5877μs | 28.9120 KOps/s | 30.0169 KOps/s | |
test_step_mdp_speed[True-False-False-False-True] | 68.9020μs | 39.6769μs | 25.2036 KOps/s | 26.5459 KOps/s | |
test_step_mdp_speed[True-False-False-False-False] | 82.7030μs | 23.7256μs | 42.1485 KOps/s | 43.8593 KOps/s | |
test_step_mdp_speed[False-True-True-True-True] | 93.8030μs | 54.8913μs | 18.2178 KOps/s | 18.5242 KOps/s | |
test_step_mdp_speed[False-True-True-True-False] | 63.6020μs | 33.0259μs | 30.2792 KOps/s | 30.6757 KOps/s | |
test_step_mdp_speed[False-True-True-False-True] | 79.6030μs | 42.4741μs | 23.5437 KOps/s | 23.5128 KOps/s | |
test_step_mdp_speed[False-True-True-False-False] | 3.3140ms | 24.7898μs | 40.3391 KOps/s | 40.8173 KOps/s | |
test_step_mdp_speed[False-True-False-True-True] | 92.8030μs | 56.1048μs | 17.8238 KOps/s | 17.6596 KOps/s | |
test_step_mdp_speed[False-True-False-True-False] | 0.1091ms | 34.5739μs | 28.9235 KOps/s | 28.9078 KOps/s | |
test_step_mdp_speed[False-True-False-False-True] | 77.4030μs | 43.4178μs | 23.0320 KOps/s | 22.4527 KOps/s | |
test_step_mdp_speed[False-True-False-False-False] | 91.4020μs | 26.5849μs | 37.6154 KOps/s | 38.8586 KOps/s | |
test_step_mdp_speed[False-False-True-True-True] | 0.1317ms | 57.8524μs | 17.2854 KOps/s | 16.8159 KOps/s | |
test_step_mdp_speed[False-False-True-True-False] | 91.0030μs | 36.9147μs | 27.0895 KOps/s | 27.2419 KOps/s | |
test_step_mdp_speed[False-False-True-False-True] | 76.4020μs | 43.4920μs | 22.9927 KOps/s | 22.5897 KOps/s | |
test_step_mdp_speed[False-False-True-False-False] | 51.8020μs | 25.7663μs | 38.8104 KOps/s | 38.7862 KOps/s | |
test_step_mdp_speed[False-False-False-True-True] | 99.8030μs | 60.4897μs | 16.5317 KOps/s | 16.6756 KOps/s | |
test_step_mdp_speed[False-False-False-True-False] | 70.8020μs | 38.1073μs | 26.2417 KOps/s | 26.0874 KOps/s | |
test_step_mdp_speed[False-False-False-False-True] | 84.7030μs | 45.0547μs | 22.1952 KOps/s | 22.4442 KOps/s | |
test_step_mdp_speed[False-False-False-False-False] | 87.0030μs | 28.1296μs | 35.5498 KOps/s | 36.5420 KOps/s | |
test_values[generalized_advantage_estimate-True-True] | 16.3963ms | 14.5353ms | 68.7978 Ops/s | 67.0843 Ops/s | |
test_values[vec_generalized_advantage_estimate-True-True] | 55.7828ms | 45.9331ms | 21.7708 Ops/s | 21.4599 Ops/s | |
test_values[td0_return_estimate-False-False] | 0.9517ms | 0.2304ms | 4.3406 KOps/s | 4.8456 KOps/s | |
test_values[td1_return_estimate-False-False] | 14.6902ms | 14.1322ms | 70.7602 Ops/s | 70.9977 Ops/s | |
test_values[vec_td1_return_estimate-False-False] | 48.7537ms | 45.2858ms | 22.0820 Ops/s | 21.5807 Ops/s | |
test_values[td_lambda_return_estimate-True-False] | 35.4501ms | 34.2886ms | 29.1642 Ops/s | 28.8388 Ops/s | |
test_values[vec_td_lambda_return_estimate-True-False] | 52.1185ms | 46.0214ms | 21.7290 Ops/s | 21.5346 Ops/s | |
test_gae_speed[generalized_advantage_estimate-False-1-512] | 13.4284ms | 12.9015ms | 77.5102 Ops/s | 75.9422 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] | 9.9856ms | 3.8861ms | 257.3302 Ops/s | 252.2812 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] | 0.6882ms | 0.5415ms | 1.8466 KOps/s | 1.8519 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] | 70.4547ms | 63.2401ms | 15.8128 Ops/s | 16.0870 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] | 8.7980ms | 3.1594ms | 316.5145 Ops/s | 306.9439 Ops/s | |
test_dqn_speed | 7.8854ms | 2.0935ms | 477.6737 Ops/s | 475.4616 Ops/s | |
test_ddpg_speed | 9.0624ms | 3.1706ms | 315.3941 Ops/s | 321.1452 Ops/s | |
test_sac_speed | 24.3538ms | 9.2791ms | 107.7695 Ops/s | 107.9764 Ops/s | |
test_redq_speed | 22.9642ms | 17.5445ms | 56.9979 Ops/s | 58.2493 Ops/s | |
test_redq_deprec_speed | 18.9927ms | 14.3588ms | 69.6437 Ops/s | 71.4784 Ops/s | |
test_td3_speed | 12.1458ms | 11.0296ms | 90.6652 Ops/s | 89.7260 Ops/s | |
test_cql_speed | 41.2503ms | 32.8412ms | 30.4496 Ops/s | 27.2872 Ops/s | |
test_a2c_speed | 12.2889ms | 5.7946ms | 172.5757 Ops/s | 170.7927 Ops/s | |
test_ppo_speed | 11.8281ms | 6.2457ms | 160.1111 Ops/s | 160.9944 Ops/s | |
test_reinforce_speed | 10.7062ms | 4.3997ms | 227.2881 Ops/s | 220.7265 Ops/s | |
test_iql_speed | 24.4331ms | 23.1197ms | 43.2531 Ops/s | 43.1471 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 3.5343ms | 2.8460ms | 351.3742 Ops/s | 361.7488 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 0.1613s | 3.5102ms | 284.8855 Ops/s | 336.7689 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 4.0813ms | 3.0011ms | 333.2087 Ops/s | 330.5272 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 0.2377s | 3.5574ms | 281.1078 Ops/s | 349.6158 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 5.1819ms | 3.1169ms | 320.8317 Ops/s | 329.9234 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 4.4518ms | 2.9493ms | 339.0684 Ops/s | 324.9193 Ops/s | |
test_sample_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 3.7818ms | 2.7610ms | 362.1905 Ops/s | 343.3214 Ops/s | |
test_sample_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 4.7622ms | 2.9850ms | 335.0031 Ops/s | 329.9058 Ops/s | |
test_sample_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 5.4334ms | 2.9503ms | 338.9490 Ops/s | 328.1113 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 3.5601ms | 2.7568ms | 362.7340 Ops/s | 353.8105 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 5.5718ms | 3.0714ms | 325.5881 Ops/s | 325.9764 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 4.2233ms | 2.9309ms | 341.1891 Ops/s | 323.2375 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 3.7276ms | 2.8266ms | 353.7776 Ops/s | 342.3646 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 5.9782ms | 2.9981ms | 333.5499 Ops/s | 324.9017 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 4.9125ms | 3.0757ms | 325.1250 Ops/s | 325.8986 Ops/s | |
test_iterate_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 3.9711ms | 2.8636ms | 349.2096 Ops/s | 345.0337 Ops/s | |
test_iterate_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 5.3334ms | 3.0418ms | 328.7477 Ops/s | 270.3369 Ops/s | |
test_iterate_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 4.9188ms | 2.9669ms | 337.0520 Ops/s | 315.1893 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-400] | 0.2737s | 31.1482ms | 32.1046 Ops/s | 31.4804 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] | 0.1434s | 30.6415ms | 32.6355 Ops/s | 31.8552 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] | 0.1485s | 28.4929ms | 35.0965 Ops/s | 38.2444 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] | 0.1426s | 28.1734ms | 35.4945 Ops/s | 34.5247 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] | 0.1428s | 27.8398ms | 35.9198 Ops/s | 34.8355 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] | 0.1418s | 30.5194ms | 32.7660 Ops/s | 31.9086 Ops/s | |
test_populate_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] | 0.1390s | 27.9566ms | 35.7698 Ops/s | 35.3693 Ops/s | |
test_populate_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] | 0.1427s | 30.5358ms | 32.7484 Ops/s | 31.5584 Ops/s | |
test_populate_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] | 0.1363s | 27.7174ms | 36.0784 Ops/s | 34.6794 Ops/s |
This is awesome, I think it might also be helpful for #1201 since Unity supports masked actions as well. |
A few things to be aware of:
|
Shouldn't it be |
I m talking about a different scenario. What i am doing right now is that i am setting a mask of all false for all actions of that agent, but not passing that to the spec (as it would crash), just having it in the env output. So the spec will still see all actions available for that agent, but in reality that whole transition for it should be masked out by the user later. is there a better way to treat this? |
# Conflicts: # test/test_specs.py
Then what should class SkipAction(Transform):
"""Adds an action to the env that skips the step if selected.
""" For instance, if |
ok yeah I see. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This does not work for multi discrete and multi one hot right?
I think it would be useful to support that
Also we need to add this line to check_env_specs, because otherwise the check will fail when the action_mask gets updated
|
# Conflicts: # test/test_transforms.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
are we leaving the multi spec for another PR?
you mean multi-one-hot (ie in a single tensor) or multiple masks? |
Multi one hot and multi discrete |
@matteobettini now these are supported too |
# Conflicts: # test/test_libs.py
# Conflicts: # torchrl/data/tensor_specs.py
# Conflicts: # torchrl/data/tensor_specs.py
Description
Addresses #1404
A quick example of the functionality
https://gist.github.com/vmoens/95d6427fcb5fa5714291b3dbfa7daa15#file-action_mask-ipynb
TODO: