Skip to content

Commit

Permalink
[Feature] Gym compatibility: Terminal and truncated (pytorch#1539)
Browse files Browse the repository at this point in the history
Co-authored-by: Skander Moalla <37197319+skandermoalla@users.noreply.github.com>
Co-authored-by: Matteo Bettini <55539777+matteobettini@users.noreply.github.com>
  • Loading branch information
3 people authored Sep 29, 2023
1 parent 18b33fe commit 802f0e4
Show file tree
Hide file tree
Showing 53 changed files with 3,583 additions and 1,218 deletions.
6 changes: 3 additions & 3 deletions .github/unittest/linux/scripts/run_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -178,16 +178,16 @@ python -m torch.utils.collect_env
#export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir
export MKL_THREADING_LAYER=GNU
export CKPT_BACKEND=torch

export MAX_IDLE_COUNT=100

pytest test/smoke_test.py -v --durations 200
pytest test/smoke_test_deps.py -v --durations 200 -k 'test_gym or test_dm_control_pixels or test_dm_control or test_tb'
if [ "${CU_VERSION:-}" != cpu ] ; then
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test \
--instafail --durations 200 --ignore test/test_rlhf.py
--instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py
else
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test \
--instafail --durations 200 --ignore test/test_rlhf.py --ignore test/test_distributed.py
--instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py --ignore test/test_distributed.py
fi

coverage combine
Expand Down
3 changes: 2 additions & 1 deletion .github/unittest/linux_examples/scripts/run_local.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/bin/bash

set -e
set -v

# Read script from line 29
filename=".github/unittest/linux_examples/scripts/run_test.sh"
Expand All @@ -12,7 +13,7 @@ script="set -e"$'\n'"$script"
script="${script//cuda:0/cpu}"

# Remove any instances of ".github/unittest/helpers/coverage_run_parallel.py"
script="${script//.circleci\/unittest\/helpers\/coverage_run_parallel.py}"
script="${script//.github\/unittest\/helpers\/coverage_run_parallel.py}"
script="${script//coverage combine}"
script="${script//coverage xml -i}"

Expand Down
10 changes: 6 additions & 4 deletions .github/unittest/linux_examples/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,16 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo_mujoco
collector.total_frames=40 \
collector.frames_per_batch=20 \
loss.mini_batch_size=10 \
loss.ppo_epochs=1 \
loss.ppo_epochs=2 \
logger.backend= \
logger.test_interval=40
logger.test_interval=10
python .github/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo_atari.py \
collector.total_frames=80 \
collector.frames_per_batch=20 \
loss.mini_batch_size=20 \
loss.ppo_epochs=1 \
loss.ppo_epochs=2 \
logger.backend= \
logger.test_interval=40
logger.test_interval=10
python .github/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
Expand Down Expand Up @@ -126,6 +126,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/sac/sac.py \
optimization.utd_ratio=1 \
replay_buffer.size=120 \
env.name=Pendulum-v1 \
network.device=cuda:0 \
logger.backend=
# logger.record_video=True \
# logger.record_frames=4 \
Expand Down Expand Up @@ -225,6 +226,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/sac/sac.py \
collector.num_workers=2 \
collector.env_per_collector=1 \
collector.collector_device=cuda:0 \
network.device=cuda:0 \
optimization.batch_size=10 \
optimization.utd_ratio=1 \
replay_buffer.size=120 \
Expand Down
3 changes: 3 additions & 0 deletions .github/unittest/linux_libs/scripts_gym/batch_scripts.sh
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ do
conda activate ./cloned_env

echo "Testing gym version: ${GYM_VERSION}"
# handling https://github.com/openai/gym/issues/3202
pip3 install wheel==0.38.4
pip3 install gym==$GYM_VERSION
$DIR/run_test.sh

Expand All @@ -67,6 +69,7 @@ do
conda activate ./cloned_env

echo "Testing gym version: ${GYM_VERSION}"
pip3 install wheel==0.38.4
pip3 install 'gym[atari]'==$GYM_VERSION
pip3 install ale-py==0.7
$DIR/run_test.sh
Expand Down
35 changes: 20 additions & 15 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ Each env will have the following attributes:
- :obj:`env.reward_spec`: a :class:`~torchrl.data.TensorSpec` object representing
the reward spec.
- :obj:`env.done_spec`: a :class:`~torchrl.data.TensorSpec` object representing
the done-flag spec.
the done-flag spec. See the section on trajectory termination below.
- :obj:`env.input_spec`: a :class:`~torchrl.data.CompositeSpec` object containing
all the input keys (:obj:`"full_action_spec"` and :obj:`"full_state_spec"`).
It is locked and should not be modified directly.
Expand Down Expand Up @@ -79,22 +79,25 @@ The following figure summarizes how a rollout is executed in torchrl.

In brief, a TensorDict is created by the :meth:`~.EnvBase.reset` method,
then populated with an action by the policy before being passed to the
:meth:`~.EnvBase.step` method which writes the observations, done flag and
:meth:`~.EnvBase.step` method which writes the observations, done flag(s) and
reward under the ``"next"`` entry. The result of this call is stored for
delivery and the ``"next"`` entry is gathered by the :func:`~.utils.step_mdp`
function.

.. note::

The Gym(nasium) API recently shifted to a splitting of the ``"done"`` state
into a ``terminated`` (the env is done and results should not be trusted)
and ``truncated`` (the maximum number of steps is reached) flags.
In TorchRL, ``"done"`` usually refers to ``"terminated"``. Truncation is
achieved via the :class:`~.StepCounter` transform class, and the output
key will be ``"truncated"`` if not chosen to be something else (e.g.
``StepCounter(max_steps=100, truncated_key="done")``).
TorchRL's collectors and rollout methods will be looking for one of these
keys when assessing if the env should be reset.
In general, all TorchRL environment have a ``"done"`` and ``"terminated"``
entry in their output tensordict. If they are not present by design,
the :class:`~.EnvBase` metaclass will ensure that every done or terminated
is flanked with its dual.
In TorchRL, ``"done"`` strictly refers to the union of all the end-of-trajectory
signals and should be interpreted as "the last step of a trajectory" or
equivalently "a signal indicating the need to reset".
If the environment provides it (eg, Gymnasium), the truncation entry is also
written in the :meth:`EnvBase.step` output under a ``"truncated"`` entry.
If the environment carries a single value, it will interpreted as a ``"terminated"``
signal by default.
By default, TorchRL's collectors and rollout methods will be looking for the ``"done"``
entry to assess if the environment should be reset.

.. note::

Expand Down Expand Up @@ -172,12 +175,13 @@ It is also possible to reset some but not all of the environments:
:caption: Parallel environment reset
>>> tensordict = TensorDict({"_reset": [[True], [False], [True], [True]]}, [4])
>>> env.reset(tensordict)
>>> env.reset(tensordict) # eliminates the "_reset" entry
TensorDict(
fields={
terminated: Tensor(torch.Size([4, 1]), dtype=torch.bool),
done: Tensor(torch.Size([4, 1]), dtype=torch.bool),
pixels: Tensor(torch.Size([4, 500, 500, 3]), dtype=torch.uint8),
_reset: Tensor(torch.Size([4, 1]), dtype=torch.bool)},
truncated: Tensor(torch.Size([4, 1]), dtype=torch.bool),
batch_size=torch.Size([4]),
device=None,
is_shared=True)
Expand Down Expand Up @@ -238,7 +242,7 @@ Some of the main differences between these paradigms include:

- **observation** can be per-agent and also have some shared components
- **reward** can be per-agent or shared
- **done** can be per-agent or shared
- **done** (and ``"truncated"`` or ``"terminated"``) can be per-agent or shared.

TorchRL accommodates all these possible paradigms thanks to its :class:`tensordict.TensorDict` data carrier.
In particular, in multi-agent environments, per-agent keys will be carried in a nested "agents" TensorDict.
Expand Down Expand Up @@ -586,6 +590,7 @@ Helpers
exploration_type
check_env_specs
make_composite_from_td
terminated_or_truncated

Domain-specific
---------------
Expand Down
4 changes: 0 additions & 4 deletions examples/a2c/a2c_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,6 @@ def main(cfg: "DictConfig"): # noqa: F821
}
)

