From fa4fe1d3aee1e7526c930bbf49796f213b5ef919 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 2 Jul 2023 12:37:57 +0100 Subject: [PATCH] [BugFix] Fix brax (#1346) --- torchrl/envs/libs/brax.py | 21 +++++++++++---------- torchrl/envs/libs/jax_utils.py | 18 +++++++++++++----- torchrl/objectives/a2c.py | 2 +- torchrl/objectives/cql.py | 2 +- torchrl/objectives/ddpg.py | 2 +- torchrl/objectives/deprecated.py | 2 +- torchrl/objectives/ppo.py | 2 +- torchrl/objectives/redq.py | 2 +- torchrl/objectives/sac.py | 2 +- torchrl/objectives/td3.py | 2 +- 10 files changed, 32 insertions(+), 23 deletions(-) diff --git a/torchrl/envs/libs/brax.py b/torchrl/envs/libs/brax.py index 402106345ac..06c5e4db28a 100644 --- a/torchrl/envs/libs/brax.py +++ b/torchrl/envs/libs/brax.py @@ -386,7 +386,7 @@ def backward(ctx, _, grad_next_obs, grad_next_reward, *grad_next_qp_values): # raise RuntimeError("grad_next_qp_values") pipeline_state = dict( - zip(ctx.next_state["pipeline_state"].keys(), grad_next_qp_values) + zip(ctx.next_state.get("pipeline_state").keys(), grad_next_qp_values) ) none_keys = [] @@ -394,24 +394,25 @@ def _make_none(key, val): if val is not None: return val none_keys.append(key) - return torch.zeros_like(ctx.next_state["pipeline_state"][key]) + return torch.zeros_like(ctx.next_state.get(("pipeline_state", key))) pipeline_state = { key: _make_none(key, val) for key, val in pipeline_state.items() } - + metrics = ctx.next_state.get("metrics", None) + if metrics is None: + metrics = {} + info = ctx.next_state.get("info", None) + if info is None: + info = {} grad_next_state_td = TensorDict( source={ "pipeline_state": pipeline_state, "obs": grad_next_obs, "reward": grad_next_reward, - "done": torch.zeros_like(ctx.next_state["done"]), - "metrics": { - k: torch.zeros_like(v) for k, v in ctx.next_state["metrics"].items() - }, - "info": { - k: torch.zeros_like(v) for k, v in ctx.next_state["info"].items() - }, + "done": torch.zeros_like(ctx.next_state.get("done")), + "metrics": {k: torch.zeros_like(v) for k, v in metrics.items()}, + "info": {k: torch.zeros_like(v) for k, v in info.items()}, }, device=ctx.env.device, batch_size=ctx.env.batch_size, diff --git a/torchrl/envs/libs/jax_utils.py b/torchrl/envs/libs/jax_utils.py index 266225ce77b..5eac0a42c9e 100644 --- a/torchrl/envs/libs/jax_utils.py +++ b/torchrl/envs/libs/jax_utils.py @@ -83,20 +83,28 @@ def _object_to_tensordict(obj, device, batch_size) -> TensorDictBase: elif isinstance(value, (jnp.ndarray, np.ndarray)): t[name] = _ndarray_to_tensor(value).to(device) else: - t[name] = _object_to_tensordict(value, device, batch_size) - return make_tensordict(t, device=device, batch_size=batch_size) + nested = _object_to_tensordict(value, device, batch_size) + if nested is not None: + t[name] = nested + if len(t): + return make_tensordict(t, device=device, batch_size=batch_size) + # discard empty tensordicts + return None def _tensordict_to_object(tensordict: TensorDictBase, object_example): """Converts a TensorDict to a namedtuple or a dataclass.""" t = {} _fields = _get_object_fields(object_example) - for name, value in tensordict.items(): - example = _fields[name] + for name, example in _fields.items(): + value = tensordict.get(name, None) if isinstance(value, TensorDictBase): t[name] = _tensordict_to_object(value, example) elif value is None: - t[name] = value + if isinstance(example, dict): + t[name] = _tensordict_to_object({}, example) + else: + t[name] = None else: if value.dtype is torch.bool: value = value.to(torch.uint8) diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index 51fcd70d203..89314c0603b 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -14,8 +14,8 @@ from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( - _GAMMA_LMBDA_DEPREC_WARNING, _cache_values, + _GAMMA_LMBDA_DEPREC_WARNING, default_value_kwargs, distance_loss, ValueEstimators, diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index 1bf01c28883..06e8de28bf7 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -21,8 +21,8 @@ from torchrl.modules import ProbabilisticActor from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( - _GAMMA_LMBDA_DEPREC_WARNING, _cache_values, + _GAMMA_LMBDA_DEPREC_WARNING, default_value_kwargs, distance_loss, ValueEstimators, diff --git a/torchrl/objectives/ddpg.py b/torchrl/objectives/ddpg.py index d5e7a30d6ad..03966fd21a0 100644 --- a/torchrl/objectives/ddpg.py +++ b/torchrl/objectives/ddpg.py @@ -18,8 +18,8 @@ from torchrl.modules.tensordict_module.actors import ActorCriticWrapper from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( - _GAMMA_LMBDA_DEPREC_WARNING, _cache_values, + _GAMMA_LMBDA_DEPREC_WARNING, default_value_kwargs, distance_loss, ValueEstimators, diff --git a/torchrl/objectives/deprecated.py b/torchrl/objectives/deprecated.py index 2476d4dc2ba..02b82ff430c 100644 --- a/torchrl/objectives/deprecated.py +++ b/torchrl/objectives/deprecated.py @@ -21,7 +21,7 @@ from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp from torchrl.objectives import default_value_kwargs, distance_loss, ValueEstimators from torchrl.objectives.common import LossModule -from torchrl.objectives.utils import _GAMMA_LMBDA_DEPREC_WARNING, _cache_values +from torchrl.objectives.utils import _cache_values, _GAMMA_LMBDA_DEPREC_WARNING from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator try: diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index e7d145b6943..c18a8487a37 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -14,8 +14,8 @@ from torch import distributions as d from torchrl.objectives.utils import ( - _GAMMA_LMBDA_DEPREC_WARNING, _cache_values, + _GAMMA_LMBDA_DEPREC_WARNING, default_value_kwargs, distance_loss, ValueEstimators, diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index 2c69343c4f9..0429cac96b9 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -20,8 +20,8 @@ from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( - _GAMMA_LMBDA_DEPREC_WARNING, _cache_values, + _GAMMA_LMBDA_DEPREC_WARNING, default_value_kwargs, distance_loss, ValueEstimators, diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 9aa993852e6..aeb9adbafea 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -22,8 +22,8 @@ from torchrl.modules.tensordict_module.actors import ActorCriticWrapper from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( - _GAMMA_LMBDA_DEPREC_WARNING, _cache_values, + _GAMMA_LMBDA_DEPREC_WARNING, default_value_kwargs, distance_loss, ValueEstimators, diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index 2be74c575f8..78445593f4d 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -16,8 +16,8 @@ from torchrl.envs.utils import step_mdp from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( - _GAMMA_LMBDA_DEPREC_WARNING, _cache_values, + _GAMMA_LMBDA_DEPREC_WARNING, default_value_kwargs, distance_loss, ValueEstimators,