Skip to content

Commit

Permalink
[Feature] SAC compatibility with composite distributions. (#2447)
Browse files Browse the repository at this point in the history
  • Loading branch information
albertbou92 authored Oct 11, 2024
1 parent 56cc525 commit ec04c35
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 40 deletions.
150 changes: 118 additions & 32 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from mocking_classes import ContinuousActionConvMockEnv

# from torchrl.data.postprocs.utils import expand_as_right
from tensordict import assert_allclose_td, TensorDict
from tensordict import assert_allclose_td, TensorDict, TensorDictBase
from tensordict.nn import NormalParamExtractor, TensorDictModule
from tensordict.nn.utils import Buffer
from tensordict.utils import unravel_key
Expand Down Expand Up @@ -3450,21 +3450,40 @@ def _create_mock_actor(
device="cpu",
observation_key="observation",
action_key="action",
composite_action_dist=False,
):
# Actor
action_spec = Bounded(
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
if composite_action_dist:
action_spec = Composite({action_key: {"action1": action_spec}})
net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor())
if composite_action_dist:
distribution_class = functools.partial(
CompositeDistribution,
distribution_map={
"action1": TanhNormal,
},
aggregate_probabilities=True,
)
module_out_keys = [
("params", "action1", "loc"),
("params", "action1", "scale"),
]
actor_in_keys = ["params"]
else:
distribution_class = TanhNormal
module_out_keys = actor_in_keys = ["loc", "scale"]
module = TensorDictModule(
net, in_keys=[observation_key], out_keys=["loc", "scale"]
net, in_keys=[observation_key], out_keys=module_out_keys
)
actor = ProbabilisticActor(
module=module,
in_keys=["loc", "scale"],
spec=action_spec,
distribution_class=TanhNormal,
distribution_class=distribution_class,
in_keys=actor_in_keys,
out_keys=[action_key],
spec=action_spec,
)
return actor.to(device)

Expand All @@ -3484,6 +3503,8 @@ def __init__(self):
self.linear = nn.Linear(obs_dim + action_dim, 1)

def forward(self, obs, act):
if isinstance(act, TensorDictBase):
act = act.get("action1")
return self.linear(torch.cat([obs, act], -1))

module = ValueClass()
Expand Down Expand Up @@ -3512,8 +3533,26 @@ def _create_mock_value(
return value.to(device)

def _create_mock_common_layer_setup(
self, n_obs=3, n_act=4, ncells=4, batch=2, n_hidden=2
self,
n_obs=3,
n_act=4,
ncells=4,
batch=2,
n_hidden=2,
composite_action_dist=False,
):
class QValueClass(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(n_hidden + n_act, n_hidden)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(n_hidden, 1)

def forward(self, obs, act):
if isinstance(act, TensorDictBase):
act = act.get("action1")
return self.linear2(self.relu(self.linear1(torch.cat([obs, act], -1))))

common = MLP(
num_cells=ncells,
in_features=n_obs,
Expand All @@ -3526,17 +3565,13 @@ def _create_mock_common_layer_setup(
depth=1,
out_features=2 * n_act,
)
qvalue = MLP(
in_features=n_hidden + n_act,
num_cells=ncells,
depth=1,
out_features=1,
)
qvalue = QValueClass()
batch = [batch]
action = torch.randn(*batch, n_act)
td = TensorDict(
{
"obs": torch.randn(*batch, n_obs),
"action": torch.randn(*batch, n_act),
"action": {"action1": action} if composite_action_dist else action,
"done": torch.zeros(*batch, 1, dtype=torch.bool),
"terminated": torch.zeros(*batch, 1, dtype=torch.bool),
"next": {
Expand All @@ -3549,14 +3584,30 @@ def _create_mock_common_layer_setup(
batch,
)
common = Mod(common, in_keys=["obs"], out_keys=["hidden"])
if composite_action_dist:
distribution_class = functools.partial(
CompositeDistribution,
distribution_map={
"action1": TanhNormal,
},
aggregate_probabilities=True,
)
module_out_keys = [
("params", "action1", "loc"),
("params", "action1", "scale"),
]
actor_in_keys = ["params"]
else:
distribution_class = TanhNormal
module_out_keys = actor_in_keys = ["loc", "scale"]
actor = ProbSeq(
common,
Mod(actor_net, in_keys=["hidden"], out_keys=["param"]),
Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]),
Mod(NormalParamExtractor(), in_keys=["param"], out_keys=module_out_keys),
ProbMod(
in_keys=["loc", "scale"],
in_keys=actor_in_keys,
out_keys=["action"],
distribution_class=TanhNormal,
distribution_class=distribution_class,
),
)
qvalue_head = Mod(
Expand All @@ -3582,6 +3633,7 @@ def _create_mock_data_sac(
done_key="done",
terminated_key="terminated",
reward_key="reward",
composite_action_dist=False,
):
# create a tensordict
obs = torch.randn(batch, obs_dim, device=device)
Expand All @@ -3603,14 +3655,21 @@ def _create_mock_data_sac(
terminated_key: terminated,
reward_key: reward,
},
action_key: action,
action_key: {"action1": action} if composite_action_dist else action,
},
device=device,
)
return td

def _create_seq_mock_data_sac(
self, batch=8, T=4, obs_dim=3, action_dim=4, atoms=None, device="cpu"
self,
batch=8,
T=4,
obs_dim=3,
action_dim=4,
atoms=None,
device="cpu",
composite_action_dist=False,
):
# create a tensordict
total_obs = torch.randn(batch, T + 1, obs_dim, device=device)
Expand All @@ -3626,6 +3685,7 @@ def _create_seq_mock_data_sac(
done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
mask = torch.ones(batch, T, dtype=torch.bool, device=device)
action = action.masked_fill_(~mask.unsqueeze(-1), 0.0)
td = TensorDict(
batch_size=(batch, T),
source={
Expand All @@ -3637,7 +3697,7 @@ def _create_seq_mock_data_sac(
"reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0),
},
"collector": {"mask": mask},
"action": action.masked_fill_(~mask.unsqueeze(-1), 0.0),
"action": {"action1": action} if composite_action_dist else action,
},
names=[None, "time"],
device=device,
Expand All @@ -3650,6 +3710,7 @@ def _create_seq_mock_data_sac(
@pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8])
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
@pytest.mark.parametrize("composite_action_dist", [True, False])
def test_sac(
self,
delay_value,
Expand All @@ -3659,14 +3720,19 @@ def test_sac(
device,
version,
td_est,
composite_action_dist,
):
if (delay_actor or delay_qvalue) and not delay_value:
pytest.skip("incompatible config")

torch.manual_seed(self.seed)
td = self._create_mock_data_sac(device=device)
td = self._create_mock_data_sac(
device=device, composite_action_dist=composite_action_dist
)

actor = self._create_mock_actor(device=device)
actor = self._create_mock_actor(
device=device, composite_action_dist=composite_action_dist
)
qvalue = self._create_mock_qvalue(device=device)
if version == 1:
value = self._create_mock_value(device=device)
Expand Down Expand Up @@ -3816,6 +3882,7 @@ def test_sac(
@pytest.mark.parametrize("delay_qvalue", (True, False))
@pytest.mark.parametrize("num_qvalue", [2])
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("composite_action_dist", [True, False])
def test_sac_state_dict(
self,
delay_value,
Expand All @@ -3824,13 +3891,16 @@ def test_sac_state_dict(
num_qvalue,
device,
version,
composite_action_dist,
):
if (delay_actor or delay_qvalue) and not delay_value:
pytest.skip("incompatible config")

torch.manual_seed(self.seed)

actor = self._create_mock_actor(device=device)
actor = self._create_mock_actor(
device=device, composite_action_dist=composite_action_dist
)
qvalue = self._create_mock_qvalue(device=device)
if version == 1:
value = self._create_mock_value(device=device)
Expand Down Expand Up @@ -3866,15 +3936,19 @@ def test_sac_state_dict(

@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("separate_losses", [False, True])
@pytest.mark.parametrize("composite_action_dist", [True, False])
def test_sac_separate_losses(
self,
device,
separate_losses,
version,
composite_action_dist,
n_act=4,
):
torch.manual_seed(self.seed)
actor, qvalue, common, td = self._create_mock_common_layer_setup(n_act=n_act)
actor, qvalue, common, td = self._create_mock_common_layer_setup(
n_act=n_act, composite_action_dist=composite_action_dist
)

loss_fn = SACLoss(
actor_network=actor,
Expand Down Expand Up @@ -3960,6 +4034,7 @@ def test_sac_separate_losses(
@pytest.mark.parametrize("delay_qvalue", (True, False))
@pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8])
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("composite_action_dist", [True, False])
def test_sac_batcher(
self,
n,
Expand All @@ -3969,13 +4044,18 @@ def test_sac_batcher(
num_qvalue,
device,
version,
composite_action_dist,
):
if (delay_actor or delay_qvalue) and not delay_value:
pytest.skip("incompatible config")
torch.manual_seed(self.seed)
td = self._create_seq_mock_data_sac(device=device)
td = self._create_seq_mock_data_sac(
device=device, composite_action_dist=composite_action_dist
)

actor = self._create_mock_actor(device=device)
actor = self._create_mock_actor(
device=device, composite_action_dist=composite_action_dist
)
qvalue = self._create_mock_qvalue(device=device)
if version == 1:
value = self._create_mock_value(device=device)
Expand Down Expand Up @@ -4126,10 +4206,11 @@ def test_sac_batcher(
@pytest.mark.parametrize(
"td_est", [ValueEstimators.TD1, ValueEstimators.TD0, ValueEstimators.TDLambda]
)
def test_sac_tensordict_keys(self, td_est, version):
td = self._create_mock_data_sac()
@pytest.mark.parametrize("composite_action_dist", [True, False])
def test_sac_tensordict_keys(self, td_est, version, composite_action_dist):
td = self._create_mock_data_sac(composite_action_dist=composite_action_dist)

actor = self._create_mock_actor()
actor = self._create_mock_actor(composite_action_dist=composite_action_dist)
qvalue = self._create_mock_qvalue()
if version == 1:
value = self._create_mock_value()
Expand All @@ -4149,7 +4230,7 @@ def test_sac_tensordict_keys(self, td_est, version):
"value": "state_value",
"state_action_value": "state_action_value",
"action": "action",
"log_prob": "_log_prob",
"log_prob": "sample_log_prob",
"reward": "reward",
"done": "done",
"terminated": "terminated",
Expand Down Expand Up @@ -4311,15 +4392,20 @@ def test_state_dict(self, version):
loss.load_state_dict(state)

@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
def test_sac_reduction(self, reduction, version):
@pytest.mark.parametrize("composite_action_dist", [True, False])
def test_sac_reduction(self, reduction, version, composite_action_dist):
torch.manual_seed(self.seed)
device = (
torch.device("cpu")
if torch.cuda.device_count() == 0
else torch.device("cuda")
)
td = self._create_mock_data_sac(device=device)
actor = self._create_mock_actor(device=device)
td = self._create_mock_data_sac(
device=device, composite_action_dist=composite_action_dist
)
actor = self._create_mock_actor(
device=device, composite_action_dist=composite_action_dist
)
qvalue = self._create_mock_qvalue(device=device)
if version == 1:
value = self._create_mock_value(device=device)
Expand Down
Loading

0 comments on commit ec04c35

Please sign in to comment.