Skip to content

Commit

Permalink
[Performance] Faster CatFrames.unfolding with padding="same" (pyt…
Browse files Browse the repository at this point in the history
  • Loading branch information
kurtamohler authored Sep 2, 2024
1 parent ca6eae4 commit d4842fe
Showing 1 changed file with 26 additions and 18 deletions.
44 changes: 26 additions & 18 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3082,6 +3082,31 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
else:
return self.unfolding(tensordict)

def _apply_same_padding(self, dim, data, done_mask):
d = data.ndim + dim - 1
res = data.clone()
num_repeats_per_sample = done_mask.sum(dim=-1)

if num_repeats_per_sample.dim() > 2:
extra_dims = num_repeats_per_sample.dim() - 2
num_repeats_per_sample = num_repeats_per_sample.flatten(0, extra_dims)
res_flat_series = res.flatten(0, extra_dims)
else:
extra_dims = 0
res_flat_series = res

if d - 1 > extra_dims:
res_flat_series_flat_batch = res_flat_series.flatten(1, d - 1)
else:
res_flat_series_flat_batch = res_flat_series[:, None]

for sample_idx, num_repeats in enumerate(num_repeats_per_sample):
if num_repeats > 0:
res_slice = res_flat_series_flat_batch[sample_idx]
res_slice[:, :num_repeats] = res_slice[:, num_repeats : num_repeats + 1]

return res

@set_lazy_legacy(False)
def unfolding(self, tensordict: TensorDictBase) -> TensorDictBase:
# it is assumed that the last dimension of the tensordict is the time dimension
Expand Down Expand Up @@ -3192,24 +3217,7 @@ def unfold_done(done, N):
if self.padding != "same":
data = torch.where(done_mask_expand, self.padding_value, data)
else:
# TODO: This is a pretty bad implementation, could be
# made more efficient but it works!
reset_any = reset.any(-1, False)
reset_vals = list(data_orig[reset_any].unbind(0))
j_ = float("inf")
reps = []
d = data.ndim + self.dim - 1
n_feat = data.shape[data.ndim + self.dim :].numel()
for j in done_mask_expand.flatten(d, -1).sum(-1).view(-1) // n_feat:
if j > j_:
reset_vals = reset_vals[1:]
reps.extend([reset_vals[0]] * int(j))
j_ = j
if reps:
reps = torch.stack(reps)
data = torch.masked_scatter(
data, done_mask_expand, reps.reshape(-1)
)
data = self._apply_same_padding(self.dim, data, done_mask)

if first_val is not None:
# Aggregate reset along last dim
Expand Down

0 comments on commit d4842fe

Please sign in to comment.