Skip to content

Commit

Permalink
[CI] Fix D4RL tests in CI (pytorch#976)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Mar 21, 2023
1 parent 8cd2a3a commit 47ddd32
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 51 deletions.
8 changes: 4 additions & 4 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1279,10 +1279,10 @@ workflows:
cu_version: cu117
name: unittest_linux_habitat_gpu_py3.8
python_version: '3.8'
# - unittest_linux_d4rl_gpu:
# cu_version: cu117
# name: unittest_linux_d4rl_gpu_py3.8
# python_version: '3.8'
- unittest_linux_d4rl_gpu:
cu_version: cu117
name: unittest_linux_d4rl_gpu_py3.8
python_version: '3.8'
- unittest_linux_jumanji_gpu:
cu_version: cu117
name: unittest_linux_jumanji_gpu_py3.8
Expand Down
9 changes: 9 additions & 0 deletions knowledge_base/MUJOCO_INSTALLATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,15 @@ issues when running `import mujoco_py` and some troubleshooting for each of them
```
_Solution_: This should disappear once `mesalib` is installed: `conda install -y -c conda-forge mesalib`
3.
```
ImportError: /lib/x86_64-linux-gnu/libstdc++.so.6: version `GLIBCXX_3.4.29' not found (required by /path/to/conda/envs/compile/bin/../lib/libOSMesa.so.8)
```
_Solution_: Install libgcc, e.g.: `conda install libgcc -y`. Then make sure that it is being loaded during execution:
```
export LD_PRELOAD=$LD_PRELOAD:/path/to/conda/envs/compile/lib/libstdc++.so.6
```
4.
```
FileNotFoundError: [Errno 2] No such file or directory: 'patchelf'
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def _main(argv):
"numpy",
"packaging",
"cloudpickle",
"tensordict>=0.0.3",
"tensordict>=0.1.0",
],
extras_require={
"atari": [
Expand Down
69 changes: 37 additions & 32 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1108,6 +1108,39 @@ def make_vmas():

@pytest.mark.skipif(not _has_d4rl, reason=f"D4RL not found: {D4RL_ERR}")
class TestD4RL:
@pytest.mark.parametrize("task", ["walker2d-medium-replay-v2"])
def test_terminate_on_end(self, task):
t0 = time.time()
data_true = D4RLExperienceReplay(
task,
split_trajs=True,
from_env=False,
terminate_on_end=True,
batch_size=2,
use_timeout_as_done=False,
)
_ = D4RLExperienceReplay(
task,
split_trajs=True,
from_env=False,
terminate_on_end=False,
batch_size=2,
use_timeout_as_done=False,
)
data_from_env = D4RLExperienceReplay(
task,
split_trajs=True,
from_env=True,
batch_size=2,
use_timeout_as_done=False,
)
keys = set(data_from_env._storage._storage.keys(True, True))
keys = keys.intersection(data_true._storage._storage.keys(True, True))
assert_allclose_td(
data_true._storage._storage.select(*keys),
data_from_env._storage._storage.select(*keys),
)

@pytest.mark.parametrize(
"task",
[
Expand All @@ -1116,8 +1149,8 @@ class TestD4RL:
# "maze2d-open-v0",
# "maze2d-open-dense-v0",
# "relocate-human-v1",
# "walker2d-medium-replay-v2",
"ant-medium-v2",
"walker2d-medium-replay-v2",
# "ant-medium-v2",
# # "flow-merge-random-v0",
# "kitchen-partial-v0",
# # "carla-town-v0",
Expand All @@ -1128,7 +1161,7 @@ def test_d4rl_dummy(self, task):
_ = D4RLExperienceReplay(task, split_trajs=True, from_env=True, batch_size=2)
print(f"completed test after {time.time()-t0}s")

@pytest.mark.parametrize("task", ["ant-medium-v2"])
@pytest.mark.parametrize("task", ["walker2d-medium-replay-v2"])
@pytest.mark.parametrize("split_trajs", [True, False])
@pytest.mark.parametrize("from_env", [True, False])
def test_dataset_build(self, task, split_trajs, from_env):
Expand All @@ -1146,35 +1179,7 @@ def test_dataset_build(self, task, split_trajs, from_env):
assert sim.shape[-1] == offline.shape[-1], key
print(f"completed test after {time.time()-t0}s")

@pytest.mark.parametrize("task", ["ant-medium-v2"])
def test_terminate_on_end(self, task):
t0 = time.time()
data_true = D4RLExperienceReplay(
task,
split_trajs=True,
from_env=False,
terminate_on_end=True,
batch_size=2,
)
_ = D4RLExperienceReplay(
task,
split_trajs=True,
from_env=False,
terminate_on_end=False,
batch_size=2,
)
data_from_env = D4RLExperienceReplay(
task, split_trajs=True, from_env=True, batch_size=2
)
keys = set(data_from_env._storage._storage.keys(True, True))
keys = keys.intersection(data_true._storage._storage.keys(True, True))
assert_allclose_td(
data_true._storage._storage.select(*keys),
data_from_env._storage._storage.select(*keys),
)
print(f"completed test after {time.time()-t0}s")

@pytest.mark.parametrize("task", ["ant-medium-v2"])
@pytest.mark.parametrize("task", ["walker2d-medium-replay-v2"])
@pytest.mark.parametrize("split_trajs", [True, False])
def test_d4rl_iteration(self, task, split_trajs):
t0 = time.time()
Expand Down
49 changes: 35 additions & 14 deletions torchrl/data/datasets/d4rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,15 @@ class D4RLExperienceReplay(TensorDictReplayBuffer):
containing meta-data and info entries that the former does
not possess.
.. note::
The keys in ``from_env=True`` and ``from_env=False`` *may* unexpectedly
differ. In particular, the ``"timeout"`` key (used to determine the
end of an episode) may be absent when ``from_env=False`` but present
otherwise, leading to a different slicing when ``traj_splits`` is enabled.
use_timeout_as_done (bool, optional): if ``True``, ``done = terminal | timeout``.
Otherwise, only the ``terminal`` key is used. Defaults to ``True``.
**env_kwargs (key-value pairs): additional kwargs for
:func:`d4rl.qlearning_dataset`. Supports ``terminate_on_end``
(``False`` by default) or other kwargs if defined by D4RL library.
Expand Down Expand Up @@ -105,16 +114,20 @@ def __init__(
transform: Optional["Transform"] = None, # noqa-F821
split_trajs: bool = False,
from_env: bool = True,
use_timeout_as_done: bool = True,
**env_kwargs,
):

if not _has_d4rl:
raise ImportError("Could not import d4rl") from D4RL_ERR
self.from_env = from_env
self.use_timeout_as_done = use_timeout_as_done
if from_env:
dataset = self._get_dataset_from_env(name, env_kwargs)
else:
dataset = self._get_dataset_direct(name, env_kwargs)
# Fill unknown next states with 0
dataset["next", "observation"][dataset["next", "done"].squeeze()] = 0

if split_trajs:
dataset = split_trajectories(dataset)
Expand Down Expand Up @@ -160,11 +173,14 @@ def _get_dataset_direct(self, name, env_kwargs):
dataset.rename_key("terminals", "terminal")
if "timeouts" in dataset.keys():
dataset.rename_key("timeouts", "timeout")
dataset.set(
"done",
dataset.get("terminal")
| dataset.get("timeout", torch.zeros((), dtype=torch.bool)),
)
if self.use_timeout_as_done:
dataset.set(
"done",
dataset.get("terminal")
| dataset.get("timeout", torch.zeros((), dtype=torch.bool)),
)
else:
dataset.set("done", dataset.get("terminal"))
dataset.rename_key("rewards", "reward")
dataset.rename_key("actions", "action")

Expand All @@ -183,9 +199,10 @@ def _get_dataset_direct(self, name, env_kwargs):
dataset["next"].update(
dataset.select("reward", "done", "terminal", "timeout", strict=False)
)
dataset = (
dataset.clone()
) # make sure that all tensors have a different data_ptr
self._shift_reward_done(dataset)
# Fill unknown next states with 0
dataset["next", "observation"][dataset["next", "done"].squeeze()] = 0
self.specs = env.specs.clone()
return dataset

Expand Down Expand Up @@ -223,11 +240,14 @@ def _get_dataset_from_env(self, name, env_kwargs):
dataset.rename_key("terminals", "terminal")
if "timeouts" in dataset.keys():
dataset.rename_key("timeouts", "timeout")
dataset.set(
"done",
dataset.get("terminal")
| dataset.get("timeout", torch.zeros((), dtype=torch.bool)),
)
if self.use_timeout_as_done:
dataset.set(
"done",
dataset.get("terminal")
| dataset.get("timeout", torch.zeros((), dtype=torch.bool)),
)
else:
dataset.set("done", dataset.get("terminal"))
dataset.rename_key("rewards", "reward")
dataset.rename_key("actions", "action")
try:
Expand All @@ -253,9 +273,10 @@ def _get_dataset_from_env(self, name, env_kwargs):
dataset["next"].update(
dataset.select("reward", "done", "terminal", "timeout", strict=False)
)
dataset = (
dataset.clone()
) # make sure that all tensors have a different data_ptr
self._shift_reward_done(dataset)
# Fill unknown next states with 0
dataset["next", "observation"][dataset["next", "done"].squeeze()] = 0
self.specs = env.specs.clone()
return dataset

Expand Down

0 comments on commit 47ddd32

Please sign in to comment.