Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] MemoryMappedTensor #541

Merged
merged 53 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
d342971
amend
vmoens Jul 21, 2023
b16837e
init
vmoens Jul 21, 2023
be58682
Merge branch 'main' into memmap_tensor_refact
vmoens Aug 11, 2023
374f0b2
amend
vmoens Aug 11, 2023
7046a83
amend
vmoens Aug 11, 2023
5ef4f63
amend
vmoens Aug 11, 2023
75bc4d9
amend
vmoens Aug 11, 2023
a6103b8
amend
vmoens Aug 11, 2023
945941f
amend
vmoens Aug 11, 2023
624d1dc
amend
vmoens Aug 11, 2023
2967f01
amend
vmoens Aug 11, 2023
0dc76bd
amend
vmoens Aug 11, 2023
2c59eaa
amend
vmoens Aug 11, 2023
cf55bfe
amend
vmoens Aug 12, 2023
92c0d73
amend
vmoens Aug 12, 2023
cd16f38
amend
vmoens Aug 12, 2023
d312ff2
amend
vmoens Aug 12, 2023
3e68b2c
amend
vmoens Aug 12, 2023
1c0694f
amend
vmoens Aug 30, 2023
751d9ec
amend
vmoens Oct 10, 2023
2b70ba5
amend
vmoens Oct 10, 2023
4fc0bcb
amend
vmoens Oct 10, 2023
fd8c272
tensordict_
vmoens Oct 10, 2023
90e0d3b
Merge branch 'main' into memmap_tensor_refact
vmoens Oct 11, 2023
b3b2856
amend
vmoens Oct 11, 2023
ef280f9
amend
vmoens Oct 12, 2023
21d5268
amend
vmoens Oct 12, 2023
d79f4ee
amend
vmoens Oct 12, 2023
11f7f32
amend
vmoens Oct 12, 2023
f96c05c
amend
vmoens Oct 12, 2023
b25485e
amend
vmoens Oct 12, 2023
5275f3a
amend
vmoens Oct 12, 2023
1e45b4a
amend
vmoens Oct 18, 2023
c8132be
amend
vmoens Oct 18, 2023
0403add
amend
vmoens Oct 18, 2023
7d174a4
amend
vmoens Oct 18, 2023
de0c028
amend
vmoens Oct 19, 2023
6b40de8
amend
vmoens Oct 19, 2023
1752bb6
amend
vmoens Oct 19, 2023
319e1bd
amend
vmoens Oct 19, 2023
af7f289
amend
vmoens Oct 19, 2023
3e31297
amend
vmoens Oct 19, 2023
a481f89
amend
vmoens Oct 19, 2023
1d13105
amend
vmoens Oct 21, 2023
19ae8e0
amend
vmoens Oct 30, 2023
ba931ef
Merge remote-tracking branch 'origin/main' into memmap_tensor_refact
vmoens Nov 14, 2023
f2e624d
amend
vmoens Nov 14, 2023
d7668fa
amend
vmoens Nov 14, 2023
0528111
amend
vmoens Nov 14, 2023
7413381
amend
vmoens Nov 14, 2023
029de18
amend
vmoens Nov 14, 2023
a5c5168
amend
vmoens Nov 14, 2023
65845b4
amend
vmoens Nov 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
amend
  • Loading branch information
vmoens committed Oct 12, 2023
commit d79f4ee4b7a9bdb91ef92e2e72e232f1097453ae
20 changes: 8 additions & 12 deletions tensordict/memmap_refact.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class MemoryMappedTensor(torch.Tensor):
index: Any
parent_shape: torch.Size

