Skip to content

Commit

Permalink
[BugFix] Fix brax (#1346)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jul 2, 2023
1 parent c42ff78 commit fa4fe1d
Show file tree
Hide file tree
Showing 10 changed files with 32 additions and 23 deletions.
21 changes: 11 additions & 10 deletions torchrl/envs/libs/brax.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,32 +386,33 @@ def backward(ctx, _, grad_next_obs, grad_next_reward, *grad_next_qp_values):
# raise RuntimeError("grad_next_qp_values")

pipeline_state = dict(
zip(ctx.next_state["pipeline_state"].keys(), grad_next_qp_values)
zip(ctx.next_state.get("pipeline_state").keys(), grad_next_qp_values)
)
none_keys = []

def _make_none(key, val):
if val is not None:
return val
none_keys.append(key)
return torch.zeros_like(ctx.next_state["pipeline_state"][key])
return torch.zeros_like(ctx.next_state.get(("pipeline_state", key)))

pipeline_state = {
key: _make_none(key, val) for key, val in pipeline_state.items()
}

metrics = ctx.next_state.get("metrics", None)
if metrics is None:
metrics = {}
info = ctx.next_state.get("info", None)
if info is None:
info = {}
grad_next_state_td = TensorDict(
source={
"pipeline_state": pipeline_state,
"obs": grad_next_obs,
"reward": grad_next_reward,
"done": torch.zeros_like(ctx.next_state["done"]),
"metrics": {
k: torch.zeros_like(v) for k, v in ctx.next_state["metrics"].items()
},
"info": {
k: torch.zeros_like(v) for k, v in ctx.next_state["info"].items()
},
"done": torch.zeros_like(ctx.next_state.get("done")),
"metrics": {k: torch.zeros_like(v) for k, v in metrics.items()},
"info": {k: torch.zeros_like(v) for k, v in info.items()},
},
device=ctx.env.device,
batch_size=ctx.env.batch_size,
Expand Down
18 changes: 13 additions & 5 deletions torchrl/envs/libs/jax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,20 +83,28 @@ def _object_to_tensordict(obj, device, batch_size) -> TensorDictBase:
elif isinstance(value, (jnp.ndarray, np.ndarray)):
t[name] = _ndarray_to_tensor(value).to(device)
else:
t[name] = _object_to_tensordict(value, device, batch_size)
return make_tensordict(t, device=device, batch_size=batch_size)
nested = _object_to_tensordict(value, device, batch_size)
if nested is not None:
t[name] = nested
if len(t):
return make_tensordict(t, device=device, batch_size=batch_size)
# discard empty tensordicts
return None


def _tensordict_to_object(tensordict: TensorDictBase, object_example):
"""Converts a TensorDict to a namedtuple or a dataclass."""
t = {}
_fields = _get_object_fields(object_example)
for name, value in tensordict.items():
example = _fields[name]
for name, example in _fields.items():
value = tensordict.get(name, None)
if isinstance(value, TensorDictBase):
t[name] = _tensordict_to_object(value, example)
elif value is None:
t[name] = value
if isinstance(example, dict):
t[name] = _tensordict_to_object({}, example)
else:
t[name] = None
else:
if value.dtype is torch.bool:
value = value.to(torch.uint8)
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

from torchrl.objectives.common import LossModule
from torchrl.objectives.utils import (
_GAMMA_LMBDA_DEPREC_WARNING,
_cache_values,
_GAMMA_LMBDA_DEPREC_WARNING,
default_value_kwargs,
distance_loss,
ValueEstimators,
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from torchrl.modules import ProbabilisticActor
from torchrl.objectives.common import LossModule
from torchrl.objectives.utils import (
_GAMMA_LMBDA_DEPREC_WARNING,
_cache_values,
_GAMMA_LMBDA_DEPREC_WARNING,
default_value_kwargs,
distance_loss,
ValueEstimators,
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from torchrl.modules.tensordict_module.actors import ActorCriticWrapper
from torchrl.objectives.common import LossModule
from torchrl.objectives.utils import (
_GAMMA_LMBDA_DEPREC_WARNING,
_cache_values,
_GAMMA_LMBDA_DEPREC_WARNING,
default_value_kwargs,
distance_loss,
ValueEstimators,
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp
from torchrl.objectives import default_value_kwargs, distance_loss, ValueEstimators
from torchrl.objectives.common import LossModule
from torchrl.objectives.utils import _GAMMA_LMBDA_DEPREC_WARNING, _cache_values
from torchrl.objectives.utils import _cache_values, _GAMMA_LMBDA_DEPREC_WARNING
from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator

try:
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from torch import distributions as d

from torchrl.objectives.utils import (
_GAMMA_LMBDA_DEPREC_WARNING,
_cache_values,
_GAMMA_LMBDA_DEPREC_WARNING,
default_value_kwargs,
distance_loss,
ValueEstimators,
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/redq.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp
from torchrl.objectives.common import LossModule
from torchrl.objectives.utils import (
_GAMMA_LMBDA_DEPREC_WARNING,
_cache_values,
_GAMMA_LMBDA_DEPREC_WARNING,
default_value_kwargs,
distance_loss,
ValueEstimators,
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
from torchrl.modules.tensordict_module.actors import ActorCriticWrapper
from torchrl.objectives.common import LossModule
from torchrl.objectives.utils import (
_GAMMA_LMBDA_DEPREC_WARNING,
_cache_values,
_GAMMA_LMBDA_DEPREC_WARNING,
default_value_kwargs,
distance_loss,
ValueEstimators,
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from torchrl.envs.utils import step_mdp
from torchrl.objectives.common import LossModule
from torchrl.objectives.utils import (
_GAMMA_LMBDA_DEPREC_WARNING,
_cache_values,
_GAMMA_LMBDA_DEPREC_WARNING,
default_value_kwargs,
distance_loss,
ValueEstimators,
Expand Down

0 comments on commit fa4fe1d

Please sign in to comment.