Skip to content

Commit

Permalink
[BugFix] Fix jumanji (#2064)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Apr 7, 2024
1 parent 8dc35be commit cf685b7
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions torchrl/envs/libs/jumanji.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import importlib.util

from typing import Dict, Optional, Tuple, Union

import numpy as np
import torch
from packaging import version
from tensordict import TensorDict, TensorDictBase
from torchrl.envs.utils import _classproperty

Expand Down Expand Up @@ -307,6 +307,9 @@ def available_envs(cls):
def lib(self):
import jumanji

if version.parse(jumanji.__version__) < version.parse("1.0.0"):
raise ImportError("jumanji version must be >= 1.0.0")

return jumanji

def __init__(self, env: "jumanji.env.Environment" = None, **kwargs): # noqa: F821
Expand Down Expand Up @@ -356,15 +359,15 @@ def _make_state_spec(self, env) -> TensorSpec:

def _make_action_spec(self, env) -> TensorSpec:
action_spec = _jumanji_to_torchrl_spec_transform(
env.action_spec(), device=self.device
env.action_spec, device=self.device
)
action_spec = action_spec.expand(*self.batch_size, *action_spec.shape)
return action_spec

def _make_observation_spec(self, env) -> TensorSpec:
jumanji = self.lib

spec = env.observation_spec()
spec = env.observation_spec
new_spec = _jumanji_to_torchrl_spec_transform(spec, device=self.device)
if isinstance(spec, jumanji.specs.Array):
return CompositeSpec(observation=new_spec).expand(self.batch_size)
Expand All @@ -377,7 +380,7 @@ def _make_observation_spec(self, env) -> TensorSpec:

def _make_reward_spec(self, env) -> TensorSpec:
reward_spec = _jumanji_to_torchrl_spec_transform(
env.reward_spec(), device=self.device
env.reward_spec, device=self.device
)
if not len(reward_spec.shape):
reward_spec.shape = torch.Size([1])
Expand Down

0 comments on commit cf685b7

Please sign in to comment.