From a0c12cd8a74d6bbae0990b2a6694564bf1c8a405 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 13 Aug 2024 12:58:02 -0700 Subject: [PATCH] [Feature] Pass replay buffers to MultiaSyncDataCollector ghstack-source-id: 7275208e2f02560229ca83c999cd9b0ae68aaf4f Pull Request resolved: https://github.com/pytorch/rl/pull/2387 --- test/test_collector.py | 25 +++++++++++++++++++++++++ torchrl/collectors/collectors.py | 21 +++++++++++++-------- 2 files changed, 38 insertions(+), 8 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index 4f12e445bf3..e914c283966 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -2830,6 +2830,31 @@ def test_collector_rb_multisync(self): collector.shutdown() assert len(rb) == 256 + @pytest.mark.skipif(not _has_gym, reason="requires gym.") + def test_collector_rb_multiasync(self): + env = GymEnv(CARTPOLE_VERSIONED()) + env.set_seed(0) + + rb = ReplayBuffer(storage=LazyTensorStorage(256), batch_size=5) + rb.add(env.rand_step(env.reset())) + rb.empty() + + collector = MultiaSyncDataCollector( + [lambda: env, lambda: env], + RandomPolicy(env.action_spec), + replay_buffer=rb, + total_frames=256, + frames_per_batch=16, + ) + torch.manual_seed(0) + pred_len = 0 + for c in collector: + pred_len += 16 + assert c is None + assert len(rb) >= pred_len + collector.shutdown() + assert len(rb) == 256 + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 3a8686f2893..0292c539b9b 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -16,7 +16,7 @@ import sys import time import warnings -from collections import OrderedDict +from collections import defaultdict, OrderedDict from copy import deepcopy from multiprocessing import connection, queues from multiprocessing.managers import SyncManager @@ -2433,7 +2433,7 @@ class MultiaSyncDataCollector(_MultiDataCollector): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.out_tensordicts = {} + self.out_tensordicts = defaultdict(lambda: None) self.running = False if self.postprocs is not None: @@ -2478,7 +2478,9 @@ def frames_per_batch_worker(self): def _get_from_queue(self, timeout=None) -> Tuple[int, int, TensorDictBase]: new_data, j = self.queue_out.get(timeout=timeout) use_buffers = self._use_buffers - if j == 0 or not use_buffers: + if self.replay_buffer is not None: + idx = new_data + elif j == 0 or not use_buffers: try: data, idx = new_data self.out_tensordicts[idx] = data @@ -2493,7 +2495,7 @@ def _get_from_queue(self, timeout=None) -> Tuple[int, int, TensorDictBase]: else: idx = new_data out = self.out_tensordicts[idx] - if j == 0 or use_buffers: + if not self.replay_buffer and (j == 0 or use_buffers): # we clone the data to make sure that we'll be working with a fixed copy out = out.clone() return idx, j, out @@ -2518,9 +2520,12 @@ def iterator(self) -> Iterator[TensorDictBase]: _check_for_faulty_process(self.procs) self._iter += 1 idx, j, out = self._get_from_queue() - worker_frames = out.numel() - if self.split_trajs: - out = split_trajectories(out, prefix="collector") + if self.replay_buffer is None: + worker_frames = out.numel() + if self.split_trajs: + out = split_trajectories(out, prefix="collector") + else: + worker_frames = self.frames_per_batch_worker self._frames += worker_frames workers_frames[idx] = workers_frames[idx] + worker_frames if self.postprocs: @@ -2536,7 +2541,7 @@ def iterator(self) -> Iterator[TensorDictBase]: else: msg = "continue" self.pipes[idx].send((idx, msg)) - if self._exclude_private_keys: + if out is not None and self._exclude_private_keys: excluded_keys = [key for key in out.keys() if key.startswith("_")] out = out.exclude(*excluded_keys) yield out