Skip to content

Commit

Permalink
[Algorithm] Online Decision transformer (pytorch#1149)
Browse files Browse the repository at this point in the history
Co-authored-by: vmoens <vincentmoens@gmail.com>
Co-authored-by: Mateusz Guzek <matguzek@meta.com>
  • Loading branch information
3 people authored Aug 30, 2023
1 parent 121ecd9 commit b444007
Show file tree
Hide file tree
Showing 32 changed files with 2,537 additions and 73 deletions.
1 change: 1 addition & 0 deletions .circleci/unittest/linux/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@ dependencies:
- av
- coverage
- ray
- transformers
3 changes: 1 addition & 2 deletions .circleci/unittest/linux_examples/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@ dependencies:
- pyyaml
- scipy
- hydra-core
- tensorboard
- imageio==2.26.0
- wandb
- dm_control
- mlflow
- av
- coverage
- vmas
- transformers
53 changes: 30 additions & 23 deletions .circleci/unittest/linux_examples/scripts/run_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,16 @@ set -v
# ================================ Init ============================================== #


if [[ $OSTYPE != 'darwin'* ]]; then
apt-get update && apt-get upgrade -y
apt-get install -y vim git wget

apt-get install -y libglfw3 libgl1-mesa-glx libosmesa6 libglew-dev
apt-get install -y libglvnd0 libgl1 libglx0 libegl1 libgles2

if [ "${CU_VERSION:-}" == cpu ] ; then
# solves version `GLIBCXX_3.4.29' not found for tensorboard
# apt-get install -y gcc-4.9
apt-get upgrade -y libstdc++6
apt-get dist-upgrade -y
else
apt-get install -y g++ gcc
fi
apt-get update && apt-get upgrade -y
apt-get install -y vim git wget

fi
apt-get install -y libglfw3 libgl1-mesa-glx libosmesa6 libglew-dev libosmesa6-dev
apt-get install -y libglvnd0 libgl1 libglx0 libegl1 libgles2
apt-get install -y g++ gcc patchelf

this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
if [[ $OSTYPE != 'darwin'* ]]; then
# from cudagl docker image
cp $this_dir/10_nvidia.json /usr/share/glvnd/egl_vendor.d/10_nvidia.json
fi
# from cudagl docker image
cp $this_dir/10_nvidia.json /usr/share/glvnd/egl_vendor.d/10_nvidia.json


# ==================================================================================== #
Expand Down Expand Up @@ -69,8 +56,8 @@ conda activate "${env_dir}"
printf "* Installing mujoco and related\n"
mkdir -p $root_dir/.mujoco
cd $root_dir/.mujoco/
wget https://github.com/deepmind/mujoco/releases/download/2.1.1/mujoco-2.1.1-linux-x86_64.tar.gz
tar -xf mujoco-2.1.1-linux-x86_64.tar.gz
#wget https://github.com/deepmind/mujoco/releases/download/2.1.1/mujoco-2.1.1-linux-x86_64.tar.gz
#tar -xf mujoco-2.1.1-linux-x86_64.tar.gz
wget https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz
tar -xf mujoco210-linux-x86_64.tar.gz
cd "${root_dir}"
Expand All @@ -80,9 +67,16 @@ printf "* Installing dependencies (except PyTorch)\n"
echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml"
cat "${this_dir}/environment.yml"

export MUJOCO_PY_MUJOCO_PATH=$root_dir/.mujoco/mujoco210
export DISPLAY=unix:0.0
#export MJLIB_PATH=$root_dir/.mujoco/mujoco-2.1.1/lib/libmujoco.so.2.1.1
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$root_dir/.mujoco/mujoco210/bin
export SDL_VIDEODRIVER=dummy
export MUJOCO_GL=egl
export PYOPENGL_PLATFORM=egl

conda env config vars set MUJOCO_PY_MUJOCO_PATH=$root_dir/.mujoco/mujoco210 \
DISPLAY=unix:0.0 \
MJLIB_PATH=$root_dir/.mujoco/mujoco-2.1.1/lib/libmujoco.so.2.1.1 \
LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$root_dir/.mujoco/mujoco210/bin \
SDL_VIDEODRIVER=dummy \
MUJOCO_GL=egl \
Expand All @@ -95,6 +89,19 @@ conda env update --file "${this_dir}/environment.yml" --prune
conda deactivate
conda activate "${env_dir}"

# install d4rl
pip install free-mujoco-py
pip install git+https://github.com/Farama-Foundation/d4rl@master#egg=d4rl

# TODO: move this down -- will break torchrl installation
conda install -y -c conda-forge libstdcxx-ng=12
## find libstdc
STDC_LOC=$(find conda/ -name "libstdc++.so.6" | head -1)
conda env config vars set LD_PRELOAD=${root_dir}/$STDC_LOC

# compile mujoco-py (bc it's done at runtime for whatever reason someone thought it was a good idea)
python -c """import gym;import d4rl"""

# install ale-py: manylinux names are broken for CentOS so we need to manually download and
# rename them
PY_VERSION=$(python --version)
Expand Down
27 changes: 24 additions & 3 deletions .circleci/unittest/linux_examples/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,28 @@ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir
export MKL_THREADING_LAYER=GNU

python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test.py -v --durations 200
python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test_deps.py -v --durations 200
#python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test_deps.py -v --durations 200

# ==================================================================================== #
# ================================ gym 0.23 ========================================== #

# With batched environments
python .circleci/unittest/helpers/coverage_run_parallel.py examples/decision_transformer/dt.py \
optim.pretrain_gradient_steps=55 \
optim.updates_per_episode=3 \
optim.warmup_steps=10 \
optim.device=cuda:0 \
logger.backend=
python .circleci/unittest/helpers/coverage_run_parallel.py examples/decision_transformer/online_dt.py \
optim.pretrain_gradient_steps=55 \
optim.updates_per_episode=3 \
optim.warmup_steps=10 \
optim.device=cuda:0 \
logger.backend=

# ==================================================================================== #
# ================================ Gymnasium ========================================= #

python .circleci/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo.py \
env.num_envs=1 \
env.device=cuda:0 \
Expand Down Expand Up @@ -136,7 +155,8 @@ python .circleci/unittest/helpers/coverage_run_parallel.py examples/iql/iql_onli
env_per_collector=2 \
collector_device=cuda:0 \
device=cuda:0 \
mode=offline
mode=offline \
logger=

# With single envs
python .circleci/unittest/helpers/coverage_run_parallel.py examples/dreamer/dreamer.py \
Expand Down Expand Up @@ -234,7 +254,8 @@ python .circleci/unittest/helpers/coverage_run_parallel.py examples/iql/iql_onli
env_per_collector=1 \
mode=offline \
device=cuda:0 \
collector_device=cuda:0
collector_device=cuda:0 \
logger=
python .circleci/unittest/helpers/coverage_run_parallel.py examples/td3/td3.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
Expand Down
22 changes: 21 additions & 1 deletion benchmarks/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def pytest_addoption(parser):
parser.addoption("--rank", action="store")


@pytest.fixture(autouse=True)
@pytest.fixture(scope="session", autouse=True)
def set_warnings() -> None:
warnings.filterwarnings(
"ignore",
Expand All @@ -69,3 +69,23 @@ def set_warnings() -> None:
category=UserWarning,
message=r"Couldn't cast the policy onto the desired device on remote process",
)
warnings.filterwarnings(
"ignore",
category=DeprecationWarning,
message=r"Deprecated call to `pkg_resources.declare_namespace",
)
warnings.filterwarnings(
"ignore",
category=DeprecationWarning,
message=r"Using or importing the ABCs",
)
warnings.filterwarnings(
"ignore",
category=DeprecationWarning,
message=r"Please use `coo_matrix` from the `scipy.sparse` namespace",
)
warnings.filterwarnings(
"ignore",
category=DeprecationWarning,
message=r"jax.tree_util.register_keypaths is deprecated|jax.ShapedArray is deprecated",
)
13 changes: 8 additions & 5 deletions docs/source/reference/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ without shared parameters. It is mainly intended as a replacement for
ActorCriticWrapper
ActorValueOperator
ValueOperator

DecisionTransformerInferenceWrapper

Domain-specific TensorDict modules
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -322,18 +322,21 @@ algorithms, such as DQN, DDPG or Dreamer.
:toctree: generated/
:template: rl_template_noinherit.rst

DuelingCnnDQNet
DistributionalDQNnet
DTActor
DdpgCnnActor
DdpgCnnQNet
DdpgMlpActor
DdpgMlpQNet
DecisionTransformer
DistributionalDQNnet
DreamerActor
DuelingCnnDQNet
LSTMModule
ObsEncoder
ObsDecoder
RSSMPrior
ObsEncoder
OnlineDTActor
RSSMPosterior
RSSMPrior

Multi-agent-specific modules
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
10 changes: 10 additions & 0 deletions docs/source/reference/objectives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,16 @@ CQL

CQLLoss

DT
----

.. autosummary::
:toctree: generated/
:template: rl_template_noinherit.rst

DTLoss
OnlineDTLoss

TD3
----

Expand Down
98 changes: 98 additions & 0 deletions examples/decision_transformer/dt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Decision Transformer Example.
This is a self-contained example of an offline Decision Transformer training script.
The helper functions are coded in the utils.py associated with this script.
"""

import hydra
import torch
import tqdm

from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules.tensordict_module import DecisionTransformerInferenceWrapper

from utils import (
make_dt_loss,
make_dt_model,
make_dt_optimizer,
make_env,
make_logger,
make_offline_replay_buffer,
)


@hydra.main(config_path=".", config_name="dt_config")
def main(cfg: "DictConfig"): # noqa: F821
model_device = cfg.optim.device
logger = make_logger(cfg)
offline_buffer, obs_loc, obs_std = make_offline_replay_buffer(
cfg.replay_buffer, cfg.env.reward_scaling
)
test_env = make_env(cfg.env, obs_loc, obs_std)
actor = make_dt_model(cfg)
policy = actor.to(model_device)

loss_module = make_dt_loss(cfg.loss, actor)
transformer_optim, scheduler = make_dt_optimizer(cfg.optim, loss_module)
inference_policy = DecisionTransformerInferenceWrapper(
policy=policy,
inference_context=cfg.env.inference_context,
).to(model_device)

pbar = tqdm.tqdm(total=cfg.optim.pretrain_gradient_steps)

r0 = None
l0 = None

pretrain_gradient_steps = cfg.optim.pretrain_gradient_steps
clip_grad = cfg.optim.clip_grad
eval_steps = cfg.logger.eval_steps
pretrain_log_interval = cfg.logger.pretrain_log_interval
reward_scaling = cfg.env.reward_scaling

print(" ***Pretraining*** ")
# Pretraining
for i in range(pretrain_gradient_steps):
pbar.update(i)
data = offline_buffer.sample()
# loss
loss_vals = loss_module(data.to(model_device))
# backprop
transformer_loss = loss_vals["loss"]

transformer_optim.zero_grad()
torch.nn.utils.clip_grad_norm_(policy.parameters(), clip_grad)
transformer_loss.backward()
transformer_optim.step()

scheduler.step()

# evaluation
with set_exploration_type(ExplorationType.MEAN), torch.no_grad():
if i % pretrain_log_interval == 0:
eval_td = test_env.rollout(
max_steps=eval_steps,
policy=inference_policy,
auto_cast_to_device=True,
)
if r0 is None:
r0 = eval_td["next", "reward"].sum(1).mean().item() / reward_scaling
if l0 is None:
l0 = transformer_loss.item()

eval_reward = eval_td["next", "reward"].sum(1).mean().item() / reward_scaling
if logger is not None:
for key, value in loss_vals.items():
logger.log_scalar(key, value.item(), i)
logger.log_scalar("evaluation reward", eval_reward, i)

pbar.set_description(
f"[Pre-Training] loss: {transformer_loss.item(): 4.4f} (init: {l0: 4.4f}), evaluation reward: {eval_reward: 4.4f} (init={r0: 4.4f})"
)


if __name__ == "__main__":
main()
Loading

0 comments on commit b444007

Please sign in to comment.