-
Notifications
You must be signed in to change notification settings - Fork 327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Device transform #1472
Conversation
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_single | 0.1706s | 0.1669s | 5.9899 Ops/s | 5.8858 Ops/s | |
test_sync | 0.1746s | 96.2884ms | 10.3855 Ops/s | 10.3557 Ops/s | |
test_async | 0.1965s | 91.2416ms | 10.9599 Ops/s | 11.2341 Ops/s | |
test_simple | 0.7990s | 0.7267s | 1.3761 Ops/s | 1.3444 Ops/s | |
test_transformed | 2.0001s | 1.9342s | 0.5170 Ops/s | 0.5138 Ops/s | |
test_serial | 2.2181s | 2.2111s | 0.4523 Ops/s | 0.4400 Ops/s | |
test_parallel | 2.0208s | 1.9084s | 0.5240 Ops/s | 0.5193 Ops/s | |
test_step_mdp_speed[True-True-True-True-True] | 0.2062ms | 51.3868μs | 19.4602 KOps/s | 18.6945 KOps/s | |
test_step_mdp_speed[True-True-True-True-False] | 58.4000μs | 29.0685μs | 34.4016 KOps/s | 33.9803 KOps/s | |
test_step_mdp_speed[True-True-True-False-True] | 80.0010μs | 35.8117μs | 27.9238 KOps/s | 27.1530 KOps/s | |
test_step_mdp_speed[True-True-True-False-False] | 52.2010μs | 19.4511μs | 51.4109 KOps/s | 49.8145 KOps/s | |
test_step_mdp_speed[True-True-False-True-True] | 0.1924ms | 53.2941μs | 18.7638 KOps/s | 18.3483 KOps/s | |
test_step_mdp_speed[True-True-False-True-False] | 59.6010μs | 31.0975μs | 32.1569 KOps/s | 31.5577 KOps/s | |
test_step_mdp_speed[True-True-False-False-True] | 77.7010μs | 38.0246μs | 26.2988 KOps/s | 25.5693 KOps/s | |
test_step_mdp_speed[True-True-False-False-False] | 49.4010μs | 21.7991μs | 45.8734 KOps/s | 45.0826 KOps/s | |
test_step_mdp_speed[True-False-True-True-True] | 0.1116ms | 54.3931μs | 18.3847 KOps/s | 17.5988 KOps/s | |
test_step_mdp_speed[True-False-True-True-False] | 0.1011ms | 32.6938μs | 30.5869 KOps/s | 29.4644 KOps/s | |
test_step_mdp_speed[True-False-True-False-True] | 85.3000μs | 37.7121μs | 26.5167 KOps/s | 25.2090 KOps/s | |
test_step_mdp_speed[True-False-True-False-False] | 0.3260ms | 21.7206μs | 46.0392 KOps/s | 44.4008 KOps/s | |
test_step_mdp_speed[True-False-False-True-True] | 0.1138ms | 57.7712μs | 17.3097 KOps/s | 16.6760 KOps/s | |
test_step_mdp_speed[True-False-False-True-False] | 91.5010μs | 34.5483μs | 28.9450 KOps/s | 27.4407 KOps/s | |
test_step_mdp_speed[True-False-False-False-True] | 98.4010μs | 40.4414μs | 24.7271 KOps/s | 23.6901 KOps/s | |
test_step_mdp_speed[True-False-False-False-False] | 61.4010μs | 24.0787μs | 41.5304 KOps/s | 40.6046 KOps/s | |
test_step_mdp_speed[False-True-True-True-True] | 0.1383ms | 54.5561μs | 18.3298 KOps/s | 17.8572 KOps/s | |
test_step_mdp_speed[False-True-True-True-False] | 76.7010μs | 32.3930μs | 30.8709 KOps/s | 29.4453 KOps/s | |
test_step_mdp_speed[False-True-True-False-True] | 80.0010μs | 43.6621μs | 22.9032 KOps/s | 21.3792 KOps/s | |
test_step_mdp_speed[False-True-True-False-False] | 57.6010μs | 24.4134μs | 40.9611 KOps/s | 39.4962 KOps/s | |
test_step_mdp_speed[False-True-False-True-True] | 0.1360ms | 56.6163μs | 17.6628 KOps/s | 17.1870 KOps/s | |
test_step_mdp_speed[False-True-False-True-False] | 1.9990ms | 35.0609μs | 28.5218 KOps/s | 28.2080 KOps/s | |
test_step_mdp_speed[False-True-False-False-True] | 79.6010μs | 46.6426μs | 21.4396 KOps/s | 20.9367 KOps/s | |
test_step_mdp_speed[False-True-False-False-False] | 0.6048ms | 27.1760μs | 36.7972 KOps/s | 36.3694 KOps/s | |
test_step_mdp_speed[False-False-True-True-True] | 93.0010μs | 57.5965μs | 17.3622 KOps/s | 16.3340 KOps/s | |
test_step_mdp_speed[False-False-True-True-False] | 99.4010μs | 36.6149μs | 27.3113 KOps/s | 25.9249 KOps/s | |
test_step_mdp_speed[False-False-True-False-True] | 0.1027ms | 46.2558μs | 21.6189 KOps/s | 20.7751 KOps/s | |
test_step_mdp_speed[False-False-True-False-False] | 0.1035ms | 26.5307μs | 37.6922 KOps/s | 36.5885 KOps/s | |
test_step_mdp_speed[False-False-False-True-True] | 0.1085ms | 60.7695μs | 16.4556 KOps/s | 15.7390 KOps/s | |
test_step_mdp_speed[False-False-False-True-False] | 73.1000μs | 39.1460μs | 25.5454 KOps/s | 24.7423 KOps/s | |
test_step_mdp_speed[False-False-False-False-True] | 0.1069ms | 48.0814μs | 20.7981 KOps/s | 19.9347 KOps/s | |
test_step_mdp_speed[False-False-False-False-False] | 90.0010μs | 28.3642μs | 35.2556 KOps/s | 33.6793 KOps/s | |
test_values[generalized_advantage_estimate-True-True] | 20.4198ms | 15.3463ms | 65.1621 Ops/s | 64.4635 Ops/s | |
test_values[vec_generalized_advantage_estimate-True-True] | 59.0384ms | 48.9144ms | 20.4439 Ops/s | 20.5920 Ops/s | |
test_values[td0_return_estimate-False-False] | 0.3849ms | 0.2340ms | 4.2739 KOps/s | 4.2871 KOps/s | |
test_values[td1_return_estimate-False-False] | 15.3761ms | 14.9407ms | 66.9311 Ops/s | 66.5584 Ops/s | |
test_values[vec_td1_return_estimate-False-False] | 60.2913ms | 50.9271ms | 19.6359 Ops/s | 20.6701 Ops/s | |
test_values[td_lambda_return_estimate-True-False] | 38.2264ms | 35.8200ms | 27.9174 Ops/s | 27.5790 Ops/s | |
test_values[vec_td_lambda_return_estimate-True-False] | 57.4065ms | 48.5279ms | 20.6067 Ops/s | 20.8179 Ops/s | |
test_gae_speed[generalized_advantage_estimate-False-1-512] | 13.8016ms | 13.3419ms | 74.9516 Ops/s | 73.1015 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] | 6.1508ms | 3.9497ms | 253.1861 Ops/s | 257.6675 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] | 7.1712ms | 0.5684ms | 1.7594 KOps/s | 1.7519 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] | 70.2361ms | 63.8463ms | 15.6626 Ops/s | 15.7069 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] | 4.0284ms | 3.2687ms | 305.9304 Ops/s | 293.2661 Ops/s | |
test_dqn_speed | 5.7149ms | 2.1868ms | 457.2837 Ops/s | 391.0891 Ops/s | |
test_ddpg_speed | 9.6905ms | 3.2701ms | 305.8012 Ops/s | 292.2160 Ops/s | |
test_sac_speed | 15.3859ms | 9.6827ms | 103.2775 Ops/s | 103.0539 Ops/s | |
test_redq_speed | 26.1642ms | 18.6491ms | 53.6220 Ops/s | 52.2202 Ops/s | |
test_redq_deprec_speed | 24.8650ms | 16.1853ms | 61.7844 Ops/s | 65.2020 Ops/s | |
test_td3_speed | 21.8242ms | 12.0606ms | 82.9144 Ops/s | 82.4932 Ops/s | |
test_cql_speed | 47.9024ms | 44.0649ms | 22.6938 Ops/s | 26.4305 Ops/s | |
test_a2c_speed | 13.1931ms | 6.3226ms | 158.1620 Ops/s | 156.6433 Ops/s | |
test_ppo_speed | 16.5614ms | 6.9434ms | 144.0223 Ops/s | 136.5848 Ops/s | |
test_reinforce_speed | 13.6830ms | 5.0145ms | 199.4201 Ops/s | 198.7375 Ops/s | |
test_iql_speed | 36.1572ms | 26.7018ms | 37.4506 Ops/s | 37.5658 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 3.8738ms | 3.1655ms | 315.9098 Ops/s | 313.1811 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 6.5373ms | 3.4157ms | 292.7661 Ops/s | 298.2416 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 5.5589ms | 3.3167ms | 301.5002 Ops/s | 301.1843 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 3.5864ms | 3.1307ms | 319.4217 Ops/s | 243.4505 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 5.5815ms | 3.3614ms | 297.4982 Ops/s | 300.1275 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 5.6793ms | 3.3209ms | 301.1205 Ops/s | 293.2555 Ops/s | |
test_sample_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 4.4857ms | 3.1962ms | 312.8671 Ops/s | 316.5940 Ops/s | |
test_sample_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 6.2774ms | 3.3359ms | 299.7680 Ops/s | 300.1600 Ops/s | |
test_sample_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.1632s | 3.8027ms | 262.9731 Ops/s | 298.4259 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 4.1782ms | 3.0862ms | 324.0265 Ops/s | 305.2262 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 5.7881ms | 3.3166ms | 301.5157 Ops/s | 293.5156 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 5.7985ms | 3.3784ms | 295.9952 Ops/s | 293.9190 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 4.3294ms | 3.0976ms | 322.8313 Ops/s | 305.3118 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 4.9160ms | 3.3900ms | 294.9852 Ops/s | 296.8634 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 7.0524ms | 3.3567ms | 297.9141 Ops/s | 293.9346 Ops/s | |
test_iterate_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 4.1296ms | 3.0831ms | 324.3459 Ops/s | 316.4863 Ops/s | |
test_iterate_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 6.6937ms | 3.3646ms | 297.2080 Ops/s | 296.2106 Ops/s | |
test_iterate_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 6.7701ms | 3.3587ms | 297.7371 Ops/s | 295.3183 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-400] | 0.2695s | 35.4243ms | 28.2292 Ops/s | 28.8679 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] | 0.1582s | 30.0272ms | 33.3031 Ops/s | 29.3111 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] | 0.1658s | 33.1355ms | 30.1792 Ops/s | 32.4944 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] | 0.1565s | 29.5430ms | 33.8489 Ops/s | 29.2707 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] | 0.1580s | 32.7299ms | 30.5531 Ops/s | 29.8215 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] | 0.1698s | 32.9321ms | 30.3655 Ops/s | 32.4205 Ops/s | |
test_populate_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] | 0.1633s | 30.4084ms | 32.8856 Ops/s | 29.4354 Ops/s | |
test_populate_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] | 0.1609s | 32.9677ms | 30.3327 Ops/s | 32.2089 Ops/s | |
test_populate_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] | 0.1639s | 30.6880ms | 32.5860 Ops/s | 29.5246 Ops/s |
@@ -2708,6 +2738,87 @@ def __init__( | |||
super().__init__(torch.double, torch.float, in_keys, in_keys_inv) | |||
|
|||
|
|||
class DeviceCastTransform(Transform): | |||
"""Casts the env device. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this really casting the env device?
Isn't it transforming the device of the data?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a question:
In environments like VMAS where there is an internal state that sits on a device, the way I thought that the environment is moved to another device is env.to(device)
If i understand correctly here we are casting the env data and its specs to another device, but not its internal state (we do not call env.to()
).
does this make sense to use with a parent environment? when is this transform preferred to env.to(device)
?
the only case i can imagine is when we need to keep an env of a specific device, apply some transforms there and then move the data to another device
This is to address #1198 where the issue is that if the env naturally sits on MPS we can't use float64. So first you must transform the data into float32 and then cast to device. Doing env.to(device) will not work but this transform will. |
Could you please explain exactly how to do this? I am confused. I am following this tutorial https://pytorch.org/tutorials/intermediate/reinforcement_ppo.html and get the MPS flaot64 error when running line: |
Can I ask what the value of device is in your case? |
device="mps". I believe I solved it with this (after quite a few hours of trying different things!!!):
Could you please kindly confirm if what Ive done is correct? I am on Apple M2 max trying to use MPS. |
Everything was progressing smoothly through the tutorial: https://pytorch.org/tutorials/intermediate/reinforcement_ppo.html however at this code I get an error again about MPS. Please could you kindly advise syntax to solve this issue? I double checked everything seems to be on mps, so I don't understand where the error is coming from.
|
@EkaterinaAbramova sorry you had this terrible experience, we should document things better for MPS. Regarding the error I will have a look into it, shouldn't be too difficult to solve. Have you tried moving the env = TransformedEnv(
base_env,
Compose(
DoubleToFloat(),
DeviceCastTransform(device=device, orig_device="cpu"),
ObservationNorm(in_keys=["observation"]), # normalise observations (make it about Standard Normal)
StepCounter(), # count the number of steps before the environment is terminated
),
) Like this the buffers in the |
@vmoens thank you for swiftly helping me, this issue is quite urgent, so very glad that you had the suggestion. It makes sense and I tried it, however this way around I get this error |
You need to call init_stats on the obs norm transform env.transform[-2].init_stats(...) Because the transform has changed place |
Ok so that's an interesting bug, which basically boils down to some internal machinery within rollout, resets and transforms. simple_env = TransformedEnv(
base_env,
Compose(
DoubleToFloat(),
DeviceCastTransform(device=device, orig_device="cpu"),
)
)
td0 = simple_env.rollout(100)
loc = td0["observation"].mean(dim=0)
scale = td0["observation"].std(dim=0)
env = TransformedEnv(
base_env,
Compose(
DoubleToFloat(),
DeviceCastTransform(device=device, orig_device="cpu"),
ObservationNorm(in_keys=["observation"], loc=loc, scale=scale),
StepCounter(),
),
) Hopefully that should help! |
#1589 will solve your problem! |
I have read through that thread, but I don't understand what I am supposed to do. Sorry!! Should I download the packages again from fresh? Or download something particular from that page? |
Just wait until we merge it and then you can reinstall from got and things should work ok |
@EkaterinaAbramova you should be good to go now! |
I can reproduce your issue, let me push a fix! |
Description
Adds a
DeviceCastTransform
transform to move environment data from one device to another.As part of this PR, transforms now can transform the device of the parent env through
transform.transform_env_device
.