From d4842fecd836c85ef3b6750296d430a5acf4678e Mon Sep 17 00:00:00 2001 From: kurtamohler Date: Mon, 2 Sep 2024 08:15:33 -0700 Subject: [PATCH] [Performance] Faster `CatFrames.unfolding` with `padding="same"` (#2407) --- torchrl/envs/transforms/transforms.py | 44 ++++++++++++++++----------- 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 2e2883c33bf..7f8403c793e 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -3082,6 +3082,31 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: else: return self.unfolding(tensordict) + def _apply_same_padding(self, dim, data, done_mask): + d = data.ndim + dim - 1 + res = data.clone() + num_repeats_per_sample = done_mask.sum(dim=-1) + + if num_repeats_per_sample.dim() > 2: + extra_dims = num_repeats_per_sample.dim() - 2 + num_repeats_per_sample = num_repeats_per_sample.flatten(0, extra_dims) + res_flat_series = res.flatten(0, extra_dims) + else: + extra_dims = 0 + res_flat_series = res + + if d - 1 > extra_dims: + res_flat_series_flat_batch = res_flat_series.flatten(1, d - 1) + else: + res_flat_series_flat_batch = res_flat_series[:, None] + + for sample_idx, num_repeats in enumerate(num_repeats_per_sample): + if num_repeats > 0: + res_slice = res_flat_series_flat_batch[sample_idx] + res_slice[:, :num_repeats] = res_slice[:, num_repeats : num_repeats + 1] + + return res + @set_lazy_legacy(False) def unfolding(self, tensordict: TensorDictBase) -> TensorDictBase: # it is assumed that the last dimension of the tensordict is the time dimension @@ -3192,24 +3217,7 @@ def unfold_done(done, N): if self.padding != "same": data = torch.where(done_mask_expand, self.padding_value, data) else: - # TODO: This is a pretty bad implementation, could be - # made more efficient but it works! - reset_any = reset.any(-1, False) - reset_vals = list(data_orig[reset_any].unbind(0)) - j_ = float("inf") - reps = [] - d = data.ndim + self.dim - 1 - n_feat = data.shape[data.ndim + self.dim :].numel() - for j in done_mask_expand.flatten(d, -1).sum(-1).view(-1) // n_feat: - if j > j_: - reset_vals = reset_vals[1:] - reps.extend([reset_vals[0]] * int(j)) - j_ = j - if reps: - reps = torch.stack(reps) - data = torch.masked_scatter( - data, done_mask_expand, reps.reshape(-1) - ) + data = self._apply_same_padding(self.dim, data, done_mask) if first_val is not None: # Aggregate reset along last dim