Skip to content

Commit

Permalink
Make _LossModule robust to shared parameters (#185)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jun 6, 2022
1 parent 4e4651b commit fc78766
Show file tree
Hide file tree
Showing 16 changed files with 726 additions and 310 deletions.
256 changes: 251 additions & 5 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
ValueOperator,
Actor,
ProbabilisticActor,
ActorValueOperator,
ActorCriticOperator,
)
from torchrl.objectives import (
DQNLoss,
Expand All @@ -41,11 +43,13 @@
KLPENPPOLoss,
)
from torchrl.objectives.costs.common import _LossModule
from torchrl.objectives.costs.redq import (
REDQLoss,
from torchrl.objectives.costs.deprecated import (
REDQLoss_deprecated,
DoubleREDQLoss_deprecated,
)
from torchrl.objectives.costs.redq import (
REDQLoss,
)
from torchrl.objectives.costs.reinforce import ReinforceLoss
from torchrl.objectives.costs.utils import hold_out_net, HardUpdate, SoftUpdate
from torchrl.objectives.returns.advantages import TDEstimate, GAE, TDLambdaEstimate
Expand Down Expand Up @@ -397,7 +401,7 @@ def test_ddpg(self, delay_actor, delay_value, device):
with _check_td_steady(td):
loss = loss_fn(td)

# check that loss are independent
# check that losses are independent
for k in loss.keys():
if not k.startswith("loss"):
continue
Expand Down Expand Up @@ -635,7 +639,7 @@ def test_sac(self, delay_value, delay_actor, delay_qvalue, num_qvalue, device):
loss = loss_fn(td)
assert loss_fn.priority_key in td.keys()

# check that loss are independent
# check that losses are independent
for k in loss.keys():
if not k.startswith("loss"):
continue
Expand Down Expand Up @@ -846,6 +850,48 @@ def forward(self, obs, act):
)
return qvalue.to(device)

def _create_shared_mock_actor_qvalue(
self, batch=2, obs_dim=3, action_dim=4, hidden_dim=5, device="cpu"
):
class CommonClass(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(obs_dim, hidden_dim)

def forward(self, obs):
return self.linear(obs)

class ActorClass(nn.Module):
def __init__(self):
super().__init__()
self.linear = NormalParamWrapper(nn.Linear(hidden_dim, 2 * action_dim))

def forward(self, hidden):
return self.linear(hidden)

class ValueClass(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(hidden_dim + action_dim, 1)

def forward(self, hidden, act):
return self.linear(torch.cat([hidden, act], -1))

common = TensorDictModule(
CommonClass(), in_keys=["observation"], out_keys=["hidden"]
)
actor_subnet = ProbabilisticActor(
TensorDictModule(
ActorClass(), in_keys=["hidden"], out_keys=["loc", "scale"]
),
dist_param_keys=["loc", "scale"],
distribution_class=TanhNormal,
return_log_prob=True,
)
qvalue_subnet = ValueOperator(ValueClass(), in_keys=["hidden", "action"])
model = ActorCriticOperator(common, actor_subnet, qvalue_subnet)
return model.to(device)

def _create_mock_data_redq(
self, batch=16, obs_dim=3, action_dim=4, atoms=None, device="cpu"
):
Expand Down Expand Up @@ -925,7 +971,7 @@ def test_redq(self, delay_qvalue, num_qvalue, device):
# check td is left untouched
assert loss_fn.priority_key in td.keys()

# check that loss are independent
# check that losses are independent
for k in loss.keys():
if not k.startswith("loss"):
continue
Expand Down Expand Up @@ -971,6 +1017,105 @@ def test_redq(self, delay_qvalue, num_qvalue, device):
for name, p in named_parameters:
assert p.grad.norm() > 0.0, f"parameter {name} has a null gradient"

@pytest.mark.parametrize("delay_qvalue", (True, False))
@pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8])
@pytest.mark.parametrize("device", get_available_devices())
def test_redq_shared(self, delay_qvalue, num_qvalue, device):

