-
Notifications
You must be signed in to change notification settings - Fork 327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Multithreaded env #734
Conversation
@@ -86,3 +86,7 @@ fi | |||
pip install pip --upgrade | |||
|
|||
conda env update --file "${this_dir}/environment.yml" --prune | |||
|
|||
if [[ $os == 'Linux' ]]; then | |||
pip install envpool |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For lib such as this one we generally do a single test in a dedicated pipeline of the file test_libs
Have a look at the jumanji pipeline for instance
@vmoens @shagunsodhani Feel free to have a look! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
- hypothesis | ||
- future | ||
- cloudpickle | ||
- gym |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should pin the gym version, unless we are sure that it works for both >=0.25 and <0.25 versions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually I'm testing locally with gym==0.24.0
and CI uses gym-0.26.2
, so we might be covered here. I had to set gym_reset_return_info=True
in envpool.make
for this to work
torchrl/envs/vec_env.py
Outdated
): | ||
if not _has_envpool: | ||
raise ImportError( | ||
f"""envpool python package was not found. Please install this dependency. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nitpicking but this error could happen both because envpool
is missing or because treevalue
is missing. Maybe we should mention both in the error
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, adjusted the error message
"""Library EnvPool only support setting a seed by recreating the environment.""" | ||
if seed is not None: | ||
self.create_env_kwargs["seed"] = seed | ||
self._env = self._build_env( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am curious if this is an expensive operation and if we should highlight this behavior via logging (debug level)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a logging.debug
message
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
torchrl/envs/vec_env.py
Outdated
action = tensordict.get("action") | ||
# Action needs to be moved to CPU and converted to numpy before being passed to envpool | ||
action = action.to(torch.device("cpu")) | ||
step_output = self._env.step(action.detach().numpy()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this detach
is redundant since the tensor is already on CPU. Not a big deal though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure, a tensor that is part of a graph is not detached when moved to another device.
But if we want to make sure that no grad is passed we can simply decorate the _step with @torch.no_grad()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed .detach()
, added the decorator. One of them is needed, otherwise it throws "RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead" even on CPU
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Great work
Can you have a look a the couple of comments I left?
**kwargs, | ||
): | ||
if not _has_envpool: | ||
raise ImportError( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another way is to keep the error in the try/except:
IMPORT_ERR_ENVPOOL = err
and then do ImportError(...) from IMPORT_ERR_ENVPOOL
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, done
torchrl/envs/vec_env.py
Outdated
action = tensordict.get("action") | ||
# Action needs to be moved to CPU and converted to numpy before being passed to envpool | ||
action = action.to(torch.device("cpu")) | ||
step_output = self._env.step(action.detach().numpy()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure, a tensor that is part of a graph is not detached when moved to another device.
But if we want to make sure that no grad is passed we can simply decorate the _step with @torch.no_grad()
out = env_multithreaded.rollout(max_steps=20) | ||
assert out.device == torch.device(device) | ||
|
||
env_multithreaded.close() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should add a simple check_env_specs(env)
in a separate test. This will assess if the specs are ok.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a separate test test_specs
. It works for all envs except for CheetahRun-v1
. For that env, it seems that
the observations returned by envpool are inconsistent with the bounding box it returns in spec: the bounding box values is strictly positive, but observations can be both positive and negative. I think I'll submit this to issues in the envpool repo.
>>> import envpool
>>> env = envpool.make("CheetahRun-v1", env_type="gym")
>>> env.spec.observation_space
Dict(position:Box([2.22507386e-308 2.22507386e-308 2.22507386e-308 2.22507386e-308
2.22507386e-308 2.22507386e-308 2.22507386e-308 2.22507386e-308], [1.79769313e+308 1.79769313e+308 1.79769313e+308 1.79769313e+308
1.79769313e+308 1.79769313e+308 1.79769313e+308 1.79769313e+308], (8,), float64), velocity:Box([2.22507386e-308 2.22507386e-308 2.22507386e-308 2.22507386e-308
2.22507386e-308 2.22507386e-308 2.22507386e-308 2.22507386e-308
2.22507386e-308], [1.79769313e+308 1.79769313e+308 1.79769313e+308 1.79769313e+308
1.79769313e+308 1.79769313e+308 1.79769313e+308 1.79769313e+308
1.79769313e+308], (9,), float64))
>>> env.reset()
<TreeValue 0x7fa1cf2be8e0>
├── 'position' --> array([[-0.09139574, 0.02779732, -0.06476798, -0.01931019, -0.09387728,
│ 0.00776585, -0.00777442, 0.00666921]])
└── 'velocity' --> array([[ 0.00531753, -0.00294308, 0.00210252, 0.0133823 , 0.01259834,
0.01001254, -0.00010664, -0.01476834, -0.01361733]])
I think your benchmarks could be improved for parallel env |
Indeed, that was the case - after I excluded environment creation from the benchmarked time, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Description
Allows creating multithreaded environments based on EnvPool library. This is an alternative to
ParallelEnv
, which uses multiprocessing parallelism. EnvPool-based approach is faster, but less flexible, as one can only create environments implemented in EnvPool (whereasParallelEnv
parallelises arbitrary TorchRL environment)Changes:
MultiThreadedEnv
environment totorchrl/envs/vec_env.py
benchmarks/benchmark_batched_envs.py
CompositeSpec
to return only non-nested keys - otherwisestep
didn't work for environments with multiple observations, which have nested observation spec (like Cheetah which returns position and velocity)In this PR I only added synchronous EnvPool interface, when batch size is equal to the number of workers, but in principle it's possible to add the asynchronous one too.
Motivation and Context
Closes #591
Benchmark
The following benchmark runs
env.rollout(max_steps=1000)
for three types of batched environments:SerialEnv
,ParallelEnv
, andMultiThreadedEnv
:One can see that on CPU
MultiThreadedEnv
is ~40% faster thanParallelEnv
, depending on the number of workers. Neither of the models is benefiting from GPU in the current setup, or I'm not using it in the right way.Environment creation time is excluded for all benchmarks.
Tests
Tests has been added to
tests/test_libs.py::TestEnvPool
and run in a separate pipelineunittest_linux_envpool_gpu
located in.cicrleci/unittest/linux_libs/scripts_envpool
. Many of the tests are the same as forSerialEnv
andParallelEnv
Comparing
_BatchedEnv
andMultiThreadedEnv
output: different RNGsOne test which is lacking is a check that results from
SerialEnv
/ParallelEnv
andMultiThreadedEnv
are the same for the same environment configuration and the seed. This is because they generate random numbers in a different way:std::uniform_real_distribution
(here and here)SerialEnv
/ParallelEnv
with Gym envs take random numbers from Gym, which usesnumpy
RNG (here and here)I googled around, but didn't find a good way to use the same random number generator on both sides:
Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
x
in all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!