Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Allow for composite action distributions in PPO/A2C losses #2391

Merged
merged 19 commits into from
Sep 4, 2024
Prev Previous commit
Next Next commit
fix tests ppo
  • Loading branch information
albertbou92 committed Aug 12, 2024
commit a668d0c98c490483e696b2ee55bff6bd8789b6b8
200 changes: 130 additions & 70 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@

from packaging import version as pack_version
from tensordict._C import unravel_keys

from tensordict.nn import (
CompositeDistribution,
InteractionType,
ProbabilisticTensorDictModule,
ProbabilisticTensorDictModule as ProbMod,
Expand All @@ -24,8 +24,8 @@
TensorDictSequential,
TensorDictSequential as Seq,
)
from torchrl.data import Composite
from torchrl.envs.utils import exploration_type, ExplorationType, set_exploration_type

from torchrl.modules.models import QMixer

_has_functorch = True
Expand Down Expand Up @@ -7533,72 +7533,48 @@ def test_dcql_reduction(self, reduction):
class TestPPO(LossModuleTestBase):
seed = 0

def _create_mock_actor_old(
def _create_mock_actor(
self,
batch=2,
obs_dim=3,
action_dim=4,
device="cpu",
observation_key="observation",
sample_log_prob_key="sample_log_prob",
composite_action_dist=False,
):
# Actor
action_spec = Bounded(
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
if composite_action_dist:
action_spec = Composite({"action": {"action1": action_spec}})
net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor())
if composite_action_dist:
distribution_class = functools.partial(
CompositeDistribution,
distribution_map={
"action1": TanhNormal,
},
name_map={
"action1": ("action", "action1"),
},
)
module_out_keys = [
("params", "action1", "loc"),
("params", "action1", "scale"),
]
actor_in_keys = ["params"]
else:
distribution_class = TanhNormal
module_out_keys = actor_in_keys = ["loc", "scale"]
module = TensorDictModule(
net, in_keys=[observation_key], out_keys=["loc", "scale"]
)
actor = ProbabilisticActor(
module=module,
distribution_class=TanhNormal,
in_keys=["loc", "scale"],
spec=action_spec,
return_log_prob=True,
log_prob_key=sample_log_prob_key,
)
return actor.to(device)

def _create_mock_actor(
self,
batch=2,
obs_dim=3,
action_dim=4,
device="cpu",
observation_key="observation",
sample_log_prob_key="sample_log_prob",
composite_action_dist=True,
):
from tensordict.nn import CompositeDistribution
from torchrl.data import Composite

# Actor
action_spec = Composite({
"action":{
"action1":
Bounded(-torch.ones(action_dim), torch.ones(action_dim), (action_dim,))
}
}
)
net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor())
module = TensorDictModule(
net, in_keys=[observation_key], out_keys=[("params", "action1", "loc"), ("params", "action1", "scale")]
)
distribution_class = functools.partial(
CompositeDistribution,
distribution_map={
"action1": TanhNormal,
},
name_map={
"action1": ("action", "action1"),
}
net, in_keys=[observation_key], out_keys=module_out_keys
)
actor = ProbabilisticActor(
module=module,
distribution_class=distribution_class,
in_keys=["params"],
out_keys=["action"],
in_keys=actor_in_keys,
spec=action_spec,
return_log_prob=True,
log_prob_key=sample_log_prob_key,
Expand All @@ -7622,22 +7598,49 @@ def _create_mock_value(
)
return value.to(device)

def _create_mock_actor_value(self, batch=2, obs_dim=3, action_dim=4, device="cpu"):
def _create_mock_actor_value(
self,
batch=2,
obs_dim=3,
action_dim=4,
device="cpu",
composite_action_dist=False,
):
# Actor
action_spec = Bounded(
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
if composite_action_dist:
action_spec = Composite({"action": {"action1": action_spec}})
base_layer = nn.Linear(obs_dim, 5)
net = nn.Sequential(
base_layer, nn.Linear(5, 2 * action_dim), NormalParamExtractor()
)
if composite_action_dist:
distribution_class = functools.partial(
CompositeDistribution,
distribution_map={
"action1": TanhNormal,
},
name_map={
"action1": ("action", "action1"),
},
)
module_out_keys = [
("params", "action1", "loc"),
("params", "action1", "scale"),
]
actor_in_keys = ["params"]
else:
distribution_class = TanhNormal
module_out_keys = actor_in_keys = ["loc", "scale"]
module = TensorDictModule(
net, in_keys=["observation"], out_keys=["loc", "scale"]
net, in_keys=["observation"], out_keys=module_out_keys
)
actor = ProbabilisticActor(
module=module,
distribution_class=TanhNormal,
in_keys=["loc", "scale"],
distribution_class=distribution_class,
in_keys=actor_in_keys,
spec=action_spec,
return_log_prob=True,
)
Expand All @@ -7649,22 +7652,47 @@ def _create_mock_actor_value(self, batch=2, obs_dim=3, action_dim=4, device="cpu
return actor.to(device), value.to(device)

def _create_mock_actor_value_shared(
self, batch=2, obs_dim=3, action_dim=4, device="cpu"
self,
batch=2,
obs_dim=3,
action_dim=4,
device="cpu",
composite_action_dist=False,
):
# Actor
action_spec = Bounded(
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
if composite_action_dist:
action_spec = Composite({"action": {"action1": action_spec}})
base_layer = nn.Linear(obs_dim, 5)
common = TensorDictModule(
base_layer, in_keys=["observation"], out_keys=["hidden"]
)
net = nn.Sequential(nn.Linear(5, 2 * action_dim), NormalParamExtractor())
module = TensorDictModule(net, in_keys=["hidden"], out_keys=["loc", "scale"])
if composite_action_dist:
distribution_class = functools.partial(
CompositeDistribution,
distribution_map={
"action1": TanhNormal,
},
name_map={
"action1": ("action", "action1"),
},
)
module_out_keys = [
("params", "action1", "loc"),
("params", "action1", "scale"),
]
actor_in_keys = ["params"]
else:
distribution_class = TanhNormal
module_out_keys = actor_in_keys = ["loc", "scale"]
module = TensorDictModule(net, in_keys=["hidden"], out_keys=module_out_keys)
actor_head = ProbabilisticActor(
module=module,
distribution_class=TanhNormal,
in_keys=["loc", "scale"],
distribution_class=distribution_class,
in_keys=actor_in_keys,
spec=action_spec,
return_log_prob=True,
)
Expand Down Expand Up @@ -7694,7 +7722,7 @@ def _create_mock_data_ppo(
done_key="done",
terminated_key="terminated",
sample_log_prob_key="sample_log_prob",
composite_action_dist=True,
composite_action_dist=False,
):
# create a tensordict
obs = torch.randn(batch, obs_dim, device=device)
Expand Down Expand Up @@ -7742,7 +7770,7 @@ def _create_seq_mock_data_ppo(
device="cpu",
sample_log_prob_key="sample_log_prob",
action_key="action",
composite_action_dist=True,
composite_action_dist=False,
):
# create a tensordict
total_obs = torch.randn(batch, T + 1, obs_dim, device=device)
Expand Down Expand Up @@ -7796,6 +7824,7 @@ def _create_seq_mock_data_ppo(
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
@pytest.mark.parametrize("functional", [True, False])
@pytest.mark.parametrize("composite_action_dist", [True, False])
def test_ppo(
self,
loss_class,
Expand All @@ -7804,11 +7833,16 @@ def test_ppo(
advantage,
td_est,
functional,
composite_action_dist,
):
torch.manual_seed(self.seed)
td = self._create_seq_mock_data_ppo(device=device)
td = self._create_seq_mock_data_ppo(
device=device, composite_action_dist=composite_action_dist
)

actor = self._create_mock_actor(device=device)
actor = self._create_mock_actor(
device=device, composite_action_dist=composite_action_dist
)
value = self._create_mock_value(device=device)
if advantage == "gae":
advantage = GAE(
Expand Down Expand Up @@ -7848,7 +7882,10 @@ def test_ppo(

loss = loss_fn(td)
if isinstance(loss_fn, KLPENPPOLoss):
kl = loss.pop("kl")
if "kl" in loss:
kl = loss.pop("kl")
else:
kl = loss.pop("kl_approx")
assert (kl != 0).any()

loss_critic = loss["loss_critic"]
Expand Down Expand Up @@ -7898,11 +7935,16 @@ def test_ppo_state_dict(self, loss_class, device, gradient_mode):
@pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss))
@pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None))
@pytest.mark.parametrize("device", get_default_devices())
def test_ppo_shared(self, loss_class, device, advantage):
@pytest.mark.parametrize("composite_action_dist", [True, False])
def test_ppo_shared(self, loss_class, device, advantage, composite_action_dist):
torch.manual_seed(self.seed)
td = self._create_seq_mock_data_ppo(device=device)
td = self._create_seq_mock_data_ppo(
device=device, composite_action_dist=composite_action_dist
)

actor, value = self._create_mock_actor_value(device=device)
actor, value = self._create_mock_actor_value(
device=device, composite_action_dist=composite_action_dist
)
if advantage == "gae":
advantage = GAE(
gamma=0.9,
Expand Down Expand Up @@ -7984,18 +8026,24 @@ def test_ppo_shared(self, loss_class, device, advantage):
)
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("separate_losses", [True, False])
@pytest.mark.parametrize("composite_action_dist", [True, False])
def test_ppo_shared_seq(
self,
loss_class,
device,
advantage,
separate_losses,
composite_action_dist,
):
"""Tests PPO with shared module with and without passing twice across the common module."""
torch.manual_seed(self.seed)
td = self._create_seq_mock_data_ppo(device=device)
td = self._create_seq_mock_data_ppo(
device=device, composite_action_dist=composite_action_dist
)

model, actor, value = self._create_mock_actor_value_shared(device=device)
model, actor, value = self._create_mock_actor_value_shared(
device=device, composite_action_dist=composite_action_dist
)
value2 = value[-1] # prune the common module
if advantage == "gae":
advantage = GAE(
Expand Down Expand Up @@ -8053,8 +8101,20 @@ def test_ppo_shared_seq(
grad2 = TensorDict(dict(model.named_parameters()), []).apply(
lambda x: x.grad.clone()
)
assert_allclose_td(loss, loss2)
assert_allclose_td(grad, grad2)
if composite_action_dist and loss_class is KLPENPPOLoss:
# KL computation for composite dist is based on randomly
# sampled data, thus will not be the same.
# Similarly, objective loss depends on the KL, so ir will
# not be the same either.
# Finally, gradients will be different too.
loss.pop("kl", None)
loss2.pop("kl", None)
loss.pop("loss_objective", None)
loss2.pop("loss_objective", None)
assert_allclose_td(loss, loss2)
else:
assert_allclose_td(loss, loss2)
assert_allclose_td(grad, grad2)
model.zero_grad()

@pytest.mark.skipif(
Expand Down