Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix brax #1346

Merged
merged 3 commits into from
Jul 2, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
init
  • Loading branch information
vmoens committed Jul 2, 2023
commit c582a5cfcfd7644592e5f1b11a10b08b6b1b2c76
21 changes: 11 additions & 10 deletions torchrl/envs/libs/brax.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,32 +386,33 @@ def backward(ctx, _, grad_next_obs, grad_next_reward, *grad_next_qp_values):
# raise RuntimeError("grad_next_qp_values")

pipeline_state = dict(
zip(ctx.next_state["pipeline_state"].keys(), grad_next_qp_values)
zip(ctx.next_state.get("pipeline_state").keys(), grad_next_qp_values)
)
none_keys = []

def _make_none(key, val):
if val is not None:
return val
none_keys.append(key)
return torch.zeros_like(ctx.next_state["pipeline_state"][key])
return torch.zeros_like(ctx.next_state.get(("pipeline_state", key)))

pipeline_state = {
key: _make_none(key, val) for key, val in pipeline_state.items()
}

metrics = ctx.next_state.get("metrics", None)
if metrics is None:
metrics = {}
info = ctx.next_state.get("info", None)
if info is None:
info = {}
grad_next_state_td = TensorDict(
source={
"pipeline_state": pipeline_state,
"obs": grad_next_obs,
"reward": grad_next_reward,
"done": torch.zeros_like(ctx.next_state["done"]),
"metrics": {
k: torch.zeros_like(v) for k, v in ctx.next_state["metrics"].items()
},
"info": {
k: torch.zeros_like(v) for k, v in ctx.next_state["info"].items()
},
"done": torch.zeros_like(ctx.next_state.get("done")),
"metrics": {k: torch.zeros_like(v) for k, v in metrics.items()},
"info": {k: torch.zeros_like(v) for k, v in info.items()},
},
device=ctx.env.device,
batch_size=ctx.env.batch_size,
Expand Down
18 changes: 13 additions & 5 deletions torchrl/envs/libs/jax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,20 +83,28 @@ def _object_to_tensordict(obj, device, batch_size) -> TensorDictBase:
elif isinstance(value, (jnp.ndarray, np.ndarray)):
t[name] = _ndarray_to_tensor(value).to(device)
else:
t[name] = _object_to_tensordict(value, device, batch_size)
return make_tensordict(t, device=device, batch_size=batch_size)
nested = _object_to_tensordict(value, device, batch_size)
if nested is not None:
t[name] = nested
if len(t):
return make_tensordict(t, device=device, batch_size=batch_size)
# discard empty tensordicts
return None


def _tensordict_to_object(tensordict: TensorDictBase, object_example):
"""Converts a TensorDict to a namedtuple or a dataclass."""
t = {}
_fields = _get_object_fields(object_example)
for name, value in tensordict.items():
example = _fields[name]
for name, example in _fields.items():
value = tensordict.get(name, None)
if isinstance(value, TensorDictBase):
t[name] = _tensordict_to_object(value, example)
elif value is None:
t[name] = value
if isinstance(example, dict):
t[name] = _tensordict_to_object({}, example)
else:
t[name] = None
else:
if value.dtype is torch.bool:
value = value.to(torch.uint8)
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

from torchrl.objectives.common import LossModule
from torchrl.objectives.utils import (
_GAMMA_LMBDA_DEPREC_WARNING,
_cache_values,
_GAMMA_LMBDA_DEPREC_WARNING,
default_value_kwargs,
distance_loss,
ValueEstimators,
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from torchrl.modules import ProbabilisticActor
from torchrl.objectives.common import LossModule
from torchrl.objectives.utils import (
_GAMMA_LMBDA_DEPREC_WARNING,
_cache_values,
_GAMMA_LMBDA_DEPREC_WARNING,
default_value_kwargs,
distance_loss,
ValueEstimators,
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from torchrl.modules.tensordict_module.actors import ActorCriticWrapper
from torchrl.objectives.common import LossModule
from torchrl.objectives.utils import (
_GAMMA_LMBDA_DEPREC_WARNING,
_cache_values,
_GAMMA_LMBDA_DEPREC_WARNING,
default_value_kwargs,
distance_loss,
ValueEstimators,
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp
from torchrl.objectives import default_value_kwargs, distance_loss, ValueEstimators
from torchrl.objectives.common import LossModule
from torchrl.objectives.utils import _GAMMA_LMBDA_DEPREC_WARNING, _cache_values
from torchrl.objectives.utils import _cache_values, _GAMMA_LMBDA_DEPREC_WARNING
from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator

try:
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from torch import distributions as d

from torchrl.objectives.utils import (
_GAMMA_LMBDA_DEPREC_WARNING,
_cache_values,
_GAMMA_LMBDA_DEPREC_WARNING,
default_value_kwargs,
distance_loss,
ValueEstimators,
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/redq.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp
from torchrl.objectives.common import LossModule
from torchrl.objectives.utils import (
_GAMMA_LMBDA_DEPREC_WARNING,
_cache_values,
_GAMMA_LMBDA_DEPREC_WARNING,
default_value_kwargs,
distance_loss,
ValueEstimators,
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
from torchrl.modules.tensordict_module.actors import ActorCriticWrapper
from torchrl.objectives.common import LossModule
from torchrl.objectives.utils import (
_GAMMA_LMBDA_DEPREC_WARNING,
_cache_values,
_GAMMA_LMBDA_DEPREC_WARNING,
default_value_kwargs,
distance_loss,
ValueEstimators,
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from torchrl.envs.utils import step_mdp
from torchrl.objectives.common import LossModule
from torchrl.objectives.utils import (
_GAMMA_LMBDA_DEPREC_WARNING,
_cache_values,
_GAMMA_LMBDA_DEPREC_WARNING,
default_value_kwargs,
distance_loss,
ValueEstimators,
Expand Down