Skip to content

Commit

Permalink
[Performance] Some efficiency improvements (pytorch#1250)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jun 9, 2023
1 parent 0d67d39 commit 99afe8b
Show file tree
Hide file tree
Showing 9 changed files with 203 additions and 112 deletions.
2 changes: 1 addition & 1 deletion test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
device=self.device,
)
return TensorDict(
{"reward": n, "done": done, "observation": n},
{"done": done, "observation": n},
[
*leading_batch_size,
*batch_size,
Expand Down
40 changes: 34 additions & 6 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,8 +1014,9 @@ def test_seed():
@pytest.mark.parametrize("exclude_done", [True, False])
@pytest.mark.parametrize("exclude_action", [True, False])
@pytest.mark.parametrize("has_out", [True, False])
@pytest.mark.parametrize("lazy_stack", [False, True])
def test_steptensordict(
keep_other, exclude_reward, exclude_done, exclude_action, has_out
keep_other, exclude_reward, exclude_done, exclude_action, has_out, lazy_stack
):
torch.manual_seed(0)
tensordict = TensorDict(
Expand All @@ -1031,7 +1032,16 @@ def test_steptensordict(
},
[4],
)
if lazy_stack:
# let's spice this a little bit
tds = tensordict.unbind(0)
tds[0]["this", "one"] = torch.zeros(2)
tds[1]["but", "not", "this", "one"] = torch.ones(2)
tds[0]["next", "this", "one"] = torch.ones(2) * 2
tensordict = torch.stack(tds, 0)
next_tensordict = TensorDict({}, [4]) if has_out else None
if has_out and lazy_stack:
next_tensordict = torch.stack(next_tensordict.unbind(0), 0)
out = step_mdp(
tensordict,
keep_other=keep_other,
Expand All @@ -1041,25 +1051,43 @@ def test_steptensordict(
next_tensordict=next_tensordict,
)
assert "ledzep" in out.keys()
assert out["ledzep"] is tensordict["next", "ledzep"]
if lazy_stack:
assert (out["ledzep"] == tensordict["next", "ledzep"]).all()
assert (out[0]["this", "one"] == 2).all()
if keep_other:
assert (out[1]["but", "not", "this", "one"] == 1).all()
else:
assert out["ledzep"] is tensordict["next", "ledzep"]
if keep_other:
assert "beatles" in out.keys()
assert out["beatles"] is tensordict["beatles"]
if lazy_stack:
assert (out["beatles"] == tensordict["beatles"]).all()
else:
assert out["beatles"] is tensordict["beatles"]
else:
assert "beatles" not in out.keys()
if not exclude_reward:
assert "reward" in out.keys()
assert out["reward"] is tensordict["next", "reward"]
if lazy_stack:
assert (out["reward"] == tensordict["next", "reward"]).all()
else:
assert out["reward"] is tensordict["next", "reward"]
else:
assert "reward" not in out.keys()
if not exclude_action:
assert "action" in out.keys()
assert out["action"] is tensordict["action"]
if lazy_stack:
assert (out["action"] == tensordict["action"]).all()
else:
assert out["action"] is tensordict["action"]
else:
assert "action" not in out.keys()
if not exclude_done:
assert "done" in out.keys()
assert out["done"] is tensordict["next", "done"]
if lazy_stack:
assert (out["done"] == tensordict["next", "done"]).all()
else:
assert out["done"] is tensordict["next", "done"]
else:
assert "done" not in out.keys()
if has_out:
Expand Down
6 changes: 3 additions & 3 deletions test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1115,7 +1115,7 @@ def test_one_hot_discrete_action_spec_rand(self):
torch.manual_seed(0)
action_spec = OneHotDiscreteTensorSpec(10)

sample = torch.stack([action_spec.rand() for _ in range(10000)], 0)
sample = action_spec.rand((100000,))

sample_list = sample.argmax(-1)
sample_list = [sum(sample_list == i).item() for i in range(10)]
Expand Down Expand Up @@ -2115,7 +2115,7 @@ def test_to_numpy(self, shape, stack_dim):
assert (val.numpy() == val_np).all()

with pytest.raises(AssertionError):
c.to_numpy(val + 1)
c.to_numpy(val + 1, safe=True)


class TestStackComposite:
Expand Down Expand Up @@ -2379,7 +2379,7 @@ def test_to_numpy(self):

td_fail = TensorDict({"a": torch.rand((2, 1, 3)) + 1}, [2, 1, 3])
with pytest.raises(AssertionError):
c.to_numpy(td_fail)
c.to_numpy(td_fail, safe=True)


# MultiDiscreteTensorSpec: Pending resolution of https://github.com/pytorch/pytorch/issues/100080.
Expand Down
1 change: 1 addition & 0 deletions test/test_tensordictmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1638,6 +1638,7 @@ def test_singel_step(self, shape):
td = lstm_module(td)
td_next = step_mdp(td, keep_other=True)
td_next = lstm_module(td_next)

assert not torch.isclose(
td_next["next", "hidden0"], td["next", "hidden0"]
).any()
Expand Down
60 changes: 42 additions & 18 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,18 +544,21 @@ def __setattr__(self, key, value):
value = torch.Size(value)
super().__setattr__(key, value)

def to_numpy(self, val: torch.Tensor, safe: bool = True) -> np.ndarray:
def to_numpy(self, val: torch.Tensor, safe: bool = None) -> np.ndarray:
"""Returns the np.ndarray correspondent of an input tensor.
Args:
val (torch.Tensor): tensor to be transformed_in to numpy
val (torch.Tensor): tensor to be transformed_in to numpy.
safe (bool): boolean value indicating whether a check should be
performed on the value against the domain of the spec.
Defaults to the value of the ``CHECK_SPEC_ENCODE`` environment variable.
Returns:
a np.ndarray
"""
if safe is None:
safe = _CHECK_SPEC_ENCODE
if safe:
self.assert_is_in(val)
return val.detach().cpu().numpy()
Expand Down Expand Up @@ -916,7 +919,9 @@ def __eq__(self, other):
# requires unbind to be implemented
pass

def to_numpy(self, val: torch.Tensor, safe: bool = True) -> dict:
def to_numpy(self, val: torch.Tensor, safe: bool = None) -> dict:
if safe is None:
safe = _CHECK_SPEC_ENCODE
if safe:
if val.shape[self.dim] != len(self._specs):
raise ValueError(
Expand Down Expand Up @@ -1120,11 +1125,11 @@ def rand(self, shape=None) -> torch.Tensor:
shape = self.shape[:-1]
else:
shape = torch.Size([*shape, *self.shape[:-1]])
return torch.nn.functional.gumbel_softmax(
torch.rand(torch.Size([*shape, self.space.n]), device=self.device),
hard=True,
dim=-1,
).to(torch.long)
n = self.space.n
m = torch.randint(n, (*shape, 1), device=self.device)
out = torch.zeros((*shape, n), device=self.device, dtype=self.dtype)
out.scatter_(-1, m, 1)
return out

def encode(
self,
Expand Down Expand Up @@ -1153,7 +1158,9 @@ def encode(
val = torch.nn.functional.one_hot(val.long(), space.n)
return val

def to_numpy(self, val: torch.Tensor, safe: bool = True) -> np.ndarray:
def to_numpy(self, val: torch.Tensor, safe: bool = None) -> np.ndarray:
if safe is None:
safe = _CHECK_SPEC_ENCODE
if safe:
if not isinstance(val, torch.Tensor):
raise NotImplementedError
Expand Down Expand Up @@ -1211,17 +1218,20 @@ def __eq__(self, other):
and self.use_register == other.use_register
)

def to_categorical(self, val: torch.Tensor, safe: bool = True) -> torch.Tensor:
def to_categorical(self, val: torch.Tensor, safe: bool = None) -> torch.Tensor:
"""Converts a given one-hot tensor in categorical format.
Args:
val (torch.Tensor, optional): One-hot tensor to convert in categorical format.
safe (bool): boolean value indicating whether a check should be
performed on the value against the domain of the spec.
Defaults to the value of the ``CHECK_SPEC_ENCODE`` environment variable.
Returns:
The categorical tensor.
"""
if safe is None:
safe = _CHECK_SPEC_ENCODE
if safe:
self.assert_is_in(val)
return val.argmax(-1)
Expand Down Expand Up @@ -1827,17 +1837,20 @@ def _project(self, val: torch.Tensor) -> torch.Tensor:
vals = self._split(val)
return torch.cat([super()._project(_val) for _val in vals], -1)

def to_categorical(self, val: torch.Tensor, safe: bool = True) -> torch.Tensor:
def to_categorical(self, val: torch.Tensor, safe: bool = None) -> torch.Tensor:
"""Converts a given one-hot tensor in categorical format.
Args:
val (torch.Tensor, optional): One-hot tensor to convert in categorical format.
safe (bool): boolean value indicating whether a check should be
performed on the value against the domain of the spec.
Defaults to the value of the ``CHECK_SPEC_ENCODE`` environment variable.
Returns:
The categorical tensor.
"""
if safe is None:
safe = _CHECK_SPEC_ENCODE
if safe:
self.assert_is_in(val)
vals = self._split(val)
Expand Down Expand Up @@ -1991,20 +2004,23 @@ def __eq__(self, other):
and self.domain == other.domain
)

def to_numpy(self, val: TensorDict, safe: bool = True) -> dict:
def to_numpy(self, val: TensorDict, safe: bool = None) -> dict:
return super().to_numpy(val, safe)

def to_one_hot(self, val: torch.Tensor, safe: bool = True) -> torch.Tensor:
def to_one_hot(self, val: torch.Tensor, safe: bool = None) -> torch.Tensor:
"""Encodes a discrete tensor from the spec domain into its one-hot correspondent.
Args:
val (torch.Tensor, optional): Tensor to one-hot encode.
safe (bool): boolean value indicating whether a check should be
performed on the value against the domain of the spec.
Defaults to the value of the ``CHECK_SPEC_ENCODE`` environment variable.
Returns:
The one-hot encoded tensor.
"""
if safe is None:
safe = _CHECK_SPEC_ENCODE
if safe:
self.assert_is_in(val)
return torch.nn.functional.one_hot(val, self.space.n)
Expand Down Expand Up @@ -2303,18 +2319,21 @@ def is_in(self, val: torch.Tensor) -> bool:
)

def to_one_hot(
self, val: torch.Tensor, safe: bool = True
self, val: torch.Tensor, safe: bool = None
) -> Union[MultiOneHotDiscreteTensorSpec, torch.Tensor]:
"""Encodes a discrete tensor from the spec domain into its one-hot correspondent.
Args:
val (torch.Tensor, optional): Tensor to one-hot encode.
safe (bool): boolean value indicating whether a check should be
performed on the value against the domain of the spec.
Defaults to the value of the ``CHECK_SPEC_ENCODE`` environment variable.
Returns:
The one-hot encoded tensor.
"""
if safe is None:
safe = _CHECK_SPEC_ENCODE
if safe:
self.assert_is_in(val)
return torch.cat(
Expand Down Expand Up @@ -2621,7 +2640,7 @@ def __getitem__(self, idx):
_idx = idx + (slice(None),) * (
len(v.shape) - len(self.shape) - protected_dims
)
indexed_specs[k] = v[_idx]
indexed_specs[k] = v[_idx] if v is not None else None

try:
device = self.device
Expand Down Expand Up @@ -2880,7 +2899,7 @@ def clone(self) -> CompositeSpec:
shape=self.shape,
)

def to_numpy(self, val: TensorDict, safe: bool = True) -> dict:
def to_numpy(self, val: TensorDict, safe: bool = None) -> dict:
return {key: self[key].to_numpy(val) for key, val in val.items()}

def zero(self, shape=None) -> TensorDictBase:
Expand Down Expand Up @@ -2999,7 +3018,10 @@ def unsqueeze(self, dim: int):
device = self._device

return CompositeSpec(
{key: value.unsqueeze(dim) for key, value in self.items()},
{
key: value.unsqueeze(dim) if value is not None else None
for key, value in self.items()
},
shape=shape,
device=device,
)
Expand Down Expand Up @@ -3090,7 +3112,9 @@ def update(self, dict_or_spec: Union[CompositeSpec, Dict[str, TensorSpec]]) -> N
def __eq__(self, other):
pass

def to_numpy(self, val: TensorDict, safe: bool = True) -> dict:
def to_numpy(self, val: TensorDict, safe: bool = None) -> dict:
if safe is None:
safe = _CHECK_SPEC_ENCODE
if safe:
if val.shape[self.dim] != len(self._specs):
raise ValueError(
Expand Down
22 changes: 10 additions & 12 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,11 +819,11 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase:
# sanity check
self._assert_tensordict_shape(tensordict)

tensordict.lock_() # make sure _step does not modify the tensordict
tensordict_out = self._step(tensordict)
# this tensordict should contain a "next" key
next_tensordict_out = tensordict_out.get("next", None)
if next_tensordict_out is None:
try:
next_tensordict_out = tensordict_out.get("next")
except KeyError:
raise RuntimeError(
"The value returned by env._step must be a tensordict where the "
"values at t+1 have been written under a 'next' entry. This "
Expand All @@ -835,7 +835,6 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase:
"tensordict. Consider emptying the TensorDict first (e.g. tensordict.empty() or "
"tensordict.select()) inside _step before writing new tensors onto this new instance."
)
tensordict.unlock_()

# TODO: Refactor this using reward spec
reward = next_tensordict_out.get(self.reward_key)
Expand Down Expand Up @@ -865,7 +864,6 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase:
if actual_done_shape != expected_done_shape:
done = done.view(expected_done_shape)
next_tensordict_out.set(self.done_key, done)

tensordict_out.set("next", next_tensordict_out)

if self.run_type_checks:
Expand Down Expand Up @@ -1015,9 +1013,9 @@ def set_state(self):
raise NotImplementedError

def _assert_tensordict_shape(self, tensordict: TensorDictBase) -> None:
if tensordict.batch_size != self.batch_size and (
if (
self.batch_locked or self.batch_size != torch.Size([])
):
) and tensordict.batch_size != self.batch_size:
raise RuntimeError(
f"Expected a tensordict with shape==env.shape, "
f"got {tensordict.batch_size} and {self.batch_size}"
Expand Down Expand Up @@ -1237,13 +1235,13 @@ def policy(td):
done_key = (done_key,)
for i in range(max_steps):
if auto_cast_to_device:
tensordict = tensordict.to(policy_device)
tensordict = tensordict.to(policy_device, non_blocking=True)
tensordict = policy(tensordict)
if auto_cast_to_device:
tensordict = tensordict.to(env_device)
tensordict = tensordict.to(env_device, non_blocking=True)
tensordict = self.step(tensordict)

tensordicts.append(tensordict.clone())
tensordicts.append(tensordict.clone(False))
done = tensordict.get(("next", *done_key))
truncated = tensordict.get(
("next", "truncated"),
Expand All @@ -1268,9 +1266,9 @@ def policy(td):
batch_size = self.batch_size if tensordict is None else tensordict.batch_size

out_td = torch.stack(tensordicts, len(batch_size))
out_td.refine_names(..., "time")
if return_contiguous:
return out_td.contiguous()
out_td = out_td.contiguous()
out_td.refine_names(..., "time")
return out_td

def _select_observation_keys(self, tensordict: TensorDictBase) -> Iterator[str]:
Expand Down
Loading

0 comments on commit 99afe8b

Please sign in to comment.