Skip to content

Commit

Permalink
[Minor] docstrings and setup fixes (pytorch#726)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Dec 6, 2022
1 parent 79eeb3c commit 44fb1b7
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 13 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def _main(argv):
"hydra-submitit-launcher",
],
"checkpointing": [
"torchinductor",
"torchsnapshot",
],
},
zip_safe=False,
Expand Down
146 changes: 137 additions & 9 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@
from torch import multiprocessing as mp
from torch.utils.data import IterableDataset

from torchrl._utils import _check_for_faulty_process, prod
from torchrl.collectors.utils import split_trajectories
from torchrl.data import TensorSpec
from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING
from torchrl.envs.common import EnvBase

from torchrl.envs.transforms import TransformedEnv
from torchrl.envs.utils import set_exploration_mode, step_mdp

from .._utils import _check_for_faulty_process, prod
from ..data import TensorSpec
from ..data.utils import CloudpickleWrapper, DEVICE_TYPING
from ..envs.common import EnvBase
from ..envs.vec_env import _BatchedEnv
from .utils import split_trajectories
from torchrl.envs.vec_env import _BatchedEnv

_TIMEOUT = 1.0
_MIN_TIMEOUT = 1e-3 # should be several orders of magnitude inferior wrt time spent collecting a trajectory
Expand Down Expand Up @@ -296,6 +296,48 @@ class SyncDataCollector(_DataCollector):
updated. This feature should be used cautiously: if the same tensordict is added to a replay buffer for instance,
the whole content of the buffer will be identical.
Default is False.
Examples:
>>> from torchrl.envs.libs.gym import GymEnv
>>> from tensordict.nn import TensorDictModule
>>> from torch import nn
>>> env_maker = lambda: GymEnv("Pendulum-v1", device="cpu")
>>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"])
>>> collector = SyncDataCollector(
... create_env_fn=env_maker,
... policy=policy,
... total_frames=2000,
... max_frames_per_traj=50,
... frames_per_batch=200,
... init_random_frames=-1,
... reset_at_each_iter=False,
... device="cpu",
... passing_device="cpu",
... )
>>> for i, data in enumerate(collector):
... if i == 2:
... print(data)
... break
TensorDict(
fields={
action: Tensor(torch.Size([4, 50, 1]), dtype=torch.float32),
done: Tensor(torch.Size([4, 50, 1]), dtype=torch.bool),
mask: Tensor(torch.Size([4, 50, 1]), dtype=torch.bool),
next: TensorDict(
fields={
observation: Tensor(torch.Size([4, 50, 3]), dtype=torch.float32)},
batch_size=torch.Size([4, 50]),
device=cpu,
is_shared=False),
observation: Tensor(torch.Size([4, 50, 3]), dtype=torch.float32),
reward: Tensor(torch.Size([4, 50, 1]), dtype=torch.float32),
step_count: Tensor(torch.Size([4, 50, 1]), dtype=torch.float32),
traj_ids: Tensor(torch.Size([4, 50, 1, 1]), dtype=torch.float32)},
batch_size=torch.Size([4, 50]),
device=cpu,
is_shared=False)
>>> del collector
"""

def __init__(
Expand Down Expand Up @@ -471,8 +513,10 @@ def set_seed(self, seed: int, static_seed: bool = False) -> int:
seed will be incremented for each of these. The resulting seed is the seed of the last environment.
Examples:
>>> from torchrl.envs import ParallelEnv
>>> from torchrl.envs.libs.gym import GymEnv
>>> env_fn = lambda: GymEnv("Pendulum-v1")
>>> env_fn_parallel = lambda: ParallelEnv(6, env_fn)
>>> env_fn_parallel = ParallelEnv(6, env_fn)
>>> collector = SyncDataCollector(env_fn_parallel)
>>> out_seed = collector.set_seed(1) # out_seed = 6
Expand Down Expand Up @@ -724,7 +768,7 @@ class _MultiDataCollector(_DataCollector):
workers only once the total number of frames has been collected on the server.
create_env_kwargs (dict, optional): A (list of) dictionaries with the arguments used to create an environment
max_frames_per_traj: Maximum steps per trajectory. Note that a trajectory can span over multiple batches
(unless reset_at_each_iter is set to True, see below). Once a trajectory reaches n_steps_max,
(unless reset_at_each_iter is set to True, see below). Once a traje tory reaches n_steps_max,
the environment is reset. If the environment wraps multiple environments together, the number of steps
is tracked for each environment independently. Negative values are allowed, in which case this argument
is ignored.
Expand Down Expand Up @@ -1094,6 +1138,48 @@ class MultiSyncDataCollector(_MultiDataCollector):
trajectory and the start of the next collection.
This class can be safely used with online RL algorithms.
Examples:
>>> from torchrl.envs.libs.gym import GymEnv
>>> from tensordict.nn import TensorDictModule
>>> from torch import nn
>>> env_maker = lambda: GymEnv("Pendulum-v1", device="cpu")
>>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"])
>>> collector = MultiSyncDataCollector(
... create_env_fn=[env_maker, env_maker],
... policy=policy,
... total_frames=2000,
... max_frames_per_traj=50,
... frames_per_batch=200,
... init_random_frames=-1,
... reset_at_each_iter=False,
... devices="cpu",
... passing_devices="cpu",
... )
>>> for i, data in enumerate(collector):
... if i == 2:
... print(data)
... break
TensorDict(
fields={
action: Tensor(torch.Size([4, 50, 1]), dtype=torch.float32),
done: Tensor(torch.Size([4, 50, 1]), dtype=torch.bool),
mask: Tensor(torch.Size([4, 50, 1]), dtype=torch.bool),
next: TensorDict(
fields={
observation: Tensor(torch.Size([4, 50, 3]), dtype=torch.float32)},
batch_size=torch.Size([4, 50]),
device=cpu,
is_shared=False),
observation: Tensor(torch.Size([4, 50, 3]), dtype=torch.float32),
reward: Tensor(torch.Size([4, 50, 1]), dtype=torch.float32),
step_count: Tensor(torch.Size([4, 50, 1]), dtype=torch.float32),
traj_ids: Tensor(torch.Size([4, 50, 1, 1]), dtype=torch.float32)},
batch_size=torch.Size([4, 50]),
device=cpu,
is_shared=False)
>>> collector.shutdown()
>>> del collector
"""

