Skip to content

Commit

Permalink
[Feature] torch.distributed collectors (#934)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Mar 16, 2023
1 parent ee58306 commit da88aad
Show file tree
Hide file tree
Showing 28 changed files with 3,939 additions and 62 deletions.
9 changes: 7 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,10 @@ And it is `functorch` and `torch.compile` compatible!
```
</details>

- multiprocess [data collectors](torchrl/collectors/collectors.py)<sup>(2)</sup> that work synchronously or asynchronously.
Through the use of TensorDict, TorchRL's training loops are made very similar to regular training loops in supervised
- multiprocess and distributed [data collectors](torchrl/collectors/collectors.py)<sup>(2)</sup>
that work synchronously or asynchronously.
Through the use of TensorDict, TorchRL's training loops are made very similar
to regular training loops in supervised
learning (although the "dataloader" -- read data collector -- is modified on-the-fly):
<details>
<summary>Code</summary>
Expand All @@ -302,6 +304,9 @@ And it is `functorch` and `torch.compile` compatible!
```
</details>

Check our [distributed collector examples](examples/distributed/collectors) to
learn more about ultra-fast data collection with TorchRL.

- efficient<sup>(2)</sup> and generic<sup>(1)</sup> [replay buffers](torchrl/data/replay_buffers/replay_buffers.py) with modularized storage:
<details>
<summary>Code</summary>
Expand Down
54 changes: 49 additions & 5 deletions docs/source/reference/collectors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ TorchRL's data collectors accept two main arguments: an environment (or a list o
environment constructors) and a policy. They will iteratively execute an environment
step and a policy query over a defined number of steps before delivering a stack of
the data collected to the user. Environments will be reset whenever they reach a done
state, and/or after a predifined number of steps.
state, and/or after a predefined number of steps.

Because data collection is a potentially compute heavy process, it is crucial to
configure the execution hyperparameters appropriately.
Expand All @@ -21,7 +21,7 @@ class will execute the data collection on the training worker. The :obj:`MultiSy
will split the workload across an number of workers and aggregate the results that
will be delivered to the training worker. Finally, the :obj:`MultiaSyncDataCollector` will
execute the data collection on several workers and deliver the first batch of results
that it can gather. This execution will occur continuously and concomittantly with
that it can gather. This execution will occur continuously and concomitantly with
the training of the networks: this implies that the weights of the policy that
is used for the data collection may slightly lag the configuration of the policy
on the training worker. Therefore, although this class may be the fastest to collect
Expand All @@ -35,7 +35,7 @@ by setting `update_at_each_batch=True` in the constructor.
The second parameter to consider (in the remote settings) is the device where the
data will be collected and the device where the environment and policy operations
will be executed. For instance, a policy executed on CPU may be slower than one
executed on CUDA. When multiple inference workers run concomittantly, dispatching
executed on CUDA. When multiple inference workers run concomitantly, dispatching
the compute workload across the available devices may speed up the collection or
avoid OOM errors. Finally, the choice of the batch size and passing device (ie the
device where the data will be stored while waiting to be passed to the collection
Expand All @@ -58,8 +58,8 @@ Besides those compute parameters, users may choose to configure the following pa
- reset_when_done: whether environments should be reset when reaching a done state.


Data collectors
---------------
Single node data collectors
---------------------------
.. currentmodule:: torchrl.collectors.collectors

.. autosummary::
Expand All @@ -73,6 +73,50 @@ Data collectors
aSyncDataCollector


Distributed data collectors
---------------------------
TorchRL provides a set of distributed data collectors. These tools support
multiple backends (``'gloo'``, ``'nccl'``, ``'mpi'`` with the :class:`~.DistributedDataCollector`
or PyTorch RPC with :class:`~.RPCDataCollector`) and launchers (``'ray'``,
``submitit`` or ``torch.multiprocessing``).
They can be efficiently used in synchronous or asynchronous mode, on a single
node or across multiple nodes.

*Resources*: Find examples for these collectors in the
`dedicated folder <https://github.com/pytorch/rl/examples/distributed/collectors>`_.

.. note::
*Choosing the sub-collector*: All distributed collectors support the various single machine collectors.
One may wonder why using a :class:`MultiSyncDataCollector` or a :class:`torchrl.envs.ParallelEnv`
instead. In general, multiprocessed collectors have a lower IO footprint than
parallel environments which need to communicate at each step. Yet, the model specs
play a role in the opposite direction, since using parallel environments will
result in a faster execution of the policy (and/or transforms) since these
operations will be vectorized.

.. note::
*Choosing the device of a collector (or a parallel environment)*: Sharing data
among processes is achieved via shared-memory buffers with parallel environment
and multiprocessed environments executed on CPU. Depending on the capabilities
of the machine being used, this may be prohibitively slow compared to sharing
data on GPU which is natively supported by cuda drivers.
In practice, this means that using the ``device="cpu"`` keyword argument when
building a parallel environment or collector can result in a slower collection
than using ``device="cuda"`` when available.


.. currentmodule:: torchrl.collectors.distributed

.. autosummary::
:toctree: generated/
:template: rl_template.rst

DistributedDataCollector
RPCDataCollector
DistributedSyncDataCollector
submitit_delayed_launcher


Helper functions
----------------

Expand Down
8 changes: 4 additions & 4 deletions docs/source/reference/trainers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,10 @@ Loggers
:template: rl_template_fun.rst

Logger
CSVLogger
MLFlowLogger
TensorboardLogger
WandbLogger
csv.CSVLogger
mlflow.MLFlowLogger
tensorboard.TensorboardLogger
wandb.WandbLogger
get_logger
generate_exp_name

