Skip to content

Commit

Permalink
[Feature] Pass replay buffers to SyncDataCollector
Browse files Browse the repository at this point in the history
ghstack-source-id: 452d429b153284ebc06e89225eed0f6a7b6ad37b
Pull Request resolved: pytorch#2384
  • Loading branch information
vmoens committed Aug 13, 2024
1 parent 2b975da commit 9627e8a
Show file tree
Hide file tree
Showing 6 changed files with 270 additions and 87 deletions.
84 changes: 82 additions & 2 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2585,8 +2585,15 @@ def test_unique_traj_sync(self, cat_results):
buffer.extend(d)
assert c._use_buffers
traj_ids = buffer[:].get(("collector", "traj_ids"))
# check that we have as many trajs as expected (no skip)
assert traj_ids.unique().numel() == traj_ids.max() + 1
# Ideally, we'd like that (sorted_traj.values == sorted_traj.indices).all()
# but in practice, one env can reach the end of the rollout and do a reset
# (which we don't want to prevent) and increment the global traj count,
# when the others have not finished yet. In that case, this traj number will never
# appear.
# sorted_traj = traj_ids.unique().sort()
# assert (sorted_traj.values == sorted_traj.indices).all()
# assert traj_ids.unique().numel() == traj_ids.max() + 1

# check that trajs are not overlapping
if stack_results:
sets = [
Expand Down Expand Up @@ -2751,6 +2758,79 @@ def test_async(self, use_buffers):
del collector


class TestCollectorRB:
@pytest.mark.skipif(not _has_gym, reason="requires gym.")
def test_collector_rb_sync(self):
env = SerialEnv(8, lambda cp=CARTPOLE_VERSIONED(): GymEnv(cp))
env.set_seed(0)
rb = ReplayBuffer(storage=LazyTensorStorage(256, ndim=2), batch_size=5)
collector = SyncDataCollector(
env,
RandomPolicy(env.action_spec),
replay_buffer=rb,
total_frames=256,
frames_per_batch=16,
)
torch.manual_seed(0)

for c in collector:
assert c is None
rb.sample()
rbdata0 = rb[:].clone()
collector.shutdown()
if not env.is_closed:
env.close()
del collector, env

env = SerialEnv(8, lambda cp=CARTPOLE_VERSIONED(): GymEnv(cp))
env.set_seed(0)
rb = ReplayBuffer(storage=LazyTensorStorage(256, ndim=2), batch_size=5)
collector = SyncDataCollector(
env, RandomPolicy(env.action_spec), total_frames=256, frames_per_batch=16
)
torch.manual_seed(0)

for i, c in enumerate(collector):
rb.extend(c)
torch.testing.assert_close(
rbdata0[:, : (i + 1) * 2]["observation"], rb[:]["observation"]
)
assert c is not None
rb.sample()

rbdata1 = rb[:].clone()
collector.shutdown()
if not env.is_closed:
env.close()
del collector, env
assert assert_allclose_td(rbdata0, rbdata1)

@pytest.mark.skipif(not _has_gym, reason="requires gym.")
def test_collector_rb_multisync(self):
env = GymEnv(CARTPOLE_VERSIONED())
env.set_seed(0)

rb = ReplayBuffer(storage=LazyTensorStorage(256), batch_size=5)
rb.add(env.rand_step(env.reset()))
rb.empty()

collector = MultiSyncDataCollector(
[lambda: env, lambda: env],
RandomPolicy(env.action_spec),
replay_buffer=rb,
total_frames=256,
frames_per_batch=16,
)
torch.manual_seed(0)
pred_len = 0
for c in collector:
pred_len += 16
assert c is None
assert len(rb) == pred_len
collector.shutdown()
assert len(rb) == 256


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
12 changes: 4 additions & 8 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2064,13 +2064,16 @@ def exec_multiproc_rb(
init=True,
writer_type=TensorDictRoundRobinWriter,
sampler_type=RandomSampler,
device=None,
):
rb = TensorDictReplayBuffer(
storage=storage_type(21), writer=writer_type(), sampler=sampler_type()
)
if init:
td = TensorDict(
{"a": torch.zeros(10), "next": {"reward": torch.ones(10)}}, [10]
{"a": torch.zeros(10), "next": {"reward": torch.ones(10)}},
[10],
device=device,
)
rb.extend(td)
q0 = mp.Queue(1)
Expand Down Expand Up @@ -2098,13 +2101,6 @@ def test_error_list(self):
with pytest.raises(RuntimeError, match="Cannot share a storage of type"):
self.exec_multiproc_rb(storage_type=ListStorage)

def test_error_nonshared(self):
# non shared tensor storage cannot be shared
with pytest.raises(
RuntimeError, match="The storage must be place in shared memory"
):
self.exec_multiproc_rb(storage_type=LazyTensorStorage)

def test_error_maxwriter(self):
# TensorDictMaxValueWriter cannot be shared
with pytest.raises(RuntimeError, match="cannot be shared between processes"):
Expand Down
Loading

0 comments on commit 9627e8a

Please sign in to comment.