From 788710f9317298fc68c5165ed0f0952055a2ae26 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 5 Aug 2024 20:02:40 -0400 Subject: [PATCH] [BugFix] Use a RL-specific NO_DEFAULT instead of TD's one (#2367) --- torchrl/data/tensor_specs.py | 10 ++++++++++ torchrl/envs/utils.py | 2 +- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index c1a6d831115..c16dd5f9ec6 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -6,6 +6,7 @@ from __future__ import annotations import abc +import enum import math import warnings from collections.abc import Iterable @@ -81,6 +82,15 @@ ) +# Akin to TD's NO_DEFAULT but won't raise a KeyError when found in a TD or used as default +class _NoDefault(enum.IntEnum): + ZERO = 0 + ONE = 1 + + +NO_DEFAULT_RL = _NoDefault.ONE + + def _default_dtype_and_device( dtype: Union[None, torch.dtype], device: Union[None, str, int, torch.device], diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index ee7649fabe4..de31ac99162 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -48,7 +48,7 @@ from torchrl.data.tensor_specs import ( CompositeSpec, - NO_DEFAULT, + NO_DEFAULT_RL as NO_DEFAULT, TensorSpec, UnboundedContinuousTensorSpec, )