Skip to content

Commit

Permalink
[Refactor] Deprecate direct usage of memmap tensors (pytorch#1684)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Nov 15, 2023
1 parent e1eb69d commit 0badd6e
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 43 deletions.
6 changes: 3 additions & 3 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ def test_vecenvs_wrapper(self, envname):
["HalfCheetah-v4", "CartPole-v1", "ALE/Pong-v5"]
+ (["FetchReach-v2"] if _has_gym_robotics else []),
)
@pytest.mark.flaky(reruns=3, reruns_delay=1)
@pytest.mark.flaky(reruns=8, reruns_delay=1)
def test_vecenvs_env(self, envname):
from _utils_internal import rollout_consistency_assertion

Expand Down Expand Up @@ -1897,10 +1897,10 @@ def test_direct_download(self, task):
assert len(keys)
assert_allclose_td(
data_direct._storage._storage.select(*keys).apply(
lambda t: t.as_tensor().float()
lambda t: t.float()
),
data_d4rl._storage._storage.select(*keys).apply(
lambda t: t.as_tensor().float()
lambda t: t.float()
),
)

Expand Down
5 changes: 3 additions & 2 deletions test/test_rb_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
import argparse
import os

import sys
import time

Expand All @@ -22,10 +23,10 @@


class ReplayBufferNode(RemoteTensorDictReplayBuffer):
def __init__(self, capacity: int):
def __init__(self, capacity: int, scratch_dir=None):
super().__init__(
storage=LazyMemmapStorage(
max_size=capacity, scratch_dir="/tmp/", device=torch.device("cpu")
max_size=capacity, scratch_dir=scratch_dir, device=torch.device("cpu")
),
sampler=RandomSampler(),
writer=RoundRobinWriter(),
Expand Down
11 changes: 8 additions & 3 deletions test/test_rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@
import torch.nn.functional as F

from _utils_internal import get_default_devices
from tensordict import is_tensor_collection, MemmapTensor, TensorDict, TensorDictBase
from tensordict import (
is_tensor_collection,
MemoryMappedTensor,
TensorDict,
TensorDictBase,
)
from tensordict.nn import TensorDictModule
from torchrl.data.rlhf import TensorDictTokenizer
from torchrl.data.rlhf.dataset import (
Expand Down Expand Up @@ -188,8 +193,8 @@ def test_dataset_to_tensordict(tmpdir, suffix):
else:
assert ("c", "d", "a") in td.keys(True)
assert ("c", "d", "b") in td.keys(True)
assert isinstance(td.get((suffix, "a")), MemmapTensor)
assert isinstance(td.get((suffix, "b")), MemmapTensor)
assert isinstance(td.get((suffix, "a")), MemoryMappedTensor)
assert isinstance(td.get((suffix, "b")), MemoryMappedTensor)


@pytest.mark.skipif(
Expand Down
50 changes: 31 additions & 19 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import torch
from tensordict import is_tensorclass
from tensordict.memmap import MemmapTensor
from tensordict.memmap import MemmapTensor, MemoryMappedTensor
from tensordict.tensordict import is_tensor_collection, TensorDict, TensorDictBase
from tensordict.utils import expand_right

Expand Down Expand Up @@ -482,7 +482,7 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None:
if self.device == "auto":
self.device = data.device
if isinstance(data, torch.Tensor):
# if Tensor, we just create a MemmapTensor of the desired shape, device and dtype
# if Tensor, we just create a MemoryMappedTensor of the desired shape, device and dtype
out = torch.empty(
self.max_size,
*data.shape,
Expand Down Expand Up @@ -531,12 +531,12 @@ class LazyMemmapStorage(LazyTensorStorage):
>>> storage.get(0)
TensorDict(
fields={
some data: MemmapTensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False),
some data: MemoryMappedTensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False),
some: TensorDict(
fields={
nested: TensorDict(
fields={
data: MemmapTensor(shape=torch.Size([11, 12]), device=cpu, dtype=torch.float32, is_shared=False)},
data: MemoryMappedTensor(shape=torch.Size([11, 12]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([11]),
device=cpu,
is_shared=False)},
Expand All @@ -560,8 +560,8 @@ class LazyMemmapStorage(LazyTensorStorage):
>>> storage.set(range(10), data)
>>> storage.get(0)
MyClass(
bar=MemmapTensor(shape=torch.Size([11, 12]), device=cpu, dtype=torch.float32, is_shared=False),
foo=MemmapTensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False),
bar=MemoryMappedTensor(shape=torch.Size([11, 12]), device=cpu, dtype=torch.float32, is_shared=False),
foo=MemoryMappedTensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False),
batch_size=torch.Size([11]),
device=cpu,
is_shared=False)
Expand Down Expand Up @@ -603,7 +603,12 @@ def load_state_dict(self, state_dict):
if isinstance(self._storage, torch.Tensor):
_mem_map_tensor_as_tensor(self._storage).copy_(_storage)
elif self._storage is None:
self._storage = MemmapTensor(_storage)
self._storage = _make_memmap(
_storage,
path=self.scratch_dir + "/tensor.memmap"
if self.scratch_dir is not None
else None,
)
else:
raise RuntimeError(
f"Cannot copy a storage of type {type(_storage)} onto another of type {type(self._storage)}"
Expand Down Expand Up @@ -657,9 +662,13 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None:
)
else:
# If not a tensorclass/tensordict, it must be a tensor(-like)
# if Tensor, we just create a MemmapTensor of the desired shape, device and dtype
out = MemmapTensor(
self.max_size, *data.shape, device=self.device, dtype=data.dtype
# if Tensor, we just create a MemoryMappedTensor of the desired shape, device and dtype
out = _make_empty_memmap(
(self.max_size, *data.shape),
dtype=data.dtype,
path=self.scratch_dir + "/tensor.memmap"
if self.scratch_dir is not None
else None,
)
if VERBOSE:
filesize = os.path.getsize(out.filename) / 1024 / 1024
Expand All @@ -685,6 +694,7 @@ def _mem_map_tensor_as_tensor(mem_map_tensor: MemmapTensor) -> torch.Tensor:
f"Supported backends are {_CKPT_BACKEND.backends}"
)
if isinstance(mem_map_tensor, torch.Tensor):
# This will account for MemoryMappedTensors
return mem_map_tensor
if _CKPT_BACKEND == "torchsnapshot":
# TorchSnapshot doesn't know how to stream MemmapTensor, so we view MemmapTensor
Expand Down Expand Up @@ -745,25 +755,27 @@ def _collate_list_tensordict(x):
return out


def _collate_contiguous(x):
def _collate_id(x):
return x


def _collate_as_tensor(x):
return x.as_tensor()


def _get_default_collate(storage, _is_tensordict=False):
if isinstance(storage, ListStorage):
if _is_tensordict:
return _collate_list_tensordict
else:
return torch.utils.data._utils.collate.default_collate
elif isinstance(storage, LazyMemmapStorage):
return _collate_as_tensor
elif isinstance(storage, (TensorStorage,)):
return _collate_contiguous
elif isinstance(storage, TensorStorage):
return _collate_id
else:
raise NotImplementedError(
f"Could not find a default collate_fn for storage {type(storage)}."
)


def _make_memmap(tensor, path):
return MemoryMappedTensor.from_tensor(tensor, filename=path)


def _make_empty_memmap(shape, dtype, path):
return MemoryMappedTensor.empty(shape=shape, dtype=dtype, filename=path)
8 changes: 4 additions & 4 deletions torchrl/data/rlhf/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ class TokenizedDatasetLoader:
>>> print(dataset)
TensorDict(
fields={
attention_mask: MemmapTensor(shape=torch.Size([185068, 550]), device=cpu, dtype=torch.int64, is_shared=False),
input_ids: MemmapTensor(shape=torch.Size([185068, 550]), device=cpu, dtype=torch.int64, is_shared=False)},
attention_mask: MemoryMappedTensor(shape=torch.Size([185068, 550]), device=cpu, dtype=torch.int64, is_shared=False),
input_ids: MemoryMappedTensor(shape=torch.Size([185068, 550]), device=cpu, dtype=torch.int64, is_shared=False)},
batch_size=torch.Size([185068]),
device=None,
is_shared=False)
Expand Down Expand Up @@ -270,8 +270,8 @@ def dataset_to_tensordict(
fields={
prefix: TensorDict(
fields={
labels: MemmapTensor(shape=torch.Size([10, 11]), device=cpu, dtype=torch.float32, is_shared=False),
tokens: MemmapTensor(shape=torch.Size([10, 11]), device=cpu, dtype=torch.int64, is_shared=False)},
labels: MemoryMappedTensor(shape=torch.Size([10, 11]), device=cpu, dtype=torch.float32, is_shared=False),
tokens: MemoryMappedTensor(shape=torch.Size([10, 11]), device=cpu, dtype=torch.int64, is_shared=False)},
batch_size=torch.Size([10]),
device=None,
is_shared=False)},
Expand Down
8 changes: 4 additions & 4 deletions torchrl/data/rlhf/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ def from_dataset(
>>> data = PromptData.from_dataset("train")
>>> print(data)
PromptDataTLDR(
attention_mask=MemmapTensor(shape=torch.Size([116722, 550]), device=cpu, dtype=torch.int64, is_shared=False),
input_ids=MemmapTensor(shape=torch.Size([116722, 550]), device=cpu, dtype=torch.int64, is_shared=False),
prompt_rindex=MemmapTensor(shape=torch.Size([116722]), device=cpu, dtype=torch.int64, is_shared=False),
labels=MemmapTensor(shape=torch.Size([116722, 550]), device=cpu, dtype=torch.int64, is_shared=False),
attention_mask=MemoryMappedTensor(shape=torch.Size([116722, 550]), device=cpu, dtype=torch.int64, is_shared=False),
input_ids=MemoryMappedTensor(shape=torch.Size([116722, 550]), device=cpu, dtype=torch.int64, is_shared=False),
prompt_rindex=MemoryMappedTensor(shape=torch.Size([116722]), device=cpu, dtype=torch.int64, is_shared=False),
labels=MemoryMappedTensor(shape=torch.Size([116722, 550]), device=cpu, dtype=torch.int64, is_shared=False),
logits=None,
loss=None,
batch_size=torch.Size([116722]),
Expand Down
16 changes: 8 additions & 8 deletions torchrl/data/rlhf/reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,16 @@ class PairwiseDataset:
>>> print(data)
PairwiseDataset(
chosen_data=RewardData(
attention_mask=MemmapTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
input_ids=MemmapTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
attention_mask=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
input_ids=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
rewards=None,
end_scores=None,
batch_size=torch.Size([92534]),
device=None,
is_shared=False),
rejected_data=RewardData(
attention_mask=MemmapTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
input_ids=MemmapTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
attention_mask=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
input_ids=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
rewards=None,
end_scores=None,
batch_size=torch.Size([92534]),
Expand Down Expand Up @@ -97,16 +97,16 @@ def from_dataset(
>>> print(data)
PairwiseDataset(
chosen_data=RewardData(
attention_mask=MemmapTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
input_ids=MemmapTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
attention_mask=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
input_ids=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
rewards=None,
end_scores=None,
batch_size=torch.Size([92534]),
device=None,
is_shared=False),
rejected_data=RewardData(
attention_mask=MemmapTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
input_ids=MemmapTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
attention_mask=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
input_ids=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
rewards=None,
end_scores=None,
batch_size=torch.Size([92534]),
Expand Down

0 comments on commit 0badd6e

Please sign in to comment.