diff --git a/.github/unittest/linux_libs/scripts_jumanji/install.sh b/.github/unittest/linux_libs/scripts_jumanji/install.sh index 95a4a5a0e29..04875d6fa3d 100755 --- a/.github/unittest/linux_libs/scripts_jumanji/install.sh +++ b/.github/unittest/linux_libs/scripts_jumanji/install.sh @@ -28,15 +28,15 @@ git submodule sync && git submodule update --init --recursive printf "Installing PyTorch with cu121" if [[ "$TORCH_VERSION" == "nightly" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U + pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu -U else - pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 -U + pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu121 -U fi elif [[ "$TORCH_VERSION" == "stable" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install torch --index-url https://download.pytorch.org/whl/cpu + pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu else - pip3 install torch --index-url https://download.pytorch.org/whl/cu121 + pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu121 fi else printf "Failed to install pytorch" diff --git a/.github/unittest/linux_libs/scripts_jumanji/run_test.sh b/.github/unittest/linux_libs/scripts_jumanji/run_test.sh index 542daa6eb99..e62c1935430 100755 --- a/.github/unittest/linux_libs/scripts_jumanji/run_test.sh +++ b/.github/unittest/linux_libs/scripts_jumanji/run_test.sh @@ -29,6 +29,6 @@ export MAGNUM_LOG=verbose MAGNUM_GPU_VALIDATION=ON # this workflow only tests the libs python -c "import jumanji" -python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestJumanji --error-for-skips +python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestJumanji --error-for-skips --runslow coverage combine coverage xml -i diff --git a/test/test_libs.py b/test/test_libs.py index a9e1704f3e9..10c596f4a63 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -1327,14 +1327,15 @@ def test_habitat_render(self, envname, from_pixels): assert "pixels" in rollout.keys() +def _jumanji_envs(): + if not _has_jumanji: + return () + return JumanjiEnv.available_envs[-10:-5] + + @pytest.mark.skipif(not _has_jumanji, reason="jumanji not installed") -@pytest.mark.parametrize( - "envname", - [ - "TSP-v1", - "Snake-v1", - ], -) +@pytest.mark.slow +@pytest.mark.parametrize("envname", _jumanji_envs()) class TestJumanji: def test_jumanji_seeding(self, envname): final_seed = [] @@ -1413,6 +1414,22 @@ def test_jumanji_consistency(self, envname, batch_size): t2 = torch.tensor(onp.asarray(t2)).view_as(t1) torch.testing.assert_close(t1, t2) + @pytest.mark.parametrize("batch_size", [[3], []]) + def test_jumanji_rendering(self, envname, batch_size): + # check that this works with a batch-size + env = JumanjiEnv(envname, from_pixels=True, batch_size=batch_size) + env.set_seed(0) + env.transform.transform_observation_spec(env.base_env.observation_spec) + + r = env.rollout(10) + pixels = r["pixels"] + if not isinstance(pixels, torch.Tensor): + pixels = torch.as_tensor(np.asarray(pixels)) + assert pixels.unique().numel() > 1 + assert pixels.dtype == torch.uint8 + + check_env_specs(env) + ENVPOOL_CLASSIC_CONTROL_ENVS = [ PENDULUM_VERSIONED(), diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 6f8443e45b0..50999eb5f36 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -1950,13 +1950,13 @@ def clone(self) -> NonTensorSpec: return self.__class__(shape=self.shape, device=self.device, dtype=self.dtype) def rand(self, shape): - return NonTensorData(data=None, shape=self.shape, device=self.device) + return NonTensorData(data=None, batch_size=self.shape, device=self.device) - def zero(self, shape): - return NonTensorData(data=None, shape=self.shape, device=self.device) + def zero(self, batch_size): + return NonTensorData(data=None, batch_size=self.shape, device=self.device) - def one(self, shape): - return NonTensorData(data=None, shape=self.shape, device=self.device) + def one(self, batch_size): + return NonTensorData(data=None, batch_size=self.shape, device=self.device) def is_in(self, val: torch.Tensor) -> bool: shape = torch.broadcast_shapes(self.shape, val.shape) diff --git a/torchrl/envs/libs/jumanji.py b/torchrl/envs/libs/jumanji.py index 5e7866864cd..071c8f7f56c 100644 --- a/torchrl/envs/libs/jumanji.py +++ b/torchrl/envs/libs/jumanji.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import importlib.util from typing import Dict, Optional, Tuple, Union @@ -9,6 +11,8 @@ import torch from packaging import version from tensordict import TensorDict, TensorDictBase + +from torchrl.envs.common import _EnvPostInit from torchrl.envs.utils import _classproperty _has_jumanji = importlib.util.find_spec("jumanji") is not None @@ -18,6 +22,8 @@ CompositeSpec, DEVICE_TYPING, DiscreteTensorSpec, + MultiDiscreteTensorSpec, + MultiOneHotDiscreteTensorSpec, OneHotDiscreteTensorSpec, TensorSpec, UnboundedContinuousTensorSpec, @@ -61,6 +67,17 @@ def _jumanji_to_torchrl_spec_transform( if dtype is None: dtype = numpy_to_torch_dtype_dict[spec.dtype] return action_space_cls(spec.num_values, dtype=dtype, device=device) + if isinstance(spec, jumanji.specs.MultiDiscreteArray): + action_space_cls = ( + MultiDiscreteTensorSpec + if categorical_action_encoding + else MultiOneHotDiscreteTensorSpec + ) + if dtype is None: + dtype = numpy_to_torch_dtype_dict[spec.dtype] + return action_space_cls( + torch.as_tensor(np.asarray(spec.num_values)), dtype=dtype, device=device + ) elif isinstance(spec, jumanji.specs.BoundedArray): shape = spec.shape if dtype is None: @@ -98,7 +115,15 @@ def _jumanji_to_torchrl_spec_transform( raise TypeError(f"Unsupported spec type {type(spec)}") -class JumanjiWrapper(GymLikeEnv): +class _JumanjiMakeRender(_EnvPostInit): + def __call__(self, *args, **kwargs): + instance = super().__call__(*args, **kwargs) + if instance.from_pixels: + return instance.make_render() + return instance + + +class JumanjiWrapper(GymLikeEnv, metaclass=_JumanjiMakeRender): """Jumanji environment wrapper. Jumanji offers a vectorized simulation framework based on Jax. @@ -120,7 +145,10 @@ class JumanjiWrapper(GymLikeEnv): Defaults to ``False``. Keyword Args: - from_pixels (bool, optional): Not yet supported. + from_pixels (bool, optional): Whether the environment should render its output. + This will drastically impact the environment throughput. Only the first environment + will be rendered. See :meth:`~torchrl.envs.JumanjiWrapper.render` for more information. + Defaults to `False`. frame_skip (int, optional): if provided, indicates for how many steps the same action is to be repeated. The observation returned will be the last observation of the sequence, whereas the reward will be the sum @@ -301,7 +329,7 @@ class JumanjiWrapper(GymLikeEnv): def available_envs(cls): if not _has_jumanji: return [] - return list(_get_envs()) + return sorted(_get_envs()) @property def lib(self): @@ -309,14 +337,19 @@ def lib(self): if version.parse(jumanji.__version__) < version.parse("1.0.0"): raise ImportError("jumanji version must be >= 1.0.0") - return jumanji - def __init__(self, env: "jumanji.env.Environment" = None, **kwargs): # noqa: F821 + def __init__( + self, + env: "jumanji.env.Environment" = None, # noqa: F821 + categorical_action_encoding=True, + **kwargs, + ): if not _has_jumanji: raise ImportError( "jumanji is not installed or importing it failed. Consider checking your installation." ) + self.categorical_action_encoding = categorical_action_encoding if env is not None: kwargs["env"] = env super().__init__(**kwargs) @@ -334,10 +367,38 @@ def _build_env( self.from_pixels = from_pixels self.pixels_only = pixels_only - if from_pixels: - raise NotImplementedError("TODO") return env + def make_render(self): + """Returns a transformed environment that can be rendered. + + Examples: + >>> from torchrl.envs import JumanjiEnv + >>> from torchrl.record import CSVLogger, VideoRecorder + >>> + >>> envname = JumanjiEnv.available_envs[-1] + >>> logger = CSVLogger("jumanji", video_format="mp4", video_fps=2) + >>> env = JumanjiEnv(envname, from_pixels=True) + >>> + >>> env = env.append_transform( + ... VideoRecorder(logger=logger, in_keys=["pixels"], tag=envname) + ... ) + >>> env.set_seed(0) + >>> r = env.rollout(100) + >>> env.transform.dump() + + """ + from torchrl.record import PixelRenderTransform + + return self.append_transform( + PixelRenderTransform( + out_keys=["pixels"], + pass_tensordict=True, + as_non_tensor=bool(self.batch_size), + as_numpy=bool(self.batch_size), + ) + ) + def _make_state_example(self, env): import jax from jax import numpy as jnp @@ -359,7 +420,9 @@ def _make_state_spec(self, env) -> TensorSpec: def _make_action_spec(self, env) -> TensorSpec: action_spec = _jumanji_to_torchrl_spec_transform( - env.action_spec, device=self.device + env.action_spec, + device=self.device, + categorical_action_encoding=self.categorical_action_encoding, ) action_spec = action_spec.expand(*self.batch_size, *action_spec.shape) return action_spec @@ -445,6 +508,84 @@ def read_obs(self, obs): obs_dict = _object_to_tensordict(obs, self.device, self.batch_size) return super().read_obs(obs_dict) + def render( + self, + tensordict, + matplotlib_backend: str | None = None, + as_numpy: bool = False, + **kwargs, + ): + """Renders the environment output given an input tensordict. + + This method is intended to be called by the :class:`~torchrl.record.PixelRenderTransform` + created whenever `from_pixels=True` is selected. + To create an appropriate rendering transform, use a similar call as bellow: + + >>> from torchrl.record import PixelRenderTransform + >>> matplotlib_backend = None # Change this value if a specific matplotlib backend has to be used. + >>> env = env.append_transform( + ... PixelRenderTransform(out_keys=["pixels"], pass_tensordict=True, matplotlib_backend=matplotlib_backend) + ... ) + + This pipeline will write a `"pixels"` entry in your output tensordict. + + Args: + tensordict (TensorDictBase): a tensordict containing a state to represent + matplotlib_backend (str, optional): the matplotlib backend + as_numpy (bool, optional): if ``False``, the np.ndarray will be converted to a torch.Tensor. + Defaults to ``False``. + + """ + import io + + import jax + import jax.numpy as jnp + import jumanji + + try: + import matplotlib + import matplotlib.pyplot as plt + import PIL + import torchvision.transforms.v2.functional + except ImportError as err: + raise ImportError( + "Rendering with Jumanji requires torchvision, matplotlib and PIL to be installed." + ) from err + + if matplotlib_backend is not None: + matplotlib.use(matplotlib_backend) + + # Get only one env + _state_example = self._state_example + while tensordict.ndim: + tensordict = tensordict[0] + _state_example = jax.tree_util.tree_map( + lambda x: jnp.take(x, 0, axis=0), _state_example + ) + # Patch jumanji is_notebook + is_notebook = jumanji.environments.is_notebook + try: + jumanji.environments.is_notebook = lambda: False + + isinteractive = plt.isinteractive() + plt.ion() + buf = io.BytesIO() + state = _tensordict_to_object(tensordict.get("state"), _state_example) + self._env.render(state, **kwargs) + plt.savefig(buf, format="png") + buf.seek(0) + # Load the image into a PIL object. + img = PIL.Image.open(buf) + img_array = torchvision.transforms.v2.functional.pil_to_tensor(img) + if not isinteractive: + plt.ioff() + plt.close() + if not as_numpy: + return img_array[:3] + return img_array[:3].numpy() + finally: + jumanji.environments.is_notebook = is_notebook + def _step(self, tensordict: TensorDictBase) -> TensorDictBase: import jax diff --git a/torchrl/record/recorder.py b/torchrl/record/recorder.py index 079c8b71e12..2c2f3fb21ac 100644 --- a/torchrl/record/recorder.py +++ b/torchrl/record/recorder.py @@ -5,13 +5,14 @@ from __future__ import annotations import importlib.util +import math from copy import copy from typing import Callable, List, Optional, Sequence, Union import numpy as np import torch -from tensordict import NonTensorData, TensorDict, TensorDictBase +from tensordict import NonTensorData, TensorDictBase from tensordict.utils import NestedKey @@ -216,7 +217,10 @@ def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor: ) from torchvision.utils import make_grid - observation_trsf = make_grid(observation_trsf.flatten(0, -4)) + obs_flat = observation_trsf.flatten(0, -4) + observation_trsf = make_grid( + obs_flat, nrow=int(math.ceil(math.sqrt(obs_flat.shape[0]))) + ) self.obs.append(observation_trsf.to(torch.uint8)) elif observation_trsf.ndimension() >= 4: self.obs.extend(observation_trsf.to(torch.uint8).flatten(0, -4)) @@ -346,6 +350,8 @@ class PixelRenderTransform(Transform): thereby relaxing the shape requirements. If not provided, it will be inferred automatically from the input data type and shape. render_method (str, optional): the name of the render method. Defaults to ``"render"``. + pass_tensordict (bool, optional): if ``True``, the input tensordict will be passed to the + render method. This enables rendering for stateless environments. Defaults to ``False``. **kwargs: additional keyword arguments to pass to the render function (e.g. ``mode="rgb_array"``). Examples: @@ -422,6 +428,7 @@ def __init__( ] = None, as_non_tensor: bool = None, render_method: str = "render", + pass_tensordict: bool = False, **kwargs, ) -> None: if out_keys is None: @@ -439,6 +446,7 @@ def __init__( self.kwargs = kwargs self.render_method = render_method self._enabled = True + self.pass_tensordict = pass_tensordict super().__init__(in_keys=[], out_keys=out_keys) def _reset( @@ -450,7 +458,12 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase: if not self._enabled: return tensordict - array = getattr(self.parent, self.render_method)(**self.kwargs) + method = getattr(self.parent, self.render_method) + if not self.pass_tensordict: + array = method(**self.kwargs) + else: + array = method(tensordict, **self.kwargs) + if self.preproc: array = self.preproc(array) if self.as_non_tensor is None: @@ -489,7 +502,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec switch = True self.switch() parent = self.parent - td_in = TensorDict({}, batch_size=parent.batch_size, device=parent.device) + td_in = parent.reset() self._call(td_in) obs = td_in.get(self.out_keys[0]) if isinstance(obs, NonTensorData):