Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] QValue modules and nested action #1351

Merged
merged 13 commits into from
Jul 5, 2023
Prev Previous commit
Next Next commit
warning for deprecation action_space spec
Signed-off-by: Matteo Bettini <matbet@meta.com>
  • Loading branch information
matteobettini committed Jul 4, 2023
commit 6395b79e66258d44ca5c31c940dbb3f39c472fe2
85 changes: 51 additions & 34 deletions torchrl/modules/tensordict_module/actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import warnings
from typing import List, Optional, Sequence, Tuple, Union

import torch
Expand Down Expand Up @@ -321,13 +321,10 @@ class QValueModule(TensorDictModuleBase):
It works with both tensordict and regular tensors.

Args:
action_space (str or TensorSpec, optional): Action space. Must be one of
``"one-hot"``, ``"mult-one-hot"``, ``"binary"`` or ``"categorical"``,
or an instance of the corresponding specs (:class:`torchrl.data.OneHotDiscreteTensorSpec`,
:class:`torchrl.data.MultiOneHotDiscreteTensorSpec`,
:class:`torchrl.data.BinaryDiscreteTensorSpec` or :class:`torchrl.data.DiscreteTensorSpec`).
This is argumets is exclusive with ``spec``, since the ``action_spec``
conditions the action spec.
action_space (str, optional): Action space. Must be one of
``"one-hot"``, ``"mult-one-hot"``, ``"binary"`` or ``"categorical"``.
This is argumets is exclusive with ``spec``, since ``spec``
conditions the action_space.
action_value_key (str or tuple of str, optional): The input key
representing the action value. Defaults to ``"action_value"``.
out_keys (list of str or tuple of str, optional): The output keys
Expand Down Expand Up @@ -379,13 +376,19 @@ class QValueModule(TensorDictModuleBase):

def __init__(
self,
action_space: Optional[Union[str, TensorSpec]],
action_space: Optional[str],
action_value_key: Union[List[str], List[Tuple[str]]] = None,
out_keys: Union[List[str], List[Tuple[str]]] = None,
var_nums: Optional[int] = None,
spec: Optional[TensorSpec] = None,
safe: bool = False,
):
if isinstance(action_space, TensorSpec):
warnings.warn(
"Using specs in action_space will be deprecated soon,"
" please use the 'spec' argument if you want to provide an action spec",
category=DeprecationWarning,
)
action_space, spec = _process_action_space_spec(action_space, spec)
self.action_space = action_space
self.var_nums = var_nums
Expand Down Expand Up @@ -512,13 +515,10 @@ class DistributionalQValueModule(QValueModule):
https://arxiv.org/pdf/1707.06887.pdf

Args:
action_space (str or TensorSpec, optional): Action space. Must be one of
``"one-hot"``, ``"mult-one-hot"``, ``"binary"`` or ``"categorical"``,
or an instance of the corresponding specs (:class:`torchrl.data.OneHotDiscreteTensorSpec`,
:class:`torchrl.data.MultiOneHotDiscreteTensorSpec`,
:class:`torchrl.data.BinaryDiscreteTensorSpec` or :class:`torchrl.data.DiscreteTensorSpec`).
This is argumets is exclusive with ``spec``, since the ``action_spec``
conditions the action spec.
action_space (str, optional): Action space. Must be one of
``"one-hot"``, ``"mult-one-hot"``, ``"binary"`` or ``"categorical"``.
This is argumets is exclusive with ``spec``, since ``spec``
conditions the action_space.
support (torch.Tensor): support of the action values.
action_value_key (str or tuple of str, optional): The input key
representing the action value. Defaults to ``"action_value"``.
Expand Down Expand Up @@ -574,7 +574,7 @@ class DistributionalQValueModule(QValueModule):

