Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add LossModule.reset_parameters_recursive #2546

Merged
merged 1 commit into from
Nov 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 217 additions & 0 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,12 @@ def get_devices():


class LossModuleTestBase:
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
assert hasattr(
cls, "test_reset_parameters_recursive"
), "Please add a test_reset_parameters_recursive test for this class"

def _flatten_in_keys(self, in_keys):
return [
in_key if isinstance(in_key, str) else "_".join(list(unravel_keys(in_key)))
Expand Down Expand Up @@ -252,6 +258,42 @@ def set_advantage_keys_through_loss_test(
getattr(test_fn.value_estimator.tensor_keys, advantage_key) == new_key
)

@classmethod
def reset_parameters_recursive_test(cls, loss_fn):
def get_params(loss_fn):
for key, item in loss_fn.__dict__.items():
if isinstance(item, nn.Module):
module_name = key
params_name = f"{module_name}_params"
target_name = f"target_{module_name}_params"
params = loss_fn._modules.get(params_name, None)
target = loss_fn._modules.get(target_name, None)

if params is not None:
yield params_name, params._param_td

else:
for subparam_name, subparam in loss_fn.named_parameters():
if module_name in subparam_name:
yield subparam_name, subparam

if target is not None:
yield target_name, target

old_params = {}

for param_name, param in get_params(loss_fn):
with torch.no_grad():
# Change the parameter to ensure that reset will change it again
param += 1000
old_params[param_name] = param.clone()

loss_fn.reset_parameters_recursive()

for param_name, param in get_params(loss_fn):
old_param = old_params[param_name]
assert (param != old_param).any()


@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("vmap_randomness", (None, "different", "same", "error"))
Expand Down Expand Up @@ -494,6 +536,11 @@ def _create_seq_mock_data_dqn(
)
return td

def test_reset_parameters_recursive(self):
actor = self._create_mock_actor(action_spec_type="one_hot")
loss_fn = DQNLoss(actor)
self.reset_parameters_recursive_test(loss_fn)

@pytest.mark.parametrize(
"delay_value,double_dqn", ([False, False], [True, False], [True, True])
)
Expand Down Expand Up @@ -1066,6 +1113,12 @@ def _create_mock_data_dqn(
td.refine_names(None, "time")
return td

def test_reset_parameters_recursive(self):
actor = self._create_mock_actor(action_spec_type="one_hot")
mixer = self._create_mock_mixer()
loss_fn = QMixerLoss(actor, mixer)
self.reset_parameters_recursive_test(loss_fn)

@pytest.mark.parametrize("delay_value", (False, True))
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("action_spec_type", ("one_hot", "categorical"))
Expand Down Expand Up @@ -1570,6 +1623,12 @@ def _create_seq_mock_data_ddpg(
)
return td

def test_reset_parameters_recursive(self):
actor = self._create_mock_actor()
value = self._create_mock_value()
loss_fn = DDPGLoss(actor, value)
self.reset_parameters_recursive_test(loss_fn)

@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("delay_actor,delay_value", [(False, False), (True, True)])
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
Expand Down Expand Up @@ -2210,6 +2269,16 @@ def _create_seq_mock_data_td3(
)
return td

def test_reset_parameters_recursive(self):
actor = self._create_mock_actor()
value = self._create_mock_value()
loss_fn = TD3Loss(
actor,
value,
bounds=(-1, 1),
)
self.reset_parameters_recursive_test(loss_fn)

@pytest.mark.skipif(not _has_functorch, reason="functorch not installed")
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize(
Expand Down Expand Up @@ -2916,6 +2985,16 @@ def _create_seq_mock_data_td3bc(
)
return td

def test_reset_parameters_recursive(self):
actor = self._create_mock_actor()
value = self._create_mock_value()
loss_fn = TD3BCLoss(
actor,
value,
bounds=(-1, 1),
)
self.reset_parameters_recursive_test(loss_fn)

@pytest.mark.skipif(not _has_functorch, reason="functorch not installed")
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize(
Expand Down Expand Up @@ -3720,6 +3799,20 @@ def _create_seq_mock_data_sac(
)
return td

def test_reset_parameters_recursive(self, version):
actor = self._create_mock_actor()
qvalue = self._create_mock_qvalue()
if version == 1:
value = self._create_mock_value()
else:
value = None
loss_fn = SACLoss(
actor_network=actor,
qvalue_network=qvalue,
value_network=value,
)
self.reset_parameters_recursive_test(loss_fn)

@pytest.mark.parametrize("delay_value", (True, False))
@pytest.mark.parametrize("delay_actor", (True, False))
@pytest.mark.parametrize("delay_qvalue", (True, False))
Expand Down Expand Up @@ -4591,6 +4684,17 @@ def _create_seq_mock_data_sac(
)
return td

def test_reset_parameters_recursive(self):
actor = self._create_mock_actor()
qvalue = self._create_mock_qvalue()
loss_fn = DiscreteSACLoss(
actor_network=actor,
qvalue_network=qvalue,
num_actions=actor.spec["action"].space.n,
action_space="one-hot",
)
self.reset_parameters_recursive_test(loss_fn)

@pytest.mark.parametrize("delay_qvalue", (True, False))
@pytest.mark.parametrize("num_qvalue", [2])
@pytest.mark.parametrize("device", get_default_devices())
Expand Down Expand Up @@ -5227,6 +5331,15 @@ def _create_seq_mock_data_crossq(
)
return td

def test_reset_parameters_recursive(self):
actor = self._create_mock_actor()
qvalue = self._create_mock_qvalue()
loss_fn = CrossQLoss(
actor_network=actor,
qvalue_network=qvalue,
)
self.reset_parameters_recursive_test(loss_fn)

@pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8])
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
Expand Down Expand Up @@ -5962,6 +6075,15 @@ def _create_seq_mock_data_redq(
)
return td

def test_reset_parameters_recursive(self):
actor = self._create_mock_actor()
qvalue = self._create_mock_qvalue()
loss_fn = REDQLoss(
actor_network=actor,
qvalue_network=qvalue,
)
self.reset_parameters_recursive_test(loss_fn)

@pytest.mark.parametrize("delay_qvalue", (True, False))
@pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8])
@pytest.mark.parametrize("device", get_default_devices())
Expand Down Expand Up @@ -6792,6 +6914,15 @@ def _create_seq_mock_data_cql(
)
return td