# Apply episodic end of life
data["done"].copy_(data["end_of_life"])
data["next", "done"].copy_(data["next", "end_of_life"])

losses = TensorDict({}, batch_size=[num_mini_batches])
training_start = time.time()

Expand Down
2 changes: 1 addition & 1 deletion examples/decision_transformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def make_offline_replay_buffer(rb_cfg, reward_scaling):
batch_size=rb_cfg.batch_size,
sampler=RandomSampler(), # SamplerWithoutReplacement(drop_last=False),
transform=transforms,
use_timeout_as_done=True,
use_truncated_as_done=True,
)
full_data = data._get_dataset_from_env(rb_cfg.dataset, {})
loc = full_data["observation"].mean(axis=0).float()
Expand Down
44 changes: 27 additions & 17 deletions examples/ppo/ppo_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ def main(cfg: "DictConfig"): # noqa: F821
normalize_advantage=True,
)

# use end-of-life as done key
loss_module.set_keys(done="eol")

# Create optimizer
optim = torch.optim.Adam(
loss_module.parameters(),
Expand Down Expand Up @@ -109,6 +112,18 @@ def main(cfg: "DictConfig"): # noqa: F821
)

sampling_start = time.time()

# extract cfg variables
cfg_loss_ppo_epochs = cfg.loss.ppo_epochs
cfg_optim_anneal_lr = cfg.optim.anneal_lr
cfg_optim_lr = cfg.optim.lr
cfg_loss_anneal_clip_eps = cfg.loss.anneal_clip_epsilon
cfg_loss_clip_epsilon = cfg.loss.clip_epsilon
cfg_logger_num_test_episodes = cfg.logger.num_test_episodes
cfg_optim_max_grad_norm = cfg.optim.max_grad_norm
cfg.loss.clip_epsilon = cfg_loss_clip_epsilon
losses = TensorDict({}, batch_size=[cfg_loss_ppo_epochs, num_mini_batches])

