diff --git a/test/test_transforms.py b/test/test_transforms.py index ba4bdc5d439..fcfd6f08aff 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -10706,6 +10706,7 @@ def test_multistep_transform(self): rollout = env.rollout( 2, auto_reset=False, tensordict=td, break_when_any_done=False ).contiguous() + assert rollout.shape[:-1] == env.batch_size assert "reward" not in rollout.keys() out = t._inv_call(rollout) td = rollout[..., -1]["next"].exclude("reward") diff --git a/torchrl/data/postprocs/postprocs.py b/torchrl/data/postprocs/postprocs.py index 4d15ba9a78d..a46b1c04ecc 100644 --- a/torchrl/data/postprocs/postprocs.py +++ b/torchrl/data/postprocs/postprocs.py @@ -233,10 +233,7 @@ def _multi_step_func( try: # let's try to reshape the tensordict tensordict.batch_size = done.shape - tensordict = tensordict.apply( - lambda x: x.transpose(ndim - 1, tensordict.ndim - 1), - batch_size=done.transpose(ndim - 1, tensordict.ndim - 1).shape, - ) + tensordict = tensordict.transpose(ndim - 1, tensordict.ndim - 1) done = tensordict.get(("next", done_key)) except Exception as err: raise RuntimeError(