Skip to content

Commit

Permalink
[Feature] Jumanji from_pixels=True (#2129)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Apr 30, 2024
1 parent 741947a commit 7b9305d
Show file tree
Hide file tree
Showing 6 changed files with 200 additions and 29 deletions.
8 changes: 4 additions & 4 deletions .github/unittest/linux_libs/scripts_jumanji/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ git submodule sync && git submodule update --init --recursive
printf "Installing PyTorch with cu121"
if [[ "$TORCH_VERSION" == "nightly" ]]; then
if [ "${CU_VERSION:-}" == cpu ] ; then
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U
pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu -U
else
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 -U
pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu121 -U
fi
elif [[ "$TORCH_VERSION" == "stable" ]]; then
if [ "${CU_VERSION:-}" == cpu ] ; then
pip3 install torch --index-url https://download.pytorch.org/whl/cpu
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu
else
pip3 install torch --index-url https://download.pytorch.org/whl/cu121
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu121
fi
else
printf "Failed to install pytorch"
Expand Down
2 changes: 1 addition & 1 deletion .github/unittest/linux_libs/scripts_jumanji/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,6 @@ export MAGNUM_LOG=verbose MAGNUM_GPU_VALIDATION=ON
# this workflow only tests the libs
python -c "import jumanji"

python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestJumanji --error-for-skips
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestJumanji --error-for-skips --runslow
coverage combine
coverage xml -i
31 changes: 24 additions & 7 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1327,14 +1327,15 @@ def test_habitat_render(self, envname, from_pixels):
assert "pixels" in rollout.keys()


def _jumanji_envs():
if not _has_jumanji:
return ()
return JumanjiEnv.available_envs[-10:-5]


@pytest.mark.skipif(not _has_jumanji, reason="jumanji not installed")
@pytest.mark.parametrize(
"envname",
[
"TSP-v1",
"Snake-v1",
],
)
@pytest.mark.slow
@pytest.mark.parametrize("envname", _jumanji_envs())
class TestJumanji:
def test_jumanji_seeding(self, envname):
final_seed = []
Expand Down Expand Up @@ -1413,6 +1414,22 @@ def test_jumanji_consistency(self, envname, batch_size):
t2 = torch.tensor(onp.asarray(t2)).view_as(t1)
torch.testing.assert_close(t1, t2)

@pytest.mark.parametrize("batch_size", [[3], []])
def test_jumanji_rendering(self, envname, batch_size):
# check that this works with a batch-size
env = JumanjiEnv(envname, from_pixels=True, batch_size=batch_size)
env.set_seed(0)
env.transform.transform_observation_spec(env.base_env.observation_spec)

r = env.rollout(10)
pixels = r["pixels"]
if not isinstance(pixels, torch.Tensor):
pixels = torch.as_tensor(np.asarray(pixels))
assert pixels.unique().numel() > 1
assert pixels.dtype == torch.uint8

check_env_specs(env)


ENVPOOL_CLASSIC_CONTROL_ENVS = [
PENDULUM_VERSIONED(),
Expand Down
10 changes: 5 additions & 5 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1950,13 +1950,13 @@ def clone(self) -> NonTensorSpec:
return self.__class__(shape=self.shape, device=self.device, dtype=self.dtype)

def rand(self, shape):
return NonTensorData(data=None, shape=self.shape, device=self.device)
return NonTensorData(data=None, batch_size=self.shape, device=self.device)

def zero(self, shape):
return NonTensorData(data=None, shape=self.shape, device=self.device)
def zero(self, batch_size):
return NonTensorData(data=None, batch_size=self.shape, device=self.device)

def one(self, shape):
return NonTensorData(data=None, shape=self.shape, device=self.device)
def one(self, batch_size):
return NonTensorData(data=None, batch_size=self.shape, device=self.device)

def is_in(self, val: torch.Tensor) -> bool:
shape = torch.broadcast_shapes(self.shape, val.shape)
Expand Down
157 changes: 149 additions & 8 deletions torchrl/envs/libs/jumanji.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import importlib.util
from typing import Dict, Optional, Tuple, Union

import numpy as np
import torch
from packaging import version
from tensordict import TensorDict, TensorDictBase

from torchrl.envs.common import _EnvPostInit
from torchrl.envs.utils import _classproperty

_has_jumanji = importlib.util.find_spec("jumanji") is not None
Expand All @@ -18,6 +22,8 @@
CompositeSpec,
DEVICE_TYPING,
DiscreteTensorSpec,
MultiDiscreteTensorSpec,
MultiOneHotDiscreteTensorSpec,
OneHotDiscreteTensorSpec,
TensorSpec,
UnboundedContinuousTensorSpec,
Expand Down Expand Up @@ -61,6 +67,17 @@ def _jumanji_to_torchrl_spec_transform(
if dtype is None:
dtype = numpy_to_torch_dtype_dict[spec.dtype]
return action_space_cls(spec.num_values, dtype=dtype, device=device)
if isinstance(spec, jumanji.specs.MultiDiscreteArray):
action_space_cls = (
MultiDiscreteTensorSpec
if categorical_action_encoding
else MultiOneHotDiscreteTensorSpec
)
if dtype is None:
dtype = numpy_to_torch_dtype_dict[spec.dtype]
return action_space_cls(
torch.as_tensor(np.asarray(spec.num_values)), dtype=dtype, device=device
)
elif isinstance(spec, jumanji.specs.BoundedArray):
shape = spec.shape
if dtype is None:
Expand Down Expand Up @@ -98,7 +115,15 @@ def _jumanji_to_torchrl_spec_transform(
raise TypeError(f"Unsupported spec type {type(spec)}")


class JumanjiWrapper(GymLikeEnv):
class _JumanjiMakeRender(_EnvPostInit):
def __call__(self, *args, **kwargs):
instance = super().__call__(*args, **kwargs)
if instance.from_pixels:
return instance.make_render()
return instance


class JumanjiWrapper(GymLikeEnv, metaclass=_JumanjiMakeRender):
"""Jumanji environment wrapper.
Jumanji offers a vectorized simulation framework based on Jax.
Expand All @@ -120,7 +145,10 @@ class JumanjiWrapper(GymLikeEnv):
Defaults to ``False``.
Keyword Args:
from_pixels (bool, optional): Not yet supported.
from_pixels (bool, optional): Whether the environment should render its output.
This will drastically impact the environment throughput. Only the first environment
will be rendered. See :meth:`~torchrl.envs.JumanjiWrapper.render` for more information.
Defaults to `False`.
frame_skip (int, optional): if provided, indicates for how many steps the
same action is to be repeated. The observation returned will be the
last observation of the sequence, whereas the reward will be the sum
Expand Down Expand Up @@ -301,22 +329,27 @@ class JumanjiWrapper(GymLikeEnv):
def available_envs(cls):
if not _has_jumanji:
return []
return list(_get_envs())
return sorted(_get_envs())

@property
def lib(self):
import jumanji

if version.parse(jumanji.__version__) < version.parse("1.0.0"):
raise ImportError("jumanji version must be >= 1.0.0")

return jumanji

def __init__(self, env: "jumanji.env.Environment" = None, **kwargs): # noqa: F821
def __init__(
self,
env: "jumanji.env.Environment" = None, # noqa: F821
categorical_action_encoding=True,
**kwargs,
):
if not _has_jumanji:
raise ImportError(
"jumanji is not installed or importing it failed. Consider checking your installation."
)
self.categorical_action_encoding = categorical_action_encoding
if env is not None:
kwargs["env"] = env
super().__init__(**kwargs)
Expand All @@ -334,10 +367,38 @@ def _build_env(
self.from_pixels = from_pixels
self.pixels_only = pixels_only

if from_pixels:
raise NotImplementedError("TODO")
return env

def make_render(self):
"""Returns a transformed environment that can be rendered.
Examples:
>>> from torchrl.envs import JumanjiEnv
>>> from torchrl.record import CSVLogger, VideoRecorder
>>>
>>> envname = JumanjiEnv.available_envs[-1]
>>> logger = CSVLogger("jumanji", video_format="mp4", video_fps=2)
>>> env = JumanjiEnv(envname, from_pixels=True)
>>>
>>> env = env.append_transform(
... VideoRecorder(logger=logger, in_keys=["pixels"], tag=envname)
... )
>>> env.set_seed(0)
>>> r = env.rollout(100)
>>> env.transform.dump()
"""
from torchrl.record import PixelRenderTransform

return self.append_transform(
PixelRenderTransform(
out_keys=["pixels"],
pass_tensordict=True,
as_non_tensor=bool(self.batch_size),
as_numpy=bool(self.batch_size),
)
)

def _make_state_example(self, env):
import jax
from jax import numpy as jnp
Expand All @@ -359,7 +420,9 @@ def _make_state_spec(self, env) -> TensorSpec:

def _make_action_spec(self, env) -> TensorSpec:
action_spec = _jumanji_to_torchrl_spec_transform(
env.action_spec, device=self.device
env.action_spec,
device=self.device,
categorical_action_encoding=self.categorical_action_encoding,
)
action_spec = action_spec.expand(*self.batch_size, *action_spec.shape)
return action_spec
Expand Down Expand Up @@ -445,6 +508,84 @@ def read_obs(self, obs):
obs_dict = _object_to_tensordict(obs, self.device, self.batch_size)
return super().read_obs(obs_dict)

def render(
self,
tensordict,
matplotlib_backend: str | None = None,
as_numpy: bool = False,
**kwargs,
):
"""Renders the environment output given an input tensordict.
This method is intended to be called by the :class:`~torchrl.record.PixelRenderTransform`
created whenever `from_pixels=True` is selected.
To create an appropriate rendering transform, use a similar call as bellow:
>>> from torchrl.record import PixelRenderTransform
>>> matplotlib_backend = None # Change this value if a specific matplotlib backend has to be used.
>>> env = env.append_transform(
... PixelRenderTransform(out_keys=["pixels"], pass_tensordict=True, matplotlib_backend=matplotlib_backend)
... )
This pipeline will write a `"pixels"` entry in your output tensordict.
Args:
tensordict (TensorDictBase): a tensordict containing a state to represent
matplotlib_backend (str, optional): the matplotlib backend
as_numpy (bool, optional): if ``False``, the np.ndarray will be converted to a torch.Tensor.
Defaults to ``False``.
"""
import io

import jax
import jax.numpy as jnp
import jumanji

try:
import matplotlib
import matplotlib.pyplot as plt
import PIL
import torchvision.transforms.v2.functional
except ImportError as err:
raise ImportError(
"Rendering with Jumanji requires torchvision, matplotlib and PIL to be installed."
) from err

if matplotlib_backend is not None:
matplotlib.use(matplotlib_backend)

# Get only one env
_state_example = self._state_example
while tensordict.ndim:
tensordict = tensordict[0]
_state_example = jax.tree_util.tree_map(
lambda x: jnp.take(x, 0, axis=0), _state_example
)
# Patch jumanji is_notebook
is_notebook = jumanji.environments.is_notebook
try:
jumanji.environments.is_notebook = lambda: False

isinteractive = plt.isinteractive()
plt.ion()
buf = io.BytesIO()
state = _tensordict_to_object(tensordict.get("state"), _state_example)
self._env.render(state, **kwargs)
plt.savefig(buf, format="png")
buf.seek(0)
# Load the image into a PIL object.
img = PIL.Image.open(buf)
img_array = torchvision.transforms.v2.functional.pil_to_tensor(img)
if not isinteractive:
plt.ioff()
plt.close()
if not as_numpy:
return img_array[:3]
return img_array[:3].numpy()
finally:
jumanji.environments.is_notebook = is_notebook

def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
import jax

Expand Down
Loading

0 comments on commit 7b9305d

Please sign in to comment.