Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Algorithm] GAIL #2273

Merged
merged 35 commits into from
Aug 6, 2024
Merged
Changes from 1 commit
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
6d7c5c4
fix norm
BY571 Jul 4, 2024
2d31f33
update docs
BY571 Jul 5, 2024
79bda13
update comments
BY571 Jul 5, 2024
1391b50
add sota-example-test
BY571 Jul 5, 2024
f444b72
update collection data slice
BY571 Jul 8, 2024
244b7ab
update docstring
BY571 Jul 8, 2024
db635d7
update config and objective with gp param
BY571 Jul 9, 2024
434622c
init cost tests gail
BY571 Jul 9, 2024
baca70f
update cost test
BY571 Jul 9, 2024
956567f
Merge branch 'main' into gail
BY571 Jul 11, 2024
8e7713f
add gail cost tests
BY571 Jul 11, 2024
714c35c
Merge remote-tracking branch 'origin/main' into gail
vmoens Jul 30, 2024
a05bef3
Merge branch 'main' into gail
BY571 Jul 31, 2024
b31da8a
Update config
BY571 Jul 31, 2024
63885b0
update gail device
BY571 Jul 31, 2024
739332c
update example tests
BY571 Jul 31, 2024
7a9919c
Merge branch 'gail' of github.com:BY571/rl into gail
BY571 Jul 31, 2024
3fd3c32
Merge remote-tracking branch 'origin/main' into gail
vmoens Aug 2, 2024
9455fef
gymnasium backend
BY571 Aug 5, 2024
fba43d2
Merge branch 'gail' of https://github.com/BY571/rl into gail
vmoens Aug 5, 2024
5e41d89
Merge remote-tracking branch 'origin/main' into gail
vmoens Aug 5, 2024
6c3f7d2
Merge remote-tracking branch 'origin/main' into gail
vmoens Aug 5, 2024
415443b
Merge remote-tracking branch 'origin/main' into gail
vmoens Aug 6, 2024
4926d80
fixes
vmoens Aug 6, 2024
cbd5dfa
init
vmoens Aug 6, 2024
70e1f49
Merge branch 'pin-mujoco' into gail
vmoens Aug 6, 2024
b8ca705
amend
vmoens Aug 6, 2024
6a00bda
Merge branch 'pin-mujoco' into gail
vmoens Aug 6, 2024
f0c225f
amend
vmoens Aug 6, 2024
511fa95
amend
vmoens Aug 6, 2024
4bc316b
amend
vmoens Aug 6, 2024
2f7e64c
Merge branch 'pin-mujoco' into gail
vmoens Aug 6, 2024
3d43e42
amend
vmoens Aug 6, 2024
63398d1
Merge branch 'pin-mujoco' into gail
vmoens Aug 6, 2024
c488bcd
amend
vmoens Aug 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
init cost tests gail
  • Loading branch information
BY571 committed Jul 9, 2024
commit 434622c15cbbc0bb743613a9dd28a96acdedcdb2
199 changes: 199 additions & 0 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -8976,6 +8976,205 @@ def test_dt_reduction(self, reduction):
assert loss["loss"].shape == torch.Size([])


class TestGAIL(LossModuleTestBase):
seed = 0

def _create_mock_discriminator(
self, batch=2, obs_dim=3, action_dim=4, device="cpu"
):
# Discriminator
body = TensorDictModule(
MLP(
in_features=obs_dim + action_dim,
out_features=32,
depth=1,
num_cells=32,
activation_class=torch.nn.ReLU,
activate_last_layer=True,
),
in_keys=["observation", "action"],
out_keys="hidden",
)
head = TensorDictModule(
MLP(
in_features=32,
out_features=1,
depth=0,
num_cells=32,
activation_class=torch.nn.Sigmoid,
activate_last_layer=True,
),
in_keys="hidden",
out_keys="d_logits",
)
discriminator = TensorDictSequential(body, head)

return discriminator.to(device)

def _create_mock_data_gail(self, batch=2, obs_dim=3, action_dim=4, device="cpu"):
# create a tensordict
obs = torch.randn(batch, obs_dim, device=device)
action = torch.randn(batch, action_dim, device=device).clamp(-1, 1)
td = TensorDict(
batch_size=(batch,),
source={
"observation": obs,
"action": action,
},
device=device,
)
return td

def _create_seq_mock_data_gail(
self, batch=2, T=4, obs_dim=3, action_dim=4, device="cpu"
):
# create a tensordict
obs = torch.randn(batch, T, obs_dim, device=device)
action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1)