def __new__(cls, tensor_or_file, handler=None, dtype=None, shape=None, index=None, device=None):
def __new__(cls, tensor_or_file, *, dtype=None, shape=None, index=None, device=None, handler=None):
if device is not None and torch.device(device).type != "cpu":
raise ValueError(f"{cls} device must be cpu!")
if isinstance(tensor_or_file, str):
Expand All @@ -63,7 +63,7 @@ def __init__(self, tensor_or_file, handler=None, dtype=None, shape=None, device=

@classmethod
def from_tensor(
cls, tensor, dir=None, prefix=None, filename=None
cls, tensor, *, dir=None, filename=None
):
if isinstance(tensor, MemoryMappedTensor):
if dir is None and (
Expand Down Expand Up @@ -112,13 +112,9 @@ def from_tensor(
out.copy_(tensor)
return out

# def __setstate__(self, state: dict[str, Any]) -> None:
# filename = state["filename"]
# handler = state['handler']
# if filename is not None:
# return self.from_filename(filename, state['dtype'], state['shape'])
# else:
# return self.from_handler(handler, state['dtype'], state['shape'])
@classmethod
def empty_like(cls, tensor, *, filename=None):
return cls.from_tensor(torch.zeros((), dtype=tensor.dtype, device=tensor.device).expand_as(tensor), filename=filename)

@classmethod
def from_filename(cls, filename, dtype, shape, index):
Expand Down Expand Up @@ -148,22 +144,22 @@ def from_handler(cls, handler, dtype, shape, index):
return out

def __reduce__(self):
if getattr(self, "handler", None) is not None:
if getattr(self, "_handler", None) is not None:
return type(self).from_handler, (
self._handler,
self.dtype,
self.parent_shape,
self.index,
)
elif getattr(self, "filename", None) is not None:
elif getattr(self, "_filename", None) is not None:
return type(self).from_filename, (
self._filename,
self.dtype,
self.parent_shape,
self.index,
)
else:
raise RuntimeError
raise RuntimeError("Could not find handler or filename.")

def __getitem__(self, item):
out = super().__getitem__(item)
Expand Down
32 changes: 16 additions & 16 deletions tensordict/tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@
import torch
from functorch import dim as ftdim
from tensordict._tensordict import _unravel_key_to_tuple
from tensordict.memmap import memmap_tensor_as_tensor, MemmapTensor
from tensordict.memmap_refact import MemoryMappedTensor
from tensordict.memmap import memmap_tensor_as_tensor, MemmapTensor as _MemmapTensor
from tensordict.memmap_refact import MemoryMappedTensor as MemmapTensor
from tensordict.utils import (
_device,
_dtype,
Expand Down Expand Up @@ -132,13 +132,13 @@ def __bool__(self):
LAZY_TD_HANDLED_FUNCTIONS: dict[Callable, Callable] = {}
CompatibleType = Union[
Tensor,
MemmapTensor,
_MemmapTensor,
] # None? # leaves space for TensorDictBase

if _has_torchrec:
CompatibleType = Union[
Tensor,
MemmapTensor,
_MemmapTensor,
KeyedJaggedTensor,
]
_STR_MIXED_INDEX_ERROR = "Received a mixed string-non string index. Only string-only or string-free indices are supported."
Expand Down Expand Up @@ -1142,8 +1142,8 @@ def _send(self, dst: int, _tag: int = -1, pseudo_rand: bool = False) -> int:
elif _is_tensor_collection(value.__class__):
_tag = value._send(dst, _tag=_tag, pseudo_rand=pseudo_rand)
continue
elif isinstance(value, MemmapTensor):
value = value.as_tensor()
# elif isinstance(value, MemmapTensor):
# value = value.as_tensor()
else:
raise NotImplementedError(f"Type {type(value)} is not supported.")
if not pseudo_rand:
Expand Down Expand Up @@ -1181,8 +1181,8 @@ def _recv(self, src: int, _tag: int = -1, pseudo_rand: bool = False) -> int:
elif _is_tensor_collection(value.__class__):
_tag = value._recv(src, _tag=_tag, pseudo_rand=pseudo_rand)
continue
elif isinstance(value, MemmapTensor):
value = value.as_tensor()
# elif isinstance(value, MemmapTensor):
# value = value.as_tensor()
else:
raise NotImplementedError(f"Type {type(value)} is not supported.")
if not pseudo_rand:
Expand Down Expand Up @@ -1294,8 +1294,8 @@ def _isend(
continue
elif isinstance(value, Tensor):
pass
elif isinstance(value, MemmapTensor):
value = value.as_tensor()
# elif isinstance(value, MemmapTensor):
# value = value.as_tensor()
else:
raise NotImplementedError(f"Type {type(value)} is not supported.")
if not pseudo_rand:
Expand Down Expand Up @@ -1365,8 +1365,8 @@ def _irecv(
pseudo_rand=pseudo_rand,
)
continue
elif isinstance(value, MemmapTensor):
value = value.as_tensor()
# elif isinstance(value, MemmapTensor):
# value = value.as_tensor()
elif isinstance(value, Tensor):
pass
else:
Expand Down Expand Up @@ -1415,8 +1415,8 @@ def _reduce(
_future_list=_future_list,
)
continue
elif isinstance(value, MemmapTensor):
value = value.as_tensor()
# elif isinstance(value, MemmapTensor):
# value = value.as_tensor()
elif isinstance(value, Tensor):
pass
else:
Expand Down Expand Up @@ -2631,8 +2631,8 @@ def to_h5(
>>>
>>> from tensordict import TensorDict, MemmapTensor
>>> td = TensorDict({
... "a": MemmapTensor(1_000_000),
... "b": {"c": MemmapTensor(1_000_000, 3)},
... "a": MemmapTensor.from_tensor(torch.zeros(()).expand(1_000_000)),
... "b": {"c": MemmapTensor.from_tensor(torch.zeros(()).expand(1_000_000, 3))},
... }, [1_000_000])
>>>
>>> file = tempfile.NamedTemporaryFile()
Expand Down
133 changes: 63 additions & 70 deletions test/test_memmap2.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,8 @@ def test_handler():


def test_memmap_from_memmap():
mt2 = MemmapTensor.from_tensor(MemmapTensor(4, 3, 2, 1))
mt = MemmapTensor.from_tensor(torch.zeros(()).expand(4, 3, 2, 1))
mt2 = MemmapTensor.from_tensor(mt)
assert mt2.squeeze(-1).shape == torch.Size([4, 3, 2])


Expand All @@ -399,51 +400,43 @@ def test_memmap_cast():

@pytest.fixture
def dummy_memmap():
return MemmapTensor.from_tensor(torch.zeros(10, 11))
return MemmapTensor.from_tensor(torch.randn(10, 11))


@pytest.mark.parametrize("device", get_available_devices())
class TestOps:
def test_eq(self, device, dummy_memmap):
memmap = dummy_memmap.to(device)
def test_eq(self, dummy_memmap):
memmap = dummy_memmap
assert (memmap == memmap.clone()).all()
assert (memmap.clone() == memmap).all()
if device.type == "cpu":
assert (memmap == memmap.as_tensor()).all()
assert (memmap.as_tensor() == memmap).all()
else:
assert (memmap == memmap._tensor).all()
assert (memmap._tensor == memmap).all()

def test_fill_(self, device, dummy_memmap):
memmap = dummy_memmap.to(device)
assert (memmap.fill_(1.0) == 1).all()

def test_copy_(self, device, dummy_memmap):
memmap = dummy_memmap.to(device)
assert (memmap.copy_(torch.ones(10, 11, device=device)) == 1).all()
assert (torch.ones(10, 11, device=device).copy_(memmap) == 1).all()

def test_or(self, device):
memmap = MemmapTensor.from_tensor(torch.ones(10, 11, dtype=torch.bool)).to(
device
)

def test_fill_(self, dummy_memmap):
memmap = dummy_memmap.fill_(1.0)
assert (memmap == 1).all()
assert isinstance(memmap, MemmapTensor)

def test_copy_(self, dummy_memmap):
memmap = dummy_memmap.copy_(torch.ones(10, 11))
assert (memmap == 1).all()
assert isinstance(memmap, MemmapTensor)
# check that memmap can be put in a tensor
assert (torch.ones(10, 11).copy_(memmap) == 1).all()

def test_or(self):
memmap = MemmapTensor.from_tensor(torch.ones(10, 11, dtype=torch.bool))
assert (memmap | (~memmap)).all()

def test_ne(self, device):
memmap = MemmapTensor.from_tensor(torch.ones(10, 11, dtype=torch.bool)).to(
device
)
def test_ne(self):
memmap = MemmapTensor.from_tensor(torch.ones(10, 11, dtype=torch.bool))
assert (memmap != ~memmap).all()


def test_memmap_del(tmpdir):
t = torch.tensor([1])
m = MemmapTensor.from_tensor(t, filename=tmpdir / "tensor")
# filename = m.filename
assert os.path.isfile(tmpdir / "tensor")
del m
assert not os.path.isfile(tmpdir / "tensor")
# def test_memmap_del(tmpdir):
# t = torch.tensor([1])
# m = MemmapTensor.from_tensor(t, filename=tmpdir / "tensor")
# # filename = m.filename
# assert os.path.isfile(tmpdir / "tensor")
# del m
# assert not os.path.isfile(tmpdir / "tensor")

# @pytest.mark.parametrize("value", [True, False])
# def test_memmap_ownership_2pass(value):
Expand Down Expand Up @@ -498,40 +491,40 @@ def test_memmap_del(tmpdir):
# finally:
# p.join()
# queue.close()
@pytest.mark.parametrize(
"mode", ["r", "r+", "w+", "c", "readonly", "readwrite", "write", "copyonwrite"]
)
def test_mode(mode, tmp_path):
mt = MemmapTensor(10, dtype=torch.float32, filename=tmp_path / "test.memmap")
mt[:] = torch.ones(10) * 1.5
del mt

if mode in ("r", "readonly"):
with pytest.raises(ValueError, match=r"Accepted values for mode are"):
MemmapTensor(
10, dtype=torch.float32, filename=tmp_path / "test.memmap", mode=mode
)
return
mt = MemmapTensor(
10, dtype=torch.float32, filename=tmp_path / "test.memmap", mode=mode
)
if mode in ("r+", "readwrite", "c", "copyonwrite"):
# data in memmap persists
assert (mt.as_tensor() == 1.5).all()
elif mode in ("w+", "write"):
# memmap is initialized to zero
assert (mt.as_tensor() == 0).all()

mt[:] = torch.ones(10) * 2.5
assert (mt.as_tensor() == 2.5).all()
del mt

mt2 = MemmapTensor(10, dtype=torch.float32, filename=tmp_path / "test.memmap")
if mode in ("c", "copyonwrite"):
# tensor was only mutated in memory, not on disk
assert (mt2.as_tensor() == 1.5).all()
else:
assert (mt2.as_tensor() == 2.5).all()
# @pytest.mark.parametrize(
# "mode", ["r", "r+", "w+", "c", "readonly", "readwrite", "write", "copyonwrite"]
# )
# def test_mode(mode, tmp_path):
# mt = MemmapTensor(10, dtype=torch.float32, filename=tmp_path / "test.memmap")
# mt[:] = torch.ones(10) * 1.5
# del mt
#
# if mode in ("r", "readonly"):
# with pytest.raises(ValueError, match=r"Accepted values for mode are"):
# MemmapTensor(
# 10, dtype=torch.float32, filename=tmp_path / "test.memmap", mode=mode
# )
# return
# mt = MemmapTensor(
# 10, dtype=torch.float32, filename=tmp_path / "test.memmap", mode=mode
# )
# if mode in ("r+", "readwrite", "c", "copyonwrite"):
# # data in memmap persists
# assert (mt.as_tensor() == 1.5).all()
# elif mode in ("w+", "write"):
# # memmap is initialized to zero
# assert (mt.as_tensor() == 0).all()
#
# mt[:] = torch.ones(10) * 2.5
# assert (mt.as_tensor() == 2.5).all()
# del mt
#
# mt2 = MemmapTensor(10, dtype=torch.float32, filename=tmp_path / "test.memmap")
# if mode in ("c", "copyonwrite"):
# # tensor was only mutated in memory, not on disk
# assert (mt2.as_tensor() == 1.5).all()
# else:
# assert (mt2.as_tensor() == 2.5).all()


if __name__ == "__main__":
Expand Down
Loading