Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CI] Fix D4RL tests in CI #976

Merged
merged 11 commits into from
Mar 21, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
amend
  • Loading branch information
vmoens committed Mar 17, 2023
commit 02d4a0313ec2be06c6f64a7de997599199b63e04
9 changes: 9 additions & 0 deletions knowledge_base/MUJOCO_INSTALLATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,15 @@ issues when running `import mujoco_py` and some troubleshooting for each of them
```

_Solution_: This should disappear once `mesalib` is installed: `conda install -y -c conda-forge mesalib`
3.
```
ImportError: /lib/x86_64-linux-gnu/libstdc++.so.6: version `GLIBCXX_3.4.29' not found (required by /path/to/conda/envs/compile/bin/../lib/libOSMesa.so.8)
```

_Solution_: Install libgcc, e.g.: `conda install libgcc -y`. Then make sure that it is being loaded during execution:
```
export LD_PRELOAD=$LD_PRELOAD:/path/to/conda/envs/compile/lib/libstdc++.so.6
```
4.
```
FileNotFoundError: [Errno 2] No such file or directory: 'patchelf'
Expand Down
4 changes: 3 additions & 1 deletion test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1117,16 +1117,18 @@ def test_terminate_on_end(self, task):
from_env=False,
terminate_on_end=True,
batch_size=2,
use_timeout_as_done=False,
)
_ = D4RLExperienceReplay(
task,
split_trajs=True,
from_env=False,
terminate_on_end=False,
batch_size=2,
use_timeout_as_done=False,
)
data_from_env = D4RLExperienceReplay(
task, split_trajs=True, from_env=True, batch_size=2
task, split_trajs=True, from_env=True, batch_size=2, use_timeout_as_done=False,
)
keys = set(data_from_env._storage._storage.keys(True, True))
keys = keys.intersection(data_true._storage._storage.keys(True, True))
Expand Down
37 changes: 27 additions & 10 deletions torchrl/data/datasets/d4rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,15 @@ class D4RLExperienceReplay(TensorDictReplayBuffer):
containing meta-data and info entries that the former does
not possess.

.. note::

The keys in ``from_env=True`` and ``from_env=False`` *may* unexpectedly
differ. In particular, the ``"timeout"`` key (used to determine the
end of an episode) may be absent when ``from_env=False`` but present
otherwise, leading to a different slicing when ``traj_splits`` is enabled.

use_timeout_as_done (bool, optional): if ``True``, ``done = terminal | timeout``.
Otherwise, only the ``terminal`` key is used. Defaults to ``True``.
**env_kwargs (key-value pairs): additional kwargs for
:func:`d4rl.qlearning_dataset`. Supports ``terminate_on_end``
(``False`` by default) or other kwargs if defined by D4RL library.
Expand Down Expand Up @@ -105,12 +114,14 @@ def __init__(
transform: Optional["Transform"] = None, # noqa-F821
split_trajs: bool = False,
from_env: bool = True,
use_timeout_as_done: bool = True,
**env_kwargs,
):

if not _has_d4rl:
raise ImportError("Could not import d4rl") from D4RL_ERR
self.from_env = from_env
self.use_timeout_as_done = use_timeout_as_done
if from_env:
dataset = self._get_dataset_from_env(name, env_kwargs)
else:
Expand Down Expand Up @@ -160,11 +171,14 @@ def _get_dataset_direct(self, name, env_kwargs):
dataset.rename_key("terminals", "terminal")
if "timeouts" in dataset.keys():
dataset.rename_key("timeouts", "timeout")
dataset.set(
"done",
dataset.get("terminal")
| dataset.get("timeout", torch.zeros((), dtype=torch.bool)),
)
if self.use_timeout_as_done:
dataset.set(
"done",
dataset.get("terminal")
| dataset.get("timeout", torch.zeros((), dtype=torch.bool)),
)
else:
dataset.set("done", dataset.get("terminal"))
dataset.rename_key("rewards", "reward")
dataset.rename_key("actions", "action")

Expand Down Expand Up @@ -223,11 +237,14 @@ def _get_dataset_from_env(self, name, env_kwargs):
dataset.rename_key("terminals", "terminal")
if "timeouts" in dataset.keys():
dataset.rename_key("timeouts", "timeout")
dataset.set(
"done",
dataset.get("terminal")
| dataset.get("timeout", torch.zeros((), dtype=torch.bool)),
)
if self.use_timeout_as_done:
dataset.set(
"done",
dataset.get("terminal")
| dataset.get("timeout", torch.zeros((), dtype=torch.bool)),
)
else:
dataset.set("done", dataset.get("terminal"))
dataset.rename_key("rewards", "reward")
dataset.rename_key("actions", "action")
try:
Expand Down