diff --git a/test/_utils_internal.py b/test/_utils_internal.py index 6fcc22bc5db..fc1dff1831c 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -8,7 +8,7 @@ # Get relative file path # this returns relative path from current file. import torch.cuda -from torchrl import seed_generator +from torchrl._utils import seed_generator def get_relative_path(curr_file, *path_components): diff --git a/test/mocking_classes.py b/test/mocking_classes.py index fdfc7a5323e..b5955878afb 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -6,7 +6,7 @@ import torch import torch.nn as nn -from torchrl import seed_generator +from torchrl._utils import seed_generator from torchrl.data.tensor_specs import ( NdUnboundedContinuousTensorSpec, NdBoundedTensorSpec, diff --git a/test/test_collector.py b/test/test_collector.py index 4fb71f83bd5..631995677ed 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -18,7 +18,7 @@ MockSerialEnv, ) from torch import nn -from torchrl import seed_generator +from torchrl._utils import seed_generator from torchrl.collectors import SyncDataCollector, aSyncDataCollector from torchrl.collectors.collectors import ( RandomPolicy, diff --git a/test/test_functorch.py b/test/test_functorch.py index 6ee8136bb15..f1bb27a39b1 100644 --- a/test/test_functorch.py +++ b/test/test_functorch.py @@ -83,7 +83,7 @@ def test_vmap_tdmodule(moduletype, batch_params): if batch_params: params = params.expand(10, *params.batch_size).contiguous() buffers = buffers.expand(10, *buffers.batch_size).contiguous() - y = tdmodule(td, params=params, buffers=buffers, vmap=(0, 0, 0)) + tdmodule(td, params=params, buffers=buffers, vmap=(0, 0, 0)) else: raise NotImplementedError y = td["y"] @@ -126,7 +126,7 @@ def test_vmap_tdmodule_nativebuilt(moduletype, batch_params): if batch_params: params = params.expand(10, *params.batch_size).contiguous() buffers = buffers.expand(10, *buffers.batch_size).contiguous() - y = tdmodule(td, params=params, buffers=buffers, vmap=(0, 0, 0)) + tdmodule(td, params=params, buffers=buffers, vmap=(0, 0, 0)) else: raise NotImplementedError y = td["y"] @@ -241,6 +241,59 @@ def test_vmap_tdsequence_nativebuilt(moduletype, batch_params): assert z.shape == torch.Size([10, 2, 3]) +class TestNativeFunctorch: + def test_vamp_basic(self): + class MyModule(torch.nn.Module): + def forward(self, tensordict): + a = tensordict["a"] + return TensorDict( + {"a": a}, tensordict.batch_size, device=tensordict.device + ) + + tensordict = TensorDict({"a": torch.randn(3)}, []).expand(4) + out = vmap(MyModule(), (0,))(tensordict) + assert out.shape == torch.Size([4]) + assert out["a"].shape == torch.Size([4, 3]) + + def test_vamp_composed(self): + class MyModule(torch.nn.Module): + def forward(self, tensordict, tensor): + a = tensordict["a"] + return ( + TensorDict( + {"a": a}, tensordict.batch_size, device=tensordict.device + ), + tensor, + ) + + tensor = torch.randn(3) + tensordict = TensorDict({"a": torch.randn(3, 1)}, [3]).expand(4, 3) + out = vmap(MyModule(), (0, None))(tensordict, tensor) + + assert out[0].shape == torch.Size([4, 3]) + assert out[1].shape == torch.Size([4, 3]) + assert out[0]["a"].shape == torch.Size([4, 3, 1]) + + def test_vamp_composed_flipped(self): + class MyModule(torch.nn.Module): + def forward(self, tensordict, tensor): + a = tensordict["a"] + return ( + TensorDict( + {"a": a}, tensordict.batch_size, device=tensordict.device + ), + tensor, + ) + + tensor = torch.randn(3).expand(4, 3) + tensordict = TensorDict({"a": torch.randn(3, 1)}, [3]) + out = vmap(MyModule(), (None, 0))(tensordict, tensor) + + assert out[0].shape == torch.Size([4, 3]) + assert out[1].shape == torch.Size([4, 3]) + assert out[0]["a"].shape == torch.Size([4, 3, 1]) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 72bb1202a64..8b2823575e2 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -12,7 +12,7 @@ import torch from _utils_internal import get_available_devices from torch import multiprocessing as mp -from torchrl import prod +from torchrl._utils import prod from torchrl.data import SavedTensorDict, TensorDict, MemmapTensor from torchrl.data.tensordict.tensordict import ( assert_allclose_td, diff --git a/test/test_transforms.py b/test/test_transforms.py index 5ed04060b76..922a84720fd 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -16,7 +16,7 @@ ) from torch import Tensor from torch import multiprocessing as mp -from torchrl import prod +from torchrl._utils import prod from torchrl.data import ( NdBoundedTensorSpec, CompositeSpec, diff --git a/torchrl/__init__.py b/torchrl/__init__.py index 2b165b3823f..582aac54338 100644 --- a/torchrl/__init__.py +++ b/torchrl/__init__.py @@ -3,19 +3,13 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import abc -import collections -import math -import time -import typing -from typing import Optional, Type, Tuple from warnings import warn -import numpy as np from torch import multiprocessing as mp from ._extension import _init_extension + try: from .version import __version__ except ImportError: @@ -38,93 +32,9 @@ ) -class timeit: - """ - A dirty but easy to use decorator for profiling code - """ - - _REG = {} - - def __init__(self, name): - self.name = name - - def __call__(self, fn): - def decorated_fn(*args, **kwargs): - with self: - out = fn(*args, **kwargs) - return out - - return decorated_fn - - def __enter__(self): - self.t0 = time.time() - - def __exit__(self, exc_type, exc_val, exc_tb): - t = time.time() - self.t0 - self._REG.setdefault(self.name, [0.0, 0.0, 0]) - - count = self._REG[self.name][1] - self._REG[self.name][0] = (self._REG[self.name][0] * count + t) / (count + 1) - self._REG[self.name][1] = self._REG[self.name][1] + t - self._REG[self.name][2] = count + 1 - - @staticmethod - def print(prefix=None): - keys = list(timeit._REG) - keys.sort() - for name in keys: - strings = [] - if prefix: - strings.append(prefix) - strings.append( - f"{name} took {timeit._REG[name][0] * 1000:4.4} msec (total = {timeit._REG[name][1]} sec)" - ) - print(" -- ".join(strings)) - - @staticmethod - def erase(): - for k in timeit._REG: - timeit._REG[k] = [0.0, 0.0, 0] - - -def _check_for_faulty_process(processes): - terminate = False - for p in processes: - if not p.is_alive(): - terminate = True - for _p in processes: - if _p.is_alive(): - _p.terminate() - if terminate: - break - if terminate: - raise RuntimeError( - "At least one process failed. Check for more infos in the log." - ) - - -def seed_generator(seed): - max_seed_val = ( - 2 ** 32 - 1 - ) # https://discuss.pytorch.org/t/what-is-the-max-seed-you-can-set-up/145688 - rng = np.random.default_rng(seed) - seed = int.from_bytes(rng.bytes(8), "big") - return seed % max_seed_val - - -class KeyDependentDefaultDict(collections.defaultdict): - def __init__(self, fun): - self.fun = fun - super().__init__() - - def __missing__(self, key): - value = self.fun(key) - self[key] = value - return value - - -def prod(sequence): - if hasattr(math, "prod"): - return math.prod(sequence) - else: - return int(np.prod(sequence)) +import torchrl.collectors +import torchrl.data +import torchrl.envs +import torchrl.modules +import torchrl.objectives +import torchrl.trainers diff --git a/torchrl/_utils.py b/torchrl/_utils.py new file mode 100644 index 00000000000..b7ed60dd160 --- /dev/null +++ b/torchrl/_utils.py @@ -0,0 +1,97 @@ +import collections +import math +import time + +import numpy as np + + +class timeit: + """ + A dirty but easy to use decorator for profiling code + """ + + _REG = {} + + def __init__(self, name): + self.name = name + + def __call__(self, fn): + def decorated_fn(*args, **kwargs): + with self: + out = fn(*args, **kwargs) + return out + + return decorated_fn + + def __enter__(self): + self.t0 = time.time() + + def __exit__(self, exc_type, exc_val, exc_tb): + t = time.time() - self.t0 + self._REG.setdefault(self.name, [0.0, 0.0, 0]) + + count = self._REG[self.name][1] + self._REG[self.name][0] = (self._REG[self.name][0] * count + t) / (count + 1) + self._REG[self.name][1] = self._REG[self.name][1] + t + self._REG[self.name][2] = count + 1 + + @staticmethod + def print(prefix=None): + keys = list(timeit._REG) + keys.sort() + for name in keys: + strings = [] + if prefix: + strings.append(prefix) + strings.append( + f"{name} took {timeit._REG[name][0] * 1000:4.4} msec (total = {timeit._REG[name][1]} sec)" + ) + print(" -- ".join(strings)) + + @staticmethod + def erase(): + for k in timeit._REG: + timeit._REG[k] = [0.0, 0.0, 0] + + +def _check_for_faulty_process(processes): + terminate = False + for p in processes: + if not p.is_alive(): + terminate = True + for _p in processes: + if _p.is_alive(): + _p.terminate() + if terminate: + break + if terminate: + raise RuntimeError( + "At least one process failed. Check for more infos in the log." + ) + + +def seed_generator(seed): + max_seed_val = ( + 2 ** 32 - 1 + ) # https://discuss.pytorch.org/t/what-is-the-max-seed-you-can-set-up/145688 + rng = np.random.default_rng(seed) + seed = int.from_bytes(rng.bytes(8), "big") + return seed % max_seed_val + + +class KeyDependentDefaultDict(collections.defaultdict): + def __init__(self, fun): + self.fun = fun + super().__init__() + + def __missing__(self, key): + value = self.fun(key) + self[key] = value + return value + + +def prod(sequence): + if hasattr(math, "prod"): + return math.prod(sequence) + else: + return int(np.prod(sequence)) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 4f2eaccbe0a..2ac32b20c52 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -19,7 +19,7 @@ from torch.utils.data import IterableDataset from torchrl.envs.utils import set_exploration_mode, step_tensordict -from .. import _check_for_faulty_process, prod +from .._utils import _check_for_faulty_process, prod from ..modules.tensordict_module import ProbabilisticTensorDictModule from .utils import split_trajectories diff --git a/torchrl/data/tensordict/memmap.py b/torchrl/data/tensordict/memmap.py index 51f45f8d8ef..12ccff013e2 100644 --- a/torchrl/data/tensordict/memmap.py +++ b/torchrl/data/tensordict/memmap.py @@ -13,7 +13,7 @@ import numpy as np import torch -from torchrl import prod +from torchrl._utils import prod from torchrl.data.tensordict.utils import _getitem_batch_size from torchrl.data.utils import ( DEVICE_TYPING, diff --git a/torchrl/data/tensordict/metatensor.py b/torchrl/data/tensordict/metatensor.py index 142dbd9d540..0df81b4c380 100644 --- a/torchrl/data/tensordict/metatensor.py +++ b/torchrl/data/tensordict/metatensor.py @@ -14,7 +14,7 @@ from torchrl.data.utils import DEVICE_TYPING, INDEX_TYPING from .memmap import MemmapTensor -from .utils import _getitem_batch_size +from .utils import _getitem_batch_size, _get_shape META_HANDLED_FUNCTIONS = dict() @@ -74,7 +74,7 @@ def __init__( ): if len(shape) == 1 and not isinstance(shape[0], (Number,)): tensor = shape[0] - shape = tensor.shape + shape = _get_shape(tensor) if _is_shared is None: _is_shared = tensor.is_shared() if _is_memmap is None: diff --git a/torchrl/data/tensordict/tensordict.py b/torchrl/data/tensordict/tensordict.py index 515339edd19..553966d1379 100644 --- a/torchrl/data/tensordict/tensordict.py +++ b/torchrl/data/tensordict/tensordict.py @@ -37,7 +37,9 @@ from torch import Tensor from torch.jit._shape_functions import infer_size_impl -from torchrl import KeyDependentDefaultDict, prod +# from torch.utils._pytree import _register_pytree_node + +from torchrl._utils import KeyDependentDefaultDict, prod from torchrl.data.tensordict.memmap import MemmapTensor from torchrl.data.tensordict.metatensor import MetaTensor from torchrl.data.tensordict.utils import ( @@ -571,10 +573,11 @@ def _process_input( tensor = self._convert_to_tensor(input) else: tensor = input - if ( - _has_functorch and isinstance(tensor, Tensor) and is_batchedtensor(tensor) - ): # TODO: find a proper way of doing that - return tensor + # if ( + # _has_functorch and isinstance(tensor, Tensor) and is_batchedtensor(tensor) + # ): # TODO: find a proper way of doing that + # return tensor + # tensor = _unwrap_value(tensor)[0] if check_device and self.device is not None: device = self.device @@ -1891,7 +1894,12 @@ def _make_meta(self, key: str) -> MetaTensor: else isinstance(proc_value, MemmapTensor) ) is_shared = ( - self._is_shared if self._is_shared is not None else proc_value.is_shared() + self._is_shared + if self._is_shared is not None + else proc_value.is_shared() + if isinstance(proc_value, (TensorDictBase, MemmapTensor)) + or not is_batchedtensor(proc_value) + else False ) return MetaTensor( proc_value, @@ -4695,3 +4703,32 @@ def _expand_to_match_shape(parent_batch_size, tensor, self_batch_dims, self_devi device=self_device, ) return out + + +# seems like we can do without registering in pytree -- which requires us to create a new TensorDict, +# an operation that does not come for free + +# def _flatten_tensordict(tensordict): +# return tensordict, tuple() +# # keys, values = list(zip(*tensordict.items())) +# # # represent values as batched tensors +# # vmap_level = 0 +# # in_dim +# # values = [_add_batch_dim(value, in_dim, vmap_level) +# # return list(values), (list(keys), tensordict.device, tensordict.batch_size) +# +# def _unflatten_tensordict(values, context): +# return values +# # values = [_unwrap_value(value) for value in values] +# # keys, device, batch_size = context +# # print(values[0].shape) +# # return TensorDict( +# # {key: value for key, value in zip(keys, values)}, +# # [], +# # # [*new_batch_sizes[0], *batch_size], +# # # new_batch_sizes[0], +# # device=device +# # ) +# +# +# _register_pytree_node(TensorDict, _flatten_tensordict, _unflatten_tensordict) diff --git a/torchrl/data/tensordict/utils.py b/torchrl/data/tensordict/utils.py index 585f8ba4c4e..6e2f20482dd 100644 --- a/torchrl/data/tensordict/utils.py +++ b/torchrl/data/tensordict/utils.py @@ -11,6 +11,16 @@ import numpy as np import torch +try: + try: + from functorch._C import is_batchedtensor, get_unwrapped + except ImportError: + from torch._C._functorch import is_batchedtensor, get_unwrapped + + _has_functorch = True +except ImportError: + _has_functorch = False + from torchrl.data.utils import INDEX_TYPING @@ -151,3 +161,22 @@ def convert_ellipsis_to_idx(idx: Union[Tuple, Ellipsis], batch_size: List[int]): ) return new_index + + +def _get_shape(value): + # we call it "legacy code" + return value.shape + + +def _unwrap_value(value): + # batch_dims = value.ndimension() + if not isinstance(value, torch.Tensor): + out = value + elif is_batchedtensor(value): + out = get_unwrapped(value) + else: + out = value + return out + # batch_dims = out.ndimension() - batch_dims + # batch_size = out.shape[:batch_dims] + # return out, batch_size diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 69337ef500d..d14405a71ff 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -14,8 +14,8 @@ import torch import torch.nn as nn -from torchrl import seed_generator, prod from torchrl.data import CompositeSpec, TensorDict, TensorSpec +from .._utils import seed_generator, prod from ..data.tensordict.tensordict import TensorDictBase from ..data.utils import DEVICE_TYPING from .utils import get_available_libraries, step_tensordict diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index 918720f2b14..884df4b7287 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -17,7 +17,7 @@ import torch from torch import multiprocessing as mp -from torchrl import _check_for_faulty_process +from torchrl._utils import _check_for_faulty_process from torchrl.data import TensorDict, TensorSpec, CompositeSpec from torchrl.data.tensordict.tensordict import TensorDictBase, LazyStackedTensorDict from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index 0f1b84737ef..74a8a36e364 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -7,3 +7,4 @@ from .models import * from .tensordict_module import * from .planners import * +from .functional_modules import * diff --git a/torchrl/modules/functional_modules.py b/torchrl/modules/functional_modules.py index 4542d6e22b2..df07a68f776 100644 --- a/torchrl/modules/functional_modules.py +++ b/torchrl/modules/functional_modules.py @@ -19,6 +19,7 @@ except ImportError: _has_functorch = False +# Monky-patch functorch, mainly for cases where a "isinstance(obj, Tensor) is invoked if _has_functorch: from functorch._src.vmap import ( _get_name, @@ -28,6 +29,7 @@ _validate_and_get_batch_size, _add_batch_dim, tree_unflatten, + _remove_batch_dim, ) # Monkey-patches @@ -93,10 +95,15 @@ def _process_batched_inputs(in_dims, args, func): def _create_batched_inputs(flat_in_dims, flat_args, vmap_level: int, args_spec): # See NOTE [Ignored _remove_batch_dim, _add_batch_dim] + # If tensordict, we remove the dim at batch_size[in_dim] such that the TensorDict can accept + # the batched tensors. This will be added in _unwrap_batched batched_inputs = [ arg if in_dim is None - else arg.apply(lambda _arg: _add_batch_dim(_arg, in_dim, vmap_level)) + else arg.apply( + lambda _arg: _add_batch_dim(_arg, in_dim, vmap_level), + batch_size=[b for i, b in enumerate(arg.batch_size) if i != in_dim], + ) if isinstance(arg, TensorDictBase) else _add_batch_dim(arg, in_dim, vmap_level) for in_dim, arg in zip(flat_in_dims, flat_args) @@ -105,6 +112,60 @@ def _create_batched_inputs(flat_in_dims, flat_args, vmap_level: int, args_spec): functorch._src.vmap._create_batched_inputs = _create_batched_inputs + def _unwrap_batched( + batched_outputs, out_dims, vmap_level: int, batch_size: int, func + ): + flat_batched_outputs, output_spec = tree_flatten(batched_outputs) + + for out in flat_batched_outputs: + # Change here: + if isinstance(out, (TensorDictBase, torch.Tensor)): + continue + raise ValueError( + f"vmap({_get_name(func)}, ...): `{_get_name(func)}` must only return " + f"Tensors, got type {type(out)} as a return." + ) + + def incompatible_error(): + raise ValueError( + f"vmap({_get_name(func)}, ..., out_dims={out_dims})(): " + f"out_dims is not compatible with the structure of `outputs`. " + f"out_dims has structure {tree_flatten(out_dims)[1]} but outputs " + f"has structure {output_spec}." + ) + + # Here: + if isinstance(batched_outputs, (TensorDictBase, torch.Tensor)): + # Some weird edge case requires us to spell out the following + # see test_out_dims_edge_case + if isinstance(out_dims, int): + flat_out_dims = [out_dims] + elif isinstance(out_dims, tuple) and len(out_dims) == 1: + flat_out_dims = out_dims + out_dims = out_dims[0] + else: + incompatible_error() + else: + flat_out_dims = _broadcast_to_and_flatten(out_dims, output_spec) + if flat_out_dims is None: + incompatible_error() + + flat_outputs = [] + for batched_output, out_dim in zip(flat_batched_outputs, flat_out_dims): + if not isinstance(batched_output, TensorDictBase): + out = _remove_batch_dim(batched_output, vmap_level, batch_size, out_dim) + else: + out = batched_output.apply( + lambda x: _remove_batch_dim(x, vmap_level, batch_size, out_dim), + batch_size=[batch_size, *batched_output.batch_size], + ) + flat_outputs.append(out) + return tree_unflatten(flat_outputs, output_spec) + + functorch._src.vmap._unwrap_batched = _unwrap_batched + +# Tensordict-compatible Functional modules + class FunctionalModule(nn.Module): """ @@ -178,6 +239,9 @@ def forward(self, params, buffers, *args, **kwargs): _swap_state(self.stateless_model, old_state_buffers) +# Some utils for these + + def extract_weights(model): tensordict = TensorDict({}, []) for name, param in list(model.named_parameters(recurse=False)): diff --git a/torchrl/modules/models/exploration.py b/torchrl/modules/models/exploration.py index 73c45923254..317018e7625 100644 --- a/torchrl/modules/models/exploration.py +++ b/torchrl/modules/models/exploration.py @@ -13,7 +13,7 @@ __all__ = ["NoisyLinear", "NoisyLazyLinear", "reset_noise"] -from torchrl import prod +from torchrl._utils import prod from torchrl.data.utils import DEVICE_TYPING, DEVICE_TYPING_ARGS from torchrl.envs.utils import exploration_mode from torchrl.modules.distributions.utils import _cast_transform_device diff --git a/torchrl/modules/models/models.py b/torchrl/modules/models/models.py index d600666c0b5..18a32a33083 100644 --- a/torchrl/modules/models/models.py +++ b/torchrl/modules/models/models.py @@ -10,7 +10,7 @@ from torch import nn from torch.nn import functional as F -from torchrl import prod +from torchrl._utils import prod from torchrl.data import DEVICE_TYPING from torchrl.modules.models.utils import ( _find_depth, diff --git a/torchrl/trainers/__init__.py b/torchrl/trainers/__init__.py index 08f049c1ae9..55d68b3c6cf 100644 --- a/torchrl/trainers/__init__.py +++ b/torchrl/trainers/__init__.py @@ -4,4 +4,5 @@ # LICENSE file in the root directory of this source tree. from .trainers import * -from .loggers import * + +# from .loggers import * diff --git a/torchrl/trainers/loggers/tensorboard.py b/torchrl/trainers/loggers/tensorboard.py index b9d3084fbf0..aa46735fd69 100644 --- a/torchrl/trainers/loggers/tensorboard.py +++ b/torchrl/trainers/loggers/tensorboard.py @@ -3,19 +3,18 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import os -from warnings import warn from torch import Tensor from .common import Logger -_has_tb = False + try: from torch.utils.tensorboard import SummaryWriter _has_tb = True except ImportError: - warn("torch.utils.tensorboard could not be imported") + _has_tb = False class TensorboardLogger(Logger): @@ -45,6 +44,8 @@ def _create_experiment(self) -> "SummaryWriter": SummaryWriter: The tensorboard experiment. """ + if not _has_tb: + raise ImportError("torch.utils.tensorboard could not be imported") log_dir = str(os.path.join(self.log_dir, self.exp_name)) return SummaryWriter(log_dir=log_dir) diff --git a/torchrl/trainers/loggers/wandb.py b/torchrl/trainers/loggers/wandb.py index 536193a7b79..c6c50800224 100644 --- a/torchrl/trainers/loggers/wandb.py +++ b/torchrl/trainers/loggers/wandb.py @@ -11,22 +11,21 @@ from .common import Logger -_has_wandb = False + try: import wandb _has_wandb = True except ImportError: - warnings.warn("wandb could not be imported") -_has_omgaconf = False + _has_wandb = False + + try: from omegaconf import OmegaConf _has_omgaconf = True except ImportError: - warnings.warn( - "OmegaConf could not be imported. Cannot log hydra configs without OmegaConf" - ) + _has_omgaconf = False class WandbLogger(Logger): @@ -52,6 +51,9 @@ def __init__( project: str = None, **kwargs, ) -> None: + if not _has_wandb: + raise ImportError("wandb could not be imported") + log_dir = kwargs.pop("log_dir", None) self.offline = offline if save_dir and log_dir: @@ -168,6 +170,11 @@ def log_hparams(self, cfg: "DictConfig") -> None: # noqa: F821 """ if type(cfg) is not dict and _has_omgaconf: + if not _has_omgaconf: + raise ImportError( + "OmegaConf could not be imported. " + "Cannot log hydra configs without OmegaConf." + ) cfg = OmegaConf.to_container(cfg, resolve=True) self.experiment.config.update(cfg, allow_val_change=True) diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index 7cd510ca448..277c0d80941 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -15,7 +15,7 @@ import torch.nn from torch import nn, optim -from torchrl import KeyDependentDefaultDict +from torchrl._utils import KeyDependentDefaultDict try: from tqdm import tqdm