torch.manual_seed(self.seed)
td = self._create_mock_data_redq(device=device)

actor_critic = self._create_shared_mock_actor_qvalue(device=device)
actor = actor_critic.get_policy_operator()
qvalue = actor_critic.get_critic_operator()

loss_fn = REDQLoss(
actor_network=actor,
qvalue_network=qvalue,
num_qvalue_nets=num_qvalue,
gamma=0.9,
loss_function="l2",
delay_qvalue=delay_qvalue,
target_entropy=0.0,
)

if delay_qvalue:
target_updater = SoftUpdate(loss_fn)
target_updater.init_()

with _check_td_steady(td):
loss = loss_fn(td)

# check that losses are independent
for k in loss.keys():
if not k.startswith("loss"):
continue
loss[k].sum().backward(retain_graph=True)
if k == "loss_actor":
assert all(
(p.grad is None) or (p.grad == 0).all()
for p in loss_fn._qvalue_network_params
)
assert not any(
(p.grad is None) or (p.grad == 0).all()
for p in loss_fn._actor_network_params
)
elif k == "loss_qvalue":
assert all(
(p.grad is None) or (p.grad == 0).all()
for p in loss_fn._actor_network_params
)
assert not any(
(p.grad is None) or (p.grad == 0).all()
for p in loss_fn._qvalue_network_params
)
elif k == "loss_alpha":
assert all(
(p.grad is None) or (p.grad == 0).all()
for p in loss_fn._actor_network_params
)
assert all(
(p.grad is None) or (p.grad == 0).all()
for p in loss_fn._qvalue_network_params
)
else:
raise NotImplementedError(k)
loss_fn.zero_grad()

# check td is left untouched
assert loss_fn.priority_key in td.keys()

sum([item for _, item in loss.items()]).backward()
named_parameters = list(loss_fn.named_parameters())
named_buffers = list(loss_fn.named_buffers())

assert len(set(p for n, p in named_parameters)) == len(list(named_parameters))
assert len(set(p for n, p in named_buffers)) == len(list(named_buffers))

for name, p in named_parameters:
assert p.grad.norm() > 0.0, f"parameter {name} has a null gradient"

# modify params and check that expanded values are updated
for p in loss_fn.parameters():
p.data *= 0

counter = 0
for p in loss_fn.qvalue_network_params:
if not isinstance(p, nn.Parameter):
counter += 1
assert (p == loss_fn._param_maps[p]).all()
assert (p == 0).all()
assert counter == len(loss_fn._actor_network_params)
assert counter == len(loss_fn.actor_network_params)

# check that params of the original actor are those of the loss_fn
for p in actor.parameters():
assert p in set(loss_fn.parameters())

if delay_qvalue:
# test that updating with target updater resets the targets of qvalue to 0
target_updater.step()

@pytest.mark.parametrize("delay_qvalue", (True, False))
@pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8])
@pytest.mark.parametrize("device", get_available_devices())
Expand Down Expand Up @@ -1513,6 +1658,107 @@ def test_tdlambda(device, gamma, lmbda, N, T):
torch.testing.assert_close(r1, r2, rtol=1e-4, atol=1e-4)


