Skip to content

Commit

Permalink
[BugFix] QValue modules and nested action (pytorch#1351)
Browse files Browse the repository at this point in the history
Signed-off-by: Matteo Bettini <matbet@meta.com>
  • Loading branch information
matteobettini authored Jul 5, 2023
1 parent 75a45be commit e09d2b3
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 63 deletions.
62 changes: 61 additions & 1 deletion test/test_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
import torch

from _utils_internal import get_default_devices

from mocking_classes import NestedCountingEnv
from tensordict import TensorDict
from tensordict.nn import TensorDictModule
from torch import nn

from torchrl.data import (
BinaryDiscreteTensorSpec,
CompositeSpec,
DiscreteTensorSpec,
MultiOneHotDiscreteTensorSpec,
Expand Down Expand Up @@ -76,6 +78,64 @@ def test_distributional_qvalue_hook_conflicting_spec(self):
):
_process_action_space_spec(None, None)

@pytest.mark.parametrize("nested_action", [True, False])
@pytest.mark.parametrize("batch_size", [(), (32,), (32, 1)])
def test_nested_keys(self, nested_action, batch_size, nested_dim=5):
# _process_action_space_spec can take
# an action_space argument (which can be string or non-composite spec)
# and a action_spec, which can be a spec
env = NestedCountingEnv(
nest_obs_action=nested_action, batch_size=batch_size, nested_dim=nested_dim
)
action_spec = env._input_spec["_action_spec"]
leaf_action_spec = env.action_spec

space_str, spec = _process_action_space_spec(None, action_spec)
assert spec == action_spec
assert space_str == "binary"

space_str, spec = _process_action_space_spec(None, leaf_action_spec)
assert spec == leaf_action_spec
assert space_str == "binary"

space_str, spec = _process_action_space_spec(leaf_action_spec, None)
assert spec == leaf_action_spec
assert space_str == "binary"

space_str, spec = _process_action_space_spec(leaf_action_spec, action_spec)
assert spec == action_spec # Spec wins
assert space_str == "binary"

space_str, spec = _process_action_space_spec("binary", action_spec)
assert spec == action_spec
assert space_str == "binary"

space_str, spec = _process_action_space_spec("binary", leaf_action_spec)
assert spec == leaf_action_spec
assert space_str == "binary"

with pytest.raises(
ValueError,
match="Passing an action_space as a TensorSpec and a spec isn't allowed, unless they match.",
):
_process_action_space_spec(BinaryDiscreteTensorSpec(n=1), action_spec)
_process_action_space_spec(BinaryDiscreteTensorSpec(n=1), leaf_action_spec)
with pytest.raises(
ValueError, match="action_space cannot be of type CompositeSpec"
):
_process_action_space_spec(action_spec, None)

mod = QValueModule(
action_value_key=("data", "action_value"),
out_keys=[
env.action_key,
("data", "action_value"),
("data", "chosen_action_value"),
],
action_space=None,
spec=action_spec,
)

@pytest.mark.parametrize(
"action_space, expected_action",
(
Expand Down
16 changes: 11 additions & 5 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,9 +240,13 @@ def _create_mock_actor(
return module.to(device)
actor = QValueActor(
spec=CompositeSpec(
action=action_spec,
action_value=None,
chosen_action_value=None,
{
"action": action_spec,
"action_value"
if action_value_key is None
else action_value_key: None,
"chosen_action_value": None,
},
shape=[],
),
action_space=action_spec_type,
Expand Down Expand Up @@ -285,8 +289,10 @@ def _create_mock_distributional_actor(
# return module
actor = DistributionalQValueActor(
spec=CompositeSpec(
action=action_spec,
action_value=None,
{
"action": action_spec,
action_value_key: None,
},
shape=[],
),
module=module,
Expand Down
140 changes: 84 additions & 56 deletions torchrl/modules/tensordict_module/actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import List, Optional, Sequence, Tuple, Union
import warnings
from typing import Optional, Sequence, Tuple, Union

import torch

Expand All @@ -15,6 +15,7 @@
TensorDictModuleWrapper,
TensorDictSequential,
)
from tensordict.utils import NestedKey
from torch import nn
from torch.distributions import Categorical

Expand Down Expand Up @@ -92,8 +93,8 @@ class Actor(SafeModule):
def __init__(
self,
module: nn.Module,
in_keys: Optional[Sequence[str]] = None,
out_keys: Optional[Sequence[str]] = None,
in_keys: Optional[Sequence[NestedKey]] = None,
out_keys: Optional[Sequence[NestedKey]] = None,
*,
spec: Optional[TensorSpec] = None,
**kwargs,
Expand Down Expand Up @@ -214,8 +215,8 @@ class ProbabilisticActor(SafeProbabilisticTensorDictSequential):
def __init__(
self,
module: TensorDictModule,
in_keys: Union[str, Sequence[str]],
out_keys: Optional[Sequence[str]] = None,
in_keys: Union[NestedKey, Sequence[NestedKey]],
out_keys: Optional[Sequence[NestedKey]] = None,
*,
spec: Optional[TensorSpec] = None,
**kwargs,
Expand Down Expand Up @@ -295,8 +296,8 @@ class ValueOperator(TensorDictModule):
def __init__(
self,
module: nn.Module,
in_keys: Optional[Sequence[str]] = None,
out_keys: Optional[Sequence[str]] = None,
in_keys: Optional[Sequence[NestedKey]] = None,
out_keys: Optional[Sequence[NestedKey]] = None,
) -> None:

if in_keys is None:
Expand All @@ -321,13 +322,10 @@ class QValueModule(TensorDictModuleBase):
It works with both tensordict and regular tensors.
Args:
action_space (str or TensorSpec, optional): Action space. Must be one of
``"one-hot"``, ``"mult-one-hot"``, ``"binary"`` or ``"categorical"``,
or an instance of the corresponding specs (:class:`torchrl.data.OneHotDiscreteTensorSpec`,
:class:`torchrl.data.MultiOneHotDiscreteTensorSpec`,
:class:`torchrl.data.BinaryDiscreteTensorSpec` or :class:`torchrl.data.DiscreteTensorSpec`).
This is argumets is exclusive with ``spec``, since the ``action_spec``
conditions the action spec.
action_space (str, optional): Action space. Must be one of
``"one-hot"``, ``"mult-one-hot"``, ``"binary"`` or ``"categorical"``.
This argument is exclusive with ``spec``, since ``spec``
conditions the action_space.
action_value_key (str or tuple of str, optional): The input key
representing the action value. Defaults to ``"action_value"``.
out_keys (list of str or tuple of str, optional): The output keys
Expand Down Expand Up @@ -379,13 +377,19 @@ class QValueModule(TensorDictModuleBase):

def __init__(
self,
action_space: Optional[Union[str, TensorSpec]],
action_value_key: Union[List[str], List[Tuple[str]]] = None,
out_keys: Union[List[str], List[Tuple[str]]] = None,
action_space: Optional[str],
action_value_key: Optional[NestedKey] = None,
out_keys: Optional[Sequence[NestedKey]] = None,
var_nums: Optional[int] = None,
spec: Optional[TensorSpec] = None,
safe: bool = False,
):
if isinstance(action_space, TensorSpec):
warnings.warn(
"Using specs in action_space will be deprecated soon,"
" please use the 'spec' argument if you want to provide an action spec",
category=DeprecationWarning,
)
action_space, spec = _process_action_space_spec(action_space, spec)
self.action_space = action_space
self.var_nums = var_nums
Expand Down Expand Up @@ -512,13 +516,10 @@ class DistributionalQValueModule(QValueModule):
https://arxiv.org/pdf/1707.06887.pdf
Args:
action_space (str or TensorSpec, optional): Action space. Must be one of
``"one-hot"``, ``"mult-one-hot"``, ``"binary"`` or ``"categorical"``,
or an instance of the corresponding specs (:class:`torchrl.data.OneHotDiscreteTensorSpec`,
:class:`torchrl.data.MultiOneHotDiscreteTensorSpec`,
:class:`torchrl.data.BinaryDiscreteTensorSpec` or :class:`torchrl.data.DiscreteTensorSpec`).
This is argumets is exclusive with ``spec``, since the ``action_spec``
conditions the action spec.
action_space (str, optional): Action space. Must be one of
``"one-hot"``, ``"mult-one-hot"``, ``"binary"`` or ``"categorical"``.
This argument is exclusive with ``spec``, since ``spec``
conditions the action_space.
support (torch.Tensor): support of the action values.
action_value_key (str or tuple of str, optional): The input key
representing the action value. Defaults to ``"action_value"``.
Expand Down Expand Up @@ -574,10 +575,10 @@ class DistributionalQValueModule(QValueModule):

def __init__(
self,
action_space: str,
action_space: Optional[str],
support: torch.Tensor,
action_value_key: Union[List[str], List[Tuple[str]]] = None,
out_keys: Union[List[str], List[Tuple[str]]] = None,
action_value_key: Optional[NestedKey] = None,
out_keys: Optional[Sequence[NestedKey]] = None,
var_nums: Optional[int] = None,
spec: TensorSpec = None,
safe: bool = False,
Expand Down Expand Up @@ -676,17 +677,27 @@ def _binary(self, value: torch.Tensor) -> torch.Tensor:


def _process_action_space_spec(action_space, spec):
nest_action = False
original_spec = spec
composite_spec = False
if isinstance(spec, CompositeSpec):
# this will break whenever our action is more complex than a single tensor
try:
# this will break whenever our action is more complex than a single tensor
spec = spec["action"]
nest_action = True
if "action" in spec.keys():
_key = "action"
else:
# the first key is the action
for _key in spec.keys(True, True):
if isinstance(_key, tuple) and _key[-1] == "action":
break
else:
raise KeyError
spec = spec[_key]
composite_spec = True
except KeyError:
raise KeyError(
"action could not be found in the spec. Make sure "
"you pass a spec that is either a native action spec or a composite action spec "
"with an 'action' entry. Otherwise, simply remove the spec and use the action_space only."
"with a leaf 'action' entry. Otherwise, simply remove the spec and use the action_space only."
)
if action_space is not None:
if isinstance(action_space, CompositeSpec):
Expand All @@ -713,8 +724,8 @@ def _process_action_space_spec(action_space, spec):
raise ValueError(
"Neither action_space nor spec was defined. The action space cannot be inferred."
)
if nest_action:
spec = CompositeSpec(action=spec)
if composite_spec:
spec = original_spec
return action_space, spec


Expand Down Expand Up @@ -769,9 +780,15 @@ def __init__(
self,
action_space: str,
var_nums: Optional[int] = None,
action_value_key: Union[str, Tuple[str]] = None,
out_keys: Union[List[str], List[Tuple[str]]] = None,
action_value_key: Optional[NestedKey] = None,
out_keys: Optional[Sequence[NestedKey]] = None,
):
if isinstance(action_space, TensorSpec):
warnings.warn(
"Using specs in action_space will be deprecated soon,"
" please use the 'spec' argument if you want to provide an action spec",
category=DeprecationWarning,
)
action_space, _ = _process_action_space_spec(action_space, None)

self.qvalue_model = QValueModule(
Expand Down Expand Up @@ -853,9 +870,15 @@ def __init__(
action_space: str,
support: torch.Tensor,
var_nums: Optional[int] = None,
action_value_key: Union[str, Tuple[str]] = None,
out_keys: Union[List[str], List[Tuple[str]]] = None,
action_value_key: Optional[NestedKey] = None,
out_keys: Optional[Sequence[NestedKey]] = None,
):
if isinstance(action_space, TensorSpec):
warnings.warn(
"Using specs in action_space will be deprecated soon,"
" please use the 'spec' argument if you want to provide an action spec",
category=DeprecationWarning,
)
action_space, _ = _process_action_space_spec(action_space, None)
self.qvalue_model = DistributionalQValueModule(
action_space=action_space,
Expand Down Expand Up @@ -901,13 +924,10 @@ class QValueActor(SafeSequential):
issues. If this value is out of bounds, it is projected back onto the
desired space using the :obj:`TensorSpec.project`
method. Default is ``False``.
action_space (str or TensorSpec, optional): Action space. Must be one of
``"one-hot"``, ``"mult-one-hot"``, ``"binary"`` or ``"categorical"``,
or an instance of the corresponding specs (:class:`torchrl.data.OneHotDiscreteTensorSpec`,
:class:`torchrl.data.MultiOneHotDiscreteTensorSpec`,
:class:`torchrl.data.BinaryDiscreteTensorSpec` or :class:`torchrl.data.DiscreteTensorSpec`).
This is argumets is exclusive with ``spec``, since the ``action_spec``
conditions the action spec.
action_space (str, optional): Action space. Must be one of
``"one-hot"``, ``"mult-one-hot"``, ``"binary"`` or ``"categorical"``.
This argument is exclusive with ``spec``, since ``spec``
conditions the action_space.
action_value_key (str or tuple of str, optional): if the input module
is a :class:`tensordict.nn.TensorDictModuleBase` instance, it must
match one of its output keys. Otherwise, this string represents
Expand Down Expand Up @@ -968,9 +988,15 @@ def __init__(
in_keys=None,
spec=None,
safe=False,
action_space: str = None,
action_space: Optional[str] = None,
action_value_key=None,
):
if isinstance(action_space, TensorSpec):
warnings.warn(
"Using specs in action_space will be deprecated soon,"
" please use the 'spec' argument if you want to provide an action spec",
category=DeprecationWarning,
)
action_space, spec = _process_action_space_spec(action_space, spec)

self.action_space = action_space
Expand Down Expand Up @@ -1050,13 +1076,10 @@ class DistributionalQValueActor(QValueActor):
this value represents the cardinality of each
action component.
support (torch.Tensor): support of the action values.
action_space (str or TensorSpec, optional): Action space. Must be one of
``"one-hot"``, ``"mult-one-hot"``, ``"binary"`` or ``"categorical"``,
or an instance of the corresponding specs (:class:`torchrl.data.OneHotDiscreteTensorSpec`,
:class:`torchrl.data.MultiOneHotDiscreteTensorSpec`,
:class:`torchrl.data.BinaryDiscreteTensorSpec` or :class:`torchrl.data.DiscreteTensorSpec`).
This is argumets is exclusive with ``spec``, since the ``action_spec``
conditions the action spec.
action_space (str, optional): Action space. Must be one of
``"one-hot"``, ``"mult-one-hot"``, ``"binary"`` or ``"categorical"``.
This argument is exclusive with ``spec``, since ``spec``
conditions the action_space.
make_log_softmax (bool, optional): if ``True`` and if the module is not
of type :class:`torchrl.modules.DistributionalDQNnet`, a log-softmax
operation will be applied along dimension -2 of the action value tensor.
Expand Down Expand Up @@ -1102,11 +1125,16 @@ def __init__(
spec=None,
safe=False,
var_nums: Optional[int] = None,
action_space: str = None,
action_space: Optional[str] = None,
action_value_key: str = "action_value",
make_log_softmax: bool = True,
):

if isinstance(action_space, TensorSpec):
warnings.warn(
"Using specs in action_space will be deprecated soon,"
" please use the 'spec' argument if you want to provide an action spec",
category=DeprecationWarning,
)
action_space, spec = _process_action_space_spec(action_space, spec)
self.action_space = action_space
self.action_value_key = action_value_key
Expand Down
Loading

0 comments on commit e09d2b3

Please sign in to comment.