Skip to content

Commit

Permalink
[Doc] Refactor DDPG and DQN tutos to narrow the scope (pytorch#979)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Apr 7, 2023
1 parent b4873b7 commit c3765cf
Show file tree
Hide file tree
Showing 29 changed files with 1,470 additions and 5,422 deletions.
Binary file added docs/source/_static/img/replaybuffer_traj.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3,824 changes: 2 additions & 3,822 deletions docs/source/_static/js/theme.js

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ Check the :obj:`torchrl.envs.utils.check_env_specs` method for a sanity check.
Utils
-----

.. currentmodule:: torchrl.data.datasets
.. currentmodule:: torchrl.data

.. autosummary::
:toctree: generated/
Expand Down
1 change: 0 additions & 1 deletion docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ provides more information on how to design a custom environment from scratch.
EnvBase
GymLikeEnv
EnvMetaData
Specs

Vectorized envs
---------------
Expand Down
2 changes: 1 addition & 1 deletion docs/source/reference/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ TensorDict modules

Hooks
-----
.. currentmodule:: torchrl.modules.tensordict_module.actors
.. currentmodule:: torchrl.modules

.. autosummary::
:toctree: generated/
Expand Down
8 changes: 5 additions & 3 deletions docs/source/reference/objectives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@ The main characteristics of TorchRL losses are:
method will receive a tensordict as input that contains all the necessary
information to return a loss value.
- They output a :class:`tensordict.TensorDict` instance with the loss values
written under a ``"loss_<smth>`` where ``smth`` is a string describing the
written under a ``"loss_<smth>"`` where ``smth`` is a string describing the
loss. Additional keys in the tensordict may be useful metrics to log during
training time.
.. note::
The reason we return independent losses is to let the user use a different
optimizer for different sets of parameters for instance. Summing the losses
can be simply done via ``sum(loss for key, loss in loss_vals.items() if key.startswith("loss_")``.
can be simply done via

>>> loss_val = sum(loss for key, loss in loss_vals.items() if key.startswith("loss_"))

Training value functions
------------------------
Expand Down Expand Up @@ -216,5 +218,5 @@ Utils
next_state_value
SoftUpdate
HardUpdate
ValueFunctions
ValueEstimators
default_value_kwargs
2 changes: 1 addition & 1 deletion docs/source/reference/trainers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ Hooks can be split into 3 categories: **data processing** (:obj:`"batch_process"
- **Data processing** hooks update a tensordict of data. Hooks :obj:`__call__` method should accept
a :obj:`TensorDict` object as input and update it given some strategy.
Examples of such hooks include Replay Buffer extension (:obj:`ReplayBufferTrainer.extend`), data normalization (including normalization
constants update), data subsampling (:doc:`BatchSubSampler`) and such.
constants update), data subsampling (:class:`torchrl.trainers.BatchSubSampler`) and such.

- **Logging** hooks take a batch of data presented as a :obj:`TensorDict` and write in the logger
some information retrieved from that data. Examples include the :obj:`Recorder` hook, the reward
Expand Down
17 changes: 8 additions & 9 deletions test/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,10 @@ class MockingLossModule(nn.Module):

def mocking_trainer(file=None, optimizer=_mocking_optim) -> Trainer:
trainer = Trainer(
MockingCollector(),
*[
None,
]
* 2,
collector=MockingCollector(),
total_frames=None,
frame_skip=None,
optim_steps_per_batch=None,
loss_module=MockingLossModule(),
optimizer=optimizer,
save_trainer_file=file,
Expand Down Expand Up @@ -862,7 +861,7 @@ def test_recorder(self, N=8):
with tempfile.TemporaryDirectory() as folder:
logger = TensorboardLogger(exp_name=folder)

recorder = transformed_env_constructor(
environment = transformed_env_constructor(
args,
video_tag="tmp",
norm_obs_only=True,
Expand All @@ -874,7 +873,7 @@ def test_recorder(self, N=8):
record_frames=args.record_frames,
frame_skip=args.frame_skip,
policy_exploration=None,
recorder=recorder,
environment=environment,
record_interval=args.record_interval,
)
trainer = mocking_trainer()
Expand Down Expand Up @@ -936,7 +935,7 @@ def _make_recorder_and_trainer(tmpdirname):
raise NotImplementedError
trainer = mocking_trainer(file)

recorder = transformed_env_constructor(
environment = transformed_env_constructor(
args,
video_tag="tmp",
norm_obs_only=True,
Expand All @@ -948,7 +947,7 @@ def _make_recorder_and_trainer(tmpdirname):
record_frames=args.record_frames,
frame_skip=args.frame_skip,
policy_exploration=None,
recorder=recorder,
environment=environment,
record_interval=args.record_interval,
)
recorder.register(trainer)
Expand Down
1 change: 1 addition & 0 deletions torchrl/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from . import datasets
from .postprocs import MultiStep
from .replay_buffers import (
LazyMemmapStorage,
Expand Down
1 change: 1 addition & 0 deletions torchrl/data/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .d4rl import D4RLExperienceReplay
from .openml import OpenMLExperienceReplay
9 changes: 7 additions & 2 deletions torchrl/data/datasets/openml.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,13 @@
import numpy as np
from tensordict.tensordict import TensorDict

from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.data.replay_buffers import Sampler, SamplerWithoutReplacement, Writer
from torchrl.data.replay_buffers import (
LazyMemmapStorage,
Sampler,
SamplerWithoutReplacement,
TensorDictReplayBuffer,
Writer,
)


class OpenMLExperienceReplay(TensorDictReplayBuffer):
Expand Down
14 changes: 7 additions & 7 deletions torchrl/data/postprocs/postprocs.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ def _get_reward(
class MultiStep(nn.Module):
"""Multistep reward transform.
Presented in 'Sutton, R. S. 1988. Learning to
predict by the methods of temporal differences. Machine learning 3(
1):9–44.'
Presented in
| Sutton, R. S. 1988. Learning to predict by the methods of temporal differences. Machine learning 3(1):9–44.
This module maps the "next" observation to the t + n "next" observation.
It is an identity transform whenever :attr:`n_steps` is 0.
Expand Down Expand Up @@ -153,6 +153,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
"""
tensordict = tensordict.clone(False)
done = tensordict.get(("next", "done"))
truncated = tensordict.get(
("next", "truncated"), torch.zeros((), dtype=done.dtype, device=done.device)
)
done = done | truncated

# we'll be using the done states to index the tensordict.
# if the shapes don't match we're in trouble.
Expand All @@ -175,10 +179,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
"(trailing singleton dimension excluded)."
) from err

truncated = tensordict.get(
("next", "truncated"), torch.zeros((), dtype=done.dtype, device=done.device)
)
done = done | truncated
mask = tensordict.get(("collector", "mask"), None)
reward = tensordict.get(("next", "reward"))
*batch, T = tensordict.batch_size
Expand Down
18 changes: 4 additions & 14 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import torch
from tensordict.tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase
from tensordict.utils import expand_right
from tensordict.utils import expand_as_right

from torchrl.data.utils import DEVICE_TYPING

Expand Down Expand Up @@ -708,6 +708,8 @@ def extend(self, tensordicts: Union[List, TensorDictBase]) -> torch.Tensor:
return index

def update_tensordict_priority(self, data: TensorDictBase) -> None:
if not isinstance(self._sampler, PrioritizedSampler):
return
priority = torch.tensor(
[self._get_priority(td) for td in data],
dtype=torch.float,
Expand Down Expand Up @@ -753,19 +755,7 @@ def sample(
data, info = super().sample(batch_size, return_info=True)
if include_info in (True, None):
for k, v in info.items():
data.set(k, torch.tensor(v, device=data.device))
if "_batch_size" in data.keys():
# we need to reset the batch-size
shape = data.pop("_batch_size")
shape = shape[0]
shape = torch.Size([data.shape[0], *shape])
# we may need to update some values in the data
for key, value in data.items():
if value.ndim >= len(shape):
continue
value = expand_right(value, shape)
data.set(key, value)
data.batch_size = shape
data.set(k, expand_as_right(torch.tensor(v, device=data.device), data))
if return_info:
return data, info
return data
Expand Down
37 changes: 35 additions & 2 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from tensordict.memmap import MemmapTensor
from tensordict.prototype import is_tensorclass
from tensordict.tensordict import is_tensor_collection, TensorDict, TensorDictBase
from tensordict.utils import expand_right

from torchrl._utils import _CKPT_BACKEND, VERBOSE
from torchrl.data.replay_buffers.utils import INT_CLASSES
Expand Down Expand Up @@ -423,10 +424,42 @@ def _mem_map_tensor_as_tensor(mem_map_tensor: MemmapTensor) -> torch.Tensor:
return mem_map_tensor._tensor


def _reset_batch_size(x):
"""Resets the batch size of a tensordict.
In some cases we save the original shape of the tensordict as a tensor (or memmap tensor).
This function will read that tensor, extract its items and reset the shape
of the tensordict to it. If items have an incompatible shape (e.g. "index")
they will be expanded to the right to match it.
"""
shape = x.pop("_batch_size", None)
if shape is not None:
# we need to reset the batch-size
if isinstance(shape, MemmapTensor):
shape = shape.as_tensor()
locked = x.is_locked
if locked:
x.unlock_()
shape = [s.item() for s in shape[0]]
shape = torch.Size([x.shape[0], *shape])
# we may need to update some values in the data
for key, value in x.items():
if value.ndim >= len(shape):
continue
value = expand_right(value, shape)
x.set(key, value)
x.batch_size = shape
if locked:
x.lock_()
return x


def _collate_list_tensordict(x):
out = torch.stack(x, 0)
if isinstance(out, TensorDictBase):
return out.to_tensordict()
return _reset_batch_size(out.to_tensordict())
return out


Expand All @@ -436,7 +469,7 @@ def _collate_list_tensors(*x):

def _collate_contiguous(x):
if isinstance(x, TensorDictBase):
return x.to_tensordict()
return _reset_batch_size(x).to_tensordict()
return x.clone()


Expand Down
60 changes: 51 additions & 9 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2602,6 +2602,13 @@ class VecNorm(Transform):
default: 0.99
eps (number, optional): lower bound of the running standard
deviation (for numerical underflow). Default is 1e-4.
shapes (List[torch.Size], optional): if provided, represents the shape
of each in_keys. Its length must match the one of ``in_keys``.
Each shape must match the trailing dimension of the corresponding
entry.
If not, the feature dimensions of the entry (ie all dims that do
not belong to the tensordict batch-size) will be considered as
feature dimension.
Examples:
>>> from torchrl.envs.libs.gym import GymEnv
Expand Down Expand Up @@ -2629,6 +2636,7 @@ def __init__(
lock: mp.Lock = None,
decay: float = 0.9999,
eps: float = 1e-4,
shapes: List[torch.Size] = None,
) -> None:
if lock is None:
lock = mp.Lock()
Expand Down Expand Up @@ -2656,8 +2664,14 @@ def __init__(

self.lock = lock
self.decay = decay
self.shapes = shapes
self.eps = eps

def _key_str(self, key):
if not isinstance(key, str):
key = "_".join(key)
return key

def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
if self.lock is not None:
self.lock.acquire()
Expand All @@ -2681,17 +2695,44 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
forward = _call

def _init(self, tensordict: TensorDictBase, key: str) -> None:
if self._td is None or key + "_sum" not in self._td.keys():
td_view = tensordict.view(-1)
td_select = td_view[0]
d = {key + "_sum": torch.zeros_like(td_select.get(key))}
d.update({key + "_ssq": torch.zeros_like(td_select.get(key))})
key_str = self._key_str(key)
if self._td is None or key_str + "_sum" not in self._td.keys():
if key is not key_str and key_str in tensordict.keys():
raise RuntimeError(
f"Conflicting key names: {key_str} from VecNorm and input tensordict keys."
)
if self.shapes is None:
td_view = tensordict.view(-1)
td_select = td_view[0]
item = td_select.get(key)
d = {key_str + "_sum": torch.zeros_like(item)}
d.update({key_str + "_ssq": torch.zeros_like(item)})
else:
idx = 0
for in_key in self.in_keys:
if in_key != key:
idx += 1
else:
break
shape = self.shapes[idx]
item = tensordict.get(key)
d = {
key_str
+ "_sum": torch.zeros(shape, device=item.device, dtype=item.dtype)
}
d.update(
{
key_str
+ "_ssq": torch.zeros(
shape, device=item.device, dtype=item.dtype
)
}
)

d.update(
{
key
+ "_count": torch.zeros(
1, device=td_select.get(key).device, dtype=torch.float
)
key_str
+ "_count": torch.zeros(1, device=item.device, dtype=torch.float)
}
)
if self._td is None:
Expand All @@ -2702,6 +2743,7 @@ def _init(self, tensordict: TensorDictBase, key: str) -> None:
pass

def _update(self, key, value, N) -> torch.Tensor:
key = self._key_str(key)
_sum = self._td.get(key + "_sum")
_ssq = self._td.get(key + "_ssq")
_count = self._td.get(key + "_count")
Expand Down
2 changes: 2 additions & 0 deletions torchrl/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,12 @@
ActorValueOperator,
AdditiveGaussianWrapper,
DistributionalQValueActor,
DistributionalQValueHook,
EGreedyWrapper,
OrnsteinUhlenbeckProcessWrapper,
ProbabilisticActor,
QValueActor,
QValueHook,
SafeModule,
SafeProbabilisticModule,
SafeProbabilisticTensorDictSequential,
Expand Down
2 changes: 2 additions & 0 deletions torchrl/modules/tensordict_module/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
ActorCriticWrapper,
ActorValueOperator,
DistributionalQValueActor,
DistributionalQValueHook,
ProbabilisticActor,
QValueActor,
QValueHook,
ValueOperator,
)
from .common import SafeModule
Expand Down
Loading

0 comments on commit c3765cf

Please sign in to comment.