for i, data in enumerate(collector):

log_info = {}
Expand All @@ -120,7 +135,7 @@ def main(cfg: "DictConfig"): # noqa: F821
# Get training rewards and episode lengths
episode_rewards = data["next", "episode_reward"][data["next", "done"]]
if len(episode_rewards) > 0:
episode_length = data["next", "step_count"][data["next", "done"]]
episode_length = data["next", "step_count"][data["next", "stop"]]
log_info.update(
{
"train/reward": episode_rewards.mean().item(),
Expand All @@ -129,13 +144,8 @@ def main(cfg: "DictConfig"): # noqa: F821
}
)

# Apply episodic end of life
data["done"].copy_(data["end_of_life"])
data["next", "done"].copy_(data["next", "end_of_life"])

losses = TensorDict({}, batch_size=[cfg.loss.ppo_epochs, num_mini_batches])
training_start = time.time()
for j in range(cfg.loss.ppo_epochs):
for j in range(cfg_loss_ppo_epochs):

# Compute GAE
with torch.no_grad():
Expand All @@ -149,12 +159,12 @@ def main(cfg: "DictConfig"): # noqa: F821

# Linearly decrease the learning rate and clip epsilon
alpha = 1.0
if cfg.optim.anneal_lr:
if cfg_optim_anneal_lr:
alpha = 1 - (num_network_updates / total_network_updates)
for group in optim.param_groups:
group["lr"] = cfg.optim.lr * alpha
if cfg.loss.anneal_clip_epsilon:
loss_module.clip_epsilon.copy_(cfg.loss.clip_epsilon * alpha)
group["lr"] = cfg_optim_lr * alpha
if cfg_loss_anneal_clip_eps:
loss_module.clip_epsilon.copy_(cfg_loss_clip_epsilon * alpha)
num_network_updates += 1

# Get a data batch
Expand All @@ -172,7 +182,7 @@ def main(cfg: "DictConfig"): # noqa: F821
# Backward pass
loss_sum.backward()
torch.nn.utils.clip_grad_norm_(
list(loss_module.parameters()), max_norm=cfg.optim.max_grad_norm
list(loss_module.parameters()), max_norm=cfg_optim_max_grad_norm
)

# Update the networks
Expand All @@ -181,15 +191,15 @@ def main(cfg: "DictConfig"): # noqa: F821

# Get training losses and times
training_time = time.time() - training_start
losses = losses.apply(lambda x: x.float().mean(), batch_size=[])
for key, value in losses.items():
losses_mean = losses.apply(lambda x: x.float().mean(), batch_size=[])
for key, value in losses_mean.items():
log_info.update({f"train/{key}": value.item()})
log_info.update(
{
"train/lr": alpha * cfg.optim.lr,
"train/lr": alpha * cfg_optim_lr,
"train/sampling_time": sampling_time,
"train/training_time": training_time,
"train/clip_epsilon": alpha * cfg.loss.clip_epsilon,
"train/clip_epsilon": alpha * cfg_loss_clip_epsilon,
}
)

