Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Algorithm] GAIL #2273

Merged
merged 35 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
6d7c5c4
fix norm
BY571 Jul 4, 2024
2d31f33
update docs
BY571 Jul 5, 2024
79bda13
update comments
BY571 Jul 5, 2024
1391b50
add sota-example-test
BY571 Jul 5, 2024
f444b72
update collection data slice
BY571 Jul 8, 2024
244b7ab
update docstring
BY571 Jul 8, 2024
db635d7
update config and objective with gp param
BY571 Jul 9, 2024
434622c
init cost tests gail
BY571 Jul 9, 2024
baca70f
update cost test
BY571 Jul 9, 2024
956567f
Merge branch 'main' into gail
BY571 Jul 11, 2024
8e7713f
add gail cost tests
BY571 Jul 11, 2024
714c35c
Merge remote-tracking branch 'origin/main' into gail
vmoens Jul 30, 2024
a05bef3
Merge branch 'main' into gail
BY571 Jul 31, 2024
b31da8a
Update config
BY571 Jul 31, 2024
63885b0
update gail device
BY571 Jul 31, 2024
739332c
update example tests
BY571 Jul 31, 2024
7a9919c
Merge branch 'gail' of github.com:BY571/rl into gail
BY571 Jul 31, 2024
3fd3c32
Merge remote-tracking branch 'origin/main' into gail
vmoens Aug 2, 2024
9455fef
gymnasium backend
BY571 Aug 5, 2024
fba43d2
Merge branch 'gail' of https://github.com/BY571/rl into gail
vmoens Aug 5, 2024
5e41d89
Merge remote-tracking branch 'origin/main' into gail
vmoens Aug 5, 2024
6c3f7d2
Merge remote-tracking branch 'origin/main' into gail
vmoens Aug 5, 2024
415443b
Merge remote-tracking branch 'origin/main' into gail
vmoens Aug 6, 2024
4926d80
fixes
vmoens Aug 6, 2024
cbd5dfa
init
vmoens Aug 6, 2024
70e1f49
Merge branch 'pin-mujoco' into gail
vmoens Aug 6, 2024
b8ca705
amend
vmoens Aug 6, 2024
6a00bda
Merge branch 'pin-mujoco' into gail
vmoens Aug 6, 2024
f0c225f
amend
vmoens Aug 6, 2024
511fa95
amend
vmoens Aug 6, 2024
4bc316b
amend
vmoens Aug 6, 2024
2f7e64c
Merge branch 'pin-mujoco' into gail
vmoens Aug 6, 2024
3d43e42
amend
vmoens Aug 6, 2024
63398d1
Merge branch 'pin-mujoco' into gail
vmoens Aug 6, 2024
c488bcd
amend
vmoens Aug 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add gail cost tests
  • Loading branch information
BY571 committed Jul 11, 2024
commit 8e7713f1d503fa1d29c26ee9c0b5a48ec1048d8e
85 changes: 49 additions & 36 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -10438,6 +10438,8 @@ def _create_mock_data_gail(self, batch=2, obs_dim=3, action_dim=4, device="cpu")
source={
"observation": obs,
"action": action,
"collector_action": action,
"collector_observation": obs,
},
device=device,
)
Expand All @@ -10455,6 +10457,8 @@ def _create_seq_mock_data_gail(
source={
"observation": obs,
"action": action,
"collector_action": action,
"collector_observation": obs,
},
device=device,
)
Expand All @@ -10478,23 +10482,26 @@ def test_gail_tensordict_keys(self):
)

@pytest.mark.parametrize("device", get_default_devices())
def test_gail_notensordict(self, device):
@pytest.mark.parametrize("use_grad_penalty", [True, False])
@pytest.mark.parametrize("gp_lambda", [0.1, 1.0])
def test_gail_notensordict(self, device, use_grad_penalty, gp_lambda):
torch.manual_seed(self.seed)
discriminator = self._create_mock_discriminator(device=device)
loss_fn = DTLoss(discriminator)

