Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] Some efficiency improvements #1250

Merged
merged 29 commits into from
Jun 9, 2023
Merged
2 changes: 1 addition & 1 deletion test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
device=self.device,
)
return TensorDict(
{"reward": n, "done": done, "observation": n},
{"done": done, "observation": n},
[
*leading_batch_size,
*batch_size,
Expand Down
40 changes: 34 additions & 6 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,8 +1014,9 @@ def test_seed():
@pytest.mark.parametrize("exclude_done", [True, False])
@pytest.mark.parametrize("exclude_action", [True, False])
@pytest.mark.parametrize("has_out", [True, False])
@pytest.mark.parametrize("lazy_stack", [False, True])
def test_steptensordict(
keep_other, exclude_reward, exclude_done, exclude_action, has_out
keep_other, exclude_reward, exclude_done, exclude_action, has_out, lazy_stack
):
torch.manual_seed(0)
tensordict = TensorDict(
Expand All @@ -1031,7 +1032,16 @@ def test_steptensordict(
},
[4],
)
if lazy_stack:
# let's spice this a little bit
tds = tensordict.unbind(0)
tds[0]["this", "one"] = torch.zeros(2)
tds[1]["but", "not", "this", "one"] = torch.ones(2)
tds[0]["next", "this", "one"] = torch.ones(2) * 2
tensordict = torch.stack(tds, 0)
next_tensordict = TensorDict({}, [4]) if has_out else None
if has_out and lazy_stack:
next_tensordict = torch.stack(next_tensordict.unbind(0), 0)
out = step_mdp(
tensordict,
keep_other=keep_other,
Expand All @@ -1041,25 +1051,43 @@ def test_steptensordict(
next_tensordict=next_tensordict,
)
assert "ledzep" in out.keys()
assert out["ledzep"] is tensordict["next", "ledzep"]
if lazy_stack:
assert (out["ledzep"] == tensordict["next", "ledzep"]).all()
assert (out[0]["this", "one"] == 2).all()
if keep_other:
assert (out[1]["but", "not", "this", "one"] == 1).all()
else:
assert out["ledzep"] is tensordict["next", "ledzep"]
if keep_other:
assert "beatles" in out.keys()
assert out["beatles"] is tensordict["beatles"]
if lazy_stack:
assert (out["beatles"] == tensordict["beatles"]).all()
else:
assert out["beatles"] is tensordict["beatles"]
else:
assert "beatles" not in out.keys()
if not exclude_reward:
assert "reward" in out.keys()
assert out["reward"] is tensordict["next", "reward"]
if lazy_stack:
assert (out["reward"] == tensordict["next", "reward"]).all()
else:
assert out["reward"] is tensordict["next", "reward"]
else:
assert "reward" not in out.keys()
if not exclude_action:
assert "action" in out.keys()
assert out["action"] is tensordict["action"]
if lazy_stack:
assert (out["action"] == tensordict["action"]).all()
else:
assert out["action"] is tensordict["action"]
else:
assert "action" not in out.keys()
if not exclude_done:
assert "done" in out.keys()
assert out["done"] is tensordict["next", "done"]
if lazy_stack:
assert (out["done"] == tensordict["next", "done"]).all()
else:
assert out["done"] is tensordict["next", "done"]
else:
assert "done" not in out.keys()
if has_out:
Expand Down
6 changes: 3 additions & 3 deletions test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1115,7 +1115,7 @@ def test_one_hot_discrete_action_spec_rand(self):
torch.manual_seed(0)
action_spec = OneHotDiscreteTensorSpec(10)

sample = torch.stack([action_spec.rand() for _ in range(10000)], 0)
sample = action_spec.rand((100000,))

sample_list = sample.argmax(-1)
sample_list = [sum(sample_list == i).item() for i in range(10)]
Expand Down Expand Up @@ -2115,7 +2115,7 @@ def test_to_numpy(self, shape, stack_dim):
assert (val.numpy() == val_np).all()

with pytest.raises(AssertionError):
c.to_numpy(val + 1)
c.to_numpy(val + 1, safe=True)


class TestStackComposite:
Expand Down Expand Up @@ -2379,7 +2379,7 @@ def test_to_numpy(self):

td_fail = TensorDict({"a": torch.rand((2, 1, 3)) + 1}, [2, 1, 3])
with pytest.raises(AssertionError):
c.to_numpy(td_fail)
c.to_numpy(td_fail, safe=True)


# MultiDiscreteTensorSpec: Pending resolution of https://github.com/pytorch/pytorch/issues/100080.
Expand Down
1 change: 1 addition & 0 deletions test/test_tensordictmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1638,6 +1638,7 @@ def test_singel_step(self, shape):
td = lstm_module(td)
td_next = step_mdp(td, keep_other=True)
td_next = lstm_module(td_next)

assert not torch.isclose(
td_next["next", "hidden0"], td["next", "hidden0"]
).any()
Expand Down
60 changes: 42 additions & 18 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,18 +544,21 @@ def __setattr__(self, key, value):
value = torch.Size(value)
super().__setattr__(key, value)

def to_numpy(self, val: torch.Tensor, safe: bool = True) -> np.ndarray:
def to_numpy(self, val: torch.Tensor, safe: bool = None) -> np.ndarray:
"""Returns the np.ndarray correspondent of an input tensor.

Args:
val (torch.Tensor): tensor to be transformed_in to numpy
val (torch.Tensor): tensor to be transformed_in to numpy.
safe (bool): boolean value indicating whether a check should be
performed on the value against the domain of the spec.
Defaults to the value of the ``CHECK_SPEC_ENCODE`` environment variable.

Returns:
a np.ndarray

"""
if safe is None:
safe = _CHECK_SPEC_ENCODE
if safe:
self.assert_is_in(val)
return val.detach().cpu().numpy()
Expand Down Expand Up @@ -916,7 +919,9 @@ def __eq__(self, other):
# requires unbind to be implemented
pass

def to_numpy(self, val: torch.Tensor, safe: bool = True) -> dict:
def to_numpy(self, val: torch.Tensor, safe: bool = None) -> dict:
if safe is None:
safe = _CHECK_SPEC_ENCODE
if safe:
if val.shape[self.dim] != len(self._specs):
raise ValueError(
Expand Down Expand Up @@ -1120,11 +1125,11 @@ def rand(self, shape=None) -> torch.Tensor:
shape = self.shape[:-1]
else:
shape = torch.Size([*shape, *self.shape[:-1]])
return torch.nn.functional.gumbel_softmax(
torch.rand(torch.Size([*shape, self.space.n]), device=self.device),
hard=True,
dim=-1,
).to(torch.long)
n = self.space.n
m = torch.randint(n, (*shape, 1), device=self.device)
out = torch.zeros((*shape, n), device=self.device, dtype=self.dtype)
out.scatter_(-1, m, 1)
return out

def encode(
self,
Expand Down Expand Up @@ -1153,7 +1158,9 @@ def encode(
val = torch.nn.functional.one_hot(val.long(), space.n)
return val

def to_numpy(self, val: torch.Tensor, safe: bool = True) -> np.ndarray:
def to_numpy(self, val: torch.Tensor, safe: bool = None) -> np.ndarray:
if safe is None:
safe = _CHECK_SPEC_ENCODE
if safe:
if not isinstance(val, torch.Tensor):
raise NotImplementedError
Expand Down Expand Up @@ -1211,17 +1218,20 @@ def __eq__(self, other):
and self.use_register == other.use_register
)

def to_categorical(self, val: torch.Tensor, safe: bool = True) -> torch.Tensor:
def to_categorical(self, val: torch.Tensor, safe: bool = None) -> torch.Tensor:
"""Converts a given one-hot tensor in categorical format.

Args:
val (torch.Tensor, optional): One-hot tensor to convert in categorical format.
safe (bool): boolean value indicating whether a check should be
performed on the value against the domain of the spec.
Defaults to the value of the ``CHECK_SPEC_ENCODE`` environment variable.

Returns:
The categorical tensor.
"""
if safe is None:
safe = _CHECK_SPEC_ENCODE
if safe:
self.assert_is_in(val)
return val.argmax(-1)
Expand Down Expand Up @@ -1827,17 +1837,20 @@ def _project(self, val: torch.Tensor) -> torch.Tensor:
vals = self._split(val)
return torch.cat([super()._project(_val) for _val in vals], -1)

def to_categorical(self, val: torch.Tensor, safe: bool = True) -> torch.Tensor:
def to_categorical(self, val: torch.Tensor, safe: bool = None) -> torch.Tensor:
"""Converts a given one-hot tensor in categorical format.

Args:
val (torch.Tensor, optional): One-hot tensor to convert in categorical format.
safe (bool): boolean value indicating whether a check should be
performed on the value against the domain of the spec.
Defaults to the value of the ``CHECK_SPEC_ENCODE`` environment variable.

Returns:
The categorical tensor.
"""
if safe is None:
safe = _CHECK_SPEC_ENCODE
if safe:
self.assert_is_in(val)
vals = self._split(val)
Expand Down Expand Up @@ -1991,20 +2004,23 @@ def __eq__(self, other):
and self.domain == other.domain
)

def to_numpy(self, val: TensorDict, safe: bool = True) -> dict:
def to_numpy(self, val: TensorDict, safe: bool = None) -> dict:
return super().to_numpy(val, safe)

def to_one_hot(self, val: torch.Tensor, safe: bool = True) -> torch.Tensor:
def to_one_hot(self, val: torch.Tensor, safe: bool = None) -> torch.Tensor:
"""Encodes a discrete tensor from the spec domain into its one-hot correspondent.

Args:
val (torch.Tensor, optional): Tensor to one-hot encode.
safe (bool): boolean value indicating whether a check should be
performed on the value against the domain of the spec.
Defaults to the value of the ``CHECK_SPEC_ENCODE`` environment variable.

Returns:
The one-hot encoded tensor.
"""
if safe is None:
safe = _CHECK_SPEC_ENCODE
if safe:
self.assert_is_in(val)
return torch.nn.functional.one_hot(val, self.space.n)
Expand Down Expand Up @@ -2303,18 +2319,21 @@ def is_in(self, val: torch.Tensor) -> bool:
)

def to_one_hot(
self, val: torch.Tensor, safe: bool = True
self, val: torch.Tensor, safe: bool = None
) -> Union[MultiOneHotDiscreteTensorSpec, torch.Tensor]:
"""Encodes a discrete tensor from the spec domain into its one-hot correspondent.

Args:
val (torch.Tensor, optional): Tensor to one-hot encode.
safe (bool): boolean value indicating whether a check should be
performed on the value against the domain of the spec.
Defaults to the value of the ``CHECK_SPEC_ENCODE`` environment variable.

Returns:
The one-hot encoded tensor.
"""
if safe is None:
safe = _CHECK_SPEC_ENCODE
if safe:
self.assert_is_in(val)
return torch.cat(
Expand Down Expand Up @@ -2621,7 +2640,7 @@ def __getitem__(self, idx):
_idx = idx + (slice(None),) * (
len(v.shape) - len(self.shape) - protected_dims
)
indexed_specs[k] = v[_idx]
indexed_specs[k] = v[_idx] if v is not None else None

try:
device = self.device
Expand Down Expand Up @@ -2880,7 +2899,7 @@ def clone(self) -> CompositeSpec:
shape=self.shape,
)

def to_numpy(self, val: TensorDict, safe: bool = True) -> dict:
def to_numpy(self, val: TensorDict, safe: bool = None) -> dict:
return {key: self[key].to_numpy(val) for key, val in val.items()}

def zero(self, shape=None) -> TensorDictBase:
Expand Down Expand Up @@ -2999,7 +3018,10 @@ def unsqueeze(self, dim: int):
device = self._device

return CompositeSpec(
{key: value.unsqueeze(dim) for key, value in self.items()},
{
key: value.unsqueeze(dim) if value is not None else None
for key, value in self.items()
},
shape=shape,
device=device,
)
Expand Down Expand Up @@ -3090,7 +3112,9 @@ def update(self, dict_or_spec: Union[CompositeSpec, Dict[str, TensorSpec]]) -> N
def __eq__(self, other):
pass

def to_numpy(self, val: TensorDict, safe: bool = True) -> dict:
def to_numpy(self, val: TensorDict, safe: bool = None) -> dict:
if safe is None:
safe = _CHECK_SPEC_ENCODE
if safe:
if val.shape[self.dim] != len(self._specs):
raise ValueError(
Expand Down
22 changes: 10 additions & 12 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,11 +819,11 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase:
# sanity check
self._assert_tensordict_shape(tensordict)

tensordict.lock_() # make sure _step does not modify the tensordict
tensordict_out = self._step(tensordict)
# this tensordict should contain a "next" key
next_tensordict_out = tensordict_out.get("next", None)
if next_tensordict_out is None:
try:
next_tensordict_out = tensordict_out.get("next")
except KeyError:
raise RuntimeError(
"The value returned by env._step must be a tensordict where the "
"values at t+1 have been written under a 'next' entry. This "
Expand All @@ -835,7 +835,6 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase:
"tensordict. Consider emptying the TensorDict first (e.g. tensordict.empty() or "
"tensordict.select()) inside _step before writing new tensors onto this new instance."
)
tensordict.unlock_()

# TODO: Refactor this using reward spec
reward = next_tensordict_out.get(self.reward_key)
Expand Down Expand Up @@ -865,7 +864,6 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase:
if actual_done_shape != expected_done_shape:
done = done.view(expected_done_shape)
next_tensordict_out.set(self.done_key, done)

tensordict_out.set("next", next_tensordict_out)

if self.run_type_checks:
Expand Down Expand Up @@ -1015,9 +1013,9 @@ def set_state(self):
raise NotImplementedError

def _assert_tensordict_shape(self, tensordict: TensorDictBase) -> None:
if tensordict.batch_size != self.batch_size and (
if (
self.batch_locked or self.batch_size != torch.Size([])
):
) and tensordict.batch_size != self.batch_size:
raise RuntimeError(
f"Expected a tensordict with shape==env.shape, "
f"got {tensordict.batch_size} and {self.batch_size}"
Expand Down Expand Up @@ -1237,13 +1235,13 @@ def policy(td):
done_key = (done_key,)
for i in range(max_steps):
if auto_cast_to_device:
tensordict = tensordict.to(policy_device)
tensordict = tensordict.to(policy_device, non_blocking=True)
tensordict = policy(tensordict)
if auto_cast_to_device:
tensordict = tensordict.to(env_device)
tensordict = tensordict.to(env_device, non_blocking=True)
tensordict = self.step(tensordict)

tensordicts.append(tensordict.clone())
tensordicts.append(tensordict.clone(False))
done = tensordict.get(("next", *done_key))
truncated = tensordict.get(
("next", "truncated"),
Expand All @@ -1268,9 +1266,9 @@ def policy(td):
batch_size = self.batch_size if tensordict is None else tensordict.batch_size

out_td = torch.stack(tensordicts, len(batch_size))
out_td.refine_names(..., "time")
if return_contiguous:
return out_td.contiguous()
out_td = out_td.contiguous()
out_td.refine_names(..., "time")
return out_td

def _select_observation_keys(self, tensordict: TensorDictBase) -> Iterator[str]:
Expand Down
Loading