diff --git a/test/test_libs.py b/test/test_libs.py index f1715a550f4..1e4d2a7d871 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -400,7 +400,7 @@ def test_vecenvs_wrapper(self, envname): ["HalfCheetah-v4", "CartPole-v1", "ALE/Pong-v5"] + (["FetchReach-v2"] if _has_gym_robotics else []), ) - @pytest.mark.flaky(reruns=3, reruns_delay=1) + @pytest.mark.flaky(reruns=8, reruns_delay=1) def test_vecenvs_env(self, envname): from _utils_internal import rollout_consistency_assertion @@ -1897,10 +1897,10 @@ def test_direct_download(self, task): assert len(keys) assert_allclose_td( data_direct._storage._storage.select(*keys).apply( - lambda t: t.as_tensor().float() + lambda t: t.float() ), data_d4rl._storage._storage.select(*keys).apply( - lambda t: t.as_tensor().float() + lambda t: t.float() ), ) diff --git a/test/test_rb_distributed.py b/test/test_rb_distributed.py index 8a46b1a006d..548f04dc41d 100644 --- a/test/test_rb_distributed.py +++ b/test/test_rb_distributed.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import argparse import os + import sys import time @@ -22,10 +23,10 @@ class ReplayBufferNode(RemoteTensorDictReplayBuffer): - def __init__(self, capacity: int): + def __init__(self, capacity: int, scratch_dir=None): super().__init__( storage=LazyMemmapStorage( - max_size=capacity, scratch_dir="/tmp/", device=torch.device("cpu") + max_size=capacity, scratch_dir=scratch_dir, device=torch.device("cpu") ), sampler=RandomSampler(), writer=RoundRobinWriter(), diff --git a/test/test_rlhf.py b/test/test_rlhf.py index 2abb9a6d386..31ef96681df 100644 --- a/test/test_rlhf.py +++ b/test/test_rlhf.py @@ -14,7 +14,12 @@ import torch.nn.functional as F from _utils_internal import get_default_devices -from tensordict import is_tensor_collection, MemmapTensor, TensorDict, TensorDictBase +from tensordict import ( + is_tensor_collection, + MemoryMappedTensor, + TensorDict, + TensorDictBase, +) from tensordict.nn import TensorDictModule from torchrl.data.rlhf import TensorDictTokenizer from torchrl.data.rlhf.dataset import ( @@ -188,8 +193,8 @@ def test_dataset_to_tensordict(tmpdir, suffix): else: assert ("c", "d", "a") in td.keys(True) assert ("c", "d", "b") in td.keys(True) - assert isinstance(td.get((suffix, "a")), MemmapTensor) - assert isinstance(td.get((suffix, "b")), MemmapTensor) + assert isinstance(td.get((suffix, "a")), MemoryMappedTensor) + assert isinstance(td.get((suffix, "b")), MemoryMappedTensor) @pytest.mark.skipif( diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index bacb5713492..9c8417b9c97 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -12,7 +12,7 @@ import torch from tensordict import is_tensorclass -from tensordict.memmap import MemmapTensor +from tensordict.memmap import MemmapTensor, MemoryMappedTensor from tensordict.tensordict import is_tensor_collection, TensorDict, TensorDictBase from tensordict.utils import expand_right @@ -482,7 +482,7 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None: if self.device == "auto": self.device = data.device if isinstance(data, torch.Tensor): - # if Tensor, we just create a MemmapTensor of the desired shape, device and dtype + # if Tensor, we just create a MemoryMappedTensor of the desired shape, device and dtype out = torch.empty( self.max_size, *data.shape, @@ -531,12 +531,12 @@ class LazyMemmapStorage(LazyTensorStorage): >>> storage.get(0) TensorDict( fields={ - some data: MemmapTensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False), + some data: MemoryMappedTensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False), some: TensorDict( fields={ nested: TensorDict( fields={ - data: MemmapTensor(shape=torch.Size([11, 12]), device=cpu, dtype=torch.float32, is_shared=False)}, + data: MemoryMappedTensor(shape=torch.Size([11, 12]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([11]), device=cpu, is_shared=False)}, @@ -560,8 +560,8 @@ class LazyMemmapStorage(LazyTensorStorage): >>> storage.set(range(10), data) >>> storage.get(0) MyClass( - bar=MemmapTensor(shape=torch.Size([11, 12]), device=cpu, dtype=torch.float32, is_shared=False), - foo=MemmapTensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False), + bar=MemoryMappedTensor(shape=torch.Size([11, 12]), device=cpu, dtype=torch.float32, is_shared=False), + foo=MemoryMappedTensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False), batch_size=torch.Size([11]), device=cpu, is_shared=False) @@ -603,7 +603,12 @@ def load_state_dict(self, state_dict): if isinstance(self._storage, torch.Tensor): _mem_map_tensor_as_tensor(self._storage).copy_(_storage) elif self._storage is None: - self._storage = MemmapTensor(_storage) + self._storage = _make_memmap( + _storage, + path=self.scratch_dir + "/tensor.memmap" + if self.scratch_dir is not None + else None, + ) else: raise RuntimeError( f"Cannot copy a storage of type {type(_storage)} onto another of type {type(self._storage)}" @@ -657,9 +662,13 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None: ) else: # If not a tensorclass/tensordict, it must be a tensor(-like) - # if Tensor, we just create a MemmapTensor of the desired shape, device and dtype - out = MemmapTensor( - self.max_size, *data.shape, device=self.device, dtype=data.dtype + # if Tensor, we just create a MemoryMappedTensor of the desired shape, device and dtype + out = _make_empty_memmap( + (self.max_size, *data.shape), + dtype=data.dtype, + path=self.scratch_dir + "/tensor.memmap" + if self.scratch_dir is not None + else None, ) if VERBOSE: filesize = os.path.getsize(out.filename) / 1024 / 1024 @@ -685,6 +694,7 @@ def _mem_map_tensor_as_tensor(mem_map_tensor: MemmapTensor) -> torch.Tensor: f"Supported backends are {_CKPT_BACKEND.backends}" ) if isinstance(mem_map_tensor, torch.Tensor): + # This will account for MemoryMappedTensors return mem_map_tensor if _CKPT_BACKEND == "torchsnapshot": # TorchSnapshot doesn't know how to stream MemmapTensor, so we view MemmapTensor @@ -745,25 +755,27 @@ def _collate_list_tensordict(x): return out -def _collate_contiguous(x): +def _collate_id(x): return x -def _collate_as_tensor(x): - return x.as_tensor() - - def _get_default_collate(storage, _is_tensordict=False): if isinstance(storage, ListStorage): if _is_tensordict: return _collate_list_tensordict else: return torch.utils.data._utils.collate.default_collate - elif isinstance(storage, LazyMemmapStorage): - return _collate_as_tensor - elif isinstance(storage, (TensorStorage,)): - return _collate_contiguous + elif isinstance(storage, TensorStorage): + return _collate_id else: raise NotImplementedError( f"Could not find a default collate_fn for storage {type(storage)}." ) + + +def _make_memmap(tensor, path): + return MemoryMappedTensor.from_tensor(tensor, filename=path) + + +def _make_empty_memmap(shape, dtype, path): + return MemoryMappedTensor.empty(shape=shape, dtype=dtype, filename=path) diff --git a/torchrl/data/rlhf/dataset.py b/torchrl/data/rlhf/dataset.py index db2b6a418d6..adc2ddcf0d7 100644 --- a/torchrl/data/rlhf/dataset.py +++ b/torchrl/data/rlhf/dataset.py @@ -77,8 +77,8 @@ class TokenizedDatasetLoader: >>> print(dataset) TensorDict( fields={ - attention_mask: MemmapTensor(shape=torch.Size([185068, 550]), device=cpu, dtype=torch.int64, is_shared=False), - input_ids: MemmapTensor(shape=torch.Size([185068, 550]), device=cpu, dtype=torch.int64, is_shared=False)}, + attention_mask: MemoryMappedTensor(shape=torch.Size([185068, 550]), device=cpu, dtype=torch.int64, is_shared=False), + input_ids: MemoryMappedTensor(shape=torch.Size([185068, 550]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([185068]), device=None, is_shared=False) @@ -270,8 +270,8 @@ def dataset_to_tensordict( fields={ prefix: TensorDict( fields={ - labels: MemmapTensor(shape=torch.Size([10, 11]), device=cpu, dtype=torch.float32, is_shared=False), - tokens: MemmapTensor(shape=torch.Size([10, 11]), device=cpu, dtype=torch.int64, is_shared=False)}, + labels: MemoryMappedTensor(shape=torch.Size([10, 11]), device=cpu, dtype=torch.float32, is_shared=False), + tokens: MemoryMappedTensor(shape=torch.Size([10, 11]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([10]), device=None, is_shared=False)}, diff --git a/torchrl/data/rlhf/prompt.py b/torchrl/data/rlhf/prompt.py index d534a95379e..d50653c9967 100644 --- a/torchrl/data/rlhf/prompt.py +++ b/torchrl/data/rlhf/prompt.py @@ -74,10 +74,10 @@ def from_dataset( >>> data = PromptData.from_dataset("train") >>> print(data) PromptDataTLDR( - attention_mask=MemmapTensor(shape=torch.Size([116722, 550]), device=cpu, dtype=torch.int64, is_shared=False), - input_ids=MemmapTensor(shape=torch.Size([116722, 550]), device=cpu, dtype=torch.int64, is_shared=False), - prompt_rindex=MemmapTensor(shape=torch.Size([116722]), device=cpu, dtype=torch.int64, is_shared=False), - labels=MemmapTensor(shape=torch.Size([116722, 550]), device=cpu, dtype=torch.int64, is_shared=False), + attention_mask=MemoryMappedTensor(shape=torch.Size([116722, 550]), device=cpu, dtype=torch.int64, is_shared=False), + input_ids=MemoryMappedTensor(shape=torch.Size([116722, 550]), device=cpu, dtype=torch.int64, is_shared=False), + prompt_rindex=MemoryMappedTensor(shape=torch.Size([116722]), device=cpu, dtype=torch.int64, is_shared=False), + labels=MemoryMappedTensor(shape=torch.Size([116722, 550]), device=cpu, dtype=torch.int64, is_shared=False), logits=None, loss=None, batch_size=torch.Size([116722]), diff --git a/torchrl/data/rlhf/reward.py b/torchrl/data/rlhf/reward.py index e7843e02f46..20f379ef659 100644 --- a/torchrl/data/rlhf/reward.py +++ b/torchrl/data/rlhf/reward.py @@ -41,16 +41,16 @@ class PairwiseDataset: >>> print(data) PairwiseDataset( chosen_data=RewardData( - attention_mask=MemmapTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False), - input_ids=MemmapTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False), + attention_mask=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False), + input_ids=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False), rewards=None, end_scores=None, batch_size=torch.Size([92534]), device=None, is_shared=False), rejected_data=RewardData( - attention_mask=MemmapTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False), - input_ids=MemmapTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False), + attention_mask=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False), + input_ids=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False), rewards=None, end_scores=None, batch_size=torch.Size([92534]), @@ -97,16 +97,16 @@ def from_dataset( >>> print(data) PairwiseDataset( chosen_data=RewardData( - attention_mask=MemmapTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False), - input_ids=MemmapTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False), + attention_mask=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False), + input_ids=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False), rewards=None, end_scores=None, batch_size=torch.Size([92534]), device=None, is_shared=False), rejected_data=RewardData( - attention_mask=MemmapTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False), - input_ids=MemmapTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False), + attention_mask=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False), + input_ids=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False), rewards=None, end_scores=None, batch_size=torch.Size([92534]),