__doc__ += _MultiDataCollector.__doc__
Expand Down Expand Up @@ -1189,6 +1275,48 @@ class MultiaSyncDataCollector(_MultiDataCollector):
the batch of rollouts is collected and the next call to the iterator.
This class can be safely used with offline RL algorithms.
Examples:
>>> from torchrl.envs.libs.gym import GymEnv
>>> from tensordict.nn import TensorDictModule
>>> from torch import nn
>>> env_maker = lambda: GymEnv("Pendulum-v1", device="cpu")
>>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"])
>>> collector = MultiaSyncDataCollector(
... create_env_fn=[env_maker, env_maker],
... policy=policy,
... total_frames=2000,
... max_frames_per_traj=50,
... frames_per_batch=200,
... init_random_frames=-1,
... reset_at_each_iter=False,
... devices="cpu",
... passing_devices="cpu",
... )
>>> for i, data in enumerate(collector):
... if i == 2:
... print(data)
... break
TensorDict(
fields={
action: Tensor(torch.Size([4, 50, 1]), dtype=torch.float32),
done: Tensor(torch.Size([4, 50, 1]), dtype=torch.bool),
mask: Tensor(torch.Size([4, 50, 1]), dtype=torch.bool),
next: TensorDict(
fields={
observation: Tensor(torch.Size([4, 50, 3]), dtype=torch.float32)},
batch_size=torch.Size([4, 50]),
device=cpu,
is_shared=False),
observation: Tensor(torch.Size([4, 50, 3]), dtype=torch.float32),
reward: Tensor(torch.Size([4, 50, 1]), dtype=torch.float32),
step_count: Tensor(torch.Size([4, 50, 1]), dtype=torch.float32),
traj_ids: Tensor(torch.Size([4, 50, 1, 1]), dtype=torch.float32)},
batch_size=torch.Size([4, 50]),
device=cpu,
is_shared=False)
>>> collector.shutdown()
>>> del collector
"""

__doc__ += _MultiDataCollector.__doc__
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __setstate__(self, ob: bytes):

self.fn, self.kwargs = pickle.loads(ob)

def __call__(self, **kwargs) -> Any:
def __call__(self, *args, **kwargs) -> Any:
kwargs = {k: item for k, item in kwargs.items()}
kwargs.update(self.kwargs)
return self.fn(**kwargs)
6 changes: 5 additions & 1 deletion torchrl/modules/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,13 +991,17 @@ class LSTMNet(nn.Module):
>>> batch = 7
>>> time_steps = 6
>>> in_features = 4
>>> out_features = 10
>>> hidden_size = 5
>>> net = LSTMNet(
... out_features,
... {"input_size": hidden_size, "hidden_size": hidden_size},
... {"out_features": hidden_size},
... )
>>> # test single step vs multi-step
>>> x = torch.randn(batch, time_steps, in_features)
>>> x = torch.randn(batch, time_steps, in_features) # >3 dims = multi-step
>>> y, hidden0_in, hidden1_in, hidden0_out, hidden1_out = net(x)
>>> x = torch.randn(batch, in_features) # 2 dims = single step
>>> y, hidden0_in, hidden1_in, hidden0_out, hidden1_out = net(x)
"""
Expand Down
3 changes: 3 additions & 0 deletions torchrl/trainers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@
ClearCudaCache,
CountFramesLog,
LogReward,
mask_batch,
OptimizerHook,
Recorder,
ReplayBuffer,
ReplayBufferTrainer,
RewardNormalizer,
SelectKeys,
Trainer,
Expand Down
2 changes: 1 addition & 1 deletion torchrl/trainers/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def __init__(
total_frames: int,
frame_skip: int,
loss_module: Union[LossModule, Callable[[TensorDictBase], TensorDictBase]],
optimizer: optim.Optimizer,
optimizer: Optional[optim.Optimizer] = None,
logger: Optional[Logger] = None,
optim_steps_per_batch: int = 500,
clip_grad_norm: bool = True,
Expand Down

0 comments on commit 44fb1b7

Please sign in to comment.