diff --git a/benchmarks/ecosystem/gym_env_throughput.py b/benchmarks/ecosystem/gym_env_throughput.py index 13adacd5868..8b71d00e746 100644 --- a/benchmarks/ecosystem/gym_env_throughput.py +++ b/benchmarks/ecosystem/gym_env_throughput.py @@ -23,11 +23,11 @@ from torchrl.collectors import ( MultiaSyncDataCollector, MultiSyncDataCollector, - RandomPolicy, SyncDataCollector, ) from torchrl.envs import EnvCreator, GymEnv, ParallelEnv from torchrl.envs.libs.gym import gym_backend as gym_bc, set_gym_backend +from torchrl.envs.utils import RandomPolicy if __name__ == "__main__": avail_devices = ("cpu",) diff --git a/benchmarks/test_collectors_benchmark.py b/benchmarks/test_collectors_benchmark.py index 1e9634f643f..1bdd26c0746 100644 --- a/benchmarks/test_collectors_benchmark.py +++ b/benchmarks/test_collectors_benchmark.py @@ -11,10 +11,10 @@ from torchrl.collectors.collectors import ( MultiaSyncDataCollector, MultiSyncDataCollector, - RandomPolicy, ) from torchrl.envs import EnvCreator, GymEnv, StepCounter, TransformedEnv from torchrl.envs.libs.dm_control import DMControlEnv +from torchrl.envs.utils import RandomPolicy def single_collector_setup(): diff --git a/examples/distributed/collectors/multi_nodes/delayed_dist.py b/examples/distributed/collectors/multi_nodes/delayed_dist.py index 9bf17b76c10..b140ee7bc67 100644 --- a/examples/distributed/collectors/multi_nodes/delayed_dist.py +++ b/examples/distributed/collectors/multi_nodes/delayed_dist.py @@ -114,9 +114,9 @@ def main(): import gym from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector - from torchrl.collectors.collectors import RandomPolicy from torchrl.data import BoundedTensorSpec from torchrl.envs.libs.gym import GymEnv, set_gym_backend + from torchrl.envs.utils import RandomPolicy collector_class = SyncDataCollector if num_workers == 1 else MultiSyncDataCollector device_str = "device" if num_workers == 1 else "devices" diff --git a/examples/distributed/collectors/multi_nodes/delayed_rpc.py b/examples/distributed/collectors/multi_nodes/delayed_rpc.py index 890968c5aae..adff8864413 100644 --- a/examples/distributed/collectors/multi_nodes/delayed_rpc.py +++ b/examples/distributed/collectors/multi_nodes/delayed_rpc.py @@ -113,9 +113,9 @@ def main(): import gym from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector - from torchrl.collectors.collectors import RandomPolicy from torchrl.data import BoundedTensorSpec from torchrl.envs.libs.gym import GymEnv, set_gym_backend + from torchrl.envs.utils import RandomPolicy collector_class = SyncDataCollector if num_workers == 1 else MultiSyncDataCollector device_str = "device" if num_workers == 1 else "devices" diff --git a/examples/distributed/collectors/multi_nodes/generic.py b/examples/distributed/collectors/multi_nodes/generic.py index 9338a0acea7..2b6ec53628a 100644 --- a/examples/distributed/collectors/multi_nodes/generic.py +++ b/examples/distributed/collectors/multi_nodes/generic.py @@ -10,14 +10,11 @@ import tqdm from torchrl._utils import logger as torchrl_logger -from torchrl.collectors.collectors import ( - MultiSyncDataCollector, - RandomPolicy, - SyncDataCollector, -) +from torchrl.collectors.collectors import MultiSyncDataCollector, SyncDataCollector from torchrl.collectors.distributed import DistributedDataCollector from torchrl.envs import EnvCreator from torchrl.envs.libs.gym import GymEnv, set_gym_backend +from torchrl.envs.utils import RandomPolicy parser = ArgumentParser() parser.add_argument( diff --git a/examples/distributed/collectors/multi_nodes/rpc.py b/examples/distributed/collectors/multi_nodes/rpc.py index be30b9c3668..32600707ebc 100644 --- a/examples/distributed/collectors/multi_nodes/rpc.py +++ b/examples/distributed/collectors/multi_nodes/rpc.py @@ -11,14 +11,11 @@ import tqdm from torchrl._utils import logger as torchrl_logger -from torchrl.collectors.collectors import ( - MultiSyncDataCollector, - RandomPolicy, - SyncDataCollector, -) +from torchrl.collectors.collectors import MultiSyncDataCollector, SyncDataCollector from torchrl.collectors.distributed import RPCDataCollector from torchrl.envs import EnvCreator from torchrl.envs.libs.gym import GymEnv, set_gym_backend +from torchrl.envs.utils import RandomPolicy parser = ArgumentParser() parser.add_argument( diff --git a/examples/distributed/collectors/multi_nodes/sync.py b/examples/distributed/collectors/multi_nodes/sync.py index 688090ca691..7149a4ed82d 100644 --- a/examples/distributed/collectors/multi_nodes/sync.py +++ b/examples/distributed/collectors/multi_nodes/sync.py @@ -10,14 +10,11 @@ import tqdm from torchrl._utils import logger as torchrl_logger -from torchrl.collectors.collectors import ( - MultiSyncDataCollector, - RandomPolicy, - SyncDataCollector, -) +from torchrl.collectors.collectors import MultiSyncDataCollector, SyncDataCollector from torchrl.collectors.distributed import DistributedSyncDataCollector from torchrl.envs import EnvCreator from torchrl.envs.libs.gym import GymEnv, set_gym_backend +from torchrl.envs.utils import RandomPolicy parser = ArgumentParser() parser.add_argument( diff --git a/examples/distributed/collectors/single_machine/generic.py b/examples/distributed/collectors/single_machine/generic.py index cd723a63806..95a6ddf139d 100644 --- a/examples/distributed/collectors/single_machine/generic.py +++ b/examples/distributed/collectors/single_machine/generic.py @@ -29,12 +29,12 @@ from torchrl.collectors.collectors import ( MultiaSyncDataCollector, MultiSyncDataCollector, - RandomPolicy, SyncDataCollector, ) from torchrl.collectors.distributed import DistributedDataCollector from torchrl.envs import EnvCreator, ParallelEnv from torchrl.envs.libs.gym import GymEnv, set_gym_backend +from torchrl.envs.utils import RandomPolicy parser = ArgumentParser() parser.add_argument( diff --git a/examples/distributed/collectors/single_machine/rpc.py b/examples/distributed/collectors/single_machine/rpc.py index c001a6586b1..5876c9a3868 100644 --- a/examples/distributed/collectors/single_machine/rpc.py +++ b/examples/distributed/collectors/single_machine/rpc.py @@ -26,10 +26,11 @@ import tqdm from torchrl._utils import logger as torchrl_logger -from torchrl.collectors.collectors import RandomPolicy, SyncDataCollector +from torchrl.collectors.collectors import SyncDataCollector from torchrl.collectors.distributed import RPCDataCollector from torchrl.envs import EnvCreator, ParallelEnv from torchrl.envs.libs.gym import GymEnv, set_gym_backend +from torchrl.envs.utils import RandomPolicy parser = ArgumentParser() parser.add_argument( diff --git a/examples/distributed/collectors/single_machine/sync.py b/examples/distributed/collectors/single_machine/sync.py index b5c77ebdb5b..b04a7de45c4 100644 --- a/examples/distributed/collectors/single_machine/sync.py +++ b/examples/distributed/collectors/single_machine/sync.py @@ -27,14 +27,11 @@ import tqdm from torchrl._utils import logger as torchrl_logger -from torchrl.collectors.collectors import ( - MultiSyncDataCollector, - RandomPolicy, - SyncDataCollector, -) +from torchrl.collectors.collectors import MultiSyncDataCollector, SyncDataCollector from torchrl.collectors.distributed import DistributedSyncDataCollector from torchrl.envs import EnvCreator, ParallelEnv from torchrl.envs.libs.gym import GymEnv, set_gym_backend +from torchrl.envs.utils import RandomPolicy parser = ArgumentParser() parser.add_argument( diff --git a/test/test_collector.py b/test/test_collector.py index 09c6ee293c3..bebbd103bc7 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -49,7 +49,6 @@ _Interruptor, MultiaSyncDataCollector, MultiSyncDataCollector, - RandomPolicy, ) from torchrl.collectors.utils import split_trajectories from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec @@ -67,6 +66,7 @@ _aggregate_end_of_traj, check_env_specs, PARTIAL_MISSING_ERR, + RandomPolicy, ) from torchrl.modules import Actor, LSTMNet, OrnsteinUhlenbeckProcessWrapper, SafeModule diff --git a/test/test_distributed.py b/test/test_distributed.py index 5f37d8bcac9..40b4f5eae44 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -33,7 +33,6 @@ from torchrl.collectors.collectors import ( MultiaSyncDataCollector, MultiSyncDataCollector, - RandomPolicy, SyncDataCollector, ) from torchrl.collectors.distributed import ( @@ -43,6 +42,7 @@ RPCDataCollector, ) from torchrl.collectors.distributed.ray import DEFAULT_RAY_INIT_CONFIG +from torchrl.envs.utils import RandomPolicy TIMEOUT = 200 diff --git a/test/test_env.py b/test/test_env.py index d8136ff382b..f03b35e20c4 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -2592,6 +2592,23 @@ def make_env(seed, device=device): p_env.close() +@pytest.mark.skipif(not _has_gym, reason="Gym required for this test") +def test_non_td_policy(): + env = GymEnv("CartPole-v1", categorical_action_encoding=True) + + class ArgMaxModule(nn.Module): + def forward(self, values): + return values.argmax(-1) + + policy = nn.Sequential( + nn.Linear(env.observation_spec["observation"].shape[-1], env.action_spec.n), + ArgMaxModule(), + ) + env.rollout(10, policy) + env = SerialEnv(2, lambda: GymEnv("CartPole-v1", categorical_action_encoding=True)) + env.rollout(10, policy) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_libs.py b/test/test_libs.py index 427eef522d0..608b9280fba 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -52,7 +52,7 @@ ) from torch import nn from torchrl._utils import implement_for -from torchrl.collectors.collectors import RandomPolicy, SyncDataCollector +from torchrl.collectors.collectors import SyncDataCollector from torchrl.data import ( BinaryDiscreteTensorSpec, BoundedTensorSpec, @@ -105,7 +105,12 @@ from torchrl.envs.libs.robohive import _has_robohive, RoboHiveEnv from torchrl.envs.libs.smacv2 import _has_smacv2, SMACv2Env from torchrl.envs.libs.vmas import _has_vmas, VmasEnv, VmasWrapper -from torchrl.envs.utils import check_env_specs, ExplorationType, MarlGroupMapType +from torchrl.envs.utils import ( + check_env_specs, + ExplorationType, + MarlGroupMapType, + RandomPolicy, +) from torchrl.modules import ActorCriticOperator, MLP, SafeModule, ValueOperator _has_d4rl = importlib.util.find_spec("d4rl") is not None diff --git a/torchrl/collectors/__init__.py b/torchrl/collectors/__init__.py index a91589e71ce..d69d8c9e50c 100644 --- a/torchrl/collectors/__init__.py +++ b/torchrl/collectors/__init__.py @@ -3,11 +3,12 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from torchrl.envs.utils import RandomPolicy + from .collectors import ( aSyncDataCollector, DataCollectorBase, MultiaSyncDataCollector, MultiSyncDataCollector, - RandomPolicy, SyncDataCollector, ) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 08fa99526a1..f4b92c87d9d 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -11,7 +11,6 @@ import functools -import inspect import os import queue import sys @@ -29,16 +28,13 @@ import torch import torch.nn as nn from tensordict import ( - is_tensor_collection, LazyStackedTensorDict, TensorDict, TensorDictBase, TensorDictParams, ) -from tensordict.nn import TensorDictModule, TensorDictModuleBase -from tensordict.utils import NestedKey +from tensordict.nn import TensorDictModule from torch import multiprocessing as mp -from torch.utils._pytree import tree_map from torch.utils.data import IterableDataset from torchrl._utils import ( @@ -51,13 +47,15 @@ VERBOSE, ) from torchrl.collectors.utils import split_trajectories -from torchrl.data.tensor_specs import CompositeSpec, TensorSpec +from torchrl.data.tensor_specs import TensorSpec from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING from torchrl.envs.common import EnvBase from torchrl.envs.transforms import StepCounter, TransformedEnv from torchrl.envs.utils import ( _aggregate_end_of_traj, _convert_exploration_type, + _make_compatible_policy, + _NonParametricPolicyWrapper, ExplorationType, set_exploration_type, ) @@ -72,34 +70,6 @@ _is_osx = sys.platform.startswith("darwin") -class RandomPolicy: - """A random policy for data collectors. - - This is a wrapper around the action_spec.rand method. - - Args: - action_spec: TensorSpec object describing the action specs - - Examples: - >>> from tensordict import TensorDict - >>> from torchrl.data.tensor_specs import BoundedTensorSpec - >>> action_spec = BoundedTensorSpec(-torch.ones(3), torch.ones(3)) - >>> actor = RandomPolicy(action_spec=action_spec) - >>> td = actor(TensorDict({}, batch_size=[])) # selects a random action in the cube [-1; 1] - """ - - def __init__(self, action_spec: TensorSpec, action_key: NestedKey = "action"): - super().__init__() - self.action_spec = action_spec.clone() - self.action_key = action_key - - def __call__(self, td: TensorDictBase) -> TensorDictBase: - if isinstance(self.action_spec, CompositeSpec): - return td.update(self.action_spec.rand()) - else: - return td.set(self.action_key, self.action_spec.rand()) - - class _Interruptor: """A class for managing the collection state of a process. @@ -154,129 +124,11 @@ def recursive_map_to_cpu(dictionary: OrderedDict) -> OrderedDict: ) -def _policy_is_tensordict_compatible(policy: nn.Module): - if isinstance(policy, _NonParametricPolicyWrapper) and isinstance( - policy.policy, RandomPolicy - ): - return True - - if isinstance(policy, TensorDictModuleBase): - return True - - sig = inspect.signature(policy.forward) - - if ( - len(sig.parameters) == 1 - and hasattr(policy, "in_keys") - and hasattr(policy, "out_keys") - ): - raise RuntimeError( - "Passing a policy that is not a tensordict.nn.TensorDictModuleBase subclass but has in_keys and out_keys " - "is deprecated. Users should inherit from this class (which " - "has very few restrictions) to make the experience smoother. " - "Simply change your policy from `class Policy(nn.Module)` to `Policy(tensordict.nn.TensorDictModuleBase)` " - "and this error should disappear.", - ) - elif not hasattr(policy, "in_keys") and not hasattr(policy, "out_keys"): - # if it's not a TensorDictModule, and in_keys and out_keys are not defined then - # we assume no TensorDict compatibility and will try to wrap it. - return False - - # if in_keys or out_keys were defined but policy is not a TensorDictModule or - # accepts multiple arguments then it's likely the user is trying to do something - # that will have undetermined behaviour, we raise an error - raise TypeError( - "Received a policy that defines in_keys or out_keys and also expects multiple " - "arguments to policy.forward. If the policy is compatible with TensorDict, it " - "should take a single argument of type TensorDict to policy.forward and define " - "both in_keys and out_keys. Alternatively, policy.forward can accept " - "arbitrarily many tensor inputs and leave in_keys and out_keys undefined and " - "TorchRL will attempt to automatically wrap the policy with a TensorDictModule." - ) - - class DataCollectorBase(IterableDataset, metaclass=abc.ABCMeta): """Base class for data collectors.""" _iterator = None - def _make_compatible_policy(self, policy, observation_spec=None): - if policy is None: - if not hasattr(self, "env") or self.env is None: - raise ValueError( - "env must be provided to _get_policy_and_device if policy is None" - ) - policy = RandomPolicy(self.env.input_spec["full_action_spec"]) - # make sure policy is an nn.Module - policy = _NonParametricPolicyWrapper(policy) - if not _policy_is_tensordict_compatible(policy): - # policy is a nn.Module that doesn't operate on tensordicts directly - # so we attempt to auto-wrap policy with TensorDictModule - if observation_spec is None: - raise ValueError( - "Unable to read observation_spec from the environment. This is " - "required to check compatibility of the environment and policy " - "since the policy is a nn.Module that operates on tensors " - "rather than a TensorDictModule or a nn.Module that accepts a " - "TensorDict as input and defines in_keys and out_keys." - ) - - try: - # signature modified by make_functional - sig = policy.forward.__signature__ - except AttributeError: - sig = inspect.signature(policy.forward) - required_kwargs = { - str(k) for k, p in sig.parameters.items() if p.default is inspect._empty - } - next_observation = { - key: value for key, value in observation_spec.rand().items() - } - # we check if all the mandatory params are there - params = list(sig.parameters.keys()) - if ( - set(sig.parameters) == {"tensordict"} - or set(sig.parameters) == {"td"} - or ( - len(params) == 1 - and is_tensor_collection(sig.parameters[params[0]].annotation) - ) - ): - pass - elif not required_kwargs.difference(set(next_observation)): - in_keys = [str(k) for k in sig.parameters if k in next_observation] - if not hasattr(self, "env") or self.env is None: - out_keys = ["action"] - else: - out_keys = list(self.env.action_keys) - for p in policy.parameters(): - policy_device = p.device - break - else: - policy_device = None - if policy_device: - next_observation = tree_map( - lambda x: x.to(policy_device), next_observation - ) - output = policy(**next_observation) - - if isinstance(output, tuple): - out_keys.extend(f"output{i + 1}" for i in range(len(output) - 1)) - - policy = TensorDictModule(policy, in_keys=in_keys, out_keys=out_keys) - else: - raise TypeError( - f"""Arguments to policy.forward are incompatible with entries in -env.observation_spec (got incongruent signatures: fun signature is {set(sig.parameters)} vs specs {set(next_observation)}). -If you want TorchRL to automatically wrap your policy with a TensorDictModule -then the arguments to policy.forward must correspond one-to-one with entries -in env.observation_spec that are prefixed with 'next_'. For more complex -behaviour and more control you can consider writing your own TensorDictModule. -Check the collector documentation to know more about accepted policies. -""" - ) - return policy - def _get_policy_and_device( self, policy: Optional[ @@ -297,7 +149,9 @@ def _get_policy_and_device( observation_spec (TensorSpec, optional): spec of the observations """ - policy = self._make_compatible_policy(policy, observation_spec) + policy = _make_compatible_policy( + policy, observation_spec, env=getattr(self, "env", None) + ) param_and_buf = TensorDict.from_module(policy, as_module=True) def get_weights_fn(param_and_buf=param_and_buf): @@ -2736,57 +2590,6 @@ def _main_async_collector( raise Exception(f"Unrecognized message {msg}") -class _PolicyMetaClass(abc.ABCMeta): - def __call__(cls, *args, **kwargs): - # no kwargs - if isinstance(args[0], nn.Module): - return args[0] - return super().__call__(*args) - - -class _NonParametricPolicyWrapper(nn.Module, metaclass=_PolicyMetaClass): - """A wrapper for non-parametric policies.""" - - def __init__(self, policy): - super().__init__() - self.policy = policy - - @property - def forward(self): - forward = self.__dict__.get("_forward", None) - if forward is None: - - @functools.wraps(self.policy) - def forward(*input, **kwargs): - return self.policy.__call__(*input, **kwargs) - - self.__dict__["_forward"] = forward - return forward - - def __getattr__(self, attr: str) -> Any: - if attr in self.__dir__(): - return self.__getattribute__( - attr - ) # make sure that appropriate exceptions are raised - - elif attr.startswith("__"): - raise AttributeError( - "passing built-in private methods is " - f"not permitted with type {type(self)}. " - f"Got attribute {attr}." - ) - - elif "policy" in self.__dir__(): - policy = self.__getattribute__("policy") - return getattr(policy, attr) - try: - super().__getattr__(attr) - except Exception: - raise AttributeError( - f"policy not set in {self.__class__.__name__}, cannot access {attr}." - ) - - def _make_meta_params(param): is_param = isinstance(param, nn.Parameter) diff --git a/torchrl/collectors/distributed/utils.py b/torchrl/collectors/distributed/utils.py index 5ae8fc3f60f..aeee573f8dc 100644 --- a/torchrl/collectors/distributed/utils.py +++ b/torchrl/collectors/distributed/utils.py @@ -51,8 +51,8 @@ class submitit_delayed_launcher: >>> num_jobs=2 >>> @submitit_delayed_launcher(num_jobs=num_jobs) ... def main(): - ... from torchrl.envs.libs.gym import GymEnv - ... from torchrl.collectors.collectors import RandomPolicy + ... from torchrl.envs.utils import RandomPolicy + from torchrl.envs.libs.gym import GymEnv ... from torchrl.data import BoundedTensorSpec ... collector = DistributedDataCollector( ... [EnvCreator(lambda: GymEnv("Pendulum-v1"))] * num_jobs, diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 0b69167fee4..06a3ddfee3f 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -26,6 +26,7 @@ ) from torchrl.data.utils import DEVICE_TYPING from torchrl.envs.utils import ( + _make_compatible_policy, _repr_by_depth, _terminated_or_truncated, _update_during_reset, @@ -2260,9 +2261,11 @@ def rollout( Args: max_steps (int): maximum number of steps to be executed. The actual number of steps can be smaller if the environment reaches a done state before max_steps have been executed. - policy (callable, optional): callable to be called to compute the desired action. If no policy is provided, - actions will be called using :obj:`env.rand_step()` - default = None + policy (callable, optional): callable to be called to compute the desired action. + If no policy is provided, actions will be called using :obj:`env.rand_step()`. + The policy can be any callable that reads either a tensordict or + the entire sequence of observation entries __sorted as__ the ``env.observation_spec.keys()``. + Defaults to `None`. callback (callable, optional): function to be called at each iteration with the given TensorDict. auto_reset (bool, optional): if ``True``, resets automatically the environment if it is in a done state when the rollout is initiated. @@ -2283,7 +2286,11 @@ def rollout( The data returned will be marked with a "time" dimension name for the last dimension of the tensordict (at the ``env.ndim`` index). + ``rollout`` is quite handy to display what the data structure of the + environment looks like. + Examples: + >>> # Using rollout without a policy >>> from torchrl.envs.libs.gym import GymEnv >>> from torchrl.envs.transforms import TransformedEnv, StepCounter >>> env = TransformedEnv(GymEnv("Pendulum-v1"), StepCounter(max_steps=20)) @@ -2339,43 +2346,106 @@ def rollout( >>> print(rollout.names) [None, 'time'] + Using a policy (a regular :class:`~torch.nn.Module` or a :class:`~tensordict.nn.TensorDictModule`) + is also easy: + + Examples: + >>> from torch import nn + >>> env = GymEnv("CartPole-v1", categorical_action_encoding=True) + >>> class ArgMaxModule(nn.Module): + ... def forward(self, values): + ... return values.argmax(-1) + >>> n_obs = env.observation_spec["observation"].shape[-1] + >>> n_act = env.action_spec.n + >>> # A deterministic policy + >>> policy = nn.Sequential( + ... nn.Linear(n_obs, n_act), + ... ArgMaxModule()) + >>> env.rollout(max_steps=10, policy=policy) + TensorDict( + fields={ + action: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False), + done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([10, 4]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([10]), + device=cpu, + is_shared=False), + observation: Tensor(shape=torch.Size([10, 4]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([10]), + device=cpu, + is_shared=False) + >>> # Under the hood, rollout will wrap the policy in a TensorDictModule + >>> # To speed things up we can do that ourselves + >>> from tensordict.nn import TensorDictModule + >>> policy = TensorDictModule(policy, in_keys=list(env.observation_spec.keys()), out_keys=["action"]) + >>> env.rollout(max_steps=10, policy=policy) + TensorDict( + fields={ + action: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False), + done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([10, 4]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([10]), + device=cpu, + is_shared=False), + observation: Tensor(shape=torch.Size([10, 4]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([10]), + device=cpu, + is_shared=False) + + In some instances, contiguous tensordict cannot be obtained because they cannot be stacked. This can happen when the data returned at each step may have a different shape, or when different environments are executed together. In that case, ``return_contiguous=False`` will cause the returned tensordict to be a lazy stack of tensordicts: - Examples: + Examples of non-contiguous rollout: >>> rollout = env.rollout(4, return_contiguous=False) >>> print(rollout) - LazyStackedTensorDict( - fields={ - action: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False), - done: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False), - next: LazyStackedTensorDict( - fields={ - done: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False), - observation: Tensor(shape=torch.Size([3, 4, 3]), device=cpu, dtype=torch.float32, is_shared=False), - reward: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False), - step_count: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False), - truncated: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, - batch_size=torch.Size([3, 4]), - device=cpu, - is_shared=False), - observation: Tensor(shape=torch.Size([3, 4, 3]), device=cpu, dtype=torch.float32, is_shared=False), - step_count: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False), - truncated: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, - batch_size=torch.Size([3, 4]), - device=cpu, - is_shared=False) - >>> print(rollout.names) - [None, 'time'] + LazyStackedTensorDict( + fields={ + action: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False), + next: LazyStackedTensorDict( + fields={ + done: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([3, 4, 3]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False), + step_count: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False), + truncated: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([3, 4]), + device=cpu, + is_shared=False), + observation: Tensor(shape=torch.Size([3, 4, 3]), device=cpu, dtype=torch.float32, is_shared=False), + step_count: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False), + truncated: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([3, 4]), + device=cpu, + is_shared=False) + >>> print(rollout.names) + [None, 'time'] Rollouts can be used in a loop to emulate data collection. To do so, you need to pass as input the last tensordict coming from the previous rollout after calling :func:`~torchrl.envs.utils.step_mdp` on it. - Examples: + Examples of data collection rollouts: >>> from torchrl.envs import GymEnv, step_mdp >>> env = GymEnv("CartPole-v1") >>> epochs = 10 @@ -2392,12 +2462,19 @@ def rollout( ... ) """ - if auto_cast_to_device: - try: - policy_device = next(policy.parameters()).device - except (StopIteration, AttributeError): + if policy is not None: + policy = _make_compatible_policy( + policy, self.observation_spec, env=self, fast_wrap=True + ) + if auto_cast_to_device: + try: + policy_device = next(policy.parameters()).device + except (StopIteration, AttributeError): + policy_device = None + else: policy_device = None else: + policy = self.rand_action policy_device = None env_device = self.device @@ -2413,10 +2490,6 @@ def rollout( else: tensordict = self.maybe_reset(tensordict) - if policy is None: - - policy = self.rand_action - kwargs = { "tensordict": tensordict, "auto_cast_to_device": auto_cast_to_device, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 1273589a56f..e61de872a19 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -2637,8 +2637,8 @@ class CatFrames(ObservationTransform): gives the complete picture, together with the usage of a :class:`torchrl.data.ReplayBuffer`: Examples: - >>> from torchrl.envs import UnsqueezeTransform, CatFrames - >>> from torchrl.collectors import SyncDataCollector, RandomPolicy + >>> from torchrl.envs.utils import RandomPolicy >>> from torchrl.envs import UnsqueezeTransform, CatFrames + >>> from torchrl.collectors import SyncDataCollector >>> # Create a transformed environment with CatFrames: notice the usage of UnsqueezeTransform to create an extra dimension >>> env = TransformedEnv( ... GymEnv("CartPole-v1", from_pixels=True), @@ -6210,7 +6210,7 @@ class Reward2GoTransform(Transform): append the `inv` method of the transform. Examples: - >>> from torchrl.collectors import SyncDataCollector, RandomPolicy + >>> from torchrl.envs.utils import RandomPolicy >>> from torchrl.collectors import SyncDataCollector >>> from torchrl.envs.libs.gym import GymEnv >>> t = Reward2GoTransform(gamma=0.99, out_keys=["reward_to_go"]) >>> env = GymEnv("Pendulum-v1") diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 71b15d1dfae..fa3d28848a8 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -4,13 +4,16 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations +import abc import contextlib +import functools import importlib.util +import inspect import os import re from enum import Enum -from typing import Dict, List, Union +from typing import Any, Dict, List, Union import torch @@ -20,6 +23,7 @@ TensorDictBase, unravel_key, ) +from tensordict.nn import TensorDictModule, TensorDictModuleBase from tensordict.nn.probabilistic import ( # noqa # Note: the `set_interaction_mode` and their associated arg `default_interaction_mode` are being deprecated! # Please use the `set_/interaction_type` ones above with the InteractionType enum instead. @@ -31,6 +35,8 @@ set_interaction_type as set_exploration_type, ) from tensordict.utils import NestedKey +from torch import nn as nn +from torch.utils._pytree import tree_map from torchrl._utils import _replace_last, _rng_decorator, logger as torchrl_logger from torchrl.data.tensor_specs import ( @@ -1115,3 +1121,208 @@ def _repr_by_depth(key): return (0, key) else: return (len(key) - 1, ".".join(key)) + + +def _make_compatible_policy(policy, observation_spec, env=None, fast_wrap=False): + if policy is None: + if env is None: + raise ValueError( + "env must be provided to _get_policy_and_device if policy is None" + ) + policy = RandomPolicy(env.input_spec["full_action_spec"]) + # make sure policy is an nn.Module + policy = _NonParametricPolicyWrapper(policy) + if not _policy_is_tensordict_compatible(policy): + # policy is a nn.Module that doesn't operate on tensordicts directly + # so we attempt to auto-wrap policy with TensorDictModule + if observation_spec is None: + raise ValueError( + "Unable to read observation_spec from the environment. This is " + "required to check compatibility of the environment and policy " + "since the policy is a nn.Module that operates on tensors " + "rather than a TensorDictModule or a nn.Module that accepts a " + "TensorDict as input and defines in_keys and out_keys." + ) + + try: + # signature modified by make_functional + sig = policy.forward.__signature__ + except AttributeError: + sig = inspect.signature(policy.forward) + # we check if all the mandatory params are there + params = list(sig.parameters.keys()) + if ( + set(sig.parameters) == {"tensordict"} + or set(sig.parameters) == {"td"} + or ( + len(params) == 1 + and is_tensor_collection(sig.parameters[params[0]].annotation) + ) + ): + return policy + if fast_wrap: + in_keys = list(observation_spec.keys()) + out_keys = list(env.action_keys) + return TensorDictModule(policy, in_keys=in_keys, out_keys=out_keys) + + required_kwargs = { + str(k) for k, p in sig.parameters.items() if p.default is inspect._empty + } + next_observation = { + key: value for key, value in observation_spec.rand().items() + } + if not required_kwargs.difference(set(next_observation)): + in_keys = [str(k) for k in sig.parameters if k in next_observation] + if env is None: + out_keys = ["action"] + else: + out_keys = list(env.action_keys) + for p in policy.parameters(): + policy_device = p.device + break + else: + policy_device = None + if policy_device: + next_observation = tree_map( + lambda x: x.to(policy_device), next_observation + ) + + output = policy(**next_observation) + + if isinstance(output, tuple): + out_keys.extend(f"output{i + 1}" for i in range(len(output) - 1)) + + policy = TensorDictModule(policy, in_keys=in_keys, out_keys=out_keys) + else: + raise TypeError( + f"""Arguments to policy.forward are incompatible with entries in + env.observation_spec (got incongruent signatures: fun signature is {set(sig.parameters)} vs specs {set(next_observation)}). + If you want TorchRL to automatically wrap your policy with a TensorDictModule + then the arguments to policy.forward must correspond one-to-one with entries + in env.observation_spec. + For more complex behaviour and more control you can consider writing your + own TensorDictModule. + Check the collector documentation to know more about accepted policies. + """ + ) + return policy + + +def _policy_is_tensordict_compatible(policy: nn.Module): + if isinstance(policy, _NonParametricPolicyWrapper) and isinstance( + policy.policy, RandomPolicy + ): + return True + + if isinstance(policy, TensorDictModuleBase): + return True + + sig = inspect.signature(policy.forward) + + if ( + len(sig.parameters) == 1 + and hasattr(policy, "in_keys") + and hasattr(policy, "out_keys") + ): + raise RuntimeError( + "Passing a policy that is not a tensordict.nn.TensorDictModuleBase subclass but has in_keys and out_keys " + "is deprecated. Users should inherit from this class (which " + "has very few restrictions) to make the experience smoother. " + "Simply change your policy from `class Policy(nn.Module)` to `Policy(tensordict.nn.TensorDictModuleBase)` " + "and this error should disappear.", + ) + elif not hasattr(policy, "in_keys") and not hasattr(policy, "out_keys"): + # if it's not a TensorDictModule, and in_keys and out_keys are not defined then + # we assume no TensorDict compatibility and will try to wrap it. + return False + + # if in_keys or out_keys were defined but policy is not a TensorDictModule or + # accepts multiple arguments then it's likely the user is trying to do something + # that will have undetermined behaviour, we raise an error + raise TypeError( + "Received a policy that defines in_keys or out_keys and also expects multiple " + "arguments to policy.forward. If the policy is compatible with TensorDict, it " + "should take a single argument of type TensorDict to policy.forward and define " + "both in_keys and out_keys. Alternatively, policy.forward can accept " + "arbitrarily many tensor inputs and leave in_keys and out_keys undefined and " + "TorchRL will attempt to automatically wrap the policy with a TensorDictModule." + ) + + +class RandomPolicy: + """A random policy for data collectors. + + This is a wrapper around the action_spec.rand method. + + Args: + action_spec: TensorSpec object describing the action specs + + Examples: + >>> from tensordict import TensorDict + >>> from torchrl.data.tensor_specs import BoundedTensorSpec + >>> action_spec = BoundedTensorSpec(-torch.ones(3), torch.ones(3)) + >>> actor = RandomPolicy(action_spec=action_spec) + >>> td = actor(TensorDict({}, batch_size=[])) # selects a random action in the cube [-1; 1] + """ + + def __init__(self, action_spec: TensorSpec, action_key: NestedKey = "action"): + super().__init__() + self.action_spec = action_spec.clone() + self.action_key = action_key + + def __call__(self, td: TensorDictBase) -> TensorDictBase: + if isinstance(self.action_spec, CompositeSpec): + return td.update(self.action_spec.rand()) + else: + return td.set(self.action_key, self.action_spec.rand()) + + +class _PolicyMetaClass(abc.ABCMeta): + def __call__(cls, *args, **kwargs): + # no kwargs + if isinstance(args[0], nn.Module): + return args[0] + return super().__call__(*args) + + +class _NonParametricPolicyWrapper(nn.Module, metaclass=_PolicyMetaClass): + """A wrapper for non-parametric policies.""" + + def __init__(self, policy): + super().__init__() + self.policy = policy + + @property + def forward(self): + forward = self.__dict__.get("_forward", None) + if forward is None: + + @functools.wraps(self.policy) + def forward(*input, **kwargs): + return self.policy.__call__(*input, **kwargs) + + self.__dict__["_forward"] = forward + return forward + + def __getattr__(self, attr: str) -> Any: + if attr in self.__dir__(): + return self.__getattribute__( + attr + ) # make sure that appropriate exceptions are raised + + elif attr.startswith("__"): + raise AttributeError( + "passing built-in private methods is " + f"not permitted with type {type(self)}. " + f"Got attribute {attr}." + ) + + elif "policy" in self.__dir__(): + policy = self.__getattribute__("policy") + return getattr(policy, attr) + try: + super().__getattr__(attr) + except Exception: + raise AttributeError( + f"policy not set in {self.__class__.__name__}, cannot access {attr}." + ) diff --git a/tutorials/sphinx-tutorials/getting-started-3.py b/tutorials/sphinx-tutorials/getting-started-3.py index 829b22cf061..cf80b47f859 100644 --- a/tutorials/sphinx-tutorials/getting-started-3.py +++ b/tutorials/sphinx-tutorials/getting-started-3.py @@ -59,8 +59,9 @@ torch.manual_seed(0) -from torchrl.collectors import RandomPolicy, SyncDataCollector +from torchrl.collectors import SyncDataCollector from torchrl.envs import GymEnv +from torchrl.envs.utils import RandomPolicy env = GymEnv("CartPole-v1") env.set_seed(0) diff --git a/tutorials/sphinx-tutorials/rb_tutorial.py b/tutorials/sphinx-tutorials/rb_tutorial.py index 5b6effd5cdd..3c0bad89e70 100644 --- a/tutorials/sphinx-tutorials/rb_tutorial.py +++ b/tutorials/sphinx-tutorials/rb_tutorial.py @@ -646,7 +646,7 @@ def assert0(x): # transformations can be recycled in the replay buffer: -from torchrl.collectors import RandomPolicy, SyncDataCollector +from torchrl.collectors import SyncDataCollector from torchrl.envs.libs.gym import GymEnv from torchrl.envs.transforms import ( Compose, @@ -655,6 +655,7 @@ def assert0(x): ToTensorImage, TransformedEnv, ) +from torchrl.envs.utils import RandomPolicy env = TransformedEnv( GymEnv("CartPole-v1", from_pixels=True),