diff --git a/benchmarks/ecosystem/gym_env_throughput.py b/benchmarks/ecosystem/gym_env_throughput.py index c69fc985ded..13adacd5868 100644 --- a/benchmarks/ecosystem/gym_env_throughput.py +++ b/benchmarks/ecosystem/gym_env_throughput.py @@ -115,7 +115,6 @@ def make(envname=envname, gym_backend=gym_backend): frames_per_batch=1024, total_frames=num_workers * 10_000, device=device, - storing_device=device, ) pbar = tqdm.tqdm(total=num_workers * 10_000) total_frames = 0 @@ -178,7 +177,6 @@ def make_env(envname=envname, gym_backend=gym_backend): frames_per_batch=1024, total_frames=num_workers * 10_000, device=device, - storing_device=device, ) pbar = tqdm.tqdm(total=num_workers * 10_000) total_frames = 0 @@ -222,7 +220,6 @@ def make_env( total_frames=num_workers * 10_000, num_sub_threads=num_workers // num_collectors, device=device, - storing_device=device, ) pbar = tqdm.tqdm(total=num_workers * 10_000) total_frames = 0 @@ -260,7 +257,6 @@ def make_env(envname=envname, gym_backend=gym_backend): frames_per_batch=1024, total_frames=num_workers * 10_000, device=device, - storing_device=device, ) pbar = tqdm.tqdm(total=num_workers * 10_000) total_frames = 0 diff --git a/test/mocking_classes.py b/test/mocking_classes.py index def8ddae1d5..9e5b2ff6879 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -6,7 +6,8 @@ import torch import torch.nn as nn -from tensordict.tensordict import TensorDict, TensorDictBase +from tensordict import TensorDict, TensorDictBase +from tensordict.nn import TensorDictModuleBase from tensordict.utils import expand_right, NestedKey from torchrl.data.tensor_specs import ( @@ -229,6 +230,7 @@ def _step(self, tensordict): "observation": n.clone(), }, batch_size=[], + device=self.device, ) def _reset(self, tensordict: TensorDictBase = None, **kwargs) -> TensorDictBase: @@ -240,7 +242,9 @@ def _reset(self, tensordict: TensorDictBase = None, **kwargs) -> TensorDictBase: done = self.counter >= self.max_val done = torch.tensor([done], dtype=torch.bool, device=self.device) return TensorDict( - {"done": done, "terminated": done.clone(), "observation": n}, [] + {"done": done, "terminated": done.clone(), "observation": n}, + [], + device=self.device, ) def rand_step(self, tensordict: Optional[TensorDictBase] = None) -> TensorDictBase: @@ -1374,8 +1378,9 @@ def _step( return tensordict -class HeteroCountingEnvPolicy: +class HeterogeneousCountingEnvPolicy(TensorDictModuleBase): def __init__(self, full_action_spec: TensorSpec, count: bool = True): + super().__init__() self.full_action_spec = full_action_spec self.count = count @@ -1386,7 +1391,7 @@ def __call__(self, td: TensorDictBase) -> TensorDictBase: return td.update(action_td) -class HeteroCountingEnv(EnvBase): +class HeterogeneousCountingEnv(EnvBase): """A heterogeneous, counting Env.""" def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs): @@ -1569,13 +1574,14 @@ def _set_seed(self, seed: Optional[int]): torch.manual_seed(seed) -class MultiKeyCountingEnvPolicy: +class MultiKeyCountingEnvPolicy(TensorDictModuleBase): def __init__( self, full_action_spec: TensorSpec, count: bool = True, deterministic: bool = False, ): + super().__init__() if not deterministic and not count: raise ValueError("Not counting policy is always deterministic") diff --git a/test/test_collector.py b/test/test_collector.py index ce7cade5746..2e090ad8fcf 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations import argparse import logging @@ -11,13 +12,16 @@ import numpy as np import pytest import torch + from _utils_internal import ( check_rollout_consistency_multikey_env, decorate_thread_sub_func, generate_seeds, + get_available_devices, get_default_devices, PENDULUM_VERSIONED, PONG_VERSIONED, + retry, ) from mocking_classes import ( ContinuousActionVecMockEnv, @@ -28,16 +32,15 @@ DiscreteActionConvPolicy, DiscreteActionVecMockEnv, DiscreteActionVecPolicy, - HeteroCountingEnv, - HeteroCountingEnvPolicy, + HeterogeneousCountingEnv, + HeterogeneousCountingEnvPolicy, MockSerialEnv, MultiKeyCountingEnv, MultiKeyCountingEnvPolicy, NestedCountingEnv, ) -from tensordict import LazyStackedTensorDict -from tensordict.nn import TensorDictModule, TensorDictSequential -from tensordict.tensordict import assert_allclose_td, TensorDict +from tensordict import assert_allclose_td, LazyStackedTensorDict, TensorDict +from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictSequential from torch import nn from torchrl._utils import _replace_last, prod, seed_generator @@ -95,16 +98,15 @@ def __init__(self, out_features: int): self.linear = nn.LazyLinear(out_features) def forward(self, tensordict): - return TensorDict( - {self.out_keys[0]: self.linear(tensordict.get(self.in_keys[0]))}, - [], + return tensordict.set( + self.out_keys[0], self.linear(tensordict.get(self.in_keys[0])) ) class UnwrappablePolicy(nn.Module): def __init__(self, out_features: int): super().__init__() - self.linear = nn.LazyLinear(out_features) + self.linear = nn.Linear(2, out_features) def forward(self, observation, other_stuff): return self.linear(observation), other_stuff.sum() @@ -163,110 +165,360 @@ def make_policy(env): raise NotImplementedError -def _is_consistent_device_type( - device_type, policy_device_type, storing_device_type, tensordict_device_type -): - if storing_device_type is None: - if device_type is None: - if policy_device_type is None: - return tensordict_device_type == "cpu" +# def _is_consistent_device_type( +# device_type, policy_device_type, storing_device_type, tensordict_device_type +# ): +# if storing_device_type is None: +# if device_type is None: +# if policy_device_type is None: +# return tensordict_device_type == "cpu" +# +# return tensordict_device_type == policy_device_type +# +# return tensordict_device_type == device_type +# +# return tensordict_device_type == storing_device_type - return tensordict_device_type == policy_device_type - return tensordict_device_type == device_type +class TestCollectorDevices: + class DeviceLessEnv(EnvBase): + # receives data on cpu, outputs on gpu -- tensordict has no device + def __init__(self, default_device): + self.default_device = default_device + super().__init__(device=None) + self.observation_spec = CompositeSpec( + observation=UnboundedContinuousTensorSpec((), device=default_device) + ) + self.reward_spec = UnboundedContinuousTensorSpec(1, device=default_device) + self.full_done_spec = CompositeSpec( + done=UnboundedContinuousTensorSpec( + 1, dtype=torch.bool, device=self.default_device + ), + truncated=UnboundedContinuousTensorSpec( + 1, dtype=torch.bool, device=self.default_device + ), + terminated=UnboundedContinuousTensorSpec( + 1, dtype=torch.bool, device=self.default_device + ), + ) + self.action_spec = UnboundedContinuousTensorSpec((), device=None) + assert self.device is None + assert self.full_observation_spec is not None + assert self.full_done_spec is not None + assert self.full_state_spec is not None + assert self.full_action_spec is not None + assert self.full_reward_spec is not None + + def _step(self, tensordict): + assert tensordict.device is None + with torch.device(self.default_device): + return TensorDict( + { + "observation": torch.zeros(()), + "reward": torch.zeros((1,)), + "done": torch.zeros((1,), dtype=torch.bool), + "terminated": torch.zeros((1,), dtype=torch.bool), + "truncated": torch.zeros((1,), dtype=torch.bool), + }, + batch_size=[], + device=None, + ) - return tensordict_device_type == storing_device_type + def _reset(self, tensordict=None): + with torch.device(self.default_device): + return TensorDict( + { + "observation": torch.zeros(()), + "done": torch.zeros((1,), dtype=torch.bool), + "terminated": torch.zeros((1,), dtype=torch.bool), + "truncated": torch.zeros((1,), dtype=torch.bool), + }, + batch_size=[], + device=None, + ) + def _set_seed(self, seed: int | None = None): + return seed -@pytest.mark.skipif( - IS_WINDOWS and PYTHON_3_10, - reason="Windows Access Violation in torch.multiprocessing / BrokenPipeError in multiprocessing.connection", -) -@pytest.mark.parametrize("num_env", [2]) -@pytest.mark.parametrize("device", ["cuda", "cpu", None]) -@pytest.mark.parametrize("policy_device", ["cuda", "cpu", None]) -@pytest.mark.parametrize("storing_device", ["cuda", "cpu", None]) -def test_output_device_consistency( - num_env, device, policy_device, storing_device, seed=40 -): - if ( - device == "cuda" or policy_device == "cuda" or storing_device == "cuda" - ) and not torch.cuda.is_available(): - pytest.skip("cuda is not available") - - if IS_WINDOWS and PYTHON_3_7: - if device == "cuda" and policy_device == "cuda" and device is None: - pytest.skip( - "BrokenPipeError in multiprocessing.connection with Python 3.7 on Windows" + class EnvWithDevice(EnvBase): + def __init__(self, default_device): + self.default_device = default_device + super().__init__(device=self.default_device) + self.observation_spec = CompositeSpec( + observation=UnboundedContinuousTensorSpec( + (), device=self.default_device + ) ) + self.reward_spec = UnboundedContinuousTensorSpec( + 1, device=self.default_device + ) + self.full_done_spec = CompositeSpec( + done=UnboundedContinuousTensorSpec( + 1, dtype=torch.bool, device=self.default_device + ), + truncated=UnboundedContinuousTensorSpec( + 1, dtype=torch.bool, device=self.default_device + ), + terminated=UnboundedContinuousTensorSpec( + 1, dtype=torch.bool, device=self.default_device + ), + device=self.default_device, + ) + self.action_spec = UnboundedContinuousTensorSpec( + (), device=self.default_device + ) + assert self.device == torch.device(self.default_device) + assert self.full_observation_spec is not None + assert self.full_done_spec is not None + assert self.full_state_spec is not None + assert self.full_action_spec is not None + assert self.full_reward_spec is not None + + def _step(self, tensordict): + assert tensordict.device == torch.device(self.default_device) + with torch.device(self.default_device): + return TensorDict( + { + "observation": torch.zeros(()), + "reward": torch.zeros((1,)), + "done": torch.zeros((1,), dtype=torch.bool), + "terminated": torch.zeros((1,), dtype=torch.bool), + "truncated": torch.zeros((1,), dtype=torch.bool), + }, + batch_size=[], + device=self.default_device, + ) - _device = "cuda:0" if device == "cuda" else device - _policy_device = "cuda:0" if policy_device == "cuda" else policy_device - _storing_device = "cuda:0" if storing_device == "cuda" else storing_device - - if num_env == 1: - - def env_fn(seed): - env = make_make_env("vec")() - env.set_seed(seed) - return env + def _reset(self, tensordict=None): + with torch.device(self.default_device): + return TensorDict( + { + "observation": torch.zeros(()), + "done": torch.zeros((1,), dtype=torch.bool), + "terminated": torch.zeros((1,), dtype=torch.bool), + "truncated": torch.zeros((1,), dtype=torch.bool), + }, + batch_size=[], + device=self.default_device, + ) - else: + def _set_seed(self, seed: int | None = None): + return seed - def env_fn(seed): - # 1226: faster execution - # env = ParallelEnv( - env = SerialEnv( - num_workers=num_env, - create_env_fn=make_make_env("vec"), - create_env_kwargs=[{"seed": i} for i in range(seed, seed + num_env)], - ) - return env + class DeviceLessPolicy(TensorDictModuleBase): + in_keys = ["observation"] + out_keys = ["action"] - if _policy_device is None: - policy = make_policy("vec") - else: - policy = ParametricPolicy().to(torch.device(_policy_device)) + # receives data on gpu and outputs on cpu + def forward(self, tensordict): + assert tensordict.device is None + return tensordict.set("action", torch.zeros((), device="cpu")) - collector = SyncDataCollector( - create_env_fn=env_fn, - create_env_kwargs={"seed": seed}, - policy=policy, - frames_per_batch=20, - max_frames_per_traj=2000, - total_frames=20000, - device=_device, - storing_device=_storing_device, - ) - for _, d in enumerate(collector): - assert _is_consistent_device_type( - device, policy_device, storing_device, d.device.type + class PolicyWithDevice(TensorDictModuleBase): + in_keys = ["observation"] + out_keys = ["action"] + # receives and sends data on gpu + default_device = "cuda:0" if torch.cuda.device_count() else "cpu" + + def forward(self, tensordict): + assert tensordict.device == torch.device(self.default_device) + return tensordict.set("action", torch.zeros((), device=self.default_device)) + + @pytest.mark.parametrize("main_device", get_default_devices()) + @pytest.mark.parametrize("storing_device", [None, *get_default_devices()]) + def test_output_device(self, main_device, storing_device): + + # env has no device, policy is strictly on GPU + device = None + env_device = None + policy_device = main_device + env = self.DeviceLessEnv(main_device) + policy = self.PolicyWithDevice() + collector = SyncDataCollector( + env, + policy, + device=device, + storing_device=storing_device, + policy_device=policy_device, + env_device=env_device, + frames_per_batch=1, + total_frames=10, ) - break - assert d.names[-1] == "time" + for data in collector: # noqa: B007 + break - collector.shutdown() + assert data.device == storing_device - ccollector = aSyncDataCollector( - create_env_fn=env_fn, - create_env_kwargs={"seed": seed}, - policy=policy, - frames_per_batch=20, - max_frames_per_traj=2000, - total_frames=20000, - device=_device, - storing_device=_storing_device, - ) - - for _, d in enumerate(ccollector): - assert _is_consistent_device_type( - device, policy_device, storing_device, d.device.type + # env is on cuda, policy has no device + device = None + env_device = main_device + policy_device = None + env = self.EnvWithDevice(main_device) + policy = self.DeviceLessPolicy() + collector = SyncDataCollector( + env, + policy, + device=device, + storing_device=storing_device, + policy_device=policy_device, + env_device=env_device, + frames_per_batch=1, + total_frames=10, ) - break - assert d.names[-1] == "time" - - ccollector.shutdown() - del ccollector + for data in collector: # noqa: B007 + break + assert data.device == storing_device + + # env and policy are on device + device = main_device + env_device = None + policy_device = None + env = self.EnvWithDevice(main_device) + policy = self.PolicyWithDevice() + collector = SyncDataCollector( + env, + policy, + device=device, + storing_device=storing_device, + policy_device=policy_device, + env_device=env_device, + frames_per_batch=1, + total_frames=10, + ) + for data in collector: # noqa: B007 + break + assert data.device == main_device + + # same but more specific + device = None + env_device = main_device + policy_device = main_device + env = self.EnvWithDevice(main_device) + policy = self.PolicyWithDevice() + collector = SyncDataCollector( + env, + policy, + device=device, + storing_device=storing_device, + policy_device=policy_device, + env_device=env_device, + frames_per_batch=1, + total_frames=10, + ) + for data in collector: # noqa: B007 + break + assert data.device == main_device + + # none has a device + device = None + env_device = None + policy_device = None + env = self.DeviceLessEnv(main_device) + policy = self.DeviceLessPolicy() + collector = SyncDataCollector( + env, + policy, + device=device, + storing_device=storing_device, + policy_device=policy_device, + env_device=env_device, + frames_per_batch=1, + total_frames=10, + ) + for data in collector: # noqa: B007 + break + assert data.device == storing_device + + +# @pytest.mark.skipif( +# IS_WINDOWS and PYTHON_3_10, +# reason="Windows Access Violation in torch.multiprocessing / BrokenPipeError in multiprocessing.connection", +# ) +# @pytest.mark.parametrize("num_env", [2]) +# @pytest.mark.parametrize("device", ["cuda", "cpu", None]) +# @pytest.mark.parametrize("policy_device", ["cuda", "cpu", None]) +# @pytest.mark.parametrize("storing_device", ["cuda", "cpu", None]) +# def test_output_device_consistency( +# num_env, device, policy_device, storing_device, seed=40 +# ): +# if ( +# device == "cuda" or policy_device == "cuda" or storing_device == "cuda" +# ) and not torch.cuda.is_available(): +# pytest.skip("cuda is not available") +# +# if IS_WINDOWS and PYTHON_3_7: +# if device == "cuda" and policy_device == "cuda" and device is None: +# pytest.skip( +# "BrokenPipeError in multiprocessing.connection with Python 3.7 on Windows" +# ) +# +# _device = "cuda:0" if device == "cuda" else device +# _policy_device = "cuda:0" if policy_device == "cuda" else policy_device +# _storing_device = "cuda:0" if storing_device == "cuda" else storing_device +# +# if num_env == 1: +# +# def env_fn(seed): +# env = make_make_env("vec")() +# env.set_seed(seed) +# return env +# +# else: +# +# def env_fn(seed): +# # 1226: faster execution +# # env = ParallelEnv( +# env = SerialEnv( +# num_workers=num_env, +# create_env_fn=make_make_env("vec"), +# create_env_kwargs=[{"seed": i} for i in range(seed, seed + num_env)], +# ) +# return env +# +# if _policy_device is None: +# policy = make_policy("vec") +# else: +# policy = ParametricPolicy().to(torch.device(_policy_device)) +# +# collector = SyncDataCollector( +# create_env_fn=env_fn, +# create_env_kwargs={"seed": seed}, +# policy=policy, +# frames_per_batch=20, +# max_frames_per_traj=2000, +# total_frames=20000, +# device=_device, +# storing_device=_storing_device, +# ) +# for _, d in enumerate(collector): +# assert _is_consistent_device_type( +# device, policy_device, storing_device, d.device.type +# ) +# break +# assert d.names[-1] == "time" +# +# collector.shutdown() +# +# ccollector = aSyncDataCollector( +# create_env_fn=env_fn, +# create_env_kwargs={"seed": seed}, +# policy=policy, +# frames_per_batch=20, +# max_frames_per_traj=2000, +# total_frames=20000, +# device=_device, +# storing_device=_storing_device, +# ) +# +# for _, d in enumerate(ccollector): +# assert _is_consistent_device_type( +# device, policy_device, storing_device, d.device.type +# ) +# break +# assert d.names[-1] == "time" +# +# ccollector.shutdown() +# del ccollector @pytest.mark.parametrize("num_env", [1, 2]) @@ -830,7 +1082,10 @@ def test_collector_vecnorm_envcreator(static_seed): policy = RandomPolicy(env_make.action_spec) num_data_collectors = 2 c = MultiSyncDataCollector( - [env_make] * num_data_collectors, policy=policy, total_frames=int(1e6) + [env_make] * num_data_collectors, + policy=policy, + total_frames=int(1e6), + frames_per_batch=200, ) init_seed = 0 @@ -889,8 +1144,9 @@ def create_env(): collector = collector_class( [create_env] * 3, policy=policy, - devices=[torch.device("cuda:0")] * 3, - storing_devices=[torch.device("cuda:0")] * 3, + device=[torch.device("cuda:0")] * 3, + storing_device=[torch.device("cuda:0")] * 3, + frames_per_batch=20, ) # collect state_dict state_dict = collector.state_dict() @@ -1125,10 +1381,10 @@ def env_fn(seed): frames_per_batch=20, max_frames_per_traj=2000, total_frames=20000, - devices=[ + device=[ device, ], - storing_devices=[ + storing_device=[ storing_device, ], ) @@ -1147,10 +1403,10 @@ def env_fn(seed): frames_per_batch=20, max_frames_per_traj=2000, total_frames=20000, - devices=[ + device=[ device, ], - storing_devices=[ + storing_device=[ storing_device, ], ) @@ -1170,7 +1426,7 @@ def env_fn(seed): ], ) class TestAutoWrap: - num_envs = 2 + num_envs = 1 @pytest.fixture def env_maker(self): @@ -1193,30 +1449,44 @@ def _create_collector_kwargs(self, env_maker, collector_class, policy): return collector_kwargs - @pytest.mark.parametrize("multiple_outputs", [False, True]) - def test_auto_wrap_modules(self, collector_class, multiple_outputs, env_maker): + @pytest.mark.parametrize("multiple_outputs", [True, False]) + @pytest.mark.parametrize("device", get_default_devices()) + def test_auto_wrap_modules( + self, collector_class, multiple_outputs, env_maker, device + ): policy = WrappablePolicy( out_features=env_maker().action_spec.shape[-1], multiple_outputs=multiple_outputs, ) + # init lazy params + policy(env_maker().reset().get("observation")) + collector = collector_class( - **self._create_collector_kwargs(env_maker, collector_class, policy) + **self._create_collector_kwargs(env_maker, collector_class, policy), + device=device, ) out_keys = ["action"] if multiple_outputs: out_keys.extend(f"output{i}" for i in range(1, 4)) - if collector_class is not SyncDataCollector: - assert all( - isinstance(p, TensorDictModule) for p in collector._policy_dict.values() - ) - assert all(p.out_keys == out_keys for p in collector._policy_dict.values()) - assert all(p.module is policy for p in collector._policy_dict.values()) - else: + if collector_class is SyncDataCollector: assert isinstance(collector.policy, TensorDictModule) assert collector.policy.out_keys == out_keys - assert collector.policy.module is policy + # this does not work now that we force the device of the policy + # assert collector.policy.module is policy + + for i, data in enumerate(collector): + if i == 0: + assert (data["action"] != 0).any() + for p in policy.parameters(): + p.data.zero_() + assert p.device == torch.device("cpu") + collector.update_policy_weights_() + elif i == 4: + assert (data["action"] == 0).all() + break + collector.shutdown() del collector @@ -1231,28 +1501,33 @@ def test_no_wrap_compatible_module(self, collector_class, env_maker): ) if collector_class is not SyncDataCollector: - assert all( - isinstance(p, TensorDictCompatiblePolicy) - for p in collector._policy_dict.values() - ) - assert all( - p.out_keys == ["action"] for p in collector._policy_dict.values() - ) - assert all(p is policy for p in collector._policy_dict.values()) + # We now do the casting only on the remote workers + pass else: assert isinstance(collector.policy, TensorDictCompatiblePolicy) assert collector.policy.out_keys == ["action"] assert collector.policy is policy + + for i, data in enumerate(collector): + if i == 0: + assert (data["action"] != 0).any() + for p in policy.parameters(): + p.data.zero_() + assert p.device == torch.device("cpu") + collector.update_policy_weights_() + elif i == 4: + assert (data["action"] == 0).all() + break + collector.shutdown() del collector def test_auto_wrap_error(self, collector_class, env_maker): policy = UnwrappablePolicy(out_features=env_maker().action_spec.shape[-1]) - with pytest.raises( TypeError, match=(r"Arguments to policy.forward are incompatible with entries in"), - ): + ) if collector_class is SyncDataCollector else pytest.raises(EOFError): collector_class( **self._create_collector_kwargs(env_maker, collector_class, policy) ) @@ -1367,8 +1642,7 @@ def env_fn(seed): frames_per_batch=frames_per_batch, init_random_frames=-1, reset_at_each_iter=False, - devices=get_default_devices()[0], - storing_devices=get_default_devices()[0], + device=get_default_devices()[0], split_trajs=False, preemptive_threshold=0.0, # stop after one iteration ) @@ -1396,23 +1670,44 @@ def test_maxframes_error(): ) -def test_reset_heterogeneous_envs(): +@retry(AssertionError, tries=10, delay=0) +@pytest.mark.parametrize("policy_device", [None, *get_available_devices()]) +@pytest.mark.parametrize("env_device", [None, *get_available_devices()]) +@pytest.mark.parametrize("storing_device", [None, *get_available_devices()]) +def test_reset_heterogeneous_envs( + policy_device: torch.device, env_device: torch.device, storing_device: torch.device +): + if ( + policy_device is not None + and policy_device.type == "cuda" + and env_device is None + ): + env_device = torch.device("cpu") # explicit mapping + elif env_device is not None and env_device.type == "cuda" and policy_device is None: + policy_device = torch.device("cpu") env1 = lambda: TransformedEnv(CountingEnv(), StepCounter(2)) env2 = lambda: TransformedEnv(CountingEnv(), StepCounter(3)) - env = SerialEnv(2, [env1, env2]) + env = SerialEnv(2, [env1, env2], device=env_device) collector = SyncDataCollector( - env, RandomPolicy(env.action_spec), total_frames=10_000, frames_per_batch=1000 + env, + RandomPolicy(env.action_spec), + total_frames=10_000, + frames_per_batch=100, + policy_device=policy_device, + env_device=env_device, + storing_device=storing_device, ) try: for data in collector: # noqa: B007 break + data_device = storing_device if storing_device is not None else env_device assert ( data[0]["next", "truncated"].squeeze() - == torch.tensor([False, True]).repeat(250)[:500] + == torch.tensor([False, True], device=data_device).repeat(25)[:50] ).all(), data[0]["next", "truncated"][:10] assert ( data[1]["next", "truncated"].squeeze() - == torch.tensor([False, False, True]).repeat(168)[:500] + == torch.tensor([False, False, True], device=data_device).repeat(17)[:50] ).all(), data[1]["next", "truncated"][:10] finally: collector.shutdown() @@ -1554,15 +1849,17 @@ def test_nested_env_dims(self, batch_size, nested_dim=5, frames_per_batch=20): ) -class TestHetEnvsCollector: +class TestHeterogeneousEnvsCollector: @pytest.mark.parametrize("batch_size", [(), (2,), (2, 1)]) @pytest.mark.parametrize("frames_per_batch", [4, 8, 16]) - def test_collector_het_env(self, batch_size, frames_per_batch, seed=1, max_steps=4): + def test_collector_heterogeneous_env( + self, batch_size, frames_per_batch, seed=1, max_steps=4 + ): batch_size = torch.Size(batch_size) - env = HeteroCountingEnv(max_steps=max_steps - 1, batch_size=batch_size) + env = HeterogeneousCountingEnv(max_steps=max_steps - 1, batch_size=batch_size) torch.manual_seed(seed) device = get_default_devices()[0] - policy = HeteroCountingEnvPolicy(env.input_spec["full_action_spec"]) + policy = HeterogeneousCountingEnvPolicy(env.input_spec["full_action_spec"]) ccollector = SyncDataCollector( create_env_fn=env, policy=policy, @@ -1590,14 +1887,14 @@ def test_collector_het_env(self, batch_size, frames_per_batch, seed=1, max_steps assert (_td["lazy"][..., i]["action"] == 1).all() del ccollector - def test_multi_collector_het_env_consistency( + def test_multi_collector_heterogeneous_env_consistency( self, seed=1, frames_per_batch=20, batch_dim=10 ): - env = HeteroCountingEnv(max_steps=3, batch_size=(batch_dim,)) + env = HeterogeneousCountingEnv(max_steps=3, batch_size=(batch_dim,)) torch.manual_seed(seed) env_fn = lambda: TransformedEnv(env, InitTracker()) check_env_specs(env_fn(), return_contiguous=False) - policy = HeteroCountingEnvPolicy(env.input_spec["full_action_spec"]) + policy = HeterogeneousCountingEnvPolicy(env.input_spec["full_action_spec"]) ccollector = MultiaSyncDataCollector( create_env_fn=[env_fn], @@ -1649,13 +1946,16 @@ class TestMultiKeyEnvsCollector: def test_collector(self, batch_size, frames_per_batch, max_steps, seed=1): env = MultiKeyCountingEnv(batch_size=batch_size, max_steps=max_steps) torch.manual_seed(seed) - policy = MultiKeyCountingEnvPolicy(env.input_spec["full_action_spec"]) + device = get_default_devices()[0] + policy = MultiKeyCountingEnvPolicy( + env.input_spec["full_action_spec"].to(device) + ) ccollector = SyncDataCollector( create_env_fn=env, policy=policy, frames_per_batch=frames_per_batch, total_frames=100, - device=get_default_devices()[0], + device=device, ) for _td in ccollector: @@ -1672,8 +1972,9 @@ def test_multi_collector_consistency( env = MultiKeyCountingEnv(batch_size=(batch_dim,)) env_fn = lambda: env torch.manual_seed(seed) + device = get_default_devices()[0] policy = MultiKeyCountingEnvPolicy( - env.input_spec["full_action_spec"], deterministic=True + env.input_spec["full_action_spec"].to(device), deterministic=True ) ccollector = MultiaSyncDataCollector( @@ -1681,7 +1982,7 @@ def test_multi_collector_consistency( policy=policy, frames_per_batch=frames_per_batch, total_frames=100, - device=get_default_devices()[0], + device=device, ) for i, d in enumerate(ccollector): if i == 0: @@ -1748,11 +2049,14 @@ def _step( **self.full_done_spec.zero(), }, self.batch_size, + device=self.device, ) def _reset(self, tensordict=None): self.state.zero_() - return TensorDict({"state": self.state.clone()}, self.batch_size) + return TensorDict( + {"state": self.state.clone()}, self.batch_size, device=self.device + ) def _set_seed(self, seed): return seed diff --git a/test/test_cost.py b/test/test_cost.py index 9561f6063e4..87e17eb252c 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -47,11 +47,11 @@ get_default_devices, ) from mocking_classes import ContinuousActionConvMockEnv -from tensordict.nn import NormalParamExtractor, TensorDictModule -from tensordict.nn.utils import Buffer # from torchrl.data.postprocs.utils import expand_as_right -from tensordict.tensordict import assert_allclose_td, TensorDict +from tensordict import assert_allclose_td, TensorDict +from tensordict.nn import NormalParamExtractor, TensorDictModule +from tensordict.nn.utils import Buffer from tensordict.utils import unravel_key from torch import autograd, nn from torchrl.data import ( diff --git a/test/test_distributions.py b/test/test_distributions.py index 30bb0288dd4..e6f228628a4 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -10,7 +10,7 @@ import torch.nn.functional as F from _utils_internal import get_default_devices -from tensordict.tensordict import TensorDictBase +from tensordict import TensorDictBase from torch import autograd, nn from torchrl.modules import ( NormalParamWrapper, diff --git a/test/test_env.py b/test/test_env.py index a6f6873f9ba..eaa31007186 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -38,8 +38,8 @@ DiscreteActionConvMockEnvNumpy, DiscreteActionVecMockEnv, DummyModelBasedEnvBase, - HeteroCountingEnv, - HeteroCountingEnvPolicy, + HeterogeneousCountingEnv, + HeterogeneousCountingEnvPolicy, MockBatchedLockedEnv, MockBatchedUnLockedEnv, MockSerialEnv, @@ -48,9 +48,13 @@ NestedCountingEnv, ) from packaging import version -from tensordict import dense_stack_tds +from tensordict import ( + assert_allclose_td, + dense_stack_tds, + LazyStackedTensorDict, + TensorDict, +) from tensordict.nn import TensorDictModuleBase -from tensordict.tensordict import assert_allclose_td, LazyStackedTensorDict, TensorDict from tensordict.utils import _unravel_key_to_tuple from torch import nn @@ -2034,12 +2038,12 @@ def test_nested_reset(self, nest_done, has_root_done, batch_size): class TestHeteroEnvs: @pytest.mark.parametrize("batch_size", [(), (32,), (1, 2)]) def test_reset(self, batch_size): - env = HeteroCountingEnv(batch_size=batch_size) + env = HeterogeneousCountingEnv(batch_size=batch_size) env.reset() @pytest.mark.parametrize("batch_size", [(), (32,), (1, 2)]) def test_rand_step(self, batch_size): - env = HeteroCountingEnv(batch_size=batch_size) + env = HeterogeneousCountingEnv(batch_size=batch_size) td = env.reset() assert (td["lazy"][..., 0]["tensor_0"] == 0).all() td = env.rand_step() @@ -2050,7 +2054,7 @@ def test_rand_step(self, batch_size): @pytest.mark.parametrize("batch_size", [(), (2,), (2, 1)]) @pytest.mark.parametrize("rollout_steps", [1, 2, 5]) def test_rollout(self, batch_size, rollout_steps, n_lazy_dim=3): - env = HeteroCountingEnv(batch_size=batch_size) + env = HeterogeneousCountingEnv(batch_size=batch_size) td = env.rollout(rollout_steps, return_contiguous=False) td = dense_stack_tds(td) @@ -2072,8 +2076,8 @@ def test_rollout(self, batch_size, rollout_steps, n_lazy_dim=3): @pytest.mark.parametrize("rollout_steps", [1, 2, 5]) @pytest.mark.parametrize("count", [True, False]) def test_rollout_policy(self, batch_size, rollout_steps, count): - env = HeteroCountingEnv(batch_size=batch_size) - policy = HeteroCountingEnvPolicy( + env = HeterogeneousCountingEnv(batch_size=batch_size) + policy = HeterogeneousCountingEnvPolicy( env.input_spec["full_action_spec"], count=count ) td = env.rollout(rollout_steps, policy=policy, return_contiguous=False) @@ -2091,14 +2095,14 @@ def test_rollout_policy(self, batch_size, rollout_steps, count): @pytest.mark.parametrize("batch_size", [(1, 2)]) @pytest.mark.parametrize("env_type", ["serial", "parallel"]) def test_vec_env(self, batch_size, env_type, rollout_steps=4, n_workers=2): - env_fun = lambda: HeteroCountingEnv(batch_size=batch_size) + env_fun = lambda: HeterogeneousCountingEnv(batch_size=batch_size) if env_type == "serial": vec_env = SerialEnv(n_workers, env_fun) else: vec_env = ParallelEnv(n_workers, env_fun) vec_batch_size = (n_workers,) + batch_size # check_env_specs(vec_env, return_contiguous=False) - policy = HeteroCountingEnvPolicy(vec_env.input_spec["full_action_spec"]) + policy = HeterogeneousCountingEnvPolicy(vec_env.input_spec["full_action_spec"]) vec_env.reset() td = vec_env.rollout( rollout_steps, @@ -2173,7 +2177,7 @@ def test_parallel( MockBatchedUnLockedEnv, MockSerialEnv, NestedCountingEnv, - HeteroCountingEnv, + HeterogeneousCountingEnv, MultiKeyCountingEnv, ], ) diff --git a/test/test_exploration.py b/test/test_exploration.py index 0d916e5d5e9..777f2714edb 100644 --- a/test/test_exploration.py +++ b/test/test_exploration.py @@ -14,9 +14,9 @@ NestedCountingEnv, ) from scipy.stats import ttest_1samp +from tensordict import TensorDict from tensordict.nn import InteractionType, TensorDictModule, TensorDictSequential -from tensordict.tensordict import TensorDict from torch import nn from torchrl._utils import _replace_last diff --git a/test/test_libs.py b/test/test_libs.py index 5fcf3497139..a1414948817 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -43,13 +43,12 @@ rollout_consistency_assertion, ) from packaging import version -from tensordict import LazyStackedTensorDict +from tensordict import assert_allclose_td, LazyStackedTensorDict, TensorDict from tensordict.nn import ( ProbabilisticTensorDictModule, TensorDictModule, TensorDictSequential, ) -from tensordict.tensordict import assert_allclose_td, TensorDict from torch import nn from torchrl._utils import implement_for from torchrl.collectors.collectors import RandomPolicy, SyncDataCollector diff --git a/test/test_postprocs.py b/test/test_postprocs.py index e28ee2eb592..10a559d3cac 100644 --- a/test/test_postprocs.py +++ b/test/test_postprocs.py @@ -7,7 +7,7 @@ import pytest import torch from _utils_internal import get_default_devices -from tensordict.tensordict import assert_allclose_td, TensorDict +from tensordict import assert_allclose_td, TensorDict from torchrl.collectors.utils import split_trajectories from torchrl.data.postprocs.postprocs import MultiStep diff --git a/test/test_rb.py b/test/test_rb.py index 96f392d5a22..4b9b1a5dc9f 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -19,8 +19,14 @@ from _utils_internal import get_default_devices, make_tc from packaging import version from packaging.version import parse -from tensordict import is_tensor_collection, is_tensorclass, tensorclass -from tensordict.tensordict import assert_allclose_td, TensorDict, TensorDictBase +from tensordict import ( + assert_allclose_td, + is_tensor_collection, + is_tensorclass, + tensorclass, + TensorDict, + TensorDictBase, +) from torch import multiprocessing as mp from torch.utils._pytree import tree_flatten, tree_map from torchrl.data import ( diff --git a/test/test_rb_distributed.py b/test/test_rb_distributed.py index 1d5a2398e92..a31836a4e72 100644 --- a/test/test_rb_distributed.py +++ b/test/test_rb_distributed.py @@ -13,7 +13,7 @@ import torch import torch.distributed.rpc as rpc import torch.multiprocessing as mp -from tensordict.tensordict import TensorDict +from tensordict import TensorDict from torchrl.data.replay_buffers import RemoteTensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import RandomSampler from torchrl.data.replay_buffers.storages import LazyMemmapStorage diff --git a/test/test_specs.py b/test/test_specs.py index cc97be11918..36f5aef65ca 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import argparse +import contextlib import numpy as np import pytest @@ -10,7 +11,7 @@ import torchrl.data.tensor_specs from _utils_internal import get_available_devices, get_default_devices, set_global_var from scipy.stats import chisquare -from tensordict.tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase +from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase from tensordict.utils import _unravel_key_to_tuple from torchrl.data.tensor_specs import ( @@ -341,7 +342,7 @@ def test_multi_discrete_conversion(ns, shape, device): @pytest.mark.parametrize("is_complete", [True, False]) -@pytest.mark.parametrize("device", get_default_devices()) +@pytest.mark.parametrize("device", [None, *get_default_devices()]) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.float64, None]) @pytest.mark.parametrize("shape", [(), (2, 3)]) class TestComposite: @@ -368,6 +369,7 @@ def _composite_spec(shape, is_complete=True, device=None, dtype=None): if is_complete else None, shape=shape, + device=device, ) def test_getitem(self, shape, is_complete, device, dtype): @@ -390,18 +392,26 @@ def test_setitem_forbidden_keys(self, shape, is_complete, device, dtype): def test_setitem_matches_device(self, shape, is_complete, device, dtype, dest): ts = self._composite_spec(shape, is_complete, device, dtype) - if dest == device: - ts["good"] = UnboundedContinuousTensorSpec( + ts["good"] = UnboundedContinuousTensorSpec( + shape=shape, device=device, dtype=dtype + ) + cm = ( + contextlib.nullcontext() + if (device == dest) or (device is None) + else pytest.raises( + RuntimeError, match="All devices of CompositeSpec must match" + ) + ) + with cm: + # auto-casting is introduced since v0.3 + ts["bad"] = UnboundedContinuousTensorSpec( shape=shape, device=dest, dtype=dtype ) - assert ts["good"].device == dest - else: - with pytest.raises( - RuntimeError, match="All devices of CompositeSpec must match" - ): - ts["bad"] = UnboundedContinuousTensorSpec( - shape=shape, device=dest, dtype=dtype - ) + assert ts.device == device + assert ts["good"].device == ( + device if device is not None else torch.zeros(()).device + ) + assert ts["bad"].device == (device if device is not None else dest) def test_del(self, shape, is_complete, device, dtype): ts = self._composite_spec(shape, is_complete, device, dtype) @@ -682,9 +692,12 @@ def test_create_composite_nested(shape, device): c = CompositeSpec(_d, shape=shape) assert isinstance(c["a", "b"], UnboundedContinuousTensorSpec) assert c["a"].shape == torch.Size(shape) + assert c.device is None # device not explicitly passed + assert c["a"].device is None # device not explicitly passed + assert c["a", "b"].device == device + c = c.to(device) assert c.device == device assert c["a"].device == device - assert c["a", "b"].device == device @pytest.mark.parametrize("recurse", [True, False]) @@ -2277,7 +2290,7 @@ def test_stack(self): class TestLazyStackedCompositeSpecs: - def _get_het_specs( + def _get_heterogeneous_specs( self, batch_size=(), stack_dim: int = 0, @@ -2362,7 +2375,7 @@ def _get_het_specs( ), ] - return torch.stack(spec_list, dim=stack_dim) + return torch.stack(spec_list, dim=stack_dim).cpu() def test_stack_index(self): c1 = CompositeSpec(a=UnboundedContinuousTensorSpec()) @@ -2640,7 +2653,7 @@ def test_unsqueeze(self): assert c.squeeze().shape == torch.Size([2, 3]) - c = self._get_het_specs() + c = self._get_heterogeneous_specs() cu = c.unsqueeze(0) assert cu.shape == torch.Size([1, 3]) cus = cu.squeeze(0) @@ -2648,14 +2661,14 @@ def test_unsqueeze(self): @pytest.mark.parametrize("batch_size", [(), (4,), (4, 2)]) def test_len(self, batch_size): - c = self._get_het_specs(batch_size=batch_size) + c = self._get_heterogeneous_specs(batch_size=batch_size) assert len(c) == c.shape[0] assert len(c) == len(c.rand()) @pytest.mark.parametrize("batch_size", [(), (4,), (4, 2)]) def test_eq(self, batch_size): - c = self._get_het_specs(batch_size=batch_size) - c2 = self._get_het_specs(batch_size=batch_size) + c = self._get_heterogeneous_specs(batch_size=batch_size) + c2 = self._get_heterogeneous_specs(batch_size=batch_size) assert c == c2 and not c != c2 assert c == c.clone() and not c != c.clone() @@ -2663,12 +2676,12 @@ def test_eq(self, batch_size): del c2["shared"] assert not c == c2 and c != c2 - c2 = self._get_het_specs(batch_size=batch_size) + c2 = self._get_heterogeneous_specs(batch_size=batch_size) del c2[0]["lidar"] assert not c == c2 and c != c2 - c2 = self._get_het_specs(batch_size=batch_size) + c2 = self._get_heterogeneous_specs(batch_size=batch_size) c2[0]["lidar"].space.low += 1 assert not c == c2 and c != c2 @@ -2676,7 +2689,7 @@ def test_eq(self, batch_size): @pytest.mark.parametrize("include_nested", [True, False]) @pytest.mark.parametrize("leaves_only", [True, False]) def test_del(self, batch_size, include_nested, leaves_only): - c = self._get_het_specs(batch_size=batch_size) + c = self._get_heterogeneous_specs(batch_size=batch_size) td_c = c.rand() keys = list(c.keys(include_nested=include_nested, leaves_only=leaves_only)) @@ -2709,7 +2722,7 @@ def test_del(self, batch_size, include_nested, leaves_only): @pytest.mark.parametrize("batch_size", [(), (4,), (4, 2)]) def test_is_in(self, batch_size): - c = self._get_het_specs(batch_size=batch_size) + c = self._get_heterogeneous_specs(batch_size=batch_size) td_c = c.rand() assert c.is_in(td_c) @@ -2735,7 +2748,7 @@ def test_is_in(self, batch_size): assert c.is_in(td_c) def test_type_check(self): - c = self._get_het_specs() + c = self._get_heterogeneous_specs() td_c = c.rand() c.type_check(td_c) @@ -2743,7 +2756,7 @@ def test_type_check(self): @pytest.mark.parametrize("batch_size", [(), (4,), (4, 2)]) def test_project(self, batch_size): - c = self._get_het_specs(batch_size=batch_size) + c = self._get_heterogeneous_specs(batch_size=batch_size) td_c = c.rand() assert c.is_in(td_c) val = c.project(td_c) @@ -2775,7 +2788,7 @@ def test_project(self, batch_size): assert c.is_in(td_c) def test_repr(self): - c = self._get_het_specs() + c = self._get_heterogeneous_specs() expected = f"""LazyStackedCompositeSpec( fields={{ @@ -2869,7 +2882,7 @@ def test_repr(self): @pytest.mark.parametrize("batch_size", [(), (2,), (2, 1)]) def test_consolidate_spec(self, batch_size): - spec = self._get_het_specs(batch_size) + spec = self._get_heterogeneous_specs(batch_size) spec_lazy = spec.clone() assert not check_no_exclusive_keys(spec_lazy) @@ -2938,8 +2951,8 @@ def test_consolidate_spec_exclusive_lazy_stacked(self, batch_size): @pytest.mark.parametrize("batch_size", [(2,), (2, 1)]) def test_update(self, batch_size, stack_dim=0): - spec = self._get_het_specs(batch_size, stack_dim) - spec2 = self._get_het_specs(batch_size, stack_dim) + spec = self._get_heterogeneous_specs(batch_size, stack_dim) + spec2 = self._get_heterogeneous_specs(batch_size, stack_dim) del spec2["shared"] spec2["hetero"] = spec2["hetero"].unsqueeze(-1) @@ -2964,7 +2977,7 @@ def test_update(self, batch_size, stack_dim=0): @pytest.mark.parametrize("batch_size", [(2,), (2, 1)]) @pytest.mark.parametrize("stack_dim", [0, 1]) def test_set_item(self, batch_size, stack_dim): - spec = self._get_het_specs(batch_size, stack_dim) + spec = self._get_heterogeneous_specs(batch_size, stack_dim) new = torch.stack( [UnboundedContinuousTensorSpec(shape=(*batch_size, i)) for i in range(3)], @@ -2998,8 +3011,8 @@ def test_set_item(self, batch_size, stack_dim): stack_dim, ) spec["comp"] = comp - assert spec["comp"] == comp - assert spec["comp", "a"] == new + assert spec["comp"] == comp.to(spec.device) + assert spec["comp", "a"] == new.to(spec.device) # MultiDiscreteTensorSpec: Pending resolution of https://github.com/pytorch/pytorch/issues/100080. diff --git a/test/test_transforms.py b/test/test_transforms.py index 2bc9f36a79b..b325a1ccd99 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -42,9 +42,8 @@ MultiKeyCountingEnvPolicy, NestedCountingEnv, ) -from tensordict import unravel_key +from tensordict import TensorDict, TensorDictBase, unravel_key from tensordict.nn import TensorDictSequential -from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import _unravel_key_to_tuple, assert_allclose_td from torch import multiprocessing as mp, nn, Tensor from torchrl._utils import _replace_last, prod diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 08c86ffea45..661903b784d 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -6,6 +6,11 @@ import _pickle import abc + +import contextlib + +import functools + import inspect import logging import os @@ -17,17 +22,23 @@ from copy import deepcopy from multiprocessing import connection, queues from multiprocessing.managers import SyncManager - from textwrap import indent from typing import Any, Callable, Dict, Iterator, Optional, Sequence, Tuple, Union import numpy as np import torch import torch.nn as nn + +from tensordict import ( + LazyStackedTensorDict, + TensorDict, + TensorDictBase, + TensorDictParams, +) from tensordict.nn import TensorDictModule, TensorDictModuleBase -from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import NestedKey from torch import multiprocessing as mp +from torch.utils._pytree import tree_map from torch.utils.data import IterableDataset from torchrl._utils import ( @@ -77,7 +88,8 @@ class RandomPolicy: """ def __init__(self, action_spec: TensorSpec, action_key: NestedKey = "action"): - self.action_spec = action_spec + super().__init__() + self.action_spec = action_spec.clone() self.action_key = action_key def __call__(self, td: TensorDictBase) -> TensorDictBase: @@ -142,10 +154,16 @@ def recursive_map_to_cpu(dictionary: OrderedDict) -> OrderedDict: def _policy_is_tensordict_compatible(policy: nn.Module): - sig = inspect.signature(policy.forward) + if isinstance(policy, _NonParametricPolicyWrapper) and isinstance( + policy.policy, RandomPolicy + ): + return True if isinstance(policy, TensorDictModuleBase): return True + + sig = inspect.signature(policy.forward) + if ( len(sig.parameters) == 1 and hasattr(policy, "in_keys") @@ -184,6 +202,74 @@ class DataCollectorBase(IterableDataset, metaclass=abc.ABCMeta): _iterator = None + def _make_compatible_policy(self, policy, observation_spec=None): + if policy is None: + if not hasattr(self, "env") or self.env is None: + raise ValueError( + "env must be provided to _get_policy_and_device if policy is None" + ) + policy = RandomPolicy(self.env.input_spec["full_action_spec"]) + # make sure policy is an nn.Module + policy = _NonParametricPolicyWrapper(policy) + if not _policy_is_tensordict_compatible(policy): + # policy is a nn.Module that doesn't operate on tensordicts directly + # so we attempt to auto-wrap policy with TensorDictModule + if observation_spec is None: + raise ValueError( + "Unable to read observation_spec from the environment. This is " + "required to check compatibility of the environment and policy " + "since the policy is a nn.Module that operates on tensors " + "rather than a TensorDictModule or a nn.Module that accepts a " + "TensorDict as input and defines in_keys and out_keys." + ) + + try: + # signature modified by make_functional + sig = policy.forward.__signature__ + except AttributeError: + sig = inspect.signature(policy.forward) + required_kwargs = { + str(k) for k, p in sig.parameters.items() if p.default is inspect._empty + } + next_observation = { + key: value for key, value in observation_spec.rand().items() + } + # we check if all the mandatory params are there + if set(sig.parameters) == {"tensordict"} or set(sig.parameters) == {"td"}: + pass + elif not required_kwargs.difference(set(next_observation)): + in_keys = [str(k) for k in sig.parameters if k in next_observation] + if not hasattr(self, "env") or self.env is None: + out_keys = ["action"] + else: + out_keys = list(self.env.action_keys) + for p in policy.parameters(): + policy_device = p.device + break + else: + policy_device = None + if policy_device: + next_observation = tree_map( + lambda x: x.to(policy_device), next_observation + ) + output = policy(**next_observation) + + if isinstance(output, tuple): + out_keys.extend(f"output{i + 1}" for i in range(len(output) - 1)) + + policy = TensorDictModule(policy, in_keys=in_keys, out_keys=out_keys) + else: + raise TypeError( + f"""Arguments to policy.forward are incompatible with entries in +env.observation_spec (got incongruent signatures: fun signature is {set(sig.parameters)} vs specs {set(next_observation)}). +If you want TorchRL to automatically wrap your policy with a TensorDictModule +then the arguments to policy.forward must correspond one-to-one with entries +in env.observation_spec that are prefixed with 'next_'. For more complex +behaviour and more control you can consider writing your own TensorDictModule. +""" + ) + return policy + def _get_policy_and_device( self, policy: Optional[ @@ -192,110 +278,46 @@ def _get_policy_and_device( Callable[[TensorDictBase], TensorDictBase], ] ] = None, - device: Optional[DEVICE_TYPING] = None, observation_spec: TensorSpec = None, - ) -> Tuple[TensorDictModule, torch.device, Union[None, Callable[[], dict]]]: + ) -> Tuple[TensorDictModule, Union[None, Callable[[], dict]]]: """Util method to get a policy and its device given the collector __init__ inputs. - From a policy and a device, assigns the self.device attribute to - the desired device and maps the policy onto it or (if the device is - ommitted) assigns the self.device attribute to the policy device. - Args: create_env_fn (Callable or list of callables): an env creator function (or a list of creators) create_env_kwargs (dictionary): kwargs for the env creator policy (TensorDictModule, optional): a policy to be used - device (int, str or torch.device, optional): device where to place - the policy observation_spec (TensorSpec, optional): spec of the observations """ - if policy is None: - if not hasattr(self, "env") or self.env is None: - raise ValueError( - "env must be provided to _get_policy_and_device if policy is None" - ) - policy = RandomPolicy(self.env.input_spec["full_action_spec"]) - elif isinstance(policy, nn.Module): - # TODO: revisit these checks when we have determined whether arbitrary - # callables should be supported as policies. - if not _policy_is_tensordict_compatible(policy): - # policy is a nn.Module that doesn't operate on tensordicts directly - # so we attempt to auto-wrap policy with TensorDictModule - if observation_spec is None: - raise ValueError( - "Unable to read observation_spec from the environment. This is " - "required to check compatibility of the environment and policy " - "since the policy is a nn.Module that operates on tensors " - "rather than a TensorDictModule or a nn.Module that accepts a " - "TensorDict as input and defines in_keys and out_keys." - ) + policy = self._make_compatible_policy(policy, observation_spec) + param_and_buf = TensorDict.from_module(policy, as_module=True) - try: - # signature modified by make_functional - sig = policy.forward.__signature__ - except AttributeError: - sig = inspect.signature(policy.forward) - required_params = { - str(k) - for k, p in sig.parameters.items() - if p.default is inspect._empty - } - next_observation = { - key: value for key, value in observation_spec.rand().items() - } - # we check if all the mandatory params are there - if not required_params.difference(set(next_observation)): - in_keys = [str(k) for k in sig.parameters if k in next_observation] - if not hasattr(self, "env") or self.env is None: - out_keys = ["action"] - else: - out_keys = self.env.action_keys - output = policy(**next_observation) + def get_weights_fn(param_and_buf=param_and_buf): + return param_and_buf.data - if isinstance(output, tuple): - out_keys.extend(f"output{i+1}" for i in range(len(output) - 1)) + if self.policy_device: + # create a stateless policy and populate it with params + def _map_to_device_params(param, device): + is_param = isinstance(param, nn.Parameter) - policy = TensorDictModule( - policy, in_keys=in_keys, out_keys=out_keys - ) - else: - raise TypeError( - f"""Arguments to policy.forward are incompatible with entries in -env.observation_spec (got incongruent signatures: fun signature is {set(sig.parameters)} vs specs {set(next_observation)}). -If you want TorchRL to automatically wrap your policy with a TensorDictModule -then the arguments to policy.forward must correspond one-to-one with entries -in env.observation_spec that are prefixed with 'next_'. For more complex -behaviour and more control you can consider writing your own TensorDictModule. -""" - ) + pd = param.detach().to(device, non_blocking=True) - try: - policy_device = next(policy.parameters()).device - except Exception: - policy_device = ( - torch.device(device) if device is not None else torch.device("cpu") - ) + if is_param: + pd = nn.Parameter(pd, requires_grad=False) + return pd - device = torch.device(device) if device is not None else policy_device - get_weights_fn = None - if policy_device != device: - param_and_buf = TensorDict.from_module(policy, as_module=True) + # Create a stateless policy, then populate this copy with params on device + with param_and_buf.apply( + functools.partial(_map_to_device_params, device="meta") + ).to_module(policy): + policy = deepcopy(policy) - def get_weights_fn(param_and_buf=param_and_buf): - return param_and_buf.data + param_and_buf.apply( + functools.partial(_map_to_device_params, device=self.policy_device) + ).to_module(policy) - policy_cast = deepcopy(policy).requires_grad_(False).to(device) - # here things may break bc policy.to("cuda") gives us weights on cuda:0 (same - # but different) - try: - device = next(policy_cast.parameters()).device - except StopIteration: # noqa - pass - else: - policy_cast = policy - return policy_cast, device, get_weights_fn + return policy, get_weights_fn def update_policy_weights_( self, policy_weights: Optional[TensorDictBase] = None @@ -363,6 +385,8 @@ class SyncDataCollector(DataCollectorBase): If ``None`` is provided, the policy used will be a :class:`~torchrl.collectors.RandomPolicy` instance with the environment ``action_spec``. + + Keyword Args: frames_per_batch (int): A keyword-only argument representing the total number of elements in a batch. total_frames (int): A keyword-only argument representing the total @@ -370,28 +394,47 @@ class SyncDataCollector(DataCollectorBase): during its lifespan. If the ``total_frames`` is not divisible by ``frames_per_batch``, an exception is raised. Endless collectors can be created by passing ``total_frames=-1``. - device (int, str or torch.device, optional): The device on which the - policy will be placed. - If it differs from the input policy device, the - :meth:`~.update_policy_weights_` method should be queried - at appropriate times during the training loop to accommodate for - the lag between parameter configuration at various times. - Defaults to ``None`` (i.e. policy is kept on its original device). + Defaults to ``-1`` (endless collector). + device (int, str or torch.device, optional): The generic device of the + collector. The ``device`` args fills any non-specified device: if + ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or + ``env_device`` is not specified, its value will be set to ``device``. + Defaults to ``None`` (No default device). storing_device (int, str or torch.device, optional): The device on which - the output :class:`tensordict.TensorDict` will be stored. For long - trajectories, it may be necessary to store the data on a different + the output :class:`~tensordict.TensorDict` will be stored. + If ``device`` is passed and ``storing_device`` is ``None``, it will + default to the value indicated by ``device``. + For long trajectories, it may be necessary to store the data on a different device than the one where the policy and env are executed. - Defaults to ``"cpu"``. + Defaults to ``None`` (the output tensordict isn't on a specific device, + leaf tensors sit on the device where they were created). + env_device (int, str or torch.device, optional): The device on which + the environment should be cast (or executed if that functionality is + supported). If not specified and the env has a non-``None`` device, + ``env_device`` will default to that value. If ``device`` is passed + and ``env_device=None``, it will default to ``device``. If the value + as such specified of ``env_device`` differs from ``policy_device`` + and one of them is not ``None``, the data will be cast to ``env_device`` + before being passed to the env (i.e., passing different devices to + policy and env is supported). Defaults to ``None``. + policy_device (int, str or torch.device, optional): The device on which + the policy should be cast. + If ``device`` is passed and ``policy_device=None``, it will default + to ``device``. If the value as such specified of ``policy_device`` + differs from ``env_device`` and one of them is not ``None``, + the data will be cast to ``policy_device`` before being passed to + the policy (i.e., passing different devices to policy and env is + supported). Defaults to ``None``. create_env_kwargs (dict, optional): Dictionary of kwargs for ``create_env_fn``. max_frames_per_traj (int, optional): Maximum steps per trajectory. - Note that a trajectory can span over multiple batches (unless + Note that a trajectory can span across multiple batches (unless ``reset_at_each_iter`` is set to ``True``, see below). Once a trajectory reaches ``n_steps``, the environment is reset. If the environment wraps multiple environments together, the number of steps is tracked for each environment independently. Negative values are allowed, in which case this argument is ignored. - Defaults to ``None`` (i.e. no maximum number of steps). + Defaults to ``None`` (i.e., no maximum number of steps). init_random_frames (int, optional): Number of frames for which the policy is ignored before it is called. This feature is mainly intended to be used in offline/model-based settings, where a @@ -411,15 +454,15 @@ class SyncDataCollector(DataCollectorBase): information. Defaults to ``False``. exploration_type (ExplorationType, optional): interaction mode to be used when - collecting data. Must be one of ``ExplorationType.RANDOM``, ``ExplorationType.MODE`` or - ``ExplorationType.MEAN``. - Defaults to ``ExplorationType.RANDOM`` + collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.RANDOM``, + ``torchrl.envs.utils.ExplorationType.MODE`` or ``torchrl.envs.utils.ExplorationType.MEAN``. + Defaults to ``torchrl.envs.utils.ExplorationType.RANDOM``. return_same_td (bool, optional): if ``True``, the same TensorDict will be returned at each iteration, with its values updated. This feature should be used cautiously: if the same tensordict is added to a replay buffer for instance, the whole content of the buffer will be identical. - Default is False. + Default is ``False``. interruptor (_Interruptor, optional): An _Interruptor object that can be used from outside the class to control rollout collection. The _Interruptor class has methods ´start_collection´ and ´stop_collection´, which allow to implement @@ -496,9 +539,11 @@ def __init__( ], *, frames_per_batch: int, - total_frames: int, + total_frames: int = -1, device: DEVICE_TYPING = None, storing_device: DEVICE_TYPING = None, + policy_device: DEVICE_TYPING = None, + env_device: DEVICE_TYPING = None, create_env_kwargs: dict | None = None, max_frames_per_traj: int | None = None, init_random_frames: int | None = None, @@ -532,29 +577,43 @@ def __init__( ) env.update_kwargs(create_env_kwargs) - if storing_device is None: - if device is not None: - storing_device = device - elif policy is not None: - try: - policy_device = next(policy.parameters()).device - except (AttributeError, StopIteration): - policy_device = torch.device("cpu") - storing_device = policy_device - else: - storing_device = torch.device("cpu") + ########################## + # Setting devices: + # The rule is the following: + # - If no device is passed, all devices are assumed to work OOB. + # The tensordict used for output is not on any device (ie, actions and observations + # can be on a different device). + # - If the ``device`` is passed, it is used for all devices (storing, env and policy) + # unless overridden by another kwarg. + # - The rest of the kwargs control the respective device. + storing_device, policy_device, env_device = self._get_devices( + storing_device=storing_device, + policy_device=policy_device, + env_device=env_device, + device=device, + ) + + self.storing_device = storing_device + self.env_device = env_device + self.policy_device = policy_device + self.device = device + # Check if we need to cast things from device to device + # If the policy has a None device and the env too, no need to cast (we don't know + # and assume the user knows what she's doing). + # If the devices match we're happy too. + # Only if the values differ we need to cast + self._cast_to_policy_device = self.policy_device != self.env_device - self.storing_device = torch.device(storing_device) self.env: EnvBase = env + del env self.closed = False if not reset_when_done: raise ValueError("reset_when_done is deprectated.") self.reset_when_done = reset_when_done self.n_env = self.env.batch_size.numel() - (self.policy, self.device, self.get_weights_fn,) = self._get_policy_and_device( + (self.policy, self.get_weights_fn,) = self._get_policy_and_device( policy=policy, - device=device, observation_spec=self.env.observation_spec, ) @@ -563,7 +622,12 @@ def __init__( else: self.policy_weights = TensorDict({}, []) - self.env: EnvBase = self.env.to(self.device) + if self.env_device: + self.env: EnvBase = self.env.to(self.env_device) + elif self.env.device is not None: + # we we did not receive an env device, we use the device of the env + self.env_device = self.env.device + self.max_frames_per_traj = ( int(max_frames_per_traj) if max_frames_per_traj is not None else 0 ) @@ -580,7 +644,7 @@ def __init__( "Possible solutions: Set max_frames_per_traj to 0 or " "remove the StepCounter limit from the environment transforms." ) - env = self.env = TransformedEnv( + self.env = TransformedEnv( self.env, StepCounter(max_steps=self.max_frames_per_traj) ) @@ -614,7 +678,11 @@ def __init__( ) self.postproc = postproc - if self.postproc is not None and hasattr(self.postproc, "to"): + if ( + self.postproc is not None + and hasattr(self.postproc, "to") + and self.storing_device + ): self.postproc.to(self.storing_device) if frames_per_batch % self.n_env != 0 and RL_WARNINGS: warnings.warn( @@ -630,64 +698,139 @@ def __init__( ) self.return_same_td = return_same_td - self._tensordict = env.reset() - traj_ids = torch.arange(self.n_env, device=env.device).view(self.env.batch_size) - self._tensordict.set( + # Shuttle is a deviceless tensordict that just carried data from env to policy and policy to env + self._shuttle = self.env.reset() + if self.policy_device != self.env_device or self.env_device is None: + self._shuttle_has_no_device = True + self._shuttle.clear_device_() + else: + self._shuttle_has_no_device = False + + traj_ids = torch.arange(self.n_env, device=self.storing_device).view( + self.env.batch_size + ) + self._shuttle.set( ("collector", "traj_ids"), traj_ids, ) - with torch.no_grad(): - self._tensordict_out = self.env.fake_tensordict() + self._final_rollout = self.env.fake_tensordict() + + # If storing device is not None, we use this to cast the storage. + # If it is None and the env and policy are on the same device, + # the storing device is already the same as those, so we don't need + # to consider this use case. + # In all other cases, we can't really put a device on the storage, + # since at least one data source has a device that is not clear. + if self.storing_device: + self._final_rollout = self._final_rollout.to( + self.storing_device, non_blocking=True + ) + else: + # erase all devices + self._final_rollout.clear_device_() + # If the policy has a valid spec, we use it + self._policy_output_keys = set() if ( hasattr(self.policy, "spec") and self.policy.spec is not None and all(v is not None for v in self.policy.spec.values(True, True)) ): if any( - key not in self._tensordict_out.keys(isinstance(key, tuple)) + key not in self._final_rollout.keys(isinstance(key, tuple)) for key in self.policy.spec.keys(True, True) ): # if policy spec is non-empty, all the values are not None and the keys # match the out_keys we assume the user has given all relevant information # the policy could have more keys than the env: policy_spec = self.policy.spec - if policy_spec.ndim < self._tensordict_out.ndim: - policy_spec = policy_spec.expand(self._tensordict_out.shape) + if policy_spec.ndim < self._final_rollout.ndim: + policy_spec = policy_spec.expand(self._final_rollout.shape) for key, spec in policy_spec.items(True, True): - if key in self._tensordict_out.keys(isinstance(key, tuple)): + self._policy_output_keys.add(key) + if key in self._final_rollout.keys(True): continue - self._tensordict_out.set(key, spec.zero()) + self._final_rollout.set(key, spec.zero()) else: # otherwise, we perform a small number of steps with the policy to - # determine the relevant keys with which to pre-populate _tensordict_out. + # determine the relevant keys with which to pre-populate _final_rollout. # This is the safest thing to do if the spec has None fields or if there is # no spec at all. # See #505 for additional context. - self._tensordict_out.update(self._tensordict) + self._final_rollout.update(self._shuttle.copy()) with torch.no_grad(): - self._tensordict_out = self.policy(self._tensordict_out.to(self.device)) + policy_input = self._shuttle.copy() + if self.policy_device: + policy_input = policy_input.to(self.policy_device) + # we cast to policy device, we'll deal with the device later + policy_input_copy = policy_input.copy() + policy_input_clone = ( + policy_input.clone() + ) # to test if values have changed in-place + policy_output = self.policy(policy_input) + + # check that we don't have exclusive keys, because they don't appear in keys + def check_exclusive(val): + if ( + isinstance(val, LazyStackedTensorDict) + and val._has_exclusive_keys + ): + raise RuntimeError( + "LazyStackedTensorDict with exclusive keys are not permitted in collectors. " + "Consider using a placeholder for missing keys." + ) - self._tensordict_out = ( - self._tensordict_out.unsqueeze(-1) - .expand(*env.batch_size, self.frames_per_batch) + policy_output._fast_apply(check_exclusive, call_on_nested=True) + # Use apply, because it works well with lazy stacks + # Edge-case of this approach: the policy may change the values in-place and only by a tiny bit + # or occasionally. In these cases, the keys will be missed (we can't detect if the policy has + # changed them here). + # This will cause a failure to update entries when policy and env device mismatch and + # casting is necessary. + filtered_policy_output = policy_output.apply( + lambda value_output, value_input, value_input_clone: value_output + if (value_input is None) + or (value_output is not value_input) + or ~torch.isclose(value_output, value_input_clone).any() + else None, + policy_input_copy, + policy_input_clone, + default=None, + ) + self._policy_output_keys = list( + self._policy_output_keys.union( + set(filtered_policy_output.keys(True, True)) + ) + ) + self._final_rollout.update( + policy_output.select(*self._policy_output_keys) + ) + del filtered_policy_output, policy_output, policy_input + + _env_output_keys = [] + for spec in ["full_observation_spec", "full_done_spec", "full_reward_spec"]: + _env_output_keys += list(self.env.output_spec[spec].keys(True, True)) + self._env_output_keys = _env_output_keys + self._final_rollout = ( + self._final_rollout.unsqueeze(-1) + .expand(*self.env.batch_size, self.frames_per_batch) .clone() .zero_() ) + # in addition to outputs of the policy, we add traj_ids to - # _tensordict_out which will be collected during rollout - self._tensordict_out = self._tensordict_out.to(self.storing_device) - self._tensordict_out.set( + # _final_rollout which will be collected during rollout + self._final_rollout.set( ("collector", "traj_ids"), torch.zeros( - *self._tensordict_out.batch_size, + *self._final_rollout.batch_size, dtype=torch.int64, device=self.storing_device, ), ) - self._tensordict_out.refine_names(..., "time") + self._final_rollout.refine_names(..., "time") if split_trajs is None: split_trajs = False @@ -697,6 +840,23 @@ def __init__( self._frames = 0 self._iter = -1 + @classmethod + def _get_devices( + cls, + *, + storing_device: torch.device, + policy_device: torch.device, + env_device: torch.device, + device: torch.device, + ): + device = torch.device(device) if device else device + storing_device = torch.device(storing_device) if storing_device else device + policy_device = torch.device(policy_device) if policy_device else device + env_device = torch.device(env_device) if env_device else device + if storing_device is None and (env_device == policy_device): + storing_device = env_device + return storing_device, policy_device, env_device + # for RPC def next(self): return super().next() @@ -739,13 +899,32 @@ def iterator(self) -> Iterator[TensorDictBase]: Yields: TensorDictBase objects containing (chunks of) trajectories """ - if self.storing_device.type == "cuda": + if self.storing_device and self.storing_device.type == "cuda": stream = torch.cuda.Stream(self.storing_device, priority=-1) event = stream.record_event() + streams = [stream] + events = [event] + elif self.storing_device is None: + streams = [] + events = [] + # this way of checking cuda is robust to lazy stacks with mismatching shapes + cuda_devices = set() + + def cuda_check(tensor: torch.Tensor): + if tensor.is_cuda: + cuda_devices.add(tensor.device) + + self._final_rollout.apply(cuda_check) + for device in cuda_devices: + streams.append(torch.cuda.Stream(device, priority=-1)) + events.append(streams[-1].record_event()) else: - event = None - stream = None - with torch.cuda.stream(stream): + streams = [] + events = [] + with contextlib.ExitStack() as stack: + for stream in streams: + stack.enter_context(torch.cuda.stream(stream)) + total_frames = self.total_frames while self._frames < self.total_frames: @@ -781,9 +960,10 @@ def is_private(key): if self.return_same_td: # This is used with multiprocessed collectors to use the buffers # stored in the tensordict. - if event is not None: - event.record() - event.synchronize() + if events: + for event in events: + event.record() + event.synchronize() yield tensordict_out else: # we must clone the values, as the tensordict is updated in-place. @@ -804,12 +984,15 @@ def _update_traj_ids(self, tensordict) -> None: tensordict.get("next"), done_keys=self.env.done_keys ) if traj_sop.any(): - traj_ids = self._tensordict.get(("collector", "traj_ids")) - traj_ids = traj_ids.clone() + traj_ids = self._shuttle.get(("collector", "traj_ids")) + traj_sop = traj_sop.to(self.storing_device) + traj_ids = traj_ids.clone().to(self.storing_device) traj_ids[traj_sop] = traj_ids.max() + torch.arange( - 1, traj_sop.sum() + 1, device=traj_ids.device + 1, + traj_sop.sum() + 1, + device=self.storing_device, ) - self._tensordict.set(("collector", "traj_ids"), traj_ids) + self._shuttle.set(("collector", "traj_ids"), traj_ids) @torch.no_grad() def rollout(self) -> TensorDictBase: @@ -820,10 +1003,10 @@ def rollout(self) -> TensorDictBase: """ if self.reset_at_each_iter: - self._tensordict.update(self.env.reset()) + self._shuttle.update(self.env.reset()) - # self._tensordict.fill_(("collector", "step_count"), 0) - self._tensordict_out.fill_(("collector", "traj_ids"), -1) + # self._shuttle.fill_(("collector", "step_count"), 0) + self._final_rollout.fill_(("collector", "traj_ids"), -1) tensordicts = [] with set_exploration_type(self.exploration_type): for t in range(self.frames_per_batch): @@ -831,20 +1014,64 @@ def rollout(self) -> TensorDictBase: self.init_random_frames is not None and self._frames < self.init_random_frames ): - self.env.rand_action(self._tensordict) + self.env.rand_action(self._shuttle) else: - self.policy(self._tensordict) - tensordict, tensordict_ = self.env.step_and_maybe_reset( - self._tensordict - ) - self._tensordict = tensordict_.set( - "collector", tensordict.get("collector").clone(False) - ) - tensordicts.append( - tensordict.to(self.storing_device, non_blocking=True) - ) + if self._cast_to_policy_device: + if self.policy_device is not None: + policy_input = self._shuttle.to( + self.policy_device, non_blocking=True + ) + elif self.policy_device is None: + # we know the tensordict has a device otherwise we would not be here + # we can pass this, clear_device_ must have been called earlier + # policy_input = self._shuttle.clear_device_() + policy_input = self._shuttle + else: + policy_input = self._shuttle + # we still do the assignment for security + policy_output = self.policy(policy_input) + if self._shuttle is not policy_output: + # ad-hoc update shuttle + self._shuttle.update( + policy_output, keys_to_update=self._policy_output_keys + ) + + if self._cast_to_policy_device: + if self.env_device is not None: + env_input = self._shuttle.to(self.env_device, non_blocking=True) + elif self.env_device is None: + # we know the tensordict has a device otherwise we would not be here + # we can pass this, clear_device_ must have been called earlier + # env_input = self._shuttle.clear_device_() + env_input = self._shuttle + else: + env_input = self._shuttle + env_output, env_next_output = self.env.step_and_maybe_reset(env_input) + + if self._shuttle is not env_output: + # ad-hoc update shuttle + next_data = env_output.get("next") + if self._shuttle_has_no_device: + # Make sure + next_data.clear_device_() + self._shuttle.set("next", next_data) + + if self.storing_device is not None: + tensordicts.append( + self._shuttle.to(self.storing_device, non_blocking=True) + ) + else: + tensordicts.append(self._shuttle) + + # carry over collector data without messing up devices + collector_data = self._shuttle.get("collector").copy() + self._shuttle = env_next_output + if self._shuttle_has_no_device: + self._shuttle.clear_device_() + self._shuttle.set("collector", collector_data) + + self._update_traj_ids(env_output) - self._update_traj_ids(tensordict) if ( self.interruptor is not None and self.interruptor.collection_stopped() @@ -852,37 +1079,47 @@ def rollout(self) -> TensorDictBase: try: torch.stack( tensordicts, - self._tensordict_out.ndim - 1, - out=self._tensordict_out[: t + 1], + self._final_rollout.ndim - 1, + out=self._final_rollout[: t + 1], ) except RuntimeError: - with self._tensordict_out.unlock_(): + with self._final_rollout.unlock_(): torch.stack( tensordicts, - self._tensordict_out.ndim - 1, - out=self._tensordict_out[: t + 1], + self._final_rollout.ndim - 1, + out=self._final_rollout[: t + 1], ) break else: try: - self._tensordict_out = torch.stack( + self._final_rollout = torch.stack( tensordicts, - self._tensordict_out.ndim - 1, - out=self._tensordict_out, + self._final_rollout.ndim - 1, + out=self._final_rollout, ) except RuntimeError: - with self._tensordict_out.unlock_(): - self._tensordict_out = torch.stack( + with self._final_rollout.unlock_(): + self._final_rollout = torch.stack( tensordicts, - self._tensordict_out.ndim - 1, - out=self._tensordict_out, + self._final_rollout.ndim - 1, + out=self._final_rollout, ) - return self._tensordict_out + return self._final_rollout + + @staticmethod + def _update_device_wise(tensor0, tensor1): + # given 2 tensors, returns tensor0 if their identity matches, + # or a copy of tensor1 on the device of tensor0 otherwise + if tensor1 is None or tensor1 is tensor0: + return tensor0 + if tensor1.device == tensor0.device: + return tensor1 + return tensor1.to(tensor0.device, non_blocking=True) def reset(self, index=None, **kwargs) -> None: """Resets the environments to a new initial state.""" # metadata - md = self._tensordict.get("collector").clone() + collector_metadata = self._shuttle.get("collector").clone() if index is not None: # check that the env supports partial reset if prod(self.env.batch_size) == 0: @@ -896,20 +1133,22 @@ def reset(self, index=None, **kwargs) -> None: device=self.env.device, ) _reset[index] = 1 - self._tensordict.set(reset_key, _reset) + self._shuttle.set(reset_key, _reset) else: _reset = None - self._tensordict.zero_() + self._shuttle.zero_() - self._tensordict.update(self.env.reset(**kwargs)) - md["traj_ids"] = md["traj_ids"] - md["traj_ids"].min() - self._tensordict["collector"] = md + self._shuttle.update(self.env.reset(**kwargs), inplace=True) + collector_metadata["traj_ids"] = ( + collector_metadata["traj_ids"] - collector_metadata["traj_ids"].min() + ) + self._shuttle["collector"] = collector_metadata def shutdown(self) -> None: """Shuts down all workers and/or closes the local environment.""" if not self.closed: self.closed = True - del self._tensordict, self._tensordict_out + del self._shuttle, self._final_rollout if not self.env.is_closed: self.env.close() del self.env @@ -974,7 +1213,7 @@ def load_state_dict(self, state_dict: OrderedDict, **kwargs) -> None: def __repr__(self) -> str: env_str = indent(f"env={self.env}", 4 * " ") policy_str = indent(f"policy={self.policy}", 4 * " ") - td_out_str = indent(f"td_out={self._tensordict_out}", 4 * " ") + td_out_str = indent(f"td_out={self._final_rollout}", 4 * " ") string = ( f"{self.__class__.__name__}(" f"\n{env_str}," @@ -994,38 +1233,61 @@ class _MultiDataCollector(DataCollectorBase): policy (Callable, optional): Instance of TensorDictModule class. Must accept TensorDictBase object as input. If ``None`` is provided, the policy used will be a - :class:`RandomPolicy` instance with the environment + :class:`~torchrl.collectors.RandomPolicy` instance with the environment ``action_spec``. + + Keyword Args: frames_per_batch (int): A keyword-only argument representing the total number of elements in a batch. - total_frames (int): A keyword-only argument representing the + total_frames (int, optional): A keyword-only argument representing the total number of frames returned by the collector during its lifespan. If the ``total_frames`` is not divisible by ``frames_per_batch``, an exception is raised. Endless collectors can be created by passing ``total_frames=-1``. - device (int, str, torch.device or sequence of such, optional): - The device on which the policy will be placed. - If it differs from the input policy device, the - :meth:`~.update_policy_weights_` method should be queried - at appropriate times during the training loop to accommodate for - the lag between parameter configuration at various times. - If necessary, a list of devices can be passed in which case each - element will correspond to the designated device of a sub-collector. - Defaults to ``None`` (i.e. policy is kept on its original device). - storing_device (int, str, torch.device or sequence of such, optional): - The device on which the output :class:`tensordict.TensorDict` will - be stored. For long trajectories, it may be necessary to store the - data on a different device than the one where the policy and env - are executed. - If necessary, a list of devices can be passed in which case each - element will correspond to the designated storing device of a - sub-collector. - Defaults to ``"cpu"``. + Defaults to ``-1`` (never ending collector). + device (int, str or torch.device, optional): The generic device of the + collector. The ``device`` args fills any non-specified device: if + ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or + ``env_device`` is not specified, its value will be set to ``device``. + Defaults to ``None`` (No default device). + Supports a list of devices if one wishes to indicate a different device + for each worker. The list must be as long as the number of workers. + storing_device (int, str or torch.device, optional): The device on which + the output :class:`~tensordict.TensorDict` will be stored. + If ``device`` is passed and ``storing_device`` is ``None``, it will + default to the value indicated by ``device``. + For long trajectories, it may be necessary to store the data on a different + device than the one where the policy and env are executed. + Defaults to ``None`` (the output tensordict isn't on a specific device, + leaf tensors sit on the device where they were created). + Supports a list of devices if one wishes to indicate a different device + for each worker. The list must be as long as the number of workers. + env_device (int, str or torch.device, optional): The device on which + the environment should be cast (or executed if that functionality is + supported). If not specified and the env has a non-``None`` device, + ``env_device`` will default to that value. If ``device`` is passed + and ``env_device=None``, it will default to ``device``. If the value + as such specified of ``env_device`` differs from ``policy_device`` + and one of them is not ``None``, the data will be cast to ``env_device`` + before being passed to the env (i.e., passing different devices to + policy and env is supported). Defaults to ``None``. + Supports a list of devices if one wishes to indicate a different device + for each worker. The list must be as long as the number of workers. + policy_device (int, str or torch.device, optional): The device on which + the policy should be cast. + If ``device`` is passed and ``policy_device=None``, it will default + to ``device``. If the value as such specified of ``policy_device`` + differs from ``env_device`` and one of them is not ``None``, + the data will be cast to ``policy_device`` before being passed to + the policy (i.e., passing different devices to policy and env is + supported). Defaults to ``None``. + Supports a list of devices if one wishes to indicate a different device + for each worker. The list must be as long as the number of workers. create_env_kwargs (dict, optional): A dictionary with the keyword arguments used to create an environment. If a list is provided, each of its elements will be assigned to a sub-collector. max_frames_per_traj (int, optional): Maximum steps per trajectory. - Note that a trajectory can span over multiple batches (unless + Note that a trajectory can span across multiple batches (unless ``reset_at_each_iter`` is set to ``True``, see below). Once a trajectory reaches ``n_steps``, the environment is reset. If the environment wraps multiple environments together, the number @@ -1051,15 +1313,9 @@ class _MultiDataCollector(DataCollectorBase): information. Defaults to ``False``. exploration_type (ExplorationType, optional): interaction mode to be used when - collecting data. Must be one of ``ExplorationType.RANDOM``, ``ExplorationType.MODE`` or - ``ExplorationType.MEAN``. - Defaults to ``ExplorationType.RANDOM`` - return_same_td (bool, optional): if ``True``, the same TensorDict - will be returned at each iteration, with its values - updated. This feature should be used cautiously: if the same - tensordict is added to a replay buffer for instance, - the whole content of the buffer will be identical. - Default is ``False``. + collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.RANDOM``, + ``torchrl.envs.utils.ExplorationType.MODE`` or ``torchrl.envs.utils.ExplorationType.MEAN``. + Defaults to ``torchrl.envs.utils.ExplorationType.RANDOM``. reset_when_done (bool, optional): if ``True`` (default), an environment that return a ``True`` value in its ``"done"`` or ``"truncated"`` entry will be reset at the corresponding indices. @@ -1088,10 +1344,12 @@ def __init__( ] ], *, - frames_per_batch: int = 200, + frames_per_batch: int, total_frames: Optional[int] = -1, - device: DEVICE_TYPING = None, - storing_device: Optional[Union[DEVICE_TYPING, Sequence[DEVICE_TYPING]]] = None, + device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, + storing_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, + env_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, + policy_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, create_env_kwargs: Optional[Sequence[dict]] = None, max_frames_per_traj: int | None = None, init_random_frames: int | None = None, @@ -1101,10 +1359,8 @@ def __init__( exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, exploration_mode=None, reset_when_done: bool = True, - preemptive_threshold: float = None, update_at_each_batch: bool = False, - devices=None, - storing_devices=None, + preemptive_threshold: float = None, num_threads: int = None, num_sub_threads: int = 1, ): @@ -1134,99 +1390,86 @@ def __init__( # To go around this, we do the copies of the policy in the server # (this object) to each possible device, and send to all the # processes their copy of the policy. - if devices is not None: - if device is not None: - raise ValueError("Cannot pass both devices and device") - warnings.warn( - "`devices` keyword argument will soon be deprecated from multiprocessed collectors. " - "Please use `device` instead." - ) - device = devices - if storing_devices is not None: - if storing_device is not None: - raise ValueError("Cannot pass both storing_devices and storing_device") - warnings.warn( - "`storing_devices` keyword argument will soon be deprecated from multiprocessed collectors. " - "Please use `storing_device` instead." - ) - storing_device = storing_devices - - def device_err_msg(device_name, devices_list): - return ( - f"The length of the {device_name} argument should match the " - f"number of workers of the collector. Got len(" - f"create_env_fn)={self.num_workers} and len(" - f"storing_device)={len(devices_list)}" - ) - if isinstance(device, (str, int, torch.device)): - device = [torch.device(device) for _ in range(self.num_workers)] - elif device is None: - device = [None for _ in range(self.num_workers)] - elif isinstance(device, Sequence): - if len(device) != self.num_workers: - raise RuntimeError(device_err_msg("devices", device)) - device = [torch.device(_device) for _device in device] - else: - raise ValueError( - "devices should be either None, a torch.device or equivalent " - "or an iterable of devices. " - f"Found {type(device)} instead." - ) - self._policy_dict = {} - self._policy_weights_dict = {} - self._get_weights_fn_dict = {} + storing_devices, policy_devices, env_devices = self._get_devices( + storing_device=storing_device, + env_device=env_device, + policy_device=policy_device, + device=device, + ) - for i, (_device, create_env, kwargs) in enumerate( - zip(device, self.create_env_fn, self.create_env_kwargs) - ): - if _device in self._policy_dict: - device[i] = _device - continue + # to avoid confusion + self.storing_device = storing_devices + self.policy_device = policy_devices + self.env_device = env_devices - if hasattr(create_env, "observation_spec"): - observation_spec = create_env.observation_spec - else: - try: - observation_spec = create_env(**kwargs).observation_spec - except: # noqa - observation_spec = None + del storing_device, env_device, policy_device, device - _policy, _device, _get_weight_fn = self._get_policy_and_device( - policy=policy, device=_device, observation_spec=observation_spec - ) - self._policy_dict[_device] = _policy - if isinstance(_policy, nn.Module): - self._policy_weights_dict[_device] = TensorDict.from_module( - _policy, as_module=True - ) - else: - self._policy_weights_dict[_device] = TensorDict({}, []) + _policy_weights_dict = {} + _get_weights_fn_dict = {} - self._get_weights_fn_dict[_device] = _get_weight_fn - device[i] = _device - self.device = device + policy = _NonParametricPolicyWrapper(policy) + policy_weights = TensorDict.from_module(policy, as_module=True) - if storing_device is None: - self.storing_device = self.device - else: - if isinstance(storing_device, (str, int, torch.device)): - self.storing_device = [ - torch.device(storing_device) for _ in range(self.num_workers) - ] - elif isinstance(storing_device, Sequence): - if len(storing_device) != self.num_workers: - raise RuntimeError( - device_err_msg("storing_devices", storing_device) - ) - self.storing_device = [ - torch.device(_storing_device) for _storing_device in storing_device - ] + # store a stateless policy + + with policy_weights.apply(_make_meta_params).to_module(policy): + self.policy = deepcopy(policy) + + for policy_device in policy_devices: + # if we have already mapped onto that device, get that value + if policy_device in _policy_weights_dict: + continue + # If policy device is None, the only thing we need to do is + # make sure that the weights are shared. + if policy_device is None: + + def map_weight( + weight, + ): + is_param = isinstance(weight, nn.Parameter) + weight = weight.data + if weight.device.type in ("cpu", "mps"): + weight = weight.share_memory_() + if is_param: + weight = nn.Parameter(weight, requires_grad=False) + return weight + + # in other cases, we need to cast the policy if and only if not all the weights + # are on the appropriate device else: - raise ValueError( - "storing_devices should be either a torch.device or equivalent or an iterable of devices. " - f"Found {type(storing_device)} instead." - ) + # check the weights devices + has_different_device = [False] + + def map_weight( + weight, + policy_device=policy_device, + has_different_device=has_different_device, + ): + is_param = isinstance(weight, nn.Parameter) + weight = weight.data + if weight.device != policy_device: + has_different_device[0] = True + weight = weight.to(policy_device) + elif weight.device.type in ("cpu", "mps"): + weight = weight.share_memory_() + if is_param: + weight = nn.Parameter(weight, requires_grad=False) + return weight + + local_policy_weights = TensorDictParams(policy_weights.apply(map_weight)) + + def _get_weight_fn(weights=policy_weights): + # This function will give the local_policy_weight the original weights. + # see self.update_policy_weights_ to see how this is used + return weights + + # We lock the weights to be able to cache a bunch of ops and to avoid modifying it + _policy_weights_dict[policy_device] = local_policy_weights.lock_() + _get_weights_fn_dict[policy_device] = _get_weight_fn + + self._policy_weights_dict = _policy_weights_dict + self._get_weights_fn_dict = _get_weights_fn_dict if total_frames is None or total_frames < 0: total_frames = float("inf") @@ -1278,18 +1521,74 @@ def device_err_msg(device_name, devices_list): self._frames = 0 self._iter = -1 + def _get_devices( + self, + *, + storing_device: torch.device, + policy_device: torch.device, + env_device: torch.device, + device: torch.device, + ): + # convert all devices to lists + if not isinstance(storing_device, (list, tuple)): + storing_device = [ + storing_device, + ] * self.num_workers + if not isinstance(policy_device, (list, tuple)): + policy_device = [ + policy_device, + ] * self.num_workers + if not isinstance(env_device, (list, tuple)): + env_device = [ + env_device, + ] * self.num_workers + if not isinstance(device, (list, tuple)): + device = [ + device, + ] * self.num_workers + if not ( + len(device) + == len(storing_device) + == len(policy_device) + == len(env_device) + == self.num_workers + ): + raise RuntimeError( + f"THe length of the devices does not match the number of workers: {self.num_workers}." + ) + storing_device, policy_device, env_device = zip( + *[ + SyncDataCollector._get_devices( + storing_device=storing_device, + policy_device=policy_device, + env_device=env_device, + device=device, + ) + for (storing_device, policy_device, env_device, device) in zip( + storing_device, policy_device, env_device, device + ) + ] + ) + return storing_device, policy_device, env_device + @property def frames_per_batch_worker(self): raise NotImplementedError def update_policy_weights_(self, policy_weights=None) -> None: - for _device in self._policy_dict: + for _device in self._policy_weights_dict: if policy_weights is not None: + if isinstance(policy_weights, TensorDictParams): + policy_weights = policy_weights.data self._policy_weights_dict[_device].data.update_(policy_weights) elif self._get_weights_fn_dict[_device] is not None: - self._policy_weights_dict[_device].data.update_( - self._get_weights_fn_dict[_device]() - ) + original_weights = self._get_weights_fn_dict[_device]() + if original_weights is None: + # if the weights match in identity, we can spare a call to update_ + continue + if isinstance(original_weights, TensorDictParams): + original_weights = original_weights.data + self._policy_weights_dict[_device].data.update_(original_weights) @property def _queue_len(self) -> int: @@ -1303,53 +1602,58 @@ def _run_processes(self) -> None: for i, (env_fun, env_fun_kwargs) in enumerate( zip(self.create_env_fn, self.create_env_kwargs) ): - _device = self.device[i] - _storing_device = self.storing_device[i] pipe_parent, pipe_child = mp.Pipe() # send messages to procs if env_fun.__class__.__name__ != "EnvCreator" and not isinstance( env_fun, EnvBase ): # to avoid circular imports env_fun = CloudpickleWrapper(env_fun) - kwargs = { - "pipe_parent": pipe_parent, - "pipe_child": pipe_child, - "queue_out": queue_out, - "create_env_fn": env_fun, - "create_env_kwargs": env_fun_kwargs, - "policy": self._policy_dict[_device], - "max_frames_per_traj": self.max_frames_per_traj, - "frames_per_batch": self.frames_per_batch_worker, - "reset_at_each_iter": self.reset_at_each_iter, - "device": _device, - "storing_device": _storing_device, - "exploration_type": self.exploration_type, - "reset_when_done": self.reset_when_done, - "idx": i, - "interruptor": self.interruptor, - } - proc = _ProcessNoWarn( - target=_main_async_collector, - num_threads=self.num_sub_threads, - kwargs=kwargs, - ) - # proc.daemon can't be set as daemonic processes may be launched by the process itself - try: - proc.start() - except _pickle.PicklingError as err: - if "" in str(err): - raise RuntimeError( - """Can't open a process with doubly cloud-pickled lambda function. + # Create a policy on the right device + policy_device = self.policy_device[i] + storing_device = self.storing_device[i] + env_device = self.env_device[i] + policy = self.policy + with self._policy_weights_dict[policy_device].to_module(policy): + kwargs = { + "pipe_parent": pipe_parent, + "pipe_child": pipe_child, + "queue_out": queue_out, + "create_env_fn": env_fun, + "create_env_kwargs": env_fun_kwargs, + "policy": policy, + "max_frames_per_traj": self.max_frames_per_traj, + "frames_per_batch": self.frames_per_batch_worker, + "reset_at_each_iter": self.reset_at_each_iter, + "policy_device": policy_device, + "storing_device": storing_device, + "env_device": env_device, + "exploration_type": self.exploration_type, + "reset_when_done": self.reset_when_done, + "idx": i, + "interruptor": self.interruptor, + } + proc = _ProcessNoWarn( + target=_main_async_collector, + num_threads=self.num_sub_threads, + kwargs=kwargs, + ) + # proc.daemon can't be set as daemonic processes may be launched by the process itself + try: + proc.start() + except _pickle.PicklingError as err: + if "" in str(err): + raise RuntimeError( + """Can't open a process with doubly cloud-pickled lambda function. This error is likely due to an attempt to use a ParallelEnv in a multiprocessed data collector. To do this, consider wrapping your lambda function in an `torchrl.envs.EnvCreator` wrapper as follows: `env = ParallelEnv(N, EnvCreator(my_lambda_function))`. This will not only ensure that your lambda function is cloud-pickled once, but also that the state dict is synchronised across processes if needed.""" - ) from err - pipe_child.close() - self.procs.append(proc) - self.pipes.append(pipe_parent) + ) from err + pipe_child.close() + self.procs.append(proc) + self.pipes.append(pipe_parent) for pipe_parent in self.pipes: msg = pipe_parent.recv() if msg != "instantiated": @@ -1971,48 +2275,102 @@ class aSyncDataCollector(MultiaSyncDataCollector): create_env_fn (Callabled): Callable returning an instance of EnvBase policy (Callable, optional): Instance of TensorDictModule class. Must accept TensorDictBase object as input. - total_frames (int): lower bound of the total number of frames returned - by the collector. In parallel settings, the actual number of - frames may well be greater than this as the closing signals are - sent to the workers only once the total number of frames has - been collected on the server. - create_env_kwargs (dict, optional): A dictionary with the arguments - used to create an environment - max_frames_per_traj: Maximum steps per trajectory. Note that a - trajectory can span over multiple batches (unless - reset_at_each_iter is set to True, see below). Once a trajectory - reaches n_steps, the environment is reset. If the - environment wraps multiple environments together, the number of - steps is tracked for each environment independently. Negative + + Keyword Args: + frames_per_batch (int): A keyword-only argument representing the + total number of elements in a batch. + total_frames (int, optional): A keyword-only argument representing the + total number of frames returned by the collector + during its lifespan. If the ``total_frames`` is not divisible by + ``frames_per_batch``, an exception is raised. + Endless collectors can be created by passing ``total_frames=-1``. + Defaults to ``-1`` (never ending collector). + device (int, str or torch.device, optional): The generic device of the + collector. The ``device`` args fills any non-specified device: if + ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or + ``env_device`` is not specified, its value will be set to ``device``. + Defaults to ``None`` (No default device). + Supports a list of devices if one wishes to indicate a different device + for each worker. The list must be as long as the number of workers. + storing_device (int, str or torch.device, optional): The device on which + the output :class:`~tensordict.TensorDict` will be stored. + If ``device`` is passed and ``storing_device`` is ``None``, it will + default to the value indicated by ``device``. + For long trajectories, it may be necessary to store the data on a different + device than the one where the policy and env are executed. + Defaults to ``None`` (the output tensordict isn't on a specific device, + leaf tensors sit on the device where they were created). + Supports a list of devices if one wishes to indicate a different device + for each worker. The list must be as long as the number of workers. + env_device (int, str or torch.device, optional): The device on which + the environment should be cast (or executed if that functionality is + supported). If not specified and the env has a non-``None`` device, + ``env_device`` will default to that value. If ``device`` is passed + and ``env_device=None``, it will default to ``device``. If the value + as such specified of ``env_device`` differs from ``policy_device`` + and one of them is not ``None``, the data will be cast to ``env_device`` + before being passed to the env (i.e., passing different devices to + policy and env is supported). Defaults to ``None``. + Supports a list of devices if one wishes to indicate a different device + for each worker. The list must be as long as the number of workers. + policy_device (int, str or torch.device, optional): The device on which + the policy should be cast. + If ``device`` is passed and ``policy_device=None``, it will default + to ``device``. If the value as such specified of ``policy_device`` + differs from ``env_device`` and one of them is not ``None``, + the data will be cast to ``policy_device`` before being passed to + the policy (i.e., passing different devices to policy and env is + supported). Defaults to ``None``. + Supports a list of devices if one wishes to indicate a different device + for each worker. The list must be as long as the number of workers. + create_env_kwargs (dict, optional): A dictionary with the + keyword arguments used to create an environment. If a list is + provided, each of its elements will be assigned to a sub-collector. + max_frames_per_traj (int, optional): Maximum steps per trajectory. + Note that a trajectory can span across multiple batches (unless + ``reset_at_each_iter`` is set to ``True``, see below). + Once a trajectory reaches ``n_steps``, the environment is reset. + If the environment wraps multiple environments together, the number + of steps is tracked for each environment independently. Negative values are allowed, in which case this argument is ignored. - Defaults to ``None`` (i.e. no maximum number of steps) - frames_per_batch (int): Time-length of a batch. - reset_at_each_iter and frames_per_batch == n_steps are equivalent configurations. - Defaults to ``200`` - init_random_frames (int): Number of frames for which the policy is ignored before it is called. - This feature is mainly intended to be used in offline/model-based settings, where a batch of random - trajectories can be used to initialize training. - Defaults to ``None`` (i.e. no random frames) - reset_at_each_iter (bool): whether environments should be reset for each batch. - default=False. - postproc (callable, optional): A PostProcessor is an object that will read a batch of data and process it in a - useful format for training. - default: None. - split_trajs (bool): Boolean indicating whether the resulting TensorDict should be split according to the trajectories. - See utils.split_trajectories for more information. - device (int, str, torch.device, optional): The device on which the - policy will be placed. If it differs from the input policy - device, the update_policy_weights_() method should be queried - at appropriate times during the training loop to accommodate for - the lag between parameter configuration at various times. - Default is `None` (i.e. policy is kept on its original device) - storing_device (int, str, torch.device, optional): The device on which - the output TensorDict will be stored. For long trajectories, - it may be necessary to store the data on a different. - device than the one where the policy is stored. Default is None. - update_at_each_batch (bool): if ``True``, the policy weights will be updated every time a batch of trajectories - is collected. - default=False + Defaults to ``None`` (i.e. no maximum number of steps). + init_random_frames (int, optional): Number of frames for which the + policy is ignored before it is called. This feature is mainly + intended to be used in offline/model-based settings, where a + batch of random trajectories can be used to initialize training. + If provided, it will be rounded up to the closest multiple of frames_per_batch. + Defaults to ``None`` (i.e. no random frames). + reset_at_each_iter (bool, optional): Whether environments should be reset + at the beginning of a batch collection. + Defaults to ``False``. + postproc (Callable, optional): A post-processing transform, such as + a :class:`~torchrl.envs.Transform` or a :class:`~torchrl.data.postprocs.MultiStep` + instance. + Defaults to ``None``. + split_trajs (bool, optional): Boolean indicating whether the resulting + TensorDict should be split according to the trajectories. + See :func:`~torchrl.collectors.utils.split_trajectories` for more + information. + Defaults to ``False``. + exploration_type (ExplorationType, optional): interaction mode to be used when + collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.RANDOM``, + ``torchrl.envs.utils.ExplorationType.MODE`` or ``torchrl.envs.utils.ExplorationType.MEAN``. + Defaults to ``torchrl.envs.utils.ExplorationType.RANDOM``. + reset_when_done (bool, optional): if ``True`` (default), an environment + that return a ``True`` value in its ``"done"`` or ``"truncated"`` + entry will be reset at the corresponding indices. + update_at_each_batch (boolm optional): if ``True``, :meth:`~.update_policy_weight_()` + will be called before (sync) or after (async) each data collection. + Defaults to ``False``. + preemptive_threshold (float, optional): a value between 0.0 and 1.0 that specifies the ratio of workers + that will be allowed to finished collecting their rollout before the rest are forced to end early. + num_threads (int, optional): number of threads for this process. + Defaults to the number of workers. + num_sub_threads (int, optional): number of threads of the subprocesses. + Should be equal to one plus the number of processes launched within + each subprocess (or one if a single process is launched). + Defaults to 1 for safety: if none is indicated, launching multiple + workers may charge the cpu load too much and harm performance. """ @@ -2024,19 +2382,27 @@ def __init__( TensorDictModule, Callable[[TensorDictBase], TensorDictBase], ] - ] = None, + ], + *, + frames_per_batch: int, total_frames: Optional[int] = -1, - create_env_kwargs: Optional[dict] = None, + device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, + storing_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, + env_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, + policy_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, + create_env_kwargs: Optional[Sequence[dict]] = None, max_frames_per_traj: int | None = None, - frames_per_batch: int = 200, init_random_frames: int | None = None, reset_at_each_iter: bool = False, postproc: Optional[Callable[[TensorDictBase], TensorDictBase]] = None, split_trajs: Optional[bool] = None, - device: Optional[Union[int, str, torch.device]] = None, - storing_device: Optional[Union[int, str, torch.device]] = None, - seed: Optional[int] = None, - pin_memory: bool = False, + exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, + exploration_mode=None, + reset_when_done: bool = True, + update_at_each_batch: bool = False, + preemptive_threshold: float = None, + num_threads: int = None, + num_sub_threads: int = 1, **kwargs, ): super().__init__( @@ -2050,9 +2416,17 @@ def __init__( init_random_frames=init_random_frames, postproc=postproc, split_trajs=split_trajs, - devices=[device] if device is not None else None, - storing_devices=[storing_device] if storing_device is not None else None, - **kwargs, + device=device, + policy_device=policy_device, + env_device=env_device, + storing_device=storing_device, + exploration_type=exploration_type, + exploration_mode=exploration_mode, + reset_when_done=reset_when_done, + update_at_each_batch=update_at_each_batch, + preemptive_threshold=preemptive_threshold, + num_threads=num_threads, + num_sub_threads=num_sub_threads, ) # for RPC @@ -2086,8 +2460,9 @@ def _main_async_collector( max_frames_per_traj: int, frames_per_batch: int, reset_at_each_iter: bool, - device: Optional[Union[torch.device, str, int]], storing_device: Optional[Union[torch.device, str, int]], + env_device: Optional[Union[torch.device, str, int]], + policy_device: Optional[Union[torch.device, str, int]], idx: int = 0, exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, reset_when_done: bool = True, @@ -2098,15 +2473,6 @@ def _main_async_collector( # init variables that will be cleared when closing tensordict = data = d = data_in = inner_collector = dc_iter = None - # send the policy to device - try: - policy = policy.to(device) - except Exception: - if RL_WARNINGS: - warnings.warn( - "Couldn't cast the policy onto the desired device on remote process. " - "If your policy is not a nn.Module instance you can probably ignore this warning." - ) inner_collector = SyncDataCollector( create_env_fn, create_env_kwargs=create_env_kwargs, @@ -2117,8 +2483,9 @@ def _main_async_collector( reset_at_each_iter=reset_at_each_iter, postproc=None, split_trajs=False, - device=device, storing_device=storing_device, + policy_device=policy_device, + env_device=env_device, exploration_type=exploration_type, reset_when_done=reset_when_done, return_same_td=True, @@ -2190,7 +2557,27 @@ def _main_async_collector( raise RuntimeError( f"expected device to be {storing_device} but got {tensordict.device}" ) - tensordict.share_memory_() + # If policy and env are on cpu, we put in shared mem, + # if policy is on cuda and env on cuda, we are fine with this + # If policy is on cuda and env on cpu (or opposite) we put tensors that + # are on cpu in shared mem. + if tensordict.device is not None: + # placehoder in case we need different behaviours + if tensordict.device.type in ("cpu", "mps"): + tensordict.share_memory_() + elif tensordict.device.type == "cuda": + tensordict.share_memory_() + else: + raise NotImplementedError( + f"Device {tensordict.device} is not supported in multi-collectors yet." + ) + else: + # make sure each cpu tensor is shared - assuming non-cpu devices are shared + tensordict.apply( + lambda x: x.share_memory_() + if x.device.type in ("cpu", "mps") + else x + ) data = (tensordict, idx) else: if d is not tensordict: @@ -2258,3 +2645,64 @@ def _main_async_collector( else: raise Exception(f"Unrecognized message {msg}") + + +class _PolicyMetaClass(abc.ABCMeta): + def __call__(cls, *args, **kwargs): + # no kwargs + if isinstance(args[0], nn.Module): + return args[0] + return super().__call__(*args) + + +class _NonParametricPolicyWrapper(nn.Module, metaclass=_PolicyMetaClass): + """A wrapper for non-parametric policies.""" + + def __init__(self, policy): + super().__init__() + self.policy = policy + + @property + def forward(self): + forward = self.__dict__.get("_forward", None) + if forward is None: + + @functools.wraps(self.policy) + def forward(*input, **kwargs): + return self.policy.__call__(*input, **kwargs) + + self.__dict__["_forward"] = forward + return forward + + def __getattr__(self, attr: str) -> Any: + if attr in self.__dir__(): + return self.__getattribute__( + attr + ) # make sure that appropriate exceptions are raised + + elif attr.startswith("__"): + raise AttributeError( + "passing built-in private methods is " + f"not permitted with type {type(self)}. " + f"Got attribute {attr}." + ) + + elif "policy" in self.__dir__(): + policy = self.__getattribute__("policy") + return getattr(policy, attr) + try: + super().__getattr__(attr) + except Exception: + raise AttributeError( + f"policy not set in {self.__class__.__name__}, cannot access {attr}." + ) + + +def _make_meta_params(param): + is_param = isinstance(param, nn.Parameter) + + pd = param.detach().to("meta") + + if is_param: + pd = nn.Parameter(pd, requires_grad=False) + return pd diff --git a/torchrl/collectors/distributed/generic.py b/torchrl/collectors/distributed/generic.py index f213f73d160..073d2f445ab 100644 --- a/torchrl/collectors/distributed/generic.py +++ b/torchrl/collectors/distributed/generic.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. r"""Generic distributed data-collector using torch.distributed backend.""" +from __future__ import annotations import logging import os @@ -11,7 +12,7 @@ import warnings from copy import copy, deepcopy from datetime import timedelta -from typing import OrderedDict +from typing import Callable, List, OrderedDict, Type import torch.cuda from tensordict import TensorDict @@ -261,6 +262,8 @@ class DistributedDataCollector(DataCollectorBase): If ``None`` is provided, the policy used will be a :class:`RandomPolicy` instance with the environment ``action_spec``. + + Keyword Args: frames_per_batch (int): A keyword-only argument representing the total number of elements in a batch. total_frames (int): A keyword-only argument representing the total @@ -268,19 +271,55 @@ class DistributedDataCollector(DataCollectorBase): during its lifespan. If the ``total_frames`` is not divisible by ``frames_per_batch``, an exception is raised. Endless collectors can be created by passing ``total_frames=-1``. + Defaults to ``-1`` (endless collector). + device (int, str or torch.device, optional): The generic device of the + collector. The ``device`` args fills any non-specified device: if + ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or + ``env_device`` is not specified, its value will be set to ``device``. + Defaults to ``None`` (No default device). + Lists of devices are supported. + storing_device (int, str or torch.device, optional): The *remote* device on which + the output :class:`~tensordict.TensorDict` will be stored. + If ``device`` is passed and ``storing_device`` is ``None``, it will + default to the value indicated by ``device``. + For long trajectories, it may be necessary to store the data on a different + device than the one where the policy and env are executed. + Defaults to ``None`` (the output tensordict isn't on a specific device, + leaf tensors sit on the device where they were created). + Lists of devices are supported. + env_device (int, str or torch.device, optional): The *remote* device on which + the environment should be cast (or executed if that functionality is + supported). If not specified and the env has a non-``None`` device, + ``env_device`` will default to that value. If ``device`` is passed + and ``env_device=None``, it will default to ``device``. If the value + as such specified of ``env_device`` differs from ``policy_device`` + and one of them is not ``None``, the data will be cast to ``env_device`` + before being passed to the env (i.e., passing different devices to + policy and env is supported). Defaults to ``None``. + Lists of devices are supported. + policy_device (int, str or torch.device, optional): The *remote* device on which + the policy should be cast. + If ``device`` is passed and ``policy_device=None``, it will default + to ``device``. If the value as such specified of ``policy_device`` + differs from ``env_device`` and one of them is not ``None``, + the data will be cast to ``policy_device`` before being passed to + the policy (i.e., passing different devices to policy and env is + supported). Defaults to ``None``. + Lists of devices are supported. max_frames_per_traj (int, optional): Maximum steps per trajectory. - Note that a trajectory can span over multiple batches (unless + Note that a trajectory can span across multiple batches (unless ``reset_at_each_iter`` is set to ``True``, see below). Once a trajectory reaches ``n_steps``, the environment is reset. If the environment wraps multiple environments together, the number of steps is tracked for each environment independently. Negative values are allowed, in which case this argument is ignored. - Defaults to ``-1`` (i.e. no maximum number of steps). + Defaults to ``None`` (i.e., no maximum number of steps). init_random_frames (int, optional): Number of frames for which the policy is ignored before it is called. This feature is mainly intended to be used in offline/model-based settings, where a batch of random trajectories can be used to initialize training. - Defaults to ``-1`` (i.e. no random frames). + If provided, it will be rounded up to the closest multiple of frames_per_batch. + Defaults to ``None`` (i.e. no random frames). reset_at_each_iter (bool, optional): Whether environments should be reset at the beginning of a batch collection. Defaults to ``False``. @@ -293,13 +332,10 @@ class DistributedDataCollector(DataCollectorBase): See :func:`~torchrl.collectors.utils.split_trajectories` for more information. Defaults to ``False``. - exploration_type (str, optional): interaction mode to be used when - collecting data. Must be one of ``"random"``, ``"mode"`` or - ``"mean"``. - Defaults to ``"random"`` - reset_when_done (bool, optional): if ``True`` (default), an environment - that return a ``True`` value in its ``"done"`` or ``"truncated"`` - entry will be reset at the corresponding indices. + exploration_type (ExplorationType, optional): interaction mode to be used when + collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.RANDOM``, + ``torchrl.envs.utils.ExplorationType.MODE`` or ``torchrl.envs.utils.ExplorationType.MEAN``. + Defaults to ``torchrl.envs.utils.ExplorationType.RANDOM``. collector_class (type or str, optional): a collector class for the remote node. Can be :class:`~torchrl.collectors.SyncDataCollector`, :class:`~torchrl.collectors.MultiSyncDataCollector`, @@ -328,8 +364,6 @@ class DistributedDataCollector(DataCollectorBase): is one of ``"gloo"``, ``"mpi"``, ``"nccl"`` or ``"ucc"``. See the torch.distributed documentation for more information. Defaults to ``"gloo"``. - storing_device (torch.device or compatible, optional): the device where the - data will be delivered. Defaults to ``"cpu"``. update_after_each_batch (bool, optional): if ``True``, the weights will be updated after each collection. For ``sync=True``, this means that all workers will see their weights updated. For ``sync=False``, @@ -367,27 +401,29 @@ def __init__( create_env_fn, policy, *, - frames_per_batch, - total_frames, - max_frames_per_traj=-1, - init_random_frames=-1, - reset_at_each_iter=False, - postproc=None, - split_trajs=False, - exploration_type=DEFAULT_EXPLORATION_TYPE, - exploration_mode=None, - reset_when_done=True, - collector_class=SyncDataCollector, - collector_kwargs=None, - num_workers_per_collector=1, - sync=False, - slurm_kwargs=None, - backend="gloo", - storing_device="cpu", - update_after_each_batch=False, - max_weight_update_interval=-1, - launcher="submitit", - tcp_port=None, + frames_per_batch: int, + total_frames: int = -1, + device: torch.device | List[torch.device] = None, + storing_device: torch.device | List[torch.device] = None, + env_device: torch.device | List[torch.device] = None, + policy_device: torch.device | List[torch.device] = None, + max_frames_per_traj: int = -1, + init_random_frames: int = -1, + reset_at_each_iter: bool = False, + postproc: Callable | None = None, + split_trajs: bool = False, + exploration_type: "ExporationType" = DEFAULT_EXPLORATION_TYPE, # noqa + exploration_mode: str = None, + collector_class: Type = SyncDataCollector, + collector_kwargs: dict = None, + num_workers_per_collector: int = 1, + sync: bool = False, + slurm_kwargs: dict | None = None, + backend: str = "gloo", + update_after_each_batch: bool = False, + max_weight_update_interval: int = -1, + launcher: str = "submitit", + tcp_port: int = None, ): exploration_type = _convert_exploration_type( exploration_mode=exploration_mode, exploration_type=exploration_type @@ -410,7 +446,12 @@ def __init__( self.policy_weights = policy_weights self.num_workers = len(create_env_fn) self.frames_per_batch = frames_per_batch + + self.device = device self.storing_device = storing_device + self.env_device = env_device + self.policy_device = policy_device + # make private to avoid changes from users during collection self._sync = sync self.update_after_each_batch = update_after_each_batch @@ -450,7 +491,7 @@ def __init__( ) # update collector kwargs - for collector_kwarg in self.collector_kwargs: + for i, collector_kwarg in enumerate(self.collector_kwargs): collector_kwarg["max_frames_per_traj"] = max_frames_per_traj collector_kwarg["init_random_frames"] = ( init_random_frames // self.num_workers @@ -465,12 +506,12 @@ def __init__( ) collector_kwarg["reset_at_each_iter"] = reset_at_each_iter collector_kwarg["exploration_type"] = exploration_type - collector_kwarg["reset_when_done"] = reset_when_done + collector_kwarg["device"] = self.device[i] + collector_kwarg["storing_device"] = self.storing_device[i] + collector_kwarg["env_device"] = self.env_device[i] + collector_kwarg["policy_device"] = self.policy_device[i] - if postproc is not None and hasattr(postproc, "to"): - self.postproc = postproc.to(self.storing_device) - else: - self.postproc = postproc + self.postproc = postproc self.split_trajs = split_trajs self.backend = backend @@ -480,6 +521,66 @@ def __init__( self._init_workers() self._make_container() + @property + def device(self) -> List[torch.device]: + return self._device + + @property + def storing_device(self) -> List[torch.device]: + return self._storing_device + + @property + def env_device(self) -> List[torch.device]: + return self._env_device + + @property + def policy_device(self) -> List[torch.device]: + return self._policy_device + + @device.setter + def device(self, value): + if isinstance(value, (tuple, list)): + if len(value) != self.num_workers: + raise RuntimeError( + "The number of devices passed to the collector must match the number of workers." + ) + self._device = value + else: + self._device = [value] * self.num_workers + + @storing_device.setter + def storing_device(self, value): + if isinstance(value, (tuple, list)): + if len(value) != self.num_workers: + raise RuntimeError( + "The number of devices passed to the collector must match the number of workers." + ) + self._storing_device = value + else: + self._storing_device = [value] * self.num_workers + + @env_device.setter + def env_device(self, value): + if isinstance(value, (tuple, list)): + if len(value) != self.num_workers: + raise RuntimeError( + "The number of devices passed to the collector must match the number of workers." + ) + self._env_device = value + else: + self._env_device = [value] * self.num_workers + + @policy_device.setter + def policy_device(self, value): + if isinstance(value, (tuple, list)): + if len(value) != self.num_workers: + raise RuntimeError( + "The number of devices passed to the collector must match the number of workers." + ) + self._policy_device = value + else: + self._policy_device = [value] * self.num_workers + def _init_master_dist( self, world_size, @@ -530,20 +631,7 @@ def _make_container(self): if self._VERBOSE: logging.info("got data", _data) logging.info("expanding...") - if not issubclass(self.collector_class, SyncDataCollector): - # Multi-data collectors - self._tensordict_out = ( - _data.expand((self.num_workers, *_data.shape)) - .to_tensordict() - .to(self.storing_device) - ) - else: - # Multi-data collectors - self._tensordict_out = ( - _data.expand((self.num_workers, *_data.shape)) - .to_tensordict() - .to(self.storing_device) - ) + self._tensordict_out = _data.expand((self.num_workers, *_data.shape)) if self._VERBOSE: logging.info("locking") if self._sync: diff --git a/torchrl/collectors/distributed/ray.py b/torchrl/collectors/distributed/ray.py index 11f94e4ea64..fa2d8e8191e 100644 --- a/torchrl/collectors/distributed/ray.py +++ b/torchrl/collectors/distributed/ray.py @@ -1,11 +1,12 @@ +from __future__ import annotations + import logging import warnings from typing import Callable, Dict, Iterator, List, OrderedDict, Union import torch import torch.nn as nn -from tensordict import TensorDict -from tensordict.tensordict import TensorDictBase +from tensordict import TensorDict, TensorDictBase from torchrl.collectors import MultiaSyncDataCollector from torchrl.collectors.collectors import ( DataCollectorBase, @@ -119,25 +120,64 @@ class RayCollector(DataCollectorBase): instance of :class:`~torchrl.envs.EnvBase`. policy (Callable): Instance of TensorDictModule class. Must accept TensorDictBase object as input. + + Keyword Args: frames_per_batch (int): A keyword-only argument representing the total number of elements in a batch. total_frames (int, Optional): lower bound of the total number of frames returned by the collector. The iterator will stop once the total number of frames equates or exceeds the total number of frames passed to the collector. Default value is -1, which mean no target total number of frames (i.e. the collector will run indefinitely). - max_frames_per_traj (int, optional): Maximum steps per trajectory. - Note that a trajectory can span over multiple batches (unless + device (int, str or torch.device, optional): The generic device of the + collector. The ``device`` args fills any non-specified device: if + ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or + ``env_device`` is not specified, its value will be set to ``device``. + Defaults to ``None`` (No default device). + Lists of devices are supported. + storing_device (int, str or torch.device, optional): The *remote* device on which + the output :class:`~tensordict.TensorDict` will be stored. + If ``device`` is passed and ``storing_device`` is ``None``, it will + default to the value indicated by ``device``. + For long trajectories, it may be necessary to store the data on a different + device than the one where the policy and env are executed. + Defaults to ``None`` (the output tensordict isn't on a specific device, + leaf tensors sit on the device where they were created). + Lists of devices are supported. + env_device (int, str or torch.device, optional): The *remote* device on which + the environment should be cast (or executed if that functionality is + supported). If not specified and the env has a non-``None`` device, + ``env_device`` will default to that value. If ``device`` is passed + and ``env_device=None``, it will default to ``device``. If the value + as such specified of ``env_device`` differs from ``policy_device`` + and one of them is not ``None``, the data will be cast to ``env_device`` + before being passed to the env (i.e., passing different devices to + policy and env is supported). Defaults to ``None``. + Lists of devices are supported. + policy_device (int, str or torch.device, optional): The *remote* device on which + the policy should be cast. + If ``device`` is passed and ``policy_device=None``, it will default + to ``device``. If the value as such specified of ``policy_device`` + differs from ``env_device`` and one of them is not ``None``, + the data will be cast to ``policy_device`` before being passed to + the policy (i.e., passing different devices to policy and env is + supported). Defaults to ``None``. + Lists of devices are supported. + create_env_kwargs (dict, optional): Dictionary of kwargs for + ``create_env_fn``. + max_frames_per_traj (int, optional): Maximum steps per trajectory. + Note that a trajectory can span across multiple batches (unless ``reset_at_each_iter`` is set to ``True``, see below). Once a trajectory reaches ``n_steps``, the environment is reset. If the environment wraps multiple environments together, the number of steps is tracked for each environment independently. Negative values are allowed, in which case this argument is ignored. - Defaults to ``-1`` (i.e. no maximum number of steps). + Defaults to ``None`` (i.e., no maximum number of steps). init_random_frames (int, optional): Number of frames for which the policy is ignored before it is called. This feature is mainly intended to be used in offline/model-based settings, where a batch of random trajectories can be used to initialize training. - Defaults to ``-1`` (i.e. no random frames). + If provided, it will be rounded up to the closest multiple of frames_per_batch. + Defaults to ``None`` (i.e. no random frames). reset_at_each_iter (bool, optional): Whether environments should be reset at the beginning of a batch collection. Defaults to ``False``. @@ -150,13 +190,10 @@ class RayCollector(DataCollectorBase): See :func:`~torchrl.collectors.utils.split_trajectories` for more information. Defaults to ``False``. - exploration_type (str, optional): interaction mode to be used when - collecting data. Must be one of ``ExplorationType.RANDOM``, ``ExplorationType.MODE`` or - ``ExplorationType.MEAN``. - Defaults to ``"random"`` - reset_when_done (bool, optional): if ``True`` (default), an environment - that return a ``True`` value in its ``"done"`` or ``"truncated"`` - entry will be reset at the corresponding indices. + exploration_type (ExplorationType, optional): interaction mode to be used when + collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.RANDOM``, + ``torchrl.envs.utils.ExplorationType.MODE`` or ``torchrl.envs.utils.ExplorationType.MEAN``. + Defaults to ``torchrl.envs.utils.ExplorationType.RANDOM``. collector_class (Python class): a collector class to be remotely instantiated. Can be :class:`~torchrl.collectors.SyncDataCollector`, :class:`~torchrl.collectors.MultiSyncDataCollector`, @@ -182,8 +219,6 @@ class RayCollector(DataCollectorBase): tensordicts collected on each node. If ``False`` (default), each tensordict results from a separate node in a "first-ready, first-served" fashion. - storing_device (torch.device, optional): if specified, collected tensordicts will be moved - to these devices before returning them to the user. update_after_each_batch (bool, optional): if ``True``, the weights will be updated after each collection. For ``sync=True``, this means that all workers will see their weights updated. For ``sync=False``, @@ -237,13 +272,16 @@ def __init__( *, frames_per_batch: int, total_frames: int = -1, + device: torch.device | List[torch.device] = None, + storing_device: torch.device | List[torch.device] = None, + env_device: torch.device | List[torch.device] = None, + policy_device: torch.device | List[torch.device] = None, max_frames_per_traj=-1, init_random_frames=-1, reset_at_each_iter=False, postproc=None, split_trajs=False, exploration_type=DEFAULT_EXPLORATION_TYPE, - reset_when_done=True, collector_class: Callable[[TensorDict], TensorDict] = SyncDataCollector, collector_kwargs: Union[Dict, List[Dict]] = None, num_workers_per_collector: int = 1, @@ -251,7 +289,6 @@ def __init__( ray_init_config: Dict = None, remote_configs: Union[Dict, List[Dict]] = None, num_collectors: int = None, - storing_device: torch.device = "cpu", update_after_each_batch=False, max_weight_update_interval=-1, ): @@ -358,7 +395,10 @@ def check_list_length_consistency(*lists): self.collector_kwargs = ( collector_kwargs if collector_kwargs is not None else [{}] ) + self.device = device self.storing_device = storing_device + self.env_device = env_device + self.policy_device = policy_device self._batches_since_weight_update = [0 for _ in range(self.num_collectors)] self._sync = sync @@ -373,7 +413,7 @@ def check_list_length_consistency(*lists): self._frames_per_batch_corrected = frames_per_batch # update collector kwargs - for collector_kwarg in self.collector_kwargs: + for i, collector_kwarg in enumerate(self.collector_kwargs): collector_kwarg["max_frames_per_traj"] = max_frames_per_traj collector_kwarg["init_random_frames"] = ( init_random_frames // self.num_collectors @@ -388,14 +428,14 @@ def check_list_length_consistency(*lists): ) collector_kwarg["reset_at_each_iter"] = reset_at_each_iter collector_kwarg["exploration_type"] = exploration_type - collector_kwarg["reset_when_done"] = reset_when_done collector_kwarg["split_trajs"] = False collector_kwarg["frames_per_batch"] = self._frames_per_batch_corrected + collector_kwarg["device"] = self.device[i] + collector_kwarg["storing_device"] = self.storing_device[i] + collector_kwarg["env_device"] = self.env_device[i] + collector_kwarg["policy_device"] = self.policy_device[i] - if postproc is not None and hasattr(postproc, "to"): - self.postproc = postproc.to(self.storing_device) - else: - self.postproc = postproc + self.postproc = postproc # Create remote instances of the collector class self._remote_collectors = [] @@ -414,6 +454,54 @@ def check_list_length_consistency(*lists): ] ray.wait(object_refs=pending_samples) + @property + def num_workers(self): + return self.num_collectors + + @property + def device(self) -> List[torch.device]: + return self._device + + @property + def storing_device(self) -> List[torch.device]: + return self._storing_device + + @property + def env_device(self) -> List[torch.device]: + return self._env_device + + @property + def policy_device(self) -> List[torch.device]: + return self._policy_device + + @device.setter + def device(self, value): + if isinstance(value, (tuple, list)): + self._device = value + else: + self._device = [value] * self.num_collectors + + @storing_device.setter + def storing_device(self, value): + if isinstance(value, (tuple, list)): + self._storing_device = value + else: + self._storing_device = [value] * self.num_collectors + + @env_device.setter + def env_device(self, value): + if isinstance(value, (tuple, list)): + self._env_device = value + else: + self._env_device = [value] * self.num_collectors + + @policy_device.setter + def policy_device(self, value): + if isinstance(value, (tuple, list)): + self._policy_device = value + else: + self._policy_device = [value] * self.num_collectors + @staticmethod def _make_collector(cls, env_maker, policy, other_params): """Create a single collector instance.""" @@ -512,7 +600,7 @@ def _sync_iterator(self) -> Iterator[TensorDictBase]: self.collected_frames += out_td.numel() - yield out_td.to(self.storing_device) + yield out_td if self.max_weight_update_interval > -1: for j in range(self.num_collectors): @@ -549,7 +637,7 @@ def _async_iterator(self) -> Iterator[TensorDictBase]: ) # should not be necessary, deleted automatically when ref count is down to 0 self.collected_frames += out_td.numel() - yield out_td.to(self.storing_device) + yield out_td for j in range(self.num_collectors): self._batches_since_weight_update[j] += 1 diff --git a/torchrl/collectors/distributed/rpc.py b/torchrl/collectors/distributed/rpc.py index 98228d15f7b..50729038b4a 100644 --- a/torchrl/collectors/distributed/rpc.py +++ b/torchrl/collectors/distributed/rpc.py @@ -4,6 +4,8 @@ # LICENSE file in the root directory of this source tree. r"""Generic distributed data-collector using torch.distributed.rpc backend.""" +from __future__ import annotations + import collections import logging import os @@ -11,7 +13,7 @@ import time import warnings from copy import copy, deepcopy -from typing import OrderedDict +from typing import Callable, List, OrderedDict from torchrl.collectors.distributed import DEFAULT_SLURM_CONF from torchrl.collectors.distributed.default_configs import ( @@ -96,31 +98,69 @@ class RPCDataCollector(DataCollectorBase): Args: create_env_fn (Callable or List[Callabled]): list of Callables, each returning an instance of :class:`~torchrl.envs.EnvBase`. - policy (Callable, optional): Instance of TensorDictModule class. - Must accept TensorDictBase object as input. + policy (Callable): Policy to be executed in the environment. + Must accept :class:`tensordict.tensordict.TensorDictBase` object as input. If ``None`` is provided, the policy used will be a :class:`RandomPolicy` instance with the environment ``action_spec``. - frames_per_batch (int): A keyword-only argument representing the - total number of elements in a batch. - total_frames (int): A keyword-only argument representing the - total number of frames returned by the collector + + Keyword Args: + frames_per_batch (int): A keyword-only argument representing the total + number of elements in a batch. + total_frames (int): A keyword-only argument representing the total + number of frames returned by the collector during its lifespan. If the ``total_frames`` is not divisible by ``frames_per_batch``, an exception is raised. Endless collectors can be created by passing ``total_frames=-1``. + Defaults to ``-1`` (endless collector). + device (int, str or torch.device, optional): The generic device of the + collector. The ``device`` args fills any non-specified device: if + ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or + ``env_device`` is not specified, its value will be set to ``device``. + Defaults to ``None`` (No default device). + Lists of devices are supported. + storing_device (int, str or torch.device, optional): The *remote* device on which + the output :class:`~tensordict.TensorDict` will be stored. + If ``device`` is passed and ``storing_device`` is ``None``, it will + default to the value indicated by ``device``. + For long trajectories, it may be necessary to store the data on a different + device than the one where the policy and env are executed. + Defaults to ``None`` (the output tensordict isn't on a specific device, + leaf tensors sit on the device where they were created). + Lists of devices are supported. + env_device (int, str or torch.device, optional): The *remote* device on which + the environment should be cast (or executed if that functionality is + supported). If not specified and the env has a non-``None`` device, + ``env_device`` will default to that value. If ``device`` is passed + and ``env_device=None``, it will default to ``device``. If the value + as such specified of ``env_device`` differs from ``policy_device`` + and one of them is not ``None``, the data will be cast to ``env_device`` + before being passed to the env (i.e., passing different devices to + policy and env is supported). Defaults to ``None``. + Lists of devices are supported. + policy_device (int, str or torch.device, optional): The *remote* device on which + the policy should be cast. + If ``device`` is passed and ``policy_device=None``, it will default + to ``device``. If the value as such specified of ``policy_device`` + differs from ``env_device`` and one of them is not ``None``, + the data will be cast to ``policy_device`` before being passed to + the policy (i.e., passing different devices to policy and env is + supported). Defaults to ``None``. + Lists of devices are supported. max_frames_per_traj (int, optional): Maximum steps per trajectory. - Note that a trajectory can span over multiple batches (unless + Note that a trajectory can span across multiple batches (unless ``reset_at_each_iter`` is set to ``True``, see below). Once a trajectory reaches ``n_steps``, the environment is reset. If the environment wraps multiple environments together, the number of steps is tracked for each environment independently. Negative values are allowed, in which case this argument is ignored. - Defaults to ``-1`` (i.e. no maximum number of steps). + Defaults to ``None`` (i.e., no maximum number of steps). init_random_frames (int, optional): Number of frames for which the policy is ignored before it is called. This feature is mainly intended to be used in offline/model-based settings, where a batch of random trajectories can be used to initialize training. - Defaults to ``-1`` (i.e. no random frames). + If provided, it will be rounded up to the closest multiple of frames_per_batch. + Defaults to ``None`` (i.e. no random frames). reset_at_each_iter (bool, optional): Whether environments should be reset at the beginning of a batch collection. Defaults to ``False``. @@ -133,14 +173,10 @@ class RPCDataCollector(DataCollectorBase): See :func:`~torchrl.collectors.utils.split_trajectories` for more information. Defaults to ``False``. - exploration_type (str, optional): interaction mode to be used when - collecting data. Must be one of ``ExplorationType.RANDOM``, - ``ExplorationType.MODE`` or - ``ExplorationType.MEAN``. - Defaults to ``ExplorationType.RANDOM`` - reset_when_done (bool, optional): if ``True`` (default), an environment - that return a ``True`` value in its ``"done"`` or ``"truncated"`` - entry will be reset at the corresponding indices. + exploration_type (ExplorationType, optional): interaction mode to be used when + collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.RANDOM``, + ``torchrl.envs.utils.ExplorationType.MODE`` or ``torchrl.envs.utils.ExplorationType.MEAN``. + Defaults to ``torchrl.envs.utils.ExplorationType.RANDOM``. collector_class (type or str, optional): a collector class for the remote node. Can be :class:`~torchrl.collectors.SyncDataCollector`, :class:`~torchrl.collectors.MultiSyncDataCollector`, @@ -156,7 +192,6 @@ class RPCDataCollector(DataCollectorBase): should always be preferred. If multiple simultaneous environment need to be executed on a single node, consider using a :class:`~torchrl.envs.ParallelEnv` instance. - collector_kwargs (dict or list, optional): a dictionary of parameters to be passed to the remote data-collector. If a list is provided, each element will correspond to an individual set of keyword arguments for the @@ -174,9 +209,6 @@ class RPCDataCollector(DataCollectorBase): first-served" fashion. slurm_kwargs (dict): a dictionary of parameters to be passed to the submitit executor. - storing_device (int, str or torch.device, optional): the device where - data will be stored and delivered by the iterator. Defaults to - ``"cpu"``. update_after_each_batch (bool, optional): if ``True``, the weights will be updated after each collection. For ``sync=True``, this means that all workers will see their weights updated. For ``sync=False``, @@ -217,22 +249,24 @@ def __init__( create_env_fn, policy, *, - frames_per_batch, - total_frames, - max_frames_per_traj=-1, - init_random_frames=-1, - reset_at_each_iter=False, - postproc=None, - split_trajs=False, - exploration_type=DEFAULT_EXPLORATION_TYPE, - exploration_mode=None, - reset_when_done=True, + frames_per_batch: int, + total_frames: int = -1, + device: torch.device | List[torch.device] = None, + storing_device: torch.device | List[torch.device] = None, + env_device: torch.device | List[torch.device] = None, + policy_device: torch.device | List[torch.device] = None, + max_frames_per_traj: int = -1, + init_random_frames: int = -1, + reset_at_each_iter: bool = False, + postproc: Callable | None = None, + split_trajs: bool = False, + exploration_type: "ExporationType" = DEFAULT_EXPLORATION_TYPE, # noqa + exploration_mode: str = None, collector_class=SyncDataCollector, collector_kwargs=None, num_workers_per_collector=1, sync=False, slurm_kwargs=None, - storing_device="cpu", update_after_each_batch=False, max_weight_update_interval=-1, launcher="submitit", @@ -259,6 +293,12 @@ def __init__( self.policy_weights = policy_weights self.num_workers = len(create_env_fn) self.frames_per_batch = frames_per_batch + + self.device = device + self.storing_device = storing_device + self.env_device = env_device + self.policy_device = policy_device + self.storing_device = storing_device # make private to avoid changes from users during collection self._sync = sync @@ -300,7 +340,7 @@ def __init__( ) # update collector kwargs - for collector_kwarg in self.collector_kwargs: + for i, collector_kwarg in enumerate(self.collector_kwargs): collector_kwarg["max_frames_per_traj"] = max_frames_per_traj collector_kwarg["init_random_frames"] = ( init_random_frames // self.num_workers @@ -315,12 +355,12 @@ def __init__( ) collector_kwarg["reset_at_each_iter"] = reset_at_each_iter collector_kwarg["exploration_type"] = exploration_type - collector_kwarg["reset_when_done"] = reset_when_done + collector_kwarg["device"] = self.device[i] + collector_kwarg["storing_device"] = self.storing_device[i] + collector_kwarg["env_device"] = self.env_device[i] + collector_kwarg["policy_device"] = self.policy_device[i] - if postproc is not None and hasattr(postproc, "to"): - self.postproc = postproc.to(self.storing_device) - else: - self.postproc = postproc + self.postproc = postproc self.split_trajs = split_trajs if tensorpipe_options is None: @@ -331,6 +371,66 @@ def __init__( ) self._init() + @property + def device(self) -> List[torch.device]: + return self._device + + @property + def storing_device(self) -> List[torch.device]: + return self._storing_device + + @property + def env_device(self) -> List[torch.device]: + return self._env_device + + @property + def policy_device(self) -> List[torch.device]: + return self._policy_device + + @device.setter + def device(self, value): + if isinstance(value, (tuple, list)): + if len(value) != self.num_workers: + raise RuntimeError( + "The number of devices passed to the collector must match the number of workers." + ) + self._device = value + else: + self._device = [value] * self.num_workers + + @storing_device.setter + def storing_device(self, value): + if isinstance(value, (tuple, list)): + if len(value) != self.num_workers: + raise RuntimeError( + "The number of devices passed to the collector must match the number of workers." + ) + self._storing_device = value + else: + self._storing_device = [value] * self.num_workers + + @env_device.setter + def env_device(self, value): + if isinstance(value, (tuple, list)): + if len(value) != self.num_workers: + raise RuntimeError( + "The number of devices passed to the collector must match the number of workers." + ) + self._env_device = value + else: + self._env_device = [value] * self.num_workers + + @policy_device.setter + def policy_device(self, value): + if isinstance(value, (tuple, list)): + if len(value) != self.num_workers: + raise RuntimeError( + "The number of devices passed to the collector must match the number of workers." + ) + self._policy_device = value + else: + self._policy_device = [value] * self.num_workers + def _init_master_rpc( self, world_size, @@ -585,7 +685,7 @@ def _next_async_rpc(self): args=(self.collector_rrefs[i],), ) self.futures.append((future, i)) - return data.to(self.storing_device) + return data self.futures.append((future, i)) def _next_sync_rpc(self): @@ -612,7 +712,7 @@ def _next_sync_rpc(self): ) else: self.futures.append((future, i)) - data = torch.cat(data).to(self.storing_device) + data = torch.cat(data) traj_ids = data.get(("collector", "traj_ids"), None) if traj_ids is not None: for i in range(1, self.num_workers): diff --git a/torchrl/collectors/distributed/sync.py b/torchrl/collectors/distributed/sync.py index 8d3afa488d4..d7a5c94487d 100644 --- a/torchrl/collectors/distributed/sync.py +++ b/torchrl/collectors/distributed/sync.py @@ -4,13 +4,14 @@ # LICENSE file in the root directory of this source tree. r"""Generic distributed data-collector using torch.distributed backend.""" +from __future__ import annotations import logging import os import socket from copy import copy, deepcopy from datetime import timedelta -from typing import OrderedDict +from typing import Callable, List, OrderedDict import torch.cuda from tensordict import TensorDict @@ -142,6 +143,8 @@ class DistributedSyncDataCollector(DataCollectorBase): If ``None`` is provided, the policy used will be a :class:`RandomPolicy` instance with the environment ``action_spec``. + + Keyword Args: frames_per_batch (int): A keyword-only argument representing the total number of elements in a batch. total_frames (int): A keyword-only argument representing the total @@ -149,19 +152,55 @@ class DistributedSyncDataCollector(DataCollectorBase): during its lifespan. If the ``total_frames`` is not divisible by ``frames_per_batch``, an exception is raised. Endless collectors can be created by passing ``total_frames=-1``. + Defaults to ``-1`` (endless collector). + device (int, str or torch.device, optional): The generic device of the + collector. The ``device`` args fills any non-specified device: if + ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or + ``env_device`` is not specified, its value will be set to ``device``. + Defaults to ``None`` (No default device). + Lists of devices are supported. + storing_device (int, str or torch.device, optional): The *remote* device on which + the output :class:`~tensordict.TensorDict` will be stored. + If ``device`` is passed and ``storing_device`` is ``None``, it will + default to the value indicated by ``device``. + For long trajectories, it may be necessary to store the data on a different + device than the one where the policy and env are executed. + Defaults to ``None`` (the output tensordict isn't on a specific device, + leaf tensors sit on the device where they were created). + Lists of devices are supported. + env_device (int, str or torch.device, optional): The *remote* device on which + the environment should be cast (or executed if that functionality is + supported). If not specified and the env has a non-``None`` device, + ``env_device`` will default to that value. If ``device`` is passed + and ``env_device=None``, it will default to ``device``. If the value + as such specified of ``env_device`` differs from ``policy_device`` + and one of them is not ``None``, the data will be cast to ``env_device`` + before being passed to the env (i.e., passing different devices to + policy and env is supported). Defaults to ``None``. + Lists of devices are supported. + policy_device (int, str or torch.device, optional): The *remote* device on which + the policy should be cast. + If ``device`` is passed and ``policy_device=None``, it will default + to ``device``. If the value as such specified of ``policy_device`` + differs from ``env_device`` and one of them is not ``None``, + the data will be cast to ``policy_device`` before being passed to + the policy (i.e., passing different devices to policy and env is + supported). Defaults to ``None``. + Lists of devices are supported. max_frames_per_traj (int, optional): Maximum steps per trajectory. - Note that a trajectory can span over multiple batches (unless + Note that a trajectory can span across multiple batches (unless ``reset_at_each_iter`` is set to ``True``, see below). Once a trajectory reaches ``n_steps``, the environment is reset. If the environment wraps multiple environments together, the number of steps is tracked for each environment independently. Negative values are allowed, in which case this argument is ignored. - Defaults to ``-1`` (i.e. no maximum number of steps). + Defaults to ``None`` (i.e., no maximum number of steps). init_random_frames (int, optional): Number of frames for which the policy is ignored before it is called. This feature is mainly intended to be used in offline/model-based settings, where a batch of random trajectories can be used to initialize training. - Defaults to ``-1`` (i.e. no random frames). + If provided, it will be rounded up to the closest multiple of frames_per_batch. + Defaults to ``None`` (i.e. no random frames). reset_at_each_iter (bool, optional): Whether environments should be reset at the beginning of a batch collection. Defaults to ``False``. @@ -174,13 +213,10 @@ class DistributedSyncDataCollector(DataCollectorBase): See :func:`~torchrl.collectors.utils.split_trajectories` for more information. Defaults to ``False``. - exploration_type (str, optional): interaction mode to be used when - collecting data. Must be one of ``"random"``, ``"mode"`` or - ``"mean"``. - Defaults to ``"random"`` - reset_when_done (bool, optional): if ``True`` (default), an environment - that return a ``True`` value in its ``"done"`` or ``"truncated"`` - entry will be reset at the corresponding indices. + exploration_type (ExplorationType, optional): interaction mode to be used when + collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.RANDOM``, + ``torchrl.envs.utils.ExplorationType.MODE`` or ``torchrl.envs.utils.ExplorationType.MEAN``. + Defaults to ``torchrl.envs.utils.ExplorationType.RANDOM``. collector_class (type or str, optional): a collector class for the remote node. Can be :class:`~torchrl.collectors.SyncDataCollector`, :class:`~torchrl.collectors.MultiSyncDataCollector`, @@ -205,8 +241,14 @@ class DistributedSyncDataCollector(DataCollectorBase): is one of ``"gloo"``, ``"mpi"``, ``"nccl"`` or ``"ucc"``. See the torch.distributed documentation for more information. Defaults to ``"gloo"``. - storing_device (torch.device or compatible, optional): the device where the - data will be delivered. Defaults to ``"cpu"``. + max_weight_update_interval (int, optional): the maximum number of + batches that can be collected before the policy weights of a worker + is updated. + For sync collections, this parameter is overwritten by ``update_after_each_batch``. + For async collections, it may be that one worker has not seen its + parameters being updated for a certain time even if ``update_after_each_batch`` + is turned on. + Defaults to -1 (no forced update). update_interval (int, optional): the frequency at which the policy is updated. Defaults to 1. launcher (str, optional): how jobs should be launched. @@ -225,22 +267,24 @@ def __init__( create_env_fn, policy, *, - frames_per_batch, - total_frames, - max_frames_per_traj=-1, - init_random_frames=-1, - reset_at_each_iter=False, - postproc=None, - split_trajs=False, - exploration_type=DEFAULT_EXPLORATION_TYPE, - exploration_mode=None, - reset_when_done=True, + frames_per_batch: int, + total_frames: int = -1, + device: torch.device | List[torch.device] = None, + storing_device: torch.device | List[torch.device] = None, + env_device: torch.device | List[torch.device] = None, + policy_device: torch.device | List[torch.device] = None, + max_frames_per_traj: int = -1, + init_random_frames: int = -1, + reset_at_each_iter: bool = False, + postproc: Callable | None = None, + split_trajs: bool = False, + exploration_type: "ExporationType" = DEFAULT_EXPLORATION_TYPE, # noqa + exploration_mode: str = None, collector_class=SyncDataCollector, collector_kwargs=None, num_workers_per_collector=1, slurm_kwargs=None, backend="gloo", - storing_device="cpu", max_weight_update_interval=-1, update_interval=1, launcher="submitit", @@ -267,6 +311,12 @@ def __init__( self.policy_weights = policy_weights self.num_workers = len(create_env_fn) self.frames_per_batch = frames_per_batch + + self.device = device + self.storing_device = storing_device + self.env_device = env_device + self.policy_device = policy_device + self.storing_device = storing_device # make private to avoid changes from users during collection self.update_interval = update_interval @@ -304,19 +354,19 @@ def __init__( ) # update collector kwargs - for collector_kwarg in self.collector_kwargs: + for i, collector_kwarg in enumerate(self.collector_kwargs): collector_kwarg["max_frames_per_traj"] = max_frames_per_traj collector_kwarg["init_random_frames"] = ( init_random_frames // self.num_workers ) collector_kwarg["reset_at_each_iter"] = reset_at_each_iter collector_kwarg["exploration_type"] = exploration_type - collector_kwarg["reset_when_done"] = reset_when_done + collector_kwarg["device"] = self.device[i] + collector_kwarg["storing_device"] = self.storing_device[i] + collector_kwarg["env_device"] = self.env_device[i] + collector_kwarg["policy_device"] = self.policy_device[i] - if postproc is not None and hasattr(postproc, "to"): - self.postproc = postproc.to(self.storing_device) - else: - self.postproc = postproc + self.postproc = postproc self.split_trajs = split_trajs self.backend = backend @@ -326,6 +376,66 @@ def __init__( self._init_workers() self._make_container() + @property + def device(self) -> List[torch.device]: + return self._device + + @property + def storing_device(self) -> List[torch.device]: + return self._storing_device + + @property + def env_device(self) -> List[torch.device]: + return self._env_device + + @property + def policy_device(self) -> List[torch.device]: + return self._policy_device + + @device.setter + def device(self, value): + if isinstance(value, (tuple, list)): + if len(value) != self.num_workers: + raise RuntimeError( + "The number of devices passed to the collector must match the number of workers." + ) + self._device = value + else: + self._device = [value] * self.num_workers + + @storing_device.setter + def storing_device(self, value): + if isinstance(value, (tuple, list)): + if len(value) != self.num_workers: + raise RuntimeError( + "The number of devices passed to the collector must match the number of workers." + ) + self._storing_device = value + else: + self._storing_device = [value] * self.num_workers + + @env_device.setter + def env_device(self, value): + if isinstance(value, (tuple, list)): + if len(value) != self.num_workers: + raise RuntimeError( + "The number of devices passed to the collector must match the number of workers." + ) + self._env_device = value + else: + self._env_device = [value] * self.num_workers + + @policy_device.setter + def policy_device(self, value): + if isinstance(value, (tuple, list)): + if len(value) != self.num_workers: + raise RuntimeError( + "The number of devices passed to the collector must match the number of workers." + ) + self._policy_device = value + else: + self._policy_device = [value] * self.num_workers + def _init_master_dist( self, world_size, @@ -353,20 +463,7 @@ def _make_container(self): ) for _data in pseudo_collector: break - if not issubclass(self.collector_class, SyncDataCollector): - # Multi-data collectors - self._tensordict_out = ( - _data.expand((self.num_workers, *_data.shape)) - .to_tensordict() - .to(self.storing_device) - ) - else: - # Multi-data collectors - self._tensordict_out = ( - _data.expand((self.num_workers, *_data.shape)) - .to_tensordict() - .to(self.storing_device) - ) + self._tensordict_out = _data.expand((self.num_workers, *_data.shape)) self._single_tds = self._tensordict_out.unbind(0) self._tensordict_out.lock_() pseudo_collector.shutdown() diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py index eee3b3e4a98..b8db47f412d 100644 --- a/torchrl/collectors/utils.py +++ b/torchrl/collectors/utils.py @@ -7,8 +7,7 @@ import torch -from tensordict import set_lazy_legacy -from tensordict.tensordict import pad, TensorDictBase +from tensordict import pad, set_lazy_legacy, TensorDictBase def _stack_output(fun) -> Callable: diff --git a/torchrl/data/datasets/d4rl.py b/torchrl/data/datasets/d4rl.py index adf6317e679..10b9767de8e 100644 --- a/torchrl/data/datasets/d4rl.py +++ b/torchrl/data/datasets/d4rl.py @@ -19,8 +19,7 @@ import torch -from tensordict import PersistentTensorDict, TensorDict -from tensordict.tensordict import make_tensordict +from tensordict import make_tensordict, PersistentTensorDict, TensorDict from torchrl.collectors.utils import split_trajectories from torchrl.data.datasets.d4rl_infos import D4RL_DATASETS diff --git a/torchrl/data/datasets/openml.py b/torchrl/data/datasets/openml.py index 0070c86d534..d0acd37822e 100644 --- a/torchrl/data/datasets/openml.py +++ b/torchrl/data/datasets/openml.py @@ -9,7 +9,7 @@ from typing import Callable import numpy as np -from tensordict.tensordict import TensorDict +from tensordict import TensorDict from torchrl.data.datasets.utils import _get_root_dir from torchrl.data.replay_buffers import ( diff --git a/torchrl/data/postprocs/postprocs.py b/torchrl/data/postprocs/postprocs.py index 21f51115d6c..d7b2db3f15a 100644 --- a/torchrl/data/postprocs/postprocs.py +++ b/torchrl/data/postprocs/postprocs.py @@ -6,7 +6,7 @@ from __future__ import annotations import torch -from tensordict.tensordict import TensorDictBase +from tensordict import TensorDictBase from tensordict.utils import expand_right from torch import nn diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index 1381c6d2383..79bf3b9b180 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -17,14 +17,15 @@ import torch -from tensordict import is_tensorclass, unravel_key -from tensordict.nn.utils import _set_dispatch_td_nn_modules -from tensordict.tensordict import ( +from tensordict import ( is_tensor_collection, + is_tensorclass, LazyStackedTensorDict, TensorDict, TensorDictBase, + unravel_key, ) +from tensordict.nn.utils import _set_dispatch_td_nn_modules from tensordict.utils import expand_as_right, expand_right from torch import Tensor diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 5357b9a835f..c37cac634e4 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -18,9 +18,8 @@ import numpy as np import tensordict import torch -from tensordict import is_tensorclass +from tensordict import is_tensor_collection, is_tensorclass, TensorDict, TensorDictBase from tensordict.memmap import MemmapTensor, MemoryMappedTensor -from tensordict.tensordict import is_tensor_collection, TensorDict, TensorDictBase from tensordict.utils import _STRDTYPE2DTYPE, expand_right from torch import multiprocessing as mp diff --git a/torchrl/data/rlhf/dataset.py b/torchrl/data/rlhf/dataset.py index 35f4e99914c..9c6b3d1e58a 100644 --- a/torchrl/data/rlhf/dataset.py +++ b/torchrl/data/rlhf/dataset.py @@ -15,7 +15,7 @@ from tensordict import TensorDict, TensorDictBase -from tensordict.tensordict import NestedKey +from tensordict.utils import NestedKey from torchrl.data.replay_buffers import ( SamplerWithoutReplacement, TensorDictReplayBuffer, diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 4aad6a7b3c1..b4d628a9051 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -31,8 +31,7 @@ import numpy as np import torch -from tensordict import unravel_key -from tensordict.tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase +from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase, unravel_key from tensordict.utils import _getitem_batch_size, NestedKey from torchrl._utils import get_binary_env_var @@ -81,12 +80,14 @@ def _default_dtype_and_device( dtype: Union[None, torch.dtype], device: Union[None, str, int, torch.device], -) -> Tuple[torch.dtype, torch.device]: + allow_none_device: bool = False, +) -> Tuple[torch.dtype, torch.device | None]: if dtype is None: dtype = torch.get_default_dtype() - if device is None: - device = torch.device("cpu") - device = torch.device(device) + if device is not None: + device = torch.device(device) + elif not allow_none_device: + device = torch.zeros(()).device return dtype, device @@ -354,7 +355,7 @@ class ContinuousBox(Box): _low: torch.Tensor _high: torch.Tensor - device: torch.device = None + device: torch.device | None = None # We store the tensors on CPU to avoid overloading CUDA with tensors that are rarely used. @property @@ -522,7 +523,7 @@ class TensorSpec: shape: torch.Size space: Union[None, Box] - device: torch.device = torch.device("cpu") + device: torch.device | None = None dtype: torch.dtype = torch.float domain: str = "" @@ -539,6 +540,10 @@ def decorator(func): return decorator + def clear_device_(self): + """A no-op for all leaf specs (which must have a device).""" + pass + def encode( self, val: Union[np.ndarray, torch.Tensor], *, ignore_device=False ) -> torch.Tensor: @@ -761,6 +766,14 @@ def zero(self, shape=None) -> torch.Tensor: def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> "TensorSpec": raise NotImplementedError + def cpu(self): + return self.to("cpu") + + def cuda(self, device=None): + if device is None: + return self.to("cuda") + return self.to(f"cuda:{device}") + @abc.abstractmethod def clone(self) -> "TensorSpec": raise NotImplementedError @@ -809,6 +822,11 @@ def __init__(self, *specs: tuple[T, ...], dim: int) -> None: if self.dim < 0: self.dim = len(self.shape) + self.dim + def clear_device_(self): + """Clears the device of the CompositeSpec.""" + for spec in self._specs: + spec.clear_device_() + def __getitem__(self, item): is_key = isinstance(item, str) or ( isinstance(item, tuple) and all(isinstance(_item, str) for _item in item) @@ -918,6 +936,8 @@ def rand(self, shape=None) -> TensorDictBase: return torch.stack([spec.rand(shape) for spec in self._specs], dim) def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> T: + if dest is None: + return self return torch.stack([spec.to(dest) for spec in self._specs], self.dim) def unbind(self, dim: int): @@ -1160,7 +1180,7 @@ class OneHotDiscreteTensorSpec(TensorSpec): shape: torch.Size space: DiscreteBox - device: torch.device = torch.device("cpu") + device: torch.device | None = None dtype: torch.dtype = torch.float domain: str = "" @@ -1187,7 +1207,9 @@ def __init__( f"The last value of the shape must match n for transform of type {self.__class__}. " f"Got n={space.n} and shape={shape}." ) - super().__init__(shape, space, device, dtype, "discrete") + super().__init__( + shape=shape, space=space, device=device, dtype=dtype, domain="discrete" + ) self.update_mask(mask) @property @@ -1205,6 +1227,8 @@ def update_mask(self, mask): self.mask = mask def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: + if dest is None: + return self if isinstance(dest, torch.dtype): dest_dtype = dest dest_device = self.device @@ -1529,8 +1553,6 @@ def __init__( dtype, device = _default_dtype_and_device(dtype, device) if dtype is None: dtype = torch.get_default_dtype() - if device is None: - device = torch._get_default_device() if not isinstance(low, torch.Tensor): low = torch.tensor(low, dtype=dtype, device=device) @@ -1592,7 +1614,11 @@ def __init__( self.shape = shape super().__init__( - shape, ContinuousBox(low, high, device=device), device, dtype, domain=domain + shape=shape, + space=ContinuousBox(low, high, device=device), + device=device, + dtype=dtype, + domain=domain, ) def __eq__(self, other): @@ -1750,6 +1776,8 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: if isinstance(dest, torch.dtype): dest_dtype = dest dest_device = self.device + elif dest is None: + return self else: dest_dtype = self.dtype dest_device = torch.device(dest) @@ -1845,6 +1873,8 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: if isinstance(dest, torch.dtype): dest_dtype = dest dest_device = self.device + elif dest is None: + return self else: dest_dtype = self.dtype dest_device = torch.device(dest) @@ -1859,7 +1889,9 @@ def rand(self, shape=None) -> torch.Tensor: if shape is None: shape = torch.Size([]) shape = [*shape, *self.shape] - return torch.randn(shape, device=self.device, dtype=self.dtype) + if self.dtype.is_floating_point: + return torch.randn(shape, device=self.device, dtype=self.dtype) + return torch.empty(shape, device=self.device, dtype=self.dtype).random_() def is_in(self, val: torch.Tensor) -> bool: return True @@ -1979,6 +2011,8 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: if isinstance(dest, torch.dtype): dest_dtype = dest dest_device = self.device + elif dest is None: + return self else: dest_dtype = self.dtype dest_device = torch.device(dest) @@ -2167,6 +2201,8 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: if isinstance(dest, torch.dtype): dest_dtype = dest dest_device = self.device + elif dest is None: + return self else: dest_dtype = self.dtype dest_device = torch.device(dest) @@ -2474,7 +2510,7 @@ class DiscreteTensorSpec(TensorSpec): shape: torch.Size space: DiscreteBox - device: torch.device = torch.device("cpu") + device: torch.device | None = None dtype: torch.dtype = torch.float domain: str = "" @@ -2492,7 +2528,9 @@ def __init__( shape = torch.Size([]) dtype, device = _default_dtype_and_device(dtype, device) space = DiscreteBox(n) - super().__init__(shape, space, device, dtype, domain="discrete") + super().__init__( + shape=shape, space=space, device=device, dtype=dtype, domain="discrete" + ) self.update_mask(mask) @property @@ -2690,6 +2728,8 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: if isinstance(dest, torch.dtype): dest_dtype = dest dest_device = self.device + elif dest is None: + return self else: dest_dtype = self.dtype dest_device = torch.device(dest) @@ -2796,6 +2836,8 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: if isinstance(dest, torch.dtype): dest_dtype = dest dest_device = self.device + elif dest is None: + return self else: dest_dtype = self.dtype dest_device = torch.device(dest) @@ -2908,6 +2950,8 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: if isinstance(dest, torch.dtype): dest_dtype = dest dest_device = self.device + elif dest is None: + return self else: dest_dtype = self.dtype dest_device = torch.device(dest) @@ -3205,6 +3249,15 @@ class CompositeSpec(TensorSpec): to be ``True`` for the corresponding tensors, and :obj:`project()` will have no effect. `spec.encode` cannot be used with missing values. + Attributes: + device (torch.device or None): if not specified, the device of the composite + spec is ``None`` (as it is the case for TensorDicts). A non-none device + constraints all leaves to be of the same device. On the other hand, + a ``None`` device allows leaves to have different devices. Defaults + to ``None``. + shape (torch.Size): the leading shape of all the leaves. Equivalent + to the batch-size of the corresponding tensordicts. + Examples: >>> pixels_spec = BoundedTensorSpec( ... torch.zeros(3,32,32), @@ -3237,7 +3290,6 @@ class CompositeSpec(TensorSpec): device=None, is_shared=False) - Examples: >>> # we can build a nested composite spec using unnamed arguments >>> print(CompositeSpec({("a", "b"): None, ("a", "c"): None})) @@ -3264,7 +3316,7 @@ class CompositeSpec(TensorSpec): @classmethod def __new__(cls, *args, **kwargs): - cls._device = torch.device("cpu") + cls._device = None cls._locked = False return super().__new__(cls) @@ -3330,19 +3382,13 @@ def __init__(self, *args, shape=None, device=None, **kwargs): for key, item in self.items(): if item is None: continue - - try: - item_device = item.device - except RuntimeError as err: - cond1 = DEVICE_ERR_MSG in str(err) - if cond1: - item_device = _device - else: - raise err - - if _device is None: - _device = item_device - elif item_device != _device: + if ( + isinstance(item, CompositeSpec) + and item.device is None + and _device is not None + ): + item = item.clone().to(_device) + elif (_device is not None) and (item.device != _device): raise RuntimeError( f"Setting a new attribute ({key}) on another device " f"({item.device} against {_device}). All devices of " @@ -3361,40 +3407,27 @@ def __init__(self, *args, shape=None, device=None, **kwargs): ) for k, item in argdict.items(): if isinstance(item, dict): - item = CompositeSpec(item, shape=shape) - if item is not None: - if self._device is None: - try: - self._device = item.device - except RuntimeError as err: - if DEVICE_ERR_MSG in str(err): - self._device = item._device - else: - raise err + item = CompositeSpec(item, shape=shape, device=_device) self[k] = item @property def device(self) -> DEVICE_TYPING: - if self._device is None: - # try to replace device by the true device - _device = None - for value in self.values(): - if value is not None: - _device = value.device - if _device is None: - raise RuntimeError( - "device of empty CompositeSpec is not defined. " - "You can set it directly by calling " - "`spec.device = device`." - ) - self._device = _device return self._device @device.setter def device(self, device: DEVICE_TYPING): + if device is None and self._device is not None: + raise RuntimeError( + "To erase the device of a composite spec, call " "spec.clear_device_()." + ) device = torch.device(device) self.to(device) + def clear_device_(self): + """Clears the device of the CompositeSpec.""" + for spec in self._specs: + spec.clear_device_() + def __getitem__(self, idx): """Indexes the current CompositeSpec based on the provided index.""" if isinstance(idx, (str, tuple)): @@ -3456,7 +3489,7 @@ def get(self, item, default=NO_DEFAULT): def __setitem__(self, key, value): if isinstance(key, tuple) and len(key) > 1: if key[0] not in self.keys(True): - self[key[0]] = CompositeSpec(shape=self.shape) + self[key[0]] = CompositeSpec(shape=self.shape, device=self.device) self[key[0]][key[1:]] = value return elif isinstance(key, tuple): @@ -3466,34 +3499,25 @@ def __setitem__(self, key, value): raise TypeError(f"Got key of type {type(key)} when a string was expected.") if key in {"shape", "device", "dtype", "space"}: raise AttributeError(f"CompositeSpec[{key}] cannot be set") - try: - if value is not None and value.device != self.device: + if isinstance(value, dict): + value = CompositeSpec(value, device=self._device, shape=self.shape) + if ( + value is not None + and self.device is not None + and value.device != self.device + ): + if isinstance(value, CompositeSpec) and value.device is None: + value = value.clone().to(self.device) + else: raise RuntimeError( f"Setting a new attribute ({key}) on another device ({value.device} against {self.device}). " f"All devices of CompositeSpec must match." ) - except RuntimeError as err: - cond1 = DEVICE_ERR_MSG in str(err) - cond2 = self._device is None - if cond1 and cond2: - try: - device_val = value.device - self.to(device_val) - except RuntimeError as suberr: - if DEVICE_ERR_MSG in str(suberr): - pass - else: - raise suberr - elif cond1: - pass - else: - raise err self.set(key, value) def __iter__(self): - for k in self._specs: - yield k + yield from self._specs def __delitem__(self, key: str) -> None: if isinstance(key, tuple) and len(key) > 1: @@ -3668,6 +3692,8 @@ def __len__(self): return len(self.keys()) def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: + if dest is None: + return self if not isinstance(dest, (str, int, torch.device)): raise ValueError( "Only device casting is allowed with specs of type CompositeSpec." @@ -4123,7 +4149,23 @@ def __setitem__(self, key: NestedKey, value): @property def device(self) -> DEVICE_TYPING: - return self._specs[0].device + device = self.__dict__.get("_device", NO_DEFAULT) + if device is NO_DEFAULT: + devices = {spec.device for spec in self._specs} + if len(devices) == 1: + device = list(devices)[0] + elif len(devices) == 2: + device0, device1 = devices + if device0 is None: + device = device1 + elif device1 is None: + device = device0 + else: + device = None + else: + device = None + self.__dict__["_device"] = device + return device @property def ndim(self): @@ -4232,7 +4274,18 @@ def _stack_composite_specs(list_of_spec, dim, out=None): raise ValueError("Cannot stack an empty list of specs.") spec0 = list_of_spec[0] if isinstance(spec0, CompositeSpec): - device = spec0.device + devices = {spec.device for spec in list_of_spec} + if len(devices) == 1: + device = list(devices)[0] + elif len(devices) == 2: + device0, device1 = devices + if device0 is None: + device = device1 + elif device1 is None: + device = device0 + else: + device = None + all_equal = True for spec in list_of_spec[1:]: if not isinstance(spec, CompositeSpec): @@ -4240,8 +4293,9 @@ def _stack_composite_specs(list_of_spec, dim, out=None): "Stacking specs cannot occur: Found more than one type of spec in " "the list." ) - if device != spec.device: - raise RuntimeError(f"Devices differ, got {device} and {spec.device}") + if device != spec.device and device is not None: + # spec.device must be None + spec = spec.to(device) if spec.shape != spec0.shape: raise RuntimeError(f"Shapes differ, got {spec.shape} and {spec0.shape}") all_equal = all_equal and spec == spec0 diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index a8d399ea08b..bc253cd3ac7 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -5,6 +5,7 @@ from __future__ import annotations +import gc import logging import os @@ -18,9 +19,8 @@ import torch -from tensordict import TensorDict +from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase from tensordict._tensordict import _unravel_key_to_tuple, unravel_key -from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase from torch import multiprocessing as mp from torchrl._utils import _check_for_faulty_process, _ProcessNoWarn, VERBOSE from torchrl.data.tensor_specs import CompositeSpec @@ -410,6 +410,15 @@ def _check_for_empty_spec(specs: CompositeSpec): self._dummy_env_str = meta_data.env_str self._env_tensordict = meta_data.tensordict + if device is None: # In other cases, the device will be mapped later + self._env_tensordict.clear_device_() + device_map = meta_data.device_map + + def map_device(key, value, device_map=device_map): + return value.to(device_map[key]) + + self._env_tensordict.named_apply(map_device, nested_keys=True) + self._batch_locked = meta_data.batch_locked else: self._batch_size = torch.Size([self.num_workers, *meta_data[0].batch_size]) @@ -576,11 +585,10 @@ def _create_td(self) -> None: # Multi-task: we share tensordict that *may* have different keys # LazyStacked already stores this so we don't need to do anything self.shared_tensordicts = self.shared_tensordict_parent - if self.shared_tensordict_parent.device.type == "cpu": - if self._share_memory: - self.shared_tensordict_parent.share_memory_() - elif self._memmap: - self.shared_tensordict_parent.memmap_() + if self._share_memory: + self.shared_tensordict_parent.share_memory_() + elif self._memmap: + self.shared_tensordict_parent.memmap_() else: if self._share_memory: self.shared_tensordict_parent.share_memory_() @@ -676,6 +684,8 @@ def _start_workers(self) -> None: for idx in range(_num_workers): env = self.create_env_fn[idx](**self.create_env_kwargs[idx]) + if self.device is not None: + env = env.to(self.device) self._envs.append(env) self.is_closed = False @@ -766,6 +776,8 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: ) if out.device == device: out = out.clone() + elif device is None: + out = out.clone().clear_device_() else: out = out.to(device, non_blocking=True) return out @@ -807,6 +819,8 @@ def _step( out = next_td.select(*self._selected_step_keys, strict=False) if out.device == device: out = out.clone() + elif device is None: + out = out.clone().clear_device_() else: out = out.to(device, non_blocking=True) return out @@ -850,8 +864,7 @@ def to(self, device: DEVICE_TYPING): return self super().to(device) if not self.is_closed: - for env in self._envs: - env.to(device) + self._envs = [env.to(device) for env in self._envs] return self @@ -1006,7 +1019,17 @@ def _start_workers(self) -> None: self.parent_channels = [] self._workers = [] func = _run_worker_pipe_shared_mem - if self.shared_tensordict_parent.device.type == "cuda": + # We look for cuda tensors through the leaves + # because the shared tensordict could be partially on cuda + # and some leaves may be inaccessible through get (e.g., LazyStacked) + has_cuda = [False] + + def look_for_cuda(tensor, has_cuda=has_cuda): + has_cuda[0] = has_cuda[0] or tensor.is_cuda + + self.shared_tensordict_parent.apply(look_for_cuda) + has_cuda = has_cuda[0] + if has_cuda: self.event = torch.cuda.Event() else: self.event = None @@ -1123,9 +1146,12 @@ def step_and_maybe_reset( if self.shared_tensordict_parent.device == device: next_td = next_td.clone() tensordict_ = tensordict_.clone() - else: + elif device is not None: next_td = next_td.to(device, non_blocking=True) tensordict_ = tensordict_.to(device, non_blocking=True) + else: + next_td = next_td.clone().clear_device_() + tensordict_ = tensordict_.clone().clear_device_() tensordict.set("next", next_td) return tensordict, tensordict_ @@ -1255,6 +1281,8 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: ) if out.device == device: out = out.clone() + elif device is None: + out = out.clear_device_().clone() else: out = out.to(device, non_blocking=True) return out @@ -1379,7 +1407,18 @@ def _run_worker_pipe_shared_mem( verbose: bool = False, ) -> None: device = shared_tensordict.device - if device.type == "cuda": + if device is None or device.type != "cuda": + # Check if some tensors are shared on cuda + has_cuda = [False] + + def look_for_cuda(tensor, has_cuda=has_cuda): + has_cuda[0] = has_cuda[0] or tensor.is_cuda + + shared_tensordict.apply(look_for_cuda) + has_cuda = has_cuda[0] + else: + has_cuda = device.type == "cuda" + if has_cuda: event = torch.cuda.Event() else: event = None @@ -1403,7 +1442,7 @@ def _run_worker_pipe_shared_mem( initialized = False child_pipe.send("started") - + next_shared_tensordict, root_shared_tensordict = (None,) * 2 while True: try: if child_pipe.poll(_timeout): @@ -1452,42 +1491,49 @@ def _run_worker_pipe_shared_mem( event.record() event.synchronize() mp_event.set() + del cur_td elif cmd == "step": if not initialized: raise RuntimeError("called 'init' before step") i += 1 - env_input = shared_tensordict - next_td = env._step(env_input) + next_td = env._step(shared_tensordict) next_shared_tensordict.update_(next_td) if event is not None: event.record() event.synchronize() mp_event.set() + del next_td elif cmd == "step_and_maybe_reset": if not initialized: raise RuntimeError("called 'init' before step") i += 1 - env_input = shared_tensordict - td, root_next_td = env.step_and_maybe_reset(env_input) + td, root_next_td = env.step_and_maybe_reset(shared_tensordict) next_shared_tensordict.update_(td.get("next")) root_shared_tensordict.update_(root_next_td) if event is not None: event.record() event.synchronize() mp_event.set() + del td, root_next_td elif cmd == "close": - del shared_tensordict, data if not initialized: raise RuntimeError("call 'init' before closing") env.close() - del env + del ( + env, + shared_tensordict, + data, + next_shared_tensordict, + root_shared_tensordict, + ) mp_event.set() child_pipe.close() if verbose: logging.info(f"{pid} closed") + gc.collect() break elif cmd == "load_state_dict": @@ -1498,6 +1544,7 @@ def _run_worker_pipe_shared_mem( state_dict = _recursively_strip_locks_from_state_dict(env.state_dict()) msg = "state_dict" child_pipe.send((msg, state_dict)) + del state_dict else: err_msg = f"{cmd} from env" diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 39484ac355a..61cd211b6ae 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -14,8 +14,7 @@ import numpy as np import torch import torch.nn as nn -from tensordict import LazyStackedTensorDict, unravel_key -from tensordict.tensordict import TensorDictBase +from tensordict import LazyStackedTensorDict, TensorDictBase, unravel_key from tensordict.utils import NestedKey from torchrl._utils import _replace_last, implement_for, prod, seed_generator @@ -59,6 +58,7 @@ def __init__( env_str: str, device: torch.device, batch_locked: bool = True, + device_map: dict = None, ): self.device = device self.tensordict = tensordict @@ -66,6 +66,7 @@ def __init__( self.batch_size = batch_size self.env_str = env_str self.batch_locked = batch_locked + self.device_map = device_map @property def tensordict(self): @@ -100,7 +101,16 @@ def metadata_from_env(env) -> EnvMetaData: device = env.device specs = specs.to("cpu") batch_locked = env.batch_locked - return EnvMetaData(tensordict, specs, batch_size, env_str, device, batch_locked) + # we need to save the device map, as the tensordict will be placed on cpu + device_map = {} + + def fill_device_map(name, val, device_map=device_map): + device_map[name] = val.device + + tensordict.named_apply(fill_device_map, nested_keys=True) + return EnvMetaData( + tensordict, specs, batch_size, env_str, device, batch_locked, device_map + ) def expand(self, *size: int) -> EnvMetaData: tensordict = self.tensordict.expand(*size).clone() @@ -112,6 +122,7 @@ def expand(self, *size: int) -> EnvMetaData: self.env_str, self.device, self.batch_locked, + self.device_map, ) def clone(self): @@ -122,13 +133,23 @@ def clone(self): deepcopy(self.env_str), self.device, self.batch_locked, + self.device_map, ) def to(self, device: DEVICE_TYPING) -> EnvMetaData: + if device is not None: + device = torch.device(device) + device_map = {key: device for key in self.device_map} tensordict = self.tensordict.contiguous().to(device) specs = self.specs.to(device) return EnvMetaData( - tensordict, specs, self.batch_size, self.env_str, device, self.batch_locked + tensordict, + specs, + self.batch_size, + self.env_str, + device, + self.batch_locked, + device_map, ) @@ -149,6 +170,51 @@ def __call__(cls, *args, **kwargs): class EnvBase(nn.Module, metaclass=_EnvPostInit): """Abstract environment parent class. + Keyword Args: + device (torch.device): The device of the environment. Deviceless environments + are allowed (device=None). If not ``None``, all specs will be cast + on that device and it is expected that all inputs and outputs will + live on that device. + Defaults to ``None``. + dtype (deprecated): dtype of the observations. Will be deprecated in v0.4. + batch_size (torch.Size or equivalent, optional): batch-size of the environment. + Corresponds to the leading dimension of all the input and output + tensordicts the environment reads and writes. Defaults to an empty batch-size. + run_type_checks (bool, optional): If ``True``, type-checks will occur + at every reset and every step. Defaults to ``False``. + allow_done_after_reset (bool, optional): if ``True``, an environment can + be done after a call to :meth:`~.reset` is made. Defaults to ``False``. + + Attributes: + done_spec (CompositeSpec): equivalent to ``full_done_spec`` as all + ``done_specs`` contain at least a ``"done"`` and a ``"terminated"`` entry + action_spec (TensorSpec): the spec of the action. Links to the spec of the leaf + action if only one action tensor is to be expected. Otherwise links to + ``full_action_spec``. + observation_spec (CompositeSpec): equivalent to ``full_observation_spec``. + reward_spec (TensorSpec): the spec of the reward. Links to the spec of the leaf + reward if only one reward tensor is to be expected. Otherwise links to + ``full_reward_spec``. + state_spec (CompositeSpec): equivalent to ``full_state_spec``. + full_done_spec (CompositeSpec): a composite spec such that ``full_done_spec.zero()`` + returns a tensordict containing only the leaves encoding the done status of the + environment. + full_action_spec (CompositeSpec): a composite spec such that ``full_action_spec.zero()`` + returns a tensordict containing only the leaves encoding the action of the + environment. + full_observation_spec (CompositeSpec): a composite spec such that ``full_observation_spec.zero()`` + returns a tensordict containing only the leaves encoding the observation of the + environment. + full_reward_spec (CompositeSpec): a composite spec such that ``full_reward_spec.zero()`` + returns a tensordict containing only the leaves encoding the reward of the + environment. + full_state_spec (CompositeSpec): a composite spec such that ``full_state_spec.zero()`` + returns a tensordict containing only the leaves encoding the inputs (actions + excluded) of the environment. + batch_size (torch.Size): The batch-size of the environment. + device (torch.device): the device where the input/outputs of the environment + are to be expected. Can be ``None``. + Methods: step (TensorDictBase -> TensorDictBase): step in the environment reset (TensorDictBase, optional -> TensorDictBase): reset the environment @@ -158,6 +224,15 @@ class EnvBase(nn.Module, metaclass=_EnvPostInit): steps if no policy is provided) Examples: + >>> from torchrl.envs import EnvBase + >>> class CounterEnv(EnvBase): + ... def __init__(self, batch_size=(), device=None, **kwargs): + ... self.observation_spec = CompositeSpec( + ... count=UnboundedContinuousTensorSpec(batch_size, device=device, dtype=torch.int64)) + ... self.action_spec = UnboundedContinuousTensorSpec(batch_size, device=device, dtype=torch.int8) + ... # done spec and reward spec are set automatically + ... def _step(self, tensordict): + ... >>> from torchrl.envs.libs.gym import GymEnv >>> env = GymEnv("Pendulum-v1") >>> env.batch_size # how many envs are run at once @@ -238,23 +313,30 @@ class EnvBase(nn.Module, metaclass=_EnvPostInit): def __init__( self, + *, device: DEVICE_TYPING = None, dtype: Optional[Union[torch.dtype, np.dtype]] = None, batch_size: Optional[torch.Size] = None, run_type_checks: bool = False, allow_done_after_reset: bool = False, ): - if device is None: - device = torch.device("cpu") self.__dict__.setdefault("_batch_size", None) if device is not None: self.__dict__["_device"] = torch.device(device) output_spec = self.__dict__.get("_output_spec", None) if output_spec is not None: - self.__dict__["_output_spec"] = output_spec.to(self.device) + self.__dict__["_output_spec"] = ( + output_spec.to(self.device) + if self.device is not None + else output_spec + ) input_spec = self.__dict__.get("_input_spec", None) if input_spec is not None: - self.__dict__["_input_spec"] = input_spec.to(self.device) + self.__dict__["_input_spec"] = ( + input_spec.to(self.device) + if self.device is not None + else input_spec + ) super().__init__() self.dtype = dtype_map.get(dtype, dtype) @@ -360,8 +442,6 @@ def batch_size(self, value: torch.Size) -> None: @property def device(self) -> torch.device: device = self.__dict__.get("_device", None) - if device is None: - device = self.__dict__["_device"] = torch.device("cpu") return device @device.setter @@ -618,7 +698,7 @@ def action_spec(self) -> TensorSpec: def action_spec(self, value: TensorSpec) -> None: try: self.input_spec.unlock_() - device = self.input_spec.device + device = self.input_spec._device try: delattr(self, "_action_keys") except AttributeError: @@ -806,7 +886,7 @@ def reward_spec(self) -> TensorSpec: def reward_spec(self, value: TensorSpec) -> None: try: self.output_spec.unlock_() - device = self.output_spec.device + device = self.output_spec._device try: delattr(self, "_reward_keys") except AttributeError: @@ -873,7 +953,7 @@ def full_reward_spec(self) -> CompositeSpec: @full_reward_spec.setter def full_reward_spec(self, spec: CompositeSpec) -> None: - self.reward_spec = spec + self.reward_spec = spec.to(self.device) if self.device is not None else spec # done spec @property @@ -938,7 +1018,7 @@ def full_done_spec(self) -> CompositeSpec: @full_done_spec.setter def full_done_spec(self, spec: CompositeSpec) -> None: - self.done_spec = spec + self.done_spec = spec.to(self.device) if self.device is not None else spec # Done spec: done specs belong to output_spec @property @@ -1168,7 +1248,6 @@ def observation_spec(self) -> CompositeSpec: def observation_spec(self, value: TensorSpec) -> None: try: self.output_spec.unlock_() - device = self.output_spec.device if not isinstance(value, CompositeSpec): raise TypeError("The type of an observation_spec must be Composite.") elif value.shape[: len(self.batch_size)] != self.batch_size: @@ -1179,7 +1258,10 @@ def observation_spec(self, value: TensorSpec) -> None: raise ValueError( f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})." ) - self.output_spec["full_observation_spec"] = value.to(device) + device = self.output_spec._device + self.output_spec["full_observation_spec"] = ( + value.to(device) if device is not None else value + ) finally: self.output_spec.lock_() @@ -1253,7 +1335,9 @@ def state_spec(self, value: CompositeSpec) -> None: raise ValueError( f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})." ) - self.input_spec["full_state_spec"] = value.to(device) + self.input_spec["full_state_spec"] = ( + value.to(device) if device is not None else value + ) finally: self.input_spec.lock_() @@ -2274,10 +2358,13 @@ def rollout( [None, 'time'] """ - try: - policy_device = next(policy.parameters()).device - except (StopIteration, AttributeError): - policy_device = self.device + if auto_cast_to_device: + try: + policy_device = next(policy.parameters()).device + except (StopIteration, AttributeError): + policy_device = None + else: + policy_device = None env_device = self.device @@ -2330,10 +2417,16 @@ def _rollout_stop_early( tensordicts = [] for i in range(max_steps): if auto_cast_to_device: - tensordict = tensordict.to(policy_device, non_blocking=True) + if policy_device is not None: + tensordict = tensordict.to(policy_device, non_blocking=True) + else: + tensordict.clear_device_() tensordict = policy(tensordict) if auto_cast_to_device: - tensordict = tensordict.to(env_device, non_blocking=True) + if env_device is not None: + tensordict = tensordict.to(env_device, non_blocking=True) + else: + tensordict.clear_device_() tensordict = self.step(tensordict) tensordicts.append(tensordict.clone(False)) @@ -2378,10 +2471,16 @@ def _rollout_nonstop( tensordict_ = tensordict for i in range(max_steps): if auto_cast_to_device: - tensordict_ = tensordict_.to(policy_device, non_blocking=True) + if policy_device is not None: + tensordict_ = tensordict_.to(policy_device, non_blocking=True) + else: + tensordict_.clear_device_() tensordict_ = policy(tensordict_) if auto_cast_to_device: - tensordict_ = tensordict_.to(env_device, non_blocking=True) + if env_device is not None: + tensordict_ = tensordict_.to(env_device, non_blocking=True) + else: + tensordict_.clear_device_() tensordict, tensordict_ = self.step_and_maybe_reset(tensordict_) tensordicts.append(tensordict) if i == max_steps - 1: diff --git a/torchrl/envs/env_creator.py b/torchrl/envs/env_creator.py index 9053b42f7f6..28c9e00c42a 100644 --- a/torchrl/envs/env_creator.py +++ b/torchrl/envs/env_creator.py @@ -10,7 +10,7 @@ from typing import Callable, Dict, Optional, Union import torch -from tensordict.tensordict import TensorDictBase +from tensordict import TensorDictBase from torchrl.data.utils import CloudpickleWrapper from torchrl.envs.common import EnvBase, EnvMetaData diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 7eca6f5a1db..60cb026c658 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -13,8 +13,7 @@ import numpy as np import torch -from tensordict import TensorDict -from tensordict.tensordict import TensorDictBase +from tensordict import TensorDict, TensorDictBase from torchrl.data.tensor_specs import ( CompositeSpec, diff --git a/torchrl/envs/libs/brax.py b/torchrl/envs/libs/brax.py index c50d1189e59..c10813d4fc3 100644 --- a/torchrl/envs/libs/brax.py +++ b/torchrl/envs/libs/brax.py @@ -7,7 +7,7 @@ from typing import Dict, Optional, Union import torch -from tensordict.tensordict import TensorDict, TensorDictBase +from tensordict import TensorDict, TensorDictBase from torchrl.data.tensor_specs import ( BoundedTensorSpec, diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index 7915ed91338..c6590d344e3 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -212,7 +212,7 @@ def gym_backend(submodule=None): def _gym_to_torchrl_spec_transform( spec, dtype=None, - device="cpu", + device=None, categorical_action_encoding=False, remap_state_to_observation: bool = True, batch_size: tuple = (), @@ -224,7 +224,7 @@ def _gym_to_torchrl_spec_transform( dtype (torch.dtype): a dtype to use for the spec. Defaults to`spec.dtype`. device (torch.device): the device for the spec. - Defaults to ``"cpu"``. + Defaults to ``None`` (no device for composite and default device for specs). categorical_action_encoding (bool): whether discrete spaces should be mapped to categorical or one-hot. Defaults to ``False`` (one-hot). remap_state_to_observation (bool): whether to rename the 'state' key of @@ -349,7 +349,7 @@ def _gym_to_torchrl_spec_transform( remap_state_to_observation=remap_state_to_observation, ) # the batch-size must be set later - return CompositeSpec(spec_out) + return CompositeSpec(spec_out, device=device) elif isinstance(spec, gym_spaces.dict.Dict): return _gym_to_torchrl_spec_transform( spec.spaces, diff --git a/torchrl/envs/libs/jax_utils.py b/torchrl/envs/libs/jax_utils.py index 68437d07d35..95c64183b7d 100644 --- a/torchrl/envs/libs/jax_utils.py +++ b/torchrl/envs/libs/jax_utils.py @@ -11,7 +11,7 @@ import torch # from jax import dlpack as jax_dlpack, numpy as jnp -from tensordict.tensordict import make_tensordict, TensorDictBase +from tensordict import make_tensordict, TensorDictBase from torch.utils import dlpack as torch_dlpack from torchrl.data.tensor_specs import ( CompositeSpec, diff --git a/torchrl/envs/libs/jumanji.py b/torchrl/envs/libs/jumanji.py index 8f8e63a6e59..42c32b3547f 100644 --- a/torchrl/envs/libs/jumanji.py +++ b/torchrl/envs/libs/jumanji.py @@ -8,7 +8,7 @@ import numpy as np import torch -from tensordict.tensordict import TensorDict, TensorDictBase +from tensordict import TensorDict, TensorDictBase from torchrl.envs.utils import _classproperty _has_jumanji = importlib.util.find_spec("jumanji") is not None diff --git a/torchrl/envs/libs/openml.py b/torchrl/envs/libs/openml.py index 1c3927e6d0f..0aa5aa99313 100644 --- a/torchrl/envs/libs/openml.py +++ b/torchrl/envs/libs/openml.py @@ -5,7 +5,7 @@ import importlib.util import torch -from tensordict.tensordict import TensorDict, TensorDictBase +from tensordict import TensorDict, TensorDictBase from torchrl.data.datasets.openml import OpenMLExperienceReplay from torchrl.data.replay_buffers import SamplerWithoutReplacement diff --git a/torchrl/envs/libs/pettingzoo.py b/torchrl/envs/libs/pettingzoo.py index 73c293186ac..14e45eb4bc4 100644 --- a/torchrl/envs/libs/pettingzoo.py +++ b/torchrl/envs/libs/pettingzoo.py @@ -10,7 +10,7 @@ from typing import Dict, List, Tuple, Union import torch -from tensordict.tensordict import TensorDictBase +from tensordict import TensorDictBase from torchrl.data.tensor_specs import ( CompositeSpec, diff --git a/torchrl/envs/libs/robohive.py b/torchrl/envs/libs/robohive.py index ee43e72ffe0..4d4998eb721 100644 --- a/torchrl/envs/libs/robohive.py +++ b/torchrl/envs/libs/robohive.py @@ -11,8 +11,7 @@ import numpy as np import torch -from tensordict import TensorDict -from tensordict.tensordict import make_tensordict +from tensordict import make_tensordict, TensorDict from torchrl._utils import implement_for from torchrl.data.tensor_specs import UnboundedContinuousTensorSpec from torchrl.envs.libs.gym import _AsyncMeta, _gym_to_torchrl_spec_transform, GymEnv diff --git a/torchrl/envs/libs/vmas.py b/torchrl/envs/libs/vmas.py index 1a5d0e2ce15..51d3970fded 100644 --- a/torchrl/envs/libs/vmas.py +++ b/torchrl/envs/libs/vmas.py @@ -9,7 +9,7 @@ from typing import Dict, List, Optional, Union import torch -from tensordict.tensordict import TensorDict, TensorDictBase +from tensordict import TensorDict, TensorDictBase from torchrl.data.tensor_specs import ( BoundedTensorSpec, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 07bc29d1b59..cf67cdc0cc3 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -21,12 +21,13 @@ is_tensor_collection, NonTensorData, set_lazy_legacy, + TensorDict, + TensorDictBase, unravel_key, unravel_key_list, ) from tensordict._tensordict import _unravel_key_to_tuple from tensordict.nn import dispatch, TensorDictModuleBase -from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import expand_as_right, NestedKey from torch import nn, Tensor from torch.utils._pytree import tree_map diff --git a/torchrl/envs/transforms/vip.py b/torchrl/envs/transforms/vip.py index 289dd60f053..ed288fdea9e 100644 --- a/torchrl/envs/transforms/vip.py +++ b/torchrl/envs/transforms/vip.py @@ -5,8 +5,7 @@ from typing import List, Optional, Union import torch -from tensordict import set_lazy_legacy, TensorDict -from tensordict.tensordict import TensorDictBase +from tensordict import set_lazy_legacy, TensorDict, TensorDictBase from torch.hub import load_state_dict_from_url from torchrl.data.tensor_specs import ( diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 82f0c2d21fb..0b978e5ef68 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -15,7 +15,12 @@ import torch -from tensordict import is_tensor_collection, TensorDictBase, unravel_key +from tensordict import ( + is_tensor_collection, + LazyStackedTensorDict, + TensorDictBase, + unravel_key, +) from tensordict.nn.probabilistic import ( # noqa # Note: the `set_interaction_mode` and their associated arg `default_interaction_mode` are being deprecated! # Please use the `set_/interaction_type` ones above with the InteractionType enum instead. @@ -26,7 +31,7 @@ set_interaction_mode as set_exploration_mode, set_interaction_type as set_exploration_type, ) -from tensordict.tensordict import LazyStackedTensorDict, NestedKey +from tensordict.utils import NestedKey from torchrl._utils import _replace_last from torchrl.data.tensor_specs import ( diff --git a/torchrl/modules/models/recipes/impala.py b/torchrl/modules/models/recipes/impala.py index f80524a0f9f..5a59bc55fa1 100644 --- a/torchrl/modules/models/recipes/impala.py +++ b/torchrl/modules/models/recipes/impala.py @@ -6,7 +6,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from tensordict.tensordict import TensorDictBase +from tensordict import TensorDictBase # TODO: code small architecture ref in Impala paper diff --git a/torchrl/modules/planners/cem.py b/torchrl/modules/planners/cem.py index 1a3fdac7387..6d9e6fb3b49 100644 --- a/torchrl/modules/planners/cem.py +++ b/torchrl/modules/planners/cem.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. import torch -from tensordict.tensordict import TensorDict, TensorDictBase +from tensordict import TensorDict, TensorDictBase from torchrl.envs.common import EnvBase from torchrl.modules.planners.common import MPCPlannerBase diff --git a/torchrl/modules/planners/common.py b/torchrl/modules/planners/common.py index 66fd1bb9e1f..3ddb9012139 100644 --- a/torchrl/modules/planners/common.py +++ b/torchrl/modules/planners/common.py @@ -6,7 +6,7 @@ from typing import Optional import torch -from tensordict.tensordict import TensorDictBase +from tensordict import TensorDictBase from torchrl.envs.common import EnvBase from torchrl.modules import SafeModule diff --git a/torchrl/modules/planners/mppi.py b/torchrl/modules/planners/mppi.py index b390d05fad6..c65b81eb11d 100644 --- a/torchrl/modules/planners/mppi.py +++ b/torchrl/modules/planners/mppi.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. import torch -from tensordict.tensordict import TensorDict, TensorDictBase +from tensordict import TensorDict, TensorDictBase from torch import nn from torchrl.envs.common import EnvBase diff --git a/torchrl/modules/tensordict_module/common.py b/torchrl/modules/tensordict_module/common.py index 22786519681..221ba3cde8d 100644 --- a/torchrl/modules/tensordict_module/common.py +++ b/torchrl/modules/tensordict_module/common.py @@ -13,10 +13,9 @@ import torch -from tensordict import unravel_key_list +from tensordict import TensorDictBase, unravel_key_list from tensordict.nn import TensorDictModule, TensorDictModuleBase -from tensordict.tensordict import TensorDictBase from tensordict.utils import NestedKey from torch import nn diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index 5c8ae799061..f641fdfef88 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -7,13 +7,13 @@ import numpy as np import torch +from tensordict import TensorDictBase from tensordict.nn import ( TensorDictModule, TensorDictModuleBase, TensorDictModuleWrapper, ) -from tensordict.tensordict import TensorDictBase from tensordict.utils import expand_as_right, expand_right, NestedKey from torchrl.data.tensor_specs import CompositeSpec, TensorSpec diff --git a/torchrl/modules/tensordict_module/rnn.py b/torchrl/modules/tensordict_module/rnn.py index fe970c292be..b05cbd55356 100644 --- a/torchrl/modules/tensordict_module/rnn.py +++ b/torchrl/modules/tensordict_module/rnn.py @@ -9,9 +9,9 @@ import torch.nn.functional as F from tensordict import TensorDictBase, unravel_key_list -from tensordict.nn import TensorDictModuleBase as ModuleBase +from tensordict.base import NO_DEFAULT -from tensordict.tensordict import NO_DEFAULT +from tensordict.nn import TensorDictModuleBase as ModuleBase from tensordict.utils import expand_as_right, prod, set_lazy_legacy from torch import nn, Tensor diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index 2cdb7af2553..c32a795a2a0 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -10,8 +10,8 @@ from typing import Tuple import torch +from tensordict import TensorDict, TensorDictBase from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule -from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import NestedKey from torch import distributions as d @@ -89,7 +89,7 @@ class A2CLoss(LossModule): >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.a2c import A2CLoss - >>> from tensordict.tensordict import TensorDict + >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index 8213fa7044f..3d90d0174b9 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -12,8 +12,8 @@ import numpy as np import torch import torch.nn as nn +from tensordict import TensorDict, TensorDictBase from tensordict.nn import dispatch, TensorDictModule -from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import NestedKey, unravel_key from torch import Tensor @@ -90,7 +90,7 @@ class CQLLoss(LossModule): >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.cql import CQLLoss - >>> from tensordict.tensordict import TensorDict + >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) diff --git a/torchrl/objectives/ddpg.py b/torchrl/objectives/ddpg.py index 3b4debe6259..70239ea62e9 100644 --- a/torchrl/objectives/ddpg.py +++ b/torchrl/objectives/ddpg.py @@ -11,8 +11,8 @@ from typing import Tuple import torch +from tensordict import TensorDict, TensorDictBase from tensordict.nn import dispatch, TensorDictModule -from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import NestedKey, unravel_key from torchrl.modules.tensordict_module.actors import ActorCriticWrapper @@ -49,7 +49,7 @@ class DDPGLoss(LossModule): >>> from torchrl.data import BoundedTensorSpec >>> from torchrl.modules.tensordict_module.actors import Actor, ValueOperator >>> from torchrl.objectives.ddpg import DDPGLoss - >>> from tensordict.tensordict import TensorDict + >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> actor = Actor(spec=spec, module=nn.Linear(n_obs, n_act)) diff --git a/torchrl/objectives/decision_transformer.py b/torchrl/objectives/decision_transformer.py index 52339d583dd..a24aa4a1271 100644 --- a/torchrl/objectives/decision_transformer.py +++ b/torchrl/objectives/decision_transformer.py @@ -8,8 +8,8 @@ from typing import Union import torch +from tensordict import TensorDict, TensorDictBase from tensordict.nn import dispatch -from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import NestedKey from torch import distributions as d diff --git a/torchrl/objectives/deprecated.py b/torchrl/objectives/deprecated.py index ea329a2b726..e920bc83960 100644 --- a/torchrl/objectives/deprecated.py +++ b/torchrl/objectives/deprecated.py @@ -11,9 +11,8 @@ import numpy as np import torch -from tensordict import TensorDict +from tensordict import TensorDict, TensorDictBase from tensordict.nn import dispatch, TensorDictModule -from tensordict.tensordict import TensorDictBase from tensordict.utils import NestedKey from torch import Tensor diff --git a/torchrl/objectives/iql.py b/torchrl/objectives/iql.py index aa0ada74801..1fd48675cb4 100644 --- a/torchrl/objectives/iql.py +++ b/torchrl/objectives/iql.py @@ -7,8 +7,8 @@ from typing import Optional, Tuple, Union import torch +from tensordict import TensorDict, TensorDictBase from tensordict.nn import dispatch, TensorDictModule -from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import NestedKey from torch import Tensor from torchrl.data.tensor_specs import TensorSpec @@ -63,7 +63,7 @@ class IQLLoss(LossModule): >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.iql import IQLLoss - >>> from tensordict.tensordict import TensorDict + >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) @@ -538,7 +538,7 @@ class DiscreteIQLLoss(IQLLoss): >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.iql import DiscreteIQLLoss - >>> from tensordict.tensordict import TensorDict + >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 >>> spec = OneHotDiscreteTensorSpec(n_act) >>> module = TensorDictModule(nn.Linear(n_obs, n_act), in_keys=["observation"], out_keys=["logits"]) @@ -597,7 +597,7 @@ class DiscreteIQLLoss(IQLLoss): >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.iql import DiscreteIQLLoss - >>> from tensordict.tensordict import TensorDict + >>> from tensordict import TensorDict >>> _ = torch.manual_seed(42) >>> n_act, n_obs = 4, 3 >>> spec = OneHotDiscreteTensorSpec(n_act) diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 542877f8f20..0f7ea835949 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -14,8 +14,8 @@ from typing import Tuple import torch +from tensordict import TensorDict, TensorDictBase from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule -from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import NestedKey from torch import distributions as d @@ -136,7 +136,7 @@ class PPOLoss(LossModule): >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.ppo import PPOLoss - >>> from tensordict.tensordict import TensorDict + >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> base_layer = nn.Linear(n_obs, 5) diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index cac829964fc..af0a94cbc96 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -9,9 +9,9 @@ from typing import Union import torch +from tensordict import TensorDict, TensorDictBase from tensordict.nn import dispatch, TensorDictModule, TensorDictSequential -from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import NestedKey from torch import Tensor @@ -86,7 +86,7 @@ class REDQLoss(LossModule): >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.redq import REDQLoss - >>> from tensordict.tensordict import TensorDict + >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py index 98c4d4d14d3..c9cc8f383ad 100644 --- a/torchrl/objectives/reinforce.py +++ b/torchrl/objectives/reinforce.py @@ -11,9 +11,9 @@ from dataclasses import dataclass import torch +from tensordict import TensorDict, TensorDictBase from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule -from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import NestedKey from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( @@ -95,7 +95,7 @@ class ReinforceLoss(LossModule): >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.reinforce import ReinforceLoss - >>> from tensordict.tensordict import TensorDict + >>> from tensordict import TensorDict >>> n_obs, n_act = 3, 5 >>> value_net = ValueOperator(nn.Linear(n_obs, 1), in_keys=["observation"]) >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 4da874148e7..431296e7486 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -11,9 +11,9 @@ import numpy as np import torch +from tensordict import TensorDict, TensorDictBase from tensordict.nn import dispatch, TensorDictModule -from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import NestedKey from torch import Tensor from torchrl.data.tensor_specs import CompositeSpec, TensorSpec @@ -106,7 +106,7 @@ class SACLoss(LossModule): >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.sac import SACLoss - >>> from tensordict.tensordict import TensorDict + >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index 7736c5cfbbf..e1aeb253681 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -7,9 +7,9 @@ from typing import Optional, Tuple import torch -from tensordict.nn import dispatch, TensorDictModule -from tensordict.tensordict import TensorDict, TensorDictBase +from tensordict import TensorDict, TensorDictBase +from tensordict.nn import dispatch, TensorDictModule from tensordict.utils import NestedKey from torchrl.data.tensor_specs import BoundedTensorSpec, CompositeSpec, TensorSpec @@ -75,7 +75,7 @@ class TD3Loss(LossModule): >>> from torchrl.modules.tensordict_module.actors import Actor, ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.td3 import TD3Loss - >>> from tensordict.tensordict import TensorDict + >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> module = nn.Linear(n_obs, n_act) diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index 91305a6a777..4c0b8ae67bd 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -10,8 +10,8 @@ from typing import Iterable, Optional, Union import torch +from tensordict import TensorDict, TensorDictBase from tensordict.nn import TensorDictModule -from tensordict.tensordict import TensorDict, TensorDictBase from torch import nn, Tensor from torch.nn import functional as F from torch.nn.modules import dropout diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index 1c43d536fe8..fc2e58a19f6 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -13,6 +13,7 @@ from typing import Callable, List, Optional, Union import torch +from tensordict import TensorDictBase from tensordict.nn import ( dispatch, is_functional, @@ -20,7 +21,6 @@ TensorDictModule, TensorDictModuleBase, ) -from tensordict.tensordict import TensorDictBase from tensordict.utils import NestedKey from torch import nn, Tensor diff --git a/torchrl/record/recorder.py b/torchrl/record/recorder.py index 1910c920a41..a6181145311 100644 --- a/torchrl/record/recorder.py +++ b/torchrl/record/recorder.py @@ -8,7 +8,7 @@ import torch -from tensordict.tensordict import TensorDictBase +from tensordict import TensorDictBase from tensordict.utils import NestedKey diff --git a/torchrl/trainers/helpers/collectors.py b/torchrl/trainers/helpers/collectors.py index f8f9c55809b..7063fb2f1c4 100644 --- a/torchrl/trainers/helpers/collectors.py +++ b/torchrl/trainers/helpers/collectors.py @@ -6,8 +6,9 @@ from dataclasses import dataclass, field from typing import Any, Callable, Dict, List, Optional, Type, Union +from tensordict import TensorDictBase + from tensordict.nn import ProbabilisticTensorDictSequential, TensorDictModuleWrapper -from tensordict.tensordict import TensorDictBase from torchrl.collectors.collectors import ( DataCollectorBase, diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index 6985037d17c..c8629be7f15 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -16,8 +16,8 @@ import numpy as np import torch.nn +from tensordict import pad, TensorDictBase from tensordict.nn import TensorDictModule -from tensordict.tensordict import pad, TensorDictBase from tensordict.utils import expand_right from torch import nn, optim diff --git a/tutorials/sphinx-tutorials/coding_ddpg.py b/tutorials/sphinx-tutorials/coding_ddpg.py index 1f69651a3b4..85590c545fa 100644 --- a/tutorials/sphinx-tutorials/coding_ddpg.py +++ b/tutorials/sphinx-tutorials/coding_ddpg.py @@ -340,7 +340,7 @@ def _loss_value( # value and actor loss, collect the cost values and write them in a tensordict # delivered to the user. -from tensordict.tensordict import TensorDict, TensorDictBase +from tensordict import TensorDict, TensorDictBase def _forward(self, input_tensordict: TensorDictBase) -> TensorDict: diff --git a/tutorials/sphinx-tutorials/pendulum.py b/tutorials/sphinx-tutorials/pendulum.py index b72d2ff0f92..12c8bdc3193 100644 --- a/tutorials/sphinx-tutorials/pendulum.py +++ b/tutorials/sphinx-tutorials/pendulum.py @@ -96,8 +96,8 @@ import numpy as np import torch import tqdm +from tensordict import TensorDict, TensorDictBase from tensordict.nn import TensorDictModule -from tensordict.tensordict import TensorDict, TensorDictBase from torch import nn from torchrl.data import BoundedTensorSpec, CompositeSpec, UnboundedContinuousTensorSpec