Skip to content

Commit

Permalink
[BugFix] Fix args/kwargs passing in advantages (#2001)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Mar 7, 2024
1 parent 07eb02d commit 130a213
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 10 deletions.
93 changes: 93 additions & 0 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -12119,6 +12119,99 @@ def test_successive_traj_tdadv(self, device, N, T):
)
torch.testing.assert_close(v1, torch.cat([v1a, v1b], -2), rtol=1e-4, atol=1e-4)

@pytest.mark.parametrize("device", get_default_devices())
def test_args_kwargs_timedim(self, device):
torch.manual_seed(0)

lmbda = 0.95
N = (2, 3)
B = (4,)
T = 20

terminated = torch.zeros(*N, T, *B, 1, device=device, dtype=torch.bool)
terminated[..., T // 2 - 1, :, :] = 1
done = terminated.clone()
done[..., -1, :, :] = 1

reward = torch.randn(*N, T, *B, 1, device=device)
state_value = torch.randn(*N, T, *B, 1, device=device)
next_state_value = torch.randn(*N, T, *B, 1, device=device)

# avoid low values of gamma
gamma = 0.95

v1 = vec_generalized_advantage_estimate(
gamma,
lmbda,
state_value,
next_state_value,
reward,
done=done,
terminated=terminated,
time_dim=-3,
)[0]


v2 = vec_generalized_advantage_estimate(
gamma=gamma,
lmbda=lmbda,
state_value=state_value,
next_state_value=next_state_value,
reward=reward,
done=done,
terminated=terminated,
time_dim=-3,
)[0]

with pytest.raises(TypeError, match="positional arguments"):
v3 = vec_generalized_advantage_estimate(
gamma,
lmbda,
state_value,
next_state_value,
reward,
done,
terminated,
-3,
)[0]

v3 = vec_generalized_advantage_estimate(
gamma,
lmbda,
state_value,
next_state_value,
reward,
done,
terminated,
time_dim=-3,
)[0]

v4 = vec_generalized_advantage_estimate(
gamma,
lmbda,
state_value,
next_state_value,
reward,
done,
terminated,
time_dim=2,
)[0]

v5 = vec_generalized_advantage_estimate(
gamma=gamma,
lmbda=lmbda,
state_value=state_value,
next_state_value=next_state_value,
reward=reward,
done=done,
terminated=terminated,
time_dim=-3,
)[0]
torch.testing.assert_close(v1, v2)
torch.testing.assert_close(v1, v3)
torch.testing.assert_close(v1, v4)
torch.testing.assert_close(v1, v5)

@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("N", [(3,), (3, 7)])
@pytest.mark.parametrize("T", [3, 5, 200])
Expand Down
21 changes: 17 additions & 4 deletions torchrl/objectives/value/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,16 @@ def transpose_tensor(tensor):
return tensor, single_dim

if time_dim != -2:
args, single_dim = zip(*(transpose_tensor(arg) for arg in args))
single_dim = any(single_dim)
single_dim = False
if args:
args, single_dim = zip(*(transpose_tensor(arg) for arg in args))
single_dim = any(single_dim)
for k, item in list(kwargs.items()):
item, sd = transpose_tensor(item)
single_dim = single_dim or sd
kwargs[k] = item
out = fun(*args, time_dim=-2, **kwargs)
# We don't pass time_dim because it isn't supposed to be used thereafter
out = fun(*args, **kwargs)
if isinstance(out, torch.Tensor):
out = transpose_tensor(out)[0]
if single_dim:
Expand All @@ -96,7 +99,8 @@ def transpose_tensor(tensor):
if single_dim:
return tuple(transpose_tensor(_out)[0].squeeze(-2) for _out in out)
return tuple(transpose_tensor(_out)[0] for _out in out)
out = fun(*args, time_dim=time_dim, **kwargs)
# We don't pass time_dim because it isn't supposed to be used thereafter
out = fun(*args, **kwargs)
if isinstance(out, tuple):
for _out in out:
if _out.ndim < 2:
Expand All @@ -123,6 +127,7 @@ def generalized_advantage_estimate(
reward: torch.Tensor,
done: torch.Tensor,
terminated: torch.Tensor | None = None,
*,
time_dim: int = -2,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Generalized advantage estimate of a trajectory.
Expand Down Expand Up @@ -267,6 +272,7 @@ def vec_generalized_advantage_estimate(
reward: torch.Tensor,
done: torch.Tensor,
terminated: torch.Tensor | None = None,
*,
time_dim: int = -2,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Vectorized Generalized advantage estimate of a trajectory.
Expand Down Expand Up @@ -457,6 +463,7 @@ def td1_return_estimate(
done: torch.Tensor,
terminated: torch.Tensor | None = None,
rolling_gamma: bool = None,
*,
time_dim: int = -2,
) -> torch.Tensor:
r"""TD(1) return estimate.
Expand Down Expand Up @@ -760,6 +767,7 @@ def td_lambda_return_estimate(
done: torch.Tensor,
terminated: torch.Tensor | None = None,
rolling_gamma: bool = None,
*,
time_dim: int = -2,
) -> torch.Tensor:
r"""TD(:math:`\lambda`) return estimate.
Expand Down Expand Up @@ -864,6 +872,7 @@ def td_lambda_advantage_estimate(
done: torch.Tensor,
terminated: torch.Tensor | None = None,
rolling_gamma: bool = None,
# not a kwarg because used directly
time_dim: int = -2,
) -> torch.Tensor:
r"""TD(:math:`\lambda`) advantage estimate.
Expand Down Expand Up @@ -997,6 +1006,7 @@ def vec_td_lambda_return_estimate(
done,
terminated: torch.Tensor | None = None,
rolling_gamma: Optional[bool] = None,
*,
time_dim: int = -2,
):
r"""Vectorized TD(:math:`\lambda`) return estimate.
Expand Down Expand Up @@ -1148,6 +1158,7 @@ def vec_td_lambda_advantage_estimate(
done,
terminated: torch.Tensor | None = None,
rolling_gamma: bool = None,
# not a kwarg because used directly
time_dim: int = -2,
):
r"""Vectorized TD(:math:`\lambda`) advantage estimate.
Expand Down Expand Up @@ -1230,6 +1241,7 @@ def vtrace_advantage_estimate(
terminated: torch.Tensor | None = None,
rho_thresh: Union[float, torch.Tensor] = 1.0,
c_thresh: Union[float, torch.Tensor] = 1.0,
# not a kwarg because used directly
time_dim: int = -2,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Computes V-Trace off-policy actor critic targets.
Expand Down Expand Up @@ -1310,6 +1322,7 @@ def reward2go(
reward,
done,
gamma,
*,
time_dim: int = -2,
):
"""Compute the discounted cumulative sum of rewards given multiple trajectories and the episode ends.
Expand Down
17 changes: 11 additions & 6 deletions torchrl/objectives/value/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,14 +182,15 @@ def _make_gammas_tensor(gamma: torch.Tensor, T: int, rolling_gamma: bool):
return gammas


def _flatten_batch(tensor):
def _flatten_batch(tensor, time_dim=-1):
"""Because we mark the end of each batch with a truncated signal, we can concatenate them.
Args:
tensor (torch.Tensor): a tensor of shape [*B, T]
tensor (torch.Tensor): a tensor of shape [*B, T, *F]
time_dim (int, optional): the time dimension T. Defaults to -1.
"""
return tensor.flatten(0, -1)
return tensor.flatten(0, time_dim)


def _get_num_per_traj(done):
Expand All @@ -211,7 +212,10 @@ def _get_num_per_traj(done):


def _split_and_pad_sequence(
tensor: Union[torch.Tensor, TensorDictBase], splits: torch.Tensor, return_mask=False
tensor: Union[torch.Tensor, TensorDictBase],
splits: torch.Tensor,
return_mask=False,
time_dim=-1,
):
"""Given a tensor of size [*B, T, F] and the corresponding traj lengths (flattened), returns the padded trajectories [NPad, Tmax, *other].
Expand Down Expand Up @@ -277,17 +281,18 @@ def _split_and_pad_sequence(
[19, 19, 19]]])
"""
tensor = _flatten_batch(tensor)
max_seq_len = torch.max(splits)
shape = (len(splits), max_seq_len)

# int16 supports length up to 32767
dtype = (
torch.int16 if tensor.shape[-1] < torch.iinfo(torch.int16).max else torch.int32
torch.int16 if tensor.shape[-2] < torch.iinfo(torch.int16).max else torch.int32
)
arange = torch.arange(max_seq_len, device=tensor.device, dtype=dtype).unsqueeze(0)
mask = arange < splits.unsqueeze(1)

tensor = _flatten_batch(tensor, time_dim=time_dim)

def _fill_tensor(tensor):
empty_tensor = torch.zeros(
*shape,
Expand Down

0 comments on commit 130a213

Please sign in to comment.