Skip to content

Commit

Permalink
[BugFix] Fix Isaac (pytorch#2072)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Apr 25, 2024
1 parent 160a946 commit 23e2121
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 43 deletions.
45 changes: 27 additions & 18 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import importlib
import os
from contextlib import nullcontext
from pathlib import Path

from torchrl._utils import logger as torchrl_logger

from torchrl.data.datasets.gen_dgrl import GenDGRLExperienceReplay

from torchrl.envs.transforms import ActionMask, TransformedEnv
from torchrl.modules import MaskedCategorical
import importlib.util

_has_isaac = importlib.util.find_spec("isaacgym") is not None

Expand All @@ -21,11 +11,13 @@
import isaacgym # noqa
import isaacgymenvs # noqa
from torchrl.envs.libs.isaacgym import IsaacGymEnv

import argparse
import importlib
import os

import time
from contextlib import nullcontext
from pathlib import Path
from sys import platform
from typing import Optional, Union

Expand Down Expand Up @@ -57,7 +49,8 @@
TensorDictSequential,
)
from torch import nn
from torchrl._utils import implement_for

from torchrl._utils import implement_for, logger as torchrl_logger
from torchrl.collectors.collectors import SyncDataCollector
from torchrl.data import (
BinaryDiscreteTensorSpec,
Expand All @@ -74,6 +67,8 @@
)
from torchrl.data.datasets.atari_dqn import AtariDQNExperienceReplay
from torchrl.data.datasets.d4rl import D4RLExperienceReplay

from torchrl.data.datasets.gen_dgrl import GenDGRLExperienceReplay
from torchrl.data.datasets.minari_data import MinariExperienceReplay
from torchrl.data.datasets.openml import OpenMLExperienceReplay
from torchrl.data.datasets.openx import OpenXExperienceReplay
Expand Down Expand Up @@ -114,13 +109,21 @@
from torchrl.envs.libs.robohive import _has_robohive, RoboHiveEnv
from torchrl.envs.libs.smacv2 import _has_smacv2, SMACv2Env
from torchrl.envs.libs.vmas import _has_vmas, VmasEnv, VmasWrapper

from torchrl.envs.transforms import ActionMask, TransformedEnv
from torchrl.envs.utils import (
check_env_specs,
ExplorationType,
MarlGroupMapType,
RandomPolicy,
)
from torchrl.modules import ActorCriticOperator, MLP, SafeModule, ValueOperator
from torchrl.modules import (
ActorCriticOperator,
MaskedCategorical,
MLP,
SafeModule,
ValueOperator,
)

_has_d4rl = importlib.util.find_spec("d4rl") is not None

Expand Down Expand Up @@ -3084,22 +3087,28 @@ def test_data(self, dataset):
)
@pytest.mark.parametrize("num_envs", [10, 20])
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("from_pixels", [True, False])
class TestIsaacGym:
@classmethod
def _run_on_proc(cls, q, task, num_envs, device):
def _run_on_proc(cls, q, task, num_envs, device, from_pixels):
try:
env = IsaacGymEnv(task=task, num_envs=num_envs, device=device)
env = IsaacGymEnv(
task=task, num_envs=num_envs, device=device, from_pixels=from_pixels
)
check_env_specs(env)
q.put(("succeeded!", None))
except Exception as err:
q.put(("failed!", err))
raise err

def test_env(self, task, num_envs, device):
def test_env(self, task, num_envs, device, from_pixels):
from torch import multiprocessing as mp

q = mp.Queue(1)
proc = mp.Process(target=self._run_on_proc, args=(q, task, num_envs, device))
self._run_on_proc(q, task, num_envs, device, from_pixels)
proc = mp.Process(
target=self._run_on_proc, args=(q, task, num_envs, device, from_pixels)
)
try:
proc.start()
msg, error = q.get()
Expand Down
16 changes: 10 additions & 6 deletions torchrl/envs/libs/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,6 +943,9 @@ def _reward_space(self, env): # noqa: F811
return rs

