Skip to content

Commit

Permalink
Fault-tolerant parallel env execution (pytorch#171)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored May 30, 2022
1 parent bc84a55 commit 9f56d97
Show file tree
Hide file tree
Showing 19 changed files with 203 additions and 138 deletions.
131 changes: 65 additions & 66 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -232,22 +232,21 @@ jobs:
name: Generate cache key
# This will refresh cache on Sundays, nightly build should generate new cache.
command: echo "$(date +"%Y-%U")" > .circleci-weekly
- restore_cache:

keys:
- env-v2-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}
# - restore_cache:
# keys:
# - env-v2-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}

- run:
name: Setup
command: .circleci/unittest/linux/scripts/setup_env.sh

- save_cache:

key: env-v2-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}

paths:
- conda
- env
# - save_cache:
#
# key: env-v2-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}
#
# paths:
# - conda
# - env
- run:
name: Install torchrl
command: .circleci/unittest/linux/scripts/install.sh
Expand Down Expand Up @@ -277,21 +276,21 @@ jobs:
name: Generate cache key
# This will refresh cache on Sundays, nightly build should generate new cache.
command: echo "$(date +"%Y-%U")" > .circleci-weekly
- restore_cache:

keys:
- env-v3-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}
# - restore_cache:
#
# keys:
# - env-v3-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}

- run:
name: Setup
command: docker run -e PYTHON_VERSION -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux/scripts/setup_env.sh
- save_cache:

key: env-v3-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}

paths:
- conda
- env
# - save_cache:
#
# key: env-v3-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}
#
# paths:
# - conda
# - env
- run:
# Here we create an envlist file that contains some env variables that we want the docker container to be aware of.
# Normally, the CIRCLECI variable is set and available on all CI workflows: https://circleci.com/docs/2.0/env-vars/#built-in-environment-variables.
Expand Down Expand Up @@ -329,21 +328,21 @@ jobs:
name: Generate cache key
# This will refresh cache on Sundays, nightly build should generate new cache.
command: echo "$(date +"%Y-%U")" > .circleci-weekly
- restore_cache:

keys:
- env-v3-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux_optdeps/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}
# - restore_cache:
#
# keys:
# - env-v3-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux_optdeps/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}

- run:
name: Setup
command: docker run -e PYTHON_VERSION -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux_optdeps/scripts/setup_env.sh
- save_cache:

key: env-v3-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux_optdeps/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}

paths:
- conda
- env
# - save_cache:
#
# key: env-v3-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux_optdeps/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}
#
# paths:
# - conda
# - env
- run:
# Here we create an envlist file that contains some env variables that we want the docker container to be aware of.
# Normally, the CIRCLECI variable is set and available on all CI workflows: https://circleci.com/docs/2.0/env-vars/#built-in-environment-variables.
Expand Down Expand Up @@ -382,22 +381,22 @@ jobs:
name: Generate cache key
# This will refresh cache on Sundays, nightly build should generate new cache.
command: echo "$(date +"%Y-%U")" > .circleci-weekly
- restore_cache:

keys:
- env-v2-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux_stable/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}
# - restore_cache:
#
# keys:
# - env-v2-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux_stable/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}

- run:
name: Setup
command: .circleci/unittest/linux_stable/scripts/setup_env.sh

- save_cache:

key: env-v2-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux_stable/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}

paths:
- conda
- env
# - save_cache:
#
# key: env-v2-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux_stable/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}
#
# paths:
# - conda
# - env
- run:
name: Install torchrl
command: .circleci/unittest/linux_stable/scripts/install.sh
Expand Down Expand Up @@ -427,21 +426,21 @@ jobs:
name: Generate cache key
# This will refresh cache on Sundays, nightly build should generate new cache.
command: echo "$(date +"%Y-%U")" > .circleci-weekly
- restore_cache:

keys:
- env-v3-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux_stable/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}
# - restore_cache:
#
# keys:
# - env-v3-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux_stable/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}

- run:
name: Setup
command: docker run -e PYTHON_VERSION -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux_stable/scripts/setup_env.sh
- save_cache:

key: env-v3-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux_stable/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}

paths:
- conda
- env
# - save_cache:
#
# key: env-v3-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux_stable/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}
#
# paths:
# - conda
# - env
- run:
# Here we create an envlist file that contains some env variables that we want the docker container to be aware of.
# Normally, the CIRCLECI variable is set and available on all CI workflows: https://circleci.com/docs/2.0/env-vars/#built-in-environment-variables.
Expand Down Expand Up @@ -479,21 +478,21 @@ jobs:
name: Generate cache key
# This will refresh cache on Sundays, nightly build should generate new cache.
command: echo "$(date +"%Y-%U")" > .circleci-weekly
- restore_cache:

keys:
- env-v3-macos-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}
# - restore_cache:
#
# keys:
# - env-v3-macos-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}

- run:
name: Setup
command: .circleci/unittest/linux/scripts/setup_env.sh
- save_cache:

key: env-v3-macos-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}

paths:
- conda
- env
# - save_cache:
#
# key: env-v3-macos-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}
#
# paths:
# - conda
# - env
- run:
name: Install torchrl
command: .circleci/unittest/linux/scripts/install.sh
Expand Down
2 changes: 2 additions & 0 deletions .circleci/unittest/linux/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ fi

export PYTORCH_TEST_WITH_SLOW='1'
python -m torch.utils.collect_env
# Avoid error: "fatal: unsafe repository"
git config --global --add safe.directory '*'
root_dir="$(git rev-parse --show-toplevel)"
export MJLIB_PATH=$root_dir/.mujoco/mujoco-2.1.1/lib/libmujoco.so.2.1.1
export MUJOCO_PY_MUJOCO_PATH=$root_dir/.mujoco/mujoco210
Expand Down
2 changes: 2 additions & 0 deletions .circleci/unittest/linux/scripts/setup_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
set -e

