diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index c1a6d831115..c16dd5f9ec6 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -6,6 +6,7 @@ from __future__ import annotations import abc +import enum import math import warnings from collections.abc import Iterable @@ -81,6 +82,15 @@ ) +# Akin to TD's NO_DEFAULT but won't raise a KeyError when found in a TD or used as default +class _NoDefault(enum.IntEnum): + ZERO = 0 + ONE = 1 + + +NO_DEFAULT_RL = _NoDefault.ONE + + def _default_dtype_and_device( dtype: Union[None, torch.dtype], device: Union[None, str, int, torch.device], diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index ee7649fabe4..de31ac99162 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -48,7 +48,7 @@ from torchrl.data.tensor_specs import ( CompositeSpec, - NO_DEFAULT, + NO_DEFAULT_RL as NO_DEFAULT, TensorSpec, UnboundedContinuousTensorSpec, )