Skip to content

Commit

Permalink
[BugFix,Feature,Doc] Fix replay buffers sampling info, docstrings and…
Browse files Browse the repository at this point in the history
… iteration (pytorch#1003)
  • Loading branch information
vmoens authored Mar 29, 2023
1 parent 0b2d2d8 commit e26d148
Show file tree
Hide file tree
Showing 10 changed files with 484 additions and 104 deletions.
6 changes: 3 additions & 3 deletions examples/discrete_sac/discrete_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,15 @@ def make_replay_buffer(
batch_size=256,
buffer_scratch_dir="/tmp/",
device="cpu",
make_replay_buffer=3,
prefetch=3,
):
if prb:
replay_buffer = TensorDictPrioritizedReplayBuffer(
alpha=0.7,
beta=0.5,
pin_memory=False,
batch_size=batch_size,
prefetch=make_replay_buffer,
prefetch=prefetch,
storage=LazyMemmapStorage(
buffer_size,
scratch_dir=buffer_scratch_dir,
Expand All @@ -63,7 +63,7 @@ def make_replay_buffer(
replay_buffer = TensorDictReplayBuffer(
pin_memory=False,
batch_size=batch_size,
prefetch=make_replay_buffer,
prefetch=prefetch,
storage=LazyMemmapStorage(
buffer_size,
scratch_dir=buffer_scratch_dir,
Expand Down
13 changes: 9 additions & 4 deletions examples/iql/iql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,33 +36,36 @@ def env_maker(env_name, frame_skip=1, device="cpu", from_pixels=False):


def make_replay_buffer(
batch_size,
prb=False,
buffer_size=1000000,
buffer_scratch_dir="/tmp/",
device="cpu",
make_replay_buffer=3,
prefetch=3,
):
if prb:
replay_buffer = TensorDictPrioritizedReplayBuffer(
alpha=0.7,
beta=0.5,
pin_memory=False,
prefetch=make_replay_buffer,
prefetch=prefetch,
storage=LazyMemmapStorage(
buffer_size,
scratch_dir=buffer_scratch_dir,
device=device,
),
batch_size=batch_size,
)
else:
replay_buffer = TensorDictReplayBuffer(
pin_memory=False,
prefetch=make_replay_buffer,
prefetch=prefetch,
storage=LazyMemmapStorage(
buffer_size,
scratch_dir=buffer_scratch_dir,
device=device,
),
batch_size=batch_size,
)
return replay_buffer

Expand Down Expand Up @@ -218,7 +221,9 @@ def env_factory(num_workers):
collector.set_seed(cfg.seed)

# Make Replay Buffer
replay_buffer = make_replay_buffer(buffer_size=cfg.buffer_size, device=device)
replay_buffer = make_replay_buffer(
buffer_size=cfg.buffer_size, device=device, batch_size=cfg.batch_size
)

# Optimizers
params = list(loss_module.parameters())
Expand Down
14 changes: 10 additions & 4 deletions examples/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,33 +60,36 @@ def apply_env_transforms(env, reward_scaling=1.0):


def make_replay_buffer(
batch_size,
prb=False,
buffer_size=1000000,
buffer_scratch_dir="/tmp/",
device="cpu",
make_replay_buffer=3,
prefetch=3,
):
if prb:
replay_buffer = TensorDictPrioritizedReplayBuffer(
alpha=0.7,
beta=0.5,
pin_memory=False,
prefetch=make_replay_buffer,
prefetch=prefetch,
storage=LazyMemmapStorage(
buffer_size,
scratch_dir=buffer_scratch_dir,
device=device,
),
batch_size=batch_size,
)
else:
replay_buffer = TensorDictReplayBuffer(
pin_memory=False,
prefetch=make_replay_buffer,
prefetch=prefetch,
storage=LazyMemmapStorage(
buffer_size,
scratch_dir=buffer_scratch_dir,
device=device,
),
batch_size=batch_size,
)
return replay_buffer

Expand Down Expand Up @@ -239,7 +242,10 @@ def main(cfg: "DictConfig"): # noqa: F821

# Make Replay Buffer
replay_buffer = make_replay_buffer(
prb=cfg.prb, buffer_size=cfg.buffer_size, device=device
batch_size=cfg.batch_size,
prb=cfg.prb,
buffer_size=cfg.buffer_size,
device=device,
)

# Optimizers
Expand Down
48 changes: 48 additions & 0 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,54 @@ def test_index(self, rbtype, storage, size, prefetch):
assert b


def test_multi_loops():
"""Tests that one can iterate multiple times over a buffer without rep."""
rb = ReplayBuffer(
batch_size=5, storage=ListStorage(10), sampler=SamplerWithoutReplacement()
)
rb.extend(torch.zeros(10))
for i, d in enumerate(rb): # noqa: B007
assert (d == 0).all()
assert i == 1
for i, d in enumerate(rb): # noqa: B007
assert (d == 0).all()
assert i == 1


def test_batch_errors():
"""Tests error messages related to batch-size"""
rb = ReplayBuffer(
storage=ListStorage(10), sampler=SamplerWithoutReplacement(drop_last=False)
)
rb.extend(torch.zeros(10))
rb.sample(3) # that works
with pytest.raises(
RuntimeError,
match="Cannot iterate over the replay buffer. Batch_size was not specified",
):
for _ in rb:
pass
with pytest.raises(RuntimeError, match="batch_size not specified"):
rb.sample()
with pytest.raises(ValueError, match="Samplers with drop_last=True"):
ReplayBuffer(
storage=ListStorage(10), sampler=SamplerWithoutReplacement(drop_last=True)
)
# that works
ReplayBuffer(
storage=ListStorage(10),
)
rb = ReplayBuffer(
storage=ListStorage(10),
sampler=SamplerWithoutReplacement(drop_last=False),
batch_size=3,
)
rb.extend(torch.zeros(10))
for _ in rb:
pass
rb.sample()


@pytest.mark.parametrize("priority_key", ["pk", "td_error"])
@pytest.mark.parametrize("contiguous", [True, False])
@pytest.mark.parametrize("device", get_available_devices())
Expand Down
2 changes: 1 addition & 1 deletion test/test_rb_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def sample_from_buffer_remotely_returns_correct_tensordict_test(rank, name, worl
_, inserted = _add_random_tensor_dict_to_buffer(buffer)
sampled = _sample_from_buffer(buffer, 1)
assert type(sampled) is type(inserted) is TensorDict
assert (sampled == inserted)["a"].item()
assert (sampled["a"] == inserted["a"]).all()


@pytest.mark.skipif(
Expand Down
18 changes: 10 additions & 8 deletions test/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,9 @@ def test_rb_trainer(self, prioritized):
S = 100
storage = ListStorage(S)
if prioritized:
replay_buffer = TensorDictPrioritizedReplayBuffer(1.1, 0.9, storage=storage)
replay_buffer = TensorDictPrioritizedReplayBuffer(
alpha=1.1, beta=0.9, storage=storage
)
else:
replay_buffer = TensorDictReplayBuffer(storage=storage)

Expand Down Expand Up @@ -260,8 +262,8 @@ def test_rb_trainer_state_dict(self, prioritized, storage_type):

if prioritized:
replay_buffer = TensorDictPrioritizedReplayBuffer(
1.1,
0.9,
alpha=1.1,
beta=0.9,
storage=storage,
)
else:
Expand Down Expand Up @@ -293,7 +295,7 @@ def test_rb_trainer_state_dict(self, prioritized, storage_type):
trainer2 = mocking_trainer()
if prioritized:
replay_buffer2 = TensorDictPrioritizedReplayBuffer(
1.1, 0.9, storage=storage
alpha=1.1, beta=0.9, storage=storage
)
else:
replay_buffer2 = TensorDictReplayBuffer(storage=storage)
Expand Down Expand Up @@ -398,8 +400,8 @@ def make_storage():
storage = make_storage()
if prioritized:
replay_buffer = TensorDictPrioritizedReplayBuffer(
1.1,
0.9,
alpha=1.1,
beta=0.9,
storage=storage,
)
else:
Expand Down Expand Up @@ -430,8 +432,8 @@ def make_storage():
storage2 = make_storage()
if prioritized:
replay_buffer2 = TensorDictPrioritizedReplayBuffer(
1.1,
0.9,
alpha=1.1,
beta=0.9,
storage=storage2,
)
else:
Expand Down
Loading

0 comments on commit e26d148

Please sign in to comment.