Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Allow for composite action distributions in PPO/A2C losses #2391

Merged
merged 19 commits into from
Sep 4, 2024
Prev Previous commit
Next Next commit
a2c tests
  • Loading branch information
albertbou92 committed Aug 14, 2024
commit 13a092d198230b030cb672138e37007767ae1968
165 changes: 133 additions & 32 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
TensorDictSequential,
TensorDictSequential as Seq,
)
from torchrl.data import Composite
from torchrl.envs.utils import exploration_type, ExplorationType, set_exploration_type
from torchrl.modules.models import QMixer

Expand Down Expand Up @@ -8595,22 +8594,45 @@ def _create_mock_actor(
obs_dim=3,
action_dim=4,
device="cpu",
action_key="action",
observation_key="observation",
sample_log_prob_key="sample_log_prob",
composite_action_dist=False,
):
# Actor
action_spec = Bounded(
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
if composite_action_dist:
action_spec = Composite({action_key: {"action1": action_spec}})
net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor())
if composite_action_dist:
distribution_class = functools.partial(
CompositeDistribution,
distribution_map={
"action1": TanhNormal,
},
name_map={
"action1": (action_key, "action1"),
},
log_prob_key=sample_log_prob_key,
)
module_out_keys = [
("params", "action1", "loc"),
("params", "action1", "scale"),
]
actor_in_keys = ["params"]
else:
distribution_class = TanhNormal
module_out_keys = actor_in_keys = ["loc", "scale"]
module = TensorDictModule(
net, in_keys=[observation_key], out_keys=["loc", "scale"]
net, in_keys=[observation_key], out_keys=module_out_keys
)
actor = ProbabilisticActor(
module=module,
in_keys=["loc", "scale"],
in_keys=actor_in_keys,
spec=action_spec,
distribution_class=TanhNormal,
distribution_class=distribution_class,
return_log_prob=True,
log_prob_key=sample_log_prob_key,
)
Expand All @@ -8634,7 +8656,15 @@ def _create_mock_value(
return value.to(device)

def _create_mock_common_layer_setup(
self, n_obs=3, n_act=4, ncells=4, batch=2, n_hidden=2, T=10
self,
n_obs=3,
n_act=4,
ncells=4,
batch=2,
n_hidden=2,
T=10,
composite_action_dist=False,
sample_log_prob_key="sample_log_prob",
):
common_net = MLP(
num_cells=ncells,
Expand All @@ -8655,10 +8685,11 @@ def _create_mock_common_layer_setup(
out_features=1,
)
batch = [batch, T]
action = torch.randn(*batch, n_act)
td = TensorDict(
{
"obs": torch.randn(*batch, n_obs),
"action": torch.randn(*batch, n_act),
"action": {"action1": action} if composite_action_dist else action,
"sample_log_prob": torch.randn(*batch),
"done": torch.zeros(*batch, 1, dtype=torch.bool),
"terminated": torch.zeros(*batch, 1, dtype=torch.bool),
Expand All @@ -8673,14 +8704,35 @@ def _create_mock_common_layer_setup(
names=[None, "time"],
)
common = Mod(common_net, in_keys=["obs"], out_keys=["hidden"])

if composite_action_dist:
distribution_class = functools.partial(
CompositeDistribution,
distribution_map={
"action1": TanhNormal,
},
name_map={
"action1": ("action", "action1"),
},
log_prob_key=sample_log_prob_key,
)
module_out_keys = [
("params", "action1", "loc"),
("params", "action1", "scale"),
]
actor_in_keys = ["params"]
else:
distribution_class = TanhNormal
module_out_keys = actor_in_keys = ["loc", "scale"]

actor = ProbSeq(
common,
Mod(actor_net, in_keys=["hidden"], out_keys=["param"]),
Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]),
Mod(NormalParamExtractor(), in_keys=["param"], out_keys=module_out_keys),
ProbMod(
in_keys=["loc", "scale"],
in_keys=actor_in_keys,
out_keys=["action"],
distribution_class=TanhNormal,
distribution_class=distribution_class,
),
)
critic = Seq(
Expand All @@ -8704,6 +8756,7 @@ def _create_seq_mock_data_a2c(
done_key="done",
terminated_key="terminated",
sample_log_prob_key="sample_log_prob",
composite_action_dist=False,
):
# create a tensordict
total_obs = torch.randn(batch, T + 1, obs_dim, device=device)
Expand All @@ -8719,8 +8772,11 @@ def _create_seq_mock_data_a2c(
done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
mask = ~torch.zeros(batch, T, dtype=torch.bool, device=device)
action = action.masked_fill_(~mask.unsqueeze(-1), 0.0)
params_mean = torch.randn_like(action) / 10
params_scale = torch.rand_like(action) / 10
loc = params_mean.masked_fill_(~mask.unsqueeze(-1), 0.0)
scale = params_scale.masked_fill_(~mask.unsqueeze(-1), 0.0)
td = TensorDict(
batch_size=(batch, T),
source={
Expand All @@ -8732,29 +8788,46 @@ def _create_seq_mock_data_a2c(
reward_key: reward.masked_fill_(~mask.unsqueeze(-1), 0.0),
},
"collector": {"mask": mask},
action_key: action.masked_fill_(~mask.unsqueeze(-1), 0.0),
action_key: {"action1": action} if composite_action_dist else action,
sample_log_prob_key: torch.randn_like(action[..., 1]).masked_fill_(
~mask, 0.0
)
/ 10,
"loc": params_mean.masked_fill_(~mask.unsqueeze(-1), 0.0),
"scale": params_scale.masked_fill_(~mask.unsqueeze(-1), 0.0),
},
device=device,
names=[None, "time"],
)
if composite_action_dist:
td[("params", "action1", "loc")] = loc
td[("params", "action1", "scale")] = scale
else:
td["loc"] = loc
td["scale"] = scale
return td

@pytest.mark.parametrize("gradient_mode", (True, False))
@pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None))
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
@pytest.mark.parametrize("functional", (True, False))
def test_a2c(self, device, gradient_mode, advantage, td_est, functional):
@pytest.mark.parametrize("composite_action_dist", [True, False])
def test_a2c(
self,
device,
gradient_mode,
advantage,
td_est,
functional,
composite_action_dist,
):
torch.manual_seed(self.seed)
td = self._create_seq_mock_data_a2c(device=device)
td = self._create_seq_mock_data_a2c(
device=device, composite_action_dist=composite_action_dist
)

actor = self._create_mock_actor(device=device)
actor = self._create_mock_actor(
device=device, composite_action_dist=composite_action_dist
)
value = self._create_mock_value(device=device)
if advantage == "gae":
advantage = GAE(
Expand Down Expand Up @@ -8835,19 +8908,25 @@ def test_a2c(self, device, gradient_mode, advantage, td_est, functional):

@pytest.mark.parametrize("gradient_mode", (True, False))
@pytest.mark.parametrize("device", get_default_devices())
def test_a2c_state_dict(self, device, gradient_mode):
@pytest.mark.parametrize("composite_action_dist", [True, False])
def test_a2c_state_dict(self, device, gradient_mode, composite_action_dist):
torch.manual_seed(self.seed)
actor = self._create_mock_actor(device=device)
actor = self._create_mock_actor(
device=device, composite_action_dist=composite_action_dist
)
value = self._create_mock_value(device=device)
loss_fn = A2CLoss(actor, value, loss_critic_type="l2")
sd = loss_fn.state_dict()
loss_fn2 = A2CLoss(actor, value, loss_critic_type="l2")
loss_fn2.load_state_dict(sd)

@pytest.mark.parametrize("separate_losses", [False, True])
def test_a2c_separate_losses(self, separate_losses):
@pytest.mark.parametrize("composite_action_dist", [True, False])
def test_a2c_separate_losses(self, separate_losses, composite_action_dist):
torch.manual_seed(self.seed)
actor, critic, common, td = self._create_mock_common_layer_setup()
actor, critic, common, td = self._create_mock_common_layer_setup(
composite_action_dist=composite_action_dist
)
loss_fn = A2CLoss(
actor_network=actor,
critic_network=critic,
Expand Down Expand Up @@ -8905,13 +8984,18 @@ def test_a2c_separate_losses(self, separate_losses):
@pytest.mark.parametrize("gradient_mode", (True, False))
@pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None))
@pytest.mark.parametrize("device", get_default_devices())
def test_a2c_diff(self, device, gradient_mode, advantage):
@pytest.mark.parametrize("composite_action_dist", [True, False])
def test_a2c_diff(self, device, gradient_mode, advantage, composite_action_dist):
if pack_version.parse(torch.__version__) > pack_version.parse("1.14"):
raise pytest.skip("make_functional_with_buffers needs to be changed")
torch.manual_seed(self.seed)
td = self._create_seq_mock_data_a2c(device=device)
td = self._create_seq_mock_data_a2c(
device=device, composite_action_dist=composite_action_dist
)

actor = self._create_mock_actor(device=device)
actor = self._create_mock_actor(
device=device, composite_action_dist=composite_action_dist
)
value = self._create_mock_value(device=device)
if advantage == "gae":
advantage = GAE(
Expand Down Expand Up @@ -8981,8 +9065,9 @@ def test_a2c_diff(self, device, gradient_mode, advantage):
ValueEstimators.TDLambda,
],
)
def test_a2c_tensordict_keys(self, td_est):
actor = self._create_mock_actor()
@pytest.mark.parametrize("composite_action_dist", [True, False])
def test_a2c_tensordict_keys(self, td_est, composite_action_dist):
actor = self._create_mock_actor(composite_action_dist=composite_action_dist)
value = self._create_mock_value()

loss_fn = A2CLoss(actor, value, loss_critic_type="l2")
Expand Down Expand Up @@ -9027,7 +9112,10 @@ def test_a2c_tensordict_keys(self, td_est):
)
@pytest.mark.parametrize("advantage", ("gae", "vtrace", None))
@pytest.mark.parametrize("device", get_default_devices())
def test_a2c_tensordict_keys_run(self, device, advantage, td_est):
@pytest.mark.parametrize("composite_action_dist", [True, False])
def test_a2c_tensordict_keys_run(
self, device, advantage, td_est, composite_action_dist
):
"""Test A2C loss module with non-default tensordict keys."""
torch.manual_seed(self.seed)
gradient_mode = True
Expand All @@ -9047,10 +9135,13 @@ def test_a2c_tensordict_keys_run(self, device, advantage, td_est):
done_key=done_key,
terminated_key=terminated_key,
sample_log_prob_key=sample_log_prob_key,
composite_action_dist=composite_action_dist,
)

