Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] QValue modules and nested action #1351

Merged
merged 13 commits into from
Jul 5, 2023
Prev Previous commit
type hints
Signed-off-by: Matteo Bettini <matbet@meta.com>
  • Loading branch information
matteobettini committed Jul 5, 2023
commit 996cf6b865e8aa1c2c1e4fc36aaea120efb42d2e
31 changes: 16 additions & 15 deletions torchrl/modules/tensordict_module/actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import warnings
from typing import List, Optional, Sequence, Tuple, Union
from typing import Optional, Sequence, Tuple, Union

import torch

Expand All @@ -15,6 +15,7 @@
TensorDictModuleWrapper,
TensorDictSequential,
)
from tensordict.utils import NestedKey
from torch import nn
from torch.distributions import Categorical

Expand Down Expand Up @@ -92,8 +93,8 @@ class Actor(SafeModule):
def __init__(
self,
module: nn.Module,
in_keys: Optional[Sequence[str]] = None,
out_keys: Optional[Sequence[str]] = None,
in_keys: Optional[Sequence[NestedKey]] = None,
out_keys: Optional[Sequence[NestedKey]] = None,
*,
spec: Optional[TensorSpec] = None,
**kwargs,
Expand Down Expand Up @@ -214,8 +215,8 @@ class ProbabilisticActor(SafeProbabilisticTensorDictSequential):
def __init__(
self,
module: TensorDictModule,
in_keys: Union[str, Sequence[str]],
out_keys: Optional[Sequence[str]] = None,
in_keys: Union[NestedKey, Sequence[NestedKey]],
out_keys: Optional[Sequence[NestedKey]] = None,
*,
spec: Optional[TensorSpec] = None,
**kwargs,
Expand Down Expand Up @@ -295,8 +296,8 @@ class ValueOperator(TensorDictModule):
def __init__(
self,
module: nn.Module,
in_keys: Optional[Sequence[str]] = None,
out_keys: Optional[Sequence[str]] = None,
in_keys: Optional[Sequence[NestedKey]] = None,
out_keys: Optional[Sequence[NestedKey]] = None,
) -> None:

if in_keys is None:
Expand Down Expand Up @@ -377,8 +378,8 @@ class QValueModule(TensorDictModuleBase):
def __init__(
self,
action_space: Optional[str],
action_value_key: Union[List[str], List[Tuple[str]]] = None,
out_keys: Union[List[str], List[Tuple[str]]] = None,
action_value_key: Optional[NestedKey] = None,
out_keys: Optional[Sequence[NestedKey]] = None,
var_nums: Optional[int] = None,
spec: Optional[TensorSpec] = None,
safe: bool = False,
Expand Down Expand Up @@ -576,8 +577,8 @@ def __init__(
self,
action_space: Optional[str],
support: torch.Tensor,
action_value_key: Union[List[str], List[Tuple[str]]] = None,
out_keys: Union[List[str], List[Tuple[str]]] = None,
action_value_key: Optional[NestedKey] = None,
out_keys: Optional[Sequence[NestedKey]] = None,
var_nums: Optional[int] = None,
spec: TensorSpec = None,
safe: bool = False,
Expand Down Expand Up @@ -779,8 +780,8 @@ def __init__(
self,
action_space: str,
var_nums: Optional[int] = None,
action_value_key: Union[str, Tuple[str]] = None,
out_keys: Union[List[str], List[Tuple[str]]] = None,
action_value_key: Optional[NestedKey] = None,
out_keys: Optional[Sequence[NestedKey]] = None,
):
if isinstance(action_space, TensorSpec):
warnings.warn(
Expand Down Expand Up @@ -869,8 +870,8 @@ def __init__(
action_space: str,
support: torch.Tensor,
var_nums: Optional[int] = None,
action_value_key: Union[str, Tuple[str]] = None,
out_keys: Union[List[str], List[Tuple[str]]] = None,
action_value_key: Optional[NestedKey] = None,
out_keys: Optional[Sequence[NestedKey]] = None,
):
if isinstance(action_space, TensorSpec):
warnings.warn(
Expand Down