Skip to content

Commit

Permalink
[BugFix] Fix R2Go once more (#2089)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Apr 18, 2024
1 parent 61c42e4 commit acf168e
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 8 deletions.
12 changes: 12 additions & 0 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -13259,6 +13259,18 @@ def test_reward2go(self):
r = torch.stack([r, -r], -1)
torch.testing.assert_close(reward2go(reward, done, 0.9), r)

reward = torch.zeros(4, 1)
reward[3, 0] = 1
done = torch.zeros(4, 1, dtype=bool)
done[3, :] = True
r = torch.ones(4)
r[1:] = 0.9
reward = reward.expand(2, 4, 1)
done = done.expand(2, 4, 1)
r = torch.cumprod(r, 0).flip(0).unsqueeze(-1).expand(2, 4, 1)
r2go = reward2go(reward, done, 0.9)
torch.testing.assert_close(r2go, r)

def test_timedimtranspose_single(self):
@_transpose_time
def fun(a, b, time_dim=-2):
Expand Down
25 changes: 17 additions & 8 deletions torchrl/objectives/value/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import functools
import math

import warnings
Expand Down Expand Up @@ -1362,13 +1363,19 @@ def reward2go(
raise ValueError(
f"reward and done must share the same shape, got {reward.shape} and {done.shape}"
)
# flatten if needed
if reward.ndim > 2:
# we know time dim is at -2, let's put it at -3
rflip = reward.transpose(-2, -3)
rflip_shape = rflip.shape[-2:]
r2go = reward2go(
rflip.flatten(-2, -1), done.transpose(-2, -3).flatten(-2, -1), gamma=gamma
).unflatten(-1, rflip_shape)
return r2go.transpose(-2, -3)

# place time at dim -1
reward = reward.transpose(-2, -1)
done = done.transpose(-2, -1)
# flatten if needed
if reward.ndim > 2:
reward = reward.flatten(0, -2)
done = done.flatten(0, -2)

num_per_traj = _get_num_per_traj(done)
td0_flat = _split_and_pad_sequence(reward, num_per_traj)
Expand All @@ -1379,8 +1386,10 @@ def reward2go(
cumsum = cumsum.reshape_as(reward)
cumsum = cumsum.transpose(-2, -1)
if cumsum.shape != shape:
raise RuntimeError(
f"Wrong shape for output reward2go: {cumsum.shape} when {shape} was expected."
)
# cumsum = cumsum.view(shape)
try:
cumsum = cumsum.reshape(shape)
except RuntimeError:
raise RuntimeError(
f"Wrong shape for output reward2go: {cumsum.shape} when {shape} was expected."
)
return cumsum

1 comment on commit acf168e

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'GPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: acf168e Previous: 61c42e4 Ratio
benchmarks/test_replaybuffer_benchmark.py::test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] 260.64145498970356 iter/sec (stddev: 0.015830092931712923) 649.9179601884035 iter/sec (stddev: 0.00030697103123672015) 2.49

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.