expert_td = self._create_mock_data_gail(device=device)
collector_td = self._create_mock_data_gail(device=device)
expert_td.set(
loss_fn.tensor_keys.collector_observation, collector_td["observation"]
loss_fn = GAILLoss(
discriminator, use_grad_penalty=use_grad_penalty, gp_lambda=gp_lambda
)
expert_td.set(loss_fn.tensor_keys.collector_action, collector_td["action"])

tensordict = self._create_mock_data_gail(device=device)

in_keys = self._flatten_in_keys(loss_fn.in_keys)
kwargs = dict(expert_td.flatten_keys("_").select(*in_keys))
kwargs = dict(tensordict.flatten_keys("_").select(*in_keys))

loss_val_td = loss_fn(tensordict)
if use_grad_penalty:
loss_val, _ = loss_fn(**kwargs)
else:
loss_val = loss_fn(**kwargs)

loss_val_td = loss_fn(expert_td)
loss_val = loss_fn(**kwargs)
torch.testing.assert_close(loss_val_td.get("loss"), loss_val)
# test select
loss_fn.select_out_keys("loss")
Expand All @@ -10510,25 +10517,27 @@ def test_gail_notensordict(self, device):
assert loss_discriminator == loss_val_td["loss"]

@pytest.mark.parametrize("device", get_available_devices())
def test_dt(self, device):
@pytest.mark.parametrize("use_grad_penalty", [True, False])
@pytest.mark.parametrize("gp_lambda", [0.1, 1.0])
def test_gail(self, device, use_grad_penalty, gp_lambda):
torch.manual_seed(self.seed)
td = self._create_mock_data_dt(device=device)
td = self._create_mock_data_gail(device=device)

actor = self._create_mock_actor(device=device)
discriminator = self._create_mock_discriminator(device=device)

loss_fn = DTLoss(actor)
loss_fn = GAILLoss(
discriminator, use_grad_penalty=use_grad_penalty, gp_lambda=gp_lambda
)
loss = loss_fn(td)
loss_transformer = loss["loss"]
loss_transformer.backward(retain_graph=True)
named_parameters = loss_fn.named_parameters()

for name, p in named_parameters:
if p.grad is not None and p.grad.norm() > 0.0:
assert "actor" in name
assert "alpha" not in name
assert "discriminator" in name
if p.grad is None:
assert "actor" not in name
assert "alpha" in name
assert "discriminator" not in name
loss_fn.zero_grad()

sum([loss_transformer]).backward()
Expand All @@ -10542,36 +10551,38 @@ def test_dt(self, device):
assert p.grad.norm() > 0.0, f"parameter {name} has a null gradient"

@pytest.mark.parametrize("device", get_available_devices())
def test_dt_state_dict(self, device):
def test_gail_state_dict(self, device):
torch.manual_seed(self.seed)

actor = self._create_mock_actor(device=device)
discriminator = self._create_mock_discriminator(device=device)

loss_fn = DTLoss(actor)
loss_fn = GAILLoss(discriminator)
sd = loss_fn.state_dict()
loss_fn2 = DTLoss(actor)
loss_fn2 = GAILLoss(discriminator)
loss_fn2.load_state_dict(sd)

@pytest.mark.parametrize("device", get_available_devices())
def test_seq_dt(self, device):
@pytest.mark.parametrize("use_grad_penalty", [True, False])
@pytest.mark.parametrize("gp_lambda", [0.1, 1.0])
def test_seq_gail(self, device, use_grad_penalty, gp_lambda):
torch.manual_seed(self.seed)
td = self._create_seq_mock_data_dt(device=device)
td = self._create_seq_mock_data_gail(device=device)

actor = self._create_mock_actor(device=device)
discriminator = self._create_mock_discriminator(device=device)

loss_fn = DTLoss(actor)
loss_fn = GAILLoss(
discriminator, use_grad_penalty=use_grad_penalty, gp_lambda=gp_lambda
)
loss = loss_fn(td)
loss_transformer = loss["loss"]
loss_transformer.backward(retain_graph=True)
named_parameters = loss_fn.named_parameters()

for name, p in named_parameters:
if p.grad is not None and p.grad.norm() > 0.0:
assert "actor" in name
assert "alpha" not in name
assert "discriminator" in name
if p.grad is None:
assert "actor" not in name
assert "alpha" in name
assert "discriminator" not in name
loss_fn.zero_grad()

sum([loss_transformer]).backward()
Expand All @@ -10585,19 +10596,21 @@ def test_seq_dt(self, device):
assert p.grad.norm() > 0.0, f"parameter {name} has a null gradient"

@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
def test_dt_reduction(self, reduction):
@pytest.mark.parametrize("use_grad_penalty", [True, False])
@pytest.mark.parametrize("gp_lambda", [0.1, 1.0])
def test_gail_reduction(self, reduction, use_grad_penalty, gp_lambda):
torch.manual_seed(self.seed)
device = (
torch.device("cpu")
if torch.cuda.device_count() == 0
else torch.device("cuda")
)
td = self._create_mock_data_dt(device=device)
actor = self._create_mock_actor(device=device)
loss_fn = DTLoss(actor, reduction=reduction)
td = self._create_mock_data_gail(device=device)
discriminator = self._create_mock_discriminator(device=device)
loss_fn = GAILLoss(discriminator, reduction=reduction)
loss = loss_fn(td)
if reduction == "none":
assert loss["loss"].shape == td["action"].shape
assert loss["loss"].shape == (td["observation"].shape[0], 1)
else:
assert loss["loss"].shape == torch.Size([])

Expand Down
65 changes: 43 additions & 22 deletions torchrl/objectives/gail.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ class _AcceptedKeys:

discriminator_network: TensorDictModule
discriminator_network_params: TensorDictParams
target_discriminator_network: TensorDictModule
target_discriminator_network_params: TensorDictParams

out_keys = [
"loss",
"gp_loss",
]

def __init__(
self,
Expand All @@ -84,7 +91,7 @@ def __init__(
"discriminator_network",
create_target_params=False,
)
self.loss_function = torch.nn.BCELoss()
self.loss_function = torch.nn.BCELoss(reduction="none")
self.use_grad_penalty = use_grad_penalty
self.gp_lambda = gp_lambda

Expand All @@ -95,6 +102,8 @@ def _set_in_keys(self):
keys = set(keys)
keys.add(self.tensor_keys.expert_observation)
keys.add(self.tensor_keys.expert_action)
keys.add(self.tensor_keys.collector_observation)
keys.add(self.tensor_keys.collector_action)
self._in_keys = sorted(keys, key=str)

def _forward_value_estimator_keys(self, **kwargs) -> None:
Expand All @@ -114,6 +123,8 @@ def in_keys(self, values):
def out_keys(self):
if self._out_keys is None:
keys = ["loss"]
if self.use_grad_penalty:
keys.append("gp_loss")
self._out_keys = keys
return self._out_keys

Expand All @@ -126,9 +137,19 @@ def forward(
self,
tensordict: TensorDictBase,
) -> TensorDictBase:
"""Compute the GAIL discriminator loss."""
"""The forward method.

Computes the discriminator loss and gradient penalty if `use_grad_penalty` is set to True. If `use_grad_penalty` is set to True, the detached gradient penalty loss is also returned for logging purposes.
To see what keys are expected in the input tensordict and what keys are expected as output, check the
class's `"in_keys"` and `"out_keys"` attributes.
"""
device = self.discriminator_network.device
tensordict = tensordict.clone(False)
batch_size = tensordict.batch_size[0]
shape = tensordict.shape
if len(shape) > 1:
batch_size, seq_len = shape
else:
batch_size = shape[0]
collector_obs = tensordict.get(self.tensor_keys.collector_observation)
collector_act = tensordict.get(self.tensor_keys.collector_action)

Expand All @@ -144,15 +165,20 @@ def forward(
self.tensor_keys.expert_action: combined_act_inputs,
},
batch_size=[2 * batch_size],
device=device,
)

# create labels
fake_labels = torch.zeros((batch_size, 1), dtype=torch.float32).to(
tensordict.device
)
real_labels = torch.ones((batch_size, 1), dtype=torch.float32).to(
tensordict.device
)
# create
if len(shape) > 1:
fake_labels = torch.zeros((batch_size, seq_len, 1), dtype=torch.float32).to(
device
)
real_labels = torch.ones((batch_size, seq_len, 1), dtype=torch.float32).to(
device
)
else:
fake_labels = torch.zeros((batch_size, 1), dtype=torch.float32).to(device)
real_labels = torch.ones((batch_size, 1), dtype=torch.float32).to(device)

with self.discriminator_network_params.to_module(self.discriminator_network):
d_logits = self.discriminator_network(combined_inputs).get(
Expand All @@ -167,22 +193,18 @@ def forward(
collection_loss = self.loss_function(collection_preds, fake_labels)

loss = expert_loss + collection_loss
out = {"loss": loss}
out = {}
if self.use_grad_penalty:
obs = tensordict.get(self.tensor_keys.collector_observation)
acts = tensordict.get(self.tensor_keys.collector_action)
obs_e = tensordict.get(self.tensor_keys.expert_observation)
acts_e = tensordict.get(self.tensor_keys.expert_action)

obss_noise = (
torch.distributions.Uniform(0.0, 1.0)
.sample(obs_e.shape)
.to(tensordict.device)
torch.distributions.Uniform(0.0, 1.0).sample(obs_e.shape).to(device)
)
acts_noise = (
torch.distributions.Uniform(0.0, 1.0)
.sample(acts_e.shape)
.to(tensordict.device)
torch.distributions.Uniform(0.0, 1.0).sample(acts_e.shape).to(device)
)
obss_mixture = obss_noise * obs + (1 - obss_noise) * obs_e
acts_mixture = acts_noise * acts + (1 - acts_noise) * acts_e
Expand All @@ -195,6 +217,7 @@ def forward(
self.tensor_keys.expert_action: acts_mixture,
},
[],
device=device,
)

with self.discriminator_network_params.to_module(
Expand All @@ -208,9 +231,7 @@ def forward(
autograd.grad(
outputs=d_logits_mixture,
inputs=(obss_mixture, acts_mixture),
grad_outputs=torch.ones(
d_logits_mixture.size(), device=tensordict.device
),
grad_outputs=torch.ones(d_logits_mixture.size(), device=device),
create_graph=True,
retain_graph=True,
only_inputs=True,
Expand All @@ -223,8 +244,8 @@ def forward(
)

loss += gp_loss
out["gp_loss"] = gp_loss
out["gp_loss"] = gp_loss.detach()
loss = _reduce(loss, reduction=self.reduction)

out["loss"] = loss
td_out = TensorDict(out, [])
return td_out
Loading