Expand Down
12 changes: 12 additions & 0 deletions examples/distributed/collectors/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Distributed data collection examples

If your algorithm is bound by the data collection speed, you may consider using
distributed data collector to make your training faster.
TorchRL offers a bunch of distributed data collectors that you can use
to increase the collection speed tenfold or more.

These examples are divided in a single machine and a multi-node series.

Refer to the [documentation](https://pytorch.org/rl/reference/collectors.html)
for more insight on what you can expect do
and how these tools should be used.
151 changes: 151 additions & 0 deletions examples/distributed/collectors/multi_nodes/delayed_dist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Multi-node distributed data collection with submitit in contexts where jobs can't launch other jobs.
The default configuration will ask for 8 nodes with 1 GPU each and 32 procs / node.
It should reach a collection speed of roughly 15-25K fps, or better depending
on the cluster specs.
The logic of the script is the following: we create a `main()` function that
executes or code (in this case just a data collection but in practice a training
loop should be present).
Since this `main()` function cannot launch sub-jobs by design, we launch the script
from the jump host and pass the slurm specs to submitit.
*Note*:
Although we don't go in much details into this in this script, the specs of the training
node and the specs of the inference nodes can differ (look at the DEFAULT_SLURM_CONF
and DEFAULT_SLURM_CONF_MAIN dictionaries below).
"""
import time
from argparse import ArgumentParser

import tqdm
from torchrl.collectors.distributed import submitit_delayed_launcher

from torchrl.collectors.distributed.default_configs import (
DEFAULT_SLURM_CONF,
DEFAULT_SLURM_CONF_MAIN,
)
from torchrl.collectors.distributed.generic import DistributedDataCollector
from torchrl.envs import EnvCreator

parser = ArgumentParser()
parser.add_argument("--partition", "-p", help="slurm partition to use")
parser.add_argument("--num_jobs", type=int, default=8, help="Number of jobs")
parser.add_argument("--tcp_port", type=int, default=1234, help="TCP port")
parser.add_argument(
"--num_workers", type=int, default=8, help="Number of workers per node"
)
parser.add_argument(
"--gpus_per_node",
"--gpus-per-node",
"-G",
type=int,
default=1,
help="Number of GPUs per node. If greater than 0, the backend used will be NCCL.",
)
parser.add_argument(
"--cpus_per_task",
"--cpus-per-task",
"-c",
type=int,
default=32,
help="Number of CPUs per node.",
)
parser.add_argument(
"--sync", action="store_true", help="Use --sync to collect data synchronously."
)
parser.add_argument(
"--frames_per_batch",
"--frames-per-batch",
default=4000,
type=int,
help="Number of frames in each batch of data. Must be "
"divisible by the product of nodes and workers if sync, by the number of "
"workers otherwise.",
)
parser.add_argument(
"--total_frames",
"--total-frames",
default=10_000_000,
type=int,
help="Total number of frames collected by the collector.",
)
parser.add_argument(
"--time",
"-t",
default="1:00:00",
help="Timeout for the nodes",
)

args = parser.parse_args()

slurm_gpus_per_node = args.gpus_per_node
slurm_time = args.time

DEFAULT_SLURM_CONF["slurm_gpus_per_node"] = slurm_gpus_per_node
DEFAULT_SLURM_CONF["slurm_time"] = slurm_time
DEFAULT_SLURM_CONF["slurm_cpus_per_task"] = args.cpus_per_task
DEFAULT_SLURM_CONF["slurm_partition"] = args.partition
DEFAULT_SLURM_CONF_MAIN["slurm_partition"] = args.partition
DEFAULT_SLURM_CONF_MAIN["slurm_time"] = slurm_time

num_jobs = args.num_jobs
tcp_port = args.tcp_port
num_workers = args.num_workers
sync = args.sync
total_frames = args.total_frames
frames_per_batch = args.frames_per_batch


@submitit_delayed_launcher(
num_jobs=num_jobs,
backend="nccl" if slurm_gpus_per_node else "gloo",
tcpport=tcp_port,
)
def main():
from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector
from torchrl.collectors.collectors import RandomPolicy
from torchrl.data import BoundedTensorSpec
from torchrl.envs.libs.gym import GymEnv

collector_class = SyncDataCollector if num_workers == 1 else MultiSyncDataCollector
device_str = "device" if num_workers == 1 else "devices"
collector = DistributedDataCollector(
[EnvCreator(lambda: GymEnv("ALE/Pong-v5"))] * num_jobs,
policy=RandomPolicy(BoundedTensorSpec(-1, 1, shape=(1,))),
launcher="submitit_delayed",
frames_per_batch=frames_per_batch,
total_frames=total_frames,
tcp_port=tcp_port,
collector_class=collector_class,
num_workers_per_collector=args.num_workers,
collector_kwargs={device_str: "cuda:0" if slurm_gpus_per_node else "cpu"},
storing_device="cuda:0" if slurm_gpus_per_node else "cpu",
backend="nccl" if slurm_gpus_per_node else "gloo",
sync=sync,
)
counter = 0
pbar = tqdm.tqdm(total=collector.total_frames)
for i, data in enumerate(collector):
pbar.update(data.numel())
pbar.set_description(f"data shape: {data.shape}, data device: {data.device}")
if i >= 10:
counter += data.numel()
if i == 10:
t0 = time.time()
t1 = time.time()
print(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps")
collector.shutdown()
exit()


if __name__ == "__main__":
main()
Loading

0 comments on commit da88aad

Please sign in to comment.