this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
# Avoid error: "fatal: unsafe repository"
git config --global --add safe.directory '*'
root_dir="$(git rev-parse --show-toplevel)"
conda_dir="${root_dir}/conda"
env_dir="${root_dir}/env"
Expand Down
2 changes: 2 additions & 0 deletions .circleci/unittest/linux_optdeps/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ conda activate ./env

export PYTORCH_TEST_WITH_SLOW='1'
python -m torch.utils.collect_env
# Avoid error: "fatal: unsafe repository"
git config --global --add safe.directory '*'
root_dir="$(git rev-parse --show-toplevel)"
export MJLIB_PATH=$root_dir/.mujoco/mujoco-2.1.1/lib/libmujoco.so.2.1.1
export MUJOCO_PY_MUJOCO_PATH=$root_dir/.mujoco/mujoco210
Expand Down
2 changes: 2 additions & 0 deletions .circleci/unittest/linux_optdeps/scripts/setup_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
set -e

this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
# Avoid error: "fatal: unsafe repository"
git config --global --add safe.directory '*'
root_dir="$(git rev-parse --show-toplevel)"
conda_dir="${root_dir}/conda"
env_dir="${root_dir}/env"
Expand Down
2 changes: 2 additions & 0 deletions .circleci/unittest/linux_stable/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ fi

export PYTORCH_TEST_WITH_SLOW='1'
python -m torch.utils.collect_env
# Avoid error: "fatal: unsafe repository"
git config --global --add safe.directory '*'
root_dir="$(git rev-parse --show-toplevel)"
export MJLIB_PATH=$root_dir/.mujoco/mujoco-2.1.1/lib/libmujoco.so.2.1.1
export MUJOCO_PY_MUJOCO_PATH=$root_dir/.mujoco/mujoco210
Expand Down
2 changes: 2 additions & 0 deletions .circleci/unittest/linux_stable/scripts/setup_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
set -e

this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
# Avoid error: "fatal: unsafe repository"
git config --global --add safe.directory '*'
root_dir="$(git rev-parse --show-toplevel)"
conda_dir="${root_dir}/conda"
env_dir="${root_dir}/env"
Expand Down
5 changes: 2 additions & 3 deletions examples/ddpg/configs/humanoid_pixels.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
env_name=humanoid
env_task=walk
env_library=dm_control
async_collection
record_video
prb
exp_name=humanoid

num_workers=32
env_per_collector=8
num_workers=4
env_per_collector=1

frame_skip=2
frames_per_batch=500
Expand Down
14 changes: 9 additions & 5 deletions examples/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import argparse

_configargparse = False

import torch.cuda
from torch.utils.tensorboard import SummaryWriter
from torchrl.envs.transforms import RewardScaling, TransformedEnv
Expand Down Expand Up @@ -117,10 +118,11 @@ def main(args):

model = make_ddpg_actor(
proof_env,
args,
args=args,
device=device,
)
loss_module, target_net_updater = make_ddpg_loss(model, args)

actor_model_explore = model[0]
if args.ou_exploration:
if args.gSDE:
Expand All @@ -131,15 +133,16 @@ def main(args):
if device == torch.device("cpu"):
# mostly for debugging
actor_model_explore.share_memory()

if args.gSDE:
if args.gSDE:
with torch.no_grad(), set_exploration_mode("random"):
# get dimensions to build the parallel env
proof_td = actor_model_explore(proof_env.reset().to(device))
with torch.no_grad(), set_exploration_mode("random"):
# get dimensions to build the parallel env
proof_td = actor_model_explore(proof_env.reset().to(device))
action_dim_gsde, state_dim_gsde = proof_td.get("_eps_gSDE").shape[-2:]
del proof_td
else:
action_dim_gsde, state_dim_gsde = None, None

proof_env.close()
create_env_fn = parallel_env_constructor(
args=args,
Expand Down Expand Up @@ -174,6 +177,7 @@ def main(args):
recorder_rm = TransformedEnv(recorder.env, recorder.transform[1:])
else:
recorder_rm = recorder

if isinstance(create_env_fn, ParallelEnv):
recorder_rm.load_state_dict(create_env_fn.state_dict()["worker0"])
create_env_fn.close()
Expand Down
7 changes: 2 additions & 5 deletions examples/redq/configs/humanoid_pixels.txt
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
env_name=humanoid
env_task=walk
env_library=dm_control
async_collection
record_video
prb
normalize_rewards_online
normalize_rewards_online_scale=10
exp_name=humanoid
tanh_loc

num_workers=32
env_per_collector=8
num_workers=4
env_per_collector=1

frame_skip=2
frames_per_batch=500
Expand Down
5 changes: 4 additions & 1 deletion examples/redq/redq.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,15 @@ def main(args):
)()
model = make_redq_model(
proof_env,
device=device,
args=args,
device=device,
)
loss_module, target_net_updater = make_redq_loss(model, args)

actor_model_explore = model[0]
if args.ou_exploration:
if args.gSDE:
raise RuntimeError("gSDE and ou_exploration are incompatible")
actor_model_explore = OrnsteinUhlenbeckProcessWrapper(
actor_model_explore, annealing_num_steps=args.annealing_frames
).to(device)
Expand Down Expand Up @@ -166,6 +168,7 @@ def main(args):
norm_obs_only=True,
stats=stats,
writer=writer,
use_env_creator=False,
)()

# remove video recorder from recorder to have matching state_dict keys
Expand Down
Loading

0 comments on commit 9f56d97

Please sign in to comment.