From 32339daf0f32baf07bff825718ba8c42480fcf3f Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 25 Apr 2023 21:20:13 +0100 Subject: [PATCH] [Refactor] Refactor DQN (#1085) --- test/test_actors.py | 85 ++++++++++- test/test_cost.py | 41 +++--- torchrl/modules/tensordict_module/actors.py | 147 +++++++++++++++----- torchrl/modules/utils/__init__.py | 1 + torchrl/modules/utils/utils.py | 36 +++++ torchrl/objectives/dqn.py | 39 +++++- 6 files changed, 290 insertions(+), 59 deletions(-) create mode 100644 torchrl/modules/utils/utils.py diff --git a/test/test_actors.py b/test/test_actors.py index 9dbe7ab0733..a2dfa37fdaf 100644 --- a/test/test_actors.py +++ b/test/test_actors.py @@ -12,9 +12,15 @@ from tensordict.nn import TensorDictModule from torch import nn -from torchrl.data import DiscreteTensorSpec, OneHotDiscreteTensorSpec +from torchrl.data import ( + CompositeSpec, + DiscreteTensorSpec, + MultiOneHotDiscreteTensorSpec, + OneHotDiscreteTensorSpec, +) from torchrl.modules import MLP, SafeModule from torchrl.modules.tensordict_module.actors import ( + _process_action_space_spec, ActorValueOperator, DistributionalQValueActor, DistributionalQValueHook, @@ -29,14 +35,44 @@ class TestQValue: def test_qvalue_hook_wrong_action_space(self): - with pytest.raises(ValueError) as exc: + with pytest.raises( + ValueError, match="action_space was not specified/not compatible" + ): QValueHook(action_space="wrong_value") - assert "action_space must be one of" in str(exc.value) def test_distributional_qvalue_hook_wrong_action_space(self): - with pytest.raises(ValueError) as exc: + with pytest.raises( + ValueError, match="action_space was not specified/not compatible" + ): DistributionalQValueHook(action_space="wrong_value", support=None) - assert "action_space must be one of" in str(exc.value) + + def test_distributional_qvalue_hook_conflicting_spec(self): + spec = OneHotDiscreteTensorSpec(3) + _process_action_space_spec("one-hot", spec) + _process_action_space_spec("one_hot", spec) + _process_action_space_spec("one_hot", None) + _process_action_space_spec(None, spec) + with pytest.raises( + ValueError, match="The action spec and the action space do not match" + ): + _process_action_space_spec("multi-one-hot", spec) + spec = MultiOneHotDiscreteTensorSpec([3, 3]) + _process_action_space_spec("multi-one-hot", spec) + _process_action_space_spec(spec, spec) + with pytest.raises( + ValueError, match="Passing an action_space as a TensorSpec and a spec" + ): + _process_action_space_spec(OneHotDiscreteTensorSpec(3), spec) + with pytest.raises( + ValueError, match="action_space cannot be of type CompositeSpec" + ): + _process_action_space_spec(CompositeSpec(), spec) + with pytest.raises(KeyError, match="action could not be found in the spec"): + _process_action_space_spec(None, CompositeSpec()) + with pytest.raises( + ValueError, match="Neither action_space nor spec was defined" + ): + _process_action_space_spec(None, None) @pytest.mark.parametrize( "action_space, expected_action", @@ -406,6 +442,45 @@ def make_net(): assert (action.sum(-1) == 1).all() +@pytest.mark.parametrize( + "spec", [None, OneHotDiscreteTensorSpec(3), MultiOneHotDiscreteTensorSpec([3, 2])] +) +@pytest.mark.parametrize( + "action_space", [None, "one-hot", "one_hot", "mult-one-hot", "mult_one_hot"] +) +def test_qvalactor_construct( + spec, + action_space, +): + kwargs = {} + if spec is not None: + kwargs["spec"] = spec + if action_space is not None: + kwargs["action_space"] = action_space + kwargs["module"] = TensorDictModule( + lambda x: x, in_keys=["x"], out_keys=["action_value"] + ) + if spec is None and action_space is None: + with pytest.raises( + ValueError, match="Neither action_space nor spec was defined" + ): + QValueActor(**kwargs) + return + if ( + type(spec) is MultiOneHotDiscreteTensorSpec + and action_space not in ("mult-one-hot", "mult_one_hot", None) + ) or ( + type(spec) is OneHotDiscreteTensorSpec + and action_space not in ("one-hot", "one_hot", None) + ): + with pytest.raises( + ValueError, match="The action spec and the action space do not match" + ): + QValueActor(**kwargs) + return + QValueActor(**kwargs) + + @pytest.mark.parametrize("device", get_available_devices()) def test_value_based_policy_categorical(device): torch.manual_seed(0) diff --git a/test/test_cost.py b/test/test_cost.py index b73382a1711..d475afe37ce 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -145,10 +145,10 @@ def _create_mock_actor( action_spec = OneHotDiscreteTensorSpec(action_dim) elif action_spec_type == "categorical": action_spec = DiscreteTensorSpec(action_dim) - elif action_spec_type == "nd_bounded": - action_spec = BoundedTensorSpec( - -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) - ) + # elif action_spec_type == "nd_bounded": + # action_spec = BoundedTensorSpec( + # -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) + # ) else: raise ValueError(f"Wrong {action_spec_type}") @@ -162,6 +162,7 @@ def _create_mock_actor( chosen_action_value=None, shape=[], ), + action_space=action_spec_type, module=module, ).to(device) return actor @@ -178,8 +179,12 @@ def _create_mock_distributional_actor( is_nn_module=False, ): # Actor + var_nums = None if action_spec_type == "mult_one_hot": - action_spec = MultiOneHotDiscreteTensorSpec([atoms] * action_dim) + action_spec = MultiOneHotDiscreteTensorSpec( + [action_dim // 2, action_dim // 2] + ) + var_nums = action_spec.nvec elif action_spec_type == "one_hot": action_spec = OneHotDiscreteTensorSpec(action_dim) elif action_spec_type == "categorical": @@ -201,9 +206,8 @@ def _create_mock_distributional_actor( ), module=module, support=support, - action_space="categorical" - if isinstance(action_spec, DiscreteTensorSpec) - else "one_hot", + action_space=action_spec_type, + var_nums=var_nums, ) return actor @@ -230,7 +234,7 @@ def _create_mock_data_dqn( if action_spec_type == "categorical": action_value = torch.max(action_value, -1, keepdim=True)[0] - action = torch.argmax(action, -1, keepdim=True) + action = torch.argmax(action, -1, keepdim=False) reward = torch.randn(batch, 1) done = torch.zeros(batch, 1, dtype=torch.bool) td = TensorDict( @@ -274,13 +278,16 @@ def _create_seq_mock_data_dqn( action_value = torch.randn(batch, T, action_dim, device=device) action = (action_value == action_value.max(-1, True)[0]).to(torch.long) - if action_spec_type == "categorical": - action_value = torch.max(action_value, -1, keepdim=True)[0] - action = torch.argmax(action, -1, keepdim=True) # action_value = action_value.unsqueeze(-1) reward = torch.randn(batch, T, 1, device=device) done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) mask = ~torch.zeros(batch, T, dtype=torch.bool, device=device) + if action_spec_type == "categorical": + action_value = torch.max(action_value, -1, keepdim=True)[0] + action = torch.argmax(action, -1, keepdim=False) + action = action.masked_fill_(~mask, 0.0) + else: + action = action.masked_fill_(~mask.unsqueeze(-1), 0.0) td = TensorDict( batch_size=(batch, T), source={ @@ -291,7 +298,7 @@ def _create_seq_mock_data_dqn( "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), }, "collector": {"mask": mask}, - "action": action.masked_fill_(~mask.unsqueeze(-1), 0.0), + "action": action, "action_value": action_value.masked_fill_(~mask.unsqueeze(-1), 0.0), }, ) @@ -299,9 +306,7 @@ def _create_seq_mock_data_dqn( @pytest.mark.parametrize("delay_value", (False, True)) @pytest.mark.parametrize("device", get_available_devices()) - @pytest.mark.parametrize( - "action_spec_type", ("nd_bounded", "one_hot", "categorical") - ) + @pytest.mark.parametrize("action_spec_type", ("one_hot", "categorical")) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) def test_dqn(self, delay_value, device, action_spec_type, td_est): torch.manual_seed(self.seed) @@ -344,9 +349,7 @@ def test_dqn(self, delay_value, device, action_spec_type, td_est): @pytest.mark.parametrize("n", range(4)) @pytest.mark.parametrize("delay_value", (False, True)) @pytest.mark.parametrize("device", get_available_devices()) - @pytest.mark.parametrize( - "action_spec_type", ("nd_bounded", "one_hot", "categorical") - ) + @pytest.mark.parametrize("action_spec_type", ("one_hot", "categorical")) def test_dqn_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9): torch.manual_seed(self.seed) actor = self._create_mock_actor( diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index 1c97473aaec..89445c8ef0d 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -23,6 +23,7 @@ SafeProbabilisticTensorDictSequential, ) from torchrl.modules.tensordict_module.sequence import SafeSequential +from torchrl.modules.utils.utils import _find_action_space class Actor(SafeModule): @@ -317,18 +318,24 @@ class QValueModule(TensorDictModuleBase): It works with both tensordict and regular tensors. Args: - action_space (str): Action space. Must be one of - ``"one-hot"``, ``"mult_one_hot"``, ``"binary"`` or ``"categorical"``. + action_space (str or TensorSpec, optional): Action space. Must be one of + ``"one-hot"``, ``"mult-one-hot"``, ``"binary"`` or ``"categorical"``, + or an instance of the corresponding specs (:class:`torchrl.data.OneHotDiscreteTensorSpec`, + :class:`torchrl.data.MultiOneHotDiscreteTensorSpec`, + :class:`torchrl.data.BinaryDiscreteTensorSpec` or :class:`torchrl.data.DiscreteTensorSpec`). + This is argumets is exclusive with ``spec``, since the ``action_spec`` + conditions the action spec. action_value_key (str or tuple of str, optional): The input key representing the action value. Defaults to ``"action_value"``. out_keys (list of str or tuple of str, optional): The output keys representing the actions, action values and chosen action value. Defaults to ``["action", "action_value", "chosen_action_value"]``. - var_nums (int, optional): if ``action_space = "mult_one_hot"``, + var_nums (int, optional): if ``action_space = "mult-one-hot"``, this value represents the cardinality of each action component. spec (TensorSpec, optional): if provided, the specs of the action (and/or - other outputs). + other outputs). This is exclusive with ``action_space``, as the spec + conditions the action space. safe (bool): if ``True``, the value of the output is checked against the input spec. Out-of-domain sampling can occur because of exploration policies or numerical under/overflow issues. @@ -369,13 +376,14 @@ class QValueModule(TensorDictModuleBase): def __init__( self, - action_space: str, + action_space: Optional[Union[str, TensorSpec]], action_value_key: Union[List[str], List[Tuple[str]]] = None, out_keys: Union[List[str], List[Tuple[str]]] = None, var_nums: Optional[int] = None, - spec: TensorSpec = None, + spec: Optional[TensorSpec] = None, safe: bool = False, ): + action_space, spec = _process_action_space_spec(action_space, spec) self.action_space = action_space self.var_nums = var_nums self.action_func_mapping = { @@ -389,7 +397,7 @@ def __init__( } if action_space not in self.action_func_mapping: raise ValueError( - f"action_space must be one of {list(self.action_func_mapping.keys())}" + f"action_space must be one of {list(self.action_func_mapping.keys())}, got {action_space}" ) if action_value_key is None: action_value_key = "action_value" @@ -401,6 +409,9 @@ def __init__( f"Expected the action-value key to be '{action_value_key}' but got {out_keys[1]} instead." ) self.out_keys = out_keys + action_key = out_keys[0] + if not isinstance(spec, CompositeSpec): + spec = CompositeSpec({action_key: spec}) super().__init__() self.register_spec(safe=safe, spec=spec) @@ -476,10 +487,11 @@ def _default_action_value( def _categorical_action_value( values: torch.Tensor, action: torch.Tensor ) -> torch.Tensor: - if len(values.shape) == 1: - return values[action].unsqueeze(-1) - batch_size = values.size(0) - return values[range(batch_size), action].unsqueeze(-1) + return values.gather(-1, action.unsqueeze(-1)) + # if values.ndim == 1: + # return values[action].unsqueeze(-1) + # batch_size = values.size(0) + # return values[range(batch_size), action].unsqueeze(-1) class DistributionalQValueModule(QValueModule): @@ -497,19 +509,25 @@ class DistributionalQValueModule(QValueModule): https://arxiv.org/pdf/1707.06887.pdf Args: - action_space (str): Action space. Must be one of - ``"one-hot"``, ``"mult_one_hot"``, ``"binary"`` or ``"categorical"``. + action_space (str or TensorSpec, optional): Action space. Must be one of + ``"one-hot"``, ``"mult-one-hot"``, ``"binary"`` or ``"categorical"``, + or an instance of the corresponding specs (:class:`torchrl.data.OneHotDiscreteTensorSpec`, + :class:`torchrl.data.MultiOneHotDiscreteTensorSpec`, + :class:`torchrl.data.BinaryDiscreteTensorSpec` or :class:`torchrl.data.DiscreteTensorSpec`). + This is argumets is exclusive with ``spec``, since the ``action_spec`` + conditions the action spec. support (torch.Tensor): support of the action values. action_value_key (str or tuple of str, optional): The input key representing the action value. Defaults to ``"action_value"``. out_keys (list of str or tuple of str, optional): The output keys representing the actions and action values. Defaults to ``["action", "action_value"]``. - var_nums (int, optional): if ``action_space = "mult_one_hot"``, + var_nums (int, optional): if ``action_space = "mult-one-hot"``, this value represents the cardinality of each action component. spec (TensorSpec, optional): if provided, the specs of the action (and/or - other outputs). + other outputs). This is exclusive with ``action_space``, as the spec + conditions the action space. safe (bool): if ``True``, the value of the output is checked against the input spec. Out-of-domain sampling can occur because of exploration policies or numerical under/overflow issues. @@ -654,6 +672,49 @@ def _binary(self, value: torch.Tensor) -> torch.Tensor: ) +def _process_action_space_spec(action_space, spec): + nest_action = False + if isinstance(spec, CompositeSpec): + try: + # this will break whenever our action is more complex than a single tensor + spec = spec["action"] + nest_action = True + except KeyError: + raise KeyError( + "action could not be found in the spec. Make sure " + "you pass a spec that is either a native action spec or a composite action spec " + "with an 'action' entry. Otherwise, simply remove the spec and use the action_space only." + ) + if action_space is not None: + if isinstance(action_space, CompositeSpec): + raise ValueError("action_space cannot be of type CompositeSpec.") + if ( + spec is not None + and isinstance(action_space, TensorSpec) + and action_space is not spec + ): + raise ValueError( + "Passing an action_space as a TensorSpec and a spec isn't allowed, unless they match." + ) + if isinstance(action_space, TensorSpec): + spec = action_space + action_space = _find_action_space(action_space) + # check that the spec and action_space match + if spec is not None and _find_action_space(spec) != action_space: + raise ValueError( + f"The action spec and the action space do not match: got action_space={action_space} and spec={spec}." + ) + elif spec is not None: + action_space = _find_action_space(spec) + else: + raise ValueError( + "Neither action_space nor spec was defined. The action space cannot be inferred." + ) + if nest_action: + spec = CompositeSpec(action=spec) + return action_space, spec + + class QValueHook: """Q-Value hook for Q-value policies. @@ -664,8 +725,8 @@ class QValueHook: Args: action_space (str): Action space. Must be one of - ``"one-hot"``, ``"mult_one_hot"``, ``"binary"`` or ``"categorical"``. - var_nums (int, optional): if ``action_space = "mult_one_hot"``, + ``"one-hot"``, ``"mult-one-hot"``, ``"binary"`` or ``"categorical"``. + var_nums (int, optional): if ``action_space = "mult-one-hot"``, this value represents the cardinality of each action component. action_value_key (str or tuple of str, optional): to be used when hooked on @@ -708,6 +769,8 @@ def __init__( action_value_key: Union[str, Tuple[str]] = None, out_keys: Union[List[str], List[Tuple[str]]] = None, ): + action_space, _ = _process_action_space_spec(action_space, None) + self.qvalue_model = QValueModule( action_space=action_space, var_nums=var_nums, @@ -740,9 +803,9 @@ class DistributionalQValueHook(QValueHook): Args: action_space (str): Action space. Must be one of - ``"one-hot"``, ``"mult_one_hot"``, ``"binary"`` or ``"categorical"``. + ``"one-hot"``, ``"mult-one-hot"``, ``"binary"`` or ``"categorical"``. support (torch.Tensor): support of the action values. - var_nums (int, optional): if ``action_space = "mult_one_hot"``, this + var_nums (int, optional): if ``action_space = "mult-one-hot"``, this value represents the cardinality of each action component. @@ -790,6 +853,7 @@ def __init__( action_value_key: Union[str, Tuple[str]] = None, out_keys: Union[List[str], List[Tuple[str]]] = None, ): + action_space, _ = _process_action_space_spec(action_space, None) self.qvalue_model = DistributionalQValueModule( action_space=action_space, var_nums=var_nums, @@ -816,6 +880,8 @@ class QValueActor(SafeSequential): with :class:`tensordict.nn.TensorDictModuleBase`, it will be wrapped in a :class:`tensordict.nn.TensorDictModule` with ``in_keys`` indicated by the following keyword argument. + + Keyword Args: in_keys (iterable of str, optional): If the class provided is not compatible with :class:`tensordict.nn.TensorDictModuleBase`, this list of keys indicates what observations need to be passed to the @@ -832,9 +898,13 @@ class QValueActor(SafeSequential): issues. If this value is out of bounds, it is projected back onto the desired space using the :obj:`TensorSpec.project` method. Default is ``False``. - action_space (str, optional): The action space to be considered. - Must be one of - ``"one-hot"``, ``"mult_one_hot"``, ``"binary"`` or ``"categorical"``. + action_space (str or TensorSpec, optional): Action space. Must be one of + ``"one-hot"``, ``"mult-one-hot"``, ``"binary"`` or ``"categorical"``, + or an instance of the corresponding specs (:class:`torchrl.data.OneHotDiscreteTensorSpec`, + :class:`torchrl.data.MultiOneHotDiscreteTensorSpec`, + :class:`torchrl.data.BinaryDiscreteTensorSpec` or :class:`torchrl.data.DiscreteTensorSpec`). + This is argumets is exclusive with ``spec``, since the ``action_spec`` + conditions the action spec. action_value_key (str or tuple of str, optional): if the input module is a :class:`tensordict.nn.TensorDictModuleBase` instance, it must match one of its output keys. Otherwise, this string represents @@ -891,12 +961,15 @@ class QValueActor(SafeSequential): def __init__( self, module, + *, in_keys=None, spec=None, safe=False, - action_space: str = "one_hot", + action_space: str = None, action_value_key=None, ): + action_space, spec = _process_action_space_spec(action_space, spec) + self.action_space = action_space self.action_value_key = action_value_key if action_value_key is None: @@ -952,16 +1025,13 @@ class DistributionalQValueActor(QValueActor): operation is applied to the action value tensor along dimension ``-2``. This can be deactivated by turning off the ``make_log_softmax`` keyword argument. + + Keyword Args: in_keys (iterable of str, optional): keys to be read from input tensordict and passed to the module. If it contains more than one element, the values will be passed in the order given by the in_keys iterable. Defaults to ``["observation"]``. - out_keys (iterable of str): keys to be written to the input tensordict. - The length of out_keys must match the - number of tensors returned by the embedded module. Using "_" as a - key avoid writing tensor to output. - Defaults to ``["action"]``. spec (TensorSpec, optional): Keyword-only argument. Specs of the output tensor. If the module outputs multiple output tensors, @@ -973,10 +1043,17 @@ class DistributionalQValueActor(QValueActor): issues. If this value is out of bounds, it is projected back onto the desired space using the :obj:`TensorSpec.project` method. Default is ``False``. + var_nums (int, optional): if ``action_space = "mult-one-hot"``, + this value represents the cardinality of each + action component. support (torch.Tensor): support of the action values. - action_space (str, optional): The action space to be considered. - Must be one of - ``"one-hot"``, ``"mult_one_hot"``, ``"binary"`` or ``"categorical"``. + action_space (str or TensorSpec, optional): Action space. Must be one of + ``"one-hot"``, ``"mult-one-hot"``, ``"binary"`` or ``"categorical"``, + or an instance of the corresponding specs (:class:`torchrl.data.OneHotDiscreteTensorSpec`, + :class:`torchrl.data.MultiOneHotDiscreteTensorSpec`, + :class:`torchrl.data.BinaryDiscreteTensorSpec` or :class:`torchrl.data.DiscreteTensorSpec`). + This is argumets is exclusive with ``spec``, since the ``action_spec`` + conditions the action spec. make_log_softmax (bool, optional): if ``True`` and if the module is not of type :class:`torchrl.modules.DistributionalDQNnet`, a log-softmax operation will be applied along dimension -2 of the action value tensor. @@ -1021,10 +1098,15 @@ def __init__( in_keys=None, spec=None, safe=False, - action_space: str = "one_hot", + var_nums: Optional[int] = None, + action_space: str = None, action_value_key: str = "action_value", make_log_softmax: bool = True, ): + + action_space, spec = _process_action_space_spec(action_space, spec) + + action_space, spec = _process_action_space_spec(action_space, spec) self.action_space = action_space self.action_value_key = action_value_key out_keys = [ @@ -1059,6 +1141,7 @@ def __init__( safe=safe, action_space=action_space, support=support, + var_nums=var_nums, ) self.make_log_softmax = make_log_softmax if make_log_softmax and not isinstance(module, DistributionalDQNnet): diff --git a/torchrl/modules/utils/__init__.py b/torchrl/modules/utils/__init__.py index ef430b85391..b9b641e23d5 100644 --- a/torchrl/modules/utils/__init__.py +++ b/torchrl/modules/utils/__init__.py @@ -8,6 +8,7 @@ import torch from packaging import version + if version.parse(torch.__version__) >= version.parse("1.12.0"): from torch.nn.parameter import _disabled_torch_function_impl, _ParameterMeta else: diff --git a/torchrl/modules/utils/utils.py b/torchrl/modules/utils/utils.py new file mode 100644 index 00000000000..12f226acc62 --- /dev/null +++ b/torchrl/modules/utils/utils.py @@ -0,0 +1,36 @@ +from torchrl.data.tensor_specs import ( + BinaryDiscreteTensorSpec, + CompositeSpec, + DiscreteTensorSpec, + MultiOneHotDiscreteTensorSpec, + OneHotDiscreteTensorSpec, + TensorSpec, +) + +ACTION_SPACE_MAP = {} +ACTION_SPACE_MAP[OneHotDiscreteTensorSpec] = "one_hot" +ACTION_SPACE_MAP[MultiOneHotDiscreteTensorSpec] = "mult_one_hot" +ACTION_SPACE_MAP[BinaryDiscreteTensorSpec] = "binary" +ACTION_SPACE_MAP[DiscreteTensorSpec] = "categorical" +ACTION_SPACE_MAP["one_hot"] = "one_hot" +ACTION_SPACE_MAP["one-hot"] = "one_hot" +ACTION_SPACE_MAP["mult_one_hot"] = "mult_one_hot" +ACTION_SPACE_MAP["mult-one-hot"] = "mult_one_hot" +ACTION_SPACE_MAP["multi_one_hot"] = "mult_one_hot" +ACTION_SPACE_MAP["multi-one-hot"] = "mult_one_hot" +ACTION_SPACE_MAP["binary"] = "binary" +ACTION_SPACE_MAP["categorical"] = "categorical" + + +def _find_action_space(action_space): + if isinstance(action_space, TensorSpec): + if isinstance(action_space, CompositeSpec): + action_space = action_space["action"] + action_space = type(action_space) + try: + action_space = ACTION_SPACE_MAP[action_space] + except KeyError: + raise ValueError( + f"action_space was not specified/not compatible and could not be retrieved from the value network. Got action_space={action_space}." + ) + return action_space diff --git a/torchrl/objectives/dqn.py b/torchrl/objectives/dqn.py index 70957785fa7..d2370d27c5f 100644 --- a/torchrl/objectives/dqn.py +++ b/torchrl/objectives/dqn.py @@ -8,11 +8,17 @@ import torch from tensordict import TensorDict, TensorDictBase from torch import nn +from torchrl.data.tensor_specs import TensorSpec from torchrl.envs.utils import step_mdp -from torchrl.modules import DistributionalQValueActor, QValueActor +from torchrl.modules.tensordict_module.actors import ( + DistributionalQValueActor, + QValueActor, +) from torchrl.modules.tensordict_module.common import ensure_tensordict_compatible +from ..modules.utils.utils import _find_action_space + from .common import LossModule from .utils import ( _GAMMA_LMBDA_DEPREC_WARNING, @@ -29,9 +35,24 @@ class DQNLoss(LossModule): Args: value_network (QValueActor or nn.Module): a Q value operator. + + Keyword Args: loss_function (str): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1". - delay_value (bool, optional): whether to duplicate the value network into a new target value network to + priority_key (str, optional): the key at which priority is assumed to + be stored within TensorDicts added to this ReplayBuffer. + This is to be used when the sampler is of type + :class:`~torchrl.data.PrioritizedSampler`. + Defaults to ``"td_error"``. + delay_value (bool, optional): whether to duplicate the value network + into a new target value network to create a double DQN. Default is ``False``. + action_space (str or TensorSpec, optional): Action space. Must be one of + ``"one-hot"``, ``"mult_one_hot"``, ``"binary"`` or ``"categorical"``, + or an instance of the corresponding specs (:class:`torchrl.data.OneHotDiscreteTensorSpec`, + :class:`torchrl.data.MultiOneHotDiscreteTensorSpec`, + :class:`torchrl.data.BinaryDiscreteTensorSpec` or :class:`torchrl.data.DiscreteTensorSpec`). + If not provided, an attempt to retrieve it from the value network + will be made. """ @@ -45,6 +66,7 @@ def __init__( priority_key: str = "td_error", delay_value: bool = False, gamma: float = None, + action_space: Union[str, TensorSpec] = None, ) -> None: super().__init__() @@ -63,7 +85,18 @@ def __init__( self.loss_function = loss_function self.priority_key = priority_key - self.action_space = self.value_network.action_space + if action_space is None: + # infer from value net + try: + action_space = value_network.spec + except AttributeError: + # let's try with action_space then + pass + try: + action_space = self.value_network.action_space + except AttributeError: + raise ValueError(self.ACTION_SPEC_ERROR) + self.action_space = _find_action_space(action_space) if gamma is not None: warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING)