Skip to content

Commit

Permalink
[Feature] Add time_dim arg in value modules (#1946)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Feb 22, 2024
1 parent 40e9900 commit b28bbfe
Show file tree
Hide file tree
Showing 2 changed files with 233 additions and 56 deletions.
86 changes: 86 additions & 0 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -11990,6 +11990,92 @@ def test_non_differentiable(self, adv, shifted, kwargs):
td = module(td.clone(False))
assert td["advantage"].is_leaf

@pytest.mark.parametrize(
"adv,kwargs",
[
[GAE, {"lmbda": 0.95}],
[TD1Estimator, {}],
[TDLambdaEstimator, {"lmbda": 0.95}],
[VTrace, {}],
],
)
def test_time_dim(self, adv, kwargs, shifted=True):
value_net = TensorDictModule(
nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"]
)

if adv is VTrace:
actor_net = TensorDictModule(
nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"]
)
actor_net = ProbabilisticActor(
module=actor_net,
in_keys=["logits"],
out_keys=["action"],
distribution_class=OneHotCategorical,
return_log_prob=True,
)
module_make = functools.partial(
adv,
gamma=0.98,
actor_network=actor_net,
value_network=value_net,
differentiable=False,
shifted=shifted,
**kwargs,
)
td = TensorDict(
{
"obs": torch.randn(1, 10, 3),
"sample_log_prob": torch.log(torch.rand(1, 10, 1)),
"next": {
"obs": torch.randn(1, 10, 3),
"reward": torch.randn(1, 10, 1, requires_grad=True),
"done": torch.zeros(1, 10, 1, dtype=torch.bool),
"terminated": torch.zeros(1, 10, 1, dtype=torch.bool),
},
},
[1, 10],
names=[None, "time"],
)
else:
module_make = functools.partial(
adv,
gamma=0.98,
value_network=value_net,
differentiable=False,
shifted=shifted,
**kwargs,
)
td = TensorDict(
{
"obs": torch.randn(1, 10, 3),
"next": {
"obs": torch.randn(1, 10, 3),
"reward": torch.randn(1, 10, 1, requires_grad=True),
"done": torch.zeros(1, 10, 1, dtype=torch.bool),
},
},
[1, 10],
names=[None, "time"],
)

module_none = module_make(time_dim=None)
module_0 = module_make(time_dim=0)
module_1 = module_make(time_dim=1)

td_none = module_none(td.clone(False))
td_1 = module_1(td.clone(False))
td_0 = module_0(td.transpose(0, 1).clone(False))
assert_allclose_td(td_none, td_1)
assert_allclose_td(td_none, td_0.transpose(0, 1))

if adv is not VTrace:
vt = module_none.value_estimate(td.clone(False))
vt_patch = module_0.value_estimate(td.clone(False), time_dim=1)
vt_patch2 = module_0.value_estimate(td.clone(False), time_dim=-1)
torch.testing.assert_close(vt, vt_patch)

@pytest.mark.parametrize(
"adv,kwargs",
[
Expand Down
Loading

0 comments on commit b28bbfe

Please sign in to comment.