Skip to content

Commit

Permalink
Speedup tensordicts (pytorch#124)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored May 13, 2022
1 parent d103545 commit 4be4246
Show file tree
Hide file tree
Showing 8 changed files with 184 additions and 92 deletions.
17 changes: 13 additions & 4 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from torch import multiprocessing as mp
from torchrl.data import NdBoundedTensorSpec, CompositeSpec
from torchrl.data import TensorDict
from torchrl.envs import EnvCreator
from torchrl.envs import EnvCreator, SerialEnv
from torchrl.envs import GymEnv, ParallelEnv
from torchrl.envs import (
Resize,
Expand Down Expand Up @@ -278,14 +278,23 @@ def test_parallelenv_vecnorm():


@pytest.mark.skipif(not _has_gym, reason="no gym library found")
@pytest.mark.parametrize("parallel", [False, True])
@pytest.mark.parametrize(
"parallel",
[
True,
False,
None,
],
)
def test_vecnorm(parallel, thr=0.2, N=200): # 10000):
torch.manual_seed(0)

if parallel:
if parallel is None:
env = GymEnv("Pendulum-v1")
elif parallel:
env = ParallelEnv(num_workers=5, create_env_fn=lambda: GymEnv("Pendulum-v1"))
else:
env = GymEnv("Pendulum-v1")
env = SerialEnv(num_workers=5, create_env_fn=lambda: GymEnv("Pendulum-v1"))

env.set_seed(0)
t = VecNorm()
Expand Down
12 changes: 9 additions & 3 deletions torchrl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,25 @@ def decorated_fn(*args, **kwargs):
return decorated_fn

def __enter__(self):
return
self.t0 = time.time()

def __exit__(self, exc_type, exc_val, exc_tb):
return
t = time.time() - self.t0
self._REG.setdefault(self.name, [0.0, 0])
self._REG.setdefault(self.name, [0.0, 0.0, 0])

count = self._REG[self.name][1]
self._REG[self.name][0] = (self._REG[self.name][0] * count + t) / (count + 1)
self._REG[self.name][1] = count + 1
self._REG[self.name][1] = self._REG[self.name][1] + t
self._REG[self.name][2] = count + 1

@staticmethod
def print():
return
keys = list(timeit._REG)
keys.sort()
for name in keys:
print(f"{name} took {timeit._REG[name][0] * 1000:4.4} msec")
print(
f"{name} took {timeit._REG[name][0] * 1000:4.4} msec (total = {timeit._REG[name][1]} sec)"
)
43 changes: 19 additions & 24 deletions torchrl/data/tensordict/metatensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from __future__ import annotations

import functools
import math
from numbers import Number
from typing import Callable, List, Optional, Sequence, Tuple, Union

import numpy as np
import torch

from torchrl.data.utils import DEVICE_TYPING, INDEX_TYPING
Expand Down Expand Up @@ -65,39 +65,28 @@ def __init__(
*shape: Union[int, torch.Tensor, "MemmapTensor"],
device: Optional[DEVICE_TYPING] = "cpu",
dtype: torch.dtype = torch.get_default_dtype(),
_is_shared: bool = False,
_is_memmap: bool = False,
_is_shared: Optional[bool] = None,
_is_memmap: Optional[bool] = None,
):

if len(shape) == 1 and not isinstance(shape[0], (Number,)):
tensor = shape[0]
shape = tensor.shape
try:
_is_shared = (
tensor.is_shared()
if tensor.device != torch.device("meta")
else _is_shared
)
except: # noqa
_is_shared = False
_is_memmap = (
isinstance(tensor, MemmapTensor)
if tensor.device != torch.device("meta")
else _is_memmap
)
device = tensor.device if tensor.device != torch.device("meta") else device
if _is_shared is None:
_is_shared = tensor.is_shared()
if _is_memmap is None:
_is_memmap = isinstance(tensor, MemmapTensor)
device = tensor.device if not tensor.is_meta else device
dtype = tensor.dtype
if not isinstance(shape, torch.Size):
shape = torch.Size(shape)
self.shape = shape
self.device = (
torch.device(device) if not isinstance(device, torch.device) else device
)
self.device = torch.device(device)
self.dtype = dtype
self._ndim = len(shape)
self._numel = np.prod(shape)
self._is_shared = _is_shared
self._is_memmap = _is_memmap
self._numel = math.prod(shape)
self._is_shared = bool(_is_shared)
self._is_memmap = bool(_is_memmap)
if _is_memmap:
name = "MemmapTensor"
elif _is_shared:
Expand Down Expand Up @@ -241,7 +230,13 @@ def view(
elif not isinstance(shape, torch.Size):
shape = torch.Size(shape)
new_shape = torch.zeros(self.shape, device="meta").view(*shape)
return MetaTensor(new_shape, device=self.device, dtype=self.dtype)
return MetaTensor(
new_shape,
device=self.device,
dtype=self.dtype,
_is_shared=self.is_shared(),
_is_memmap=self.is_memmap(),
)


def _stack_meta(
Expand Down
Loading

0 comments on commit 4be4246

Please sign in to comment.