Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] RNG for RBs #2379

Merged
merged 4 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
init
  • Loading branch information
vmoens committed Aug 7, 2024
commit f28d76d63564bda831c7d14d361168d22f1365a6
84 changes: 84 additions & 0 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@
".".join([str(s) for s in version.parse(str(torch.__version__)).release])
) >= version.parse("2.3.0")

ReplayBufferRNG = functools.partial(ReplayBuffer, generator=torch.Generator())
TensorDictReplayBufferRNG = functools.partial(
TensorDictReplayBuffer, generator=torch.Generator()
)


@pytest.mark.parametrize(
"sampler",
Expand All @@ -125,17 +130,27 @@
"rb_type,storage,datatype",
[
[ReplayBuffer, ListStorage, None],
[ReplayBufferRNG, ListStorage, None],
[TensorDictReplayBuffer, ListStorage, "tensordict"],
[TensorDictReplayBufferRNG, ListStorage, "tensordict"],
[RemoteTensorDictReplayBuffer, ListStorage, "tensordict"],
[ReplayBuffer, LazyTensorStorage, "tensor"],
[ReplayBuffer, LazyTensorStorage, "tensordict"],
[ReplayBuffer, LazyTensorStorage, "pytree"],
[ReplayBufferRNG, LazyTensorStorage, "tensor"],
[ReplayBufferRNG, LazyTensorStorage, "tensordict"],
[ReplayBufferRNG, LazyTensorStorage, "pytree"],
[TensorDictReplayBuffer, LazyTensorStorage, "tensordict"],
[TensorDictReplayBufferRNG, LazyTensorStorage, "tensordict"],
[RemoteTensorDictReplayBuffer, LazyTensorStorage, "tensordict"],
[ReplayBuffer, LazyMemmapStorage, "tensor"],
[ReplayBuffer, LazyMemmapStorage, "tensordict"],
[ReplayBuffer, LazyMemmapStorage, "pytree"],
[ReplayBufferRNG, LazyMemmapStorage, "tensor"],
[ReplayBufferRNG, LazyMemmapStorage, "tensordict"],
[ReplayBufferRNG, LazyMemmapStorage, "pytree"],
[TensorDictReplayBuffer, LazyMemmapStorage, "tensordict"],
[TensorDictReplayBufferRNG, LazyMemmapStorage, "tensordict"],
[RemoteTensorDictReplayBuffer, LazyMemmapStorage, "tensordict"],
],
)
Expand Down Expand Up @@ -1155,17 +1170,86 @@ def test_replay_buffer_trajectories(stack, reduction, datatype):
# sampled_td_filtered.batch_size = [3, 4]


class TestRNG:
def test_rb_rng(self):
state = torch.random.get_rng_state()
rb = ReplayBufferRNG(sampler=RandomSampler(), storage=LazyTensorStorage(100))
rb.extend(torch.arange(100))
rb._rng.set_state(state)
a = rb.sample(32)
rb._rng.set_state(state)
b = rb.sample(32)
assert (a == b).all()
c = rb.sample(32)
assert (a != c).any()

def test_prb_rng(self):
state = torch.random.get_rng_state()
rb = ReplayBuffer(
sampler=PrioritizedSampler(100, 1.0, 1.0),
storage=LazyTensorStorage(100),
generator=torch.Generator(),
)
rb.extend(torch.arange(100))
rb.update_priority(index=torch.arange(100), priority=torch.arange(1, 101))

rb._rng.set_state(state)
a = rb.sample(32)

rb._rng.set_state(state)
b = rb.sample(32)
assert (a == b).all()

c = rb.sample(32)
assert (a != c).any()

def test_slice_rng(self):
state = torch.random.get_rng_state()
rb = ReplayBuffer(
sampler=SliceSampler(num_slices=4),
storage=LazyTensorStorage(100),
generator=torch.Generator(),
)
done = torch.zeros(100, 1, dtype=torch.bool)
done[49] = 1
done[-1] = 1
data = TensorDict(
{
"data": torch.arange(100),
("next", "done"): done,
},
batch_size=[100],
)
rb.extend(data)

rb._rng.set_state(state)
a = rb.sample(32)

rb._rng.set_state(state)
b = rb.sample(32)
assert (a == b).all()

c = rb.sample(32)
assert (a != c).any()