@pytest.mark.parametrize(
"dest,expected_dtype,expected_device",
list(
zip(
get_available_devices(),
[torch.float] * len(get_available_devices()),
get_available_devices(),
)
)
+ [
["cuda", torch.float, "cuda:0"],
["double", torch.double, "cpu"],
[torch.double, torch.double, "cpu"],
[torch.half, torch.half, "cpu"],
["half", torch.half, "cpu"],
],
)
def test_shared_params(dest, expected_dtype, expected_device):
if torch.cuda.device_count() == 0 and dest == "cuda":
pytest.skip("no cuda device available")
module_hidden = torch.nn.Linear(4, 4)
td_module_hidden = TensorDictModule(
module=module_hidden,
spec=None,
in_keys=["observation"],
out_keys=["hidden"],
)
module_action = TensorDictModule(
NormalParamWrapper(torch.nn.Linear(4, 8)),
in_keys=["hidden"],
out_keys=["loc", "scale"],
)
td_module_action = ProbabilisticActor(
module=module_action,
spec=None,
dist_param_keys=["loc", "scale"],
out_key_sample=["action"],
distribution_class=TanhNormal,
return_log_prob=True,
)
module_value = torch.nn.Linear(4, 1)
td_module_value = ValueOperator(
module=module_value,
in_keys=["hidden"],
)
td_module = ActorValueOperator(td_module_hidden, td_module_action, td_module_value)

class MyLoss(_LossModule):
def __init__(self, actor_network, qvalue_network):
super().__init__()
self.convert_to_functional(
actor_network,
"actor_network",
create_target_params=True,
)
self.convert_to_functional(
qvalue_network,
"qvalue_network",
3,
create_target_params=True,
compare_against=list(actor_network.parameters()),
)

actor_network = td_module.get_policy_operator()
value_network = td_module.get_value_operator()

loss = MyLoss(actor_network, value_network)
# modify params
for p in loss.parameters():
p.data += torch.randn_like(p)

assert len(list(loss.parameters())) == 6
assert len(loss.actor_network_params) == 4
assert len(loss.qvalue_network_params) == 4
for p in loss.actor_network_params:
assert isinstance(p, nn.Parameter)
assert (loss.qvalue_network_params[0] == loss.actor_network_params[0]).all()
assert (loss.qvalue_network_params[1] == loss.actor_network_params[1]).all()

# map module
if dest == "double":
loss = loss.double()
elif dest == "cuda":
loss = loss.cuda()
elif dest == "half":
loss = loss.half()
else:
loss = loss.to(dest)

for p in loss.actor_network_params:
assert isinstance(p, nn.Parameter)
assert p.dtype is expected_dtype
assert p.device == torch.device(expected_device)
loss.qvalue_network_params[0].dtype is expected_dtype
loss.qvalue_network_params[1].dtype is expected_dtype
loss.qvalue_network_params[0].device == torch.device(expected_device)
loss.qvalue_network_params[1].device == torch.device(expected_device)
assert (loss.qvalue_network_params[0] == loss.actor_network_params[0]).all()
assert (loss.qvalue_network_params[1] == loss.actor_network_params[1]).all()


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
12 changes: 10 additions & 2 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,13 @@
)
from torchrl.envs.utils import step_tensordict
from torchrl.envs.vec_env import ParallelEnv, SerialEnv
from torchrl.modules import ActorCriticOperator, TensorDictModule, ValueOperator, Actor
from torchrl.modules import (
ActorCriticOperator,
TensorDictModule,
ValueOperator,
Actor,
MLP,
)

try:
this_dir = os.path.dirname(os.path.realpath(__file__))
Expand Down Expand Up @@ -356,7 +362,9 @@ def test_parallel_env_with_policy(
in_keys=["hidden"],
out_keys=["action"],
),
ValueOperator(module=nn.LazyLinear(1), in_keys=["hidden"]),
ValueOperator(
module=MLP(out_features=1, num_cells=[]), in_keys=["hidden", "action"]
),
)

td = TensorDict(
Expand Down
2 changes: 1 addition & 1 deletion torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,9 +800,9 @@ def shutdown(self) -> None:
self._shutdown_main()

def _shutdown_main(self) -> None:
_check_for_faulty_process(self.procs)
if self.closed:
return
_check_for_faulty_process(self.procs)
self.closed = True
for idx in range(self.num_workers):
self.pipes[idx].send((None, "close"))
Expand Down
Loading

0 comments on commit fc78766

Please sign in to comment.