Skip to content

Commit

Permalink
[BugFix] Fix Cython for D4RL (pytorch#1429)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Aug 8, 2023
1 parent 2fe836c commit e39e701
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 15 deletions.
2 changes: 1 addition & 1 deletion .circleci/unittest/linux_libs/scripts_d4rl/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ ln -s /usr/bin/swig3.0 /usr/bin/swig
# we install d4rl here bc env variables have been updated
git clone https://github.com/Farama-Foundation/d4rl.git
cd d4rl
pip3 install -U 'mujoco-py<2.1,>=2.0'
#pip3 install -U 'mujoco-py<2.1,>=2.0'
pip3 install -U "gym[classic_control,atari,accept-rom-license]"==0.23
pip3 install -U six
pip install -e .
Expand Down
12 changes: 11 additions & 1 deletion .circleci/unittest/linux_libs/scripts_d4rl/setup_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# Do not install PyTorch and torchvision here, otherwise they also get cached.

set -e
set -v

this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
# Avoid error: "fatal: unsafe repository"
Expand Down Expand Up @@ -39,6 +40,12 @@ if [ ! -d "${env_dir}" ]; then
fi
conda activate "${env_dir}"

#pip3 uninstall cython -y
#pip uninstall cython -y
#conda uninstall cython -y
pip3 install "cython<3"
conda install -c anaconda cython="<3.0.0" -y


# 3. Install mujoco
printf "* Installing mujoco and related\n"
Expand All @@ -53,14 +60,17 @@ wget https://www.roboti.us/file/mjkey.txt
cp mjkey.txt ./mujoco200_linux/bin/
# install mujoco-py locally
git clone https://github.com/vmoens/mujoco-py.git
cd mujoco-py
git checkout v2.0.2.1
pip install -e .
cd $this_dir

# 4. Install Conda dependencies
printf "* Installing dependencies (except PyTorch)\n"
echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml"
cat "${this_dir}/environment.yml"

pip install pip --upgrade
pip3 install pip --upgrade

# 5. env variables
if [[ $OSTYPE == 'darwin'* ]]; then
Expand Down
5 changes: 5 additions & 0 deletions .circleci/unittest/linux_libs/scripts_habitat/setup_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ if [ ! -d "${env_dir}" ]; then
conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION"
fi
conda activate "${env_dir}"
#pip3 uninstall cython -y
#pip uninstall cython -y
#conda uninstall cython -y
pip3 install "cython<3"
conda install -c anaconda cython="<3.0.0" -y


# 3. Install Conda dependencies
Expand Down
13 changes: 0 additions & 13 deletions test/test_postprocs.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,19 +320,6 @@ def test_splits(self, num_workers, traj_len):
== split_trajs.get(("collector", "traj_ids")).max() + 1
)

@pytest.mark.parametrize("num_workers", range(3, 34, 3))
@pytest.mark.parametrize("traj_len", [10, 17, 50, 97])
def test_splits_notraj(self, num_workers, traj_len):

trajs = TestSplits.create_fake_trajs(num_workers, traj_len)
trajs_pop = trajs.clone()
del trajs_pop[("collector", "traj_ids")]
split_trajs = split_trajectories(trajs, prefix="collector")
split_trajs_pop = split_trajectories(trajs_pop, prefix="collector")
del split_trajs[("collector", "traj_ids")]
del split_trajs_pop[("collector", "traj_ids")]
assert (split_trajs == split_trajs_pop).all()


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
Expand Down

0 comments on commit e39e701

Please sign in to comment.