Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] RB MultiStep transform #2008

Merged
merged 10 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
amend
  • Loading branch information
vmoens committed Mar 18, 2024
commit ddfce880d7d4c554d39d8859f4123d6054856007
24 changes: 24 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10259,6 +10259,30 @@ def test_multistep_transform(self):

assert_allclose_td(outs, outs_3)

def test_multistep_transform_changes(self):
data = TensorDict(
{
"steps": torch.arange(100),
"next": {
"steps": torch.arange(1, 101),
"reward": torch.ones(100, 1),
"done": torch.zeros(100, 1, dtype=torch.bool),
"terminated": torch.zeros(100, 1, dtype=torch.bool),
"truncated": torch.zeros(100, 1, dtype=torch.bool),
},
},
batch_size=[100],
)
data_splits = data.split(10)
t = MultiStepTransform(3, 0.98)
rb = ReplayBuffer(storage=LazyTensorStorage(100), transform=t)
for data in data_splits:
rb.extend(data)
t.n_steps = t.n_steps + 1
assert (rb[:]["steps"] == torch.arange(len(rb))).all()
assert rb[:]["next", "steps"][-1] == data["steps"][-1]
assert t._buffer["steps"][-1] == data["steps"][-1]


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
Expand Down
52 changes: 39 additions & 13 deletions torchrl/envs/transforms/rb_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ class MultiStepTransform(Transform):
outputs the transformed previous ``n_steps`` with the ``T-n_steps`` current
frames.

All entries in the ``"next"`` tensordict that are not part of the ``done_keys``
or ``reward_keys`` will be mapped to their respective ``t + n_steps - 1``
correspondent.

This transform is a more hyperparameter resistant version of
:class:`~torchrl.data.postprocs.postprocs.MultiStep`:
the replay buffer transform will make the multi-step transform insensitive
Expand All @@ -29,18 +33,24 @@ class MultiStepTransform(Transform):
(because collectors have no memory of previous output).

Args:
n_steps (int): Number of steps in multi-step.
n_steps (int): Number of steps in multi-step. The number of steps can be
dynamically changed by changing the ``n_steps`` attribute of this
transform.
gamma (float): Discount factor.

Keyword Args:
reward_key (NestedKey, optional): the reward key in the input tensordict.
Defaults to ``"reward"``.
done_key (NestedKey, optional): the done key in the input tensordict.
reward_keys (list of NestedKey, optional): the reward keys in the input tensordict.
The reward entries indicated by these keys will be accumulated and discounted
across ``n_steps`` steps in the future. A corresponding ``<reward_key>_orig``
entry will be written in the ``"next"`` entry of the output tensordict
to keep track of the original value of the reward.
Defaults to ``["reward"]``.
done_key (NestedKey, optional): the done key in the input tensordict, used to indicate
an end of trajectory.
Defaults to ``"done"``.
terminated_key (NestedKey, optional): the terminated key in the input tensordict.
Defaults to ``"terminated"``.
truncated_key (NestedKey, optional): the truncated key in the input tensordict.
Defaults to ``"truncated"``.
done_keys (list of NestedKey, optional): the list of end keys in the input tensordict.
All the entries indicated by these keys will be left untouched by the transform.
Defaults to ``["done", "truncated", "terminated"]``.
mask_key (NestedKey, optional): the mask key in the input tensordict.
The mask represents the valid frames in the input tensordict and
should have a shape that allows the input tensordict to be masked
Expand Down Expand Up @@ -114,9 +124,25 @@ def __init__(
self.done_keys = done_keys
self.mask_key = mask_key
self.gamma = gamma
self.buffer = None
self._buffer = None
self._validated = False

@property
def n_steps(self):
"""The look ahead window of the transform.

This value can be dynamically edited during training.
"""
return self._n_steps

@n_steps.setter
def n_steps(self, value):
if not isinstance(value, int) or not (value >= 1):
raise ValueError(
"The value of n_steps must be a strictly positive integer."
)
self._n_steps = value

@property
def done_key(self):
return self._done_key
Expand Down Expand Up @@ -182,10 +208,10 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
return out[..., : -self.n_steps]

def _append_tensordict(self, data):
if self.buffer is None:
if self._buffer is None:
total_cat = data
self.buffer = data[..., -self.n_steps :].copy()
self._buffer = data[..., -self.n_steps :].copy()
else:
total_cat = torch.cat([self.buffer, data], -1)
self.buffer = total_cat[..., -self.n_steps :].copy()
total_cat = torch.cat([self._buffer, data], -1)
self._buffer = total_cat[..., -self.n_steps :].copy()
return total_cat
Loading