actor = self._create_mock_actor(
device=device, sample_log_prob_key=sample_log_prob_key
device=device,
sample_log_prob_key=sample_log_prob_key,
composite_action_dist=composite_action_dist,
)
value = self._create_mock_value(device=device, out_keys=[value_key])
if advantage == "gae":
Expand Down Expand Up @@ -9186,15 +9277,20 @@ def test_a2c_notensordict(
assert loss_critic == loss_val_td["loss_critic"]

@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
def test_a2c_reduction(self, reduction):
@pytest.mark.parametrize("composite_action_dist", [True, False])
def test_a2c_reduction(self, reduction, composite_action_dist):
torch.manual_seed(self.seed)
device = (
torch.device("cpu")
if torch.cuda.device_count() == 0
else torch.device("cuda")
)
td = self._create_seq_mock_data_a2c(device=device)
actor = self._create_mock_actor(device=device)
td = self._create_seq_mock_data_a2c(
device=device, composite_action_dist=composite_action_dist
)
actor = self._create_mock_actor(
device=device, composite_action_dist=composite_action_dist
)
value = self._create_mock_value(device=device)
advantage = GAE(
gamma=0.9,
Expand All @@ -9221,10 +9317,15 @@ def test_a2c_reduction(self, reduction):

@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("clip_value", [True, None, 0.5, torch.tensor(0.5)])
def test_a2c_value_clipping(self, clip_value, device):
@pytest.mark.parametrize("composite_action_dist", [True, False])
def test_a2c_value_clipping(self, clip_value, device, composite_action_dist):
torch.manual_seed(self.seed)
td = self._create_seq_mock_data_a2c(device=device)
actor = self._create_mock_actor(device=device)
td = self._create_seq_mock_data_a2c(
device=device, composite_action_dist=composite_action_dist
)
actor = self._create_mock_actor(
device=device, composite_action_dist=composite_action_dist
)
value = self._create_mock_value(device=device)
advantage = GAE(
gamma=0.9,
Expand Down