td = TensorDict(
batch_size=(batch, T),
source={
"observation": obs,
"action": action,
},
device=device,
)
return td

def test_dt_tensordict_keys(self):
actor = self._create_mock_actor()
loss_fn = DTLoss(actor)

default_keys = {
"action_target": "action",
"action_pred": "action",
}

self.tensordict_keys_test(
loss_fn,
default_keys=default_keys,
)

@pytest.mark.parametrize("device", get_default_devices())
def test_dt_notensordict(self, device):
torch.manual_seed(self.seed)
actor = self._create_mock_actor(device=device)
td = self._create_mock_data_dt(device=device)
loss_fn = DTLoss(actor)

in_keys = self._flatten_in_keys(loss_fn.in_keys)
kwargs = dict(td.flatten_keys("_").select(*in_keys))

loss_val_td = loss_fn(td)
loss_val = loss_fn(**kwargs)
torch.testing.assert_close(loss_val_td.get("loss"), loss_val)
# test select
loss_fn.select_out_keys("loss")
if torch.__version__ >= "2.0.0":
loss_actor = loss_fn(**kwargs)
else:
with pytest.raises(
RuntimeError,
match="You are likely using tensordict.nn.dispatch with keyword arguments",
):
loss_actor = loss_fn(**kwargs)
return
assert loss_actor == loss_val_td["loss"]

@pytest.mark.parametrize("device", get_available_devices())
def test_dt(self, device):
torch.manual_seed(self.seed)
td = self._create_mock_data_dt(device=device)

actor = self._create_mock_actor(device=device)

loss_fn = DTLoss(actor)
loss = loss_fn(td)
loss_transformer = loss["loss"]
loss_transformer.backward(retain_graph=True)
named_parameters = loss_fn.named_parameters()

for name, p in named_parameters:
if p.grad is not None and p.grad.norm() > 0.0:
assert "actor" in name
assert "alpha" not in name
if p.grad is None:
assert "actor" not in name
assert "alpha" in name
loss_fn.zero_grad()

sum([loss_transformer]).backward()
named_parameters = list(loss_fn.named_parameters())
named_buffers = list(loss_fn.named_buffers())

assert len({p for n, p in named_parameters}) == len(list(named_parameters))
assert len({p for n, p in named_buffers}) == len(list(named_buffers))

for name, p in named_parameters:
assert p.grad.norm() > 0.0, f"parameter {name} has a null gradient"

@pytest.mark.parametrize("device", get_available_devices())
def test_dt_state_dict(self, device):
torch.manual_seed(self.seed)

actor = self._create_mock_actor(device=device)

loss_fn = DTLoss(actor)
sd = loss_fn.state_dict()
loss_fn2 = DTLoss(actor)
loss_fn2.load_state_dict(sd)

@pytest.mark.parametrize("device", get_available_devices())
def test_seq_dt(self, device):
torch.manual_seed(self.seed)
td = self._create_seq_mock_data_dt(device=device)

actor = self._create_mock_actor(device=device)

loss_fn = DTLoss(actor)
loss = loss_fn(td)
loss_transformer = loss["loss"]
loss_transformer.backward(retain_graph=True)
named_parameters = loss_fn.named_parameters()

for name, p in named_parameters:
if p.grad is not None and p.grad.norm() > 0.0:
assert "actor" in name
assert "alpha" not in name
if p.grad is None:
assert "actor" not in name
assert "alpha" in name
loss_fn.zero_grad()

sum([loss_transformer]).backward()
named_parameters = list(loss_fn.named_parameters())
named_buffers = list(loss_fn.named_buffers())

assert len({p for n, p in named_parameters}) == len(list(named_parameters))
assert len({p for n, p in named_buffers}) == len(list(named_buffers))

for name, p in named_parameters:
assert p.grad.norm() > 0.0, f"parameter {name} has a null gradient"

@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
def test_dt_reduction(self, reduction):
torch.manual_seed(self.seed)
device = (
torch.device("cpu")
if torch.cuda.device_count() == 0
else torch.device("cuda")
)
td = self._create_mock_data_dt(device=device)
actor = self._create_mock_actor(device=device)
loss_fn = DTLoss(actor, reduction=reduction)
loss = loss_fn(td)
if reduction == "none":
assert loss["loss"].shape == td["action"].shape
else:
assert loss["loss"].shape == torch.Size([])


@pytest.mark.skipif(
not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}"
)
Expand Down