Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactoring] Import at root to enable vmap monkey-patching #500

Merged
merged 8 commits into from
Sep 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# Get relative file path
# this returns relative path from current file.
import torch.cuda
from torchrl import seed_generator
from torchrl._utils import seed_generator


def get_relative_path(curr_file, *path_components):
Expand Down
2 changes: 1 addition & 1 deletion test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import torch
import torch.nn as nn
from torchrl import seed_generator
from torchrl._utils import seed_generator
from torchrl.data.tensor_specs import (
NdUnboundedContinuousTensorSpec,
NdBoundedTensorSpec,
Expand Down
2 changes: 1 addition & 1 deletion test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
MockSerialEnv,
)
from torch import nn
from torchrl import seed_generator
from torchrl._utils import seed_generator
from torchrl.collectors import SyncDataCollector, aSyncDataCollector
from torchrl.collectors.collectors import (
RandomPolicy,
Expand Down
57 changes: 55 additions & 2 deletions test/test_functorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_vmap_tdmodule(moduletype, batch_params):
if batch_params:
params = params.expand(10, *params.batch_size).contiguous()
buffers = buffers.expand(10, *buffers.batch_size).contiguous()
y = tdmodule(td, params=params, buffers=buffers, vmap=(0, 0, 0))
tdmodule(td, params=params, buffers=buffers, vmap=(0, 0, 0))
else:
raise NotImplementedError
y = td["y"]
Expand Down Expand Up @@ -126,7 +126,7 @@ def test_vmap_tdmodule_nativebuilt(moduletype, batch_params):
if batch_params:
params = params.expand(10, *params.batch_size).contiguous()
buffers = buffers.expand(10, *buffers.batch_size).contiguous()
y = tdmodule(td, params=params, buffers=buffers, vmap=(0, 0, 0))
tdmodule(td, params=params, buffers=buffers, vmap=(0, 0, 0))
else:
raise NotImplementedError
y = td["y"]
Expand Down Expand Up @@ -241,6 +241,59 @@ def test_vmap_tdsequence_nativebuilt(moduletype, batch_params):
assert z.shape == torch.Size([10, 2, 3])


class TestNativeFunctorch:
def test_vamp_basic(self):
class MyModule(torch.nn.Module):
def forward(self, tensordict):
a = tensordict["a"]
return TensorDict(
{"a": a}, tensordict.batch_size, device=tensordict.device
)

tensordict = TensorDict({"a": torch.randn(3)}, []).expand(4)
out = vmap(MyModule(), (0,))(tensordict)
assert out.shape == torch.Size([4])
assert out["a"].shape == torch.Size([4, 3])

def test_vamp_composed(self):
class MyModule(torch.nn.Module):
def forward(self, tensordict, tensor):
a = tensordict["a"]
return (
TensorDict(
{"a": a}, tensordict.batch_size, device=tensordict.device
),
tensor,
)

tensor = torch.randn(3)
tensordict = TensorDict({"a": torch.randn(3, 1)}, [3]).expand(4, 3)
out = vmap(MyModule(), (0, None))(tensordict, tensor)

assert out[0].shape == torch.Size([4, 3])
assert out[1].shape == torch.Size([4, 3])
assert out[0]["a"].shape == torch.Size([4, 3, 1])

def test_vamp_composed_flipped(self):
class MyModule(torch.nn.Module):
def forward(self, tensordict, tensor):
a = tensordict["a"]
return (
TensorDict(
{"a": a}, tensordict.batch_size, device=tensordict.device
),
tensor,
)

tensor = torch.randn(3).expand(4, 3)
tensordict = TensorDict({"a": torch.randn(3, 1)}, [3])
out = vmap(MyModule(), (None, 0))(tensordict, tensor)

assert out[0].shape == torch.Size([4, 3])
assert out[1].shape == torch.Size([4, 3])
assert out[0]["a"].shape == torch.Size([4, 3, 1])


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
2 changes: 1 addition & 1 deletion test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch
from _utils_internal import get_available_devices
from torch import multiprocessing as mp
from torchrl import prod
from torchrl._utils import prod
from torchrl.data import SavedTensorDict, TensorDict, MemmapTensor
from torchrl.data.tensordict.tensordict import (
assert_allclose_td,
Expand Down
2 changes: 1 addition & 1 deletion test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)
from torch import Tensor
from torch import multiprocessing as mp
from torchrl import prod
from torchrl._utils import prod
from torchrl.data import (
NdBoundedTensorSpec,
CompositeSpec,
Expand Down
104 changes: 7 additions & 97 deletions torchrl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,13 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import abc
import collections
import math
import time
import typing
from typing import Optional, Type, Tuple
from warnings import warn

import numpy as np
from torch import multiprocessing as mp

from ._extension import _init_extension


try:
from .version import __version__
except ImportError:
Expand All @@ -38,93 +32,9 @@
)


class timeit:
"""
A dirty but easy to use decorator for profiling code
"""

_REG = {}

def __init__(self, name):
self.name = name

def __call__(self, fn):
def decorated_fn(*args, **kwargs):
with self:
out = fn(*args, **kwargs)
return out

return decorated_fn

def __enter__(self):
self.t0 = time.time()

def __exit__(self, exc_type, exc_val, exc_tb):
t = time.time() - self.t0
self._REG.setdefault(self.name, [0.0, 0.0, 0])

count = self._REG[self.name][1]
self._REG[self.name][0] = (self._REG[self.name][0] * count + t) / (count + 1)
self._REG[self.name][1] = self._REG[self.name][1] + t
self._REG[self.name][2] = count + 1

@staticmethod
def print(prefix=None):
keys = list(timeit._REG)
keys.sort()
for name in keys:
strings = []
if prefix:
strings.append(prefix)
strings.append(
f"{name} took {timeit._REG[name][0] * 1000:4.4} msec (total = {timeit._REG[name][1]} sec)"
)
print(" -- ".join(strings))

@staticmethod
def erase():
for k in timeit._REG:
timeit._REG[k] = [0.0, 0.0, 0]


def _check_for_faulty_process(processes):
terminate = False
for p in processes:
if not p.is_alive():
terminate = True
for _p in processes:
if _p.is_alive():
_p.terminate()
if terminate:
break
if terminate:
raise RuntimeError(
"At least one process failed. Check for more infos in the log."
)


def seed_generator(seed):
max_seed_val = (
2 ** 32 - 1
) # https://discuss.pytorch.org/t/what-is-the-max-seed-you-can-set-up/145688
rng = np.random.default_rng(seed)
seed = int.from_bytes(rng.bytes(8), "big")
return seed % max_seed_val


class KeyDependentDefaultDict(collections.defaultdict):
def __init__(self, fun):
self.fun = fun
super().__init__()

def __missing__(self, key):
value = self.fun(key)
self[key] = value
return value


def prod(sequence):
if hasattr(math, "prod"):
return math.prod(sequence)
else:
return int(np.prod(sequence))
import torchrl.collectors
import torchrl.data
import torchrl.envs
import torchrl.modules
import torchrl.objectives
import torchrl.trainers
97 changes: 97 additions & 0 deletions torchrl/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import collections
import math
import time

import numpy as np


class timeit:
"""
A dirty but easy to use decorator for profiling code
"""

_REG = {}

def __init__(self, name):
self.name = name

def __call__(self, fn):
def decorated_fn(*args, **kwargs):
with self:
out = fn(*args, **kwargs)
return out

return decorated_fn

def __enter__(self):
self.t0 = time.time()

def __exit__(self, exc_type, exc_val, exc_tb):
t = time.time() - self.t0
self._REG.setdefault(self.name, [0.0, 0.0, 0])

count = self._REG[self.name][1]
self._REG[self.name][0] = (self._REG[self.name][0] * count + t) / (count + 1)
self._REG[self.name][1] = self._REG[self.name][1] + t
self._REG[self.name][2] = count + 1

@staticmethod
def print(prefix=None):
keys = list(timeit._REG)
keys.sort()
for name in keys:
strings = []
if prefix:
strings.append(prefix)
strings.append(
f"{name} took {timeit._REG[name][0] * 1000:4.4} msec (total = {timeit._REG[name][1]} sec)"
)
print(" -- ".join(strings))

@staticmethod
def erase():
for k in timeit._REG:
timeit._REG[k] = [0.0, 0.0, 0]


def _check_for_faulty_process(processes):
terminate = False
for p in processes:
if not p.is_alive():
terminate = True
for _p in processes:
if _p.is_alive():
_p.terminate()
if terminate:
break
if terminate:
raise RuntimeError(
"At least one process failed. Check for more infos in the log."
)


def seed_generator(seed):
max_seed_val = (
2 ** 32 - 1
) # https://discuss.pytorch.org/t/what-is-the-max-seed-you-can-set-up/145688
rng = np.random.default_rng(seed)
seed = int.from_bytes(rng.bytes(8), "big")
return seed % max_seed_val


class KeyDependentDefaultDict(collections.defaultdict):
def __init__(self, fun):
self.fun = fun
super().__init__()

def __missing__(self, key):
value = self.fun(key)
self[key] = value
return value


def prod(sequence):
if hasattr(math, "prod"):
return math.prod(sequence)
else:
return int(np.prod(sequence))
2 changes: 1 addition & 1 deletion torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch.utils.data import IterableDataset

from torchrl.envs.utils import set_exploration_mode, step_tensordict
from .. import _check_for_faulty_process, prod
from .._utils import _check_for_faulty_process, prod
from ..modules.tensordict_module import ProbabilisticTensorDictModule
from .utils import split_trajectories

Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/tensordict/memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import numpy as np
import torch

from torchrl import prod
from torchrl._utils import prod
from torchrl.data.tensordict.utils import _getitem_batch_size
from torchrl.data.utils import (
DEVICE_TYPING,
Expand Down
4 changes: 2 additions & 2 deletions torchrl/data/tensordict/metatensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from torchrl.data.utils import DEVICE_TYPING, INDEX_TYPING
from .memmap import MemmapTensor
from .utils import _getitem_batch_size
from .utils import _getitem_batch_size, _get_shape

META_HANDLED_FUNCTIONS = dict()

Expand Down Expand Up @@ -74,7 +74,7 @@ def __init__(
):
if len(shape) == 1 and not isinstance(shape[0], (Number,)):
tensor = shape[0]
shape = tensor.shape
shape = _get_shape(tensor)
if _is_shared is None:
_is_shared = tensor.is_shared()
if _is_memmap is None:
Expand Down
Loading