diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 2364f70ac0c..b17a0fbe736 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -1884,28 +1884,40 @@ class MultiSyncDataCollector(_MultiDataCollector): trajectory and the start of the next collection. This class can be safely used with online RL sota-implementations. + .. note:: Python requires multiprocessed code to be instantiated within a main guard: + + >>> from torchrl.collectors import MultiSyncDataCollector + >>> if __name__ == "__main__": + ... # Create your collector here + + See https://docs.python.org/3/library/multiprocessing.html for more info. + Examples: >>> from torchrl.envs.libs.gym import GymEnv - >>> from torchrl.envs import StepCounter >>> from tensordict.nn import TensorDictModule >>> from torch import nn - >>> env_maker = lambda: TransformedEnv(GymEnv("Pendulum-v1", device="cpu"), StepCounter(max_steps=50)) - >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) - >>> collector = MultiSyncDataCollector( - ... create_env_fn=[env_maker, env_maker], - ... policy=policy, - ... total_frames=2000, - ... max_frames_per_traj=50, - ... frames_per_batch=200, - ... init_random_frames=-1, - ... reset_at_each_iter=False, - ... devices="cpu", - ... storing_devices="cpu", - ... ) - >>> for i, data in enumerate(collector): - ... if i == 2: - ... print(data) - ... break + >>> from torchrl.collectors import MultiSyncDataCollector + >>> if __name__ == "__main__": + ... env_maker = lambda: GymEnv("Pendulum-v1", device="cpu") + ... policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) + ... collector = MultiSyncDataCollector( + ... create_env_fn=[env_maker, env_maker], + ... policy=policy, + ... total_frames=2000, + ... max_frames_per_traj=50, + ... frames_per_batch=200, + ... init_random_frames=-1, + ... reset_at_each_iter=False, + ... device="cpu", + ... storing_device="cpu", + ... cat_results="stack", + ... ) + ... for i, data in enumerate(collector): + ... if i == 2: + ... print(data) + ... break + ... collector.shutdown() + ... del collector TensorDict( fields={ action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), @@ -1932,8 +1944,6 @@ class MultiSyncDataCollector(_MultiDataCollector): batch_size=torch.Size([200]), device=cpu, is_shared=False) - >>> collector.shutdown() - >>> del collector """ @@ -1987,7 +1997,6 @@ def _queue_len(self) -> int: return self.num_workers def iterator(self) -> Iterator[TensorDictBase]: - cat_results = self.cat_results if cat_results is None: cat_results = 0 @@ -2232,27 +2241,40 @@ class MultiaSyncDataCollector(_MultiDataCollector): the batch of rollouts is collected and the next call to the iterator. This class can be safely used with offline RL sota-implementations. - Examples: + .. note:: Python requires multiprocessed code to be instantiated within a main guard: + + >>> from torchrl.collectors import MultiaSyncDataCollector + >>> if __name__ == "__main__": + ... # Create your collector here + + See https://docs.python.org/3/library/multiprocessing.html for more info. + + Examples: >>> from torchrl.envs.libs.gym import GymEnv >>> from tensordict.nn import TensorDictModule >>> from torch import nn - >>> env_maker = lambda: GymEnv("Pendulum-v1", device="cpu") - >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) - >>> collector = MultiaSyncDataCollector( - ... create_env_fn=[env_maker, env_maker], - ... policy=policy, - ... total_frames=2000, - ... max_frames_per_traj=50, - ... frames_per_batch=200, - ... init_random_frames=-1, - ... reset_at_each_iter=False, - ... devices="cpu", - ... storing_devices="cpu", - ... ) - >>> for i, data in enumerate(collector): - ... if i == 2: - ... print(data) - ... break + >>> from torchrl.collectors import MultiaSyncDataCollector + >>> if __name__ == "__main__": + ... env_maker = lambda: GymEnv("Pendulum-v1", device="cpu") + ... policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) + ... collector = MultiaSyncDataCollector( + ... create_env_fn=[env_maker, env_maker], + ... policy=policy, + ... total_frames=2000, + ... max_frames_per_traj=50, + ... frames_per_batch=200, + ... init_random_frames=-1, + ... reset_at_each_iter=False, + ... device="cpu", + ... storing_device="cpu", + ... cat_results="stack", + ... ) + ... for i, data in enumerate(collector): + ... if i == 2: + ... print(data) + ... break + ... collector.shutdown() + ... del collector TensorDict( fields={ action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), @@ -2279,8 +2301,6 @@ class MultiaSyncDataCollector(_MultiDataCollector): batch_size=torch.Size([200]), device=cpu, is_shared=False) - >>> collector.shutdown() - >>> del collector """