-
Notifications
You must be signed in to change notification settings - Fork 327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BugFix] Fix Gym Categorical/One-hot issues #1482
Conversation
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_single | 0.1102s | 0.1095s | 9.1290 Ops/s | 9.0666 Ops/s | |
test_sync | 0.1302s | 62.0725ms | 16.1102 Ops/s | 16.0113 Ops/s | |
test_async | 0.1652s | 58.3687ms | 17.1325 Ops/s | 17.1627 Ops/s | |
test_simple | 0.5674s | 0.5133s | 1.9480 Ops/s | 1.9324 Ops/s | |
test_transformed | 1.3218s | 1.2744s | 0.7847 Ops/s | 0.7762 Ops/s | |
test_serial | 1.4943s | 1.4490s | 0.6901 Ops/s | 0.6875 Ops/s | |
test_parallel | 1.3335s | 1.2725s | 0.7859 Ops/s | 0.7496 Ops/s | |
test_step_mdp_speed[True-True-True-True-True] | 1.5282ms | 42.7074μs | 23.4151 KOps/s | 23.8854 KOps/s | |
test_step_mdp_speed[True-True-True-True-False] | 93.0990μs | 24.2290μs | 41.2728 KOps/s | 41.9321 KOps/s | |
test_step_mdp_speed[True-True-True-False-True] | 74.9980μs | 29.4744μs | 33.9278 KOps/s | 34.4680 KOps/s | |
test_step_mdp_speed[True-True-True-False-False] | 46.2990μs | 16.3861μs | 61.0272 KOps/s | 61.1041 KOps/s | |
test_step_mdp_speed[True-True-False-True-True] | 69.9990μs | 43.8239μs | 22.8186 KOps/s | 23.1304 KOps/s | |
test_step_mdp_speed[True-True-False-True-False] | 80.9990μs | 25.9944μs | 38.4698 KOps/s | 38.9563 KOps/s | |
test_step_mdp_speed[True-True-False-False-True] | 0.1447ms | 31.8004μs | 31.4461 KOps/s | 32.1993 KOps/s | |
test_step_mdp_speed[True-True-False-False-False] | 9.8956ms | 18.6512μs | 53.6158 KOps/s | 53.8915 KOps/s | |
test_step_mdp_speed[True-False-True-True-True] | 78.3990μs | 46.3751μs | 21.5633 KOps/s | 22.2949 KOps/s | |
test_step_mdp_speed[True-False-True-True-False] | 54.1990μs | 27.7127μs | 36.0845 KOps/s | 36.4132 KOps/s | |
test_step_mdp_speed[True-False-True-False-True] | 62.8980μs | 31.2597μs | 31.9901 KOps/s | 32.5799 KOps/s | |
test_step_mdp_speed[True-False-True-False-False] | 43.3990μs | 18.2779μs | 54.7109 KOps/s | 54.4959 KOps/s | |
test_step_mdp_speed[True-False-False-True-True] | 0.1066ms | 48.0050μs | 20.8311 KOps/s | 21.5331 KOps/s | |
test_step_mdp_speed[True-False-False-True-False] | 55.0990μs | 29.0359μs | 34.4401 KOps/s | 34.5332 KOps/s | |
test_step_mdp_speed[True-False-False-False-True] | 0.2070ms | 33.1712μs | 30.1467 KOps/s | 31.0371 KOps/s | |
test_step_mdp_speed[True-False-False-False-False] | 49.7990μs | 20.0546μs | 49.8639 KOps/s | 50.6537 KOps/s | |
test_step_mdp_speed[False-True-True-True-True] | 0.1008ms | 46.5299μs | 21.4916 KOps/s | 22.3730 KOps/s | |
test_step_mdp_speed[False-True-True-True-False] | 51.2990μs | 27.8096μs | 35.9588 KOps/s | 36.5608 KOps/s | |
test_step_mdp_speed[False-True-True-False-True] | 81.7980μs | 34.2357μs | 29.2093 KOps/s | 29.3084 KOps/s | |
test_step_mdp_speed[False-True-True-False-False] | 2.9407ms | 19.8683μs | 50.3315 KOps/s | 50.6344 KOps/s | |
test_step_mdp_speed[False-True-False-True-True] | 71.6980μs | 47.4572μs | 21.0716 KOps/s | 21.3675 KOps/s | |
test_step_mdp_speed[False-True-False-True-False] | 54.4990μs | 29.2968μs | 34.1334 KOps/s | 34.3499 KOps/s | |
test_step_mdp_speed[False-True-False-False-True] | 91.7990μs | 35.7134μs | 28.0007 KOps/s | 28.2532 KOps/s | |
test_step_mdp_speed[False-True-False-False-False] | 0.4163ms | 21.6275μs | 46.2374 KOps/s | 45.3184 KOps/s | |
test_step_mdp_speed[False-False-True-True-True] | 94.0980μs | 48.9869μs | 20.4136 KOps/s | 20.7242 KOps/s | |
test_step_mdp_speed[False-False-True-True-False] | 64.1990μs | 31.1682μs | 32.0840 KOps/s | 32.2594 KOps/s | |
test_step_mdp_speed[False-False-True-False-True] | 82.3990μs | 35.3023μs | 28.3268 KOps/s | 28.2456 KOps/s | |
test_step_mdp_speed[False-False-True-False-False] | 53.0990μs | 21.6686μs | 46.1498 KOps/s | 46.7219 KOps/s | |
test_step_mdp_speed[False-False-False-True-True] | 79.6990μs | 50.5331μs | 19.7890 KOps/s | 20.0389 KOps/s | |
test_step_mdp_speed[False-False-False-True-False] | 63.0990μs | 32.8532μs | 30.4384 KOps/s | 30.6337 KOps/s | |
test_step_mdp_speed[False-False-False-False-True] | 94.6980μs | 36.7892μs | 27.1819 KOps/s | 27.4430 KOps/s | |
test_step_mdp_speed[False-False-False-False-False] | 69.8990μs | 23.0578μs | 43.3693 KOps/s | 42.6179 KOps/s | |
test_values[generalized_advantage_estimate-True-True] | 16.4732ms | 13.7711ms | 72.6156 Ops/s | 73.1976 Ops/s | |
test_values[vec_generalized_advantage_estimate-True-True] | 56.5103ms | 51.5568ms | 19.3961 Ops/s | 19.4600 Ops/s | |
test_values[td0_return_estimate-False-False] | 0.3135ms | 0.2077ms | 4.8135 KOps/s | 4.3200 KOps/s | |
test_values[td1_return_estimate-False-False] | 13.6491ms | 13.4489ms | 74.3555 Ops/s | 72.9317 Ops/s | |
test_values[vec_td1_return_estimate-False-False] | 56.9017ms | 51.6954ms | 19.3441 Ops/s | 19.3452 Ops/s | |
test_values[td_lambda_return_estimate-True-False] | 32.2558ms | 31.9904ms | 31.2593 Ops/s | 31.0846 Ops/s | |
test_values[vec_td_lambda_return_estimate-True-False] | 57.1222ms | 51.2620ms | 19.5076 Ops/s | 19.4730 Ops/s | |
test_gae_speed[generalized_advantage_estimate-False-1-512] | 12.1765ms | 12.0861ms | 82.7395 Ops/s | 82.6357 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] | 7.0431ms | 2.4177ms | 413.6238 Ops/s | 415.5445 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] | 5.0844ms | 0.4140ms | 2.4156 KOps/s | 2.4239 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] | 57.0903ms | 53.1296ms | 18.8219 Ops/s | 19.0123 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] | 4.7819ms | 3.7588ms | 266.0436 Ops/s | 266.1431 Ops/s | |
test_dqn_speed | 5.8706ms | 1.7173ms | 582.3053 Ops/s | 578.4364 Ops/s | |
test_ddpg_speed | 7.5033ms | 2.4548ms | 407.3693 Ops/s | 390.2642 Ops/s | |
test_sac_speed | 16.8155ms | 8.0114ms | 124.8214 Ops/s | 127.6408 Ops/s | |
test_redq_speed | 31.2155ms | 15.5108ms | 64.4711 Ops/s | 64.9979 Ops/s | |
test_redq_deprec_speed | 18.2706ms | 13.1974ms | 75.7728 Ops/s | 78.4537 Ops/s | |
test_td3_speed | 17.6213ms | 9.9566ms | 100.4354 Ops/s | 102.8943 Ops/s | |
test_cql_speed | 36.0455ms | 30.6579ms | 32.6180 Ops/s | 38.1358 Ops/s | |
test_a2c_speed | 11.0218ms | 5.5491ms | 180.2100 Ops/s | 182.7844 Ops/s | |
test_ppo_speed | 11.0680ms | 5.8456ms | 171.0695 Ops/s | 167.9898 Ops/s | |
test_reinforce_speed | 9.2236ms | 4.2535ms | 235.0980 Ops/s | 238.5584 Ops/s | |
test_iql_speed | 31.0680ms | 22.2708ms | 44.9018 Ops/s | 46.0064 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 3.1820ms | 2.5867ms | 386.5937 Ops/s | 370.6991 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 5.3843ms | 2.7509ms | 363.5177 Ops/s | 360.9264 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 5.3515ms | 2.7006ms | 370.2816 Ops/s | 358.9336 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 0.1888s | 3.0933ms | 323.2795 Ops/s | 377.3961 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 4.4698ms | 2.6994ms | 370.4553 Ops/s | 359.6077 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 4.8889ms | 2.7372ms | 365.3346 Ops/s | 319.8752 Ops/s | |
test_sample_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 3.0180ms | 2.7122ms | 368.7097 Ops/s | 374.2409 Ops/s | |
test_sample_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 4.6739ms | 2.7450ms | 364.2981 Ops/s | 369.2402 Ops/s | |
test_sample_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 5.0467ms | 2.7498ms | 363.6626 Ops/s | 370.1911 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 3.0124ms | 2.6321ms | 379.9265 Ops/s | 391.7751 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 4.7617ms | 2.7200ms | 367.6449 Ops/s | 369.0693 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 4.3874ms | 2.7238ms | 367.1295 Ops/s | 368.9615 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 3.2076ms | 2.5783ms | 387.8549 Ops/s | 393.2668 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 4.8989ms | 2.7266ms | 366.7559 Ops/s | 356.5545 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 5.1270ms | 2.7036ms | 369.8810 Ops/s | 357.1350 Ops/s | |
test_iterate_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 3.3879ms | 2.6171ms | 382.1043 Ops/s | 368.9005 Ops/s | |
test_iterate_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 4.3614ms | 2.7229ms | 367.2619 Ops/s | 360.3607 Ops/s | |
test_iterate_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 4.7285ms | 2.7324ms | 365.9839 Ops/s | 354.8254 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-400] | 0.2228s | 26.5227ms | 37.7036 Ops/s | 37.9018 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] | 0.1169s | 26.1540ms | 38.2350 Ops/s | 38.3001 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] | 0.1183s | 24.1444ms | 41.4175 Ops/s | 41.4172 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] | 0.1181s | 26.2229ms | 38.1346 Ops/s | 41.3961 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] | 0.1210s | 24.4216ms | 40.9474 Ops/s | 41.6724 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] | 0.1183s | 26.1193ms | 38.2859 Ops/s | 38.6997 Ops/s | |
test_populate_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] | 0.1208s | 22.2808ms | 44.8816 Ops/s | 41.6920 Ops/s | |
test_populate_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] | 0.1169s | 25.9157ms | 38.5866 Ops/s | 38.5538 Ops/s | |
test_populate_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] | 0.1143s | 23.8145ms | 41.9912 Ops/s | 38.5787 Ops/s |
and action.size == 1 | ||
): | ||
# some envs require an integer for indexing | ||
action = int(action) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tiny: You may want to use action.item() as it is more efficient
scalar = np.array([1])
timeit.timeit(lambda: scalar.item())
0.05105495895259082
timeit.timeit(lambda: int(scalar))
0.34526545903645456
Moreover, it seems this functionality will be deprecated in future:
DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
Description
Makes sure that one-hot and categorical work ok when a gym env expects an integer.