Skip to content

Commit

Permalink
[Performance] Faster DMC (pytorch#2002)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Mar 8, 2024
1 parent ad73733 commit 358475a
Show file tree
Hide file tree
Showing 8 changed files with 318 additions and 79 deletions.
1 change: 0 additions & 1 deletion test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -12151,7 +12151,6 @@ def test_args_kwargs_timedim(self, device):
time_dim=-3,
)[0]


v2 = vec_generalized_advantage_estimate(
gamma=gamma,
lmbda=lmbda,
Expand Down
49 changes: 49 additions & 0 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
from torchrl.envs.libs.gym import _has_gym, GymEnv, GymWrapper
from torchrl.envs.transforms import Compose, StepCounter, TransformedEnv
from torchrl.envs.utils import (
_StepMDP,
_terminated_or_truncated,
check_env_specs,
check_marl_grouping,
Expand Down Expand Up @@ -1312,6 +1313,54 @@ def test_steptensordict(
if has_out:
assert out is next_tensordict

@pytest.mark.parametrize("keep_other", [True, False])
@pytest.mark.parametrize("exclude_reward", [True, False])
@pytest.mark.parametrize("exclude_done", [False, True])
@pytest.mark.parametrize("exclude_action", [False, True])
@pytest.mark.parametrize(
"envcls",
[
ContinuousActionVecMockEnv,
CountingBatchedEnv,
CountingEnv,
NestedCountingEnv,
CountingBatchedEnv,
HeterogeneousCountingEnv,
DiscreteActionConvMockEnv,
],
)
def test_step_class(
self,
envcls,
keep_other,
exclude_reward,
exclude_done,
exclude_action,
):
torch.manual_seed(0)
env = envcls()

tensordict = env.rand_step(env.reset())
out = step_mdp(
tensordict.lock_(),
keep_other=keep_other,
exclude_reward=exclude_reward,
exclude_done=exclude_done,
exclude_action=exclude_action,
done_keys=env.done_keys,
action_keys=env.action_keys,
reward_keys=env.reward_keys,
)
step_func = _StepMDP(
env,
keep_other=keep_other,
exclude_reward=exclude_reward,
exclude_done=exclude_done,
exclude_action=exclude_action,
)
out2 = step_func(tensordict)
assert (out == out2).all()

@pytest.mark.parametrize("nested_obs", [True, False])
@pytest.mark.parametrize("nested_action", [True, False])
@pytest.mark.parametrize("nested_done", [True, False])
Expand Down
16 changes: 10 additions & 6 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1555,9 +1555,9 @@ def __init__(
dtype = torch.get_default_dtype()

if not isinstance(low, torch.Tensor):
low = torch.as_tensor(low, dtype=dtype, device=device)
low = torch.tensor(low, dtype=dtype, device=device)
if not isinstance(high, torch.Tensor):
high = torch.as_tensor(high, dtype=dtype, device=device)
high = torch.tensor(high, dtype=dtype, device=device)
if high.device != device:
high = high.to(device)
if low.device != device:
Expand Down Expand Up @@ -3599,13 +3599,17 @@ def project(self, val: TensorDictBase) -> TensorDictBase:
def rand(self, shape=None) -> TensorDictBase:
if shape is None:
shape = torch.Size([])
_dict = {
key: self[key].rand(shape) for key in self.keys() if self[key] is not None
}
_dict = {}
for key, item in self.items():
if item is not None:
_dict[key] = item.rand(shape)
return TensorDict(
_dict,
batch_size=[*shape, *self.shape],
batch_size=torch.Size([*shape, *self.shape]),
device=self._device,
# No need to run checks since we know Composite is compliant with
# TensorDict requirements
_run_checks=False,
)

def keys(
Expand Down
62 changes: 38 additions & 24 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@
from torchrl.envs.utils import (
_make_compatible_policy,
_repr_by_depth,
_StepMDP,
_terminated_or_truncated,
_update_during_reset,
get_available_libraries,
step_mdp,
)

LIBRARIES = get_available_libraries()
Expand Down Expand Up @@ -2513,6 +2513,14 @@ def rollout(
out_td.refine_names(..., "time")
return out_td

@property
def _step_mdp(self):
step_func = self.__dict__.get("_step_mdp_value", None)
if step_func is None:
step_func = _StepMDP(self, exclude_action=False)
self.__dict__["_step_mdp_value"] = step_func
return step_func

def _rollout_stop_early(
self,
*,
Expand Down Expand Up @@ -2543,15 +2551,8 @@ def _rollout_stop_early(
if i == max_steps - 1:
# we don't truncated as one could potentially continue the run
break
tensordict = step_mdp(
tensordict,
keep_other=True,
exclude_action=False,
exclude_reward=True,
reward_keys=self.reward_keys,
action_keys=self.action_keys,
done_keys=self.done_keys,
)
tensordict = self._step_mdp(tensordict)

# done and truncated are in done_keys
# We read if any key is done.
any_done = _terminated_or_truncated(
Expand Down Expand Up @@ -2649,18 +2650,23 @@ def step_and_maybe_reset(
tensordict = self.step(tensordict)
# done and truncated are in done_keys
# We read if any key is done.
tensordict_ = step_mdp(
tensordict,
keep_other=True,
exclude_action=False,
exclude_reward=True,
reward_keys=self.reward_keys,
action_keys=self.action_keys,
done_keys=self.done_keys,
)
tensordict_ = self._step_mdp(tensordict)
tensordict_ = self.maybe_reset(tensordict_)
return tensordict, tensordict_

@property
def _simple_done(self):
_simple_done = self.__dict__.get("_simple_done_value", None)
if _simple_done is None:
key_set = set(self.full_done_spec.keys())
_simple_done = key_set == {
"done",
"truncated",
"terminated",
} or key_set == {"done", "terminated"}
self.__dict__["_simple_done_value"] = _simple_done
return _simple_done

def maybe_reset(self, tensordict: TensorDictBase) -> TensorDictBase:
"""Checks the done keys of the input tensordict and, if needed, resets the environment where it is done.
Expand All @@ -2672,11 +2678,19 @@ def maybe_reset(self, tensordict: TensorDictBase) -> TensorDictBase:
not reset and contains the new reset data where the environment was reset.
"""
any_done = _terminated_or_truncated(
tensordict,
full_done_spec=self.output_spec["full_done_spec"],
key="_reset",
)
if self._simple_done:
done = tensordict._get_str("done", default=None)
any_done = done.any()
if any_done:
tensordict._set_str(
"_reset", done.clone(), validated=True, inplace=False
)
else:
any_done = _terminated_or_truncated(
tensordict,
full_done_spec=self.output_spec["full_done_spec"],
key="_reset",
)
if any_done:
tensordict = self.reset(tensordict)
return tensordict
Expand Down
56 changes: 31 additions & 25 deletions torchrl/envs/gym_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
from __future__ import annotations

import abc
import itertools
import warnings
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union

import numpy as np
import torch
Expand All @@ -32,7 +31,8 @@ def __call__(
) -> TensorDictBase:
raise NotImplementedError

@abc.abstractproperty
@property
@abc.abstractmethod
def info_spec(self) -> Dict[str, TensorSpec]:
raise NotImplementedError

Expand Down Expand Up @@ -235,7 +235,14 @@ def read_reward(self, reward):
reward (torch.Tensor or TensorDict): reward to be mapped.
"""
return self.reward_spec.encode(reward, ignore_device=True)
if isinstance(reward, int) and reward == 0:
return self.reward_spec.zero()
reward = self.reward_spec.encode(reward, ignore_device=True)

if reward is None:
reward = torch.tensor(np.nan).expand(self.reward_spec.shape)

return reward

def read_obs(
self, observations: Union[Dict[str, Any], torch.Tensor, np.ndarray]
Expand All @@ -253,14 +260,21 @@ def read_obs(
# naming it 'state' will result in envs that have a different name for the state vector
# when queried with and without pixels
observations["observation"] = observations.pop("state")
if not isinstance(observations, (TensorDict, dict)):
(key,) = itertools.islice(self.observation_spec.keys(True, True), 1)
observations = {key: observations}
for key, val in observations.items():
observations[key] = self.observation_spec[key].encode(
val, ignore_device=True
)
# observations = self.observation_spec.encode(observations, ignore_device=True)
if not isinstance(observations, Mapping):
for key, spec in self.observation_spec.items(True, True):
observations_dict = {}
observations_dict[key] = spec.encode(observations, ignore_device=True)
# we don't check that there is only one spec because obs spec also
# contains the data spec of the info dict.
break
else:
raise RuntimeError("Could not find any element in observation_spec.")
observations = observations_dict
else:
for key, val in observations.items():
observations[key] = self.observation_spec[key].encode(
val, ignore_device=True
)
return observations

def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
Expand All @@ -277,14 +291,9 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
done,
info_dict,
) = self._output_transform(self._env.step(action_np))
if isinstance(obs, list) and len(obs) == 1:
# Until gym 0.25.2 we had rendered frames returned in lists of length 1
obs = obs[0]

if _reward is None:
_reward = self.reward_spec.zero()

reward = reward + _reward
if _reward is not None:
reward = reward + _reward

terminated, truncated, done, do_break = self.read_done(
terminated=terminated, truncated=truncated, done=done
Expand All @@ -294,17 +303,13 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:

reward = self.read_reward(reward)
obs_dict = self.read_obs(obs)

if reward is None:
reward = torch.tensor(np.nan).expand(self.reward_spec.shape)

obs_dict[self.reward_key] = reward

# if truncated/terminated is not in the keys, we just don't pass it even if it
# is defined.
if terminated is None:
terminated = done
if truncated is not None and "truncated" in self.done_keys:
if truncated is not None:
obs_dict["truncated"] = truncated
obs_dict["done"] = done
obs_dict["terminated"] = terminated
Expand All @@ -322,7 +327,8 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
tensordict_out = TensorDict(
obs_dict, batch_size=tensordict.batch_size, _run_checks=False
)
tensordict_out = tensordict_out.to(self.device, non_blocking=True)
if self.device is not None:
tensordict_out = tensordict_out.to(self.device, non_blocking=True)

if self.info_dict_reader and (info_dict is not None):
if not isinstance(info_dict, dict):
Expand Down
10 changes: 10 additions & 0 deletions torchrl/envs/libs/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,6 +1062,11 @@ def _output_transform(self, step_outputs_tuple): # noqa: F811
# if it's not a ndarray, we must return bool
# since it's not a bool, we make it so
terminated = bool(terminated)

if isinstance(observations, list) and len(observations) == 1:
# Until gym 0.25.2 we had rendered frames returned in lists of length 1
observations = observations[0]

return (observations, reward, terminated, truncated, done, info)

@implement_for("gym", "0.24", "0.26")
Expand All @@ -1083,6 +1088,11 @@ def _output_transform(self, step_outputs_tuple): # noqa: F811
# if it's not a ndarray, we must return bool
# since it's not a bool, we make it so
terminated = bool(terminated)

if isinstance(observations, list) and len(observations) == 1:
# Until gym 0.25.2 we had rendered frames returned in lists of length 1
observations = observations[0]

return (observations, reward, terminated, truncated, done, info)

@implement_for("gym", "0.26", None)
Expand Down
Loading

0 comments on commit 358475a

Please sign in to comment.