Skip to content

Commit

Permalink
[Feature] Refactor CatFrames using a proper preallocated buffer (pyto…
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jan 19, 2023
1 parent 4a81a6c commit c0bc12a
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 30 deletions.
1 change: 0 additions & 1 deletion test/test_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import sys
import time
import warnings

Expand Down
39 changes: 25 additions & 14 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1289,44 +1289,55 @@ def test_catframes_transform_observation_spec(self):
)

@pytest.mark.parametrize("device", get_available_devices())
def test_catframes_buffer_check_latest_frame(self, device):
@pytest.mark.parametrize("d", range(1, 4))
def test_catframes_buffer_check_latest_frame(self, device, d):
key1 = "first key"
key2 = "second key"
N = 4
keys = [key1, key2]
key1_tensor = torch.zeros(1, 1, 3, 3, device=device)
key2_tensor = torch.ones(1, 1, 3, 3, device=device)
key1_tensor = torch.ones(1, d, 3, 3, device=device) * 2
key2_tensor = torch.ones(1, d, 3, 3, device=device)
key_tensors = [key1_tensor, key2_tensor]
td = TensorDict(dict(zip(keys, key_tensors)), [1], device=device)
cat_frames = CatFrames(N=N, in_keys=keys)

cat_frames(td)
latest_frame = td.get(key2)
tdclone = cat_frames(td.clone())
latest_frame = tdclone.get(key2)

assert latest_frame.shape[1] == N * d
assert (latest_frame[0, :-d] == 0).all()
assert (latest_frame[0, -d:] == 1).all()

tdclone = cat_frames(td.clone())
latest_frame = tdclone.get(key2)

assert latest_frame.shape[1] == N
for i in range(0, N - 1):
assert torch.equal(latest_frame[0][i], key2_tensor[0][0])
assert torch.equal(latest_frame[0][N - 1], key1_tensor[0][0])
assert latest_frame.shape[1] == N * d
assert (latest_frame[0, : -2 * d] == 0).all()
assert (latest_frame[0, -2 * d :] == 1).all()

@pytest.mark.parametrize("device", get_available_devices())
def test_catframes_reset(self, device):
key1 = "first key"
key2 = "second key"
N = 4
keys = [key1, key2]
key1_tensor = torch.zeros(1, 1, 3, 3, device=device)
key2_tensor = torch.ones(1, 1, 3, 3, device=device)
key1_tensor = torch.randn(1, 1, 3, 3, device=device)
key2_tensor = torch.randn(1, 1, 3, 3, device=device)
key_tensors = [key1_tensor, key2_tensor]
td = TensorDict(dict(zip(keys, key_tensors)), [1], device=device)
cat_frames = CatFrames(N=N, in_keys=keys)

cat_frames(td)
buffer_length1 = len(cat_frames.buffer)
buffer = getattr(cat_frames, f"_cat_buffers_{key1}")

passed_back_td = cat_frames.reset(td)

assert buffer_length1 == 2
assert td is passed_back_td
assert 0 == len(cat_frames.buffer)
assert (0 == buffer).all()

_ = cat_frames._call(td)
assert (0 == buffer[..., :-1, :, :]).all()
assert (0 != buffer[..., -1:, :, :]).all()

@pytest.mark.parametrize("device", get_available_devices())
def test_finitetensordictcheck(self, device):
Expand Down
99 changes: 84 additions & 15 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1521,27 +1521,107 @@ class CatFrames(ObservationTransform):
cat_dim (int, optional): dimension along which concatenate the
observations. Default is `cat_dim=-3`.
in_keys (list of int, optional): keys pointing to the frames that have
to be concatenated.
to be concatenated. Defaults to ["pixels"].
out_keys (list of int, optional): keys pointing to where the output
has to be written. Defaults to the value of `in_keys`.
"""

inplace = False
_CAT_DIM_ERR = (
"cat_dim must be > 0 to accomodate for tensordict of "
"different batch-sizes (since negative dims are batch invariant)."
)

def __init__(
self,
N: int = 4,
cat_dim: int = -3,
in_keys: Optional[Sequence[str]] = None,
out_keys: Optional[Sequence[str]] = None,
):
if in_keys is None:
in_keys = IMAGE_KEYS
super().__init__(in_keys=in_keys)
super().__init__(in_keys=in_keys, out_keys=out_keys)
self.N = N
if cat_dim > 0:
raise ValueError(self._CAT_DIM_ERR)
self.cat_dim = cat_dim
self.buffer = []

def reset(self, tensordict: TensorDictBase) -> TensorDictBase:
self.buffer = []
"""Resets _buffers."""
# Non-batched environments
if len(tensordict.batch_size) < 1 or tensordict.batch_size[0] == 1:
for in_key in self.in_keys:
buffer_name = f"_cat_buffers_{in_key}"
try:
buffer = getattr(self, buffer_name)
buffer.fill_(0.0)
except AttributeError:
# we'll instantiate later, when needed
pass

# Batched environments
else:
_reset = tensordict.get(
"_reset",
torch.ones(
tensordict.batch_size,
dtype=torch.bool,
device=tensordict.device,
),
)
for in_key in self.in_keys:
buffer_name = f"_cat_buffers_{in_key}"
try:
buffer = getattr(self, buffer_name)
buffer[_reset] = 0.0
except AttributeError:
# we'll instantiate later, when needed
pass

return tensordict

def _make_missing_buffer(self, data, buffer_name):
shape = list(data.shape)
d = shape[self.cat_dim]
shape[self.cat_dim] = d * self.N
shape = torch.Size(shape)
self.register_buffer(
buffer_name,
torch.zeros(
shape,
dtype=data.dtype,
device=data.device,
),
)
buffer = getattr(self, buffer_name)
return buffer

def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
"""Update the episode tensordict with max pooled keys."""
for in_key, out_key in zip(self.in_keys, self.out_keys):
# Lazy init of buffers
buffer_name = f"_cat_buffers_{in_key}"
data = tensordict[in_key]
d = data.size(self.cat_dim)
try:
buffer = getattr(self, buffer_name)
# shift obs 1 position to the right
buffer.copy_(torch.roll(buffer, shifts=-d, dims=self.cat_dim))
except AttributeError:
buffer = self._make_missing_buffer(data, buffer_name)
# add new obs
idx = self.cat_dim
if idx < 0:
idx = buffer.ndimension() + idx
else:
raise ValueError(self._CAT_DIM_ERR)
idx = [slice(None, None) for _ in range(idx)] + [slice(-d, None)]
buffer[idx].copy_(data)
# add to tensordict
tensordict.set(out_key, buffer.clone())

return tensordict

@_apply_to_composite
Expand All @@ -1557,17 +1637,6 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
observation_spec.shape = torch.Size(shape)
return observation_spec

def _apply_transform(self, obs: torch.Tensor) -> torch.Tensor:
self.buffer.append(obs)
self.buffer = self.buffer[-self.N :]
buffer = list(reversed(self.buffer))
buffer = [buffer[0]] * (self.N - len(buffer)) + buffer
if len(buffer) != self.N:
raise RuntimeError(
f"actual buffer length ({buffer}) differs from expected (" f"{self.N})"
)
return torch.cat(buffer, self.cat_dim)

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(N={self.N}, cat_dim"
Expand Down

0 comments on commit c0bc12a

Please sign in to comment.