Skip to content

Commit

Permalink
[Doc] Per-release doc (pytorch#2108)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Apr 25, 2024
1 parent 45c91cc commit 62a3adb
Show file tree
Hide file tree
Showing 20 changed files with 125 additions and 97 deletions.
57 changes: 32 additions & 25 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -108,37 +108,44 @@ jobs:
REF_TYPE=${{ github.ref_type }}
REF_NAME=${{ github.ref_name }}
# TODO: adopt this behaviour
# if [[ "${REF_TYPE}" == branch ]]; then
# TARGET_FOLDER="${REF_NAME}"
# elif [[ "${REF_TYPE}" == tag ]]; then
# case "${REF_NAME}" in
# *-rc*)
# echo "Aborting upload since this is an RC tag: ${REF_NAME}"
# exit 0
# ;;
# *)
# # Strip the leading "v" as well as the trailing patch version. For example:
# # 'v0.15.2' -> '0.15'
# TARGET_FOLDER=$(echo "${REF_NAME}" | sed 's/v\([0-9]\+\)\.\([0-9]\+\)\.[0-9]\+/\1.\2/')
# ;;
# esac
# fi
TARGET_FOLDER="./"
if [[ "${REF_TYPE}" == branch ]]; then
if [[ "${REF_NAME}" == main ]]; then
TARGET_FOLDER="${REF_NAME}"
# Bebug:
# else
# TARGET_FOLDER="release-doc"
fi
elif [[ "${REF_TYPE}" == tag ]]; then
case "${REF_NAME}" in
*-rc*)
echo "Aborting upload since this is an RC tag: ${REF_NAME}"
exit 0
;;
*)
# Strip the leading "v" as well as the trailing patch version. For example:
# 'v0.15.2' -> '0.15'
TARGET_FOLDER=$(echo "${REF_NAME}" | sed 's/v\([0-9]\+\)\.\([0-9]\+\)\.[0-9]\+/\1.\2/')
;;
esac
fi
echo "Target Folder: ${TARGET_FOLDER}"
# mkdir -p "${TARGET_FOLDER}"
# rm -rf "${TARGET_FOLDER}"/*
mkdir -p "${TARGET_FOLDER}"
rm -rf "${TARGET_FOLDER}"/*
echo $(ls "${RUNNER_ARTIFACT_DIR}")
rsync -a "${RUNNER_ARTIFACT_DIR}"/ "${TARGET_FOLDER}"
git add "${TARGET_FOLDER}" || true
# if [[ "${TARGET_FOLDER}" == main ]]; then
# mkdir -p _static
# rm -rf _static/*
# cp -r "${TARGET_FOLDER}"/_static/* _static
# git add _static || true
# fi
# Debug
# if [[ "${TARGET_FOLDER}" == "main" ]] || [[ "${TARGET_FOLDER}" == "release-doc" ]]; then
if [[ "${TARGET_FOLDER}" == "main" ]] ; then
mkdir -p _static
rm -rf _static/*
cp -r "${TARGET_FOLDER}"/_static/* _static
git add _static || true
fi
git config user.name 'pytorchbot'
git config user.email 'soumith+bot@pytorch.org'
Expand Down
13 changes: 6 additions & 7 deletions docs/source/reference/collectors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ In each of these cases, the last dimension (``T`` for ``time``) is adapted such
that the batch size equals the ``frames_per_batch`` argument passed to the
collector.

.. warning:: :class:`~torchrl.collectors.collectors.MultiSyncDataCollector` should not be
.. warning:: :class:`~torchrl.collectors.MultiSyncDataCollector` should not be
used with ``cat_results=0``, as the data will be stacked along the batch
dimension with batched environment, or the time dimension for single environments,
which can introduce some confusion when swapping one with the other.
Expand All @@ -91,12 +91,12 @@ collector.
better interchangeability between configurations, collector classes and other
components.

Whereas :class:`~torchrl.collectors.collectors.MultiSyncDataCollector`
Whereas :class:`~torchrl.collectors.MultiSyncDataCollector`
has a dimension corresponding to the number of sub-collectors being run (``B``),
:class:`~torchrl.collectors.collectors.MultiaSyncDataCollector` doesn't. This
is easily understood when considering that :class:`~torchrl.collectors.collectors.MultiaSyncDataCollector`
:class:`~torchrl.collectors.MultiaSyncDataCollector` doesn't. This
is easily understood when considering that :class:`~torchrl.collectors.MultiaSyncDataCollector`
delivers batches of data on a first-come, first-serve basis, whereas
:class:`~torchrl.collectors.collectors.MultiSyncDataCollector` gathers data from
:class:`~torchrl.collectors.MultiSyncDataCollector` gathers data from
each sub-collector before delivering it.

Collectors and replay buffers interoperability
Expand Down Expand Up @@ -168,7 +168,7 @@ batches written in the buffer won't come from the same source (thereby interrupt

Single node data collectors
---------------------------
.. currentmodule:: torchrl.collectors.collectors
.. currentmodule:: torchrl.collectors

.. autosummary::
:toctree: generated/
Expand All @@ -178,7 +178,6 @@ Single node data collectors
SyncDataCollector
MultiSyncDataCollector
MultiaSyncDataCollector
RandomPolicy
aSyncDataCollector


Expand Down
11 changes: 6 additions & 5 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -849,14 +849,15 @@ Helpers
:toctree: generated/
:template: rl_template_fun.rst

step_mdp
get_available_libraries
set_exploration_mode #deprecated
set_exploration_type
RandomPolicy
check_env_specs
exploration_mode #deprecated
exploration_type
check_env_specs
get_available_libraries
make_composite_from_td
set_exploration_mode #deprecated
set_exploration_type
step_mdp
terminated_or_truncated

Domain-specific
Expand Down
2 changes: 0 additions & 2 deletions docs/source/reference/trainers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,6 @@ Builders
make_collector_offpolicy
make_collector_onpolicy
make_dqn_loss
make_redq_loss
make_redq_model
make_replay_buffer
make_target_updater
make_trainer
Expand Down
2 changes: 2 additions & 0 deletions torchrl/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
WriterEnsemble,
)
from .rlhf import (
AdaptiveKLController,
ConstantKLController,
create_infinite_iterator,
get_dataloader,
PairwiseDataset,
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/rlhf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
)
from .prompt import PromptData, PromptTensorDictTokenizer
from .reward import PairwiseDataset, RewardData
from .utils import RolloutFromModel
from .utils import AdaptiveKLController, ConstantKLController, RolloutFromModel
1 change: 1 addition & 0 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1568,6 +1568,7 @@ def _shutdown_workers(self) -> None:
if self._verbose:
torchrl_logger.info(f"closing {i}")
channel.send(("close", None))
for i in range(self.num_workers):
self._events[i].wait(self._timeout)
self._events[i].clear()

Expand Down
4 changes: 2 additions & 2 deletions torchrl/envs/libs/jumanji.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def _jumanji_to_torchrl_spec_transform(
dtype = numpy_to_torch_dtype_dict[spec.dtype]
return BoundedTensorSpec(
shape=shape,
low=np.asarray(spec.minimum),
high=np.asarray(spec.maximum),
low=np.asarray(spec.low),
high=np.asarray(spec.high),
dtype=dtype,
device=device,
)
Expand Down
7 changes: 5 additions & 2 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3120,8 +3120,11 @@ def unfold_done(done, N):
reset_vals = reset_vals[1:]
reps.extend([reset_vals[0]] * int(j))
j_ = j
reps = torch.stack(reps)
data = torch.masked_scatter(data, done_mask_expand, reps.reshape(-1))
if reps:
reps = torch.stack(reps)
data = torch.masked_scatter(
data, done_mask_expand, reps.reshape(-1)
)

if first_val is not None:
# Aggregate reset along last dim
Expand Down
2 changes: 1 addition & 1 deletion tutorials/sphinx-tutorials/coding_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def _loss_value(
batch_size=self.target_actor_network_params.batch_size,
device=self.target_actor_network_params.device,
)
with target_params.to_module(self.value_estimator):
with target_params.to_module(self.actor_critic):
target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1)

# Computes the value loss: L2, L1 or smooth L1 depending on `self.loss_function`
Expand Down
41 changes: 29 additions & 12 deletions tutorials/sphinx-tutorials/coding_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,20 @@

try:
multiprocessing.set_start_method("spawn" if is_sphinx else "fork")
mp_context = "fork"
except RuntimeError:
# If we can't set the method globally we can still run the parallel env with "fork"
# This will fail on windows! Use "spawn" and put the script within `if __name__ == "__main__"`
mp_context = "fork"
pass


# sphinx_gallery_end_ignore
import os
import uuid

import torch
from torch import nn
from torchrl.collectors import MultiaSyncDataCollector
from torchrl.collectors import MultiaSyncDataCollector, SyncDataCollector
from torchrl.data import LazyMemmapStorage, MultiStep, TensorDictReplayBuffer
from torchrl.envs import (
EnvCreator,
Expand Down Expand Up @@ -217,20 +220,26 @@ def is_notebook() -> bool:
def make_env(
parallel=False,
obs_norm_sd=None,
num_workers=1,
):
if obs_norm_sd is None:
obs_norm_sd = {"standard_normal": True}
if parallel:

def maker():
return GymEnv(
"CartPole-v1",
from_pixels=True,
pixels_only=True,
device=device,
)

base_env = ParallelEnv(
num_workers,
EnvCreator(
lambda: GymEnv(
"CartPole-v1",
from_pixels=True,
pixels_only=True,
device=device,
)
),
EnvCreator(maker),
# Don't create a sub-process if we have only one worker
serial_for_single=True,
mp_start_method=mp_context,
)
else:
base_env = GymEnv(
Expand Down Expand Up @@ -279,6 +288,7 @@ def get_norm_stats():
# ``C=4`` (because of :class:`~torchrl.envs.CatFrames`).
print("state dict of the observation norm:", obs_norm_sd)
test_env.close()
del test_env
return obs_norm_sd


Expand Down Expand Up @@ -426,8 +436,15 @@ def get_collector(
total_frames,
device,
):
cls = MultiaSyncDataCollector
env_arg = [make_env(parallel=True, obs_norm_sd=stats)] * num_collectors
# We can't use nested child processes with mp_start_method="fork"
if is_fork:
cls = SyncDataCollector
env_arg = make_env(parallel=True, obs_norm_sd=stats, num_workers=num_workers)
else:
cls = MultiaSyncDataCollector
env_arg = [
make_env(parallel=True, obs_norm_sd=stats, num_workers=num_workers)
] * num_collectors
data_collector = cls(
env_arg,
policy=actor_explore,
Expand Down
4 changes: 2 additions & 2 deletions tutorials/sphinx-tutorials/dqn_with_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@
# either by passing a string or an action-spec. This allows us to use
# Categorical (sometimes called "sparse") encoding or the one-hot version of it.
#
qval = QValueModule(action_space=env.action_spec)
qval = QValueModule(action_space=None, spec=env.action_spec)

######################################################################
# .. note::
Expand Down Expand Up @@ -412,7 +412,7 @@
#

utd = 16
pbar = tqdm.tqdm(total=1_000_000)
pbar = tqdm.tqdm(total=collector.total_frames)
longest = 0

traj_lens = []
Expand Down
4 changes: 1 addition & 3 deletions tutorials/sphinx-tutorials/getting-started-1.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,9 +273,7 @@

policy = TensorDictSequential(
value_net, # writes action values in our tensordict
QValueModule(
action_space=env.action_spec
), # Reads the "action_value" entry by default
QValueModule(spec=env.action_spec), # Reads the "action_value" entry by default
)

###################################
Expand Down
2 changes: 1 addition & 1 deletion tutorials/sphinx-tutorials/getting-started-5.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@

value_mlp = MLP(out_features=env.action_spec.shape[-1], num_cells=[64, 64])
value_net = Mod(value_mlp, in_keys=["observation"], out_keys=["action_value"])
policy = Seq(value_net, QValueModule(env.action_spec))
policy = Seq(value_net, QValueModule(spec=env.action_spec))
exploration_module = EGreedyModule(
env.action_spec, annealing_num_steps=100_000, eps_init=0.5
)
Expand Down
7 changes: 4 additions & 3 deletions tutorials/sphinx-tutorials/multi_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# sphinx_gallery_start_ignore
import warnings

from tensordict import LazyStackedTensorDict

warnings.filterwarnings("ignore")

from torch import multiprocessing
Expand All @@ -31,7 +33,6 @@

# sphinx_gallery_end_ignore

import torch
from tensordict.nn import TensorDictModule, TensorDictSequential
from torch import nn

Expand Down Expand Up @@ -77,9 +78,9 @@
tdreset1 = env1.reset()
tdreset2 = env2.reset()

# In TorchRL, stacking is done in a lazy manner: the original tensordicts
# With LazyStackedTensorDict, stacking is done in a lazy manner: the original tensordicts
# can still be recovered by indexing the main tensordict
tdreset = torch.stack([tdreset1, tdreset2], 0)
tdreset = LazyStackedTensorDict.lazy_stack([tdreset1, tdreset2], 0)
assert tdreset[0] is tdreset1

###############################################################################
Expand Down
15 changes: 7 additions & 8 deletions tutorials/sphinx-tutorials/multiagent_competitive_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,6 @@
===========================================================================
**Author**: `Matteo Bettini <https://github.com/matteobettini>`_
.. note::
If you are interested in Multi-Agent Reinforcement Learning (MARL) in
TorchRL, check out
`BenchMARL <https://github.com/facebookresearch/BenchMARL>`__: a benchmarking library where you
can train and compare MARL sota-implementations, tasks, and models using TorchRL!
This tutorial demonstrates how to use PyTorch and TorchRL to
solve a Competitive Multi-Agent Reinforcement Learning (MARL) problem.
Expand Down Expand Up @@ -141,6 +134,12 @@

from tqdm import tqdm

# Check if we're building the doc, in which case disable video rendering
try:
is_sphinx = __sphinx_build__
except NameError:
is_sphinx = False

######################################################################
# Define Hyperparameters
# ----------------------
Expand Down Expand Up @@ -879,7 +878,7 @@ def process_batch(batch: TensorDictBase) -> TensorDictBase:
# logger `video_logger`. Note that this code may require some external dependencies such as torchvision.
#

if use_vmas:
if use_vmas and not is_sphinx:
# Replace tmpdir with any desired path where the video should be saved
with tempfile.TemporaryDirectory() as tmpdir:
video_logger = CSVLogger("vmas_logs", tmpdir, video_format="mp4")
Expand Down
Loading

0 comments on commit 62a3adb

Please sign in to comment.