Skip to content

Commit

Permalink
[Refactor] Use td.transpose in multi-step transform (pytorch#2288)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jul 10, 2024
1 parent dcd332d commit 8e43ac8
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 4 deletions.
1 change: 1 addition & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10706,6 +10706,7 @@ def test_multistep_transform(self):
rollout = env.rollout(
2, auto_reset=False, tensordict=td, break_when_any_done=False
).contiguous()
assert rollout.shape[:-1] == env.batch_size
assert "reward" not in rollout.keys()
out = t._inv_call(rollout)
td = rollout[..., -1]["next"].exclude("reward")
Expand Down
5 changes: 1 addition & 4 deletions torchrl/data/postprocs/postprocs.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,7 @@ def _multi_step_func(
try:
# let's try to reshape the tensordict
tensordict.batch_size = done.shape
tensordict = tensordict.apply(
lambda x: x.transpose(ndim - 1, tensordict.ndim - 1),
batch_size=done.transpose(ndim - 1, tensordict.ndim - 1).shape,
)
tensordict = tensordict.transpose(ndim - 1, tensordict.ndim - 1)
done = tensordict.get(("next", done_key))
except Exception as err:
raise RuntimeError(
Expand Down

0 comments on commit 8e43ac8

Please sign in to comment.