Expand All @@ -201,7 +211,7 @@ def main(cfg: "DictConfig"): # noqa: F821
actor.eval()
eval_start = time.time()
test_rewards = eval_model(
actor, test_env, num_episodes=cfg.logger.num_test_episodes
actor, test_env, num_episodes=cfg_logger_num_test_episodes
)
eval_time = time.time() - eval_start
log_info.update(
Expand Down
38 changes: 24 additions & 14 deletions examples/ppo/ppo_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,17 @@ def main(cfg: "DictConfig"): # noqa: F821
pbar = tqdm.tqdm(total=cfg.collector.total_frames)

sampling_start = time.time()

# extract cfg variables
cfg_loss_ppo_epochs = cfg.loss.ppo_epochs
cfg_optim_anneal_lr = cfg.optim.anneal_lr
cfg_optim_lr = cfg.optim.lr
cfg_loss_anneal_clip_eps = cfg.loss.anneal_clip_epsilon
cfg_loss_clip_epsilon = cfg.loss.clip_epsilon
cfg_logger_test_interval = cfg.logger.test_interval
cfg_logger_num_test_episodes = cfg.logger.num_test_episodes
losses = TensorDict({}, batch_size=[cfg_loss_ppo_epochs, num_mini_batches])

for i, data in enumerate(collector):

log_info = {}
Expand All @@ -120,9 +131,8 @@ def main(cfg: "DictConfig"): # noqa: F821
}
)

losses = TensorDict({}, batch_size=[cfg.loss.ppo_epochs, num_mini_batches])
training_start = time.time()
for j in range(cfg.loss.ppo_epochs):
for j in range(cfg_loss_ppo_epochs):

# Compute GAE
with torch.no_grad():
Expand All @@ -136,14 +146,14 @@ def main(cfg: "DictConfig"): # noqa: F821

# Linearly decrease the learning rate and clip epsilon
alpha = 1.0
if cfg.optim.anneal_lr:
if cfg_optim_anneal_lr:
alpha = 1 - (num_network_updates / total_network_updates)
for group in actor_optim.param_groups:
group["lr"] = cfg.optim.lr * alpha
group["lr"] = cfg_optim_lr * alpha
for group in critic_optim.param_groups:
group["lr"] = cfg.optim.lr * alpha
if cfg.loss.anneal_clip_epsilon:
loss_module.clip_epsilon.copy_(cfg.loss.clip_epsilon * alpha)
group["lr"] = cfg_optim_lr * alpha
if cfg_loss_anneal_clip_eps:
loss_module.clip_epsilon.copy_(cfg_loss_clip_epsilon * alpha)
num_network_updates += 1

# Forward pass PPO loss
Expand All @@ -166,27 +176,27 @@ def main(cfg: "DictConfig"): # noqa: F821

# Get training losses and times
training_time = time.time() - training_start
losses = losses.apply(lambda x: x.float().mean(), batch_size=[])
for key, value in losses.items():
losses_mean = losses.apply(lambda x: x.float().mean(), batch_size=[])
for key, value in losses_mean.items():
log_info.update({f"train/{key}": value.item()})
log_info.update(
{
"train/lr": alpha * cfg.optim.lr,
"train/lr": alpha * cfg_optim_lr,
"train/sampling_time": sampling_time,
"train/training_time": training_time,
"train/clip_epsilon": alpha * cfg.loss.clip_epsilon,
"train/clip_epsilon": alpha * cfg_loss_clip_epsilon,
}
)

# Get test rewards
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
if ((i - 1) * frames_in_batch) // cfg.logger.test_interval < (
if ((i - 1) * frames_in_batch) // cfg_logger_test_interval < (
i * frames_in_batch
) // cfg.logger.test_interval:
) // cfg_logger_test_interval:
actor.eval()
eval_start = time.time()
test_rewards = eval_model(
actor, test_env, num_episodes=cfg.logger.num_test_episodes
actor, test_env, num_episodes=cfg_logger_num_test_episodes
)
eval_time = time.time() - eval_start
log_info.update(
Expand Down
Loading

0 comments on commit 802f0e4

Please sign in to comment.