Skip to content

Commit

Permalink
[Feature] shifted for all adv (#1276)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jun 14, 2023
1 parent 671637a commit 02963e2
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 7 deletions.
7 changes: 6 additions & 1 deletion test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -7101,14 +7101,16 @@ def test_diff_reward(
[TDLambdaEstimator, {"lmbda": 0.95}],
],
)
def test_non_differentiable(self, adv, kwargs):
@pytest.mark.parametrize("shifted", [True, False])
def test_non_differentiable(self, adv, shifted, kwargs):
value_net = TensorDictModule(
nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"]
)
module = adv(
gamma=0.98,
value_network=value_net,
differentiable=False,
shifted=shifted,
**kwargs,
)
td = TensorDict(
Expand All @@ -7135,12 +7137,14 @@ def test_non_differentiable(self, adv, kwargs):
)
@pytest.mark.parametrize("has_value_net", [True, False])
@pytest.mark.parametrize("skip_existing", [True, False, None])
@pytest.mark.parametrize("shifted", [True, False])
def test_skip_existing(
self,
adv,
kwargs,
has_value_net,
skip_existing,
shifted,
):
if has_value_net:
value_net = TensorDictModule(
Expand All @@ -7155,6 +7159,7 @@ def test_skip_existing(
gamma=0.98,
value_network=value_net,
differentiable=True,
shifted=shifted,
skip_existing=skip_existing,
**kwargs,
)
Expand Down
38 changes: 32 additions & 6 deletions torchrl/objectives/value/advantages.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,23 +78,22 @@ def _call_value_nets(
ndim = i + 1
if single_call:
# get data at t and last of t+1
data = torch.cat(
data_in = torch.cat(
[
data.select(*in_keys, value_key, strict=False),
data.get("next").select(*in_keys, value_key, strict=False)[..., -1:],
],
-1,
)
print("single", data)
# next_params should be None or be identical to params
if next_params is not None and next_params is not params:
raise ValueError(
"the value at t and t+1 cannot be retrieved in a single call without recurring to vmap when both params and next params are passed."
)
if params is not None:
value_est = value_net(data, params).get(value_key)
value_est = value_net(data_in, params).get(value_key)
else:
value_est = value_net(data).get(value_key)
value_est = value_net(data_in).get(value_key)
idx = (slice(None),) * (ndim - 1) + (slice(None, -1),)
idx_ = (slice(None),) * (ndim - 1) + (slice(1, None),)
value, value_ = value_est[idx], value_est[idx_]
Expand All @@ -117,8 +116,8 @@ def _call_value_nets(
data_out = torch.vmap(value_net, (0,))(data_in)
value_est = data_out.get(value_key)
value, value_ = value_est[0], value_est[1]
data.set(value_key, value)
data.set(("next", value_key), value_)
data.set(value_key, value)
data.set(("next", value_key), value_)
if detach_next:
value_ = value_.detach()
return value, value_
Expand Down Expand Up @@ -572,6 +571,13 @@ class TD1Estimator(ValueEstimatorBase):
of the advantage entry. Defaults to ``"value_target"``.
value_key (str or tuple of str, optional): [Deprecated] the value key to
read from the input tensordict. Defaults to ``"state_value"``.
shifted (bool, optional): if ``True``, the value and next value are
estimated with a single call to the value network. This is faster
but is only valid whenever (1) the ``"next"`` value is shifted by
only one time step (which is not the case with multi-step value
estimation, for instance) and (2) when the parameters used at time
``t`` and ``t+1`` are identical (which is not the case when target
parameters are to be used). Defaults to ``False``.
"""

Expand All @@ -586,13 +592,15 @@ def __init__(
advantage_key: NestedKey = None,
value_target_key: NestedKey = None,
value_key: NestedKey = None,
shifted: bool = False,
):
super().__init__(
value_network=value_network,
differentiable=differentiable,
advantage_key=advantage_key,
value_target_key=value_target_key,
value_key=value_key,
shifted=shifted,
skip_existing=skip_existing,
)
try:
Expand Down Expand Up @@ -758,6 +766,13 @@ class TDLambdaEstimator(ValueEstimatorBase):
of the advantage entry. Defaults to ``"value_target"``.
value_key (str or tuple of str, optional): [Deprecated] the value key to
read from the input tensordict. Defaults to ``"state_value"``.
shifted (bool, optional): if ``True``, the value and next value are
estimated with a single call to the value network. This is faster
but is only valid whenever (1) the ``"next"`` value is shifted by
only one time step (which is not the case with multi-step value
estimation, for instance) and (2) when the parameters used at time
``t`` and ``t+1`` are identical (which is not the case when target
parameters are to be used). Defaults to ``False``.
"""

Expand All @@ -774,6 +789,7 @@ def __init__(
advantage_key: NestedKey = None,
value_target_key: NestedKey = None,
value_key: NestedKey = None,
shifted: bool = False,
):
super().__init__(
value_network=value_network,
Expand All @@ -782,6 +798,7 @@ def __init__(
value_target_key=value_target_key,
value_key=value_key,
skip_existing=skip_existing,
shifted=shifted,
)
try:
device = next(value_network.parameters()).device
Expand Down Expand Up @@ -959,6 +976,13 @@ class GAE(ValueEstimatorBase):
of the advantage entry. Defaults to ``"value_target"``.
value_key (str or tuple of str, optional): [Deprecated] the value key to
read from the input tensordict. Defaults to ``"state_value"``.
shifted (bool, optional): if ``True``, the value and next value are
estimated with a single call to the value network. This is faster
but is only valid whenever (1) the ``"next"`` value is shifted by
only one time step (which is not the case with multi-step value
estimation, for instance) and (2) when the parameters used at time
``t`` and ``t+1`` are identical (which is not the case when target
parameters are to be used). Defaults to ``False``.
GAE will return an :obj:`"advantage"` entry containing the advange value. It will also
return a :obj:`"value_target"` entry with the return value that is to be used
Expand Down Expand Up @@ -987,8 +1011,10 @@ def __init__(
advantage_key: NestedKey = None,
value_target_key: NestedKey = None,
value_key: NestedKey = None,
shifted: bool = False,
):
super().__init__(
shifted=shifted,
value_network=value_network,
differentiable=differentiable,
advantage_key=advantage_key,
Expand Down

0 comments on commit 02963e2

Please sign in to comment.