diff --git a/.circleci/unittest/linux_libs/scripts_gym/install.sh b/.circleci/unittest/linux_libs/scripts_gym/install.sh index 9e3739fe2b2..0cdee0320c1 100755 --- a/.circleci/unittest/linux_libs/scripts_gym/install.sh +++ b/.circleci/unittest/linux_libs/scripts_gym/install.sh @@ -44,5 +44,9 @@ fi # install tensordict pip install git+https://github.com/pytorch-labs/tensordict +# smoke test +python -c "import tensordict" + printf "* Installing torchrl\n" python setup.py develop +python -c "import torchrl" diff --git a/.circleci/unittest/linux_libs/scripts_gym/run_test.sh b/.circleci/unittest/linux_libs/scripts_gym/run_test.sh index 3e151539d92..4e08dec58e1 100755 --- a/.circleci/unittest/linux_libs/scripts_gym/run_test.sh +++ b/.circleci/unittest/linux_libs/scripts_gym/run_test.sh @@ -5,6 +5,8 @@ set -e eval "$(./conda/bin/conda shell.bash hook)" conda activate ./env +yum makecache && yum install libglvnd-devel mesa-libGL mesa-libGL-devel mesa-libEGL mesa-libEGL-devel glfw mesa-libOSMesa-devel glew glew-devel egl-utils freeglut xorg-x11-server-Xvfb -y + export PYTORCH_TEST_WITH_SLOW='1' python -m torch.utils.collect_env # Avoid error: "fatal: unsafe repository" diff --git a/.circleci/unittest/linux_libs/scripts_habitat/install.sh b/.circleci/unittest/linux_libs/scripts_habitat/install.sh index af2f78de49f..e5833cd1356 100755 --- a/.circleci/unittest/linux_libs/scripts_habitat/install.sh +++ b/.circleci/unittest/linux_libs/scripts_habitat/install.sh @@ -41,7 +41,7 @@ fi pip install git+https://github.com/pytorch-labs/tensordict # smoke test -python -c "import functorch" +python -c "import functorch;import tensordict" printf "* Installing torchrl\n" pip3 install -e . diff --git a/.circleci/unittest/linux_libs/scripts_jumanji/install.sh b/.circleci/unittest/linux_libs/scripts_jumanji/install.sh index c0f97977649..767070f2b25 100755 --- a/.circleci/unittest/linux_libs/scripts_jumanji/install.sh +++ b/.circleci/unittest/linux_libs/scripts_jumanji/install.sh @@ -35,8 +35,11 @@ else pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cu116 --force-reinstall fi +# install tensordict +pip install git+https://github.com/pytorch-labs/tensordict + # smoke test -python -c "import functorch" +python -c "import functorch;import tensordict" printf "* Installing torchrl\n" pip3 install -e . diff --git a/.circleci/unittest/linux_olddeps/scripts_gym_0_13/install.sh b/.circleci/unittest/linux_olddeps/scripts_gym_0_13/install.sh index 9e3739fe2b2..0cdee0320c1 100755 --- a/.circleci/unittest/linux_olddeps/scripts_gym_0_13/install.sh +++ b/.circleci/unittest/linux_olddeps/scripts_gym_0_13/install.sh @@ -44,5 +44,9 @@ fi # install tensordict pip install git+https://github.com/pytorch-labs/tensordict +# smoke test +python -c "import tensordict" + printf "* Installing torchrl\n" python setup.py develop +python -c "import torchrl" diff --git a/test/test_cost.py b/test/test_cost.py index 40fdd8919f6..e1edaac8f4b 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -6,17 +6,14 @@ import argparse from copy import deepcopy -from tensordict.nn.functional_modules import FunctionalModuleWithBuffers - _has_functorch = True -FUNCTORCH_ERR = "" try: - import functorch + import functorch as ft # noqa - make_functional_with_buffers = functorch.make_functional_with_buffers + make_functional_with_buffers = ft.make_functional_with_buffers + FUNCTORCH_ERR = "" except ImportError as err: _has_functorch = False - make_functional_with_buffers = FunctionalModuleWithBuffers._create_from FUNCTORCH_ERR = str(err) import numpy as np @@ -24,9 +21,10 @@ import torch from _utils_internal import dtype_fixture, get_available_devices # noqa from mocking_classes import ContinuousActionConvMockEnv +from tensordict.nn import get_functional # from torchrl.data.postprocs.utils import expand_as_right -from tensordict.tensordict import assert_allclose_td, TensorDict, TensorDictBase +from tensordict.tensordict import assert_allclose_td, TensorDict from tensordict.utils import expand_as_right from torch import autograd, nn from torchrl.data import ( @@ -65,6 +63,7 @@ ProbabilisticActor, ValueOperator, ) +from torchrl.modules.utils import Buffer from torchrl.objectives import ( A2CLoss, ClipPPOLoss, @@ -270,56 +269,12 @@ def _create_seq_mock_data_dqn( ) return td - @pytest.mark.skipif( - not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}" - ) @pytest.mark.parametrize("delay_value", (False, True)) @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize( "action_spec_type", ("nd_bounded", "one_hot", "categorical") ) - @pytest.mark.parametrize("is_nn_module", (False, True)) - def test_dqn(self, delay_value, device, action_spec_type, is_nn_module): - torch.manual_seed(self.seed) - actor = self._create_mock_actor( - action_spec_type=action_spec_type, device=device, is_nn_module=is_nn_module - ) - td = self._create_mock_data_dqn( - action_spec_type=action_spec_type, device=device - ) - loss_fn = DQNLoss(actor, gamma=0.9, loss_function="l2", delay_value=delay_value) - with _check_td_steady(td): - loss = loss_fn(td) - assert loss_fn.priority_key in td.keys() - - sum([item for _, item in loss.items()]).backward() - assert torch.nn.utils.clip_grad.clip_grad_norm_(actor.parameters(), 1.0) > 0.0 - - # Check param update effect on targets - target_value = [p.clone() for p in loss_fn.target_value_network_params] - for p in loss_fn.parameters(): - p.data += torch.randn_like(p) - target_value2 = [p.clone() for p in loss_fn.target_value_network_params] - if loss_fn.delay_value: - assert all((p1 == p2).all() for p1, p2 in zip(target_value, target_value2)) - else: - assert not any( - (p1 == p2).any() for p1, p2 in zip(target_value, target_value2) - ) - - # check that policy is updated after parameter update - parameters = [p.clone() for p in actor.parameters()] - for p in loss_fn.parameters(): - p.data += torch.randn_like(p) - assert all((p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters())) - - @pytest.mark.skipif(_has_functorch, reason="functorch installed") - @pytest.mark.parametrize("delay_value", (False, True)) - @pytest.mark.parametrize("device", get_available_devices()) - @pytest.mark.parametrize( - "action_spec_type", ("nd_bounded", "one_hot", "categorical") - ) - def test_dqn_nofunctorch(self, delay_value, device, action_spec_type): + def test_dqn(self, delay_value, device, action_spec_type): torch.manual_seed(self.seed) actor = self._create_mock_actor( action_spec_type=action_spec_type, device=device @@ -351,9 +306,6 @@ def test_dqn_nofunctorch(self, delay_value, device, action_spec_type): p.data += torch.randn_like(p) assert all((p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters())) - @pytest.mark.skipif( - not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}" - ) @pytest.mark.parametrize("n", range(4)) @pytest.mark.parametrize("delay_value", (False, True)) @pytest.mark.parametrize("device", get_available_devices()) @@ -395,68 +347,6 @@ def test_dqn_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9): sum([item for _, item in loss_ms.items()]).backward() assert torch.nn.utils.clip_grad.clip_grad_norm_(actor.parameters(), 1.0) > 0.0 - # Check param update effect on targets - target_value = [p.clone() for p in loss_fn.target_value_network_params] - for p in loss_fn.parameters(): - p.data += torch.randn_like(p) - target_value2 = [p.clone() for p in loss_fn.target_value_network_params] - if loss_fn.delay_value: - assert all((p1 == p2).all() for p1, p2 in zip(target_value, target_value2)) - else: - assert not any( - (p1 == p2).any() for p1, p2 in zip(target_value, target_value2) - ) - - # check that policy is updated after parameter update - parameters = [p.clone() for p in actor.parameters()] - for p in loss_fn.parameters(): - p.data += torch.randn_like(p) - assert all((p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters())) - - @pytest.mark.skipif(_has_functorch, reason="functorch installed") - @pytest.mark.parametrize("n", range(4)) - @pytest.mark.parametrize("delay_value", (False, True)) - @pytest.mark.parametrize("device", get_available_devices()) - @pytest.mark.parametrize( - "action_spec_type", ("nd_bounded", "one_hot", "categorical") - ) - def test_dqn_batcher_nofunctorch( - self, n, delay_value, device, action_spec_type, gamma=0.9 - ): - torch.manual_seed(self.seed) - actor = self._create_mock_actor( - action_spec_type=action_spec_type, device=device - ) - - td = self._create_seq_mock_data_dqn( - action_spec_type=action_spec_type, device=device - ) - loss_fn = DQNLoss( - actor, gamma=gamma, loss_function="l2", delay_value=delay_value - ) - - ms = MultiStep(gamma=gamma, n_steps_max=n).to(device) - ms_td = ms(td.clone()) - - with _check_td_steady(ms_td): - loss_ms = loss_fn(ms_td) - assert loss_fn.priority_key in ms_td.keys() - - with torch.no_grad(): - loss = loss_fn(td) - if n == 0: - assert_allclose_td(td, ms_td.select(*list(td.keys()))) - _loss = sum([item for _, item in loss.items()]) - _loss_ms = sum([item for _, item in loss_ms.items()]) - assert ( - abs(_loss - _loss_ms) < 1e-3 - ), f"found abs(loss-loss_ms) = {abs(loss - loss_ms):4.5f} for n=0" - else: - with pytest.raises(AssertionError): - assert_allclose_td(loss, loss_ms) - sum([item for _, item in loss_ms.items()]).backward() - assert torch.nn.utils.clip_grad.clip_grad_norm_(actor.parameters(), 1.0) > 0.0 - # Check param update effect on targets target_value = loss_fn.target_value_network_params.clone() for p in loss_fn.parameters(): @@ -473,62 +363,13 @@ def test_dqn_batcher_nofunctorch( p.data += torch.randn_like(p) assert all((p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters())) - @pytest.mark.skipif( - not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}" - ) @pytest.mark.parametrize("atoms", range(4, 10)) @pytest.mark.parametrize("delay_value", (False, True)) @pytest.mark.parametrize("device", get_devices()) @pytest.mark.parametrize( "action_spec_type", ("mult_one_hot", "one_hot", "categorical") ) - @pytest.mark.parametrize("is_nn_module", (False, True)) def test_distributional_dqn( - self, atoms, delay_value, device, action_spec_type, is_nn_module, gamma=0.9 - ): - torch.manual_seed(self.seed) - actor = self._create_mock_distributional_actor( - action_spec_type=action_spec_type, atoms=atoms, is_nn_module=is_nn_module - ).to(device) - - td = self._create_mock_data_dqn( - action_spec_type=action_spec_type, atoms=atoms - ).to(device) - loss_fn = DistributionalDQNLoss(actor, gamma=gamma, delay_value=delay_value) - - with _check_td_steady(td): - loss = loss_fn(td) - assert loss_fn.priority_key in td.keys() - - sum([item for _, item in loss.items()]).backward() - assert torch.nn.utils.clip_grad.clip_grad_norm_(actor.parameters(), 1.0) > 0.0 - - # Check param update effect on targets - target_value = [p.clone() for p in loss_fn.target_value_network_params] - for p in loss_fn.parameters(): - p.data += torch.randn_like(p) - target_value2 = [p.clone() for p in loss_fn.target_value_network_params] - if loss_fn.delay_value: - assert all((p1 == p2).all() for p1, p2 in zip(target_value, target_value2)) - else: - assert not any( - (p1 == p2).any() for p1, p2 in zip(target_value, target_value2) - ) - - # check that policy is updated after parameter update - parameters = [p.clone() for p in actor.parameters()] - for p in loss_fn.parameters(): - p.data += torch.randn_like(p) - assert all((p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters())) - - @pytest.mark.skipif(_has_functorch, reason="functorch installed") - @pytest.mark.parametrize("atoms", range(4, 10)) - @pytest.mark.parametrize("delay_value", (False, True)) - @pytest.mark.parametrize("device", get_devices()) - @pytest.mark.parametrize( - "action_spec_type", ("mult_one_hot", "one_hot", "categorical") - ) - def test_distributional_dqn_nofunctorch( self, atoms, delay_value, device, action_spec_type, gamma=0.9 ): torch.manual_seed(self.seed) @@ -556,7 +397,10 @@ def test_distributional_dqn_nofunctorch( if loss_fn.delay_value: assert_allclose_td(target_value, target_value2) else: - assert not (target_value == target_value2).any() + for key, val in target_value.flatten_keys(",").items(): + if key in ("support",): + continue + assert not (val == target_value2[tuple(key.split(","))]).any(), key # check that policy is updated after parameter update parameters = [p.clone() for p in actor.parameters()] @@ -678,6 +522,14 @@ def test_ddpg(self, delay_actor, delay_value, device): with _check_td_steady(td): loss = loss_fn(td) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.value_network_params.values(True, True) + ) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values(True, True) + ) # check that losses are independent for k in loss.keys(): if not k.startswith("loss"): @@ -686,20 +538,20 @@ def test_ddpg(self, delay_actor, delay_value, device): if k == "loss_actor": assert all( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.value_network_params + for p in loss_fn.value_network_params.values(True, True) ) assert not any( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.actor_network_params + for p in loss_fn.actor_network_params.values(True, True) ) elif k == "loss_value": assert all( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.actor_network_params + for p in loss_fn.actor_network_params.values(True, True) ) assert not any( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.value_network_params + for p in loss_fn.value_network_params.values(True, True) ) else: raise NotImplementedError(k) @@ -712,12 +564,18 @@ def test_ddpg(self, delay_actor, delay_value, device): assert p.grad.norm() > 0.0 # Check param update effect on targets - target_actor = [p.clone() for p in loss_fn.target_actor_network_params] - target_value = [p.clone() for p in loss_fn.target_value_network_params] - for p in loss_fn.parameters(): + target_actor = [p.clone() for p in loss_fn.target_actor_network_params.values()] + target_value = [p.clone() for p in loss_fn.target_value_network_params.values()] + _i = -1 + for _i, p in enumerate(loss_fn.parameters()): p.data += torch.randn_like(p) - target_actor2 = [p.clone() for p in loss_fn.target_actor_network_params] - target_value2 = [p.clone() for p in loss_fn.target_value_network_params] + assert _i >= 0 + target_actor2 = [ + p.clone() for p in loss_fn.target_actor_network_params.values() + ] + target_value2 = [ + p.clone() for p in loss_fn.target_value_network_params.values() + ] if loss_fn.delay_actor: assert all((p1 == p2).all() for p1, p2 in zip(target_actor, target_actor2)) else: @@ -930,54 +788,78 @@ def test_sac(self, delay_value, delay_actor, delay_qvalue, num_qvalue, device): if k == "loss_actor": assert all( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.value_network_params + for p in loss_fn.value_network_params.values( + include_nested=True, leaves_only=True + ) ) assert all( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.qvalue_network_params + for p in loss_fn.qvalue_network_params.values( + include_nested=True, leaves_only=True + ) ) assert not any( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.actor_network_params + for p in loss_fn.actor_network_params.values( + include_nested=True, leaves_only=True + ) ) elif k == "loss_value": assert all( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.actor_network_params + for p in loss_fn.actor_network_params.values( + include_nested=True, leaves_only=True + ) ) assert all( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.qvalue_network_params + for p in loss_fn.qvalue_network_params.values( + include_nested=True, leaves_only=True + ) ) assert not any( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.value_network_params + for p in loss_fn.value_network_params.values( + include_nested=True, leaves_only=True + ) ) elif k == "loss_qvalue": assert all( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.actor_network_params + for p in loss_fn.actor_network_params.values( + include_nested=True, leaves_only=True + ) ) assert all( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.value_network_params + for p in loss_fn.value_network_params.values( + include_nested=True, leaves_only=True + ) ) assert not any( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.qvalue_network_params + for p in loss_fn.qvalue_network_params.values( + include_nested=True, leaves_only=True + ) ) elif k == "loss_alpha": assert all( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.actor_network_params + for p in loss_fn.actor_network_params.values( + include_nested=True, leaves_only=True + ) ) assert all( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.value_network_params + for p in loss_fn.value_network_params.values( + include_nested=True, leaves_only=True + ) ) assert all( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.qvalue_network_params + for p in loss_fn.qvalue_network_params.values( + include_nested=True, leaves_only=True + ) ) else: raise NotImplementedError(k) @@ -1063,14 +945,44 @@ def test_sac_batcher( assert p.grad.norm() > 0.0, f"parameter {name} has null gradient" # Check param update effect on targets - target_actor = [p.clone() for p in loss_fn.target_actor_network_params] - target_qvalue = [p.clone() for p in loss_fn.target_qvalue_network_params] - target_value = [p.clone() for p in loss_fn.target_value_network_params] + target_actor = [ + p.clone() + for p in loss_fn.target_actor_network_params.values( + include_nested=True, leaves_only=True + ) + ] + target_qvalue = [ + p.clone() + for p in loss_fn.target_qvalue_network_params.values( + include_nested=True, leaves_only=True + ) + ] + target_value = [ + p.clone() + for p in loss_fn.target_value_network_params.values( + include_nested=True, leaves_only=True + ) + ] for p in loss_fn.parameters(): p.data += torch.randn_like(p) - target_actor2 = [p.clone() for p in loss_fn.target_actor_network_params] - target_qvalue2 = [p.clone() for p in loss_fn.target_qvalue_network_params] - target_value2 = [p.clone() for p in loss_fn.target_value_network_params] + target_actor2 = [ + p.clone() + for p in loss_fn.target_actor_network_params.values( + include_nested=True, leaves_only=True + ) + ] + target_qvalue2 = [ + p.clone() + for p in loss_fn.target_qvalue_network_params.values( + include_nested=True, leaves_only=True + ) + ] + target_value2 = [ + p.clone() + for p in loss_fn.target_value_network_params.values( + include_nested=True, leaves_only=True + ) + ] if loss_fn.delay_actor: assert all((p1 == p2).all() for p1, p2 in zip(target_actor, target_actor2)) else: @@ -1264,29 +1176,41 @@ def test_redq(self, delay_qvalue, num_qvalue, device): if k == "loss_actor": assert all( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.qvalue_network_params + for p in loss_fn.qvalue_network_params.values( + include_nested=True, leaves_only=True + ) ) assert not any( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.actor_network_params + for p in loss_fn.actor_network_params.values( + include_nested=True, leaves_only=True + ) ) elif k == "loss_qvalue": assert all( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.actor_network_params + for p in loss_fn.actor_network_params.values( + include_nested=True, leaves_only=True + ) ) assert not any( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.qvalue_network_params + for p in loss_fn.qvalue_network_params.values( + include_nested=True, leaves_only=True + ) ) elif k == "loss_alpha": assert all( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.actor_network_params + for p in loss_fn.actor_network_params.values( + include_nested=True, leaves_only=True + ) ) assert all( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.qvalue_network_params + for p in loss_fn.qvalue_network_params.values( + include_nested=True, leaves_only=True + ) ) else: raise NotImplementedError(k) @@ -1339,29 +1263,30 @@ def test_redq_shared(self, delay_qvalue, num_qvalue, device): if k == "loss_actor": assert all( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn._qvalue_network_params + for p in loss_fn.qvalue_network_params.values(True, True) ) assert not any( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn._actor_network_params + for p in loss_fn.actor_network_params.values(True, True) ) elif k == "loss_qvalue": assert all( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn._actor_network_params + for p in loss_fn.actor_network_params.values(True, True) ) assert not any( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn._qvalue_network_params + for p in loss_fn.qvalue_network_params.values(True, True) + if isinstance(p, nn.Parameter) ) elif k == "loss_alpha": assert all( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn._actor_network_params + for p in loss_fn.actor_network_params.values(True, True) ) assert all( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn._qvalue_network_params + for p in loss_fn.qvalue_network_params.values(True, True) ) else: raise NotImplementedError(k) @@ -1385,13 +1310,19 @@ def test_redq_shared(self, delay_qvalue, num_qvalue, device): p.data *= 0 counter = 0 - for p in loss_fn.qvalue_network_params: + for key, p in loss_fn.qvalue_network_params.items(True, True): + if not isinstance(key, tuple): + key = (key,) if not isinstance(p, nn.Parameter): counter += 1 - assert (p == loss_fn._param_maps[p]).all() + key = "_sep_".join(["qvalue_network", *key]) + mapped_param = next( + (k for k, val in loss_fn._param_maps.items() if val == key) + ) + assert (p == getattr(loss_fn, mapped_param)).all() assert (p == 0).all() - assert counter == len(loss_fn._actor_network_params) - assert counter == len(loss_fn.actor_network_params) + assert counter == len(loss_fn._actor_network_params.keys(True, True)) + assert counter == len(loss_fn.actor_network_params.keys(True, True)) # check that params of the original actor are those of the loss_fn for p in actor.parameters(): @@ -1497,12 +1428,20 @@ def test_redq_batcher(self, n, delay_qvalue, num_qvalue, device, gamma=0.9): assert p.grad.norm() > 0.0, f"parameter {name} has null gradient" # Check param update effect on targets - target_actor = [p.clone() for p in loss_fn.target_actor_network_params] - target_qvalue = [p.clone() for p in loss_fn.target_qvalue_network_params] + target_actor = loss_fn.target_actor_network_params.clone().values( + include_nested=True, leaves_only=True + ) + target_qvalue = loss_fn.target_qvalue_network_params.clone().values( + include_nested=True, leaves_only=True + ) for p in loss_fn.parameters(): p.data += torch.randn_like(p) - target_actor2 = [p.clone() for p in loss_fn.target_actor_network_params] - target_qvalue2 = [p.clone() for p in loss_fn.target_qvalue_network_params] + target_actor2 = loss_fn.target_actor_network_params.clone().values( + include_nested=True, leaves_only=True + ) + target_qvalue2 = loss_fn.target_qvalue_network_params.clone().values( + include_nested=True, leaves_only=True + ) if loss_fn.delay_actor: assert all((p1 == p2).all() for p1, p2 in zip(target_actor, target_actor2)) else: @@ -1554,6 +1493,29 @@ def _create_mock_value(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): ) return value.to(device) + def _create_mock_actor_value(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): + # Actor + action_spec = NdBoundedTensorSpec( + -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) + ) + base_layer = nn.Linear(obs_dim, 5) + net = NormalParamWrapper( + nn.Sequential(base_layer, nn.Linear(5, 2 * action_dim)) + ) + module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) + actor = ProbabilisticActor( + module=module, + distribution_class=TanhNormal, + dist_in_keys=["loc", "scale"], + spec=CompositeSpec(action=action_spec, loc=None, scale=None), + ) + module = nn.Sequential(base_layer, nn.Linear(5, 1)) + value = ValueOperator( + module=module, + in_keys=["observation"], + ) + return actor.to(device), value.to(device) + def _create_mock_distributional_actor( self, batch=2, obs_dim=3, action_dim=4, atoms=0, vmin=1, vmax=5 ): @@ -1657,26 +1619,105 @@ def test_ppo(self, loss_class, device, gradient_mode, advantage): loss_critic.backward(retain_graph=True) # check that grads are independent and non null named_parameters = loss_fn.named_parameters() + counter = 0 + for name, p in named_parameters: + if p.grad is not None and p.grad.norm() > 0.0: + counter += 1 + assert "actor" not in name + assert "critic" in name + if p.grad is None: + assert "actor" in name + assert "critic" not in name + assert counter == 2 + + value.zero_grad() + loss_objective.backward() + counter = 0 + named_parameters = loss_fn.named_parameters() + for name, p in named_parameters: + if p.grad is not None and p.grad.norm() > 0.0: + counter += 1 + assert "actor" in name + assert "critic" not in name + if p.grad is None: + assert "actor" not in name + assert "critic" in name + assert counter == 2 + actor.zero_grad() + + @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) + @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda")) + @pytest.mark.parametrize("device", get_available_devices()) + def test_ppo_shared(self, loss_class, device, advantage): + torch.manual_seed(self.seed) + td = self._create_seq_mock_data_ppo(device=device) + + actor, value = self._create_mock_actor_value(device=device) + if advantage == "gae": + advantage = GAE( + gamma=0.9, + lmbda=0.9, + value_network=value, + gradient_mode=False, + ) + elif advantage == "td": + advantage = TDEstimate( + gamma=0.9, + value_network=value, + gradient_mode=False, + ) + elif advantage == "td_lambda": + advantage = TDLambdaEstimate( + gamma=0.9, + lmbda=0.9, + value_network=value, + gradient_mode=False, + ) + else: + raise NotImplementedError + loss_fn = loss_class( + actor, + value, + gamma=0.9, + loss_critic_type="l2", + advantage_module=advantage, + ) + + loss = loss_fn(td) + loss_critic = loss["loss_critic"] + loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0) + loss_critic.backward(retain_graph=True) + # check that grads are independent and non null + named_parameters = loss_fn.named_parameters() + counter = 0 for name, p in named_parameters: if p.grad is not None and p.grad.norm() > 0.0: + counter += 1 assert "actor" not in name assert "critic" in name if p.grad is None: assert "actor" in name assert "critic" not in name + assert counter == 2 value.zero_grad() loss_objective.backward() named_parameters = loss_fn.named_parameters() + counter = 0 for name, p in named_parameters: if p.grad is not None and p.grad.norm() > 0.0: + counter += 1 assert "actor" in name assert "critic" not in name if p.grad is None: assert "actor" not in name assert "critic" in name actor.zero_grad() + assert counter == 4 + @pytest.mark.skipif( + not _has_functorch, reason=f"functorch not found, {FUNCTORCH_ERR}" + ) @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) @pytest.mark.parametrize("gradient_mode", (True, False)) @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda")) @@ -1707,58 +1748,41 @@ def test_ppo_diff(self, loss_class, device, gradient_mode, advantage): ) floss_fn, params, buffers = make_functional_with_buffers(loss_fn) - + # fill params with zero + for p in params: + p.data.zero_() + # assert len(list(floss_fn.parameters())) == 0 loss = floss_fn(params, buffers, td) loss_critic = loss["loss_critic"] loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0) loss_critic.backward(retain_graph=True) # check that grads are independent and non null named_parameters = loss_fn.named_parameters() - if _has_functorch: - for (name, _), p in zip(named_parameters, params): - if p.grad is not None and p.grad.norm() > 0.0: - assert "actor" not in name - assert "critic" in name - if p.grad is None: - assert "actor" in name - assert "critic" not in name - else: - for key, p in params.flatten_keys(".").items(): - if p.grad is not None and p.grad.norm() > 0.0: - assert "actor" not in key - assert "value" in key or "critic" in key - if p.grad is None: - assert "actor" in key - assert "value" not in key and "critic" not in key - - if _has_functorch: - for param in params: - param.grad = None - else: - for param in params.flatten_keys(".").values(): - param.grad = None + for (name, _), p in zip(named_parameters, params): + if p.grad is not None and p.grad.norm() > 0.0: + assert "actor" not in name + assert "critic" in name + if p.grad is None: + assert "actor" in name + assert "critic" not in name + + for param in params: + param.grad = None loss_objective.backward() named_parameters = loss_fn.named_parameters() - if _has_functorch: - for (name, _), p in zip(named_parameters, params): - if p.grad is not None and p.grad.norm() > 0.0: - assert "actor" in name - assert "critic" not in name - if p.grad is None: - assert "actor" not in name - assert "critic" in name - for param in params: - param.grad = None - else: - for key, p in params.flatten_keys(".").items(): - if p.grad is not None and p.grad.norm() > 0.0: - assert "actor" in key - assert "value" not in key and "critic" not in key - if p.grad is None: - assert "actor" not in key - assert "value" in key or "critic" in key - for param in params.flatten_keys(".").values(): - param.grad = None + + for (name, other_p), p in zip(named_parameters, params): + assert other_p.shape == p.shape + assert other_p.dtype == p.dtype + assert other_p.device == p.device + if p.grad is not None and p.grad.norm() > 0.0: + assert "actor" in name + assert "critic" not in name + if p.grad is None: + assert "actor" not in name + assert "critic" in name + for param in params: + param.grad = None class TestA2C: @@ -1858,7 +1882,7 @@ def test_a2c(self, device, gradient_mode, advantage): RuntimeError, match="tensordict stored action require grad.", ): - loss = loss_fn._log_probs(td) + _ = loss_fn._log_probs(td) td["action"].requires_grad = False # Check error is raised when advantage_diff_key present and does not required grad @@ -1900,6 +1924,9 @@ def test_a2c(self, device, gradient_mode, advantage): # test reset loss_fn.reset() + @pytest.mark.skipif( + not _has_functorch, reason=f"functorch not found, {FUNCTORCH_ERR}" + ) @pytest.mark.parametrize("gradient_mode", (True, False)) @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda")) @pytest.mark.parametrize("device", get_available_devices()) @@ -1936,51 +1963,27 @@ def test_a2c_diff(self, device, gradient_mode, advantage): loss_critic.backward(retain_graph=True) # check that grads are independent and non null named_parameters = loss_fn.named_parameters() - if _has_functorch: - for (name, _), p in zip(named_parameters, params): - if p.grad is not None and p.grad.norm() > 0.0: - assert "actor" not in name - assert "critic" in name - if p.grad is None: - assert "actor" in name - assert "critic" not in name - else: - for key, p in params.flatten_keys(".").items(): - if p.grad is not None and p.grad.norm() > 0.0: - assert "actor" not in key - assert "value" in key or "critic" in key - if p.grad is None: - assert "actor" in key - assert "value" not in key and "critic" not in key - - if _has_functorch: - for param in params: - param.grad = None - else: - for param in params.flatten_keys(".").values(): - param.grad = None + for (name, _), p in zip(named_parameters, params): + if p.grad is not None and p.grad.norm() > 0.0: + assert "actor" not in name + assert "critic" in name + if p.grad is None: + assert "actor" in name + assert "critic" not in name + + for param in params: + param.grad = None loss_objective.backward() named_parameters = loss_fn.named_parameters() - if _has_functorch: - for (name, _), p in zip(named_parameters, params): - if p.grad is not None and p.grad.norm() > 0.0: - assert "actor" in name - assert "critic" not in name - if p.grad is None: - assert "actor" not in name - assert "critic" in name - for param in params: - param.grad = None - else: - for key, p in params.flatten_keys(".").items(): - if p.grad is not None and p.grad.norm() > 0.0: - assert "actor" in key - assert "value" not in key and "critic" not in key - if p.grad is None: - assert "actor" not in key - assert "value" in key or "critic" in key - for param in params.flatten_keys(".").values(): - param.grad = None + for (name, _), p in zip(named_parameters, params): + if p.grad is not None and p.grad.norm() > 0.0: + assert "actor" in name + assert "critic" not in name + if p.grad is None: + assert "actor" not in name + assert "critic" in name + for param in params: + param.grad = None class TestReinforce: @@ -2008,20 +2011,20 @@ def test_reinforce_value_net(self, advantage, gradient_mode, delay_value): advantage_module = GAE( gamma=gamma, lmbda=0.9, - value_network=value_net.make_functional_with_buffers(clone=True)[0], + value_network=get_functional(value_net), gradient_mode=gradient_mode, ) elif advantage == "td": advantage_module = TDEstimate( gamma=gamma, - value_network=value_net.make_functional_with_buffers(clone=True)[0], + value_network=get_functional(value_net), gradient_mode=gradient_mode, ) elif advantage == "td_lambda": advantage_module = TDLambdaEstimate( gamma=0.9, lmbda=0.9, - value_network=value_net.make_functional_with_buffers(clone=True)[0], + value_network=get_functional(value_net), gradient_mode=gradient_mode, ) else: @@ -2496,21 +2499,28 @@ def test_hold_out(): @pytest.mark.parametrize("mode", ["hard", "soft"]) @pytest.mark.parametrize("value_network_update_interval", [100, 1000]) @pytest.mark.parametrize("device", get_available_devices()) -def test_updater(mode, value_network_update_interval, device): +@pytest.mark.parametrize( + "dtype", + [ + torch.float64, + torch.float32, + ], +) +def test_updater(mode, value_network_update_interval, device, dtype): torch.manual_seed(100) class custom_module_error(nn.Module): def __init__(self): super().__init__() - self._target_params = [torch.randn(3, 4)] - self._target_error_params = [torch.randn(3, 4)] + self.target_params = [torch.randn(3, 4)] + self.target_error_params = [torch.randn(3, 4)] self.params = nn.ParameterList( [nn.Parameter(torch.randn(3, 4, requires_grad=True))] ) module = custom_module_error().to(device) with pytest.raises( - RuntimeError, match="Your module seems to have a _target tensor list " + RuntimeError, match="Your module seems to have a target tensor list " ): if mode == "hard": upd = HardUpdate(module, value_network_update_interval) @@ -2524,21 +2534,18 @@ def __init__(self): self.convert_to_functional(module1, "module1", create_target_params=True) module2 = torch.nn.BatchNorm2d(10).eval() self.module2 = module2 - if _has_functorch: - iterator_params = self.target_module1_params - iterator_buffers = self.target_module1_buffers - else: - iterator_params = self.target_module1_params.values() - iterator_buffers = self.target_module1_buffers.values() + iterator_params = self.target_module1_params.values( + include_nested=True, leaves_only=True + ) for target in iterator_params: - target.data.normal_() - for target in iterator_buffers: if target.dtype is not torch.int64: target.data.normal_() else: target.data += 10 - module = custom_module().to(device) + module = custom_module().to(device).to(dtype) + _ = module.module1_params + _ = module.target_module1_params if mode == "hard": upd = HardUpdate( module, value_network_update_interval=value_network_update_interval @@ -2546,130 +2553,79 @@ def __init__(self): elif mode == "soft": upd = SoftUpdate(module, 1 - 1 / value_network_update_interval) upd.init_() - for _, v in upd._targets.items(): - if isinstance(v, TensorDictBase): - for _v in v.values(): - if _v.dtype is not torch.int64: - _v.copy_(torch.randn_like(_v)) - else: - _v += 10 + for _, _v in upd._targets.items(True, True): + if _v.dtype is not torch.int64: + _v.copy_(torch.randn_like(_v)) else: - for _v in v: - if _v.dtype is not torch.int64: - _v.copy_(torch.randn_like(_v)) - else: - _v += 10 + _v += 10 # total dist - if _has_functorch: - d0 = sum( - [ - (target_val[0] - val[0]).norm().item() - for (_, target_val), (_, val) in zip( - upd._targets.items(), upd._sources.items() - ) - ] - ) - else: - d0 = 0.0 - for (_, target_val), (_, val) in zip( - upd._targets.items(), upd._sources.items() - ): - for key in target_val.keys(): - if target_val[key].dtype == torch.long: - continue - d0 += (target_val[key] - val[key]).norm().item() + d0 = 0.0 + for (key, source_val) in upd._sources.items(True, True): + if not isinstance(key, tuple): + key = (key,) + key = ("target_" + key[0], *key[1:]) + target_val = upd._targets[key] + assert target_val.dtype is source_val.dtype, key + assert target_val.device == source_val.device, key + if target_val.dtype == torch.long: + continue + d0 += (target_val - source_val).norm().item() assert d0 > 0 if mode == "hard": for i in range(value_network_update_interval + 1): # test that no update is occuring until value_network_update_interval - if _has_functorch: - d1 = sum( - [ - (target_val[0] - val[0]).norm().item() - for (_, target_val), (_, val) in zip( - upd._targets.items(), upd._sources.items() - ) - ] - ) - else: - d1 = 0.0 - for (_, target_val), (_, val) in zip( - upd._targets.items(), upd._sources.items() - ): - for key in target_val.keys(): - if target_val[key].dtype == torch.long: - continue - d1 += (target_val[key] - val[key]).norm().item() + d1 = 0.0 + for (key, source_val) in upd._sources.items(True, True): + if not isinstance(key, tuple): + key = (key,) + key = ("target_" + key[0], *key[1:]) + target_val = upd._targets[key] + if target_val.dtype == torch.long: + continue + d1 += (target_val - source_val).norm().item() assert d1 == d0, i assert upd.counter == i upd.step() assert upd.counter == 0 # test that a new update has occured - if _has_functorch: - d1 = sum( - [ - (target_val[0] - val[0]).norm().item() - for (_, target_val), (_, val) in zip( - upd._targets.items(), upd._sources.items() - ) - ] - ) - else: - d1 = 0.0 - for (_, target_val), (_, val) in zip( - upd._targets.items(), upd._sources.items() - ): - for key in target_val.keys(): - if target_val[key].dtype == torch.long: - continue - d1 += (target_val[key] - val[key]).norm().item() + d1 = 0.0 + for (key, source_val) in upd._sources.items(True, True): + if not isinstance(key, tuple): + key = (key,) + key = ("target_" + key[0], *key[1:]) + target_val = upd._targets[key] + if target_val.dtype == torch.long: + continue + d1 += (target_val - source_val).norm().item() assert d1 < d0 elif mode == "soft": upd.step() - if _has_functorch: - d1 = sum( - [ - (target_val[0] - val[0]).norm().item() - for (_, target_val), (_, val) in zip( - upd._targets.items(), upd._sources.items() - ) - ] - ) - else: - d1 = 0.0 - for (_, target_val), (_, val) in zip( - upd._targets.items(), upd._sources.items() - ): - for key in target_val.keys(): - if target_val[key].dtype == torch.long: - continue - d1 += (target_val[key] - val[key]).norm().item() + d1 = 0.0 + for (key, source_val) in upd._sources.items(True, True): + if not isinstance(key, tuple): + key = (key,) + key = ("target_" + key[0], *key[1:]) + target_val = upd._targets[key] + if target_val.dtype == torch.long: + continue + d1 += (target_val - source_val).norm().item() assert d1 < d0 upd.init_() upd.step() - if _has_functorch: - d2 = sum( - [ - (target_val[0] - val[0]).norm().item() - for (_, target_val), (_, val) in zip( - upd._targets.items(), upd._sources.items() - ) - ] - ) - else: - d2 = 0.0 - for (_, target_val), (_, val) in zip( - upd._targets.items(), upd._sources.items() - ): - for key in target_val.keys(): - if target_val[key].dtype == torch.long: - continue - d2 += (target_val[key] - val[key]).norm().item() + d2 = 0.0 + for (key, source_val) in upd._sources.items(True, True): + if not isinstance(key, tuple): + key = (key,) + key = ("target_" + key[0], *key[1:]) + target_val = upd._targets[key] + if target_val.dtype == torch.long: + continue + d2 += (target_val - source_val).norm().item() assert d2 < 1e-6 @@ -2953,12 +2909,22 @@ def __init__(self, actor_network, qvalue_network): p.data += torch.randn_like(p) assert len(list(loss.parameters())) == 6 - assert len(loss.actor_network_params) == 4 - assert len(loss.qvalue_network_params) == 4 - for p in loss.actor_network_params: - assert isinstance(p, nn.Parameter) - assert (loss.qvalue_network_params[0] == loss.actor_network_params[0]).all() - assert (loss.qvalue_network_params[1] == loss.actor_network_params[1]).all() + assert ( + len(loss.actor_network_params.keys(include_nested=True, leaves_only=True)) == 4 + ) + assert ( + len(loss.qvalue_network_params.keys(include_nested=True, leaves_only=True)) == 4 + ) + for p in loss.actor_network_params.values(include_nested=True, leaves_only=True): + assert isinstance(p, nn.Parameter) or isinstance(p, Buffer) + for i, (key, value) in enumerate( + loss.qvalue_network_params.items(include_nested=True, leaves_only=True) + ): + p1 = value + p2 = loss.actor_network_params[key] + assert (p1 == p2).all() + if i == 1: + break # map module if dest == "double": @@ -2970,16 +2936,18 @@ def __init__(self, actor_network, qvalue_network): else: loss = loss.to(dest) - for p in loss.actor_network_params: + for p in loss.actor_network_params.values(include_nested=True, leaves_only=True): assert isinstance(p, nn.Parameter) assert p.dtype is expected_dtype assert p.device == torch.device(expected_device) - assert loss.qvalue_network_params[0].dtype is expected_dtype - assert loss.qvalue_network_params[1].dtype is expected_dtype - assert loss.qvalue_network_params[0].device == torch.device(expected_device) - assert loss.qvalue_network_params[1].device == torch.device(expected_device) - assert (loss.qvalue_network_params[0] == loss.actor_network_params[0]).all() - assert (loss.qvalue_network_params[1] == loss.actor_network_params[1]).all() + for i, (key, qvalparam) in enumerate( + loss.qvalue_network_params.items(include_nested=True, leaves_only=True) + ): + assert qvalparam.dtype is expected_dtype, (key, qvalparam) + assert qvalparam.device == torch.device(expected_device), key + assert (qvalparam == loss.actor_network_params[key]).all(), key + if i == 1: + break if __name__ == "__main__": diff --git a/test/test_functorch.py b/test/test_functorch.py deleted file mode 100644 index 7b043968afb..00000000000 --- a/test/test_functorch.py +++ /dev/null @@ -1,323 +0,0 @@ -import argparse - -import pytest -import torch - -try: - from functorch import vmap - - _has_functorch = True -except ImportError: - _has_functorch = False -from tensordict import TensorDict -from tensordict.nn.functional_modules import ( - FunctionalModule, - FunctionalModuleWithBuffers, -) -from torch import nn -from torchrl.modules import SafeModule, SafeSequential - - -@pytest.mark.skipif( - not _has_functorch, reason="vmap can only be tested when functorch is installed" -) -@pytest.mark.parametrize( - "moduletype,batch_params", - [ - ["linear", False], - ["bn1", True], - ["linear", True], - ], -) -def test_vmap_patch(moduletype, batch_params): - if moduletype == "linear": - module = nn.Linear(3, 4) - elif moduletype == "bn1": - module = nn.BatchNorm1d(3) - else: - raise NotImplementedError - if moduletype == "linear": - fmodule, params = FunctionalModule._create_from(module) - x = torch.randn(10, 1, 3) - if batch_params: - params = params.expand(10, *params.batch_size) - y = vmap(fmodule, (0, 0))(params, x) - else: - y = vmap(fmodule, (None, 0))(params, x) - assert y.shape == torch.Size([10, 1, 4]) - elif moduletype == "bn1": - fmodule, params, buffers = FunctionalModuleWithBuffers._create_from(module) - x = torch.randn(10, 2, 3) - if batch_params: - params = params.expand(10, *params.batch_size).contiguous() - buffers = buffers.expand(10, *buffers.batch_size).contiguous() - y = vmap(fmodule, (0, 0, 0))(params, buffers, x) - else: - raise NotImplementedError - assert y.shape == torch.Size([10, 2, 3]) - - -@pytest.mark.skipif( - not _has_functorch, reason="vmap can only be tested when functorch is installed" -) -@pytest.mark.parametrize( - "moduletype,batch_params", - [ - ["linear", False], - ["bn1", True], - ["linear", True], - ], -) -def test_vmap_tdmodule(moduletype, batch_params): - if moduletype == "linear": - module = nn.Linear(3, 4) - elif moduletype == "bn1": - module = nn.BatchNorm1d(3) - else: - raise NotImplementedError - if moduletype == "linear": - fmodule, params = FunctionalModule._create_from(module) - tdmodule = SafeModule(fmodule, in_keys=["x"], out_keys=["y"]) - x = torch.randn(10, 1, 3) - td = TensorDict({"x": x}, [10]) - if batch_params: - params = params.expand(10, *params.batch_size) - tdmodule(td, params=params, vmap=(0, 0)) - else: - tdmodule(td, params=params, vmap=(None, 0)) - y = td["y"] - assert y.shape == torch.Size([10, 1, 4]) - elif moduletype == "bn1": - fmodule, params, buffers = FunctionalModuleWithBuffers._create_from(module) - tdmodule = SafeModule(fmodule, in_keys=["x"], out_keys=["y"]) - x = torch.randn(10, 2, 3) - td = TensorDict({"x": x}, [10]) - if batch_params: - params = params.expand(10, *params.batch_size).contiguous() - buffers = buffers.expand(10, *buffers.batch_size).contiguous() - tdmodule(td, params=params, buffers=buffers, vmap=(0, 0, 0)) - else: - raise NotImplementedError - y = td["y"] - assert y.shape == torch.Size([10, 2, 3]) - - -@pytest.mark.skipif( - not _has_functorch, reason="vmap can only be tested when functorch is installed" -) -@pytest.mark.parametrize( - "moduletype,batch_params", - [ - ["linear", False], - ["bn1", True], - ["linear", True], - ], -) -def test_vmap_tdmodule_nativebuilt(moduletype, batch_params): - if moduletype == "linear": - module = nn.Linear(3, 4) - elif moduletype == "bn1": - module = nn.BatchNorm1d(3) - else: - raise NotImplementedError - if moduletype == "linear": - tdmodule = SafeModule(module, in_keys=["x"], out_keys=["y"]) - tdmodule, (params, buffers) = tdmodule.make_functional_with_buffers(native=True) - x = torch.randn(10, 1, 3) - td = TensorDict({"x": x}, [10]) - if batch_params: - params = params.expand(10, *params.batch_size) - buffers = buffers.expand(10, *buffers.batch_size) - tdmodule(td, params=params, buffers=buffers, vmap=(0, 0, 0)) - else: - tdmodule(td, params=params, buffers=buffers, vmap=(None, None, 0)) - y = td["y"] - assert y.shape == torch.Size([10, 1, 4]) - elif moduletype == "bn1": - tdmodule = SafeModule(module, in_keys=["x"], out_keys=["y"]) - tdmodule, (params, buffers) = tdmodule.make_functional_with_buffers(native=True) - x = torch.randn(10, 2, 3) - td = TensorDict({"x": x}, [10]) - if batch_params: - params = params.expand(10, *params.batch_size).contiguous() - buffers = buffers.expand(10, *buffers.batch_size).contiguous() - tdmodule(td, params=params, buffers=buffers, vmap=(0, 0, 0)) - else: - raise NotImplementedError - y = td["y"] - assert y.shape == torch.Size([10, 2, 3]) - - -@pytest.mark.skipif( - not _has_functorch, reason="vmap can only be tested when functorch is installed" -) -@pytest.mark.parametrize( - "moduletype,batch_params", - [ - ["linear", False], - ["bn1", True], - ["linear", True], - ], -) -def test_vmap_tdsequence(moduletype, batch_params): - if moduletype == "linear": - module1 = nn.Linear(3, 4) - fmodule1, params1 = FunctionalModule._create_from(module1) - module2 = nn.Linear(4, 5) - fmodule2, params2 = FunctionalModule._create_from(module2) - elif moduletype == "bn1": - module1 = nn.BatchNorm1d(3) - fmodule1, params1, buffers1 = FunctionalModuleWithBuffers._create_from(module1) - module2 = nn.BatchNorm1d(3) - fmodule2, params2, buffers2 = FunctionalModuleWithBuffers._create_from(module2) - else: - raise NotImplementedError - if moduletype == "linear": - tdmodule1 = SafeModule(fmodule1, in_keys=["x"], out_keys=["y"]) - tdmodule2 = SafeModule(fmodule2, in_keys=["y"], out_keys=["z"]) - params = TensorDict({"0": params1, "1": params2}, []) - tdmodule = SafeSequential(tdmodule1, tdmodule2) - assert {"0", "1"} == set(params.keys()) - x = torch.randn(10, 1, 3) - td = TensorDict({"x": x}, [10]) - if batch_params: - params = params.expand(10, *params.batch_size) - tdmodule(td, params=params, vmap=(0, 0)) - else: - tdmodule(td, params=params, vmap=(None, 0)) - z = td["z"] - assert z.shape == torch.Size([10, 1, 5]) - elif moduletype == "bn1": - tdmodule1 = SafeModule(fmodule1, in_keys=["x"], out_keys=["y"]) - tdmodule2 = SafeModule(fmodule2, in_keys=["y"], out_keys=["z"]) - params = TensorDict({"0": params1, "1": params2}, []) - buffers = TensorDict({"0": buffers1, "1": buffers2}, []) - tdmodule = SafeSequential(tdmodule1, tdmodule2) - assert {"0", "1"} == set(params.keys()) - assert {"0", "1"} == set(buffers.keys()) - x = torch.randn(10, 2, 3) - td = TensorDict({"x": x}, [10]) - if batch_params: - params = params.expand(10, *params.batch_size).contiguous() - buffers = buffers.expand(10, *buffers.batch_size).contiguous() - tdmodule(td, params=params, buffers=buffers, vmap=(0, 0, 0)) - else: - raise NotImplementedError - z = td["z"] - assert z.shape == torch.Size([10, 2, 3]) - - -@pytest.mark.skipif( - not _has_functorch, reason="vmap can only be tested when functorch is installed" -) -@pytest.mark.parametrize( - "moduletype,batch_params", - [ - ["linear", False], - ["bn1", True], - ["linear", True], - ], -) -def test_vmap_tdsequence_nativebuilt(moduletype, batch_params): - if moduletype == "linear": - module1 = nn.Linear(3, 4) - module2 = nn.Linear(4, 5) - elif moduletype == "bn1": - module1 = nn.BatchNorm1d(3) - module2 = nn.BatchNorm1d(3) - else: - raise NotImplementedError - if moduletype == "linear": - tdmodule1 = SafeModule(module1, in_keys=["x"], out_keys=["y"]) - tdmodule2 = SafeModule(module2, in_keys=["y"], out_keys=["z"]) - tdmodule = SafeSequential(tdmodule1, tdmodule2) - tdmodule, (params, buffers) = tdmodule.make_functional_with_buffers(native=True) - assert {"0", "1"} == set(params.keys()) - x = torch.randn(10, 1, 3) - td = TensorDict({"x": x}, [10]) - if batch_params: - params = params.expand(10, *params.batch_size) - buffers = buffers.expand(10, *buffers.batch_size) - tdmodule(td, params=params, buffers=buffers, vmap=(0, 0, 0)) - else: - tdmodule(td, params=params, buffers=buffers, vmap=(None, None, 0)) - z = td["z"] - assert z.shape == torch.Size([10, 1, 5]) - elif moduletype == "bn1": - tdmodule1 = SafeModule(module1, in_keys=["x"], out_keys=["y"]) - tdmodule2 = SafeModule(module2, in_keys=["y"], out_keys=["z"]) - tdmodule = SafeSequential(tdmodule1, tdmodule2) - tdmodule, (params, buffers) = tdmodule.make_functional_with_buffers(native=True) - assert {"0", "1"} == set(params.keys()) - assert {"0", "1"} == set(buffers.keys()) - x = torch.randn(10, 2, 3) - td = TensorDict({"x": x}, [10]) - if batch_params: - params = params.expand(10, *params.batch_size).contiguous() - buffers = buffers.expand(10, *buffers.batch_size).contiguous() - tdmodule(td, params=params, buffers=buffers, vmap=(0, 0, 0)) - else: - raise NotImplementedError - z = td["z"] - assert z.shape == torch.Size([10, 2, 3]) - - -@pytest.mark.skipif( - not _has_functorch, reason="vmap can only be tested when functorch is installed" -) -class TestNativeFunctorch: - def test_vamp_basic(self): - class MyModule(torch.nn.Module): - def forward(self, tensordict): - a = tensordict["a"] - return TensorDict( - {"a": a}, tensordict.batch_size, device=tensordict.device - ) - - tensordict = TensorDict({"a": torch.randn(3)}, []).expand(4) - out = vmap(MyModule(), (0,))(tensordict) - assert out.shape == torch.Size([4]) - assert out["a"].shape == torch.Size([4, 3]) - - def test_vamp_composed(self): - class MyModule(torch.nn.Module): - def forward(self, tensordict, tensor): - a = tensordict["a"] - return ( - TensorDict( - {"a": a}, tensordict.batch_size, device=tensordict.device - ), - tensor, - ) - - tensor = torch.randn(3) - tensordict = TensorDict({"a": torch.randn(3, 1)}, [3]).expand(4, 3) - out = vmap(MyModule(), (0, None))(tensordict, tensor) - - assert out[0].shape == torch.Size([4, 3]) - assert out[1].shape == torch.Size([4, 3]) - assert out[0]["a"].shape == torch.Size([4, 3, 1]) - - def test_vamp_composed_flipped(self): - class MyModule(torch.nn.Module): - def forward(self, tensordict, tensor): - a = tensordict["a"] - return ( - TensorDict( - {"a": a}, tensordict.batch_size, device=tensordict.device - ), - tensor, - ) - - tensor = torch.randn(3).expand(4, 3) - tensordict = TensorDict({"a": torch.randn(3, 1)}, [3]) - out = vmap(MyModule(), (None, 0))(tensordict, tensor) - - assert out[0].shape == torch.Size([4, 3]) - assert out[1].shape == torch.Size([4, 3]) - assert out[0]["a"].shape == torch.Size([4, 3, 1]) - - -if __name__ == "__main__": - args, unknown = argparse.ArgumentParser().parse_known_args() - pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_libs.py b/test/test_libs.py index 06e09a0521b..4e4b4b811b2 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -272,15 +272,15 @@ def test_td_creation_from_spec(env_lib, env_args, env_kwargs): ) env = env_lib(*env_args, **env_kwargs) td = env.rollout(max_steps=5) - td0 = td[0].flatten_keys(".") + td0 = td[0] fake_td = env.fake_tensordict() - fake_td = fake_td.flatten_keys(".") - td = td.flatten_keys(".") - assert set(fake_td.keys()) == set(td.keys()) - for key in fake_td.keys(): + assert set(fake_td.keys(include_nested=True, leaves_only=True)) == set( + td.keys(include_nested=True, leaves_only=True) + ) + for key in fake_td.keys(include_nested=True, leaves_only=True): assert fake_td.get(key).shape == td.get(key)[0].shape - for key in fake_td.keys(): + for key in fake_td.keys(include_nested=True, leaves_only=True): assert fake_td.get(key).shape == td0.get(key).shape assert fake_td.get(key).dtype == td0.get(key).dtype assert fake_td.get(key).device == td0.get(key).device diff --git a/test/test_modules.py b/test/test_modules.py index 3a83f48c18c..2f37daab5a8 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -11,10 +11,6 @@ from mocking_classes import MockBatchedUnLockedEnv from packaging import version from tensordict import TensorDict -from tensordict.nn.functional_modules import ( - FunctionalModule, - FunctionalModuleWithBuffers, -) from torch import nn from torchrl.data.tensor_specs import ( DiscreteTensorSpec, @@ -441,34 +437,6 @@ def test_lstm_net_nobatch(device, out_features, hidden_size): torch.testing.assert_close(tds_vec["hidden1_out"][-1], tds_loop["hidden1_out"][-1]) -class TestFunctionalModules: - def test_func_seq(self): - module = nn.Sequential(nn.Linear(3, 4), nn.Linear(4, 3)) - fmodule, params = FunctionalModule._create_from(module) - x = torch.randn(3) - assert (fmodule(params, x) == module(x)).all() - - def test_func_bn(self): - module = nn.Sequential(nn.Linear(3, 4), nn.BatchNorm1d(4)) - module.eval() - fmodule, params, buffers = FunctionalModuleWithBuffers._create_from(module) - x = torch.randn(10, 3) - assert (fmodule(params, buffers, x) == module(x)).all() - - def test_func_transformer(self): - torch.manual_seed(10) - batch = ( - (10,) - if version.parse(torch.__version__) >= version.parse("1.11") - else (1, 10) - ) - module = nn.Transformer(128) - module.eval() - fmodule, params, buffers = FunctionalModuleWithBuffers._create_from(module) - x = torch.randn(*batch, 128) - torch.testing.assert_close(fmodule(params, buffers, x, x), module(x, x)) - - class TestPlanner: @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("batch_size", [3, 5]) diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index 60f513ae213..ef6238777e2 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -7,22 +7,8 @@ import pytest import torch -from tensordict.tensordict import TensorDictBase - -_has_functorch = False -try: - from functorch import make_functional, make_functional_with_buffers - - _has_functorch = True -except ImportError: - from tensordict.nn.functional_modules import ( - FunctionalModule, - FunctionalModuleWithBuffers, - ) - - make_functional = FunctionalModule._create_from - make_functional_with_buffers = FunctionalModuleWithBuffers._create_from from tensordict import TensorDict +from tensordict.nn.functional_modules import make_functional from torch import nn from torchrl.data.tensor_specs import ( CompositeSpec, @@ -38,6 +24,14 @@ from torchrl.modules.tensordict_module.probabilistic import SafeProbabilisticModule from torchrl.modules.tensordict_module.sequence import SafeSequential +_has_functorch = False +try: + from functorch import vmap + + _has_functorch = True +except ImportError: + pass + class TestTDModule: def test_multiple_output(self): @@ -243,7 +237,7 @@ def test_functional(self, safe, spec_type): net = nn.Linear(3, 4 * param_multiplier) - fnet, params = make_functional(net) + params = make_functional(net) if spec_type is None: spec = None @@ -260,7 +254,7 @@ def test_functional(self, safe, spec_type): ): tensordict_module = SafeModule( spec=spec, - module=fnet, + module=net, in_keys=["in"], out_keys=["out"], safe=safe, @@ -269,7 +263,7 @@ def test_functional(self, safe, spec_type): else: tensordict_module = SafeModule( spec=spec, - module=fnet, + module=net, in_keys=["in"], out_keys=["out"], safe=safe, @@ -292,75 +286,11 @@ def test_functional_probabilistic(self, safe, spec_type): torch.manual_seed(0) param_multiplier = 2 - net = nn.Linear(3, 4 * param_multiplier) - in_keys = ["in"] - net = NormalParamWrapper(net) - fnet, params = make_functional(net) - tdnet = SafeModule( - module=fnet, spec=None, in_keys=in_keys, out_keys=["loc", "scale"] - ) + net = NormalParamWrapper(nn.Linear(3, 4 * param_multiplier)) + params = make_functional(net) - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = NdBoundedTensorSpec(-0.1, 0.1, 4) - elif spec_type == "unbounded": - spec = NdUnboundedContinuousTensorSpec(4) - else: - raise NotImplementedError - spec = ( - CompositeSpec(out=spec, loc=None, scale=None) if spec is not None else None - ) - - kwargs = {"distribution_class": TanhNormal} - - if safe and spec is None: - with pytest.raises( - RuntimeError, - match="is not a valid configuration as the tensor specs are not " - "specified", - ): - tensordict_module = SafeProbabilisticModule( - module=tdnet, - spec=spec, - dist_in_keys=["loc", "scale"], - sample_out_key=["out"], - safe=safe, - **kwargs, - ) - return - else: - tensordict_module = SafeProbabilisticModule( - module=tdnet, - spec=spec, - dist_in_keys=["loc", "scale"], - sample_out_key=["out"], - safe=safe, - **kwargs, - ) - - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - tensordict_module(td, params=params) - assert td.shape == torch.Size([3]) - assert td.get("out").shape == torch.Size([3, 4]) - - # test bounds - if not safe and spec_type == "bounded": - assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() - - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_functional_probabilistic_laterconstruct(self, safe, spec_type): - torch.manual_seed(0) - param_multiplier = 2 - - net = nn.Linear(3, 4 * param_multiplier) - in_keys = ["in"] - net = NormalParamWrapper(net) tdnet = SafeModule( - module=net, spec=None, in_keys=in_keys, out_keys=["loc", "scale"] + module=net, spec=None, in_keys=["in"], out_keys=["loc", "scale"] ) if spec_type is None: @@ -401,13 +331,9 @@ def test_functional_probabilistic_laterconstruct(self, safe, spec_type): safe=safe, **kwargs, ) - tensordict_module, ( - params, - buffers, - ) = tensordict_module.make_functional_with_buffers() td = TensorDict({"in": torch.randn(3, 3)}, [3]) - td = tensordict_module(td, params=params, buffers=buffers) + tensordict_module(td, params=params) assert td.shape == torch.Size([3]) assert td.get("out").shape == torch.Size([3, 4]) @@ -424,8 +350,7 @@ def test_functional_with_buffer(self, safe, spec_type): param_multiplier = 1 net = nn.BatchNorm1d(32 * param_multiplier) - - fnet, params, buffers = make_functional_with_buffers(net) + params = make_functional(net) if spec_type is None: spec = None @@ -442,7 +367,7 @@ def test_functional_with_buffer(self, safe, spec_type): ): tdmodule = SafeModule( spec=spec, - module=fnet, + module=net, in_keys=["in"], out_keys=["out"], safe=safe, @@ -451,14 +376,14 @@ def test_functional_with_buffer(self, safe, spec_type): else: tdmodule = SafeModule( spec=spec, - module=fnet, + module=net, in_keys=["in"], out_keys=["out"], safe=safe, ) td = TensorDict({"in": torch.randn(3, 32 * param_multiplier)}, [3]) - tdmodule(td, params=params, buffers=buffers) + tdmodule(td, params=params) assert td.shape == torch.Size([3]) assert td.get("out").shape == torch.Size([3, 32]) @@ -474,12 +399,10 @@ def test_functional_with_buffer_probabilistic(self, safe, spec_type): torch.manual_seed(0) param_multiplier = 2 - net = nn.BatchNorm1d(32 * param_multiplier) - in_keys = ["in"] - net = NormalParamWrapper(net) - fnet, params, buffers = make_functional_with_buffers(net) + net = NormalParamWrapper(nn.BatchNorm1d(32 * param_multiplier)) + params = make_functional(net) tdnet = SafeModule( - module=fnet, spec=None, in_keys=in_keys, out_keys=["loc", "scale"] + module=net, spec=None, in_keys=["in"], out_keys=["loc", "scale"] ) if spec_type is None: @@ -522,71 +445,7 @@ def test_functional_with_buffer_probabilistic(self, safe, spec_type): ) td = TensorDict({"in": torch.randn(3, 32 * param_multiplier)}, [3]) - tdmodule(td, params=params, buffers=buffers) - assert td.shape == torch.Size([3]) - assert td.get("out").shape == torch.Size([3, 32]) - - # test bounds - if not safe and spec_type == "bounded": - assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() - - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_functional_with_buffer_probabilistic_laterconstruct(self, safe, spec_type): - torch.manual_seed(0) - param_multiplier = 2 - - net = nn.BatchNorm1d(32 * param_multiplier) - in_keys = ["in"] - net = NormalParamWrapper(net) - tdnet = SafeModule( - module=net, spec=None, in_keys=in_keys, out_keys=["loc", "scale"] - ) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = NdBoundedTensorSpec(-0.1, 0.1, 32) - elif spec_type == "unbounded": - spec = NdUnboundedContinuousTensorSpec(32) - else: - raise NotImplementedError - spec = ( - CompositeSpec(out=spec, loc=None, scale=None) if spec is not None else None - ) - - kwargs = {"distribution_class": TanhNormal} - - if safe and spec is None: - with pytest.raises( - RuntimeError, - match="is not a valid configuration as the tensor specs are not " - "specified", - ): - SafeProbabilisticModule( - module=tdnet, - spec=spec, - dist_in_keys=["loc", "scale"], - sample_out_key=["out"], - safe=safe, - **kwargs, - ) - return - else: - tdmodule = SafeProbabilisticModule( - module=tdnet, - spec=spec, - dist_in_keys=["loc", "scale"], - sample_out_key=["out"], - safe=safe, - **kwargs, - ) - tdmodule, (params, buffers) = tdmodule.make_functional_with_buffers() - - td = TensorDict({"in": torch.randn(3, 32 * param_multiplier)}, [3]) - tdmodule(td, params=params, buffers=buffers) + tdmodule(td, params=params) assert td.shape == torch.Size([3]) assert td.get("out").shape == torch.Size([3, 32]) @@ -607,8 +466,6 @@ def test_vmap(self, safe, spec_type): net = nn.Linear(3, 4 * param_multiplier) - fnet, params = make_functional(net) - if spec_type is None: spec = None elif spec_type == "bounded": @@ -624,7 +481,7 @@ def test_vmap(self, safe, spec_type): ): tdmodule = SafeModule( spec=spec, - module=fnet, + module=net, in_keys=["in"], out_keys=["out"], safe=safe, @@ -633,27 +490,25 @@ def test_vmap(self, safe, spec_type): else: tdmodule = SafeModule( spec=spec, - module=fnet, + module=net, in_keys=["in"], out_keys=["out"], safe=safe, ) + params = make_functional(tdmodule) + # vmap = True - params = [p.repeat(10, *[1 for _ in p.shape]) for p in params] + params = params.expand(10) td = TensorDict({"in": torch.randn(3, 3)}, [3]) - td_out = tdmodule(td, params=params, vmap=True) - assert td_out is not td - assert td_out.shape == torch.Size([10, 3]) - assert td_out.get("out").shape == torch.Size([10, 3, 4]) - # test bounds - if not safe and spec_type == "bounded": - assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() - - # vmap = (0, None) - td_out = tdmodule(td, params=params, vmap=(0, None)) + if safe and spec_type == "bounded": + with pytest.raises( + RuntimeError, match="vmap cannot be used with safe=True" + ): + td_out = vmap(tdmodule, (None, 0))(td, params) + return + else: + td_out = vmap(tdmodule, (None, 0))(td, params) assert td_out is not td assert td_out.shape == torch.Size([10, 3]) assert td_out.get("out").shape == torch.Size([10, 3, 4]) @@ -664,8 +519,9 @@ def test_vmap(self, safe, spec_type): assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() # vmap = (0, 0) - td_repeat = td.expand(10, *td.batch_size).clone() - td_out = tdmodule(td_repeat, params=params, vmap=(0, 0)) + td = TensorDict({"in": torch.randn(3, 3)}, [3]) + td_repeat = td.expand(10, *td.batch_size) + td_out = vmap(tdmodule, (0, 0))(td_repeat, params) assert td_out is not td assert td_out.shape == torch.Size([10, 3]) assert td_out.get("out").shape == torch.Size([10, 3, 4]) @@ -684,12 +540,9 @@ def test_vmap_probabilistic(self, safe, spec_type): torch.manual_seed(0) param_multiplier = 2 - net = nn.Linear(3, 4 * param_multiplier) - net = NormalParamWrapper(net) - in_keys = ["in"] - fnet, params = make_functional(net) + net = NormalParamWrapper(nn.Linear(3, 4 * param_multiplier)) tdnet = SafeModule( - module=fnet, spec=None, in_keys=in_keys, out_keys=["loc", "scale"] + module=net, spec=None, in_keys=["in"], out_keys=["loc", "scale"] ) if spec_type is None: @@ -731,113 +584,19 @@ def test_vmap_probabilistic(self, safe, spec_type): **kwargs, ) + params = make_functional(tdmodule) + # vmap = True - params = [p.repeat(10, *[1 for _ in p.shape]) for p in params] + params = params.expand(10) td = TensorDict({"in": torch.randn(3, 3)}, [3]) - td_out = tdmodule(td, params=params, vmap=True) - assert td_out is not td - assert td_out.shape == torch.Size([10, 3]) - assert td_out.get("out").shape == torch.Size([10, 3, 4]) - # test bounds - if not safe and spec_type == "bounded": - assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() - - # vmap = (0, None) - td_out = tdmodule(td, params=params, vmap=(0, None)) - assert td_out is not td - assert td_out.shape == torch.Size([10, 3]) - assert td_out.get("out").shape == torch.Size([10, 3, 4]) - # test bounds - if not safe and spec_type == "bounded": - assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() - - # vmap = (0, 0) - td_repeat = td.expand(10, *td.batch_size).clone() - td_out = tdmodule(td_repeat, params=params, vmap=(0, 0)) - assert td_out is not td - assert td_out.shape == torch.Size([10, 3]) - assert td_out.get("out").shape == torch.Size([10, 3, 4]) - # test bounds - if not safe and spec_type == "bounded": - assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() - - @pytest.mark.skipif( - not _has_functorch, reason="vmap can only be used with functorch" - ) - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_vmap_probabilistic_laterconstruct(self, safe, spec_type): - torch.manual_seed(0) - param_multiplier = 2 - - net = nn.Linear(3, 4 * param_multiplier) - net = NormalParamWrapper(net) - in_keys = ["in"] - tdnet = SafeModule( - module=net, spec=None, in_keys=in_keys, out_keys=["loc", "scale"] - ) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = NdBoundedTensorSpec(-0.1, 0.1, 4) - elif spec_type == "unbounded": - spec = NdUnboundedContinuousTensorSpec(4) - else: - raise NotImplementedError - spec = ( - CompositeSpec(out=spec, loc=None, scale=None) if spec is not None else None - ) - - kwargs = {"distribution_class": TanhNormal} - - if safe and spec is None: + if safe and spec_type == "bounded": with pytest.raises( - RuntimeError, - match="is not a valid configuration as the tensor specs are not " - "specified", + RuntimeError, match="vmap cannot be used with safe=True" ): - tdmodule = SafeProbabilisticModule( - module=tdnet, - spec=spec, - dist_in_keys=["loc", "scale"], - sample_out_key=["out"], - safe=safe, - **kwargs, - ) + td_out = vmap(tdmodule, (None, 0))(td, params) return else: - tdmodule = SafeProbabilisticModule( - module=tdnet, - spec=spec, - dist_in_keys=["loc", "scale"], - sample_out_key=["out"], - safe=safe, - **kwargs, - ) - tdmodule, (params, buffers) = tdmodule.make_functional_with_buffers() - - # vmap = True - params = [p.repeat(10, *[1 for _ in p.shape]) for p in params] - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - td_out = tdmodule(td, params=params, buffers=buffers, vmap=True) - assert td_out is not td - assert td_out.shape == torch.Size([10, 3]) - assert td_out.get("out").shape == torch.Size([10, 3, 4]) - # test bounds - if not safe and spec_type == "bounded": - assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() - - # vmap = (0, 0, None) - td_out = tdmodule(td, params=params, buffers=buffers, vmap=(0, 0, None)) + td_out = vmap(tdmodule, (None, 0))(td, params) assert td_out is not td assert td_out.shape == torch.Size([10, 3]) assert td_out.get("out").shape == torch.Size([10, 3, 4]) @@ -847,9 +606,10 @@ def test_vmap_probabilistic_laterconstruct(self, safe, spec_type): elif safe and spec_type == "bounded": assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() - # vmap = (0, 0, 0) - td_repeat = td.expand(10, *td.batch_size).clone() - td_out = tdmodule(td_repeat, params=params, buffers=buffers, vmap=(0, 0, 0)) + # vmap = (0, 0) + td = TensorDict({"in": torch.randn(3, 3)}, [3]) + td_repeat = td.expand(10, *td.batch_size) + td_out = vmap(tdmodule, (0, 0))(td_repeat, params) assert td_out is not td assert td_out.shape == torch.Size([10, 3]) assert td_out.get("out").shape == torch.Size([10, 3, 4]) @@ -1046,14 +806,6 @@ def test_functional(self, safe, spec_type): dummy_net = nn.Linear(4, 4) net2 = nn.Linear(4, 4 * param_multiplier) - fnet1, params1 = make_functional(net1) - fdummy_net, _ = make_functional(dummy_net) - fnet2, params2 = make_functional(net2) - if isinstance(params1, TensorDictBase): - params = TensorDict({"0": params1, "1": params2}, []) - else: - params = list(params1) + list(params2) - if spec_type is None: spec = None elif spec_type == "bounded": @@ -1065,17 +817,17 @@ def test_functional(self, safe, spec_type): pytest.skip("safe and spec is None is checked elsewhere") else: tdmodule1 = SafeModule( - fnet1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False + net1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False ) dummy_tdmodule = SafeModule( - fdummy_net, + dummy_net, spec=None, in_keys=["hidden"], out_keys=["hidden"], safe=False, ) tdmodule2 = SafeModule( - fnet2, + net2, spec=spec, in_keys=["hidden"], out_keys=["out"], @@ -1083,14 +835,18 @@ def test_functional(self, safe, spec_type): ) tdmodule = SafeSequential(tdmodule1, dummy_tdmodule, tdmodule2) + params = make_functional(tdmodule) + assert hasattr(tdmodule, "__setitem__") assert len(tdmodule) == 3 tdmodule[1] = tdmodule2 + params["module", "1"] = params["module", "2"] assert len(tdmodule) == 3 assert hasattr(tdmodule, "__delitem__") assert len(tdmodule) == 3 del tdmodule[2] + del params["module", "2"] assert len(tdmodule) == 2 assert hasattr(tdmodule, "__getitem__") @@ -1098,7 +854,7 @@ def test_functional(self, safe, spec_type): assert tdmodule[1] is tdmodule2 td = TensorDict({"in": torch.randn(3, 3)}, [3]) - tdmodule(td, params=params) + tdmodule(td, params) assert td.shape == torch.Size([3]) assert td.get("out").shape == torch.Size([3, 4]) @@ -1122,14 +878,7 @@ def test_functional_probabilistic(self, safe, spec_type): net2 = nn.Linear(4, 4 * param_multiplier) net2 = NormalParamWrapper(net2) - fnet1, params1 = make_functional(net1) - fdummy_net, _ = make_functional(dummy_net) - fnet2, params2 = make_functional(net2) - fnet2 = SafeModule(module=fnet2, in_keys=["hidden"], out_keys=["loc", "scale"]) - if isinstance(params1, TensorDictBase): - params = TensorDict({"0": params1, "1": params2}, []) - else: - params = list(params1) + list(params2) + net2 = SafeModule(module=net2, in_keys=["hidden"], out_keys=["loc", "scale"]) if spec_type is None: spec = None @@ -1149,17 +898,17 @@ def test_functional_probabilistic(self, safe, spec_type): pytest.skip("safe and spec is None is checked elsewhere") else: tdmodule1 = SafeModule( - fnet1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False + net1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False ) dummy_tdmodule = SafeModule( - fdummy_net, + dummy_net, spec=None, in_keys=["hidden"], out_keys=["hidden"], safe=False, ) tdmodule2 = SafeProbabilisticModule( - fnet2, + net2, spec=spec, dist_in_keys=["loc", "scale"], sample_out_key=["out"], @@ -1168,14 +917,18 @@ def test_functional_probabilistic(self, safe, spec_type): ) tdmodule = SafeSequential(tdmodule1, dummy_tdmodule, tdmodule2) + params = make_functional(tdmodule, funs_to_decorate=["forward", "get_dist"]) + assert hasattr(tdmodule, "__setitem__") assert len(tdmodule) == 3 tdmodule[1] = tdmodule2 + params["module", "1"] = params["module", "2"] assert len(tdmodule) == 3 assert hasattr(tdmodule, "__delitem__") assert len(tdmodule) == 3 del tdmodule[2] + del params["module", "2"] assert len(tdmodule) == 2 assert hasattr(tdmodule, "__getitem__") @@ -1212,19 +965,6 @@ def test_functional_with_buffer( nn.Linear(7, 7 * param_multiplier), nn.BatchNorm1d(7 * param_multiplier) ) - fnet1, params1, buffers1 = make_functional_with_buffers(net1) - fdummy_net, _, _ = make_functional_with_buffers(dummy_net) - fnet2, params2, buffers2 = make_functional_with_buffers(net2) - - if isinstance(params1, TensorDictBase): - params = TensorDict({"0": params1, "1": params2}, []) - else: - params = list(params1) + list(params2) - if isinstance(buffers1, TensorDictBase): - buffers = TensorDict({"0": buffers1, "1": buffers2}, []) - else: - buffers = list(buffers1) + list(buffers2) - if spec_type is None: spec = None elif spec_type == "bounded": @@ -1236,17 +976,17 @@ def test_functional_with_buffer( pytest.skip("safe and spec is None is checked elsewhere") else: tdmodule1 = SafeModule( - fnet1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False + net1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False ) dummy_tdmodule = SafeModule( - fdummy_net, + dummy_net, spec=None, in_keys=["hidden"], out_keys=["hidden"], safe=False, ) tdmodule2 = SafeModule( - fnet2, + net2, spec=spec, in_keys=["hidden"], out_keys=["out"], @@ -1254,14 +994,18 @@ def test_functional_with_buffer( ) tdmodule = SafeSequential(tdmodule1, dummy_tdmodule, tdmodule2) + params = make_functional(tdmodule) + assert hasattr(tdmodule, "__setitem__") assert len(tdmodule) == 3 tdmodule[1] = tdmodule2 + params["module", "1"] = params["module", "2"] assert len(tdmodule) == 3 assert hasattr(tdmodule, "__delitem__") assert len(tdmodule) == 3 del tdmodule[2] + del params["module", "2"] assert len(tdmodule) == 2 assert hasattr(tdmodule, "__getitem__") @@ -1269,10 +1013,10 @@ def test_functional_with_buffer( assert tdmodule[1] is tdmodule2 td = TensorDict({"in": torch.randn(3, 7)}, [3]) - tdmodule(td, params=params, buffers=buffers) + tdmodule(td, params=params) with pytest.raises(RuntimeError, match="Cannot call get_dist on a sequence"): - dist, *_ = tdmodule.get_dist(td, params=params, buffers=buffers) + dist, *_ = tdmodule.get_dist(td, params=params) assert td.shape == torch.Size([3]) assert td.get("out").shape == torch.Size([3, 7]) @@ -1299,22 +1043,7 @@ def test_functional_with_buffer_probabilistic( nn.Linear(7, 7 * param_multiplier), nn.BatchNorm1d(7 * param_multiplier) ) net2 = NormalParamWrapper(net2) - - fnet1, params1, buffers1 = make_functional_with_buffers(net1) - fdummy_net, _, _ = make_functional_with_buffers(dummy_net) - # fnet2, params2, buffers2 = make_functional_with_buffers(net2) - # fnet2 = SafeModule(fnet2, in_keys=["hidden"], out_keys=["loc", "scale"]) net2 = SafeModule(net2, in_keys=["hidden"], out_keys=["loc", "scale"]) - fnet2, (params2, buffers2) = net2.make_functional_with_buffers() - - if isinstance(params1, TensorDictBase): - params = TensorDict({"0": params1, "1": params2}, []) - else: - params = list(params1) + list(params2) - if isinstance(buffers1, TensorDictBase): - buffers = TensorDict({"0": buffers1, "1": buffers2}, []) - else: - buffers = list(buffers1) + list(buffers2) if spec_type is None: spec = None @@ -1334,17 +1063,17 @@ def test_functional_with_buffer_probabilistic( pytest.skip("safe and spec is None is checked elsewhere") else: tdmodule1 = SafeModule( - fnet1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False + net1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False ) dummy_tdmodule = SafeModule( - fdummy_net, + dummy_net, spec=None, in_keys=["hidden"], out_keys=["hidden"], safe=False, ) tdmodule2 = SafeProbabilisticModule( - fnet2, + net2, spec=spec, dist_in_keys=["loc", "scale"], sample_out_key=["out"], @@ -1353,14 +1082,18 @@ def test_functional_with_buffer_probabilistic( ) tdmodule = SafeSequential(tdmodule1, dummy_tdmodule, tdmodule2) + params = make_functional(tdmodule, ["forward", "get_dist"]) + assert hasattr(tdmodule, "__setitem__") assert len(tdmodule) == 3 tdmodule[1] = tdmodule2 + params["module", "1"] = params["module", "2"] assert len(tdmodule) == 3 assert hasattr(tdmodule, "__delitem__") assert len(tdmodule) == 3 del tdmodule[2] + del params["module", "2"] assert len(tdmodule) == 2 assert hasattr(tdmodule, "__getitem__") @@ -1368,73 +1101,9 @@ def test_functional_with_buffer_probabilistic( assert tdmodule[1] is tdmodule2 td = TensorDict({"in": torch.randn(3, 7)}, [3]) - tdmodule(td, params=params, buffers=buffers) - - dist, *_ = tdmodule.get_dist(td, params=params, buffers=buffers) - assert dist.rsample().shape[: td.ndimension()] == td.shape - - assert td.shape == torch.Size([3]) - assert td.get("out").shape == torch.Size([3, 7]) - - # test bounds - if not safe and spec_type == "bounded": - assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() - - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_functional_with_buffer_probabilistic_laterconstruct( - self, - safe, - spec_type, - ): - torch.manual_seed(0) - param_multiplier = 2 - - net1 = nn.Sequential(nn.Linear(7, 7), nn.BatchNorm1d(7)) - net2 = nn.Sequential( - nn.Linear(7, 7 * param_multiplier), nn.BatchNorm1d(7 * param_multiplier) - ) - net2 = NormalParamWrapper(net2) - net2 = SafeModule(net2, in_keys=["hidden"], out_keys=["loc", "scale"]) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = NdBoundedTensorSpec(-0.1, 0.1, 7) - elif spec_type == "unbounded": - spec = NdUnboundedContinuousTensorSpec(7) - else: - raise NotImplementedError - spec = ( - CompositeSpec(out=spec, loc=None, scale=None) if spec is not None else None - ) - - kwargs = {"distribution_class": TanhNormal} - - if safe and spec is None: - pytest.skip("safe and spec is None is checked elsewhere") - else: - tdmodule1 = SafeModule( - net1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False - ) - tdmodule2 = SafeProbabilisticModule( - net2, - spec=spec, - dist_in_keys=["loc", "scale"], - sample_out_key=["out"], - safe=safe, - **kwargs, - ) - tdmodule = SafeSequential(tdmodule1, tdmodule2) - - tdmodule, (params, buffers) = tdmodule.make_functional_with_buffers() - - td = TensorDict({"in": torch.randn(3, 7)}, [3]) - tdmodule(td, params=params, buffers=buffers) + tdmodule(td, params=params) - dist, *_ = tdmodule.get_dist(td, params=params, buffers=buffers) + dist, *_ = tdmodule.get_dist(td, params=params) assert dist.rsample().shape[: td.ndimension()] == td.shape assert td.shape == torch.Size([3]) @@ -1459,11 +1128,6 @@ def test_vmap(self, safe, spec_type): dummy_net = nn.Linear(4, 4) net2 = nn.Linear(4, 4 * param_multiplier) - fnet1, params1 = make_functional(net1) - fdummy_net, _ = make_functional(dummy_net) - fnet2, params2 = make_functional(net2) - params = params1 + params2 - if spec_type is None: spec = None elif spec_type == "bounded": @@ -1475,21 +1139,21 @@ def test_vmap(self, safe, spec_type): pytest.skip("safe and spec is None is checked elsewhere") else: tdmodule1 = SafeModule( - fnet1, + net1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False, ) dummy_tdmodule = SafeModule( - fdummy_net, + dummy_net, spec=None, in_keys=["hidden"], out_keys=["hidden"], safe=False, ) tdmodule2 = SafeModule( - fnet2, + net2, spec=spec, in_keys=["hidden"], out_keys=["out"], @@ -1497,14 +1161,18 @@ def test_vmap(self, safe, spec_type): ) tdmodule = SafeSequential(tdmodule1, dummy_tdmodule, tdmodule2) + params = make_functional(tdmodule) + assert hasattr(tdmodule, "__setitem__") assert len(tdmodule) == 3 tdmodule[1] = tdmodule2 + params["module", "1"] = params["module", "2"] assert len(tdmodule) == 3 assert hasattr(tdmodule, "__delitem__") assert len(tdmodule) == 3 del tdmodule[2] + del params["module", "2"] assert len(tdmodule) == 2 assert hasattr(tdmodule, "__getitem__") @@ -1512,20 +1180,17 @@ def test_vmap(self, safe, spec_type): assert tdmodule[1] is tdmodule2 # vmap = True - params = [p.repeat(10, *[1 for _ in p.shape]) for p in params] + params = params.expand(10) td = TensorDict({"in": torch.randn(3, 3)}, [3]) - td_out = tdmodule(td, params=params, vmap=True) - assert td_out is not td - assert td_out.shape == torch.Size([10, 3]) - assert td_out.get("out").shape == torch.Size([10, 3, 4]) - # test bounds - if not safe and spec_type == "bounded": - assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() + if safe and spec_type == "bounded": + with pytest.raises( + RuntimeError, match="vmap cannot be used with safe=True" + ): + td_out = vmap(tdmodule, (None, 0))(td, params) + return + else: + td_out = vmap(tdmodule, (None, 0))(td, params) - # vmap = (0, None) - td_out = tdmodule(td, params=params, vmap=(0, None)) assert td_out is not td assert td_out.shape == torch.Size([10, 3]) assert td_out.get("out").shape == torch.Size([10, 3, 4]) @@ -1536,9 +1201,10 @@ def test_vmap(self, safe, spec_type): assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() # vmap = (0, 0) - td_repeat = td.expand(10, *td.batch_size).clone() - td_out = tdmodule(td_repeat, params=params, vmap=(0, 0)) - assert td_out is not td + td = TensorDict({"in": torch.randn(3, 3)}, [3]) + td_repeat = td.expand(10, *td.batch_size) + td_out = vmap(tdmodule, (0, 0))(td_repeat, params) + assert td_out is not td_repeat assert td_out.shape == torch.Size([10, 3]) assert td_out.get("out").shape == torch.Size([10, 3, 4]) # test bounds @@ -1557,14 +1223,10 @@ def test_vmap_probabilistic(self, safe, spec_type): param_multiplier = 2 net1 = nn.Linear(3, 4) - fnet1, params1 = make_functional(net1) net2 = nn.Linear(4, 4 * param_multiplier) net2 = NormalParamWrapper(net2) - fnet2, params2 = make_functional(net2) - fnet2 = SafeModule(fnet2, in_keys=["hidden"], out_keys=["loc", "scale"]) - - params = params1 + params2 + net2 = SafeModule(net2, in_keys=["hidden"], out_keys=["loc", "scale"]) if spec_type is None: spec = None @@ -1584,14 +1246,14 @@ def test_vmap_probabilistic(self, safe, spec_type): pytest.skip("safe and spec is None is checked elsewhere") else: tdmodule1 = SafeModule( - fnet1, + net1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False, ) tdmodule2 = SafeProbabilisticModule( - fnet2, + net2, spec=spec, sample_out_key=["out"], dist_in_keys=["loc", "scale"], @@ -1600,21 +1262,19 @@ def test_vmap_probabilistic(self, safe, spec_type): ) tdmodule = SafeSequential(tdmodule1, tdmodule2) + params = make_functional(tdmodule) + # vmap = True - params = [p.repeat(10, *[1 for _ in p.shape]) for p in params] + params = params.expand(10) td = TensorDict({"in": torch.randn(3, 3)}, [3]) - td_out = tdmodule(td, params=params, vmap=True) - assert td_out is not td - assert td_out.shape == torch.Size([10, 3]) - assert td_out.get("out").shape == torch.Size([10, 3, 4]) - # test bounds - if not safe and spec_type == "bounded": - assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() - - # vmap = (0, None) - td_out = tdmodule(td, params=params, vmap=(0, None)) + if safe and spec_type == "bounded": + with pytest.raises( + RuntimeError, match="vmap cannot be used with safe=True" + ): + td_out = vmap(tdmodule, (None, 0))(td, params) + return + else: + td_out = vmap(tdmodule, (None, 0))(td, params) assert td_out is not td assert td_out.shape == torch.Size([10, 3]) assert td_out.get("out").shape == torch.Size([10, 3, 4]) @@ -1625,9 +1285,10 @@ def test_vmap_probabilistic(self, safe, spec_type): assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() # vmap = (0, 0) - td_repeat = td.expand(10, *td.batch_size).clone() - td_out = tdmodule(td_repeat, params=params, vmap=(0, 0)) - assert td_out is not td + td = TensorDict({"in": torch.randn(3, 3)}, [3]) + td_repeat = td.expand(10, *td.batch_size) + td_out = vmap(tdmodule, (0, 0))(td_repeat, params) + assert td_out is not td_repeat assert td_out.shape == torch.Size([10, 3]) assert td_out.get("out").shape == torch.Size([10, 3, 4]) # test bounds @@ -1636,6 +1297,45 @@ def test_vmap_probabilistic(self, safe, spec_type): elif safe and spec_type == "bounded": assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() + @pytest.mark.parametrize("functional", [True, False]) + def test_submodule_sequence(self, functional): + td_module_1 = SafeModule( + nn.Linear(3, 2), + in_keys=["in"], + out_keys=["hidden"], + ) + td_module_2 = SafeModule( + nn.Linear(2, 4), + in_keys=["hidden"], + out_keys=["out"], + ) + td_module = SafeSequential(td_module_1, td_module_2) + + if functional: + td_1 = TensorDict({"in": torch.randn(5, 3)}, [5]) + sub_seq_1 = td_module.select_subsequence(out_keys=["hidden"]) + params = make_functional(sub_seq_1) + sub_seq_1(td_1, params=params) + assert "hidden" in td_1.keys() + assert "out" not in td_1.keys() + td_2 = TensorDict({"hidden": torch.randn(5, 2)}, [5]) + sub_seq_2 = td_module.select_subsequence(in_keys=["hidden"]) + params = make_functional(sub_seq_2) + sub_seq_2(td_2, params=params) + assert "out" in td_2.keys() + assert td_2.get("out").shape == torch.Size([5, 4]) + else: + td_1 = TensorDict({"in": torch.randn(5, 3)}, [5]) + sub_seq_1 = td_module.select_subsequence(out_keys=["hidden"]) + sub_seq_1(td_1) + assert "hidden" in td_1.keys() + assert "out" not in td_1.keys() + td_2 = TensorDict({"hidden": torch.randn(5, 2)}, [5]) + sub_seq_2 = td_module.select_subsequence(in_keys=["hidden"]) + sub_seq_2(td_2) + assert "out" in td_2.keys() + assert td_2.get("out").shape == torch.Size([5, 4]) + @pytest.mark.parametrize("stack", [True, False]) @pytest.mark.parametrize("functional", [True, False]) def test_sequential_partial(self, stack, functional): @@ -1643,29 +1343,14 @@ def test_sequential_partial(self, stack, functional): param_multiplier = 2 net1 = nn.Linear(3, 4) - if functional: - fnet1, params1 = make_functional(net1) - else: - params1 = None - fnet1 = net1 net2 = nn.Linear(4, 4 * param_multiplier) net2 = NormalParamWrapper(net2) - if functional: - fnet2, params2 = make_functional(net2) - else: - fnet2 = net2 - params2 = None - fnet2 = SafeModule(fnet2, in_keys=["b"], out_keys=["loc", "scale"]) + net2 = SafeModule(net2, in_keys=["b"], out_keys=["loc", "scale"]) net3 = nn.Linear(4, 4 * param_multiplier) net3 = NormalParamWrapper(net3) - if functional: - fnet3, params3 = make_functional(net3) - else: - fnet3 = net3 - params3 = None - fnet3 = SafeModule(fnet3, in_keys=["c"], out_keys=["loc", "scale"]) + net3 = SafeModule(net3, in_keys=["c"], out_keys=["loc", "scale"]) spec = NdBoundedTensorSpec(-0.1, 0.1, 4) spec = CompositeSpec(out=spec, loc=None, scale=None) @@ -1673,14 +1358,14 @@ def test_sequential_partial(self, stack, functional): kwargs = {"distribution_class": TanhNormal} tdmodule1 = SafeModule( - fnet1, + net1, spec=None, in_keys=["a"], out_keys=["hidden"], safe=False, ) tdmodule2 = SafeProbabilisticModule( - fnet2, + net2, spec=spec, sample_out_key=["out"], dist_in_keys=["loc", "scale"], @@ -1688,7 +1373,7 @@ def test_sequential_partial(self, stack, functional): **kwargs, ) tdmodule3 = SafeProbabilisticModule( - fnet3, + net3, spec=spec, sample_out_key=["out"], dist_in_keys=["loc", "scale"], @@ -1699,6 +1384,11 @@ def test_sequential_partial(self, stack, functional): tdmodule1, tdmodule2, tdmodule3, partial_tolerant=True ) + if functional: + params = make_functional(tdmodule) + else: + params = None + if stack: td = torch.stack( [ @@ -1708,16 +1398,6 @@ def test_sequential_partial(self, stack, functional): 0, ) if functional: - if _has_functorch: - params = params1 + params2 + params3 - else: - params = TensorDict( - { - str(i): params - for i, params in enumerate((params1, params2, params3)) - }, - [], - ) tdmodule(td, params=params) else: tdmodule(td) @@ -1732,16 +1412,6 @@ def test_sequential_partial(self, stack, functional): else: td = TensorDict({"a": torch.randn(3), "b": torch.randn(4)}, []) if functional: - if _has_functorch: - params = params1 + params2 + params3 - else: - params = TensorDict( - { - str(i): params - for i, params in enumerate((params1, params2, params3)) - }, - [], - ) tdmodule(td, params=params) else: tdmodule(td) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index a05b53cf7c5..1bba001c1c1 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -370,7 +370,9 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None: .memmap_(prefix=self.scratch_dir) .to(self.device) ) - for key, tensor in sorted(out.flatten_keys(".").items()): + for key, tensor in sorted( + out.items(include_nested=True, leaves_only=True), key=str + ): filesize = os.path.getsize(tensor.filename) / 1024 / 1024 print( f"\t{key}: {tensor.filename}, {filesize} Mb of storage (size: {tensor.shape})." diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index 8c6ca5d8593..5d42fbe753f 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -13,13 +13,6 @@ TanhNormal, TruncatedNormal, ) - -# from .functional_modules import ( -# FunctionalModule, -# FunctionalModuleWithBuffers, -# extract_weights, -# extract_buffers, -# ) from .models import ( ConvNet, DdpgCnnActor, diff --git a/torchrl/modules/distributions/continuous.py b/torchrl/modules/distributions/continuous.py index da136f67208..f0644174be0 100644 --- a/torchrl/modules/distributions/continuous.py +++ b/torchrl/modules/distributions/continuous.py @@ -95,7 +95,7 @@ def _call(self, x: torch.Tensor) -> torch.Tensor: def _inverse(self, y: torch.Tensor) -> torch.Tensor: eps = torch.finfo(y.dtype).eps - y.data.clamp_(-1 + eps, 1 - eps) + y = y.clamp(-1 + eps, 1 - eps) x = super()._inverse(y) return x diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index dba80fc67a5..f9661faa90e 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -76,25 +76,24 @@ class ProbabilisticActor(SafeProbabilisticModule): automatically translated into :obj:`spec = CompositeSpec(action=spec)` Examples: - >>> import functorch >>> import torch >>> from tensordict import TensorDict + >>> from tensordict.nn.functional_modules import make_functional >>> from torchrl.data import NdBoundedTensorSpec - >>> from torchrl.modules import Actor, TanhNormal, NormalParamWrapper + >>> from torchrl.modules import ProbabilisticActor, NormalParamWrapper, SafeModule, TanhNormal >>> td = TensorDict({"observation": torch.randn(3, 4)}, [3,]) >>> action_spec = NdBoundedTensorSpec(shape=torch.Size([4]), ... minimum=-1, maximum=1) >>> module = NormalParamWrapper(torch.nn.Linear(4, 8)) - >>> fmodule, params, buffers = functorch.make_functional_with_buffers( - ... module) - >>> tensordict_module = SafeModule(fmodule, in_keys=["observation"], out_keys=["loc", "scale"]) + >>> params = make_functional(module) + >>> tensordict_module = SafeModule(module, in_keys=["observation"], out_keys=["loc", "scale"]) >>> td_module = ProbabilisticActor( ... module=tensordict_module, ... spec=action_spec, ... dist_in_keys=["loc", "scale"], ... distribution_class=TanhNormal, ... ) - >>> td = td_module(td, params=params, buffers=buffers) + >>> td = td_module(td, params=params) >>> td TensorDict( fields={ @@ -143,9 +142,9 @@ class ValueOperator(SafeModule): key is part of the in_keys list). Examples: - >>> import functorch >>> import torch >>> from tensordict import TensorDict + >>> from tensordict.nn.functional_modules import make_functional >>> from torch import nn >>> from torchrl.data import NdUnboundedContinuousTensorSpec >>> from torchrl.modules import ValueOperator @@ -157,20 +156,20 @@ class ValueOperator(SafeModule): ... def forward(self, obs, action): ... return self.linear(torch.cat([obs, action], -1)) >>> module = CustomModule() - >>> fmodule, params, buffers = functorch.make_functional_with_buffers(module) >>> td_module = ValueOperator( - ... in_keys=["observation", "action"], - ... module=fmodule, - ... ) - >>> td_module(td, params=params, buffers=buffers) + ... in_keys=["observation", "action"], module=module + ... ) + >>> params = make_functional(td_module) + >>> td_module(td, params=params) >>> print(td) TensorDict( - fields={observation: Tensor(torch.Size([3, 4]), dtype=torch.float32), + fields={ action: Tensor(torch.Size([3, 2]), dtype=torch.float32), + observation: Tensor(torch.Size([3, 4]), dtype=torch.float32), state_action_value: Tensor(torch.Size([3, 1]), dtype=torch.float32)}, - shared=False, batch_size=torch.Size([3]), - device=cpu) + device=None, + is_shared=False) """ @@ -210,28 +209,29 @@ class QValueHook: action component. Examples: - >>> import functorch >>> import torch >>> from tensordict import TensorDict + >>> from tensordict.nn.functional_modules import make_functional >>> from torch import nn >>> from torchrl.data import OneHotDiscreteTensorSpec >>> from torchrl.modules.tensordict_module.actors import QValueHook, Actor >>> td = TensorDict({'observation': torch.randn(5, 4)}, [5]) >>> module = nn.Linear(4, 4) - >>> fmodule, params, buffers = functorch.make_functional_with_buffers(module) + >>> params = make_functional(module) >>> hook = QValueHook("one_hot") - >>> _ = fmodule.register_forward_hook(hook) + >>> module.register_forward_hook(hook) >>> action_spec = OneHotDiscreteTensorSpec(4) - >>> qvalue_actor = Actor(module=fmodule, spec=action_spec, out_keys=["action", "action_value"]) - >>> _ = qvalue_actor(td, params=params, buffers=buffers) + >>> qvalue_actor = Actor(module=module, spec=action_spec, out_keys=["action", "action_value"]) + >>> qvalue_actor(td, params=params) >>> print(td) TensorDict( - fields={observation: Tensor(torch.Size([5, 4]), dtype=torch.float32), + fields={ action: Tensor(torch.Size([5, 4]), dtype=torch.int64), - action_value: Tensor(torch.Size([5, 4]), dtype=torch.float32)}, - shared=False, + action_value: Tensor(torch.Size([5, 4]), dtype=torch.float32), + observation: Tensor(torch.Size([5, 4]), dtype=torch.float32)}, batch_size=torch.Size([5]), - device=cpu) + device=None, + is_shared=False) """ @@ -326,9 +326,9 @@ class DistributionalQValueHook(QValueHook): action component. Examples: - >>> import functorch >>> import torch >>> from tensordict import TensorDict + >>> from tensordict.nn.functional_modules import make_functional >>> from torch import nn >>> from torchrl.data import OneHotDiscreteTensorSpec >>> from torchrl.modules.tensordict_module.actors import DistributionalQValueHook, Actor @@ -343,20 +343,21 @@ class DistributionalQValueHook(QValueHook): ... return self.linear(x).view(-1, nbins, 4).log_softmax(-2) ... >>> module = CustomDistributionalQval() - >>> fmodule, params, buffers = functorch.make_functional_with_buffers(module) + >>> params = make_functional(module) >>> action_spec = OneHotDiscreteTensorSpec(4) >>> hook = DistributionalQValueHook("one_hot", support = torch.arange(nbins)) - >>> _ = fmodule.register_forward_hook(hook) - >>> qvalue_actor = Actor(module=fmodule, spec=action_spec, out_keys=["action", "action_value"]) - >>> _ = qvalue_actor(td, params=params, buffers=buffers) + >>> module.register_forward_hook(hook) + >>> qvalue_actor = Actor(module=module, spec=action_spec, out_keys=["action", "action_value"]) + >>> qvalue_actor(td, params=params) >>> print(td) TensorDict( - fields={observation: Tensor(torch.Size([5, 4]), dtype=torch.float32), + fields={ action: Tensor(torch.Size([5, 4]), dtype=torch.int64), - action_value: Tensor(torch.Size([5, 3, 4]), dtype=torch.float32)}, - shared=False, + action_value: Tensor(torch.Size([5, 3, 4]), dtype=torch.float32), + observation: Tensor(torch.Size([5, 4]), dtype=torch.float32)}, batch_size=torch.Size([5]), - device=cpu) + device=None, + is_shared=False) """ @@ -438,27 +439,27 @@ class QValueActor(Actor): This class hooks the module such that it returns a one-hot encoding of the argmax value. Examples: - >>> import functorch >>> import torch >>> from tensordict import TensorDict + >>> from tensordict.nn.functional_modules import make_functional >>> from torch import nn >>> from torchrl.data import OneHotDiscreteTensorSpec >>> from torchrl.modules.tensordict_module.actors import QValueActor >>> td = TensorDict({'observation': torch.randn(5, 4)}, [5]) >>> module = nn.Linear(4, 4) - >>> fmodule, params, buffers = functorch.make_functional_with_buffers(module) + >>> params= make_functional(module) >>> action_spec = OneHotDiscreteTensorSpec(4) - >>> qvalue_actor = QValueActor(module=fmodule, spec=action_spec) - >>> _ = qvalue_actor(td, params=params, buffers=buffers) + >>> qvalue_actor = QValueActor(module=module, spec=action_spec) + >>> qvalue_actor(td, params=params) >>> print(td) TensorDict( fields={ - observation: Tensor(torch.Size([5, 4]), dtype=torch.float32), action: Tensor(torch.Size([5, 4]), dtype=torch.int64), action_value: Tensor(torch.Size([5, 4]), dtype=torch.float32), - chosen_action_value: Tensor(torch.Size([5, 1]), dtype=torch.float32)}, + chosen_action_value: Tensor(torch.Size([5, 1]), dtype=torch.float32), + observation: Tensor(torch.Size([5, 4]), dtype=torch.float32)}, batch_size=torch.Size([5]), - device=cpu, + device=None, is_shared=False) """ @@ -480,7 +481,6 @@ class DistributionalQValueActor(QValueActor): This class hooks the module such that it returns a one-hot encoding of the argmax value on its support. Examples: - >>> import functorch >>> import torch >>> from tensordict import TensorDict >>> from torch import nn @@ -491,15 +491,15 @@ class DistributionalQValueActor(QValueActor): >>> module = MLP(out_features=(nbins, 4), depth=2) >>> action_spec = OneHotDiscreteTensorSpec(4) >>> qvalue_actor = DistributionalQValueActor(module=module, spec=action_spec, support=torch.arange(nbins)) - >>> _ = qvalue_actor(td) + >>> qvalue_actor(td) >>> print(td) TensorDict( fields={ - observation: Tensor(torch.Size([5, 4]), dtype=torch.float32), action: Tensor(torch.Size([5, 4]), dtype=torch.int64), - action_value: Tensor(torch.Size([5, 3, 4]), dtype=torch.float32)}, + action_value: Tensor(torch.Size([5, 3, 4]), dtype=torch.float32), + observation: Tensor(torch.Size([5, 4]), dtype=torch.float32)}, batch_size=torch.Size([5]), - device=cpu, + device=None, is_shared=False) """ @@ -566,7 +566,7 @@ class ActorValueOperator(SafeSequential): Examples: >>> import torch >>> from tensordict import TensorDict - >>> from torchrl.modules.tensordict_module import ProbabilisticActor + >>> from torchrl.modules import ProbabilisticActor, SafeModule >>> from torchrl.data import NdUnboundedContinuousTensorSpec, NdBoundedTensorSpec >>> from torchrl.modules import ValueOperator, TanhNormal, ActorValueOperator, NormalParamWrapper >>> spec_hidden = NdUnboundedContinuousTensorSpec(4) @@ -601,34 +601,40 @@ class ActorValueOperator(SafeSequential): >>> td_clone = td_module(td.clone()) >>> print(td_clone) TensorDict( - fields={observation: Tensor(torch.Size([3, 4]), dtype=torch.float32), - hidden: Tensor(torch.Size([3, 4]), dtype=torch.float32), + fields={ action: Tensor(torch.Size([3, 4]), dtype=torch.float32), + hidden: Tensor(torch.Size([3, 4]), dtype=torch.float32), + loc: Tensor(torch.Size([3, 4]), dtype=torch.float32), + observation: Tensor(torch.Size([3, 4]), dtype=torch.float32), sample_log_prob: Tensor(torch.Size([3, 1]), dtype=torch.float32), + scale: Tensor(torch.Size([3, 4]), dtype=torch.float32), state_value: Tensor(torch.Size([3, 1]), dtype=torch.float32)}, - shared=False, batch_size=torch.Size([3]), - device=cpu) + device=None, + is_shared=False) >>> td_clone = td_module.get_policy_operator()(td.clone()) >>> print(td_clone) # no value TensorDict( - fields={observation: Tensor(torch.Size([3, 4]), dtype=torch.float32), - hidden: Tensor(torch.Size([3, 4]), dtype=torch.float32), + fields={ action: Tensor(torch.Size([3, 4]), dtype=torch.float32), - sample_log_prob: Tensor(torch.Size([3, 1]), dtype=torch.float32)}, - shared=False, + hidden: Tensor(torch.Size([3, 4]), dtype=torch.float32), + loc: Tensor(torch.Size([3, 4]), dtype=torch.float32), + observation: Tensor(torch.Size([3, 4]), dtype=torch.float32), + sample_log_prob: Tensor(torch.Size([3, 1]), dtype=torch.float32), + scale: Tensor(torch.Size([3, 4]), dtype=torch.float32)}, batch_size=torch.Size([3]), - device=cpu) - + device=None, + is_shared=False) >>> td_clone = td_module.get_value_operator()(td.clone()) >>> print(td_clone) # no action TensorDict( - fields={observation: Tensor(torch.Size([3, 4]), dtype=torch.float32), + fields={ hidden: Tensor(torch.Size([3, 4]), dtype=torch.float32), + observation: Tensor(torch.Size([3, 4]), dtype=torch.float32), state_value: Tensor(torch.Size([3, 1]), dtype=torch.float32)}, - shared=False, batch_size=torch.Size([3]), - device=cpu) + device=None, + is_shared=False) """ @@ -692,7 +698,7 @@ class ActorCriticOperator(ActorValueOperator): Examples: >>> import torch >>> from tensordict import TensorDict - >>> from torchrl.modules.tensordict_module import ProbabilisticActor + >>> from torchrl.modules import ProbabilisticActor, SafeModule >>> from torchrl.data import NdUnboundedContinuousTensorSpec, NdBoundedTensorSpec >>> from torchrl.modules import ValueOperator, TanhNormal, ActorCriticOperator, NormalParamWrapper, MLP >>> spec_hidden = NdUnboundedContinuousTensorSpec(4) @@ -734,7 +740,7 @@ class ActorCriticOperator(ActorValueOperator): scale: Tensor(torch.Size([3, 4]), dtype=torch.float32), state_action_value: Tensor(torch.Size([3, 1]), dtype=torch.float32)}, batch_size=torch.Size([3]), - device=cpu, + device=None, is_shared=False) >>> td_clone = td_module.get_policy_operator()(td.clone()) >>> print(td_clone) # no value @@ -747,7 +753,7 @@ class ActorCriticOperator(ActorValueOperator): sample_log_prob: Tensor(torch.Size([3, 1]), dtype=torch.float32), scale: Tensor(torch.Size([3, 4]), dtype=torch.float32)}, batch_size=torch.Size([3]), - device=cpu, + device=None, is_shared=False) >>> td_clone = td_module.get_critic_operator()(td.clone()) >>> print(td_clone) # no action @@ -761,7 +767,7 @@ class ActorCriticOperator(ActorValueOperator): scale: Tensor(torch.Size([3, 4]), dtype=torch.float32), state_action_value: Tensor(torch.Size([3, 1]), dtype=torch.float32)}, batch_size=torch.Size([3]), - device=cpu, + device=None, is_shared=False) """ @@ -823,13 +829,24 @@ class ActorCriticWrapper(SafeSequential): >>> import torch >>> from tensordict import TensorDict >>> from torchrl.data import NdUnboundedContinuousTensorSpec, NdBoundedTensorSpec - >>> from torchrl.modules.tensordict_module.deprec import ProbabilisticActor_deprecated - >>> from torchrl.modules import ValueOperator, TanhNormal, ActorCriticWrapper - >>> spec_action = NdBoundedTensorSpec(-1, 1, torch.Size([8])) - >>> module_action = torch.nn.Linear(4, 8) - >>> td_module_action = ProbabilisticActor_deprecated( - ... module=module_action, - ... spec=spec_action, + >>> from torchrl.modules import ( + ActorCriticWrapper, + ProbabilisticActor, + NormalParamWrapper, + SafeModule, + TanhNormal, + ValueOperator, + ) + >>> action_spec = NdBoundedTensorSpec(-1, 1, torch.Size([8])) + >>> action_module = SafeModule( + NormalParamWrapper(torch.nn.Linear(4, 8)), + in_keys=["observation"], + out_keys=["loc", "scale"], + ) + >>> td_module_action = ProbabilisticActor( + ... module=action_module, + ... spec=action_spec, + ... dist_in_keys=["loc", "scale"], ... distribution_class=TanhNormal, ... return_log_prob=True, ... ) @@ -843,31 +860,37 @@ class ActorCriticWrapper(SafeSequential): >>> td_clone = td_module(td.clone()) >>> print(td_clone) TensorDict( - fields={observation: Tensor(torch.Size([3, 4]), dtype=torch.float32), + fields={ action: Tensor(torch.Size([3, 4]), dtype=torch.float32), + loc: Tensor(torch.Size([3, 4]), dtype=torch.float32), + observation: Tensor(torch.Size([3, 4]), dtype=torch.float32), sample_log_prob: Tensor(torch.Size([3, 1]), dtype=torch.float32), + scale: Tensor(torch.Size([3, 4]), dtype=torch.float32), state_value: Tensor(torch.Size([3, 1]), dtype=torch.float32)}, - shared=False, batch_size=torch.Size([3]), - device=cpu) + device=None, + is_shared=False) >>> td_clone = td_module.get_policy_operator()(td.clone()) >>> print(td_clone) # no value TensorDict( - fields={observation: Tensor(torch.Size([3, 4]), dtype=torch.float32), + fields={ action: Tensor(torch.Size([3, 4]), dtype=torch.float32), - sample_log_prob: Tensor(torch.Size([3, 1]), dtype=torch.float32)}, - shared=False, + loc: Tensor(torch.Size([3, 4]), dtype=torch.float32), + observation: Tensor(torch.Size([3, 4]), dtype=torch.float32), + sample_log_prob: Tensor(torch.Size([3, 1]), dtype=torch.float32), + scale: Tensor(torch.Size([3, 4]), dtype=torch.float32)}, batch_size=torch.Size([3]), - device=cpu) - + device=None, + is_shared=False) >>> td_clone = td_module.get_value_operator()(td.clone()) >>> print(td_clone) # no action TensorDict( - fields={observation: Tensor(torch.Size([3, 4]), dtype=torch.float32), + fields={ + observation: Tensor(torch.Size([3, 4]), dtype=torch.float32), state_value: Tensor(torch.Size([3, 1]), dtype=torch.float32)}, - shared=False, batch_size=torch.Size([3]), - device=cpu) + device=None, + is_shared=False) """ diff --git a/torchrl/modules/tensordict_module/common.py b/torchrl/modules/tensordict_module/common.py index c092197eb7c..d666031d028 100644 --- a/torchrl/modules/tensordict_module/common.py +++ b/torchrl/modules/tensordict_module/common.py @@ -6,6 +6,7 @@ from __future__ import annotations import inspect +import re import warnings from typing import Iterable, Optional, Type, Union @@ -24,10 +25,13 @@ "functional programming should work, but functionality and performance " "may be affected. Consider installing functorch and/or upgrating pytorch." ) - from tensordict.nn.functional_modules import ( - FunctionalModule, - FunctionalModuleWithBuffers, - ) + + class FunctionalModule: # noqa: D101 + pass + + class FunctionalModuleWithBuffers: # noqa: D101 + pass + from tensordict.nn import TensorDictModule from tensordict.tensordict import TensorDictBase @@ -51,33 +55,44 @@ def _check_all_str(list_of_str, first_level=True): def _forward_hook_safe_action(module, tensordict_in, tensordict_out): - spec = module.spec - if len(module.out_keys) > 1 and not isinstance(spec, CompositeSpec): - raise RuntimeError( - "safe SafeModules with multiple out_keys require a CompositeSpec with matching keys. Got " - f"keys {module.out_keys}." - ) - elif not isinstance(spec, CompositeSpec): - out_key = module.out_keys[0] - keys = [out_key] - values = [spec] - else: - keys = list(spec.keys()) - values = [spec[key] for key in keys] - for _spec, _key in zip(values, keys): - if _spec is None: - continue - if not _spec.is_in(tensordict_out.get(_key)): - try: - tensordict_out.set_( - _key, - _spec.project(tensordict_out.get(_key)), - ) - except RuntimeError: - tensordict_out.set( - _key, - _spec.project(tensordict_out.get(_key)), - ) + try: + spec = module.spec + if len(module.out_keys) > 1 and not isinstance(spec, CompositeSpec): + raise RuntimeError( + "safe SafeModules with multiple out_keys require a CompositeSpec with matching keys. Got " + f"keys {module.out_keys}." + ) + elif not isinstance(spec, CompositeSpec): + out_key = module.out_keys[0] + keys = [out_key] + values = [spec] + else: + keys = list(spec.keys()) + values = [spec[key] for key in keys] + for _spec, _key in zip(values, keys): + if _spec is None: + continue + if not _spec.is_in(tensordict_out.get(_key)): + try: + tensordict_out.set_( + _key, + _spec.project(tensordict_out.get(_key)), + ) + except RuntimeError: + tensordict_out.set( + _key, + _spec.project(tensordict_out.get(_key)), + ) + except RuntimeError as err: + if re.search( + "attempting to use a Tensor in some data-dependent control flow", str(err) + ): + # "_is_stateless" in module.__dict__ and module._is_stateless: + raise RuntimeError( + f"vmap cannot be used with safe=True, consider turning the safe mode off. (original error message: {err})" + ) + else: + raise err class SafeModule(TensorDictModule): @@ -103,34 +118,35 @@ class SafeModule(TensorDictModule): case, the 'params' (and 'buffers') keyword argument must be specified: Examples: - >>> import functorch >>> import torch >>> from tensordict import TensorDict + >>> from tensordict.nn.functional_modules import make_functional >>> from torchrl.data import NdUnboundedContinuousTensorSpec >>> from torchrl.modules import SafeModule >>> td = TensorDict({"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3,]) >>> spec = NdUnboundedContinuousTensorSpec(8) >>> module = torch.nn.GRUCell(4, 8) - >>> fmodule, params, buffers = functorch.make_functional_with_buffers(module) >>> td_fmodule = SafeModule( - ... module=fmodule, + ... module=module, ... spec=spec, ... in_keys=["input", "hidden"], ... out_keys=["output"], ... ) - >>> td_functional = td_fmodule(td.clone(), params=params, buffers=buffers) + >>> params = make_functional(td_fmodule) + >>> td_functional = td_fmodule(td.clone(), params=params) >>> print(td_functional) TensorDict( - fields={input: Tensor(torch.Size([3, 4]), dtype=torch.float32), + fields={ hidden: Tensor(torch.Size([3, 8]), dtype=torch.float32), + input: Tensor(torch.Size([3, 4]), dtype=torch.float32), output: Tensor(torch.Size([3, 8]), dtype=torch.float32)}, - shared=False, batch_size=torch.Size([3]), - device=cpu) + device=None, + is_shared=False) In the stateful case: >>> td_module = SafeModule( - ... module=module, + ... module=torch.nn.GRUCell(4, 8), ... spec=spec, ... in_keys=["input", "hidden"], ... out_keys=["output"], @@ -138,27 +154,29 @@ class SafeModule(TensorDictModule): >>> td_stateful = td_module(td.clone()) >>> print(td_stateful) TensorDict( - fields={input: Tensor(torch.Size([3, 4]), dtype=torch.float32), + fields={ hidden: Tensor(torch.Size([3, 8]), dtype=torch.float32), + input: Tensor(torch.Size([3, 4]), dtype=torch.float32), output: Tensor(torch.Size([3, 8]), dtype=torch.float32)}, - shared=False, batch_size=torch.Size([3]), - device=cpu) + device=None, + is_shared=False) One can use a vmap operator to call the functional module. In this case the tensordict is expanded to match the batch size (i.e. the tensordict isn't modified in-place anymore): >>> # Model ensemble using vmap - >>> params_repeat = tuple(param.expand(4, *param.shape).contiguous().normal_() for param in params) - >>> buffers_repeat = tuple(param.expand(4, *param.shape).contiguous().normal_() for param in buffers) - >>> td_vmap = td_fmodule(td.clone(), params=params_repeat, buffers=buffers_repeat, vmap=True) + >>> from functorch import vmap + >>> params_repeat = params.expand(4, *params.shape) + >>> td_vmap = vmap(td_fmodule, (None, 0))(td.clone(), params_repeat) >>> print(td_vmap) TensorDict( - fields={input: Tensor(torch.Size([4, 3, 4]), dtype=torch.float32), + fields={ hidden: Tensor(torch.Size([4, 3, 8]), dtype=torch.float32), + input: Tensor(torch.Size([4, 3, 4]), dtype=torch.float32), output: Tensor(torch.Size([4, 3, 8]), dtype=torch.float32)}, - shared=False, batch_size=torch.Size([4, 3]), - device=cpu) + device=None, + is_shared=False) """ diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index fe3aac62df9..2dd65bb339d 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -65,7 +65,7 @@ class EGreedyWrapper(TensorDictModuleWrapper): [ 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000], - [ 0.0000, 0.0000, 0.0000, 0.0000]], grad_fn=) + [ 0.0000, 0.0000, 0.0000, 0.0000]], grad_fn=) """ @@ -285,10 +285,19 @@ class OrnsteinUhlenbeckProcessWrapper(TensorDictModuleWrapper): >>> torch.manual_seed(0) >>> spec = NdBoundedTensorSpec(-1, 1, torch.Size([4])) >>> module = torch.nn.Linear(4, 4, bias=False) - >>> policy = Actor(spec, module=module) + >>> policy = Actor(module=module, spec=spec) >>> explorative_policy = OrnsteinUhlenbeckProcessWrapper(policy) >>> td = TensorDict({"observation": torch.zeros(10, 4)}, batch_size=[10]) >>> print(explorative_policy(td)) + TensorDict( + fields={ + _ou_prev_noise: Tensor(torch.Size([10, 4]), dtype=torch.float32), + _ou_steps: Tensor(torch.Size([10, 1]), dtype=torch.int64), + action: Tensor(torch.Size([10, 4]), dtype=torch.float32), + observation: Tensor(torch.Size([10, 4]), dtype=torch.float32)}, + batch_size=torch.Size([10]), + device=None, + is_shared=False) """ def __init__( diff --git a/torchrl/modules/tensordict_module/probabilistic.py b/torchrl/modules/tensordict_module/probabilistic.py index 3061d1017fa..022977f378a 100644 --- a/torchrl/modules/tensordict_module/probabilistic.py +++ b/torchrl/modules/tensordict_module/probabilistic.py @@ -77,53 +77,57 @@ class of interest, e.g. :obj:`"loc"` and :obj:`"scale"` for the Normal distribut Default is 1000 Examples: - >>> import functorch >>> import torch >>> from tensordict import TensorDict - >>> from torchrl.data import NdUnboundedContinuousTensorSpec - >>> from torchrl.modules import SafeProbabilisticModule, TanhNormal, NormalParamWrapper + >>> from tensordict.nn.functional_modules import make_functional + >>> from torchrl.data import CompositeSpec, NdUnboundedContinuousTensorSpec + >>> from torchrl.modules import ( + NormalParamWrapper, + SafeModule, + SafeProbabilisticModule, + TanhNormal, + ) >>> td = TensorDict({"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3,]) - >>> spec = NdUnboundedContinuousTensorSpec(4) + >>> spec = CompositeSpec(action=NdUnboundedContinuousTensorSpec(4), loc=None, scale=None) >>> net = NormalParamWrapper(torch.nn.GRUCell(4, 8)) - >>> fnet, params, buffers = functorch.make_functional_with_buffers(net) - >>> module = SafeModule(fnet, in_keys=["input", "hidden"], out_keys=["loc", "scale"]) + >>> module = SafeModule(net, in_keys=["input", "hidden"], out_keys=["loc", "scale"]) >>> td_module = SafeProbabilisticModule( - ... module=module, - ... spec=spec, - ... dist_in_keys=["loc", "scale"], - ... sample_out_key=["action"], - ... distribution_class=TanhNormal, - ... return_log_prob=True, - ... ) - >>> _ = td_module(td, params=params, buffers=buffers) + ... module=module, + ... spec=spec, + ... dist_in_keys=["loc", "scale"], + ... sample_out_key=["action"], + ... distribution_class=TanhNormal, + ... return_log_prob=True, + ... ) + >>> params = make_functional(td_module) + >>> td_module(td, params=params) >>> print(td) TensorDict( fields={ - input: Tensor(torch.Size([3, 4]), dtype=torch.float32), + action: Tensor(torch.Size([3, 4]), dtype=torch.float32), hidden: Tensor(torch.Size([3, 8]), dtype=torch.float32), + input: Tensor(torch.Size([3, 4]), dtype=torch.float32), loc: Tensor(torch.Size([3, 4]), dtype=torch.float32), - scale: Tensor(torch.Size([3, 4]), dtype=torch.float32), - action: Tensor(torch.Size([3, 4]), dtype=torch.float32), - sample_log_prob: Tensor(torch.Size([3, 1]), dtype=torch.float32)}, + sample_log_prob: Tensor(torch.Size([3, 1]), dtype=torch.float32), + scale: Tensor(torch.Size([3, 4]), dtype=torch.float32)}, batch_size=torch.Size([3]), - device=cpu, + device=None, is_shared=False) - >>> # In the vmap case, the tensordict is again expended to match the batch: - >>> params = tuple(p.expand(4, *p.shape).contiguous().normal_() for p in params) - >>> buffers = tuple(b.expand(4, *b.shape).contiguous().normal_() for p in buffers) - >>> td_vmap = td_module(td, params=params, buffers=buffers, vmap=True) + >>> from functorch import vmap + >>> params = params.expand(4, *params.shape) + >>> td_vmap = vmap(td_module, (None, 0))(td, params) >>> print(td_vmap) TensorDict( fields={ - input: Tensor(torch.Size([4, 3, 4]), dtype=torch.float32), + action: Tensor(torch.Size([4, 3, 4]), dtype=torch.float32), hidden: Tensor(torch.Size([4, 3, 8]), dtype=torch.float32), + input: Tensor(torch.Size([4, 3, 4]), dtype=torch.float32), loc: Tensor(torch.Size([4, 3, 4]), dtype=torch.float32), - scale: Tensor(torch.Size([4, 3, 4]), dtype=torch.float32), - action: Tensor(torch.Size([4, 3, 4]), dtype=torch.float32), - sample_log_prob: Tensor(torch.Size([4, 3, 1]), dtype=torch.float32)}, + sample_log_prob: Tensor(torch.Size([4, 3, 1]), dtype=torch.float32), + scale: Tensor(torch.Size([4, 3, 4]), dtype=torch.float32)}, batch_size=torch.Size([4, 3]), - device=cpu, + device=None, is_shared=False) """ diff --git a/torchrl/modules/tensordict_module/sequence.py b/torchrl/modules/tensordict_module/sequence.py index bbc3323630f..5e4886b70b2 100644 --- a/torchrl/modules/tensordict_module/sequence.py +++ b/torchrl/modules/tensordict_module/sequence.py @@ -31,77 +31,74 @@ class SafeSequential(TensorDictSequential, SafeModule): TensorDictSequence supports functional, modular and vmap coding: Examples: - >>> import functorch >>> import torch >>> from tensordict import TensorDict - >>> from torchrl.data import NdUnboundedContinuousTensorSpec - >>> from torchrl.modules import TanhNormal, SafeSequential, NormalParamWrapper + >>> from tensordict.nn.functional_modules import make_functional + >>> from torchrl.data import CompositeSpec, NdUnboundedContinuousTensorSpec + >>> from torchrl.modules import TanhNormal, SafeSequential, SafeModule, NormalParamWrapper >>> from torchrl.modules.tensordict_module import SafeProbabilisticModule >>> td = TensorDict({"input": torch.randn(3, 4)}, [3,]) - >>> spec1 = NdUnboundedContinuousTensorSpec(4) + >>> spec1 = CompositeSpec(hidden=NdUnboundedContinuousTensorSpec(4), loc=None, scale=None) >>> net1 = NormalParamWrapper(torch.nn.Linear(4, 8)) - >>> fnet1, params1, buffers1 = functorch.make_functional_with_buffers(net1) - >>> fmodule1 = SafeModule(fnet1, in_keys=["input"], out_keys=["loc", "scale"]) + >>> module1 = SafeModule(net1, in_keys=["input"], out_keys=["loc", "scale"]) >>> td_module1 = SafeProbabilisticModule( - ... module=fmodule1, - ... spec=spec1, - ... dist_in_keys=["loc", "scale"], - ... sample_out_key=["hidden"], - ... distribution_class=TanhNormal, - ... return_log_prob=True, - ... ) + ... module=module1, + ... spec=spec1, + ... dist_in_keys=["loc", "scale"], + ... sample_out_key=["hidden"], + ... distribution_class=TanhNormal, + ... return_log_prob=True, + ... ) >>> spec2 = NdUnboundedContinuousTensorSpec(8) >>> module2 = torch.nn.Linear(4, 8) - >>> fmodule2, params2, buffers2 = functorch.make_functional_with_buffers(module2) >>> td_module2 = SafeModule( - ... module=fmodule2, + ... module=module2, ... spec=spec2, ... in_keys=["hidden"], ... out_keys=["output"], ... ) >>> td_module = SafeSequential(td_module1, td_module2) - >>> params = params1 + params2 - >>> buffers = buffers1 + buffers2 - >>> _ = td_module(td, params=params, buffers=buffers) + >>> params = make_functional(td_module) + >>> td_module(td, params=params) >>> print(td) TensorDict( fields={ + hidden: Tensor(torch.Size([3, 4]), dtype=torch.float32), input: Tensor(torch.Size([3, 4]), dtype=torch.float32), loc: Tensor(torch.Size([3, 4]), dtype=torch.float32), - scale: Tensor(torch.Size([3, 4]), dtype=torch.float32), - hidden: Tensor(torch.Size([3, 4]), dtype=torch.float32), + output: Tensor(torch.Size([3, 8]), dtype=torch.float32), sample_log_prob: Tensor(torch.Size([3, 1]), dtype=torch.float32), - output: Tensor(torch.Size([3, 8]), dtype=torch.float32)}, + scale: Tensor(torch.Size([3, 4]), dtype=torch.float32)}, batch_size=torch.Size([3]), - device=cpu, + device=None, is_shared=False) - >>> # The module spec aggregates all the input specs: >>> print(td_module.spec) CompositeSpec( hidden: NdUnboundedContinuousTensorSpec( - shape=torch.Size([4]),space=None,device=cpu,dtype=torch.float32,domain=continuous), + shape=torch.Size([4]), space=None, device=cpu, dtype=torch.float32, domain=continuous), + loc: None, + scale: None, output: NdUnboundedContinuousTensorSpec( - shape=torch.Size([8]),space=None,device=cpu,dtype=torch.float32,domain=continuous)) + shape=torch.Size([8]), space=None, device=cpu, dtype=torch.float32, domain=continuous)) In the vmap case: - >>> params = tuple(p.expand(4, *p.shape).contiguous().normal_() for p in params) - >>> buffers = tuple(b.expand(4, *b.shape).contiguous().normal_() for p in buffers) - >>> td_vmap = td_module(td, params=params, buffers=buffers, vmap=True) + >>> from functorch import vmap + >>> params = params.expand(4, *params.shape) + >>> td_vmap = vmap(td_module, (None, 0))(td, params) >>> print(td_vmap) TensorDict( fields={ + hidden: Tensor(torch.Size([4, 3, 4]), dtype=torch.float32), input: Tensor(torch.Size([4, 3, 4]), dtype=torch.float32), loc: Tensor(torch.Size([4, 3, 4]), dtype=torch.float32), - scale: Tensor(torch.Size([4, 3, 4]), dtype=torch.float32), - hidden: Tensor(torch.Size([4, 3, 4]), dtype=torch.float32), + output: Tensor(torch.Size([4, 3, 8]), dtype=torch.float32), sample_log_prob: Tensor(torch.Size([4, 3, 1]), dtype=torch.float32), - output: Tensor(torch.Size([4, 3, 8]), dtype=torch.float32)}, + scale: Tensor(torch.Size([4, 3, 4]), dtype=torch.float32)}, batch_size=torch.Size([4, 3]), - device=cpu, + device=None, is_shared=False) - """ module: nn.ModuleList diff --git a/torchrl/modules/utils/__init__.py b/torchrl/modules/utils/__init__.py index 4af16165f7c..ef430b85391 100644 --- a/torchrl/modules/utils/__init__.py +++ b/torchrl/modules/utils/__init__.py @@ -3,4 +3,88 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from collections import OrderedDict + +import torch +from packaging import version + +if version.parse(torch.__version__) >= version.parse("1.12.0"): + from torch.nn.parameter import _disabled_torch_function_impl, _ParameterMeta +else: + from torch.nn.parameter import _disabled_torch_function_impl + + # Metaclass to combine _TensorMeta and the instance check override for Parameter. + class _ParameterMeta(torch._C._TensorMeta): + # Make `isinstance(t, Parameter)` return True for custom tensor instances that have the _is_param flag. + def __instancecheck__(self, instance): + return super().__instancecheck__(instance) or ( + isinstance(instance, torch.Tensor) + and getattr(instance, "_is_param", False) + ) + + from .mappings import biased_softplus, inv_softplus, mappings + + +class Buffer(torch.Tensor, metaclass=_ParameterMeta): + r"""A kind of Tensor that is to be considered a module parameter. + + Parameters are :class:`~torch.Tensor` subclasses, that have a + very special property when used with :class:`Module` s - when they're + assigned as Module attributes they are automatically added to the list of + its parameters, and will appear e.g. in :meth:`~Module.parameters` iterator. + Assigning a Tensor doesn't have such effect. This is because one might + want to cache some temporary state, like last hidden state of the RNN, in + the model. If there was no such class as :class:`Parameter`, these + temporaries would get registered too. + + Args: + data (Tensor): parameter tensor. + requires_grad (bool, optional): if the parameter requires gradient. See + :ref:`locally-disable-grad-doc` for more details. Default: `True` + """ + + def __new__(cls, data=None, requires_grad=False): + if data is None: + data = torch.empty(0) + if type(data) is torch.Tensor or type(data) is Buffer: + # For ease of BC maintenance, keep this path for standard Tensor. + # Eventually (tm), we should change the behavior for standard Tensor to match. + return torch.Tensor._make_subclass(cls, data, requires_grad) + + # Path for custom tensors: set a flag on the instance to indicate parameter-ness. + t = data.detach().requires_grad_(requires_grad) + if type(t) is not type(data): + raise RuntimeError( + f"Creating a Parameter from an instance of type {type(data).__name__} " + "requires that detach() returns an instance of the same type, but return " + f"type {type(t).__name__} was found instead. To use the type as a " + "Parameter, please correct the detach() semantics defined by " + "its __torch_dispatch__() implementation." + ) + t._is_param = True + return t + + # Note: the 3 methods below only apply to standard Tensor. Parameters of custom tensor types + # are still considered that custom tensor type and these methods will not be called for them. + def __deepcopy__(self, memo): + if id(self) in memo: + return memo[id(self)] + else: + result = type(self)( + self.data.clone(memory_format=torch.preserve_format), self.requires_grad + ) + memo[id(self)] = result + return result + + def __repr__(self): + return "Buffer containing:\n" + super(Buffer, self).__repr__() + + def __reduce_ex__(self, proto): + # See Note [Don't serialize hooks] + return ( + torch._utils._rebuild_parameter, + (self.data, self.requires_grad, OrderedDict()), + ) + + __torch_function__ = _disabled_torch_function_impl diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index af20007b26a..e6b6d0852a6 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -54,7 +54,9 @@ def __init__( advantage_module: Optional[Callable[[TensorDictBase], TensorDictBase]] = None, ): super().__init__() - self.convert_to_functional(actor, "actor") + self.convert_to_functional( + actor, "actor", funs_to_decorate=["forward", "get_dist"] + ) self.convert_to_functional(critic, "critic", compare_against=self.actor_params) self.advantage_key = advantage_key self.advantage_diff_key = advantage_diff_key @@ -93,7 +95,8 @@ def _log_probs( tensordict_clone = tensordict.select(*self.actor.in_keys).clone() dist, *_ = self.actor.get_dist( - tensordict_clone, params=self.actor_params, buffers=self.actor_buffers + tensordict_clone, + params=self.actor_params, ) log_prob = dist.log_prob(action) log_prob = log_prob.unsqueeze(-1) @@ -117,7 +120,6 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: value = self.critic( tensordict_select, params=self.critic_params, - buffers=self.critic_buffers, ).get("state_value") value_target = advantage + value.detach() loss_value = distance_loss( diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index a7c90521a5a..312be720433 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -5,17 +5,27 @@ from __future__ import annotations +import itertools +from copy import deepcopy from typing import Iterator, List, Optional, Tuple, Union import torch -from tensordict.nn.functional_modules import FunctionalModuleWithBuffers + +from tensordict.nn import make_functional, repopulate_module + +from tensordict.tensordict import TensorDictBase +from torch import nn, Tensor +from torch.nn import Parameter + +from torchrl.modules import SafeModule +from torchrl.modules.utils import Buffer _has_functorch = False try: - import functorch - from functorch._src.make_functional import _swap_state + import functorch as ft # noqa _has_functorch = True + FUNCTORCH_ERR = "" except ImportError: print( "failed to import functorch. TorchRL's features that do not require " @@ -24,12 +34,6 @@ ) FUNCTORCH_ERROR = "functorch not installed. Consider installing functorch to use this functionality." -from tensordict.tensordict import TensorDict, TensorDictBase -from torch import nn, Tensor -from torch.nn import Parameter - -from torchrl.modules import SafeModule - class LossModule(nn.Module): """A parent class for RL losses. @@ -44,6 +48,7 @@ class LossModule(nn.Module): def __init__(self): super().__init__() self._param_maps = {} + # self.register_forward_pre_hook(_parameters_to_tensordict) def forward(self, tensordict: TensorDictBase) -> TensorDictBase: """It is designed to read an input TensorDict and return another tensordict with loss keys named "loss*". @@ -69,376 +74,208 @@ def convert_to_functional( expand_dim: Optional[int] = None, create_target_params: bool = False, compare_against: Optional[List[Parameter]] = None, + funs_to_decorate=None, ) -> None: - if _has_functorch: - return self._convert_to_functional_functorch( - module, - module_name, - expand_dim, - create_target_params, - compare_against, - ) - else: - return self._convert_to_functional_native( - module, - module_name, - expand_dim, - create_target_params, - compare_against, - ) - - def _convert_to_functional_functorch( - self, - module: SafeModule, - module_name: str, - expand_dim: Optional[int] = None, - create_target_params: bool = False, - compare_against: Optional[List[Parameter]] = None, - ) -> None: + if funs_to_decorate is None: + funs_to_decorate = ["forward"] # To make it robust to device casting, we must register list of # tensors as lazy calls to `getattr(self, name_of_tensor)`. # Otherwise, casting the module to a device will keep old references # to uncast tensors - - network_orig = module - if hasattr(module, "make_functional_with_buffers"): - functional_module, ( - _, - module_buffers, - ) = module.make_functional_with_buffers(clone=True) - else: - ( - functional_module, - module_params, - module_buffers, - ) = functorch.make_functional_with_buffers(module) - - for _ in functional_module.parameters(): - # Erase meta params - none_state = [None for _ in module_params + module_buffers] - if hasattr(functional_module, "all_names_map"): - # functorch >= 0.2.0 - _swap_state( - functional_module.stateless_model, - functional_module.all_names_map, - none_state, - ) - else: - # functorch < 0.2.0 - _swap_state( - functional_module.stateless_model, - functional_module.split_names, - none_state, - ) - break - del module_params - - param_name = module_name + "_params" - - # we keep the original parameters and not the copy returned by functorch - params = network_orig.parameters() - - # unless we need to expand them, in that case we'll delete the weights to make sure that the user does not - # run anything with them expecting them to be updated - params = list(params) - module_buffers = list(module_buffers) - + try: + buffer_names = next(itertools.islice(zip(*module.named_buffers()), 1)) + except StopIteration: + buffer_names = () + params = make_functional(module, funs_to_decorate=funs_to_decorate) + functional_module = deepcopy(module) + repopulate_module(module, params) + + params_and_buffers = params + # we transform the buffers in params to make sure they follow the device + # as tensor = nn.Parameter(tensor) keeps its identity when moved to another device + + def create_buffers(tensor): + + if isinstance(tensor, torch.Tensor) and not isinstance( + tensor, (Buffer, nn.Parameter) + ): + return Buffer(tensor, requires_grad=tensor.requires_grad) + return tensor + + # separate params and buffers + params_and_buffers = params_and_buffers.apply(create_buffers) + for key in params_and_buffers.keys(True): + if "_sep_" in key: + raise KeyError( + f"The key {key} contains the '_sep_' pattern which is prohibited. Consider renaming the parameter / buffer." + ) + params_and_buffers_flat = params_and_buffers.flatten_keys("_sep_") + buffers = params_and_buffers_flat.select(*buffer_names) + params = params_and_buffers_flat.exclude(*buffer_names) + + if expand_dim and not _has_functorch: + raise ImportError( + "expanding params is only possible when functorch is installed," + "as this feature requires calls to the vmap operator." + ) if expand_dim: + # Expands the dims of params and buffers. + # If the param already exist in the module, we return a simple expansion of the + # original one. Otherwise, we expand and resample it. + # For buffers, a cloned expansion (or equivalently a repeat) is returned. if compare_against is not None: compare_against = set(compare_against) else: compare_against = set() - for i, p in enumerate(params): - if p in compare_against: - # expanded parameters are 'detached': the parameter will not - # be trained to minimize loss involving this network. - p_out = p.data.expand(expand_dim, *p.shape) + + def _compare_and_expand(param): + + if param in compare_against: + expanded_param = param.data.expand(expand_dim, *param.shape) # the expanded parameter must be sent to device when to() # is called: - self._param_maps[p_out] = p + return expanded_param else: - p_out = p.repeat(expand_dim, *[1 for _ in p.shape]) + p_out = param.repeat(expand_dim, *[1 for _ in param.shape]) p_out = nn.Parameter( p_out.uniform_( p_out.min().item(), p_out.max().item() ).requires_grad_() ) - params[i] = p_out + return p_out - for i, b in enumerate(module_buffers): - b = b.expand(expand_dim, *b.shape).clone() - module_buffers[i] = b - - # # delete weights of original model as they do not correspond to the optimized weights - # network_orig.to("meta") - - params_list = params - set_params = set(self.parameters()) - setattr( - self, - "_" + param_name, - nn.ParameterList( - [ - p - for p in params_list - if isinstance(p, nn.Parameter) and p not in set_params - ] - ), - ) - setattr(self, param_name, params) - - buffer_name = module_name + "_buffers" - # we register each buffer independently - for i, p in enumerate(module_buffers): - _name = module_name + f"_buffer_{i}" - self.register_buffer(_name, p) - # replace buffer by its name - module_buffers[i] = _name - setattr( - self.__class__, - buffer_name, - property(lambda _self: [getattr(_self, _name) for _name in module_buffers]), - ) - - # we set the functional module - setattr(self, module_name, functional_module) - - name_params_target = "_target_" + param_name - name_buffers_target = "_target_" + buffer_name - if create_target_params: - target_params = [p.detach().clone() for p in getattr(self, param_name)] - for i, p in enumerate(target_params): - name = "_".join([name_params_target, str(i)]) - self.register_buffer(name, p) - target_params[i] = name - setattr( - self.__class__, - name_params_target, - property( - lambda _self: [getattr(_self, _name) for _name in target_params] - ), + params_udpated = params.apply( + _compare_and_expand, batch_size=[expand_dim, *params.shape] ) - target_buffers = [p.detach().clone() for p in getattr(self, buffer_name)] - for i, p in enumerate(target_buffers): - name = "_".join([name_buffers_target, str(i)]) - self.register_buffer(name, p) - target_buffers[i] = name - setattr( - self.__class__, - name_buffers_target, - property( - lambda _self: [getattr(_self, _name) for _name in target_buffers] - ), + params = params_udpated + buffers = buffers.apply( + lambda buffer: Buffer(buffer.expand(expand_dim, *buffer.shape).clone()), + batch_size=[expand_dim, *buffers.shape], ) - else: - setattr(self.__class__, name_params_target, None) - setattr(self.__class__, name_buffers_target, None) + params_and_buffers.update(params.unflatten_keys("_sep_")) + params_and_buffers.update(buffers.unflatten_keys("_sep_")) + params_and_buffers.batch_size = params.batch_size - setattr( - self.__class__, - name_params_target[1:], - property(lambda _self: self._target_param_getter(module_name)), - ) - setattr( - self.__class__, - name_buffers_target[1:], - property(lambda _self: self._target_buffer_getter(module_name)), - ) - - def _convert_to_functional_native( - self, - module: SafeModule, - module_name: str, - expand_dim: Optional[int] = None, - create_target_params: bool = False, - compare_against: Optional[List[Parameter]] = None, - ) -> None: - # To make it robust to device casting, we must register list of - # tensors as lazy calls to `getattr(self, name_of_tensor)`. - # Otherwise, casting the module to a device will keep old references - # to uncast tensors - - network_orig = module - if hasattr(module, "make_functional_with_buffers"): - functional_module, ( - params, - module_buffers, - ) = module.make_functional_with_buffers(clone=True) - else: - ( - functional_module, - params, - module_buffers, - ) = FunctionalModuleWithBuffers._create_from(module) + # self.params_to_map = params_to_map param_name = module_name + "_params" - # params must be retrieved directly because make_functional will copy the content - params_vals = TensorDict( - {name: value for name, value in network_orig.named_parameters()}, [] - ) - # rename params_vals keys to match params: otherwise we'll have to deal with - # module.module.param or such names. We assume that there is a constant prefix - # and that, when sorted, all keys will match. We could check that the values - # do match too. - keys1 = sorted(params.flatten_keys(".").keys()) - keys2 = sorted(params_vals.keys()) - for key1, key2 in zip(keys1, keys2): - params_vals.rename_key(key2, key1) - params = params_vals.unflatten_keys(".") + prev_set_params = set(self.parameters()) - if expand_dim: - raise ImportError( - "expanding params is only possible when functorch is installed," - "as this feature requires calls to the vmap operator." - ) + # register parameters and buffers + for key, parameter in params.items(): + if parameter not in prev_set_params: + setattr(self, "_sep_".join([module_name, key]), parameter) + else: + for _param_name, p in self.named_parameters(): + if parameter is p: + break + else: + raise RuntimeError("parameter not found") + setattr(self, "_sep_".join([module_name, key]), _param_name) + prev_set_buffers = set(self.buffers()) + for key, buffer in buffers.items(): + if buffer not in prev_set_buffers: + self.register_buffer("_sep_".join([module_name, key]), buffer) + else: + for _buffer_name, b in self.named_buffers(): + if buffer is b: + break + else: + raise RuntimeError("buffer not found") + setattr(self, "_sep_".join([module_name, key]), _buffer_name) - params_list = list(params.flatten_keys(".").values()) - set_params = set(self.parameters()) - setattr( - self, - "_" + param_name, - nn.ParameterList( - [ - p - for p in params_list - if isinstance(p, nn.Parameter) and p not in set_params - ] - ), - ) - setattr(self, param_name, params) - - buffer_name = module_name + "_buffers" - buffers_iter = list(module_buffers.flatten_keys(".").items()) - module_buffers_list = [] - for i, (key, value) in enumerate(sorted(buffers_iter)): - _name = module_name + f"_buffer_{i}" - self.register_buffer(_name, value) - # replace buffer by its name - module_buffers_list.append((_name, key)) + setattr(self, "_" + param_name, params_and_buffers) setattr( self.__class__, - buffer_name, - property( - lambda _self: TensorDict( - { - key: getattr(_self, _name) - for (_name, key) in module_buffers_list - }, - [], - device=self.device, - ).unflatten_keys(".") - ), + param_name, + property(lambda _self=self: _self._param_getter(module_name)), ) - # we set the functional module + # set the functional module setattr(self, module_name, functional_module) - name_params_target = "_target_" + param_name - name_buffers_target = "_target_" + buffer_name + # creates a map nn.Parameter name -> expanded parameter name + for key, value in params.items(True, True): + if not isinstance(key, tuple): + key = (key,) + if not isinstance(value, nn.Parameter): + # find the param name + for name, param in self.named_parameters(): + if param.data.data_ptr() == value.data_ptr() and param is not value: + self._param_maps[name] = "_sep_".join([module_name, *key]) + break + else: + raise RuntimeError("did not find matching param.") + + name_params_target = "_target_" + module_name if create_target_params: - target_params = getattr(self, param_name).detach().clone() - target_params_items = sorted(target_params.flatten_keys(".").items()) + target_params = params_and_buffers.detach().clone() + target_params_items = target_params.items(True, True) target_params_list = [] - for i, (key, val) in enumerate(target_params_items): - name = "_".join([name_params_target, str(i)]) - self.register_buffer(name, val) + for (key, val) in target_params_items: + if not isinstance(key, tuple): + key = (key,) + name = "_sep_".join([name_params_target, *key]) + self.register_buffer(name, Buffer(val)) target_params_list.append((name, key)) - setattr( - self.__class__, - name_params_target, - property( - lambda _self: TensorDict( - { - key: getattr(_self, _name) - for (_name, key) in target_params_list - }, - [], - device=self.device, - ).unflatten_keys(".") - ), - ) - - target_buffers = getattr(self, buffer_name).detach().clone() - target_buffers_items = sorted(target_buffers.flatten_keys(".").items()) - target_buffers_list = [] - for i, (key, val) in enumerate(target_buffers_items): - name = "_".join([name_buffers_target, str(i)]) - self.register_buffer(name, val) - target_buffers_list.append((name, key)) - setattr( - self.__class__, - name_buffers_target, - property( - lambda _self: TensorDict( - { - key: getattr(_self, _name) - for (_name, key) in target_buffers_list - }, - [], - device=self.device, - ).unflatten_keys(".") - ), - ) - + setattr(self, name_params_target + "_params", target_params) else: - setattr(self.__class__, name_params_target, None) - setattr(self.__class__, name_buffers_target, None) - - setattr( - self.__class__, - name_params_target[1:], - property(lambda _self: self._target_param_getter(module_name)), - ) + setattr(self, name_params_target + "_params", None) setattr( self.__class__, - name_buffers_target[1:], - property(lambda _self: self._target_buffer_getter(module_name)), + name_params_target[1:] + "_params", + property(lambda _self=self: _self._target_param_getter(module_name)), ) - def _target_param_getter(self, network_name): - target_name = "_target_" + network_name + "_params" + def _param_getter(self, network_name): + name = "_" + network_name + "_params" param_name = network_name + "_params" - if hasattr(self, target_name): - target_params = getattr(self, target_name) - if target_params is not None: - if isinstance(target_params, TensorDictBase): - return target_params - return tuple(target_params) + if name in self.__dict__: + params = getattr(self, name) + if params is not None: + # get targets and update + for key in params.keys(True, True): + if not isinstance(key, tuple): + key = (key,) + value_to_set = getattr(self, "_sep_".join([network_name, *key])) + if isinstance(value_to_set, str): + value_to_set = getattr(self, value_to_set).detach() + params.set(key, value_to_set) + return params else: params = getattr(self, param_name) - if isinstance(params, TensorDictBase): - return params.detach() - else: - # detach params as a surrogate for targets - return tuple(p.detach() for p in params) + return params.detach() else: raise RuntimeError( - f"{self.__class__.__name__} does not have the target param {target_name}" + f"{self.__class__.__name__} does not have the target param {name}" ) - def _target_buffer_getter(self, network_name): - target_name = "_target_" + network_name + "_buffers" - buffer_name = network_name + "_buffers" - if hasattr(self, target_name): - target_buffers = getattr(self, target_name) - if target_buffers is not None: - if isinstance(target_buffers, TensorDictBase): - return target_buffers - return tuple(target_buffers) + def _target_param_getter(self, network_name): + target_name = "_target_" + network_name + "_params" + param_name = network_name + "_params" + if target_name in self.__dict__: + target_params = getattr(self, target_name) + if target_params is not None: + # get targets and update + for key in target_params.keys(True, True): + if not isinstance(key, tuple): + key = (key,) + value_to_set = getattr( + self, "_sep_".join(["_target_" + network_name, *key]) + ) + target_params.set(key, value_to_set) + return target_params else: - buffers = getattr(self, buffer_name) - if isinstance(buffers, TensorDictBase): - return buffers.detach() - else: - return tuple(p.detach() for p in buffers) + params = getattr(self, param_name) + return params.detach() else: raise RuntimeError( - f"{self.__class__.__name__} does not have the target buffer {target_name}" + f"{self.__class__.__name__} does not have the target param {target_name}" ) def _networks(self) -> Iterator[nn.Module]: @@ -455,7 +292,7 @@ def device(self) -> torch.device: def register_buffer( self, name: str, tensor: Optional[Tensor], persistent: bool = True ) -> None: - tensor = tensor.to(self.device) + # tensor = tensor.to(self.device) return super().register_buffer(name, tensor, persistent) def parameters(self, recurse: bool = True) -> Iterator[Parameter]: @@ -476,16 +313,23 @@ def reset(self) -> None: def to(self, *args, **kwargs): # get the names of the parameters to map out = super().to(*args, **kwargs) - lists_of_params = { - name: value - for name, value in self.__dict__.items() - if name.endswith("_params") and (type(value) is list) - } - for _, list_of_params in lists_of_params.items(): - for i, param in enumerate(list_of_params): - # we replace the param by the expanded form if needs be - if param in self._param_maps: - list_of_params[i] = self._param_maps[param].data.expand_as(param) + for origin, target in self._param_maps.items(): + origin_value = getattr(self, origin) + target_value = getattr(self, target) + setattr(self, target, origin_value.expand_as(target_value)) + + # lists_of_params = { + # name: value + # for name, value in self.__dict__.items() + # if name.endswith("_params") and isinstance(value, TensorDictBase) + # } + # for list_of_params in lists_of_params.values(): + # for key, param in list(list_of_params.items(True)): + # if isinstance(param, TensorDictBase): + # continue + # # we replace the param by the expanded form if needs be + # if param in self._param_maps: + # list_of_params[key] = self._param_maps[param].data.expand_as(param) return out def cuda(self, device: Optional[Union[int, device]] = None) -> LossModule: diff --git a/torchrl/objectives/ddpg.py b/torchrl/objectives/ddpg.py index 5692f109739..42999226605 100644 --- a/torchrl/objectives/ddpg.py +++ b/torchrl/objectives/ddpg.py @@ -5,9 +5,12 @@ from __future__ import annotations +from copy import deepcopy + from typing import Tuple import torch +from tensordict.nn import make_functional, repopulate_module from tensordict.tensordict import TensorDict, TensorDictBase from torchrl.modules import SafeModule @@ -46,6 +49,13 @@ def __init__( super().__init__() self.delay_actor = delay_actor self.delay_value = delay_value + + actor_critic = ActorCriticWrapper(actor_network, value_network) + params = make_functional(actor_critic) + self.actor_critic = deepcopy(actor_critic) + repopulate_module(actor_network, params["module", "0"]) + repopulate_module(value_network, params["module", "1"]) + self.convert_to_functional( actor_network, "actor_network", @@ -57,6 +67,8 @@ def __init__( create_target_params=self.delay_value, compare_against=list(actor_network.parameters()), ) + self.actor_critic.module[0] = self.actor_network + self.actor_critic.module[1] = self.value_network self.actor_in_keys = actor_network.in_keys @@ -116,11 +128,11 @@ def _loss_actor( td_copy = self.actor_network( td_copy, params=self.actor_network_params, - buffers=self.actor_network_buffers, ) with hold_out_params(self.value_network_params) as params: td_copy = self.value_network( - td_copy, params=params, buffers=self.value_network_buffers + td_copy, + params=params, ) return -td_copy.get("state_action_value") @@ -133,16 +145,19 @@ def _loss_value( self.value_network( td_copy, params=self.value_network_params, - buffers=self.value_network_buffers, ) pred_val = td_copy.get("state_action_value").squeeze(-1) - actor_critic = ActorCriticWrapper(self.actor_network, self.value_network) - target_params = list(self.target_actor_network_params) + list( - self.target_value_network_params - ) - target_buffers = list(self.target_actor_network_buffers) + list( - self.target_value_network_buffers + actor_critic = self.actor_critic + target_params = TensorDict( + { + "module": { + "0": self.target_actor_network_params, + "1": self.target_value_network_params, + } + }, + batch_size=self.target_actor_network_params.batch_size, + device=self.target_actor_network_params.device, ) with set_exploration_mode("mode"): target_value = next_state_value( @@ -150,7 +165,6 @@ def _loss_value( actor_critic, gamma=self.gamma, params=target_params, - buffers=target_buffers, ) # td_error = pred_val - target_value diff --git a/torchrl/objectives/deprecated.py b/torchrl/objectives/deprecated.py index 555525161a4..b80cf2854ff 100644 --- a/torchrl/objectives/deprecated.py +++ b/torchrl/objectives/deprecated.py @@ -4,10 +4,10 @@ import numpy as np import torch + from tensordict import TensorDict from tensordict.tensordict import TensorDictBase from torch import Tensor - from torchrl.envs.utils import set_exploration_mode, step_mdp from torchrl.modules import SafeModule from torchrl.objectives import ( @@ -17,6 +17,15 @@ ) from torchrl.objectives.common import LossModule +try: + from functorch import vmap + + FUNCTORCH_ERR = "" + _has_functorch = True +except ImportError as err: + FUNCTORCH_ERR = str(err) + _has_functorch = False + class REDQLoss_deprecated(LossModule): """REDQ Loss module. @@ -66,6 +75,10 @@ def __init__( delay_qvalue: bool = True, gSDE: bool = False, ): + if not _has_functorch: + raise ImportError( + f"Failed to import functorch with error message:\n{FUNCTORCH_ERR}" + ) super().__init__() self.convert_to_functional( actor_network, @@ -163,18 +176,12 @@ def _actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]: self.actor_network( tensordict_clone, params=self.actor_network_params, - buffers=self.actor_network_buffers, ) with hold_out_params(self.qvalue_network_params) as params: - tensordict_expand = self.qvalue_network( + tensordict_expand = vmap(self.qvalue_network, (None, 0))( tensordict_clone.select(*self.qvalue_network.in_keys), - tensordict_out=TensorDict( - {}, [self.num_qvalue_nets, *tensordict_clone.shape] - ), - params=params, - buffers=self.qvalue_network_buffers, - vmap=True, + params, ) state_action_value = tensordict_expand.get("state_action_value").squeeze(-1) loss_actor = -( @@ -193,12 +200,7 @@ def _qvalue_loss(self, tensordict: TensorDictBase) -> Tensor: : self.sub_sample_len ].sort()[0] with torch.no_grad(): - selected_q_params = [ - p[selected_models_idx] for p in self.target_qvalue_network_params - ] - selected_q_buffers = [ - b[selected_models_idx] for b in self.target_qvalue_network_buffers - ] + selected_q_params = self.target_qvalue_network_params[selected_models_idx] next_td = step_mdp(tensordict).select( *self.actor_network.in_keys @@ -208,17 +210,13 @@ def _qvalue_loss(self, tensordict: TensorDictBase) -> Tensor: with set_exploration_mode("random"): self.actor_network( next_td, - params=list(self.target_actor_network_params), - buffers=self.target_actor_network_buffers, + params=self.target_actor_network_params, ) sample_log_prob = next_td.get("sample_log_prob") # get q-values - next_td = self.qvalue_network( + next_td = vmap(self.qvalue_network, (None, 0))( next_td, - tensordict_out=TensorDict({}, [self.sub_sample_len, *next_td.shape]), - params=selected_q_params, - buffers=selected_q_buffers, - vmap=True, + selected_q_params, ) state_value = ( next_td.get("state_action_value") - self.alpha * sample_log_prob @@ -231,12 +229,9 @@ def _qvalue_loss(self, tensordict: TensorDictBase) -> Tensor: gamma=self.gamma, pred_next_val=state_value, ) - tensordict_expand = self.qvalue_network( + tensordict_expand = vmap(self.qvalue_network, (None, 0))( tensordict.select(*self.qvalue_network.in_keys), - tensordict_out=TensorDict({}, [self.num_qvalue_nets, *tensordict.shape]), - params=list(self.qvalue_network_params), - buffers=self.qvalue_network_buffers, - vmap=True, + self.qvalue_network_params, ) pred_val = tensordict_expand.get("state_action_value").squeeze(-1) td_error = abs(pred_val - target_value) diff --git a/torchrl/objectives/dqn.py b/torchrl/objectives/dqn.py index 99b82f1404a..444616ac364 100644 --- a/torchrl/objectives/dqn.py +++ b/torchrl/objectives/dqn.py @@ -91,10 +91,10 @@ def forward(self, input_tensordict: TensorDictBase) -> TensorDict: td_copy = tensordict.clone() if td_copy.device != tensordict.device: raise RuntimeError(f"{tensordict} and {td_copy} have different devices") + assert hasattr(self.value_network, "_is_stateless") self.value_network( td_copy, params=self.value_network_params, - buffers=self.value_network_buffers, ) action = tensordict.get("action") @@ -112,7 +112,6 @@ def forward(self, input_tensordict: TensorDictBase) -> TensorDict: self.value_network, gamma=self.gamma, params=self.target_value_network_params, - buffers=self.target_value_network_buffers, next_val_key="chosen_action_value", ) priority_tensor = (pred_val_index - target_value).pow(2) @@ -201,7 +200,7 @@ def forward(self, input_tensordict: TensorDictBase) -> TensorDict: "tensordict as input" ) batch_size = tensordict.batch_size[0] - support = self.value_network.support + support = self.value_network_params["support"] atoms = support.numel() Vmin = support.min().item() Vmax = support.max().item() @@ -220,7 +219,6 @@ def forward(self, input_tensordict: TensorDictBase) -> TensorDict: self.value_network( td_clone, params=self.value_network_params, - buffers=self.value_network_buffers, ) # Log probabilities log p(s_t, ·; θonline) action_log_softmax = td_clone.get("action_value") @@ -237,7 +235,6 @@ def forward(self, input_tensordict: TensorDictBase) -> TensorDict: self.value_network( next_td, params=self.value_network_params, - buffers=self.value_network_buffers, ) # Probabilities p(s_t+n, ·; θonline) next_td_action = next_td.get("action") @@ -249,7 +246,6 @@ def forward(self, input_tensordict: TensorDictBase) -> TensorDict: self.value_network( next_td, params=self.target_value_network_params, - buffers=self.target_value_network_buffers, ) # Probabilities p(s_t+n, ·; θtarget) pns = next_td.get("action_value").exp() # Double-Q probabilities diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 2926e2c667a..130df12fd3d 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -65,7 +65,9 @@ def __init__( advantage_module: Optional[Callable[[TensorDictBase], TensorDictBase]] = None, ): super().__init__() - self.convert_to_functional(actor, "actor") + self.convert_to_functional( + actor, "actor", funs_to_decorate=["forward", "get_dist"] + ) # we want to make sure there are no duplicates in the params: the # params of critic must be refs to actor if they're shared self.convert_to_functional(critic, "critic", compare_against=self.actor_params) @@ -106,7 +108,8 @@ def _log_weight( tensordict_clone = tensordict.select(*self.actor.in_keys).clone() dist, *_ = self.actor.get_dist( - tensordict_clone, params=self.actor_params, buffers=self.actor_buffers + tensordict_clone, + params=self.actor_params, ) log_prob = dist.log_prob(action) log_prob = log_prob.unsqueeze(-1) @@ -136,7 +139,6 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: value = self.critic( tensordict_select, params=self.critic_params, - buffers=self.critic_buffers, ).get("state_value") value_target = advantage + value.detach() loss_value = distance_loss( @@ -363,7 +365,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: previous_dist = self.actor.build_dist_from_params(tensordict_clone) current_dist, *_ = self.actor.get_dist( - tensordict_clone, params=self.actor_params, buffers=self.actor_buffers + tensordict_clone, + params=self.actor_params, ) try: kl = torch.distributions.kl.kl_divergence(previous_dist, current_dist) diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index 70b3d6a3e8d..d8b28bc677b 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -9,18 +9,28 @@ import numpy as np import torch + +from tensordict.nn import TensorDictSequential from tensordict.tensordict import TensorDict, TensorDictBase from torch import Tensor from torchrl.envs.utils import set_exploration_mode, step_mdp from torchrl.modules import SafeModule -from torchrl.objectives.common import _has_functorch, LossModule +from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( distance_loss, - hold_out_params, next_state_value as get_next_state_value, ) +try: + from functorch import vmap + + FUNCTORCH_ERR = "" + _has_functorch = True +except ImportError as err: + FUNCTORCH_ERR = str(err) + _has_functorch = False + class REDQLoss(LossModule): """REDQ Loss module. @@ -75,13 +85,16 @@ def __init__( gSDE: bool = False, ): if not _has_functorch: - raise ImportError("REDQ requires functorch to be installed.") + raise ImportError( + f"Failed to import functorch with error message:\n{FUNCTORCH_ERR}" + ) super().__init__() self.convert_to_functional( actor_network, "actor_network", create_target_params=self.delay_actor, + funs_to_decorate=["forward", "get_dist_params"], ) # let's make sure that actor_network has `return_log_prob` set to True @@ -152,25 +165,11 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: selected_models_idx = torch.randperm(self.num_qvalue_nets)[ : self.sub_sample_len ].sort()[0] - selected_q_params = [ - p[selected_models_idx] for p in self.target_qvalue_network_params - ] - selected_q_buffers = [ - b[selected_models_idx] for b in self.target_qvalue_network_buffers - ] - - actor_params = [ - torch.stack([p1, p2], 0) - for p1, p2 in zip( - self.actor_network_params, self.target_actor_network_params - ) - ] - actor_buffers = [ - torch.stack([p1, p2], 0) - for p1, p2 in zip( - self.actor_network_buffers, self.target_actor_network_buffers - ) - ] + selected_q_params = self.target_qvalue_network_params[selected_models_idx] + + actor_params = torch.stack( + [self.actor_network_params, self.target_actor_network_params], 0 + ) tensordict_actor_grad = tensordict_select.select( *obs_keys @@ -187,60 +186,58 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: "_eps_gSDE", torch.zeros(tensordict_actor.shape, device=tensordict_actor.device), ) - tensordict_actor = self.actor_network( + # vmap doesn't support sampling, so we take it out from the vmap + td_params = vmap(self.actor_network.get_dist_params)( tensordict_actor, - params=actor_params, - buffers=actor_buffers, - vmap=(0, 0, 0), + actor_params, + ) + if isinstance(self.actor_network, TensorDictSequential): + sample_key = self.actor_network[-1].sample_out_key[0] + tensordict_actor_dist = self.actor_network[-1].build_dist_from_params( + td_params + ) + else: + sample_key = self.actor_network.sample_out_key[0] + tensordict_actor_dist = self.actor_network.build_dist_from_params( + td_params + ) + tensordict_actor[sample_key] = tensordict_actor_dist.rsample() + tensordict_actor["sample_log_prob"] = tensordict_actor_dist.log_prob( + tensordict_actor[sample_key] ) # repeat tensordict_actor to match the qvalue size + _actor_loss_td = ( + tensordict_actor[0] + .select(*self.qvalue_network.in_keys) + .expand(self.num_qvalue_nets, *tensordict_actor[0].batch_size) + ) # for actor loss + _qval_td = tensordict_select.select(*self.qvalue_network.in_keys).expand( + self.num_qvalue_nets, + *tensordict_select.select(*self.qvalue_network.in_keys).batch_size, + ) # for qvalue loss + _next_val_td = ( + tensordict_actor[1] + .select(*self.qvalue_network.in_keys) + .expand(self.sub_sample_len, *tensordict_actor[1].batch_size) + ) # for next value estimation tensordict_qval = torch.cat( [ - tensordict_actor[0] - .select(*self.qvalue_network.in_keys) - .expand( - self.num_qvalue_nets, *tensordict_actor[0].batch_size - ), # for actor loss - tensordict_actor[1] - .select(*self.qvalue_network.in_keys) - .expand( - self.sub_sample_len, *tensordict_actor[1].batch_size - ), # for next value estimation - tensordict_select.select(*self.qvalue_network.in_keys).expand( - self.num_qvalue_nets, - *tensordict_select.select(*self.qvalue_network.in_keys).batch_size, - ), # for qvalue loss + _actor_loss_td, + _next_val_td, + _qval_td, ], 0, ) # cat params - q_params_detach = hold_out_params(self.qvalue_network_params).params - qvalue_params = [ - torch.cat([p1, p2, p3], 0) - for p1, p2, p3 in zip( - q_params_detach, selected_q_params, self.qvalue_network_params - ) - ] - qvalue_buffers = [ - torch.cat([p1, p2, p3], 0) - for p1, p2, p3 in zip( - self.qvalue_network_buffers, - selected_q_buffers, - self.qvalue_network_buffers, - ) - ] - tensordict_qval = self.qvalue_network( + q_params_detach = self.qvalue_network_params.detach() + qvalue_params = torch.cat( + [q_params_detach, selected_q_params, self.qvalue_network_params], 0 + ) + tensordict_qval = vmap(self.qvalue_network)( tensordict_qval, - tensordict_out=TensorDict({}, tensordict_qval.shape), - params=qvalue_params, - buffers=qvalue_buffers, - vmap=( - 0, - 0, - 0, - ), # TensorDict vmap will take care of expanding the tuple as needed + qvalue_params, ) state_action_value = tensordict_qval.get("state_action_value").squeeze(-1) diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py index 294f79c50ec..719710de5b6 100644 --- a/torchrl/objectives/reinforce.py +++ b/torchrl/objectives/reinforce.py @@ -67,9 +67,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict = self.advantage_module( tensordict, params=self.critic_params, - buffers=self.critic_buffers, target_params=self.target_critic_params, - target_buffers=self.target_critic_buffers, ) advantage = tensordict.get(self.advantage_key) @@ -77,7 +75,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict = self.actor_network( tensordict, params=self.actor_network_params, - buffers=self.actor_network_buffers, ) log_prob = tensordict.get("sample_log_prob") @@ -108,14 +105,12 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: next_value = self.critic( next_td, params=self.critic_params, - buffers=self.critic_buffers, ).get("state_value") value_target = reward + next_value * self.gamma tensordict_select = tensordict.select(*self.critic.in_keys).clone() value = self.critic( tensordict_select, params=self.critic_params, - buffers=self.critic_buffers, ).get("state_value") loss_value = distance_loss( diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index bfc5e088813..79ed9cee11c 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -9,6 +9,7 @@ import numpy as np import torch +from tensordict.nn import make_functional from tensordict.tensordict import TensorDict, TensorDictBase from torch import Tensor @@ -19,6 +20,15 @@ from ..envs.utils import set_exploration_mode from .common import LossModule +try: + from functorch import vmap + + _has_functorch = True + err = "" +except ImportError as err: + _has_functorch = False + FUNCTORCH_ERROR = str(err) + class SACLoss(LossModule): """TorchRL implementation of the SAC loss. @@ -83,6 +93,10 @@ def __init__( delay_qvalue: bool = False, delay_value: bool = False, ) -> None: + if not _has_functorch: + raise ImportError( + f"Failed to import functorch with error message:\n{FUNCTORCH_ERROR}" + ) super().__init__() # Actor @@ -91,6 +105,10 @@ def __init__( actor_network, "actor_network", create_target_params=self.delay_actor, + funs_to_decorate=[ + "forward", + "get_dist", + ], ) # Value @@ -150,14 +168,12 @@ def __init__( self.register_buffer( "target_entropy", torch.tensor(target_entropy, device=device) ) + self.actor_critic = ActorCriticWrapper(self.actor_network, self.value_network) + make_functional(self.actor_critic) @property def device(self) -> torch.device: - for p in self.actor_network_params: - return p.device - for p in self.qvalue_network_params: - return p.device - for p in self.value_network_params: + for p in self.parameters(): return p.device raise RuntimeError( "At least one of the networks of SACLoss must have trainable " "parameters." @@ -198,8 +214,7 @@ def _loss_actor(self, tensordict: TensorDictBase) -> Tensor: with set_exploration_mode("random"): dist = self.actor_network.get_dist( tensordict, - params=list(self.actor_network_params), - buffers=list(self.actor_network_buffers), + params=self.actor_network_params, )[0] a_reparm = dist.rsample() # if not self.actor_network.spec.is_in(a_reparm): @@ -208,11 +223,8 @@ def _loss_actor(self, tensordict: TensorDictBase) -> Tensor: td_q = tensordict.select(*self.qvalue_network.in_keys) td_q.set("action", a_reparm) - td_q = self.qvalue_network( - td_q, - params=list(self.target_qvalue_network_params), - buffers=list(self.qvalue_network_buffers), - vmap=True, + td_q = vmap(self.qvalue_network, (None, 0))( + td_q, self.target_qvalue_network_params ) min_q_logprob = td_q.get("state_action_value").min(0)[0].squeeze(-1) @@ -226,12 +238,16 @@ def _loss_actor(self, tensordict: TensorDictBase) -> Tensor: return self._alpha * log_prob - min_q_logprob def _loss_qvalue(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]: - actor_critic = ActorCriticWrapper(self.actor_network, self.value_network) - params = list(self.target_actor_network_params) + list( - self.target_value_network_params - ) - buffers = list(self.target_actor_network_buffers) + list( - self.target_value_network_buffers + actor_critic = self.actor_critic + params = TensorDict( + { + "module": { + "0": self.target_actor_network_params, + "1": self.target_value_network_params, + } + }, + [], + _run_checks=False, ) with set_exploration_mode("mode"): target_value = next_state_value( @@ -240,7 +256,6 @@ def _loss_qvalue(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]: gamma=self.gamma, next_val_key="state_value", params=params, - buffers=buffers, ) # value loss @@ -260,16 +275,8 @@ def _loss_qvalue(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]: target_chunks = torch.stack(target_value.chunk(self.num_qvalue_nets, dim=0), 0) # if vmap=True, it is assumed that the input tensordict must be cast to the param shape - tensordict_chunks = qvalue_network( - tensordict_chunks, - params=list(self.qvalue_network_params), - buffers=list(self.qvalue_network_buffers), - vmap=( - 0, - 0, - 0, - 0, - ), + tensordict_chunks = vmap(qvalue_network)( + tensordict_chunks, self.qvalue_network_params ) pred_val = tensordict_chunks.get("state_action_value").squeeze(-1) loss_value = distance_loss( @@ -284,18 +291,14 @@ def _loss_value(self, tensordict: TensorDictBase) -> Tensor: td_copy = tensordict.select(*self.value_network.in_keys).detach() self.value_network( td_copy, - params=list(self.value_network_params), - buffers=list(self.value_network_buffers), + params=self.value_network_params, ) pred_val = td_copy.get("state_value").squeeze(-1) - action_dist = self.actor_network.get_dist( + action_dist, *_ = self.actor_network.get_dist( td_copy, - params=list(self.target_actor_network_params), - buffers=list(self.target_actor_network_buffers), - )[ - 0 - ] # resample an action + params=self.target_actor_network_params, + ) # resample an action action = action_dist.rsample() # if not self.actor_network.spec.is_in(action): # action.data.copy_(self.actor_network.spec.project(action.data)) @@ -303,11 +306,9 @@ def _loss_value(self, tensordict: TensorDictBase) -> Tensor: td_copy.set("action", action, inplace=False) qval_net = self.qvalue_network - td_copy = qval_net( + td_copy = vmap(qval_net, (None, 0))( td_copy, - params=list(self.target_qvalue_network_params), - buffers=list(self.target_qvalue_network_buffers), - vmap=True, + self.target_qvalue_network_params, ) min_qval = td_copy.get("state_action_value").squeeze(-1).min(0)[0] diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index 4f2da57c93a..ec93e4135c1 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -4,11 +4,10 @@ # LICENSE file in the root directory of this source tree. import functools -from collections import OrderedDict from typing import Iterable, Optional, Union import torch -from tensordict.tensordict import TensorDictBase +from tensordict.tensordict import TensorDict, TensorDictBase from torch import nn, Tensor from torch.nn import functional as F @@ -97,7 +96,7 @@ def __init__( # for properties for name in loss_module.__class__.__dict__: if ( - name.startswith("_target_") + name.startswith("target_") and (name.endswith("params") or name.endswith("buffers")) and (getattr(loss_module, name) is not None) ): @@ -106,12 +105,12 @@ def __init__( # for regular lists: raise an exception for name in loss_module.__dict__: if ( - name.startswith("_target_") + name.startswith("target_") and (name.endswith("params") or name.endswith("buffers")) and (getattr(loss_module, name) is not None) ): raise RuntimeError( - "Your module seems to have a _target tensor list contained " + "Your module seems to have a target tensor list contained " "in a non-dynamic structure (such as a list). If the " "module is cast onto a device, the reference to these " "tensors will be lost." @@ -119,10 +118,10 @@ def __init__( if len(_target_names) == 0: raise RuntimeError( - "Did not found any target parameters or buffers in the loss module." + "Did not find any target parameters or buffers in the loss module." ) - _source_names = ["".join(name.split("_target_")) for name in _target_names] + _source_names = ["".join(name.split("target_")) for name in _target_names] for _source in _source_names: try: @@ -140,28 +139,28 @@ def __init__( @property def _targets(self): - return OrderedDict( - {name: getattr(self.loss_module, name) for name in self._target_names} + return TensorDict( + {name: getattr(self.loss_module, name) for name in self._target_names}, + [], ) @property def _sources(self): - return OrderedDict( - {name: getattr(self.loss_module, name) for name in self._source_names} + return TensorDict( + {name: getattr(self.loss_module, name) for name in self._source_names}, + [], ) def init_(self) -> None: - for source, target in zip(self._sources.values(), self._targets.values()): - if isinstance(source, TensorDictBase) and not source.is_empty(): - # native functional modules - source = list(zip(*sorted(source.items())))[1] - target = list(zip(*sorted(target.items())))[1] - elif isinstance(source, TensorDictBase) and source.is_empty(): - continue - for p_source, p_target in zip(source, target): - if p_target.requires_grad: - raise RuntimeError("the target parameter is part of a graph.") - p_target.data.copy_(p_source.data) + for key, source in self._sources.items(True, True): + if not isinstance(key, tuple): + key = (key,) + key = ("target_" + key[0], *key[1:]) + target = self._targets[key] + # for p_source, p_target in zip(source, target): + if target.requires_grad: + raise RuntimeError("the target parameter is part of a graph.") + target.data.copy_(source.data) self.initialized = True def step(self) -> None: @@ -170,29 +169,25 @@ def step(self) -> None: f"{self.__class__.__name__} must be " f"initialized (`{self.__class__.__name__}.init_()`) before calling step()" ) - - for source, target in zip(self._sources.values(), self._targets.values()): - if isinstance(source, TensorDictBase) and not source.is_empty(): - # native functional modules - source = list(zip(*sorted(source.items())))[1] - target = list(zip(*sorted(target.items())))[1] - elif isinstance(source, TensorDictBase) and source.is_empty(): - continue - for p_source, p_target in zip(source, target): - if p_target.requires_grad: - raise RuntimeError("the target parameter is part of a graph.") - if p_source.is_leaf: - self._step(p_source, p_target) - else: - p_target.copy_(p_source) + for key, source in self._sources.items(True, True): + if not isinstance(key, tuple): + key = (key,) + key = ("target_" + key[0], *key[1:]) + target = self._targets[key] + if target.requires_grad: + raise RuntimeError("the target parameter is part of a graph.") + if target.is_leaf: + self._step(source, target) + else: + target.copy_(source) def _step(self, p_source: Tensor, p_target: Tensor) -> None: raise NotImplementedError def __repr__(self) -> str: string = ( - f"{self.__class__.__name__}(sources={list(self._sources)}, targets=" - f"{list(self._targets)})" + f"{self.__class__.__name__}(sources={self._sources}, targets=" + f"{self._targets})" ) return string @@ -281,7 +276,10 @@ class hold_out_params(_context_manager): """Context manager to hold a list of parameters out of a computational graph.""" def __init__(self, params: Iterable[Tensor]) -> None: - self.params = tuple(p.detach() for p in params) + if isinstance(params, TensorDictBase): + self.params = params.detach() + else: + self.params = tuple(p.detach() for p in params) def __enter__(self) -> None: return self.params diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index 6ee6ef3503b..4f558feecef 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -46,20 +46,24 @@ def __init__( super().__init__() self.register_buffer("gamma", torch.tensor(gamma)) self.value_network = value_network - self.is_functional = value_network.is_functional self.average_rewards = average_rewards self.gradient_mode = gradient_mode self.value_key = value_key + @property + def is_functional(self): + return ( + "_is_stateless" in self.value_network.__dict__ + and self.value_network.__dict__["_is_stateless"] + ) + def forward( self, tensordict: TensorDictBase, *unused_args, params: Optional[List[Tensor]] = None, - buffers: Optional[List[Tensor]] = None, target_params: Optional[List[Tensor]] = None, - target_buffers: Optional[List[Tensor]] = None, ) -> TensorDictBase: """Computes the GAE given the data in tensordict. @@ -93,8 +97,6 @@ def forward( ) if params is not None: kwargs["params"] = params - if buffers is not None: - kwargs["buffers"] = buffers self.value_network(tensordict, **kwargs) value = tensordict.get(self.value_key) @@ -107,10 +109,6 @@ def forward( kwargs["params"] = target_params elif "params" in kwargs: kwargs["params"] = [param.detach() for param in kwargs["params"]] - if target_buffers is not None: - kwargs["buffers"] = target_buffers - elif "buffers" in kwargs: - kwargs["buffers"] = [buffer.detach() for buffer in kwargs["buffers"]] self.value_network(step_td, **kwargs) next_value = step_td.get(self.value_key) @@ -154,21 +152,25 @@ def __init__( self.register_buffer("gamma", torch.tensor(gamma)) self.register_buffer("lmbda", torch.tensor(lmbda)) self.value_network = value_network - self.is_functional = value_network.is_functional self.vectorized = vectorized self.average_rewards = average_rewards self.gradient_mode = gradient_mode self.value_key = value_key + @property + def is_functional(self): + return ( + "_is_stateless" in self.value_network.__dict__ + and self.value_network.__dict__["_is_stateless"] + ) + def forward( self, tensordict: TensorDictBase, *unused_args, params: Optional[List[Tensor]] = None, - buffers: Optional[List[Tensor]] = None, target_params: Optional[List[Tensor]] = None, - target_buffers: Optional[List[Tensor]] = None, ) -> TensorDictBase: """Computes the GAE given the data in tensordict. @@ -204,8 +206,6 @@ def forward( ) if params is not None: kwargs["params"] = params - if buffers is not None: - kwargs["buffers"] = buffers self.value_network(tensordict, **kwargs) value = tensordict.get(self.value_key) @@ -218,10 +218,6 @@ def forward( kwargs["params"] = target_params elif "params" in kwargs: kwargs["params"] = [param.detach() for param in kwargs["params"]] - if target_buffers is not None: - kwargs["buffers"] = target_buffers - elif "buffers" in kwargs: - kwargs["buffers"] = [buffer.detach() for buffer in kwargs["buffers"]] self.value_network(step_td, **kwargs) next_value = step_td.get(self.value_key) @@ -270,19 +266,23 @@ def __init__( self.register_buffer("gamma", torch.tensor(gamma)) self.register_buffer("lmbda", torch.tensor(lmbda)) self.value_network = value_network - self.is_functional = value_network.is_functional self.average_rewards = average_rewards self.gradient_mode = gradient_mode + @property + def is_functional(self): + return ( + "_is_stateless" in self.value_network.__dict__ + and self.value_network.__dict__["_is_stateless"] + ) + def forward( self, tensordict: TensorDictBase, *unused_args, params: Optional[List[Tensor]] = None, - buffers: Optional[List[Tensor]] = None, target_params: Optional[List[Tensor]] = None, - target_buffers: Optional[List[Tensor]] = None, ) -> TensorDictBase: """Computes the GAE given the data in tensordict. @@ -316,8 +316,6 @@ def forward( ) if params is not None: kwargs["params"] = params - if buffers is not None: - kwargs["buffers"] = buffers self.value_network(tensordict, **kwargs) value = tensordict.get("state_value") @@ -330,10 +328,6 @@ def forward( kwargs["params"] = target_params elif "params" in kwargs: kwargs["params"] = [param.detach() for param in kwargs["params"]] - if target_buffers is not None: - kwargs["buffers"] = target_buffers - elif "buffers" in kwargs: - kwargs["buffers"] = [buffer.detach() for buffer in kwargs["buffers"]] self.value_network(step_td, **kwargs) next_value = step_td.get("state_value") done = tensordict.get("done")