def test_reset_parameters_recursive(self):
actor = self._create_mock_actor()
qvalue = self._create_mock_qvalue()
loss_fn = CQLLoss(
actor_network=actor,
qvalue_network=qvalue,
)
self.reset_parameters_recursive_test(loss_fn)

@pytest.mark.parametrize("delay_actor", (True, False))
@pytest.mark.parametrize("delay_qvalue", (True, True))
@pytest.mark.parametrize("max_q_backup", [True, False])
Expand Down Expand Up @@ -7367,6 +7498,13 @@ def _create_seq_mock_data_dcql(
)
return td

def test_reset_parameters_recursive(self):
actor = self._create_mock_actor(
action_spec_type="one_hot",
)
loss_fn = DiscreteCQLLoss(actor)
self.reset_parameters_recursive_test(loss_fn)

@pytest.mark.parametrize("delay_value", (False, True))
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("action_spec_type", ("one_hot", "categorical"))
Expand Down Expand Up @@ -7938,6 +8076,13 @@ def _create_seq_mock_data_ppo(

return td

@pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss))
def test_reset_parameters_recursive(self, loss_class):
actor = self._create_mock_actor()
value = self._create_mock_value()
loss_fn = loss_class(actor, value)
self.reset_parameters_recursive_test(loss_fn)

@pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss))
@pytest.mark.parametrize("gradient_mode", (True, False))
@pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None))
Expand Down Expand Up @@ -9016,6 +9161,12 @@ def _create_seq_mock_data_a2c(
td["scale"] = scale
return td

def test_reset_parameters_recursive(self):
actor = self._create_mock_actor()
value = self._create_mock_value()
loss_fn = A2CLoss(actor, value)
self.reset_parameters_recursive_test(loss_fn)

@pytest.mark.parametrize("gradient_mode", (True, False))
@pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None))
@pytest.mark.parametrize("device", get_default_devices())
Expand Down Expand Up @@ -9624,6 +9775,27 @@ def test_a2c_value_clipping(self, clip_value, device, composite_action_dist):
class TestReinforce(LossModuleTestBase):
seed = 0

