From 8e43ac8f48b9b6aead0b693674e8176bbda8221b Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 10 Jul 2024 15:54:27 +0100 Subject: [PATCH] [Refactor] Use td.transpose in multi-step transform (#2288) --- test/test_transforms.py | 1 + torchrl/data/postprocs/postprocs.py | 5 +---- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index ba4bdc5d439..fcfd6f08aff 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -10706,6 +10706,7 @@ def test_multistep_transform(self): rollout = env.rollout( 2, auto_reset=False, tensordict=td, break_when_any_done=False ).contiguous() + assert rollout.shape[:-1] == env.batch_size assert "reward" not in rollout.keys() out = t._inv_call(rollout) td = rollout[..., -1]["next"].exclude("reward") diff --git a/torchrl/data/postprocs/postprocs.py b/torchrl/data/postprocs/postprocs.py index 4d15ba9a78d..a46b1c04ecc 100644 --- a/torchrl/data/postprocs/postprocs.py +++ b/torchrl/data/postprocs/postprocs.py @@ -233,10 +233,7 @@ def _multi_step_func( try: # let's try to reshape the tensordict tensordict.batch_size = done.shape - tensordict = tensordict.apply( - lambda x: x.transpose(ndim - 1, tensordict.ndim - 1), - batch_size=done.transpose(ndim - 1, tensordict.ndim - 1).shape, - ) + tensordict = tensordict.transpose(ndim - 1, tensordict.ndim - 1) done = tensordict.get(("next", done_key)) except Exception as err: raise RuntimeError(