Skip to content

Commit

Permalink
[Doc] Indicate necessary context to run multiprocessed collectors in …
Browse files Browse the repository at this point in the history
…doc (pytorch#2126)

Co-authored-by: Gert-Jan Both <bothg@hhmi.org>
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
Co-authored-by: Vincent Moens <vmoens@meta.com>
  • Loading branch information
4 people authored Apr 30, 2024
1 parent 68101b0 commit 741947a
Showing 1 changed file with 61 additions and 41 deletions.
102 changes: 61 additions & 41 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1884,28 +1884,40 @@ class MultiSyncDataCollector(_MultiDataCollector):
trajectory and the start of the next collection.
This class can be safely used with online RL sota-implementations.
.. note:: Python requires multiprocessed code to be instantiated within a main guard:
>>> from torchrl.collectors import MultiSyncDataCollector
>>> if __name__ == "__main__":
... # Create your collector here
See https://docs.python.org/3/library/multiprocessing.html for more info.
Examples:
>>> from torchrl.envs.libs.gym import GymEnv
>>> from torchrl.envs import StepCounter
>>> from tensordict.nn import TensorDictModule
>>> from torch import nn
>>> env_maker = lambda: TransformedEnv(GymEnv("Pendulum-v1", device="cpu"), StepCounter(max_steps=50))
>>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"])
>>> collector = MultiSyncDataCollector(
... create_env_fn=[env_maker, env_maker],
... policy=policy,
... total_frames=2000,
... max_frames_per_traj=50,
... frames_per_batch=200,
... init_random_frames=-1,
... reset_at_each_iter=False,
... devices="cpu",
... storing_devices="cpu",
... )
>>> for i, data in enumerate(collector):
... if i == 2:
... print(data)
... break
>>> from torchrl.collectors import MultiSyncDataCollector
>>> if __name__ == "__main__":
... env_maker = lambda: GymEnv("Pendulum-v1", device="cpu")
... policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"])
... collector = MultiSyncDataCollector(
... create_env_fn=[env_maker, env_maker],
... policy=policy,
... total_frames=2000,
... max_frames_per_traj=50,
... frames_per_batch=200,
... init_random_frames=-1,
... reset_at_each_iter=False,
... device="cpu",
... storing_device="cpu",
... cat_results="stack",
... )
... for i, data in enumerate(collector):
... if i == 2:
... print(data)
... break
... collector.shutdown()
... del collector
TensorDict(
fields={
action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False),
Expand All @@ -1932,8 +1944,6 @@ class MultiSyncDataCollector(_MultiDataCollector):
batch_size=torch.Size([200]),
device=cpu,
is_shared=False)
>>> collector.shutdown()
>>> del collector
"""

Expand Down Expand Up @@ -1987,7 +1997,6 @@ def _queue_len(self) -> int:
return self.num_workers

def iterator(self) -> Iterator[TensorDictBase]:

cat_results = self.cat_results
if cat_results is None:
cat_results = 0
Expand Down Expand Up @@ -2232,27 +2241,40 @@ class MultiaSyncDataCollector(_MultiDataCollector):
the batch of rollouts is collected and the next call to the iterator.
This class can be safely used with offline RL sota-implementations.
Examples:
.. note:: Python requires multiprocessed code to be instantiated within a main guard:
>>> from torchrl.collectors import MultiaSyncDataCollector
>>> if __name__ == "__main__":
... # Create your collector here
See https://docs.python.org/3/library/multiprocessing.html for more info.
Examples:
>>> from torchrl.envs.libs.gym import GymEnv
>>> from tensordict.nn import TensorDictModule
>>> from torch import nn
>>> env_maker = lambda: GymEnv("Pendulum-v1", device="cpu")
>>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"])
>>> collector = MultiaSyncDataCollector(
... create_env_fn=[env_maker, env_maker],
... policy=policy,
... total_frames=2000,
... max_frames_per_traj=50,
... frames_per_batch=200,
... init_random_frames=-1,
... reset_at_each_iter=False,
... devices="cpu",
... storing_devices="cpu",
... )
>>> for i, data in enumerate(collector):
... if i == 2:
... print(data)
... break
>>> from torchrl.collectors import MultiaSyncDataCollector
>>> if __name__ == "__main__":
... env_maker = lambda: GymEnv("Pendulum-v1", device="cpu")
... policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"])
... collector = MultiaSyncDataCollector(
... create_env_fn=[env_maker, env_maker],
... policy=policy,
... total_frames=2000,
... max_frames_per_traj=50,
... frames_per_batch=200,
... init_random_frames=-1,
... reset_at_each_iter=False,
... device="cpu",
... storing_device="cpu",
... cat_results="stack",
... )
... for i, data in enumerate(collector):
... if i == 2:
... print(data)
... break
... collector.shutdown()
... del collector
TensorDict(
fields={
action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False),
Expand All @@ -2279,8 +2301,6 @@ class MultiaSyncDataCollector(_MultiDataCollector):
batch_size=torch.Size([200]),
device=cpu,
is_shared=False)
>>> collector.shutdown()
>>> del collector
"""

Expand Down

0 comments on commit 741947a

Please sign in to comment.