def _make_specs(self, env: "gym.Env", batch_size=None) -> None: # noqa: F821
# If batch_size is provided, we se it to tell what batch size must be used
# instead of self.batch_size
cur_batch_size = self.batch_size if batch_size is None else torch.Size([])
action_spec = _gym_to_torchrl_spec_transform(
env.action_space,
device=self.device,
Expand All @@ -956,14 +959,14 @@ def _make_specs(self, env: "gym.Env", batch_size=None) -> None: # noqa: F821
if not isinstance(observation_spec, CompositeSpec):
if self.from_pixels:
observation_spec = CompositeSpec(
pixels=observation_spec, shape=self.batch_size
pixels=observation_spec, shape=cur_batch_size
)
else:
observation_spec = CompositeSpec(
observation=observation_spec, shape=self.batch_size
observation=observation_spec, shape=cur_batch_size
)
elif observation_spec.shape[: len(self.batch_size)] != self.batch_size:
observation_spec.shape = self.batch_size
elif observation_spec.shape[: len(cur_batch_size)] != cur_batch_size:
observation_spec.shape = cur_batch_size

reward_space = self._reward_space(env)
if reward_space is not None:
Expand All @@ -983,10 +986,11 @@ def _make_specs(self, env: "gym.Env", batch_size=None) -> None: # noqa: F821
observation_spec = observation_spec.expand(
*batch_size, *observation_spec.shape
)

self.done_spec = self._make_done_spec()
self.action_spec = action_spec
if reward_spec.shape[: len(self.batch_size)] != self.batch_size:
self.reward_spec = reward_spec.expand(*self.batch_size, *reward_spec.shape)
if reward_spec.shape[: len(cur_batch_size)] != cur_batch_size:
self.reward_spec = reward_spec.expand(*cur_batch_size, *reward_spec.shape)
else:
self.reward_spec = reward_spec
self.observation_spec = observation_spec
Expand Down
50 changes: 33 additions & 17 deletions torchrl/envs/libs/isaacgym.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch

from tensordict import TensorDictBase
from torchrl.data import CompositeSpec
from torchrl.envs.libs.gym import GymWrapper
from torchrl.envs.utils import _classproperty, make_composite_from_td

Expand Down Expand Up @@ -49,19 +50,23 @@ def __init__(
warnings.warn(
"IsaacGym environment support is an experimental feature that may change in the future."
)
num_envs = env.num_envs
super().__init__(
env, torch.device(env.device), batch_size=torch.Size([num_envs]), **kwargs
env, torch.device(env.device), batch_size=torch.Size([]), **kwargs
)
if not hasattr(self, "task"):
# by convention in IsaacGymEnvs
self.task = env.__name__

def _make_specs(self, env: "gym.Env") -> None: # noqa: F821
super()._make_specs(env, batch_size=self.batch_size)
self.full_done_spec = {
key: spec.squeeze(-1) for key, spec in self.full_done_spec.items(True, True)
}
self.full_done_spec = CompositeSpec(
{
key: spec.squeeze(-1)
for key, spec in self.full_done_spec.items(True, True)
},
shape=self.batch_size,
)

self.observation_spec["obs"] = self.observation_spec["observation"]
del self.observation_spec["observation"]

Expand All @@ -78,7 +83,18 @@ def _make_specs(self, env: "gym.Env") -> None: # noqa: F821
obs_spec.unlock_()
obs_spec.update(specs)
obs_spec.lock_()
self.__dict__["full_observation_spec"] = obs_spec

def _output_transform(self, output):
obs, reward, done, info = output
if self.from_pixels:
obs["pixels"] = self._env.render(mode="rgb_array")
return obs, reward, done ^ done, done, done, info

def _reset_output_transform(self, reset_data):
reset_data.pop("reward", None)
if self.from_pixels:
reset_data["pixels"] = self._env.render(mode="rgb_array")
return reset_data, {}

@classmethod
def _make_envs(cls, *, task, num_envs, device, seed=None, headless=True, **kwargs):
Expand Down Expand Up @@ -125,15 +141,8 @@ def read_done(
done = done.bool()
return terminated, truncated, done, done.any()

def read_reward(self, total_reward, step_reward):
"""Reads a reward and the total reward so far (in the frame skip loop) and returns a sum of the two.
Args:
total_reward (torch.Tensor or TensorDict): total reward so far in the step
step_reward (reward in the format provided by the inner env): reward of this particular step
"""
return total_reward + step_reward
def read_reward(self, total_reward):
return total_reward

def read_obs(
self, observations: Union[Dict[str, Any], torch.Tensor, np.ndarray]
Expand Down Expand Up @@ -183,6 +192,13 @@ def __init__(self, task=None, *, env=None, num_envs, device, **kwargs):
raise RuntimeError("Cannot provide both `task` and `env` arguments.")
elif env is not None:
task = env
envs = self._make_envs(task=task, num_envs=num_envs, device=device, **kwargs)
from_pixels = kwargs.pop("from_pixels", False)
envs = self._make_envs(
task=task,
num_envs=num_envs,
device=device,
virtual_screen_capture=True,
**kwargs,
)
self.task = task
super().__init__(envs, **kwargs)
super().__init__(envs, from_pixels=from_pixels, **kwargs)
4 changes: 2 additions & 2 deletions torchrl/envs/libs/jumanji.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def _jumanji_to_torchrl_spec_transform(
dtype = numpy_to_torch_dtype_dict[spec.dtype]
return BoundedTensorSpec(
shape=shape,
low=np.asarray(spec.low),
high=np.asarray(spec.high),
low=np.asarray(spec.minimum),
high=np.asarray(spec.maximum),
dtype=dtype,
device=device,
)
Expand Down

0 comments on commit 23e2121

Please sign in to comment.