Skip to content

Commit

Permalink
[BugFix] Fix habitat (pytorch#1941)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Feb 20, 2024
1 parent 93885b6 commit 13bef42
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 3 deletions.
5 changes: 3 additions & 2 deletions .github/unittest/linux_libs/scripts_habitat/setup_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ pip install pip --upgrade

conda env update --file "${this_dir}/environment.yml" --prune

conda install habitat-sim withbullet headless -c conda-forge -c aihabitat-nightly -y
conda run python -m pip install git+https://github.com/facebookresearch/habitat-lab.git#subdirectory=habitat-lab
#conda install habitat-sim withbullet headless -c conda-forge -c aihabitat -y
conda install habitat-sim withbullet headless -c conda-forge -c aihabitat -y
conda run python -m pip install git+https://github.com/facebookresearch/habitat-lab.git@stable#subdirectory=habitat-lab
#conda run python -m pip install git+https://github.com/facebookresearch/habitat-lab.git#subdirectory=habitat-baselines
conda run python -m pip install "gym[atari,accept-rom-license]" pygame
3 changes: 2 additions & 1 deletion test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,8 @@ def test_mb_env_batch_lock(self, device, seed=0):
mb_env.step(td)

with pytest.raises(
RuntimeError, match=re.escape("Expected a tensordict with shape==env.batch_size")
RuntimeError,
match=re.escape("Expected a tensordict with shape==env.batch_size"),
):
mb_env.step(td_expanded)

Expand Down
1 change: 1 addition & 0 deletions torchrl/envs/libs/habitat.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def __init__(self, env_name, **kwargs):
device_num = torch.device(kwargs.pop("device", 0)).index
kwargs["override_options"] = [
f"habitat.simulator.habitat_sim_v0.gpu_device_id={device_num}",
"habitat.simulator.concur_render=False",
]
super().__init__(env_name=env_name, **kwargs)

Expand Down

0 comments on commit 13bef42

Please sign in to comment.