Skip to content

Commit

Permalink
[Doc] Make tutos runnable without colab (#1826)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jan 29, 2024
1 parent 79374d8 commit 6277226
Show file tree
Hide file tree
Showing 12 changed files with 156 additions and 11 deletions.
23 changes: 14 additions & 9 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,41 +345,46 @@ def functional(self):

@property
def actor(self):
logging.warning(
warnings.warn(
f"{self.__class__.__name__}.actor is deprecated, use {self.__class__.__name__}.actor_network instead. This "
"link will be removed in v0.4."
"link will be removed in v0.4.",
category=DeprecationWarning,
)
return self.actor_network

@property
def critic(self):
logging.warning(
warnings.warn(
f"{self.__class__.__name__}.critic is deprecated, use {self.__class__.__name__}.critic_network instead. This "
"link will be removed in v0.4."
"link will be removed in v0.4.",
category=DeprecationWarning,
)
return self.critic_network

@property
def actor_params(self):
logging.warning(
f"{self.__class__.__name__}.actor_params is deprecated, use {self.__class__.__name__}.actor_network_params instead. This "
"link will be removed in v0.4."
"link will be removed in v0.4.",
category=DeprecationWarning,
)
return self.actor_network_params

@property
def critic_params(self):
logging.warning(
warnings.warn(
f"{self.__class__.__name__}.critic_params is deprecated, use {self.__class__.__name__}.critic_network_params instead. This "
"link will be removed in v0.4."
"link will be removed in v0.4.",
category=DeprecationWarning,
)
return self.critic_network_params

@property
def target_critic_params(self):
logging.warning(
warnings.warn(
f"{self.__class__.__name__}.target_critic_params is deprecated, use {self.__class__.__name__}.target_critic_network_params instead. This "
"link will be removed in v0.4."
"link will be removed in v0.4.",
category=DeprecationWarning,
)
return self.target_critic_network_params

Expand Down
10 changes: 10 additions & 0 deletions tutorials/sphinx-tutorials/coding_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,16 @@
from typing import Tuple

warnings.filterwarnings("ignore")
from torch import multiprocessing

# TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside
# `__main__` method call, but for the easy of reading the code switch to fork
# which is also a default spawn method in Google's Colaboratory
try:
multiprocessing.set_start_method("fork")
except RuntimeError:
assert multiprocessing.get_start_method() == "fork"

# sphinx_gallery_end_ignore

import torch.cuda
Expand Down
12 changes: 12 additions & 0 deletions tutorials/sphinx-tutorials/coding_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,18 @@
import warnings

warnings.filterwarnings("ignore")

from torch import multiprocessing

# TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside
# `__main__` method call, but for the easy of reading the code switch to fork
# which is also a default spawn method in Google's Colaboratory
try:
multiprocessing.set_start_method("fork")
except RuntimeError:
assert multiprocessing.get_start_method() == "fork"


# sphinx_gallery_end_ignore

import os
Expand Down
16 changes: 16 additions & 0 deletions tutorials/sphinx-tutorials/coding_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,22 @@
# description and more about the algorithm itself.
#

# sphinx_gallery_start_ignore
import warnings

warnings.filterwarnings("ignore")
from torch import multiprocessing

# TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside
# `__main__` method call, but for the easy of reading the code switch to fork
# which is also a default spawn method in Google's Colaboratory
try:
multiprocessing.set_start_method("fork")
except RuntimeError:
assert multiprocessing.get_start_method() == "fork"

# sphinx_gallery_end_ignore

from collections import defaultdict

import matplotlib.pyplot as plt
Expand Down
16 changes: 16 additions & 0 deletions tutorials/sphinx-tutorials/dqn_with_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,22 @@
# -----
#

# sphinx_gallery_start_ignore
import warnings

warnings.filterwarnings("ignore")
from torch import multiprocessing

# TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside
# `__main__` method call, but for the easy of reading the code switch to fork
# which is also a default spawn method in Google's Colaboratory
try:
multiprocessing.set_start_method("fork")
except RuntimeError:
assert multiprocessing.get_start_method() == "fork"

# sphinx_gallery_end_ignore

import torch
import tqdm
from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq
Expand Down
11 changes: 11 additions & 0 deletions tutorials/sphinx-tutorials/multi_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,17 @@
import warnings

warnings.filterwarnings("ignore")

from torch import multiprocessing

# TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside
# `__main__` method call, but for the easy of reading the code switch to fork
# which is also a default spawn method in Google's Colaboratory
try:
multiprocessing.set_start_method("fork")
except RuntimeError:
assert multiprocessing.get_start_method() == "fork"

# sphinx_gallery_end_ignore

import torch
Expand Down
4 changes: 2 additions & 2 deletions tutorials/sphinx-tutorials/multiagent_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,8 +659,8 @@
with torch.no_grad():
GAE(
tensordict_data,
params=loss_module.critic_params,
target_params=loss_module.target_critic_params,
params=loss_module.critic_network_params,
target_params=loss_module.target_critic_network_params,
) # Compute GAE and add it to the data

data_view = tensordict_data.reshape(-1) # Flatten the batch size to shuffle data
Expand Down
17 changes: 17 additions & 0 deletions tutorials/sphinx-tutorials/pendulum.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,23 @@
# simulation graph.
# * Finally, we will train a simple policy to solve the system we implemented.
#

# sphinx_gallery_start_ignore
import warnings

warnings.filterwarnings("ignore")
from torch import multiprocessing

# TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside
# `__main__` method call, but for the easy of reading the code switch to fork
# which is also a default spawn method in Google's Colaboratory
try:
multiprocessing.set_start_method("fork")
except RuntimeError:
assert multiprocessing.get_start_method() == "fork"

# sphinx_gallery_end_ignore

from collections import defaultdict
from typing import Optional

Expand Down
17 changes: 17 additions & 0 deletions tutorials/sphinx-tutorials/pretrained_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,23 @@
# in one or the other context. In this tutorial, we will be using R3M (https://arxiv.org/abs/2203.12601),
# but other models (e.g. VIP) will work equally well.
#

# sphinx_gallery_start_ignore
import warnings

warnings.filterwarnings("ignore")
from torch import multiprocessing

# TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside
# `__main__` method call, but for the easy of reading the code switch to fork
# which is also a default spawn method in Google's Colaboratory
try:
multiprocessing.set_start_method("fork")
except RuntimeError:
assert multiprocessing.get_start_method() == "fork"

# sphinx_gallery_end_ignore

import torch.cuda
from tensordict.nn import TensorDictSequential
from torch import nn
Expand Down
17 changes: 17 additions & 0 deletions tutorials/sphinx-tutorials/rb_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,23 @@
# replay buffer is a straightforward process, as shown in the following
# example:
#

# sphinx_gallery_start_ignore
import warnings

warnings.filterwarnings("ignore")
from torch import multiprocessing

# TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside
# `__main__` method call, but for the easy of reading the code switch to fork
# which is also a default spawn method in Google's Colaboratory
try:
multiprocessing.set_start_method("fork")
except RuntimeError:
assert multiprocessing.get_start_method() == "fork"

# sphinx_gallery_end_ignore

import tempfile

from torchrl.data import ReplayBuffer
Expand Down
12 changes: 12 additions & 0 deletions tutorials/sphinx-tutorials/torchrl_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,18 @@
import warnings

warnings.filterwarnings("ignore")

from torch import multiprocessing

# TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside
# `__main__` method call, but for the easy of reading the code switch to fork
# which is also a default spawn method in Google's Colaboratory
try:
multiprocessing.set_start_method("fork")
except RuntimeError:
assert multiprocessing.get_start_method() == "fork"


# sphinx_gallery_end_ignore

import torch
Expand Down
12 changes: 12 additions & 0 deletions tutorials/sphinx-tutorials/torchrl_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,18 @@
import warnings

warnings.filterwarnings("ignore")

from torch import multiprocessing

# TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside
# `__main__` method call, but for the easy of reading the code switch to fork
# which is also a default spawn method in Google's Colaboratory
try:
multiprocessing.set_start_method("fork")
except RuntimeError:
assert multiprocessing.get_start_method() == "fork"


# sphinx_gallery_end_ignore

import torch
Expand Down

0 comments on commit 6277226

Please sign in to comment.