def __init__(
self,
action_space: str,
action_space: Optional[str],
support: torch.Tensor,
action_value_key: Union[List[str], List[Tuple[str]]] = None,
out_keys: Union[List[str], List[Tuple[str]]] = None,
Expand Down Expand Up @@ -782,6 +782,12 @@ def __init__(
action_value_key: Union[str, Tuple[str]] = None,
out_keys: Union[List[str], List[Tuple[str]]] = None,
):
if isinstance(action_space, TensorSpec):
warnings.warn(
"Using specs in action_space will be deprecated soon,"
" please use the 'spec' argument if you want to provide an action spec",
category=DeprecationWarning,
)
action_space, _ = _process_action_space_spec(action_space, None)

self.qvalue_model = QValueModule(
Expand Down Expand Up @@ -866,6 +872,12 @@ def __init__(
action_value_key: Union[str, Tuple[str]] = None,
out_keys: Union[List[str], List[Tuple[str]]] = None,
):
if isinstance(action_space, TensorSpec):
warnings.warn(
"Using specs in action_space will be deprecated soon,"
" please use the 'spec' argument if you want to provide an action spec",
category=DeprecationWarning,
)
action_space, _ = _process_action_space_spec(action_space, None)
self.qvalue_model = DistributionalQValueModule(
action_space=action_space,
Expand Down Expand Up @@ -911,13 +923,10 @@ class QValueActor(SafeSequential):
issues. If this value is out of bounds, it is projected back onto the
desired space using the :obj:`TensorSpec.project`
method. Default is ``False``.
action_space (str or TensorSpec, optional): Action space. Must be one of
``"one-hot"``, ``"mult-one-hot"``, ``"binary"`` or ``"categorical"``,
or an instance of the corresponding specs (:class:`torchrl.data.OneHotDiscreteTensorSpec`,
:class:`torchrl.data.MultiOneHotDiscreteTensorSpec`,
:class:`torchrl.data.BinaryDiscreteTensorSpec` or :class:`torchrl.data.DiscreteTensorSpec`).
This is argumets is exclusive with ``spec``, since the ``action_spec``
conditions the action spec.
action_space (str, optional): Action space. Must be one of
``"one-hot"``, ``"mult-one-hot"``, ``"binary"`` or ``"categorical"``.
This is argumets is exclusive with ``spec``, since ``spec``
conditions the action_space.
action_value_key (str or tuple of str, optional): if the input module
is a :class:`tensordict.nn.TensorDictModuleBase` instance, it must
match one of its output keys. Otherwise, this string represents
Expand Down Expand Up @@ -978,9 +987,15 @@ def __init__(
in_keys=None,
spec=None,
safe=False,
action_space: str = None,
action_space: Optional[str] = None,
action_value_key=None,
):
if isinstance(action_space, TensorSpec):
warnings.warn(
"Using specs in action_space will be deprecated soon,"
" please use the 'spec' argument if you want to provide an action spec",
category=DeprecationWarning,
)
action_space, spec = _process_action_space_spec(action_space, spec)

self.action_space = action_space
Expand Down Expand Up @@ -1060,13 +1075,10 @@ class DistributionalQValueActor(QValueActor):
this value represents the cardinality of each
action component.
support (torch.Tensor): support of the action values.
action_space (str or TensorSpec, optional): Action space. Must be one of
``"one-hot"``, ``"mult-one-hot"``, ``"binary"`` or ``"categorical"``,
or an instance of the corresponding specs (:class:`torchrl.data.OneHotDiscreteTensorSpec`,
:class:`torchrl.data.MultiOneHotDiscreteTensorSpec`,
:class:`torchrl.data.BinaryDiscreteTensorSpec` or :class:`torchrl.data.DiscreteTensorSpec`).
This is argumets is exclusive with ``spec``, since the ``action_spec``
conditions the action spec.
action_space (str, optional): Action space. Must be one of
``"one-hot"``, ``"mult-one-hot"``, ``"binary"`` or ``"categorical"``.
This is argumets is exclusive with ``spec``, since ``spec``
conditions the action_space.
make_log_softmax (bool, optional): if ``True`` and if the module is not
of type :class:`torchrl.modules.DistributionalDQNnet`, a log-softmax
operation will be applied along dimension -2 of the action value tensor.
Expand Down Expand Up @@ -1112,11 +1124,16 @@ def __init__(
spec=None,
safe=False,
var_nums: Optional[int] = None,
action_space: str = None,
action_space: Optional[str] = None,
action_value_key: str = "action_value",
make_log_softmax: bool = True,
):

if isinstance(action_space, TensorSpec):
warnings.warn(
"Using specs in action_space will be deprecated soon,"
" please use the 'spec' argument if you want to provide an action spec",
category=DeprecationWarning,
)
action_space, spec = _process_action_space_spec(action_space, spec)
self.action_space = action_space
self.action_value_key = action_value_key
Expand Down