Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Multithreaded env #734

Merged
merged 51 commits into from
Jan 31, 2023
Merged

Conversation

sgrigory
Copy link
Contributor

@sgrigory sgrigory commented Dec 8, 2022

Description

Allows creating multithreaded environments based on EnvPool library. This is an alternative to ParallelEnv, which uses multiprocessing parallelism. EnvPool-based approach is faster, but less flexible, as one can only create environments implemented in EnvPool (whereas ParallelEnv parallelises arbitrary TorchRL environment)

Changes:

  • Add MultiThreadedEnv environment to torchrl/envs/vec_env.py
  • Add benchmarking script to benchmarks/benchmark_batched_envs.py
  • Add an option for CompositeSpec to return only non-nested keys - otherwise step didn't work for environments with multiple observations, which have nested observation spec (like Cheetah which returns position and velocity)
  • Add tests

In this PR I only added synchronous EnvPool interface, when batch size is equal to the number of workers, but in principle it's possible to add the asynchronous one too.

Motivation and Context

Closes #591

Benchmark

The following benchmark runs env.rollout(max_steps=1000) for three types of batched environments: SerialEnv, ParallelEnv, and MultiThreadedEnv:

pip install pandas
python benchmarks/benchmark_batched_envs.py
  num_workers_1_cpu num_workers_4_cpu num_workers_16_cpu num_workers_1_cuda num_workers_4_cuda num_workers_16_cuda
Serial, s 0.217 0.433 1.357 0.319 0.645 1.924
Parallel, s 0.206 0.234 0.275 0.296 0.502 1.428
Multithreaded, s 0.13 0.153 0.161 0.187 0.211 0.242
Gain, % 36.9 34.6 41.5 36.8 58.0 83.1

One can see that on CPU MultiThreadedEnv is ~40% faster than ParallelEnv, depending on the number of workers. Neither of the models is benefiting from GPU in the current setup, or I'm not using it in the right way.
Environment creation time is excluded for all benchmarks.

Tests

Tests has been added to tests/test_libs.py::TestEnvPool and run in a separate pipeline unittest_linux_envpool_gpu located in .cicrleci/unittest/linux_libs/scripts_envpool. Many of the tests are the same as for SerialEnv and ParallelEnv

Comparing _BatchedEnv and MultiThreadedEnv output: different RNGs

One test which is lacking is a check that results from SerialEnv/ParallelEnv and MultiThreadedEnv are the same for the same environment configuration and the seed. This is because they generate random numbers in a different way:

  • EnvPool uses C random number generator from std::uniform_real_distribution (here and here)
  • SerialEnv/ParallelEnv with Gym envs take random numbers from Gym, which uses numpy RNG (here and here)

I googled around, but didn't find a good way to use the same random number generator on both sides:

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds core functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)
  • Example (update in the folder of examples)

Checklist

Go over all the following points, and put an x in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

  • I have read the CONTRIBUTION guide (required)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 8, 2022
@sgrigory sgrigory changed the title [WIP] Add multithread env [WIP] Add multithreaded envi Dec 8, 2022
@sgrigory sgrigory changed the title [WIP] Add multithreaded envi [WIP] Add multithreaded env Dec 8, 2022
@@ -86,3 +86,7 @@ fi
pip install pip --upgrade

conda env update --file "${this_dir}/environment.yml" --prune

if [[ $os == 'Linux' ]]; then
pip install envpool
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For lib such as this one we generally do a single test in a dedicated pipeline of the file test_libs
Have a look at the jumanji pipeline for instance

torchrl/envs/vec_env.py Outdated Show resolved Hide resolved
torchrl/envs/vec_env.py Outdated Show resolved Hide resolved
torchrl/envs/vec_env.py Outdated Show resolved Hide resolved
@sgrigory sgrigory marked this pull request as ready for review January 23, 2023 16:15
@sgrigory sgrigory requested review from shagunsodhani and vmoens and removed request for shagunsodhani January 23, 2023 16:16
@sgrigory
Copy link
Contributor Author

@vmoens @shagunsodhani Feel free to have a look!

Copy link
Contributor

@shagunsodhani shagunsodhani left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

- hypothesis
- future
- cloudpickle
- gym
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should pin the gym version, unless we are sure that it works for both >=0.25 and <0.25 versions

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I'm testing locally with gym==0.24.0 and CI uses gym-0.26.2, so we might be covered here. I had to set gym_reset_return_info=True in envpool.make for this to work

):
if not _has_envpool:
raise ImportError(
f"""envpool python package was not found. Please install this dependency.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpicking but this error could happen both because envpool is missing or because treevalue is missing. Maybe we should mention both in the error

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, adjusted the error message

"""Library EnvPool only support setting a seed by recreating the environment."""
if seed is not None:
self.create_env_kwargs["seed"] = seed
self._env = self._build_env(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am curious if this is an expensive operation and if we should highlight this behavior via logging (debug level)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a logging.debug message

@sgrigory sgrigory requested review from shagunsodhani and removed request for vmoens January 27, 2023 13:11
Copy link
Contributor

@shagunsodhani shagunsodhani left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

action = tensordict.get("action")
# Action needs to be moved to CPU and converted to numpy before being passed to envpool
action = action.to(torch.device("cpu"))
step_output = self._env.step(action.detach().numpy())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this detach is redundant since the tensor is already on CPU. Not a big deal though

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure, a tensor that is part of a graph is not detached when moved to another device.
But if we want to make sure that no grad is passed we can simply decorate the _step with @torch.no_grad()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed .detach(), added the decorator. One of them is needed, otherwise it throws "RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead" even on CPU

@sgrigory sgrigory requested a review from vmoens January 30, 2023 10:58
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Great work
Can you have a look a the couple of comments I left?

**kwargs,
):
if not _has_envpool:
raise ImportError(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another way is to keep the error in the try/except:

IMPORT_ERR_ENVPOOL = err

and then do ImportError(...) from IMPORT_ERR_ENVPOOL

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, done

action = tensordict.get("action")
# Action needs to be moved to CPU and converted to numpy before being passed to envpool
action = action.to(torch.device("cpu"))
step_output = self._env.step(action.detach().numpy())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure, a tensor that is part of a graph is not detached when moved to another device.
But if we want to make sure that no grad is passed we can simply decorate the _step with @torch.no_grad()

out = env_multithreaded.rollout(max_steps=20)
assert out.device == torch.device(device)

env_multithreaded.close()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add a simple check_env_specs(env) in a separate test. This will assess if the specs are ok.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a separate test test_specs. It works for all envs except for CheetahRun-v1. For that env, it seems that
the observations returned by envpool are inconsistent with the bounding box it returns in spec: the bounding box values is strictly positive, but observations can be both positive and negative. I think I'll submit this to issues in the envpool repo.

>>> import envpool
>>> env = envpool.make("CheetahRun-v1", env_type="gym")
>>> env.spec.observation_space
Dict(position:Box([2.22507386e-308 2.22507386e-308 2.22507386e-308 2.22507386e-308
 2.22507386e-308 2.22507386e-308 2.22507386e-308 2.22507386e-308], [1.79769313e+308 1.79769313e+308 1.79769313e+308 1.79769313e+308
 1.79769313e+308 1.79769313e+308 1.79769313e+308 1.79769313e+308], (8,), float64), velocity:Box([2.22507386e-308 2.22507386e-308 2.22507386e-308 2.22507386e-308
 2.22507386e-308 2.22507386e-308 2.22507386e-308 2.22507386e-308
 2.22507386e-308], [1.79769313e+308 1.79769313e+308 1.79769313e+308 1.79769313e+308
 1.79769313e+308 1.79769313e+308 1.79769313e+308 1.79769313e+308
 1.79769313e+308], (9,), float64))
>>> env.reset()
<TreeValue 0x7fa1cf2be8e0>
├── 'position' --> array([[-0.09139574,  0.02779732, -0.06476798, -0.01931019, -0.09387728,
│                           0.00776585, -0.00777442,  0.00666921]])
└── 'velocity' --> array([[ 0.00531753, -0.00294308,  0.00210252,  0.0133823 ,  0.01259834,
                            0.01001254, -0.00010664, -0.01476834, -0.01361733]])

@vmoens
Copy link
Contributor

vmoens commented Jan 31, 2023

@vmoens
Copy link
Contributor

vmoens commented Jan 31, 2023

I think your benchmarks could be improved for parallel env
Starting the workers take a tremendous amount of time, sometimes > 30 secs, but once it's started it should be faster.
The first few rollouts collected should be discarded from the benchmark

@sgrigory
Copy link
Contributor Author

I think your benchmarks could be improved for parallel env Starting the workers take a tremendous amount of time, sometimes > 30 secs, but once it's started it should be faster. The first few rollouts collected should be discarded from the benchmark

Indeed, that was the case - after I excluded environment creation from the benchmarked time, ParallelEnv became faster than SerialEnv. Updated the benchmarking script and the PR description

@vmoens vmoens changed the title Add multithreaded env [Feature] Add multithreaded env Jan 31, 2023
@vmoens vmoens added the enhancement New feature or request label Jan 31, 2023
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@vmoens vmoens changed the title [Feature] Add multithreaded env [Feature] Multithreaded env Jan 31, 2023
@vmoens vmoens merged commit 45cdbd1 into pytorch:main Jan 31, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature Request] Multithreaded env
4 participants