From 218d5bf70b9eb52cbead567f70021ec35eafaede Mon Sep 17 00:00:00 2001 From: kurtamohler Date: Sat, 9 Nov 2024 13:40:00 -0800 Subject: [PATCH] [Feature] Add `LossModule.reset_parameters_recursive` (#2546) --- test/test_cost.py | 217 +++++++++++++++++++++++++++++++++++ torchrl/objectives/common.py | 24 ++++ 2 files changed, 241 insertions(+) diff --git a/test/test_cost.py b/test/test_cost.py index 0b36f5b8961..1b54b8bf111 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -195,6 +195,12 @@ def get_devices(): class LossModuleTestBase: + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + assert hasattr( + cls, "test_reset_parameters_recursive" + ), "Please add a test_reset_parameters_recursive test for this class" + def _flatten_in_keys(self, in_keys): return [ in_key if isinstance(in_key, str) else "_".join(list(unravel_keys(in_key))) @@ -252,6 +258,42 @@ def set_advantage_keys_through_loss_test( getattr(test_fn.value_estimator.tensor_keys, advantage_key) == new_key ) + @classmethod + def reset_parameters_recursive_test(cls, loss_fn): + def get_params(loss_fn): + for key, item in loss_fn.__dict__.items(): + if isinstance(item, nn.Module): + module_name = key + params_name = f"{module_name}_params" + target_name = f"target_{module_name}_params" + params = loss_fn._modules.get(params_name, None) + target = loss_fn._modules.get(target_name, None) + + if params is not None: + yield params_name, params._param_td + + else: + for subparam_name, subparam in loss_fn.named_parameters(): + if module_name in subparam_name: + yield subparam_name, subparam + + if target is not None: + yield target_name, target + + old_params = {} + + for param_name, param in get_params(loss_fn): + with torch.no_grad(): + # Change the parameter to ensure that reset will change it again + param += 1000 + old_params[param_name] = param.clone() + + loss_fn.reset_parameters_recursive() + + for param_name, param in get_params(loss_fn): + old_param = old_params[param_name] + assert (param != old_param).any() + @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("vmap_randomness", (None, "different", "same", "error")) @@ -494,6 +536,11 @@ def _create_seq_mock_data_dqn( ) return td + def test_reset_parameters_recursive(self): + actor = self._create_mock_actor(action_spec_type="one_hot") + loss_fn = DQNLoss(actor) + self.reset_parameters_recursive_test(loss_fn) + @pytest.mark.parametrize( "delay_value,double_dqn", ([False, False], [True, False], [True, True]) ) @@ -1066,6 +1113,12 @@ def _create_mock_data_dqn( td.refine_names(None, "time") return td + def test_reset_parameters_recursive(self): + actor = self._create_mock_actor(action_spec_type="one_hot") + mixer = self._create_mock_mixer() + loss_fn = QMixerLoss(actor, mixer) + self.reset_parameters_recursive_test(loss_fn) + @pytest.mark.parametrize("delay_value", (False, True)) @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("action_spec_type", ("one_hot", "categorical")) @@ -1570,6 +1623,12 @@ def _create_seq_mock_data_ddpg( ) return td + def test_reset_parameters_recursive(self): + actor = self._create_mock_actor() + value = self._create_mock_value() + loss_fn = DDPGLoss(actor, value) + self.reset_parameters_recursive_test(loss_fn) + @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("delay_actor,delay_value", [(False, False), (True, True)]) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) @@ -2210,6 +2269,16 @@ def _create_seq_mock_data_td3( ) return td + def test_reset_parameters_recursive(self): + actor = self._create_mock_actor() + value = self._create_mock_value() + loss_fn = TD3Loss( + actor, + value, + bounds=(-1, 1), + ) + self.reset_parameters_recursive_test(loss_fn) + @pytest.mark.skipif(not _has_functorch, reason="functorch not installed") @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize( @@ -2916,6 +2985,16 @@ def _create_seq_mock_data_td3bc( ) return td + def test_reset_parameters_recursive(self): + actor = self._create_mock_actor() + value = self._create_mock_value() + loss_fn = TD3BCLoss( + actor, + value, + bounds=(-1, 1), + ) + self.reset_parameters_recursive_test(loss_fn) + @pytest.mark.skipif(not _has_functorch, reason="functorch not installed") @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize( @@ -3720,6 +3799,20 @@ def _create_seq_mock_data_sac( ) return td + def test_reset_parameters_recursive(self, version): + actor = self._create_mock_actor() + qvalue = self._create_mock_qvalue() + if version == 1: + value = self._create_mock_value() + else: + value = None + loss_fn = SACLoss( + actor_network=actor, + qvalue_network=qvalue, + value_network=value, + ) + self.reset_parameters_recursive_test(loss_fn) + @pytest.mark.parametrize("delay_value", (True, False)) @pytest.mark.parametrize("delay_actor", (True, False)) @pytest.mark.parametrize("delay_qvalue", (True, False)) @@ -4591,6 +4684,17 @@ def _create_seq_mock_data_sac( ) return td + def test_reset_parameters_recursive(self): + actor = self._create_mock_actor() + qvalue = self._create_mock_qvalue() + loss_fn = DiscreteSACLoss( + actor_network=actor, + qvalue_network=qvalue, + num_actions=actor.spec["action"].space.n, + action_space="one-hot", + ) + self.reset_parameters_recursive_test(loss_fn) + @pytest.mark.parametrize("delay_qvalue", (True, False)) @pytest.mark.parametrize("num_qvalue", [2]) @pytest.mark.parametrize("device", get_default_devices()) @@ -5227,6 +5331,15 @@ def _create_seq_mock_data_crossq( ) return td + def test_reset_parameters_recursive(self): + actor = self._create_mock_actor() + qvalue = self._create_mock_qvalue() + loss_fn = CrossQLoss( + actor_network=actor, + qvalue_network=qvalue, + ) + self.reset_parameters_recursive_test(loss_fn) + @pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8]) @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) @@ -5962,6 +6075,15 @@ def _create_seq_mock_data_redq( ) return td + def test_reset_parameters_recursive(self): + actor = self._create_mock_actor() + qvalue = self._create_mock_qvalue() + loss_fn = REDQLoss( + actor_network=actor, + qvalue_network=qvalue, + ) + self.reset_parameters_recursive_test(loss_fn) + @pytest.mark.parametrize("delay_qvalue", (True, False)) @pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8]) @pytest.mark.parametrize("device", get_default_devices()) @@ -6792,6 +6914,15 @@ def _create_seq_mock_data_cql( ) return td + def test_reset_parameters_recursive(self): + actor = self._create_mock_actor() + qvalue = self._create_mock_qvalue() + loss_fn = CQLLoss( + actor_network=actor, + qvalue_network=qvalue, + ) + self.reset_parameters_recursive_test(loss_fn) + @pytest.mark.parametrize("delay_actor", (True, False)) @pytest.mark.parametrize("delay_qvalue", (True, True)) @pytest.mark.parametrize("max_q_backup", [True, False]) @@ -7367,6 +7498,13 @@ def _create_seq_mock_data_dcql( ) return td + def test_reset_parameters_recursive(self): + actor = self._create_mock_actor( + action_spec_type="one_hot", + ) + loss_fn = DiscreteCQLLoss(actor) + self.reset_parameters_recursive_test(loss_fn) + @pytest.mark.parametrize("delay_value", (False, True)) @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("action_spec_type", ("one_hot", "categorical")) @@ -7938,6 +8076,13 @@ def _create_seq_mock_data_ppo( return td + @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) + def test_reset_parameters_recursive(self, loss_class): + actor = self._create_mock_actor() + value = self._create_mock_value() + loss_fn = loss_class(actor, value) + self.reset_parameters_recursive_test(loss_fn) + @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) @pytest.mark.parametrize("gradient_mode", (True, False)) @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @@ -9016,6 +9161,12 @@ def _create_seq_mock_data_a2c( td["scale"] = scale return td + def test_reset_parameters_recursive(self): + actor = self._create_mock_actor() + value = self._create_mock_value() + loss_fn = A2CLoss(actor, value) + self.reset_parameters_recursive_test(loss_fn) + @pytest.mark.parametrize("gradient_mode", (True, False)) @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @pytest.mark.parametrize("device", get_default_devices()) @@ -9624,6 +9775,27 @@ def test_a2c_value_clipping(self, clip_value, device, composite_action_dist): class TestReinforce(LossModuleTestBase): seed = 0 + def test_reset_parameters_recursive(self): + n_obs = 3 + n_act = 5 + value_net = ValueOperator(nn.Linear(n_obs, 1), in_keys=["observation"]) + net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) + module = TensorDictModule( + net, in_keys=["observation"], out_keys=["loc", "scale"] + ) + actor_net = ProbabilisticActor( + module, + distribution_class=TanhNormal, + return_log_prob=True, + in_keys=["loc", "scale"], + spec=Unbounded(n_act), + ) + loss_fn = ReinforceLoss( + actor_net, + critic_network=value_net, + ) + self.reset_parameters_recursive_test(loss_fn) + @pytest.mark.parametrize("gradient_mode", [True, False]) @pytest.mark.parametrize("advantage", ["gae", "td", "td_lambda", None]) @pytest.mark.parametrize( @@ -10323,6 +10495,11 @@ def _create_value_model(self, rssm_hidden_dim, state_dim, mlp_num_units=13): value_model(td) return value_model + def test_reset_parameters_recursive(self, device): + world_model = self._create_world_model_model(10, 5).to(device) + loss_fn = DreamerModelLoss(world_model) + self.reset_parameters_recursive_test(loss_fn) + @pytest.mark.parametrize("lambda_kl", [0, 1.0]) @pytest.mark.parametrize("lambda_reco", [0, 1.0]) @pytest.mark.parametrize("lambda_reward", [0, 1.0]) @@ -10604,6 +10781,11 @@ def _create_seq_mock_data_odt( ) return td + def test_reset_parameters_recursive(self): + actor = self._create_mock_actor() + loss_fn = OnlineDTLoss(actor) + self.reset_parameters_recursive_test(loss_fn) + @pytest.mark.parametrize("device", get_available_devices()) def test_odt(self, device): torch.manual_seed(self.seed) @@ -10831,6 +11013,11 @@ def _create_seq_mock_data_dt( ) return td + def test_reset_parameters_recursive(self): + actor = self._create_mock_actor() + loss_fn = DTLoss(actor) + self.reset_parameters_recursive_test(loss_fn) + def test_dt_tensordict_keys(self): actor = self._create_mock_actor() loss_fn = DTLoss(actor) @@ -11034,6 +11221,11 @@ def _create_seq_mock_data_gail( ) return td + def test_reset_parameters_recursive(self): + discriminator = self._create_mock_discriminator() + loss_fn = GAILLoss(discriminator) + self.reset_parameters_recursive_test(loss_fn) + def test_gail_tensordict_keys(self): discriminator = self._create_mock_discriminator() loss_fn = GAILLoss(discriminator) @@ -11406,6 +11598,17 @@ def _create_seq_mock_data_iql( ) return td + def test_reset_parameters_recursive(self): + actor = self._create_mock_actor() + qvalue = self._create_mock_qvalue() + value = self._create_mock_value() + loss_fn = IQLLoss( + actor_network=actor, + qvalue_network=qvalue, + value_network=value, + ) + self.reset_parameters_recursive_test(loss_fn) + @pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8]) @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("temperature", [0.0, 0.1, 1.0, 10.0]) @@ -12214,6 +12417,18 @@ def _create_seq_mock_data_discrete_iql( ) return td + def test_reset_parameters_recursive(self): + actor = self._create_mock_actor() + qvalue = self._create_mock_qvalue() + value = self._create_mock_value() + loss_fn = DiscreteIQLLoss( + actor_network=actor, + qvalue_network=qvalue, + value_network=value, + action_space="one-hot", + ) + self.reset_parameters_recursive_test(loss_fn) + @pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8]) @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("temperature", [0.0, 0.1, 1.0, 10.0]) @@ -12842,6 +13057,8 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: ) loss = MyLoss(actor_module) + LossModuleTestBase.reset_parameters_recursive_test(loss) + if create_target_params: SoftUpdate(loss, eps=0.5) diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index f6935ceae82..be05e2fa66b 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -471,6 +471,30 @@ def reset(self) -> None: # mainly used for PPO with KL target pass + def _reset_module_parameters(self, module_name, module): + params_name = f"{module_name}_params" + target_name = f"target_{module_name}_params" + params = self._modules.get(params_name, None) + target = self._modules.get(target_name, None) + + if params is not None: + with params.to_module(module): + module.reset_parameters_recursive() + else: + module.reset_parameters_recursive() + + if target is not None: + with target.to_module(module): + module.reset_parameters_recursive() + + def reset_parameters_recursive( + self, + ): + """Reset the parameters of the module.""" + for key, item in self.__dict__.items(): + if isinstance(item, nn.Module): + self._reset_module_parameters(key, item) + @property def value_estimator(self) -> ValueEstimatorBase: """The value function blends in the reward and value estimate(s) from upcoming state(s)/state-action pair(s) into a target value estimate for the value network."""