Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Masking actions #1421

Merged
merged 32 commits into from
Sep 3, 2023
Merged

[Feature] Masking actions #1421

merged 32 commits into from
Sep 3, 2023

Conversation

vmoens
Copy link
Contributor

@vmoens vmoens commented Jul 27, 2023

Description

Addresses #1404

A quick example of the functionality
https://gist.github.com/vmoens/95d6427fcb5fa5714291b3dbfa7daa15#file-action_mask-ipynb

TODO:

  • DQN compatibility
  • ActionMask tests
    • focus on reset
    • focus on exceptions
    • focus on compose (do we change a clone or the actual spec?). what about nested compose?
  • test exceptions of masking specs

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 27, 2023
@vmoens vmoens added the enhancement New feature or request label Jul 27, 2023
@vmoens vmoens requested a review from xiaomengy July 27, 2023 15:03
@github-actions
Copy link

github-actions bot commented Jul 27, 2023

$\color{#D29922}\textsf{\Large⚠\kern{0.2cm}\normalsize Warning}$ Result of CPU Benchmark Tests

Total Benchmarks: 89. Improved: $\large\color{#35bf28}7$. Worsened: $\large\color{#d91a1a}5$.

Expand to view detailed results
Name Max Mean Ops Ops on Repo HEAD Change
test_single 0.1548s 0.1520s 6.5772 Ops/s 6.5386 Ops/s $\color{#35bf28}+0.59\%$
test_sync 0.1575s 85.9564ms 11.6338 Ops/s 10.9943 Ops/s $\textbf{\color{#35bf28}+5.82\%}$
test_async 0.2196s 78.9957ms 12.6589 Ops/s 11.8781 Ops/s $\textbf{\color{#35bf28}+6.57\%}$
test_simple 0.7483s 0.6761s 1.4790 Ops/s 1.4680 Ops/s $\color{#35bf28}+0.75\%$
test_transformed 1.8586s 1.8076s 0.5532 Ops/s 0.5518 Ops/s $\color{#35bf28}+0.27\%$
test_serial 1.9170s 1.8669s 0.5357 Ops/s 0.5312 Ops/s $\color{#35bf28}+0.84\%$
test_parallel 1.6259s 1.5795s 0.6331 Ops/s 0.6273 Ops/s $\color{#35bf28}+0.93\%$
test_step_mdp_speed[True-True-True-True-True] 0.1405ms 50.2774μs 19.8897 KOps/s 19.8566 KOps/s $\color{#35bf28}+0.17\%$
test_step_mdp_speed[True-True-True-True-False] 0.1062ms 28.4863μs 35.1046 KOps/s 35.0612 KOps/s $\color{#35bf28}+0.12\%$
test_step_mdp_speed[True-True-True-False-True] 83.5030μs 35.0919μs 28.4966 KOps/s 28.1979 KOps/s $\color{#35bf28}+1.06\%$
test_step_mdp_speed[True-True-True-False-False] 70.8030μs 19.3828μs 51.5921 KOps/s 50.8936 KOps/s $\color{#35bf28}+1.37\%$
test_step_mdp_speed[True-True-False-True-True] 88.4040μs 51.9179μs 19.2612 KOps/s 18.9809 KOps/s $\color{#35bf28}+1.48\%$
test_step_mdp_speed[True-True-False-True-False] 59.4030μs 30.8485μs 32.4165 KOps/s 32.7999 KOps/s $\color{#d91a1a}-1.17\%$
test_step_mdp_speed[True-True-False-False-True] 0.1078ms 37.2733μs 26.8289 KOps/s 26.6865 KOps/s $\color{#35bf28}+0.53\%$
test_step_mdp_speed[True-True-False-False-False] 73.8030μs 21.5821μs 46.3348 KOps/s 45.9610 KOps/s $\color{#35bf28}+0.81\%$
test_step_mdp_speed[True-False-True-True-True] 0.1880ms 54.8654μs 18.2264 KOps/s 18.1899 KOps/s $\color{#35bf28}+0.20\%$
test_step_mdp_speed[True-False-True-True-False] 84.8030μs 32.7533μs 30.5313 KOps/s 30.4183 KOps/s $\color{#35bf28}+0.37\%$
test_step_mdp_speed[True-False-True-False-True] 75.5020μs 37.6253μs 26.5779 KOps/s 26.4269 KOps/s $\color{#35bf28}+0.57\%$
test_step_mdp_speed[True-False-True-False-False] 52.0020μs 21.8457μs 45.7757 KOps/s 46.4769 KOps/s $\color{#d91a1a}-1.51\%$
test_step_mdp_speed[True-False-False-True-True] 0.1175ms 56.0874μs 17.8293 KOps/s 17.9417 KOps/s $\color{#d91a1a}-0.63\%$
test_step_mdp_speed[True-False-False-True-False] 77.6020μs 34.5877μs 28.9120 KOps/s 30.0169 KOps/s $\color{#d91a1a}-3.68\%$
test_step_mdp_speed[True-False-False-False-True] 68.9020μs 39.6769μs 25.2036 KOps/s 26.5459 KOps/s $\textbf{\color{#d91a1a}-5.06\%}$
test_step_mdp_speed[True-False-False-False-False] 82.7030μs 23.7256μs 42.1485 KOps/s 43.8593 KOps/s $\color{#d91a1a}-3.90\%$
test_step_mdp_speed[False-True-True-True-True] 93.8030μs 54.8913μs 18.2178 KOps/s 18.5242 KOps/s $\color{#d91a1a}-1.65\%$
test_step_mdp_speed[False-True-True-True-False] 63.6020μs 33.0259μs 30.2792 KOps/s 30.6757 KOps/s $\color{#d91a1a}-1.29\%$
test_step_mdp_speed[False-True-True-False-True] 79.6030μs 42.4741μs 23.5437 KOps/s 23.5128 KOps/s $\color{#35bf28}+0.13\%$
test_step_mdp_speed[False-True-True-False-False] 3.3140ms 24.7898μs 40.3391 KOps/s 40.8173 KOps/s $\color{#d91a1a}-1.17\%$
test_step_mdp_speed[False-True-False-True-True] 92.8030μs 56.1048μs 17.8238 KOps/s 17.6596 KOps/s $\color{#35bf28}+0.93\%$
test_step_mdp_speed[False-True-False-True-False] 0.1091ms 34.5739μs 28.9235 KOps/s 28.9078 KOps/s $\color{#35bf28}+0.05\%$
test_step_mdp_speed[False-True-False-False-True] 77.4030μs 43.4178μs 23.0320 KOps/s 22.4527 KOps/s $\color{#35bf28}+2.58\%$
test_step_mdp_speed[False-True-False-False-False] 91.4020μs 26.5849μs 37.6154 KOps/s 38.8586 KOps/s $\color{#d91a1a}-3.20\%$
test_step_mdp_speed[False-False-True-True-True] 0.1317ms 57.8524μs 17.2854 KOps/s 16.8159 KOps/s $\color{#35bf28}+2.79\%$
test_step_mdp_speed[False-False-True-True-False] 91.0030μs 36.9147μs 27.0895 KOps/s 27.2419 KOps/s $\color{#d91a1a}-0.56\%$
test_step_mdp_speed[False-False-True-False-True] 76.4020μs 43.4920μs 22.9927 KOps/s 22.5897 KOps/s $\color{#35bf28}+1.78\%$
test_step_mdp_speed[False-False-True-False-False] 51.8020μs 25.7663μs 38.8104 KOps/s 38.7862 KOps/s $\color{#35bf28}+0.06\%$
test_step_mdp_speed[False-False-False-True-True] 99.8030μs 60.4897μs 16.5317 KOps/s 16.6756 KOps/s $\color{#d91a1a}-0.86\%$
test_step_mdp_speed[False-False-False-True-False] 70.8020μs 38.1073μs 26.2417 KOps/s 26.0874 KOps/s $\color{#35bf28}+0.59\%$
test_step_mdp_speed[False-False-False-False-True] 84.7030μs 45.0547μs 22.1952 KOps/s 22.4442 KOps/s $\color{#d91a1a}-1.11\%$
test_step_mdp_speed[False-False-False-False-False] 87.0030μs 28.1296μs 35.5498 KOps/s 36.5420 KOps/s $\color{#d91a1a}-2.72\%$
test_values[generalized_advantage_estimate-True-True] 16.3963ms 14.5353ms 68.7978 Ops/s 67.0843 Ops/s $\color{#35bf28}+2.55\%$
test_values[vec_generalized_advantage_estimate-True-True] 55.7828ms 45.9331ms 21.7708 Ops/s 21.4599 Ops/s $\color{#35bf28}+1.45\%$
test_values[td0_return_estimate-False-False] 0.9517ms 0.2304ms 4.3406 KOps/s 4.8456 KOps/s $\textbf{\color{#d91a1a}-10.42\%}$
test_values[td1_return_estimate-False-False] 14.6902ms 14.1322ms 70.7602 Ops/s 70.9977 Ops/s $\color{#d91a1a}-0.33\%$
test_values[vec_td1_return_estimate-False-False] 48.7537ms 45.2858ms 22.0820 Ops/s 21.5807 Ops/s $\color{#35bf28}+2.32\%$
test_values[td_lambda_return_estimate-True-False] 35.4501ms 34.2886ms 29.1642 Ops/s 28.8388 Ops/s $\color{#35bf28}+1.13\%$
test_values[vec_td_lambda_return_estimate-True-False] 52.1185ms 46.0214ms 21.7290 Ops/s 21.5346 Ops/s $\color{#35bf28}+0.90\%$
test_gae_speed[generalized_advantage_estimate-False-1-512] 13.4284ms 12.9015ms 77.5102 Ops/s 75.9422 Ops/s $\color{#35bf28}+2.06\%$
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] 9.9856ms 3.8861ms 257.3302 Ops/s 252.2812 Ops/s $\color{#35bf28}+2.00\%$
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] 0.6882ms 0.5415ms 1.8466 KOps/s 1.8519 KOps/s $\color{#d91a1a}-0.29\%$
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] 70.4547ms 63.2401ms 15.8128 Ops/s 16.0870 Ops/s $\color{#d91a1a}-1.71\%$
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] 8.7980ms 3.1594ms 316.5145 Ops/s 306.9439 Ops/s $\color{#35bf28}+3.12\%$
test_dqn_speed 7.8854ms 2.0935ms 477.6737 Ops/s 475.4616 Ops/s $\color{#35bf28}+0.47\%$
test_ddpg_speed 9.0624ms 3.1706ms 315.3941 Ops/s 321.1452 Ops/s $\color{#d91a1a}-1.79\%$
test_sac_speed 24.3538ms 9.2791ms 107.7695 Ops/s 107.9764 Ops/s $\color{#d91a1a}-0.19\%$
test_redq_speed 22.9642ms 17.5445ms 56.9979 Ops/s 58.2493 Ops/s $\color{#d91a1a}-2.15\%$
test_redq_deprec_speed 18.9927ms 14.3588ms 69.6437 Ops/s 71.4784 Ops/s $\color{#d91a1a}-2.57\%$
test_td3_speed 12.1458ms 11.0296ms 90.6652 Ops/s 89.7260 Ops/s $\color{#35bf28}+1.05\%$
test_cql_speed 41.2503ms 32.8412ms 30.4496 Ops/s 27.2872 Ops/s $\textbf{\color{#35bf28}+11.59\%}$
test_a2c_speed 12.2889ms 5.7946ms 172.5757 Ops/s 170.7927 Ops/s $\color{#35bf28}+1.04\%$
test_ppo_speed 11.8281ms 6.2457ms 160.1111 Ops/s 160.9944 Ops/s $\color{#d91a1a}-0.55\%$
test_reinforce_speed 10.7062ms 4.3997ms 227.2881 Ops/s 220.7265 Ops/s $\color{#35bf28}+2.97\%$
test_iql_speed 24.4331ms 23.1197ms 43.2531 Ops/s 43.1471 Ops/s $\color{#35bf28}+0.25\%$
test_sample_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 3.5343ms 2.8460ms 351.3742 Ops/s 361.7488 Ops/s $\color{#d91a1a}-2.87\%$
test_sample_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 0.1613s 3.5102ms 284.8855 Ops/s 336.7689 Ops/s $\textbf{\color{#d91a1a}-15.41\%}$
test_sample_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 4.0813ms 3.0011ms 333.2087 Ops/s 330.5272 Ops/s $\color{#35bf28}+0.81\%$
test_sample_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 0.2377s 3.5574ms 281.1078 Ops/s 349.6158 Ops/s $\textbf{\color{#d91a1a}-19.60\%}$
test_sample_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 5.1819ms 3.1169ms 320.8317 Ops/s 329.9234 Ops/s $\color{#d91a1a}-2.76\%$
test_sample_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 4.4518ms 2.9493ms 339.0684 Ops/s 324.9193 Ops/s $\color{#35bf28}+4.35\%$
test_sample_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 3.7818ms 2.7610ms 362.1905 Ops/s 343.3214 Ops/s $\textbf{\color{#35bf28}+5.50\%}$
test_sample_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 4.7622ms 2.9850ms 335.0031 Ops/s 329.9058 Ops/s $\color{#35bf28}+1.55\%$
test_sample_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 5.4334ms 2.9503ms 338.9490 Ops/s 328.1113 Ops/s $\color{#35bf28}+3.30\%$
test_iterate_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 3.5601ms 2.7568ms 362.7340 Ops/s 353.8105 Ops/s $\color{#35bf28}+2.52\%$
test_iterate_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 5.5718ms 3.0714ms 325.5881 Ops/s 325.9764 Ops/s $\color{#d91a1a}-0.12\%$
test_iterate_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 4.2233ms 2.9309ms 341.1891 Ops/s 323.2375 Ops/s $\textbf{\color{#35bf28}+5.55\%}$
test_iterate_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 3.7276ms 2.8266ms 353.7776 Ops/s 342.3646 Ops/s $\color{#35bf28}+3.33\%$
test_iterate_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 5.9782ms 2.9981ms 333.5499 Ops/s 324.9017 Ops/s $\color{#35bf28}+2.66\%$
test_iterate_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 4.9125ms 3.0757ms 325.1250 Ops/s 325.8986 Ops/s $\color{#d91a1a}-0.24\%$
test_iterate_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 3.9711ms 2.8636ms 349.2096 Ops/s 345.0337 Ops/s $\color{#35bf28}+1.21\%$
test_iterate_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 5.3334ms 3.0418ms 328.7477 Ops/s 270.3369 Ops/s $\textbf{\color{#35bf28}+21.61\%}$
test_iterate_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 4.9188ms 2.9669ms 337.0520 Ops/s 315.1893 Ops/s $\textbf{\color{#35bf28}+6.94\%}$
test_populate_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-400] 0.2737s 31.1482ms 32.1046 Ops/s 31.4804 Ops/s $\color{#35bf28}+1.98\%$
test_populate_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] 0.1434s 30.6415ms 32.6355 Ops/s 31.8552 Ops/s $\color{#35bf28}+2.45\%$
test_populate_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] 0.1485s 28.4929ms 35.0965 Ops/s 38.2444 Ops/s $\textbf{\color{#d91a1a}-8.23\%}$
test_populate_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] 0.1426s 28.1734ms 35.4945 Ops/s 34.5247 Ops/s $\color{#35bf28}+2.81\%$
test_populate_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] 0.1428s 27.8398ms 35.9198 Ops/s 34.8355 Ops/s $\color{#35bf28}+3.11\%$
test_populate_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] 0.1418s 30.5194ms 32.7660 Ops/s 31.9086 Ops/s $\color{#35bf28}+2.69\%$
test_populate_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] 0.1390s 27.9566ms 35.7698 Ops/s 35.3693 Ops/s $\color{#35bf28}+1.13\%$
test_populate_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] 0.1427s 30.5358ms 32.7484 Ops/s 31.5584 Ops/s $\color{#35bf28}+3.77\%$
test_populate_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] 0.1363s 27.7174ms 36.0784 Ops/s 34.6794 Ops/s $\color{#35bf28}+4.03\%$

@hyerra
Copy link
Contributor

hyerra commented Jul 27, 2023

This is awesome, I think it might also be helpful for #1201 since Unity supports masked actions as well.

@matteobettini
Copy link
Contributor

I also think this is awsome. It will be needed for MLAgents, as @hyerra said, for PettingZoo #1441 and for SMAC #1466

@matteobettini
Copy link
Contributor

matteobettini commented Aug 23, 2023

A few things to be aware of:

  • this works well when you have to mask some of the actions but some are still available. how do we treat the case when all actions are unavailable (eg an agent dies)?
  • check_env_specs needs also to be adapted, because it will try to check that the last action is in the current action spec, but the mask of the current action spec has already been updated after the last step.

@vmoens
Copy link
Contributor Author

vmoens commented Aug 27, 2023

A few things to be aware of:

  • this works well when you have to mask some of the actions but some are still available. how do we treat the case when all actions are unavailable (eg an agent dies)?

Shouldn't it be done then?
If we think about other situations, I'm not sure it makes sense to cover this explicitly. Example: if you're playing a game and you have N actions available but when your turn comes you don't have anything to do, there should be one more action N+1 "pass".

@matteobettini
Copy link
Contributor

matteobettini commented Aug 27, 2023

Example: if you're playing a game and you have N actions available but when your turn comes you don't have anything to do, there should be one more action N+1 "pass".

I m talking about a different scenario.
For example, you are playing a game, but only one player acts at a time. It is not your turn, but you are also not done.

What i am doing right now is that i am setting a mask of all false for all actions of that agent, but not passing that to the spec (as it would crash), just having it in the env output. So the spec will still see all actions available for that agent, but in reality that whole transition for it should be masked out by the user later.

is there a better way to treat this?
Maybe passing a mask of all false should be allowed? And all specs should have mask compatibility?

# Conflicts:
#	test/test_specs.py
@vmoens
Copy link
Contributor Author

vmoens commented Aug 30, 2023

Maybe passing a mask of all false should be allowed? And all specs should have mask compatibility?

Then what should action_spec.rand() return? It's gonna be confusing no?
IMO if your env has the option of having an action "do nothing" this is something that should be in the env otherwise it will be surprising.
In this case, we can do a transform:

class SkipAction(Transform):
    """Adds an action to the env that skips the step if selected.
    """

For instance, if action_spec has 6 possible actions, an action with value 6 will instruct the transform to skip the step.
A bit of magic will be needed to make it work with composite actions...

@matteobettini
Copy link
Contributor

ok yeah I see.
we can keep this PR to the current scope then

@vmoens vmoens marked this pull request as ready for review August 30, 2023 14:45
@vmoens vmoens linked an issue Aug 30, 2023 that may be closed by this pull request
Copy link
Contributor

@matteobettini matteobettini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

This does not work for multi discrete and multi one hot right?
I think it would be useful to support that

@matteobettini
Copy link
Contributor

Also we need to add this line to check_env_specs, because otherwise the check will fail when the action_mask gets updated

# Check specs
last_td = real_tensordict[..., -1]
# replace action with one sampled from last spec
# (since spec mask might be changed after last step)
last_td = env.rand_action(last_td) ## <------ this line
full_action_spec = env.input_spec["full_action_spec"]
full_state_spec = env.input_spec["full_state_spec"]

Copy link
Contributor

@matteobettini matteobettini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

are we leaving the multi spec for another PR?

@vmoens
Copy link
Contributor Author

vmoens commented Sep 1, 2023

LGTM

are we leaving the multi spec for another PR?

you mean multi-one-hot (ie in a single tensor) or multiple masks?

@matteobettini
Copy link
Contributor

Multi one hot and multi discrete

@vmoens
Copy link
Contributor Author

vmoens commented Sep 1, 2023

@matteobettini now these are supported too

@vmoens vmoens merged commit 9fded1a into main Sep 3, 2023
@vmoens vmoens deleted the masked_actions branch September 5, 2023 08:42
vmoens added a commit to hyerra/rl that referenced this pull request Oct 10, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature Request] Action Masking
4 participants