Skip to content

Commit

Permalink
[Deprecation] Deprecate ambiguous device for memmap replay buffer (py…
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Oct 10, 2023
1 parent 5e81445 commit 70c650e
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
16 changes: 14 additions & 2 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,8 +586,14 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None:
data.clone()
.expand(self.max_size, *data.shape)
.memmap_like(prefix=self.scratch_dir)
.to(self.device)
)
if self.device.type != "cpu":
warnings.warn(
"Support for Memmap device other than CPU will be deprecated in v0.4.0.",
category=DeprecationWarning,
)
out = out.to(self.device).memmap_()

for key, tensor in sorted(
out.items(include_nested=True, leaves_only=True), key=str
):
Expand All @@ -603,8 +609,14 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None:
data.clone()
.expand(self.max_size, *data.shape)
.memmap_like(prefix=self.scratch_dir)
.to(self.device)
)
if self.device.type != "cpu":
warnings.warn(
"Support for Memmap device other than CPU will be deprecated in v0.4.0.",
category=DeprecationWarning,
)
out = out.to(self.device).memmap_()

for key, tensor in sorted(
out.items(include_nested=True, leaves_only=True), key=str
):
Expand Down
2 changes: 1 addition & 1 deletion tutorials/sphinx-tutorials/pretrained_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
#
from torchrl.data import LazyMemmapStorage, ReplayBuffer

storage = LazyMemmapStorage(1000, device=device)
storage = LazyMemmapStorage(1000)
rb = ReplayBuffer(storage=storage, transform=r3m)

##############################################################################
Expand Down

0 comments on commit 70c650e

Please sign in to comment.