Skip to content

Commit

Permalink
[Performance] Accelerate TD lambda return estimate (pytorch#1158)
Browse files Browse the repository at this point in the history
  • Loading branch information
Blonck authored May 18, 2023
1 parent 555e156 commit e8a43b9
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 48 deletions.
60 changes: 60 additions & 0 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -4306,6 +4306,66 @@ def test_tdlambda_tensor_gamma(self, device, gamma, lmbda, N, T, has_done):

torch.testing.assert_close(v1, v2, rtol=1e-4, atol=1e-4)

@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("gamma", [0.5, 0.99])
@pytest.mark.parametrize("lmbda", [0.25, 0.99])
@pytest.mark.parametrize("N", [(3,), (7, 3)])
@pytest.mark.parametrize("T", [3, 100])
@pytest.mark.parametrize("F", [1, 4])
@pytest.mark.parametrize("has_done", [True, False])
@pytest.mark.parametrize(
"gamma_tensor", ["scalar", "tensor", "tensor_single_element"]
)
@pytest.mark.parametrize("lmbda_tensor", ["scalar", "tensor_single_element"])
def test_tdlambda_tensor_gamma_single_element(
self, device, gamma, lmbda, N, T, F, has_done, gamma_tensor, lmbda_tensor
):
"""Tests vec_td_lambda_advantage_estimate against itself with
gamma being a tensor or a scalar
"""
torch.manual_seed(0)

done = torch.zeros(*N, T, F, device=device, dtype=torch.bool)
if has_done:
done = done.bernoulli_(0.1)
reward = torch.randn(*N, T, F, device=device)
state_value = torch.randn(*N, T, F, device=device)
next_state_value = torch.randn(*N, T, F, device=device)

if gamma_tensor == "tensor":
gamma_vec = torch.full_like(reward, gamma)
elif gamma_tensor == "tensor_single_element":
gamma_vec = torch.as_tensor([gamma], device=device)
else:
gamma_vec = gamma

if gamma_tensor == "tensor_single_element":
lmbda_vec = torch.as_tensor([lmbda], device=device)
else:
lmbda_vec = lmbda

v1 = vec_td_lambda_advantage_estimate(
gamma, lmbda, state_value, next_state_value, reward, done
)
v2 = vec_td_lambda_advantage_estimate(
gamma_vec, lmbda_vec, state_value, next_state_value, reward, done
)

torch.testing.assert_close(v1, v2, rtol=1e-4, atol=1e-4)

# # same with last done being true
done[..., -1, :] = True # terminating trajectory

v1 = vec_td_lambda_advantage_estimate(
gamma, lmbda, state_value, next_state_value, reward, done
)
v2 = vec_td_lambda_advantage_estimate(
gamma_vec, lmbda_vec, state_value, next_state_value, reward, done
)

torch.testing.assert_close(v1, v2, rtol=1e-4, atol=1e-4)

@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("gamma", [0.5, 0.99, 0.1])
@pytest.mark.parametrize("N", [(3,), (7, 3)])
Expand Down
140 changes: 92 additions & 48 deletions torchrl/objectives/value/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,63 @@ def td_lambda_advantage_estimate(
return advantage


def _fast_td_lambda_return_estimate(
gamma: Union[torch.Tensor, float],
lmbda: float,
next_state_value: torch.Tensor,
reward: torch.Tensor,
done: torch.Tensor,
thr: float = 1e-7,
):
"""Fast vectorized TD lambda return estimate.
In contrast to the generalized `vec_td_lambda_return_estimate` this function does not need
to allocate a big tensor of the form [B, T, T], but it only works with gamma/lmbda being scalars.
Args:
gamma (scalar): the gamma decay, can be a tensor with a single element (trajectory discount)
lmbda (scalar): the lambda decay (exponential mean discount)
next_state_value (torch.Tensor): a [*B, T, F] tensor containing next state values (value function)
reward (torch.Tensor): a [*B, T, F] tensor containing rewards
done (torch.Tensor): a [B, T] boolean tensor containing the done states
thr (float): threshold for the filter. Below this limit, components will ignored.
Defaults to 1e-7.
All tensors (values, reward and done) must have shape
``[*Batch x TimeSteps x F]``, with ``F`` feature dimensions.
"""
device = reward.device
done = done.transpose(-2, -1)
reward = reward.transpose(-2, -1)
next_state_value = next_state_value.transpose(-2, -1)

gamma_tensor = torch.tensor([gamma], device=device)
gammalmbda = gamma_tensor * lmbda

not_done = (~done).int()
num_per_traj = _get_num_per_traj(done)
nvalue_ndone = not_done * next_state_value

t = nvalue_ndone * gamma_tensor * (1 - lmbda) + reward
v3 = torch.zeros_like(t, device=device)
v3[..., -1] = nvalue_ndone[..., -1].clone()

t_flat, mask = _split_and_pad_sequence(
t + v3 * gammalmbda, num_per_traj, return_mask=True
)

# cutoff gammalmbdas as soon as is smaller than `thr`
lim = int(math.log(thr) / gammalmbda.log().item())
# create decay filter [1, g, g**2, g**3, ...]
gammalmbdas = gammalmbda.pow(torch.arange(lim, device=device)).unsqueeze(-1)

ret_flat = _custom_conv1d(t_flat.unsqueeze(1), gammalmbdas)
ret = ret_flat.squeeze(1)[mask]

return ret.view_as(reward).transpose(-1, -2)


@_transpose_time
def vec_td_lambda_return_estimate(
gamma,
Expand Down Expand Up @@ -831,10 +888,27 @@ def vec_td_lambda_return_estimate(
"""
if not (next_state_value.shape == reward.shape == done.shape):
raise RuntimeError(SHAPE_ERR)

gamma_thr = 1e-7
shape = next_state_value.shape

*batch, T, lastdim = shape

def _is_scalar(tensor):
return not isinstance(tensor, torch.Tensor) or tensor.numel() == 1

# There are two use-cases: if gamma/lmbda are scalars we can use the
# fast implementation, if not we must construct a gamma tensor.
if _is_scalar(gamma) and _is_scalar(lmbda):
return _fast_td_lambda_return_estimate(
gamma=gamma,
lmbda=lmbda,
next_state_value=next_state_value,
reward=reward,
done=done,
thr=gamma_thr,
)

next_state_value = next_state_value.transpose(-2, -1).unsqueeze(-2)
if len(batch):
next_state_value = next_state_value.flatten(0, len(batch))
Expand All @@ -847,62 +921,32 @@ def vec_td_lambda_return_estimate(
device = reward.device
not_done = (~done).int()

first_below_thr_gamma = None

# 3 use cases: (1) there is one gamma per time step, (2) there is a single gamma but
# some intermediate dones and (3) there is a single gamma and no done.
# (3) can be treated much faster than (1) and (2) (lower mem footprint)
if (isinstance(gamma, torch.Tensor) and gamma.numel() > 1) or done.any():
if rolling_gamma is None:
rolling_gamma = True
if rolling_gamma:
gamma = gamma * not_done
gammas = _make_gammas_tensor(gamma, T, rolling_gamma)

if not rolling_gamma:
done_follows_done = done[..., 1:, :][done[..., :-1, :]].all()
if not done_follows_done:
raise NotImplementedError(
"When using rolling_gamma=False and vectorized TD(lambda), "
"make sure that conseducitve trajectories are separated as different batch "
"items. Propagating a gamma value across trajectories is not permitted with "
"this method. Check that you need to use rolling_gamma=False, and if so "
"consider using the non-vectorized version of the return computation or splitting "
"your trajectories."
)
else:
gammas[..., 1:, :] = gammas[..., 1:, :] * not_done.view(-1, 1, T, 1)

else:
if rolling_gamma is not None:
raise RuntimeError(
"rolling_gamma cannot be set if a non-tensor gamma is provided"
if rolling_gamma is None:
rolling_gamma = True
if rolling_gamma:
gamma = gamma * not_done
gammas = _make_gammas_tensor(gamma, T, rolling_gamma)

if not rolling_gamma:
done_follows_done = done[..., 1:, :][done[..., :-1, :]].all()
if not done_follows_done:
raise NotImplementedError(
"When using rolling_gamma=False and vectorized TD(lambda) with time-dependent gamma, "
"make sure that conseducitve trajectories are separated as different batch "
"items. Propagating a gamma value across trajectories is not permitted with "
"this method. Check that you need to use rolling_gamma=False, and if so "
"consider using the non-vectorized version of the return computation or splitting "
"your trajectories."
)
gammas = torch.ones(T + 1, 1, device=device)
gammas[1:] = gamma
else:
gammas[..., 1:, :] = gammas[..., 1:, :] * not_done.view(-1, 1, T, 1)

gammas_cp = torch.cumprod(gammas, -2)

lambdas = torch.ones(T + 1, 1, device=device)
lambdas[1:] = lmbda
lambdas_cp = torch.cumprod(lambdas, -2)

if not isinstance(gamma, torch.Tensor) or gamma.numel() <= 0:
first_below_thr = gammas_cp < 1e-7
while first_below_thr.ndimension() > 2:
# if we have multiple gammas, we only want to truncate if _all_ of
# the geometric sequences fall below the threshold
first_below_thr = first_below_thr.all(axis=0)
if first_below_thr.any():
first_below_thr_gamma = first_below_thr.nonzero()[0, 0]
first_below_thr = lambdas_cp < 1e-7
if first_below_thr.any() and first_below_thr_gamma is not None:
first_below_thr = max(
first_below_thr_gamma, first_below_thr.nonzero()[0, 0]
)
gammas_cp = gammas_cp[..., :first_below_thr, :]
lambdas_cp = lambdas_cp[:first_below_thr]

gammas = gammas[..., 1:, :]
lambdas = lambdas[1:]

Expand Down
1 change: 1 addition & 0 deletions torchrl/objectives/value/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ def _inv_pad_sequence(
tensor.shape[-1], device=tensor.device, dtype=dtype
).unsqueeze(0)
mask = arange < splits.unsqueeze(1)

return tensor[mask]


Expand Down

0 comments on commit e8a43b9

Please sign in to comment.