def test_reset_parameters_recursive(self):
n_obs = 3
n_act = 5
value_net = ValueOperator(nn.Linear(n_obs, 1), in_keys=["observation"])
net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor())
module = TensorDictModule(
net, in_keys=["observation"], out_keys=["loc", "scale"]
)
actor_net = ProbabilisticActor(
module,
distribution_class=TanhNormal,
return_log_prob=True,
in_keys=["loc", "scale"],
spec=Unbounded(n_act),
)
loss_fn = ReinforceLoss(
actor_net,
critic_network=value_net,
)
self.reset_parameters_recursive_test(loss_fn)

@pytest.mark.parametrize("gradient_mode", [True, False])
@pytest.mark.parametrize("advantage", ["gae", "td", "td_lambda", None])
@pytest.mark.parametrize(
Expand Down Expand Up @@ -10323,6 +10495,11 @@ def _create_value_model(self, rssm_hidden_dim, state_dim, mlp_num_units=13):
value_model(td)
return value_model

def test_reset_parameters_recursive(self, device):
world_model = self._create_world_model_model(10, 5).to(device)
loss_fn = DreamerModelLoss(world_model)
self.reset_parameters_recursive_test(loss_fn)

@pytest.mark.parametrize("lambda_kl", [0, 1.0])
@pytest.mark.parametrize("lambda_reco", [0, 1.0])
@pytest.mark.parametrize("lambda_reward", [0, 1.0])
Expand Down Expand Up @@ -10604,6 +10781,11 @@ def _create_seq_mock_data_odt(
)
return td

def test_reset_parameters_recursive(self):
actor = self._create_mock_actor()
loss_fn = OnlineDTLoss(actor)
self.reset_parameters_recursive_test(loss_fn)

@pytest.mark.parametrize("device", get_available_devices())
def test_odt(self, device):
torch.manual_seed(self.seed)
Expand Down Expand Up @@ -10831,6 +11013,11 @@ def _create_seq_mock_data_dt(
)
return td

def test_reset_parameters_recursive(self):
actor = self._create_mock_actor()
loss_fn = DTLoss(actor)
self.reset_parameters_recursive_test(loss_fn)

def test_dt_tensordict_keys(self):
actor = self._create_mock_actor()
loss_fn = DTLoss(actor)
Expand Down Expand Up @@ -11034,6 +11221,11 @@ def _create_seq_mock_data_gail(
)
return td

def test_reset_parameters_recursive(self):
discriminator = self._create_mock_discriminator()
loss_fn = GAILLoss(discriminator)
self.reset_parameters_recursive_test(loss_fn)

def test_gail_tensordict_keys(self):
discriminator = self._create_mock_discriminator()
loss_fn = GAILLoss(discriminator)
Expand Down Expand Up @@ -11406,6 +11598,17 @@ def _create_seq_mock_data_iql(
)
return td

def test_reset_parameters_recursive(self):
actor = self._create_mock_actor()
qvalue = self._create_mock_qvalue()
value = self._create_mock_value()
loss_fn = IQLLoss(
actor_network=actor,
qvalue_network=qvalue,
value_network=value,
)
self.reset_parameters_recursive_test(loss_fn)

@pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8])
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("temperature", [0.0, 0.1, 1.0, 10.0])
Expand Down Expand Up @@ -12214,6 +12417,18 @@ def _create_seq_mock_data_discrete_iql(
)
return td

def test_reset_parameters_recursive(self):
actor = self._create_mock_actor()
qvalue = self._create_mock_qvalue()
value = self._create_mock_value()
loss_fn = DiscreteIQLLoss(
actor_network=actor,
qvalue_network=qvalue,
value_network=value,
action_space="one-hot",
)
self.reset_parameters_recursive_test(loss_fn)

@pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8])
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("temperature", [0.0, 0.1, 1.0, 10.0])
Expand Down Expand Up @@ -12842,6 +13057,8 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:
)
loss = MyLoss(actor_module)

LossModuleTestBase.reset_parameters_recursive_test(loss)

if create_target_params:
SoftUpdate(loss, eps=0.5)

Expand Down
Loading
Loading