@pytest.mark.parametrize(
"rbtype,storage",
[
(ReplayBuffer, None),
(ReplayBuffer, ListStorage),
(ReplayBufferRNG, None),
(ReplayBufferRNG, ListStorage),
(PrioritizedReplayBuffer, None),
(PrioritizedReplayBuffer, ListStorage),
(TensorDictReplayBuffer, None),
(TensorDictReplayBuffer, ListStorage),
(TensorDictReplayBuffer, LazyTensorStorage),
(TensorDictReplayBuffer, LazyMemmapStorage),
(TensorDictReplayBufferRNG, None),
(TensorDictReplayBufferRNG, ListStorage),
(TensorDictReplayBufferRNG, LazyTensorStorage),
(TensorDictReplayBufferRNG, LazyMemmapStorage),
(TensorDictPrioritizedReplayBuffer, None),
(TensorDictPrioritizedReplayBuffer, ListStorage),
(TensorDictPrioritizedReplayBuffer, LazyTensorStorage),
Expand Down
37 changes: 37 additions & 0 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,13 @@ class ReplayBuffer:
>>> for d in data.unbind(1):
... rb.add(d)
>>> rb.extend(data)
generator (torch.Generator, optional): a generator to use for sampling.
Using a dedicated generator for the replay buffer can allow a fine-grained control
over seeding, for instance keeping the global seed different but the RB seed identical
for distributed jobs.
Defaults to ``None`` (global default generator).

.. warning:: As of now, the generator has no effect on the transforms.

Examples:
>>> import torch
Expand Down Expand Up @@ -204,6 +210,7 @@ def __init__(
batch_size: int | None = None,
dim_extend: int | None = None,
checkpointer: "StorageCheckpointerBase" | None = None, # noqa: F821
generator: torch.Generator | None = None,
) -> None:
self._storage = storage if storage is not None else ListStorage(max_size=1_000)
self._storage.attach(self)
Expand Down Expand Up @@ -263,6 +270,11 @@ def __init__(
self.dim_extend = dim_extend
self._storage.checkpointer = checkpointer

self._rng = generator
self._storage._rng = generator
self._sampler._rng = generator
self._writer._rng = generator

@property
def dim_extend(self):
return self._dim_extend
Expand Down Expand Up @@ -995,6 +1007,13 @@ class TensorDictReplayBuffer(ReplayBuffer):
>>> for d in data.unbind(1):
... rb.add(d)
>>> rb.extend(data)
generator (torch.Generator, optional): a generator to use for sampling.
Using a dedicated generator for the replay buffer can allow a fine-grained control
over seeding, for instance keeping the global seed different but the RB seed identical
for distributed jobs.
Defaults to ``None`` (global default generator).

.. warning:: As of now, the generator has no effect on the transforms.

Examples:
>>> import torch
Expand Down Expand Up @@ -1327,6 +1346,13 @@ class TensorDictPrioritizedReplayBuffer(TensorDictReplayBuffer):
>>> for d in data.unbind(1):
... rb.add(d)
>>> rb.extend(data)
generator (torch.Generator, optional): a generator to use for sampling.
Using a dedicated generator for the replay buffer can allow a fine-grained control
over seeding, for instance keeping the global seed different but the RB seed identical
for distributed jobs.
Defaults to ``None`` (global default generator).

.. warning:: As of now, the generator has no effect on the transforms.

Examples:
>>> import torch
Expand Down Expand Up @@ -1400,6 +1426,7 @@ def __init__(
reduction: str = "max",
batch_size: int | None = None,
dim_extend: int | None = None,
generator: torch.Generator | None = None,
) -> None:
if storage is None:
storage = ListStorage(max_size=1_000)
Expand All @@ -1416,6 +1443,7 @@ def __init__(
transform=transform,
batch_size=batch_size,
dim_extend=dim_extend,
generator=generator,
)


Expand Down Expand Up @@ -1555,6 +1583,13 @@ class ReplayBufferEnsemble(ReplayBuffer):
sampled according to the probabilities ``p``. Can also
be passed to torchrl.data.replay_buffers.samplers.SamplerEnsemble`
if the buffer is built explicitely.
generator (torch.Generator, optional): a generator to use for sampling.
Using a dedicated generator for the replay buffer can allow a fine-grained control
over seeding, for instance keeping the global seed different but the RB seed identical
for distributed jobs.
Defaults to ``None`` (global default generator).

.. warning:: As of now, the generator has no effect on the transforms.

Examples:
>>> from torchrl.envs import Compose, ToTensorImage, Resize, RenameTransform
Expand Down Expand Up @@ -1644,6 +1679,7 @@ def __init__(
p: Tensor = None,
sample_from_all: bool = False,
num_buffer_sampled: int | None = None,
generator: torch.Generator | None = None,
**kwargs,
):

Expand Down Expand Up @@ -1680,6 +1716,7 @@ def __init__(
transform=transform,
batch_size=batch_size,
collate_fn=collate_fn,
generator=generator,
**kwargs,
)

Expand Down
35 changes: 30 additions & 5 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ class Sampler(ABC):
# need to keep track of the number of remaining batches
_remaining_batches = int(torch.iinfo(torch.int64).max)

# The RNG is set by the replay buffer
_rng: torch.Generator | None = None

@abstractmethod
def sample(self, storage: Storage, batch_size: int) -> Tuple[Any, dict]:
...
Expand Down Expand Up @@ -192,7 +195,9 @@ def _get_sample_list(self, storage: Storage, len_storage: int, batch_size: int):
device = storage.device if hasattr(storage, "device") else None

if self.shuffle:
_sample_list = torch.randperm(len_storage, device=device)
_sample_list = torch.randperm(
len_storage, device=device, generator=self._rng
)
else:
_sample_list = torch.arange(len_storage, device=device)
self._sample_list = _sample_list
Expand Down Expand Up @@ -473,7 +478,11 @@ def sample(self, storage: Storage, batch_size: int) -> torch.Tensor:
raise RuntimeError("non-positive p_min")
# For some undefined reason, only np.random works here.
# All PT attempts fail, even when subsequently transformed into numpy
mass = np.random.uniform(0.0, p_sum, size=batch_size)
if self._rng is None:
mass = np.random.uniform(0.0, p_sum, size=batch_size)
else:
mass = torch.rand(batch_size, generator=self._rng) * p_sum

# mass = torch.zeros(batch_size, dtype=torch.double).uniform_(0.0, p_sum)
# mass = torch.rand(batch_size).mul_(p_sum)
index = self._sum_tree.scan_lower_bound(mass)
Expand Down Expand Up @@ -1187,7 +1196,9 @@ def _sample_slices(
# start_idx and stop_idx are 2d tensors organized like a non-zero

def get_traj_idx(maxval):
return torch.randint(maxval, (num_slices,), device=lengths.device)
return torch.randint(
maxval, (num_slices,), device=lengths.device, generator=self._rng
)

if (lengths < seq_length).any():
if self.strict_length:
Expand Down Expand Up @@ -1290,7 +1301,8 @@ def _get_index(
start_point = -span_right

relative_starts = (
torch.rand(num_slices, device=lengths.device) * (end_point - start_point)
torch.rand(num_slices, device=lengths.device, generator=self._rng)
* (end_point - start_point)
).floor().to(start_idx.dtype) + start_point

if self.span[0]:
Expand Down Expand Up @@ -2033,6 +2045,7 @@ class SamplerEnsemble(Sampler):
def __init__(
self, *samplers, p=None, sample_from_all=False, num_buffer_sampled=None
):
self._rng_private = None
self._samplers = samplers
self.sample_from_all = sample_from_all
if sample_from_all and p is not None:
Expand All @@ -2042,6 +2055,16 @@ def __init__(
self.p = p
self.num_buffer_sampled = num_buffer_sampled

@property
def _rng(self):
return self._rng_private

@_rng.setter
def _rng(self, value):
self._rng_private = value
for sampler in self._samplers:
sampler._rng = value

@property
def p(self):
return self._p
Expand Down Expand Up @@ -2082,7 +2105,9 @@ def sample(self, storage, batch_size):
else:
if self.p is None:
buffer_ids = torch.randint(
len(self._samplers), (self.num_buffer_sampled,)
len(self._samplers),
(self.num_buffer_sampled,),
generator=self._rng,
)
else:
buffer_ids = torch.multinomial(self.p, self.num_buffer_sampled, True)
Expand Down
18 changes: 16 additions & 2 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class Storage:
ndim = 1
max_size: int
_default_checkpointer: StorageCheckpointerBase = StorageCheckpointerBase
_rng: torch.Generator | None = None

def __init__(
self, max_size: int, checkpointer: StorageCheckpointerBase | None = None
Expand Down Expand Up @@ -142,7 +143,7 @@ def _empty(self):
def _rand_given_ndim(self, batch_size):
# a method to return random indices given the storage ndim
if self.ndim == 1:
return torch.randint(0, len(self), (batch_size,))
return torch.randint(0, len(self), (batch_size,), generator=self._rng)
raise RuntimeError(
f"Random number generation is not implemented for storage of type {type(self)} with ndim {self.ndim}. "
f"Please report this exception as well as the use case (incl. buffer construction) on github."
Expand Down Expand Up @@ -497,7 +498,9 @@ def _rand_given_ndim(self, batch_size):
if self.ndim == 1:
return super()._rand_given_ndim(batch_size)
shape = self.shape
return tuple(torch.randint(_dim, (batch_size,)) for _dim in shape)
return tuple(
torch.randint(_dim, (batch_size,), generator=self._rng) for _dim in shape
)

def flatten(self):
if self.ndim == 1:
Expand Down Expand Up @@ -1142,13 +1145,24 @@ def __init__(
*storages: Storage,
transforms: List["Transform"] = None, # noqa: F821
):
self._rng_private = None
self._storages = storages
self._transforms = transforms
if transforms is not None and len(transforms) != len(storages):
raise TypeError(
"transforms must have the same length as the storages " "provided."
)

@property
def _rng(self):
return self._rng_private

@_rng.setter
def _rng(self, value):
self._rng_private = value
for storage in self._storages:
storage._rng = value

@property
def _attached_entities(self):
return set()
Expand Down
Loading
Loading