From 041dd53189cc8b1350c6a6665e990a6f518d928a Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 19 Oct 2022 14:18:20 +0100 Subject: [PATCH] [Feature] More restrictive tests on docstrings (#457) --- .pre-commit-config.yaml | 1 + docs/source/reference/objectives.rst | 22 +- setup.cfg | 11 +- test/test_postprocs.py | 2 + test/test_tensor_spec.py | 2 + torchrl/_extension.py | 4 +- torchrl/_utils.py | 28 +- torchrl/collectors/collectors.py | 35 +- torchrl/collectors/utils.py | 5 +- torchrl/data/postprocs/postprocs.py | 10 +- torchrl/data/replay_buffers/rb_prototype.py | 19 +- torchrl/data/replay_buffers/replay_buffers.py | 151 ++----- torchrl/data/replay_buffers/samplers.py | 23 +- torchrl/data/replay_buffers/storages.py | 11 +- torchrl/data/replay_buffers/utils.py | 61 +-- torchrl/data/replay_buffers/writers.py | 14 +- torchrl/data/tensor_specs.py | 91 ++--- torchrl/data/tensordict/memmap.py | 16 +- torchrl/data/tensordict/metatensor.py | 18 +- torchrl/data/tensordict/tensordict.py | 299 +++++++------- torchrl/data/tensordict/utils.py | 32 +- torchrl/data/utils.py | 8 +- torchrl/envs/common.py | 58 +-- torchrl/envs/env_creator.py | 1 + torchrl/envs/gym_like.py | 40 +- torchrl/envs/libs/dm_control.py | 8 +- torchrl/envs/libs/gym.py | 11 +- torchrl/envs/libs/utils.py | 46 +-- torchrl/envs/model_based/common.py | 4 +- torchrl/envs/transforms/functional.py | 1 + torchrl/envs/transforms/r3m.py | 2 +- torchrl/envs/transforms/transforms.py | 132 +++--- torchrl/envs/transforms/utils.py | 18 +- torchrl/envs/utils.py | 66 +-- torchrl/envs/vec_env.py | 22 +- torchrl/modules/distributions/continuous.py | 20 +- .../modules/distributions/truncated_normal.py | 10 +- torchrl/modules/functional_modules.py | 56 ++- torchrl/modules/models/exploration.py | 25 +- torchrl/modules/models/models.py | 44 +- torchrl/modules/models/recipes/impala.py | 8 +- torchrl/modules/models/utils.py | 22 +- torchrl/modules/planners/cem.py | 2 +- torchrl/modules/planners/common.py | 6 +- torchrl/modules/tensordict_module/actors.py | 62 +-- torchrl/modules/tensordict_module/common.py | 27 +- torchrl/modules/tensordict_module/deprec.py | 384 ------------------ .../modules/tensordict_module/exploration.py | 20 +- .../tensordict_module/probabilistic.py | 30 +- torchrl/modules/tensordict_module/sequence.py | 22 +- .../modules/tensordict_module/world_models.py | 8 +- torchrl/modules/utils/mappings.py | 21 +- torchrl/objectives/__init__.py | 3 +- torchrl/objectives/common.py | 10 +- torchrl/objectives/ddpg.py | 9 +- torchrl/objectives/deprecated.py | 8 +- torchrl/objectives/dqn.py | 16 +- torchrl/objectives/functional.py | 6 +- torchrl/objectives/ppo.py | 9 +- torchrl/objectives/redq.py | 7 +- torchrl/objectives/reinforce.py | 5 +- torchrl/objectives/sac.py | 15 +- torchrl/objectives/utils.py | 22 +- torchrl/objectives/value/__init__.py | 2 - torchrl/objectives/value/advantages.py | 7 +- torchrl/objectives/value/functional.py | 19 +- torchrl/objectives/value/returns.py | 23 -- torchrl/objectives/value/utils.py | 17 +- torchrl/objectives/value/vtrace.py | 12 +- torchrl/record/recorder.py | 15 +- torchrl/trainers/helpers/collectors.py | 19 +- torchrl/trainers/helpers/envs.py | 22 +- torchrl/trainers/helpers/logger.py | 2 + torchrl/trainers/helpers/models.py | 23 +- torchrl/trainers/loggers/common.py | 5 +- torchrl/trainers/loggers/csv.py | 19 +- torchrl/trainers/loggers/mlflow.py | 3 +- torchrl/trainers/loggers/tensorboard.py | 19 +- torchrl/trainers/loggers/wandb.py | 21 +- torchrl/trainers/trainers.py | 18 +- 80 files changed, 876 insertions(+), 1519 deletions(-) delete mode 100644 torchrl/modules/tensordict_module/deprec.py delete mode 100644 torchrl/objectives/value/returns.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index dccc1c1a9b2..1afafcca1cc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,3 +28,4 @@ repos: rev: 6.1.1 hooks: - id: pydocstyle + files: ^torchrl/ diff --git a/docs/source/reference/objectives.rst b/docs/source/reference/objectives.rst index c6b68157084..579a4f67ffe 100644 --- a/docs/source/reference/objectives.rst +++ b/docs/source/reference/objectives.rst @@ -53,25 +53,27 @@ PPO Returns ------- +.. currentmodule:: torchrl.objectives.value .. autosummary:: :toctree: generated/ :template: rl_template_noinherit.rst - value.GAE - value.TDLambdaEstimate - value.TDEstimate - value.functional.generalized_advantage_estimate - value.functional.vec_generalized_advantage_estimate - value.functional.vec_td_lambda_return_estimate - value.functional.vec_td_lambda_advantage_estimate - value.functional.td_lambda_return_estimate - value.functional.td_lambda_advantage_estimate - value.functional.td_advantage_estimate + GAE + TDLambdaEstimate + TDEstimate + functional.generalized_advantage_estimate + functional.vec_generalized_advantage_estimate + functional.vec_td_lambda_return_estimate + functional.vec_td_lambda_advantage_estimate + functional.td_lambda_return_estimate + functional.td_lambda_advantage_estimate + functional.td_advantage_estimate Utils ----- +.. currentmodule:: torchrl.objectives .. autosummary:: :toctree: generated/ diff --git a/setup.cfg b/setup.cfg index 78b3c8fbafd..014de2f29a1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -17,4 +17,13 @@ per-file-ignores = exclude = venv [pydocstyle] -select = D417 # Missing argument descriptions in the docstring +;select = D417 # Missing argument descriptions in the docstring +;inherit = false +match = .*\.py +;match_dir = ^(?!(.circlecli|test)).* +convention = google +add-ignore = D100, D104, D105, D107, D102 +ignore-decorators = + test_* +; test/*.py +; .circleci/* diff --git a/test/test_postprocs.py b/test/test_postprocs.py index 003ab09012c..317668dd18d 100644 --- a/test/test_postprocs.py +++ b/test/test_postprocs.py @@ -90,6 +90,8 @@ def test_multistep(n, key, device, T=11): class TestSplits: + """Tests the splitting of collected tensordicts in trajectories.""" + @staticmethod def create_fake_trajs( num_workers=32, diff --git a/test/test_tensor_spec.py b/test/test_tensor_spec.py index 67a1adc0ac4..0c624cc50cb 100644 --- a/test/test_tensor_spec.py +++ b/test/test_tensor_spec.py @@ -384,6 +384,8 @@ def test_nested_composite_spec(self, is_complete, device, dtype): class TestEquality: + """Tests spec comparison.""" + @staticmethod def _ts_make_all_fields_equal(ts_to, ts_from): ts_to.shape = ts_from.shape diff --git a/torchrl/_extension.py b/torchrl/_extension.py index 0047a23b6b1..30788159a57 100644 --- a/torchrl/_extension.py +++ b/torchrl/_extension.py @@ -8,13 +8,13 @@ def is_module_available(*modules: str) -> bool: - r"""Returns if a top-level module with :attr:`name` exists *without** - importing it. + r"""Returns if a top-level module with :attr:`name` exists *without** importing it. This is generally safer than try-catch block around a `import X`. It avoids third party libraries breaking assumptions of some of our tests, e.g., setting multiprocessing start method when imported (see librosa/#747, torchvision/#544). + """ return all(importlib.util.find_spec(m) is not None for m in modules) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index b7ed60dd160..a05ad395dfa 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -6,9 +6,7 @@ class timeit: - """ - A dirty but easy to use decorator for profiling code - """ + """A dirty but easy to use decorator for profiling code.""" _REG = {} @@ -71,6 +69,17 @@ def _check_for_faulty_process(processes): def seed_generator(seed): + """A seed generator function. + + Given a seeding integer, generates a deterministic next seed to be used in a + seeding sequence. + + Args: + seed (int): initial seed. + + Returns: Next seed of the chain. + + """ max_seed_val = ( 2 ** 32 - 1 ) # https://discuss.pytorch.org/t/what-is-the-max-seed-you-can-set-up/145688 @@ -80,6 +89,14 @@ def seed_generator(seed): class KeyDependentDefaultDict(collections.defaultdict): + """A key-dependent default dict. + + Examples: + >>> my_dict = KeyDependentDefaultDict(lambda key: "foo_" + key) + >>> print(my_dict["bar"]) + foo_bar + """ + def __init__(self, fun): self.fun = fun super().__init__() @@ -91,6 +108,11 @@ def __missing__(self, key): def prod(sequence): + """General prod function, that generalised usage across math and np. + + Created for multiple python versions compatibility). + + """ if hasattr(math, "prod"): return math.prod(sequence) else: diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 4fc49a8a59a..94b376ec748 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -49,6 +49,7 @@ class RandomPolicy: def __init__(self, action_spec: TensorSpec): """Random policy for a given action_spec. + This is a wrapper around the action_spec.rand method. @@ -63,6 +64,7 @@ def __init__(self, action_spec: TensorSpec): >>> action_spec = NdBoundedTensorSpec(-torch.ones(3), torch.ones(3)) >>> actor = RandomPolicy(spec=action_spec) >>> td = actor(TensorDict(batch_size=[])) # selects a random action in the cube [-1; 1] + """ self.action_spec = action_spec @@ -127,7 +129,9 @@ def _get_policy_and_device( ) -> Tuple[ ProbabilisticTensorDictModule, torch.device, Union[None, Callable[[], dict]] ]: - """From a policy and a device, assigns the self.device attribute to + """Util method to get a policy and its device given the collector __init__ inputs. + + From a policy and a device, assigns the self.device attribute to the desired device and maps the policy onto it or (if the device is ommitted) assigns the self.device attribute to the policy device. @@ -247,8 +251,7 @@ def __repr__(self) -> str: class SyncDataCollector(_DataCollector): - """ - Generic data collector for RL problems. Requires and environment constructor and a policy. + """Generic data collector for RL problems. Requires and environment constructor and a policy. Args: create_env_fn (Callable), returns an instance of EnvBase class. @@ -684,15 +687,13 @@ def __del__(self): self.shutdown() # make sure env is closed def state_dict(self) -> OrderedDict: - """Returns the local state_dict of the data collector (environment - and policy). + """Returns the local state_dict of the data collector (environment and policy). Returns: - an ordered dictionary with fields `"policy_state_dict"` and + an ordered dictionary with fields :obj:`"policy_state_dict"` and `"env_state_dict"`. """ - if isinstance(self.env, TransformedEnv): env_state_dict = self.env.transform.state_dict() elif isinstance(self.env, _BatchedEnv): @@ -716,7 +717,7 @@ def load_state_dict(self, state_dict: OrderedDict, **kwargs) -> None: Args: state_dict (OrderedDict): ordered dictionary containing the fields - `"policy_state_dict"` and `"env_state_dict"`. + `"policy_state_dict"` and :obj:`"env_state_dict"`. """ strict = kwargs.get("strict", True) @@ -791,7 +792,7 @@ class _MultiDataCollector(_DataCollector): reset_when_done (bool, optional): if True, the contained environment will be reset every time it hits a done. If the env contains multiple independent envs, a reset index will be passed to it to reset only thos environments that need to - be reset. In practice, this will happen through a call to `env.reset(tensordict)`, + be reset. In practice, this will happen through a call to :obj:`env.reset(tensordict)`, in other words, if the env is a multi-agent env, all agents will be reset once one of them is done. Defaults to `True`. @@ -1081,8 +1082,8 @@ def reset(self, reset_idx: Optional[Sequence[bool]] = None) -> None: raise RuntimeError(f"Expected msg='reset', got {msg}") def state_dict(self) -> OrderedDict: - """ - Returns the state_dict of the data collector. + """Returns the state_dict of the data collector. + Each field represents a worker containing its own state_dict. """ @@ -1098,15 +1099,13 @@ def state_dict(self) -> OrderedDict: return state_dict def load_state_dict(self, state_dict: OrderedDict) -> None: - """ - Loads the state_dict on the workers. + """Loads the state_dict on the workers. Args: state_dict (OrderedDict): state_dict of the form ``{"worker0": state_dict0, "worker1": state_dict1}``. """ - for idx in range(self.num_workers): self.pipes[idx].send((state_dict[f"worker{idx}"], "load_state_dict")) for idx in range(self.num_workers): @@ -1116,13 +1115,13 @@ def load_state_dict(self, state_dict: OrderedDict) -> None: class MultiSyncDataCollector(_MultiDataCollector): - """Runs a given number of DataCollectors on separate processes - synchronously. + """Runs a given number of DataCollectors on separate processes synchronously. The collection starts when the next item of the collector is queried, and no environment step is computed in between the reception of a batch of trajectory and the start of the next collection. This class can be safely used with online RL algorithms. + """ __doc__ += _MultiDataCollector.__doc__ @@ -1212,12 +1211,12 @@ def iterator(self) -> Iterator[TensorDictBase]: class MultiaSyncDataCollector(_MultiDataCollector): - """Runs a given number of DataCollectors on separate processes - asynchronously. + """Runs a given number of DataCollectors on separate processes asynchronously. The collection keeps on occuring on all processes even between the time the batch of rollouts is collected and the next call to the iterator. This class can be safely used with offline RL algorithms. + """ __doc__ += _MultiDataCollector.__doc__ diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py index 949f8e6f5c1..06d5eb6b219 100644 --- a/torchrl/collectors/utils.py +++ b/torchrl/collectors/utils.py @@ -28,7 +28,10 @@ def stacked_output_fun(*args, **kwargs): def split_trajectories(rollout_tensordict: TensorDictBase) -> TensorDictBase: - """Takes a tensordict with a key traj_ids that indicates the id of each trajectory. + """A util function for trajectory separation. + + Takes a tensordict with a key traj_ids that indicates the id of each trajectory. + From there, builds a B x T x ... zero-padded tensordict with B batches on max duration T """ traj_ids = rollout_tensordict.get("traj_ids") diff --git a/torchrl/data/postprocs/postprocs.py b/torchrl/data/postprocs/postprocs.py index db9538b0793..21f0cd6e474 100644 --- a/torchrl/data/postprocs/postprocs.py +++ b/torchrl/data/postprocs/postprocs.py @@ -107,8 +107,9 @@ def select_and_repeat( class MultiStep(nn.Module): - """ - Multistep reward, as presented in 'Sutton, R. S. 1988. Learning to + """Multistep reward transform. + + Presented in 'Sutton, R. S. 1988. Learning to predict by the methods of temporal differences. Machine learning 3( 1):9–44.' @@ -140,7 +141,9 @@ def __init__( ) def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - """Args: + """Re-writes a tensordict following the multi-step transform. + + Args: tensordict: TennsorDict instance with Batch x Time-steps x ... dimensions. The TensorDict must contain a "reward" and "done" key. All @@ -160,6 +163,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - The "reward" values will be replaced by the newly computed rewards. + Returns: in-place transformation of the input tensordict. diff --git a/torchrl/data/replay_buffers/rb_prototype.py b/torchrl/data/replay_buffers/rb_prototype.py index 8c273937416..73809aee9b7 100644 --- a/torchrl/data/replay_buffers/rb_prototype.py +++ b/torchrl/data/replay_buffers/rb_prototype.py @@ -9,13 +9,13 @@ from .replay_buffers import pin_memory_output, stack_tensors, stack_td from .samplers import Sampler, RandomSampler from .storages import Storage, ListStorage -from .utils import INT_CLASSES, to_numpy +from .utils import INT_CLASSES, _to_numpy from .writers import Writer, RoundRobinWriter class ReplayBuffer: - """ - #TODO: Description of the ReplayBuffer class needed. + """A generic, composable replay buffer class. + Args: storage (Storage, optional): the storage to be used. If none is provided a default ListStorage with max_size of 1_000 will be created. @@ -73,7 +73,7 @@ def __repr__(self) -> str: @pin_memory_output def __getitem__(self, index: Union[int, torch.Tensor]) -> Any: - index = to_numpy(index) + index = _to_numpy(index) with self._replay_lock: data = self._storage[index] @@ -97,8 +97,7 @@ def add(self, data: Any) -> int: return index def extend(self, data: Sequence) -> torch.Tensor: - """Extends the replay buffer with one or more elements contained in - an iterable. + """Extends the replay buffer with one or more elements contained in an iterable. Args: data (iterable): collection of data to be added to the replay @@ -130,8 +129,8 @@ def _sample(self, batch_size: int) -> Tuple[Any, dict]: return data, info def sample(self, batch_size: int) -> Tuple[Any, dict]: - """ - Samples a batch of data from the replay buffer. + """Samples a batch of data from the replay buffer. + Uses Sampler to sample indices, and retrieves them from Storage. Args: @@ -158,8 +157,8 @@ def sample(self, batch_size: int) -> Tuple[Any, dict]: class TensorDictReplayBuffer(ReplayBuffer): - """ - TensorDict-specific wrapper around the ReplayBuffer class. + """TensorDict-specific wrapper around the ReplayBuffer class. + Args: priority_key (str): the key at which priority is assumed to be stored within TensorDicts added to this ReplayBuffer. diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index 7b1223b487e..7652cd66e10 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -5,7 +5,6 @@ import collections import concurrent.futures -import functools import threading from typing import Any, Callable, List, Optional, Sequence, Tuple, Union @@ -22,13 +21,12 @@ from torchrl.data.replay_buffers.storages import Storage, ListStorage from torchrl.data.replay_buffers.utils import INT_CLASSES from torchrl.data.replay_buffers.utils import ( - cat_fields_to_device, - to_numpy, - to_torch, + _to_numpy, + _to_torch, ) from torchrl.data.tensordict.tensordict import ( TensorDictBase, - stack as stack_td, + _stack as stack_td, LazyStackedTensorDict, ) from torchrl.data.utils import DEVICE_TYPING @@ -38,14 +36,11 @@ "PrioritizedReplayBuffer", "TensorDictReplayBuffer", "TensorDictPrioritizedReplayBuffer", - "create_replay_buffer", - "create_prioritized_replay_buffer", ] def stack_tensors(list_of_tensor_iterators: List) -> Tuple[torch.Tensor]: - """Zips a list of iterables containing tensor-like objects and stacks the - resulting lists of tensors together. + """Zips a list of iterables containing tensor-like objects and stacks the resulting lists of tensors together. Args: list_of_tensor_iterators (list): Sequence containing similar iterators, @@ -82,8 +77,7 @@ def _pin_memory(output: Any) -> Any: def pin_memory_output(fun) -> Callable: - """Calls pin_memory on outputs of decorated function if they have such - method.""" + """Calls pin_memory on outputs of decorated function if they have such method.""" def decorated_fun(self, *args, **kwargs): output = fun(self, *args, **kwargs) @@ -102,8 +96,7 @@ def decorated_fun(self, *args, **kwargs): class ReplayBuffer: - """ - Circular replay buffer. + """Circular replay buffer. Args: size (int): integer indicating the maximum size of the replay buffer. @@ -153,7 +146,7 @@ def __len__(self) -> int: @pin_memory_output def __getitem__(self, index: Union[int, Tensor]) -> Any: - index = to_numpy(index) + index = _to_numpy(index) with self._replay_lock: data = self._storage[index] @@ -187,8 +180,7 @@ def add(self, data: Any) -> int: return ret def extend(self, data: Sequence[Any]): - """Extends the replay buffer with one or more elements contained in - an iterable. + """Extends the replay buffer with one or more elements contained in an iterable. Args: data (iterable): collection of data to be added to the replay @@ -281,8 +273,9 @@ def __repr__(self) -> str: class PrioritizedReplayBuffer(ReplayBuffer): - """ - Prioritized replay buffer as presented in + """Prioritized replay buffer. + + Presented in "Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015. Prioritized experience replay." (https://arxiv.org/abs/1511.05952) @@ -348,7 +341,7 @@ def __init__( @pin_memory_output def __getitem__(self, index: Union[int, Tensor]) -> Any: - index = to_numpy(index) + index = _to_numpy(index) with self._replay_lock: p_min = self._min_tree.query(0, self._capacity) @@ -367,7 +360,7 @@ def __getitem__(self, index: Union[int, Tensor]) -> Any: # x = first_field(data) # if isinstance(x, torch.Tensor): device = data.device if hasattr(data, "device") else torch.device("cpu") - weight = to_torch(weight, device, self._pin_memory) + weight = _to_torch(weight, device, self._pin_memory) return data, weight @property @@ -398,7 +391,7 @@ def _add_or_extend( do_add: bool = True, ) -> torch.Tensor: if priority is not None: - priority = to_numpy(priority) + priority = _to_numpy(priority) max_priority = np.max(priority) with self._replay_lock: self._max_priority = max(self._max_priority, max_priority) @@ -470,18 +463,16 @@ def _sample(self, batch_size: int) -> Tuple[Any, torch.Tensor, torch.Tensor]: # x = first_field(data) # avoid calling tree.flatten # if isinstance(x, torch.Tensor): device = data.device if hasattr(data, "device") else torch.device("cpu") - weight = to_torch(weight, device, self._pin_memory) + weight = _to_torch(weight, device, self._pin_memory) return data, weight, index def sample(self, batch_size: int) -> Tuple[Any, np.ndarray, torch.Tensor]: - """Gather a batch of data according to the non-uniform multinomial - distribution with weights computed with the provided priorities of - each input. + """Gathers a batch of data according to the non-uniform multinomial distribution with weights computed with the provided priorities of each input. Args: batch_size (int): float of data to be collected. - Returns: + Returns: a random sample from the replay buffer. """ if not self._prefetch: @@ -529,8 +520,8 @@ def update_priority( "priority should be a number or an iterable of the same " "length as index" ) - index = to_numpy(index) - priority = to_numpy(priority) + index = _to_numpy(index) + priority = _to_numpy(priority) with self._replay_lock: self._max_priority = max(self._max_priority, np.max(priority)) @@ -540,9 +531,7 @@ def update_priority( class TensorDictReplayBuffer(ReplayBuffer): - """ - TensorDict-specific wrapper around the ReplayBuffer class. - """ + """TensorDict-specific wrapper around the ReplayBuffer class.""" def __init__( self, @@ -561,8 +550,8 @@ def collate_fn(x): class TensorDictPrioritizedReplayBuffer(PrioritizedReplayBuffer): - """ - TensorDict-specific wrapper around the PrioritizedReplayBuffer class. + """TensorDict-specific wrapper around the PrioritizedReplayBuffer class. + This class returns tensordicts with a new key "index" that represents the index of each element in the replay buffer. It also facilitates the call to the 'update_priority' method, as it only requires for the @@ -574,14 +563,14 @@ class TensorDictPrioritizedReplayBuffer(PrioritizedReplayBuffer): used, with α = 0 corresponding to the uniform case. beta (float): importance sampling negative exponent. priority_key (str, optional): key where the priority value can be - found in the stored tensordicts. Default is `"td_error"` + found in the stored tensordicts. Default is :obj:`"td_error"` eps (float, optional): delta added to the priorities to ensure that the buffer does not contain null priorities. collate_fn (callable, optional): merges a list of samples to form a mini-batch of Tensor(s)/outputs. Used when using batched loading from a map-style dataset. pin_memory (bool, optional): whether pin_memory() should be called on - the rb samples. Default is `False`. + the rb samples. Default is :obj:`False`. prefetch (int, optional): number of next batches to be prefetched using multithreading. storage (Storage, optional): the storage to be used. If none is provided, @@ -681,8 +670,7 @@ def extend( return idx def update_priority(self, tensordict: TensorDictBase) -> None: - """Updates the priorities of the tensordicts stored in the replay - buffer. + """Updates the priorities of the tensordicts stored in the replay buffer. Args: tensordict: tensordict with key-value pairs 'self.priority_key' @@ -699,10 +687,7 @@ def update_priority(self, tensordict: TensorDictBase) -> None: return super().update_priority(tensordict.get("index"), priority=priority) def sample(self, size: int, return_weight: bool = False) -> TensorDictBase: - """ - Gather a batch of tensordicts according to the non-uniform multinomial - distribution with weights computed with the priority_key of each - input tensordict. + """Gather a batch of tensordicts according to the non-uniform multinomial distribution with weights computed with the priority_key of each input tensordict. Args: size (int): size of the batch to be returned @@ -720,90 +705,6 @@ def sample(self, size: int, return_weight: bool = False) -> TensorDictBase: return td -def create_replay_buffer( - size: int, - device: Optional[DEVICE_TYPING] = None, - collate_fn: Callable = None, - pin_memory: bool = False, - prefetch: Optional[int] = None, -) -> ReplayBuffer: - """ - Helper function to create a Replay buffer. - - Args: - size (int): integer indicating the maximum size of the replay buffer. - device (str, int or torch.device, optional): device where to cast the - samples. - collate_fn (callable, optional): merges a list of samples to form a - mini-batch of Tensor(s)/outputs. Used when using batched loading - from a map-style dataset. - pin_memory (bool): whether pin_memory() should be called on the rb - samples. - prefetch (int, optional): number of next batches to be prefetched - using multithreading. - - Returns: - a ReplayBuffer instance - - """ - if isinstance(device, str): - device = torch.device(device) - - if device.type == "cuda" and collate_fn is None: - # Postman will add batch_dim for uploaded data, so using cat instead of - # stack here. - collate_fn = functools.partial(cat_fields_to_device, device=device) - - return ReplayBuffer(size, collate_fn, pin_memory, prefetch) - - -def create_prioritized_replay_buffer( - size: int, - alpha: float, - beta: float, - eps: float = 1e-8, - device: Optional[DEVICE_TYPING] = "cpu", - collate_fn: Callable = None, - pin_memory: bool = False, - prefetch: Optional[int] = None, -) -> PrioritizedReplayBuffer: - """ - Helper function to create a Prioritized Replay buffer. - - Args: - size (int): integer indicating the maximum size of the replay buffer. - alpha (float): exponent α determines how much prioritization is used, - with α = 0 corresponding to the uniform case. - beta (float): importance sampling negative exponent. - eps (float): delta added to the priorities to ensure that the buffer - does not contain null priorities. - device (str, int or torch.device, optional): device where to cast the - samples. - collate_fn (callable, optional): merges a list of samples to form a - mini-batch of Tensor(s)/outputs. Used when using batched loading - from a map-style dataset. - pin_memory (bool): whether pin_memory() should be called on the rb - samples. - prefetch (int, optional): number of next batches to be prefetched - using multithreading. - - Returns: - a ReplayBuffer instance - - """ - if isinstance(device, str): - device = torch.device(device) - - if device.type == "cuda" and collate_fn is None: - # Postman will add batch_dim for uploaded data, so using cat instead of - # stack here. - collate_fn = functools.partial(cat_fields_to_device, device=device) - - return PrioritizedReplayBuffer( - size, alpha, beta, eps, collate_fn, pin_memory, prefetch - ) - - class InPlaceSampler: def __init__(self, device: Optional[DEVICE_TYPING] = None): self.out = None diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 23e08697762..1da861b8183 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -11,10 +11,12 @@ SumSegmentTreeFp64, ) from .storages import Storage -from .utils import INT_CLASSES, to_numpy +from .utils import INT_CLASSES, _to_numpy class Sampler(ABC): + """A generic sampler base class for composable Replay Buffers.""" + @abstractmethod def sample(self, storage: Storage, batch_size: int) -> Tuple[Any, dict]: raise NotImplementedError @@ -36,15 +38,17 @@ def default_priority(self) -> float: class RandomSampler(Sampler): + """A uniformly random sampler for composable replay buffers.""" + def sample(self, storage: Storage, batch_size: int) -> Tuple[np.array, dict]: index = np.random.randint(0, len(storage), size=batch_size) return index, {} class PrioritizedSampler(Sampler): - """ - Prioritized sampler for replay buffer as presented in - "Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015. + """Prioritized sampler for replay buffer. + + Presented in "Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015. Prioritized experience replay." (https://arxiv.org/abs/1511.05952) @@ -54,6 +58,7 @@ class PrioritizedSampler(Sampler): beta (float): importance sampling negative exponent. eps (float): delta added to the priorities to ensure that the buffer does not contain null priorities. + """ def __init__( @@ -144,14 +149,14 @@ def extend(self, index: torch.Tensor) -> None: def update_priority( self, index: Union[int, torch.Tensor], priority: Union[float, torch.Tensor] ) -> None: - """ - Updates the priority of the data pointed by the index. + """Updates the priority of the data pointed by the index. Args: index (int or torch.Tensor): indexes of the priorities to be updated. priority (Number or torch.Tensor): new priorities of the - indexed elements + indexed elements. + """ if isinstance(index, INT_CLASSES): if not isinstance(priority, float): @@ -170,8 +175,8 @@ def update_priority( "priority should be a number or an iterable of the same " "length as index" ) - index = to_numpy(index) - priority = to_numpy(priority) + index = _to_numpy(index) + priority = _to_numpy(priority) self._max_priority = max(self._max_priority, np.max(priority)) priority = np.power(priority + self._eps, self._alpha) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index e8967ab0390..2aea8237ff3 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -54,6 +54,13 @@ def __len__(self): class ListStorage(Storage): + """A storage stored in a list. + + Args: + max_size (int): the maximum number of elements stored in the storage. + + """ + def __init__(self, max_size: int): super().__init__(max_size) self._storage = [] @@ -101,7 +108,7 @@ class LazyTensorStorage(Storage): size (int): size of the storage, i.e. maximum number of elements stored in the buffer. device (torch.device, optional): device where the sampled tensors will be - stored and sent. Default is `torch.device("cpu")`. + stored and sent. Default is :obj:`torch.device("cpu")`. """ def __init__(self, max_size, scratch_dir=None, device=None): @@ -176,7 +183,7 @@ class LazyMemmapStorage(LazyTensorStorage): in the buffer. scratch_dir (str or path): directory where memmap-tensors will be written. device (torch.device, optional): device where the sampled tensors will be - stored and sent. Default is `torch.device("cpu")`. + stored and sent. Default is :obj:`torch.device("cpu")`. """ def __init__(self, max_size, scratch_dir=None, device=None): diff --git a/torchrl/data/replay_buffers/utils.py b/torchrl/data/replay_buffers/utils.py index 30d749d4c66..d9e254d491d 100644 --- a/torchrl/data/replay_buffers/utils.py +++ b/torchrl/data/replay_buffers/utils.py @@ -18,49 +18,11 @@ INT_CLASSES = (int, np.integer) -def fields_pin_memory(input): - raise NotImplementedError - # return tree.map_structure(lambda x: pin_memory(x), input) - - -def pin_memory(data: Tensor) -> Tensor: - if isinstance(data, torch.Tensor): - return data.pin_memory() - else: - return data - - -def to_numpy(data: Tensor) -> np.ndarray: +def _to_numpy(data: Tensor) -> np.ndarray: return data.detach().cpu().numpy() if isinstance(data, torch.Tensor) else data -def fast_map(func, *inputs): - raise NotImplementedError - # flat_inputs = (tree.flatten(x) for x in inputs) - # entries = zip(*flat_inputs) - # return tree.unflatten_as(inputs[-1], [func(*x) for x in entries]) - - -def stack_tensors(input): - if not len(input): - raise RuntimeError("input length must be non-null") - if isinstance(input[0], torch.Tensor): - return torch.stack(input) - else: - return np.stack(input) - - -def stack_fields(input): - if not len(input): - raise RuntimeError("stack_fields requires non-empty list if tensors") - return fast_map(lambda *x: stack_tensors(x), *input) - - -def first_field(data) -> Tensor: - raise NotImplementedError - - -def to_torch( +def _to_torch( data: Tensor, device, pin_memory: bool = False, non_blocking: bool = False ) -> torch.Tensor: if isinstance(data, np.generic): @@ -75,22 +37,3 @@ def to_torch( data = data.to(device, non_blocking=non_blocking) return data - - -def cat_fields_to_device( - input, device, pin_memory: bool = False, non_blocking: bool = False -): - input_on_device = fields_to_device(input, device, pin_memory, non_blocking) - return cat_fields(input_on_device) - - -def cat_fields(input): - if not input: - raise RuntimeError("cat_fields requires a non-empty input collection.") - return fast_map(lambda *x: torch.cat(x), *input) - - -def fields_to_device( - input, device, pin_memory: bool = False, non_blocking: bool = False -): # type:ignore - raise NotImplementedError diff --git a/torchrl/data/replay_buffers/writers.py b/torchrl/data/replay_buffers/writers.py index 2dd2ce72e23..88657e1fcba 100644 --- a/torchrl/data/replay_buffers/writers.py +++ b/torchrl/data/replay_buffers/writers.py @@ -8,6 +8,8 @@ class Writer(ABC): + """A ReplayBuffer base Writer class.""" + def __init__(self) -> None: self._storage = None @@ -16,22 +18,18 @@ def register_storage(self, storage: Storage) -> None: @abstractmethod def add(self, data: Any) -> int: - """ - Inserts one piece of data at an appropriate index, - and returns that index. - """ + """Inserts one piece of data at an appropriate index, and returns that index.""" raise NotImplementedError @abstractmethod def extend(self, data: Sequence) -> torch.Tensor: - """ - Inserts a series of data points at appropriate indices, - and returns a tensor containing the indices. - """ + """Inserts a series of data points at appropriate indices, and returns a tensor containing the indices.""" raise NotImplementedError class RoundRobinWriter(Writer): + """A RoundRobin Writer class for composable replay buffers.""" + def __init__(self, **kw) -> None: super().__init__(**kw) self._cursor = 0 diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 8cd1c48a7f2..1a3cf1f88d8 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -73,6 +73,14 @@ def _default_dtype_and_device( class invertible_dict(dict): + """An invertible dictionary. + + Examples: + >>> my_dict = invertible_dict(a=3, b=2) + >>> inv_dict = my_dict.invert() + >>> assert {2, 3} == set(inv_dict.keys()) + """ + def __init__(self, *args, inv_dict=None, **kwargs): if inv_dict is None: inv_dict = dict() @@ -99,9 +107,7 @@ def inverse(self): class Box: - """ - A box of values - """ + """A box of values.""" def __iter__(self): raise NotImplementedError @@ -120,10 +126,7 @@ class Values: @dataclass(repr=False) class ContinuousBox(Box): - """ - A continuous box of values, in between a minimum and a maximum. - - """ + """A continuous box of values, in between a minimum and a maximum.""" minimum: torch.Tensor maximum: torch.Tensor @@ -154,10 +157,7 @@ def __eq__(self, other): @dataclass(repr=False) class DiscreteBox(Box): - """ - A box of discrete values - - """ + """A box of discrete values.""" n: int register = invertible_dict() @@ -171,10 +171,7 @@ def __repr__(self): @dataclass(repr=False) class BoxList(Box): - """ - A box of discrete values - - """ + """A box of discrete values.""" boxes: List @@ -191,10 +188,7 @@ def __repr__(self): @dataclass(repr=False) class BinaryBox(Box): - """ - A box of n binary values - - """ + """A box of n binary values.""" n: int @@ -207,9 +201,7 @@ def __repr__(self): @dataclass(repr=False) class TensorSpec: - """ - Parent class of the tensor meta-data containers for observation, actions - and rewards. + """Parent class of the tensor meta-data containers for observation, actions and rewards. Args: shape (torch.Size): size of the tensor @@ -227,8 +219,7 @@ class TensorSpec: domain: str = "" def encode(self, val: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: - """Encodes a value given the specified spec, and return the - corresponding tensor. + """Encodes a value given the specified spec, and return the corresponding tensor. Args: val (np.ndarray or torch.Tensor): value to be encoded as tensor. @@ -281,7 +272,7 @@ def to_numpy(self, val: torch.Tensor, safe: bool = True) -> np.ndarray: @abc.abstractmethod def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Tensor: - """Indexes the input tensor + """Indexes the input tensor. Args: index (int, torch.Tensor, slice or list): index of the tensor @@ -298,8 +289,7 @@ def _project(self, val: torch.Tensor) -> torch.Tensor: @abc.abstractmethod def is_in(self, val: torch.Tensor) -> bool: - """If the value `val` is in the box defined by the TensorSpec, - returns True, otherwise False. + """If the value :obj:`val` is in the box defined by the TensorSpec, returns True, otherwise False. Args: val (torch.Tensor): value to be checked @@ -311,8 +301,7 @@ def is_in(self, val: torch.Tensor) -> bool: raise NotImplementedError def project(self, val: torch.Tensor) -> torch.Tensor: - """If the input tensor is not in the TensorSpec box, it maps it back - to it given some heuristic. + """If the input tensor is not in the TensorSpec box, it maps it back to it given some heuristic. Args: val (torch.Tensor): tensor to be mapped to the box. @@ -326,8 +315,7 @@ def project(self, val: torch.Tensor) -> torch.Tensor: return val def assert_is_in(self, value: torch.Tensor) -> None: - """Asserts whether a tensor belongs to the box, and raises an - exception otherwise. + """Asserts whether a tensor belongs to the box, and raises an exception otherwise. Args: value (torch.Tensor): value to be checked. @@ -341,8 +329,7 @@ def assert_is_in(self, value: torch.Tensor) -> None: ) def type_check(self, value: torch.Tensor, key: str = None) -> None: - """Checks the input value dtype against the TensorSpec dtype and - raises an exception if they don't match. + """Checks the input value dtype against the TensorSpec dtype and raises an exception if they don't match. Args: value (torch.Tensor): tensor whose dtype has to be checked @@ -359,8 +346,7 @@ def type_check(self, value: torch.Tensor, key: str = None) -> None: @abc.abstractmethod def rand(self, shape=None) -> torch.Tensor: - """Returns a random tensor in the box. The sampling will be uniform - unless the box is unbounded. + """Returns a random tensor in the box. The sampling will be uniform unless the box is unbounded. Args: shape (torch.Size): shape of the random tensor @@ -409,8 +395,7 @@ def __repr__(self): @dataclass(repr=False) class BoundedTensorSpec(TensorSpec): - """ - A bounded, unidimensional, continuous tensor spec. + """A bounded, unidimensional, continuous tensor spec. Args: minimum (np.ndarray, torch.Tensor or number): lower bound of the box. @@ -502,8 +487,8 @@ def is_in(self, val: torch.Tensor) -> bool: @dataclass(repr=False) class OneHotDiscreteTensorSpec(TensorSpec): - """ - A unidimensional, one-hot discrete tensor spec. + """A unidimensional, one-hot discrete tensor spec. + By default, TorchRL assumes that categorical variables are encoded as one-hot encodings of the variable. This allows for simple indexing of tensors, e.g. @@ -628,8 +613,7 @@ def __eq__(self, other): @dataclass(repr=False) class UnboundedContinuousTensorSpec(TensorSpec): - """ - An unbounded, unidimensional, continuous tensor spec. + """An unbounded, unidimensional, continuous tensor spec. Args: device (str, int or torch.device, optional): device of the tensors. @@ -661,8 +645,7 @@ def is_in(self, val: torch.Tensor) -> bool: @dataclass(repr=False) class UnboundedDiscreteTensorSpec(TensorSpec): - """ - An unbounded, unidimensional, discrete tensor spec. + """An unbounded, unidimensional, discrete tensor spec. Args: device (str, int or torch.device, optional): device of the tensors. @@ -701,8 +684,7 @@ def is_in(self, val: torch.Tensor) -> bool: @dataclass(repr=False) class NdBoundedTensorSpec(BoundedTensorSpec): - """ - A bounded, multi-dimensional, continuous tensor spec. + """A bounded, multi-dimensional, continuous tensor spec. Args: minimum (np.ndarray, torch.Tensor or number): lower bound of the box. @@ -794,8 +776,7 @@ def __init__( @dataclass(repr=False) class NdUnboundedContinuousTensorSpec(UnboundedContinuousTensorSpec): - """ - An unbounded, multi-dimensional, continuous tensor spec. + """An unbounded, multi-dimensional, continuous tensor spec. Args: device (str, int or torch.device, optional): device of the tensors. @@ -824,8 +805,7 @@ def __init__( @dataclass(repr=False) class NdUnboundedDiscreteTensorSpec(UnboundedDiscreteTensorSpec): - """ - An unbounded, multi-dimensional, discrete tensor spec. + """An unbounded, multi-dimensional, discrete tensor spec. Args: device (str, int or torch.device, optional): device of the tensors. @@ -859,8 +839,7 @@ def __init__( @dataclass(repr=False) class BinaryDiscreteTensorSpec(TensorSpec): - """ - A binary discrete tensor spec. + """A binary discrete tensor spec. Args: n (int): length of the binary vector. @@ -908,8 +887,7 @@ def is_in(self, val: torch.Tensor) -> bool: @dataclass(repr=False) class MultOneHotDiscreteTensorSpec(OneHotDiscreteTensorSpec): - """ - A concatenation of one-hot discrete tensor spec. + """A concatenation of one-hot discrete tensor spec. Args: nvec (iterable of integers): cardinality of each of the elements of @@ -1020,13 +998,12 @@ def _project(self, val: torch.Tensor) -> torch.Tensor: class CompositeSpec(TensorSpec): - """ - A composition of TensorSpecs. + """A composition of TensorSpecs. Args: **kwargs (key (str): value (TensorSpec)): dictionary of tensorspecs to be stored. Values can be None, in which case is_in will be assumed - to be `True` for the corresponding tensors, and `project()` will have no + to be :obj:`True` for the corresponding tensors, and :obj:`project()` will have no effect. `spec.encode` cannot be used with missing values. Examples: @@ -1219,7 +1196,7 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: return self def to_numpy(self, val: TensorDict, safe: bool = True) -> dict: - return {key: self[key].to_numpy(val) for key, val in val.items()} + return {key: self[key]._to_numpy(val) for key, val in val.items()} def zero(self, shape=None) -> TensorDictBase: if shape is None: diff --git a/torchrl/data/tensordict/memmap.py b/torchrl/data/tensordict/memmap.py index ba4f04c463a..f0428e1d281 100644 --- a/torchrl/data/tensordict/memmap.py +++ b/torchrl/data/tensordict/memmap.py @@ -30,7 +30,7 @@ def implements_for_memmap(torch_function) -> Callable: - """Register a torch function override for ScalarTensor""" + """Register a torch function override for ScalarTensor.""" @functools.wraps(torch_function) def decorator(func): @@ -75,7 +75,7 @@ class MemmapTensor(object): created from torch.Tensor objects. Default is "cpu". dtype (torch.dtype, optional): dtype of the loaded tensor. This should not be used with MemmapTensors created from torch.Tensor - objects. Default is `torch.get_default_dtype()`. + objects. Default is :obj:`torch.get_default_dtype()`. transfer_ownership (bool, optional): affects the ownership after serialization: if True, the current process looses ownership immediately after serialization. If False, the current process keeps the ownership @@ -324,7 +324,7 @@ def numel(self) -> int: return self._numel def clone(self) -> MemmapTensor: - """Clones the MemmapTensor onto another tensor + """Clones the MemmapTensor onto another tensor. Returns: a new torch.Tensor with the same data but a new storage. @@ -359,7 +359,7 @@ def shape(self) -> torch.Size: return self._shape def cpu(self) -> torch.Tensor: - """Defines the device of the MemmapTensor as "cpu" + """Defines the device of the MemmapTensor as "cpu". Returns: a MemmapTensor where device has been modified in-place @@ -368,7 +368,7 @@ def cpu(self) -> torch.Tensor: return self def cuda(self) -> torch.Tensor: - """Defines the device of the MemmapTensor as "cuda" + """Defines the device of the MemmapTensor as "cuda". Returns: a MemmapTensor where device has been modified in-place @@ -384,8 +384,7 @@ def copy_(self, other: Union[torch.Tensor, MemmapTensor]) -> MemmapTensor: return self def set_transfer_ownership(self, value: bool = True) -> MemmapTensor: - """Controls whether the ownership will be transferred to another - process upon serialization/deserialization + """Controls whether the ownership will be transferred to another process upon serialization/deserialization. Args: value (bool): if True, the ownership will be transferred. @@ -538,7 +537,7 @@ def to( tensor will be retrieved, mapped to the desired dtype and cast to a new MemmapTensor. - Returns: + Returns: the same memmap-tensor with the changed device. """ if isinstance(dest, (int, str, torch.device)): @@ -621,5 +620,6 @@ def _cat( def set_transfer_ownership(memmap: MemmapTensor, value: bool = True) -> None: + """Changes the transfer_ownership attribute of a MemmapTensor.""" if isinstance(memmap, MemmapTensor): memmap.set_transfer_ownership(value) diff --git a/torchrl/data/tensordict/metatensor.py b/torchrl/data/tensordict/metatensor.py index 0df81b4c380..7b00e139bde 100644 --- a/torchrl/data/tensordict/metatensor.py +++ b/torchrl/data/tensordict/metatensor.py @@ -20,7 +20,7 @@ def implements_for_meta(torch_function) -> Callable: - """Register a torch function override for ScalarTensor""" + """Register a torch function override for ScalarTensor.""" @functools.wraps(torch_function) def decorator(func): @@ -31,14 +31,13 @@ def decorator(func): class MetaTensor: - """MetaTensor is a custom class that stores the meta-information about a - tensor without requiring to access the tensor. + """MetaTensor is a custom class that stores the meta-information about a tensor without requiring to access the tensor. This is intended to be used with tensors that have a high access cost. MetaTensor supports more operations than tensors on 'meta' device ( `torch.tensor(..., device='meta')`). For instance, MetaTensor supports some operations on its shape and device, - such as `mt.to(device)`, `mt.view(*new_shape)`, `mt.expand( + such as :obj:`mt.to(device)`, :obj:`mt.view(*new_shape)`, :obj:`mt.expand( *expand_shape)` etc. Args: @@ -142,7 +141,6 @@ def share_memory_(self) -> MetaTensor: self """ - self._is_shared = True self.class_name = "SharedTensor" if self.device.type != "cuda" else "Tensor" return self @@ -165,10 +163,9 @@ def ndimension(self) -> int: return self._ndim def clone(self) -> MetaTensor: - """ + """Clones the meta-tensor. - Returns: - a new MetaTensor with the same specs. + Returns: a new MetaTensor with the same specs. """ return MetaTensor( @@ -230,6 +227,7 @@ def __repr__(self) -> str: ) def unsqueeze(self, dim: int) -> MetaTensor: + """Unsqueezes the meta-tensor along the desired dim.""" clone = self.clone() new_shape = [] shape = [i for i in clone.shape] @@ -243,6 +241,7 @@ def unsqueeze(self, dim: int) -> MetaTensor: return clone def squeeze(self, dim: Optional[int] = None) -> MetaTensor: + """Squeezes the meta-tensor along the desired dim.""" clone = self.clone() shape = clone.shape if dim is None: @@ -260,6 +259,7 @@ def squeeze(self, dim: Optional[int] = None) -> MetaTensor: return clone def permute(self, dims: int) -> MetaTensor: + """Permutes the dims of the meta-tensor.""" clone = self.clone() new_shape = [self.shape[dim] for dim in dims] clone.shape = torch.Size(new_shape) @@ -270,6 +270,7 @@ def view( *shape: Sequence, size: Optional[Union[List, Tuple, torch.Size]] = None, ) -> MetaTensor: + """Returns a view of a reshaped meta-tensor.""" if len(shape) == 0 and size is not None: return self.view(*size) elif len(shape) == 1 and isinstance(shape[0], (list, tuple, torch.Size)): @@ -336,6 +337,7 @@ def stack_meta( dim: int = 0, safe: bool = False, ) -> MetaTensor: + """Stacks similar meta-tensors into a single meta-tensor.""" dtype = ( list_of_meta_tensors[0].dtype if len(list_of_meta_tensors) diff --git a/torchrl/data/tensordict/tensordict.py b/torchrl/data/tensordict/tensordict.py index fcc7985d675..1c0f9e53676 100644 --- a/torchrl/data/tensordict/tensordict.py +++ b/torchrl/data/tensordict/tensordict.py @@ -83,10 +83,7 @@ class TensorDictBase(Mapping, metaclass=abc.ABCMeta): - """ - TensorDictBase is an abstract parent class for TensorDicts, the torchrl - data container. - """ + """TensorDictBase is an abstract parent class for TensorDicts, the torchrl data container.""" _safe = False _lazy = False @@ -111,13 +108,14 @@ def _make_meta(self, key: str) -> MetaTensor: @property def shape(self) -> torch.Size: - """See TensorDictBase.batch_size""" + """See :obj:`TensorDictBase.batch_size`.""" return self.batch_size @property @abc.abstractmethod def batch_size(self) -> torch.Size: """Shape of (or batch_size) of a TensorDict. + The shape of a tensordict corresponds to the common N first dimensions of the tensors it contains, where N is an arbitrary number. The TensorDict shape is controlled by the user upon @@ -131,8 +129,9 @@ def batch_size(self) -> torch.Size: raise NotImplementedError def size(self, dim: Optional[int] = None): - """Returns the size of the dimension indicated by `dim`. If dim is not - specified, returns the batch_size (or shape) of the TensorDict. + """Returns the size of the dimension indicated by :obj:`dim`. + + If dim is not specified, returns the batch_size (or shape) of the TensorDict. """ if dim is None: @@ -150,7 +149,7 @@ def _batch_size_setter(self, new_batch_size: torch.Size) -> None: raise RuntimeError( "modifying the batch size of a lazy repesentation of a " "tensordict is not permitted. Consider instantiating the " - "tensordict fist by calling `td = td.to_tensordict()` before " + "tensordict fist by calling :obj:`td = td.to_tensordict()` before " "resetting the batch size." ) if self.batch_size == new_batch_size: @@ -179,7 +178,9 @@ def dim(self) -> int: @property @abc.abstractmethod def device(self) -> Union[None, torch.device]: - """Device of a TensorDict. If the TensorDict has a specified device, all + """Device of a TensorDict. + + If the TensorDict has a specified device, all tensors of a tensordict must live on the same device. If the TensorDict device is None, then different values can be located on different devices. @@ -267,7 +268,7 @@ def set( item (torch.Tensor): value to be stored in the tensordict inplace (bool, optional): if True and if a key matches an existing key in the tensordict, then the update will occur in-place - for that key-value pair. Default is `False`. + for that key-value pair. Default is :obj:`False`. Returns: self @@ -347,8 +348,7 @@ def _default_get( def get( self, key: str, default: Union[str, COMPATIBLE_TYPES] = "_no_default_" ) -> COMPATIBLE_TYPES: - """ - Gets the value stored with the input key. + """Gets the value stored with the input key. Args: key (str): key to be queried. @@ -370,8 +370,7 @@ def _get_meta(self, key) -> MetaTensor: ) def apply_(self, fn: Callable) -> TensorDictBase: - """Applies a callable to all values stored in the tensordict and - re-writes them in-place. + """Applies a callable to all values stored in the tensordict and re-writes them in-place. Args: fn (Callable): function to be applied to the tensors in the @@ -389,15 +388,14 @@ def apply( batch_size: Optional[Sequence[int]] = None, inplace: bool = False, ) -> TensorDictBase: - """Applies a callable to all values stored in the tensordict and sets - them in a new tensordict. + """Applies a callable to all values stored in the tensordict and sets them in a new tensordict. Args: fn (Callable): function to be applied to the tensors in the tensordict. batch_size (sequence of int, optional): if provided, the resulting TensorDict will have the desired batch_size. - The `batch_size` argument should match the batch_size after + The :obj:`batch_size` argument should match the batch_size after the transformation. inplace (bool, optional): if True, changes are made in-place. Default is False. @@ -429,19 +427,18 @@ def update( inplace: bool = False, **kwargs, ) -> TensorDictBase: - """Updates the TensorDict with values from either a dictionary or - another TensorDict. + """Updates the TensorDict with values from either a dictionary or another TensorDict. Args: input_dict_or_td (TensorDictBase or dict): Does not keyword arguments - (unlike `dict.update()`). + (unlike :obj:`dict.update()`). clone (bool, optional): whether the tensors in the input ( tensor) dict should be cloned before being set. Default is `False`. inplace (bool, optional): if True and if a key matches an existing key in the tensordict, then the update will occur in-place - for that key-value pair. Default is `False`. - **kwargs: keyword arguments for the `TensorDict.set` method + for that key-value pair. Default is :obj:`False`. + **kwargs: keyword arguments for the :obj:`TensorDict.set` method Returns: self @@ -466,15 +463,14 @@ def update_( input_dict_or_td: Union[Dict[str, COMPATIBLE_TYPES], TensorDictBase], clone: bool = False, ) -> TensorDictBase: - """Updates the TensorDict in-place with values from either a dictionary - or another TensorDict. + """Updates the TensorDict in-place with values from either a dictionary or another TensorDict. Unlike TensorDict.update, this function will throw an error if the key is unknown to the TensorDict Args: input_dict_or_td (TensorDictBase or dict): Does not keyword - arguments (unlike `dict.update()`). + arguments (unlike :obj:`dict.update()`). clone (bool, optional): whether the tensors in the input ( tensor) dict should be cloned before being set. Default is `False`. @@ -503,15 +499,13 @@ def update_at_( idx: INDEX_TYPING, clone: bool = False, ) -> TensorDictBase: - """Updates the TensorDict in-place at the specified index with - values from either a dictionary or another TensorDict. + """Updates the TensorDict in-place at the specified index with values from either a dictionary or another TensorDict. - Unlike TensorDict.update, this function will throw an error if the - key is unknown to the TensorDict. + Unlike TensorDict.update, this function will throw an error if the key is unknown to the TensorDict. Args: input_dict_or_td (TensorDictBase or dict): Does not keyword arguments - (unlike `dict.update()`). + (unlike :obj:`dict.update()`). idx (int, torch.Tensor, iterable, slice): index of the tensordict where the update should occur. clone (bool, optional): whether the tensors in the input ( @@ -529,15 +523,14 @@ def update_at_( ... 'b': torch.ones(1, 4, 10)}, batch_size=[1, 4]), ... slice(1, 2)) TensorDict( - fields={a: Tensor(torch.Size([3, 4, 5]), dtype=torch.float32), - b: Tensor(torch.Size([3, 4, 10]),\ -dtype=torch.float32)}, - shared=False, + fields={ + a: Tensor(torch.Size([3, 4, 5]), dtype=torch.float32), + b: Tensor(torch.Size([3, 4, 10]), dtype=torch.float32)}, batch_size=torch.Size([3, 4]), - device=cpu) + device=None, + is_shared=False) """ - for key, value in input_dict_or_td.items(): if not isinstance(value, _accepted_classes): raise TypeError( @@ -611,33 +604,23 @@ def _process_input( @abc.abstractmethod def pin_memory(self) -> TensorDictBase: - """Calls pin_memory() on the stored tensors.""" + """Calls :obj:`pin_memory` on the stored tensors.""" raise NotImplementedError(f"{self.__class__.__name__}") - # @abc.abstractmethod - # def is_pinned(self) -> bool: - # """Checks if tensors are pinned.""" - # raise NotImplementedError(f"{self.__class__.__name__}") - def items(self) -> Iterator[Tuple[str, COMPATIBLE_TYPES]]: - """ - Returns a generator of key-value pairs for the tensordict. - - """ + """Returns a generator of key-value pairs for the tensordict.""" for k in self.keys(): yield k, self.get(k) def values(self) -> Iterator[COMPATIBLE_TYPES]: - """ - Returns a generator representing the values for the tensordict. - - """ + """Returns a generator representing the values for the tensordict.""" for k in self.keys(): yield self.get(k) def items_meta(self, make_unset: bool = True) -> Iterator[Tuple[str, MetaTensor]]: - """Returns a generator of key-value pairs for the tensordict, where the - values are MetaTensor instances corresponding to the stored tensors. + """Returns a generator of key-value pairs for the tensordict. + + The values are MetaTensor instances corresponding to the stored tensors. """ if make_unset: @@ -647,8 +630,9 @@ def items_meta(self, make_unset: bool = True) -> Iterator[Tuple[str, MetaTensor] return self._dict_meta.items() def values_meta(self, make_unset: bool = True) -> Iterator[MetaTensor]: - """Returns a generator representing the values for the tensordict, those - values are MetaTensor instances corresponding to the stored tensors. + """Returns a generator representing the values for the tensordict. + + Those values are MetaTensor instances corresponding to the stored tensors. """ if make_unset: @@ -660,12 +644,13 @@ def values_meta(self, make_unset: bool = True) -> Iterator[MetaTensor]: @abc.abstractmethod def keys(self) -> KeysView: """Returns a generator of tensordict keys.""" - raise NotImplementedError(f"{self.__class__.__name__}") def expand(self, *shape) -> TensorDictBase: - """Expands each tensors of the tensordict according to - `tensor.expand(*shape, *tensor.shape)` + """Expands each tensors of the tensordict according to the torch.expand function. + + In practice, this amends to: :obj:`tensor.expand(*shape, *tensor.shape)`. + Supports iterables to specify the shape Examples: @@ -674,6 +659,7 @@ def expand(self, *shape) -> TensorDictBase: >>> td_expand = td.expand(10, 3, 4) >>> assert td_expand.shape == torch.Size([10, 3, 4]) >>> assert td_expand.get("a").shape == torch.Size([10, 3, 4, 5]) + """ d = dict() tensordict_dims = self.batch_dims @@ -716,15 +702,18 @@ def __bool__(self) -> bool: raise ValueError("Converting a tensordict to boolean value is not permitted") def __ne__(self, other: object) -> TensorDictBase: - """XOR operation over two tensordicts, for evey key. The two - tensordicts must have the same key set. + """XOR operation over two tensordicts, for evey key. + + The two tensordicts must have the same key set. + + Args: + other (TensorDictBase, dict, or float): the value to compare against. Returns: a new TensorDict instance with all tensors are boolean tensors of the same shape as the original tensors. """ - if not isinstance(other, (TensorDictBase, dict, float, int)): return False if not isinstance(other, TensorDictBase) and isinstance(other, dict): @@ -747,8 +736,7 @@ def __ne__(self, other: object) -> TensorDictBase: return TensorDict(batch_size=self.batch_size, source=d, device=self.device) def __eq__(self, other: object) -> TensorDictBase: - """Compares two tensordicts against each other, for every key. The two - tensordicts must have the same key set. + """Compares two tensordicts against each other, for every key. The two tensordicts must have the same key set. Returns: a new TensorDict instance with all tensors are boolean @@ -789,8 +777,7 @@ def del_(self, key: str) -> TensorDictBase: @abc.abstractmethod def select(self, *keys: str, inplace: bool = False) -> TensorDictBase: - """Selects the keys of the tensordict and returns an new tensordict - with only the selected keys. + """Selects the keys of the tensordict and returns an new tensordict with only the selected keys. The values are not copied: in-place modifications a tensor of either of the original or new tensordict will result in a change in both @@ -799,7 +786,7 @@ def select(self, *keys: str, inplace: bool = False) -> TensorDictBase: Args: *keys (str): keys to select inplace (bool): if True, the tensordict is pruned in place. - Default is `False`. + Default is :obj:`False`. Returns: A new tensordict with the selected keys only. @@ -815,7 +802,7 @@ def exclude(self, *keys: str, inplace: bool = False) -> TensorDictBase: def set_at_( self, key: str, value: COMPATIBLE_TYPES, idx: INDEX_TYPING ) -> TensorDictBase: - """Sets the values in-place at the index indicated by `idx`. + """Sets the values in-place at the index indicated by :obj:`idx`. Args: key (str): key to be modified. @@ -829,11 +816,11 @@ def set_at_( raise NotImplementedError(f"{self.__class__.__name__}") def copy_(self, tensordict: TensorDictBase) -> TensorDictBase: - """See `TensorDictBase.update_`.""" + """See :obj:`TensorDictBase.update_`.""" return self.update_(tensordict) def copy_at_(self, tensordict: TensorDictBase, idx: INDEX_TYPING) -> TensorDictBase: - """See `TensorDictBase.update_at_`.""" + """See :obj:`TensorDictBase.update_at_`.""" return self.update_at_(tensordict, idx) def get_at( @@ -882,7 +869,6 @@ def memmap_(self, prefix=None, lock=True) -> TensorDictBase: self. """ - raise NotImplementedError(f"{self.__class__.__name__}") @abc.abstractmethod @@ -902,7 +888,6 @@ def detach(self) -> TensorDictBase: a new tensordict with no tensor requiring gradient. """ - return TensorDict( {key: item.detach() for key, item in self.items()}, batch_size=self.batch_size, @@ -934,9 +919,9 @@ def zero_(self) -> TensorDictBase: return self def unbind(self, dim: int) -> Tuple[TensorDictBase, ...]: - """Returns a tuple of indexed tensordicts unbound along the - indicated dimension. Resulting tensordicts will share - the storage of the initial tensordict. + """Returns a tuple of indexed tensordicts unbound along the indicated dimension. + + Resulting tensordicts will share the storage of the initial tensordict. """ idx = [ @@ -946,8 +931,9 @@ def unbind(self, dim: int) -> Tuple[TensorDictBase, ...]: return tuple(self[_idx] for _idx in idx) def chunk(self, chunks: int, dim: int = 0) -> Tuple[TensorDictBase, ...]: - """Attempts to split a tendordict into the specified number of - chunks. Each chunk is a view of the input tensordict. + """Splits a tendordict into the specified number of chunks, if possible. + + Each chunk is a view of the input tensordict. Args: chunks (int): number of chunks to return @@ -982,6 +968,7 @@ def clone(self, recurse: bool = True) -> TensorDictBase: Args: recurse (bool, optional): if True, each tensor contained in the TensorDict will be copied too. Default is `True`. + """ return TensorDict( source={ @@ -1011,9 +998,9 @@ def __torch_function__( def to( self, dest: Union[DEVICE_TYPING, Type, torch.Size], **kwargs ) -> TensorDictBase: - """Maps a TensorDictBase subclass either on a new device or to another - TensorDictBase subclass (if permitted). Casting tensors to a new dtype - is not allowed, as tensordicts are not bound to contain a single + """Maps a TensorDictBase subclass either on a new device or to another TensorDictBase subclass (if permitted). + + Casting tensors to a new dtype is not allowed, as tensordicts are not bound to contain a single tensor dtype. Args: @@ -1049,7 +1036,7 @@ def _change_batch_size(self, new_size: torch.Size): raise NotImplementedError def cpu(self) -> TensorDictBase: - """Casts a tensordict to cpu (if not already on cpu).""" + """Casts a tensordict to CPU.""" return self.to("cpu") def cuda(self, device: int = 0) -> TensorDictBase: @@ -1082,7 +1069,7 @@ def masked_fill_(self, mask: Tensor, value: Union[float, bool]) -> TensorDictBas @abc.abstractmethod def masked_fill(self, mask: Tensor, value: Union[float, bool]) -> TensorDictBase: - """Out-of-place version of masked_fill + """Out-of-place version of masked_fill. Args: mask (boolean torch.Tensor): mask of values to be filled. Shape @@ -1105,8 +1092,7 @@ def masked_fill(self, mask: Tensor, value: Union[float, bool]) -> TensorDictBase raise NotImplementedError def masked_select(self, mask: Tensor) -> TensorDictBase: - """Masks all tensors of the TensorDict and return a new TensorDict - instance with similar keys pointing to masked values. + """Masks all tensors of the TensorDict and return a new TensorDict instance with similar keys pointing to masked values. Args: mask (torch.Tensor): boolean mask to be used for the tensors. @@ -1134,42 +1120,23 @@ def masked_select(self, mask: Tensor) -> TensorDictBase: @abc.abstractmethod def is_contiguous(self) -> bool: - """ - - Returns: - boolean indicating if all the tensors are contiguous. - - """ + """Returns a boolean indicating if all the tensors are contiguous.""" raise NotImplementedError @abc.abstractmethod def contiguous(self) -> TensorDictBase: - """ - - Returns: - a new tensordict of the same type with contiguous values ( - or self if values are already contiguous). - - """ + """Returns a new tensordict of the same type with contiguous values (or self if values are already contiguous).""" raise NotImplementedError def to_dict(self) -> Dict[str, Any]: - """ - - Returns: - dictionary with key-value pairs matching those of the - tensordict. - - """ + """Returns a dictionary with key-value pairs matching those of the tensordict.""" return { key: value.to_dict() if isinstance(value, TensorDictBase) else value for key, value in self.items() } def unsqueeze(self, dim: int) -> TensorDictBase: - """Unsqueeze all tensors for a dimension comprised in between - `-td.batch_dims` and `td.batch_dims` and returns them in a new - tensordict. + """Unsqueeze all tensors for a dimension comprised in between `-td.batch_dims` and `td.batch_dims` and returns them in a new tensordict. Args: dim (int): dimension along which to unsqueeze @@ -1193,9 +1160,7 @@ def unsqueeze(self, dim: int) -> TensorDictBase: ) def squeeze(self, dim: int) -> TensorDictBase: - """Squeezes all tensors for a dimension comprised in between - `-td.batch_dims+1` and `td.batch_dims-1` and returns them - in a new tensordict. + """Squeezes all tensors for a dimension comprised in between `-td.batch_dims+1` and `td.batch_dims-1` and returns them in a new tensordict. Args: dim (int): dimension along which to squeeze @@ -1261,8 +1226,7 @@ def view( *shape: int, size: Optional[Union[List, Tuple, torch.Size]] = None, ) -> TensorDictBase: - """Returns a tensordict with views of the tensors according to a new - shape, compatible with the tensordict batch_size. + """Returns a tensordict with views of the tensors according to a new shape, compatible with the tensordict batch_size. Args: *shape (int): new shape of the resulting tensordict. @@ -1304,7 +1268,7 @@ def permute( *dims_list: int, dims=None, ) -> TensorDictBase: - """Returns a view of a tensordict with the batch dimensions permuted according to dims + """Returns a view of a tensordict with the batch dimensions permuted according to dims. Args: *dims_list (int): the new ordering of the batch dims of the tensordict. Alternatively, @@ -1518,17 +1482,13 @@ def unflatten_keys( return out def __len__(self) -> int: - """ - - Returns: - Length of first dimension, if there is, otherwise 0. - - """ + """Returns the length of first dimension, if there is, otherwise 0.""" return self.shape[0] if self.batch_dims else 0 def __getitem__(self, idx: INDEX_TYPING) -> TensorDictBase: - """Indexes all tensors according to idx and returns a new tensordict - where the values share the storage of the original tensors (even + """Indexes all tensors according to the provided index. + + Returns a new tensordict where the values share the storage of the original tensors (even when the index is a torch.Tensor). Any in-place modification to the resulting tensordict will impact the parent tensordict too. @@ -1765,9 +1725,9 @@ class TensorDict(TensorDictBase): - Reading: `td.get(key)`, `td.get_at(key, index)` - - Content modification: `td.set(key, value)`, `td.set_(key, value)`, - `td.update(td_or_dict)`, `td.update_(td_or_dict)`, `td.fill_(key, - value)`, `td.rename_key(old_name, new_name)`, etc. + - Content modification: :obj:`td.set(key, value)`, :obj:`td.set_(key, value)`, + :obj:`td.update(td_or_dict)`, :obj:`td.update_(td_or_dict)`, :obj:`td.fill_(key, + value)`, :obj:`td.rename_key(old_name, new_name)`, etc. - Operations on multiple tensordicts: `torch.cat(tensordict_list, dim)`, `torch.stack(tensordict_list, dim)`, `td1 == td2` etc. @@ -1919,8 +1879,10 @@ def batch_dims(self, value: COMPATIBLE_TYPES) -> None: @property def device(self) -> Union[None, torch.device]: - """Returns `None` if device hasn't been provided in the constructor - or set via `tensordict.to(device)`. + """Device of the tensordict. + + Returns `None` if device hasn't been provided in the constructor or set via `tensordict.to(device)`. + """ return self._device @@ -1994,9 +1956,10 @@ def pin_memory(self) -> TensorDictBase: return self def expand(self, *shape) -> TensorDictBase: - """Expands every tensor with `(*shape, *tensor.shape)` and returns the - same tensordict with new tensors with expanded shapes. + """Expands every tensor with `(*shape, *tensor.shape)` and returns the same tensordict with new tensors with expanded shapes. + Supports iterables to specify the shape. + """ d = dict() tensordict_dims = self.batch_dims @@ -2044,8 +2007,10 @@ def set( _run_checks: bool = True, _meta_val: Optional[MetaTensor] = None, ) -> TensorDictBase: - """Sets a value in the TensorDict. If inplace=True (default is False), - and if the key already exists, set will call set_ (in place setting). + """Sets a value in the TensorDict. + + If inplace=True (default is False), and if the key already exists, set will call set_ (in place setting). + """ if self.is_locked: if not inplace or key not in self.keys(): @@ -2352,8 +2317,10 @@ def keys(self) -> KeysView: class _ErrorInteceptor: - """Context manager for catching errors and modifying message. Intended for - use with stacking / concatenation operations applied to TensorDicts. + """Context manager for catching errors and modifying message. + + Intended for use with stacking / concatenation operations applied to TensorDicts. + """ DEFAULT_EXC_MSG = "Expected all tensors to be on the same device" @@ -2382,7 +2349,7 @@ def __exit__(self, exc_type, exc_value, _): def implements_for_td(torch_function: Callable) -> Callable: - """Register a torch function override for ScalarTensor""" + """Register a torch function override for ScalarTensor.""" @functools.wraps(torch_function) def decorator(func): @@ -2401,6 +2368,7 @@ def assert_allclose_td( equal_nan: bool = True, msg: str = "", ) -> bool: + """Compares two tensordicts and raise an exception if their content does not match exactly.""" if not isinstance(actual, TensorDictBase) or not isinstance( expected, TensorDictBase ): @@ -2440,12 +2408,12 @@ def assert_allclose_td( @implements_for_td(torch.unbind) -def unbind(td: TensorDictBase, *args, **kwargs) -> Tuple[TensorDictBase, ...]: +def _unbind(td: TensorDictBase, *args, **kwargs) -> Tuple[TensorDictBase, ...]: return td.unbind(*args, **kwargs) @implements_for_td(torch.full_like) -def full_like(td: TensorDictBase, fill_value, **kwargs) -> TensorDictBase: +def _full_like(td: TensorDictBase, fill_value, **kwargs) -> TensorDictBase: td_clone = td.clone() for key in td_clone.keys(): td_clone.fill_(key, fill_value) @@ -2462,7 +2430,7 @@ def full_like(td: TensorDictBase, fill_value, **kwargs) -> TensorDictBase: @implements_for_td(torch.zeros_like) -def zeros_like(td: TensorDictBase, **kwargs) -> TensorDictBase: +def _zeros_like(td: TensorDictBase, **kwargs) -> TensorDictBase: td_clone = td.clone() for key in td_clone.keys(): td_clone.fill_(key, 0.0) @@ -2479,7 +2447,7 @@ def zeros_like(td: TensorDictBase, **kwargs) -> TensorDictBase: @implements_for_td(torch.ones_like) -def ones_like(td: TensorDictBase, **kwargs) -> TensorDictBase: +def _ones_like(td: TensorDictBase, **kwargs) -> TensorDictBase: td_clone = td.clone() for key in td_clone.keys(): td_clone.fill_(key, 1.0) @@ -2494,32 +2462,32 @@ def ones_like(td: TensorDictBase, **kwargs) -> TensorDictBase: @implements_for_td(torch.clone) -def clone(td: TensorDictBase, *args, **kwargs) -> TensorDictBase: +def _clone(td: TensorDictBase, *args, **kwargs) -> TensorDictBase: return td.clone(*args, **kwargs) @implements_for_td(torch.squeeze) -def squeeze(td: TensorDictBase, *args, **kwargs) -> TensorDictBase: +def _squeeze(td: TensorDictBase, *args, **kwargs) -> TensorDictBase: return td.squeeze(*args, **kwargs) @implements_for_td(torch.unsqueeze) -def unsqueeze(td: TensorDictBase, *args, **kwargs) -> TensorDictBase: +def _unsqueeze(td: TensorDictBase, *args, **kwargs) -> TensorDictBase: return td.unsqueeze(*args, **kwargs) @implements_for_td(torch.masked_select) -def masked_select(td: TensorDictBase, *args, **kwargs) -> TensorDictBase: +def _masked_select(td: TensorDictBase, *args, **kwargs) -> TensorDictBase: return td.masked_select(*args, **kwargs) @implements_for_td(torch.permute) -def permute(td: TensorDictBase, dims) -> TensorDictBase: +def _permute(td: TensorDictBase, dims) -> TensorDictBase: return td.permute(*dims) @implements_for_td(torch.cat) -def cat( +def _cat( list_of_tensordicts: Sequence[TensorDictBase], dim: int = 0, device: DEVICE_TYPING = None, @@ -2572,7 +2540,7 @@ def cat( @implements_for_td(torch.stack) -def stack( +def _stack( list_of_tensordicts: Sequence[TensorDictBase], dim: int = 0, device: DEVICE_TYPING = None, @@ -2679,8 +2647,7 @@ def stack( def pad(tensordict: TensorDictBase, pad_size: Sequence[int], value: float = 0.0): - """Pads all tensors in a tensordict along the batch dimensions with a constant value, - returning a new tensordict + """Pads all tensors in a tensordict along the batch dimensions with a constant value, returning a new tensordict. Args: tensordict (TensorDict): The tensordict to pad @@ -2697,7 +2664,8 @@ def pad(tensordict: TensorDictBase, pad_size: Sequence[int], value: float = 0.0) A new TensorDict padded along the batch dimensions Examples: - >>> from torchrl.data import TensorDict, pad + >>> from torchrl.data import TensorDict + >>> from torchrl.data.tensordict.tensordict import pad >>> import torch >>> td = TensorDict({'a': torch.ones(3, 4, 1), ... 'b': torch.ones(3, 4, 1, 1)}, batch_size=[3, 4]) @@ -2709,8 +2677,8 @@ def pad(tensordict: TensorDictBase, pad_size: Sequence[int], value: float = 0.0) torch.Size([4, 6, 1]) >>> print(padded_td.get("b").shape) torch.Size([4, 6, 1, 1]) - """ + """ if len(pad_size) > 2 * len(tensordict.batch_size): raise RuntimeError( "The length of pad_size must be <= 2 * the number of batch dimensions" @@ -2780,8 +2748,7 @@ def pad_sequence_td( class SubTensorDict(TensorDictBase): - """ - A TensorDict that only sees an index of the stored tensors. + """A TensorDict that only sees an index of the stored tensors. By default, indexing a tensordict with an iterable will result in a SubTensorDict. This is done such that a TensorDict indexed with @@ -2798,11 +2765,11 @@ class SubTensorDict(TensorDictBase): >>> print(type(td_index), td_index.shape) \ torch.Size([3]) - >>> td_index = td[:, slice(None)] + >>> td_index = td[slice(None), slice(None)] >>> print(type(td_index), td_index.shape) \ torch.Size([3, 4]) - >>> td_index = td[:, Tensor([0, 2]).to(torch.long)] + >>> td_index = td.get_sub_tensordict((slice(None), torch.tensor([0, 2], dtype=torch.long))) >>> print(type(td_index), td_index.shape) \ torch.Size([3, 2]) @@ -3218,6 +3185,7 @@ def share_memory_(self, lock=True) -> TensorDictBase: def merge_tensordicts(*tensordicts: TensorDictBase) -> TensorDictBase: + """Merges tensordicts together.""" if len(tensordicts) < 2: raise RuntimeError( f"at least 2 tensordicts must be provided, got" f" {len(tensordicts)}" @@ -3254,6 +3222,7 @@ class LazyStackedTensorDict(TensorDictBase): torch.Size([3, 10, 4]) >>> print(td_stack[:, 0] is tds[0]) True + """ _safe = False @@ -3803,6 +3772,8 @@ def masked_fill(self, mask: Tensor, value: Union[float, bool]) -> TensorDictBase class SavedTensorDict(TensorDictBase): + """A saved tensordict class.""" + _safe = False _lazy = False @@ -4061,8 +4032,11 @@ def to(self, dest: Union[DEVICE_TYPING, Type], **kwargs): def to_tensordict(self): """Returns a regular TensorDict instance from the TensorDictBase. - Makes a copy of the tensor dict. - Memmap and shared memory tensors are converted to regular tensors. + + Makes a copy of the tensor dict. + + Memmap and shared memory tensors are converted to regular tensors. + Returns: a new TensorDict object containing the same values. @@ -4191,8 +4165,9 @@ def __init__( def _update_custom_op_kwargs( self, source_meta_tensor: MetaTensor ) -> Dict[str, Any]: - """Allows for a transformation to be customized for a certain shape, - device or dtype. By default, this is a no-op on self.custom_op_kwargs + """Allows for a transformation to be customized for a certain shape, device or dtype. + + By default, this is a no-op on self.custom_op_kwargs Args: source_meta_tensor: corresponding MetaTensor @@ -4205,8 +4180,7 @@ def _update_custom_op_kwargs( return self.custom_op_kwargs def _update_inv_op_kwargs(self, source_tensor: Tensor) -> Dict[str, Any]: - """Allows for an inverse transformation to be customized for a - certain shape, device or dtype. + """Allows for an inverse transformation to be customized for a certain shape, device or dtype. By default, this is a no-op on self.inv_op_kwargs @@ -4494,9 +4468,10 @@ def _stack_onto_( class SqueezedTensorDict(_CustomOpTensorDict): - """ - A lazy view on a squeezed TensorDict. + """A lazy view on a squeezed TensorDict. + See the `UnsqueezedTensorDict` class documentation for more information. + """ def unsqueeze(self, dim: int) -> TensorDictBase: @@ -4561,8 +4536,7 @@ def view( class PermutedTensorDict(_CustomOpTensorDict): - """ - A lazy view on a TensorDict with the batch dimensions permuted. + """A lazy view on a TensorDict with the batch dimensions permuted. When calling `tensordict.permute(dims_list, dim)`, a lazy view of this operation is returned such that the following code snippet works without raising an @@ -4579,6 +4553,7 @@ class PermutedTensorDict(_CustomOpTensorDict): torch.Size([6, 5, 4]) >>> print(td_permute.permute(dims=(2, 1, 0)) is td) True + """ def permute( @@ -4737,8 +4712,7 @@ def make_tensordict( device: Optional[DEVICE_TYPING] = None, **kwargs, # source ) -> TensorDict: - """ - Returns a TensorDict created from the keyword arguments. + """Returns a TensorDict created from the keyword arguments. If batch_size is not specified, returns the maximum batch size possible @@ -4746,6 +4720,7 @@ def make_tensordict( **kwargs (TensorDict or torch.Tensor): keyword arguments as data source. batch_size (iterable of int, optional): a batch size for the tensordict. device (torch.device or compatible type, optional): a device for the TensorDict. + """ if batch_size is None: batch_size = _find_max_batch_size(kwargs) diff --git a/torchrl/data/tensordict/utils.py b/torchrl/data/tensordict/utils.py index 9fa22681d31..42110ee4552 100644 --- a/torchrl/data/tensordict/utils.py +++ b/torchrl/data/tensordict/utils.py @@ -25,9 +25,16 @@ def _sub_index(tensor: torch.Tensor, idx: INDEX_TYPING) -> torch.Tensor: - """Allows indexing of tensors with nested tuples, i.e. - tensor[tuple1][tuple2] can be indexed via _sub_index(tensor, (tuple1, - tuple2)) + """Allows indexing of tensors with nested tuples. + + >>> sub_tensor1 = tensor[tuple1][tuple2] + >>> sub_tensor2 = _sub_index(tensor, (tuple1, tuple2)) + >>> assert torch.allclose(sub_tensor1, sub_tensor2) + + Args: + tensor (torch.Tensor): tensor to be indexed. + idx (tuple of indices): indices sequence to be used. + """ if isinstance(idx, tuple) and len(idx) and isinstance(idx[0], tuple): idx0 = idx[0] @@ -40,15 +47,13 @@ def _getitem_batch_size( shape: torch.Size, items: INDEX_TYPING, ) -> torch.Size: - """ - Given an input shape and an index, returns the size of the resulting - indexed tensor. + """Given an input shape and an index, returns the size of the resulting indexed tensor. This function is aimed to be used when indexing is an expensive operation. Args: - shape: Input shape - items: Index of the hypothetical tensor + shape (torch.Size): Input shape + items (index): Index of the hypothetical tensor Returns: Size of the resulting object (tensor or tensordict) @@ -115,9 +120,14 @@ def _getitem_batch_size( def convert_ellipsis_to_idx(idx: Union[Tuple, Ellipsis], batch_size: List[int]): - """ - Given an index containing an ellipsis or just an ellipsis, converts any ellipsis to slice(None) - Example: idx = (..., 0), batch_size = [1,2,3] -> new_index = (slice(None), slice(None), 0) + """Given an index containing an ellipsis or just an ellipsis, converts any ellipsis to slice(None). + + Example: + >>> idx = (..., 0) + >>> batch_size = [1,2,3] + >>> new_index = convert_ellipsis_to_idx(idx, batch_size) + >>> print(new_index) + (slice(None, None, None), slice(None, None, None), 0) Args: idx (tuple, Ellipsis): Input index diff --git a/torchrl/data/utils.py b/torchrl/data/utils.py index 72d7eb939e7..0cab9a1b6d4 100644 --- a/torchrl/data/utils.py +++ b/torchrl/data/utils.py @@ -36,6 +36,8 @@ class CloudpickleWrapper(object): + """A wrapper for functions that allow for serialization in multiprocessed settings.""" + def __init__(self, fn: Callable, **kwargs): if fn.__class__.__name__ == "EnvCreator": raise RuntimeError( @@ -66,6 +68,7 @@ def expand_as_right( dest: Union[torch.Tensor, "MemmapTensor", "TensorDictBase"], # noqa: F821 ): """Expand a tensor on the right to match another tensor shape. + Args: tensor: tensor to be expanded dest: tensor providing the target shape @@ -78,8 +81,8 @@ def expand_as_right( >>> dest = torch.zeros(3,4,5) >>> print(expand_as_right(tensor, dest).shape) torch.Size([3,4,5]) - """ + """ if dest.ndimension() < tensor.ndimension(): raise RuntimeError( "expand_as_right requires the destination tensor to have less " @@ -101,6 +104,7 @@ def expand_right( tensor: Union[torch.Tensor, "MemmapTensor"], shape: Sequence[int] # noqa: F821 ) -> torch.Tensor: """Expand a tensor on the right to match a desired shape. + Args: tensor: tensor to be expanded shape: target shape @@ -113,8 +117,8 @@ def expand_right( >>> shape = (3,4,5) >>> print(expand_right(tensor, shape).shape) torch.Size([3,4,5]) - """ + """ tensor_expand = tensor while tensor_expand.ndimension() < len(shape): tensor_expand = tensor_expand.unsqueeze(-1) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 97af90b82fa..75446676682 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -42,6 +42,8 @@ def _tensor_to_np(t): class EnvMetaData: + """A class for environment meta-data storage and passing in multiprocessed settings.""" + def __init__( self, tensordict: TensorDictBase, @@ -134,7 +136,18 @@ def keys(self) -> Sequence[str]: def build_tensordict( self, next_observation: bool = True, log_prob: bool = False ) -> TensorDictBase: - """returns a TensorDict with empty tensors of the desired shape""" + """Returns a TensorDict with empty tensors of the desired shape. + + Args: + next_observation (bool, optional): if False, the observation returned + will be of the current step only (no :obj:`"next_"` key will be present). + Default is True. + log_prob (bool, optional): If True, a log_prob key-value pair will be added + to the tensordict. + + Returns: A tensordict populated according to the env specs. + + """ # build a tensordict from specs td = TensorDict({}, batch_size=torch.Size([])) action_placeholder = torch.zeros( @@ -146,7 +159,7 @@ def build_tensordict( for (key, item) in self["observation_spec"].items(): if not key.startswith("next_"): raise RuntimeError( - f"All observation keys must start with the `'next_'` prefix. Found {key}" + f"All observation keys must start with the :obj:`'next_'` prefix. Found {key}" ) observation_placeholder = torch.zeros(item.shape, dtype=item.dtype) if next_observation: @@ -174,8 +187,7 @@ def build_tensordict( class EnvBase(nn.Module, metaclass=abc.ABCMeta): - """ - Abstract environment parent class. + """Abstract environment parent class. Properties: - observation_spec (CompositeSpec): sampling spec of the observations; @@ -243,8 +255,8 @@ def __new__(cls, *args, _inplace_update=False, _batch_locked=True, **kwargs): @property def batch_locked(self) -> bool: - """ - Whether the environnement can be used with a batch size different from the one it was initialized with or not. + """Whether the environnement can be used with a batch size different from the one it was initialized with or not. + If True, the env needs to be used with a tensordict having the same batch size as the env. batch_locked is an immutable property. """ @@ -299,6 +311,7 @@ def observation_spec(self, value: TensorSpec) -> None: def step(self, tensordict: TensorDictBase) -> TensorDictBase: """Makes a step in the environment. + Step accepts a single argument, tensordict, which usually carries an 'action' key which indicates the action to be taken. Step will call an out-place private method, _step, which is the method to be re-written by EnvBase subclasses. @@ -311,7 +324,6 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase: (+ others if needed). """ - # sanity check self._assert_tensordict_shape(tensordict) @@ -367,15 +379,17 @@ def reset( **kwargs, ) -> TensorDictBase: """Resets the environment. - As for step and _step, only the private method `_reset` should be overwritten by EnvBase subclasses. + + As for step and _step, only the private method :obj:`_reset` should be overwritten by EnvBase subclasses. Args: tensordict (TensorDictBase, optional): tensordict to be used to contain the resulting new observation. In some cases, this input can also be used to pass argument to the reset function. - execute_step (bool, optional): if True, a `step_mdp` is executed on the output TensorDict, - hereby removing the `"next_"` prefixes from the keys. + execute_step (bool, optional): if True, a :obj:`step_mdp` is executed on the output TensorDict, + hereby removing the :obj:`"next_"` prefixes from the keys. kwargs (optional): other arguments to be passed to the native reset function. + Returns: a tensordict (or the input tensordict, if any), modified in place with the resulting observations. @@ -419,8 +433,7 @@ def numel(self) -> int: return prod(self.batch_size) def set_seed(self, seed: int, static_seed: bool = False) -> int: - """Sets the seed of the environment and returns the next seed to be used ( - which is the input seed if a single environment is present) + """Sets the seed of the environment and returns the next seed to be used (which is the input seed if a single environment is present). Args: seed (int): seed to be set @@ -484,9 +497,8 @@ def rand_step(self, tensordict: Optional[TensorDictBase] = None) -> TensorDictBa @property def specs(self) -> Specs: - """ + """Returns a Specs container where all the environment specs are contained. - Returns a Specs container where all the environment specs are contained. This feature allows one to create an environment, retrieve all of the specs in a single data container and then erase the environment from the workspace. @@ -513,14 +525,14 @@ def rollout( max_steps (int): maximum number of steps to be executed. The actual number of steps can be smaller if the environment reaches a done state before max_steps have been executed. policy (callable, optional): callable to be called to compute the desired action. If no policy is provided, - actions will be called using `env.rand_step()` + actions will be called using :obj:`env.rand_step()` default = None callback (callable, optional): function to be called at each iteration with the given TensorDict. auto_reset (bool, optional): if True, resets automatically the environment if it is in a done state when the rollout is initiated. - Default is `True`. + Default is :obj:`True`. auto_cast_to_device (bool, optional): if True, the device of the tensordict is automatically cast to the - policy device before the policy is used. Default is `False`. + policy device before the policy is used. Default is :obj:`False`. break_when_any_done (bool): breaks if any of the done state is True. Default is True. return_contiguous (bool): if False, a LazyStackedTensorDict will be returned. Default is True. tensordict (TensorDict, optional): if auto_reset is False, an initial @@ -649,11 +661,7 @@ def to(self, device: DEVICE_TYPING) -> EnvBase: return super().to(device) def fake_tensordict(self) -> TensorDictBase: - """ - Returns a fake tensordict with key-value pairs that match in shape, device - and dtype what can be expected during an environment rollout. - - """ + """Returns a fake tensordict with key-value pairs that match in shape, device and dtype what can be expected during an environment rollout.""" input_spec = self.input_spec fake_input = input_spec.zero(self.batch_size) observation_spec = self.observation_spec @@ -678,7 +686,7 @@ def fake_tensordict(self) -> TensorDictBase: class _EnvWrapper(EnvBase, metaclass=abc.ABCMeta): """Abstract environment wrapper class. - Unlike EnvBase, _EnvWrapper comes with a `_build_env` private method that will be called upon instantiation. + Unlike EnvBase, _EnvWrapper comes with a :obj:`_build_env` private method that will be called upon instantiation. Interfaces with other libraries should be coded using _EnvWrapper. It is possible to directly query attributed from the nested environment it its name does not conflict with @@ -807,9 +815,7 @@ def make_tensordict( env: _EnvWrapper, policy: Optional[Callable[[TensorDictBase, ...], TensorDictBase]] = None, ) -> TensorDictBase: - """ - Returns a zeroed-tensordict with fields matching those required for a full step - (action selection and environment step) in the environment + """Returns a zeroed-tensordict with fields matching those required for a full step (action selection and environment step) in the environment. Args: env (_EnvWrapper): environment defining the observation, action and reward space; diff --git a/torchrl/envs/env_creator.py b/torchrl/envs/env_creator.py index 843d9b7418b..cc39ab053c0 100644 --- a/torchrl/envs/env_creator.py +++ b/torchrl/envs/env_creator.py @@ -167,6 +167,7 @@ def env_creator(fun: Callable) -> EnvCreator: def get_env_metadata( env_or_creator: Union[EnvBase, Callable], kwargs: Optional[Dict] = None ): + """Retrieves a EnvMetaData object from an env.""" if isinstance(env_or_creator, (EnvBase,)): return EnvMetaData.build_metadata_from_env(env_or_creator) elif not isinstance(env_or_creator, EnvBase) and not isinstance( diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 1670900f683..ed6d60f81c2 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -21,9 +21,7 @@ class BaseInfoDictReader(metaclass=abc.ABCMeta): - """ - Base class for info-readers. - """ + """Base class for info-readers.""" @abc.abstractmethod def __call__( @@ -37,8 +35,7 @@ def info_spec(self) -> Dict[str, TensorSpec]: class default_info_dict_reader(BaseInfoDictReader): - """ - Default info-key reader. + """Default info-key reader. In cases where keys can be directly written to a tensordict (mostly if they abide to the tensordict shape), one simply needs to indicate the keys to be registered during @@ -100,20 +97,17 @@ def info_spec(self) -> Dict[str, TensorSpec]: class GymLikeEnv(_EnvWrapper): - _info_dict_reader: BaseInfoDictReader - - """ - A gym-like env is an environment whose behaviour is similar to gym environments in what - common methods (specifically reset and step) are expected to do. + """A gym-like env is an environment. + Its behaviour is similar to gym environments in what common methods (specifically reset and step) are expected to do. - A `GymLikeEnv` has a `.step()` method with the following signature: + A :obj:`GymLikeEnv` has a :obj:`.step()` method with the following signature: ``env.step(action: np.ndarray) -> Tuple[Union[np.ndarray, dict], double, bool, *info]`` where the outputs are the observation, reward and done state respectively. In this implementation, the info output is discarded (but specific keys can be read - by updating info_dict_reader, see `set_info_dict_reader` class method). + by updating info_dict_reader, see :obj:`set_info_dict_reader` class method). By default, the first output is written at the "next_observation" key-value pair in the output tensordict, unless the first output is a dictionary. In that case, each observation output will be put at the corresponding @@ -122,14 +116,15 @@ class GymLikeEnv(_EnvWrapper): It is also expected that env.reset() returns an observation similar to the one observed after a step is completed. """ + _info_dict_reader: BaseInfoDictReader + @classmethod def __new__(cls, *args, **kwargs): cls._info_dict_reader = None return super().__new__(cls, *args, _batch_locked=True, **kwargs) def read_action(self, action): - """Reads the action obtained from the input TensorDict and transforms it - in the format expected by the contained environment. + """Reads the action obtained from the input TensorDict and transforms it in the format expected by the contained environment. Args: action (Tensor or TensorDict): an action to be taken in the environment @@ -140,7 +135,9 @@ def read_action(self, action): return self.action_spec.to_numpy(action, safe=False) def read_done(self, done): - """Reads a done state and returns a tuple containing: + """Done state reader. + + Reads a done state and returns a tuple containing: - a done state to be set in the environment - a boolean value indicating whether the frame_skip loop should be broken @@ -151,8 +148,7 @@ def read_done(self, done): return done, done def read_reward(self, total_reward, step_reward): - """Reads a reward and the total reward so far (in the frame skip loop) - and returns a sum of the two. + """Reads a reward and the total reward so far (in the frame skip loop) and returns a sum of the two. Args: total_reward (torch.Tensor or TensorDict): total reward so far in the step @@ -164,8 +160,7 @@ def read_reward(self, total_reward, step_reward): def read_obs( self, observations: Union[Dict[str, Any], torch.Tensor, np.ndarray] ) -> Dict[str, Any]: - """Reads an observation from the environment and returns an observation - compatible with the output TensorDict. + """Reads an observation from the environment and returns an observation compatible with the output TensorDict. Args: observations (observation under a format dictated by the inner env): observation to be read. @@ -250,7 +245,7 @@ def _reset( return tensordict_out def _output_transform(self, step_outputs_tuple: Tuple) -> Tuple: - """To be overwritten when step_outputs differ from Tuple[Observation: Union[np.ndarray, dict], reward: Number, done:Bool]""" + """To be overwritten when step_outputs differ from Tuple[Observation: Union[np.ndarray, dict], reward: Number, done:Bool].""" if not isinstance(step_outputs_tuple, tuple): raise TypeError( f"Expected step_outputs_tuple type to be Tuple but got {type(step_outputs_tuple)}" @@ -258,8 +253,9 @@ def _output_transform(self, step_outputs_tuple: Tuple) -> Tuple: return step_outputs_tuple def set_info_dict_reader(self, info_dict_reader: BaseInfoDictReader) -> GymLikeEnv: - """ - Sets an info_dict_reader function. This function should take as input an + """Sets an info_dict_reader function. + + This function should take as input an info_dict dictionary and the tensordict returned by the step function, and write values in an ad-hoc manner from one to the other. diff --git a/torchrl/envs/libs/dm_control.py b/torchrl/envs/libs/dm_control.py index 6267b8880e2..f86c43bc98e 100644 --- a/torchrl/envs/libs/dm_control.py +++ b/torchrl/envs/libs/dm_control.py @@ -101,8 +101,7 @@ def _robust_to_tensor(array: Union[float, np.ndarray]) -> torch.Tensor: class DMControlWrapper(GymLikeEnv): - """ - DeepMind Control lab environment wrapper. + """DeepMind Control lab environment wrapper. Args: env (dm_control.suite env): environment instance @@ -115,6 +114,7 @@ class DMControlWrapper(GymLikeEnv): >>> td = env.rand_step() >>> print(td) >>> print(env.available_envs) + """ git_url = "https://github.com/deepmind/dm_control" @@ -252,8 +252,7 @@ def __repr__(self) -> str: class DMControlEnv(DMControlWrapper): - """ - DeepMind Control lab environment wrapper. + """DeepMind Control lab environment wrapper. Args: env_name (str): name of the environment @@ -269,6 +268,7 @@ class DMControlEnv(DMControlWrapper): >>> td = env.rand_step() >>> print(td) >>> print(env.available_envs) + """ def __init__(self, env_name, task_name, **kwargs): diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index dd4d24ad94d..db84d9baf50 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -21,7 +21,7 @@ ) from ...data.utils import numpy_to_torch_dtype_dict from ..gym_like import GymLikeEnv, default_info_dict_reader -from ..utils import classproperty +from ..utils import _classproperty try: import gym @@ -129,8 +129,7 @@ def _is_from_pixels(env): class GymWrapper(GymLikeEnv): - """ - OpenAI Gym environment wrapper. + """OpenAI Gym environment wrapper. Examples: >>> env = gym.make("Pendulum-v0") @@ -138,6 +137,7 @@ class GymWrapper(GymLikeEnv): >>> td = env.rand_step() >>> print(td) >>> print(env.available_envs) + """ git_url = "https://github.com/openai/gym" @@ -186,7 +186,7 @@ def _build_env( env = PixelObservationWrapper(env, pixels_only=pixels_only) return env - @classproperty + @_classproperty def available_envs(cls) -> List[str]: return _get_envs() @@ -266,8 +266,7 @@ def info_dict_reader(self, value: callable): class GymEnv(GymWrapper): - """ - OpenAI Gym environment wrapper. + """OpenAI Gym environment wrapper. Examples: >>> env = GymEnv(env_name="Pendulum-v0", frame_skip=4) diff --git a/torchrl/envs/libs/utils.py b/torchrl/envs/libs/utils.py index bbb716d0f7b..01be86f0fbf 100644 --- a/torchrl/envs/libs/utils.py +++ b/torchrl/envs/libs/utils.py @@ -16,34 +16,32 @@ class GymPixelObservationWrapper(ObservationWrapper): - """Augment observations by pixel values.""" + """Augment observations by pixel values. + + Args: + env: The environment to wrap. + pixels_only: If :obj:`True` (default), the original observation returned + by the wrapped environment will be discarded, and a dictionary + observation will only include pixels. If :obj:`False`, the + observation dictionary will contain both the original + observations and the pixel observations. + render_kwargs: Optional :obj:`dict` containing keyword arguments passed + to the :obj:`self.render` method. + pixel_keys: Optional custom string specifying the pixel + observation's key in the :obj:`OrderedDict` of observations. + Defaults to 'pixels'. + + Raises: + ValueError: If :obj:`env`'s observation spec is not compatible with the + wrapper. Supported formats are a single array, or a dict of + arrays. + ValueError: If :obj:`env`'s observation already contains any of the + specified :obj:`pixel_keys`. + """ def __init__( self, env, pixels_only=True, render_kwargs=None, pixel_keys=("pixels",) ): - """Initializes a new pixel Wrapper. - - Args: - env: The environment to wrap. - pixels_only: If `True` (default), the original observation returned - by the wrapped environment will be discarded, and a dictionary - observation will only include pixels. If `False`, the - observation dictionary will contain both the original - observations and the pixel observations. - render_kwargs: Optional `dict` containing keyword arguments passed - to the `self.render` method. - pixel_keys: Optional custom string specifying the pixel - observation's key in the `OrderedDict` of observations. - Defaults to 'pixels'. - - Raises: - ValueError: If `env`'s observation spec is not compatible with the - wrapper. Supported formats are a single array, or a dict of - arrays. - ValueError: If `env`'s observation already contains any of the - specified `pixel_keys`. - """ - super().__init__(env) if render_kwargs is None: diff --git a/torchrl/envs/model_based/common.py b/torchrl/envs/model_based/common.py index 55cf3f38b21..cc0fad356dd 100644 --- a/torchrl/envs/model_based/common.py +++ b/torchrl/envs/model_based/common.py @@ -136,9 +136,7 @@ def __new__(cls, *args, **kwargs): ) def set_specs_from_env(self, env: EnvBase): - """ - Sets the specs of the environment from the specs of the given environment. - """ + """Sets the specs of the environment from the specs of the given environment.""" self.observation_spec = deepcopy(env.observation_spec).to(self.device) self.reward_spec = deepcopy(env.reward_spec).to(self.device) self.input_spec = deepcopy(env.input_spec).to(self.device) diff --git a/torchrl/envs/transforms/functional.py b/torchrl/envs/transforms/functional.py index 2fb7bff62ad..e2135c3e814 100644 --- a/torchrl/envs/transforms/functional.py +++ b/torchrl/envs/transforms/functional.py @@ -28,6 +28,7 @@ def _assert_channels(img: Tensor, permitted: List[int]) -> None: def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor: + """Turns an RGB image into grayscale.""" if img.ndim < 3: raise TypeError( "Input image tensor should have at least 3 dimensions, but found" diff --git a/torchrl/envs/transforms/r3m.py b/torchrl/envs/transforms/r3m.py index a95eb726fff..918cffdf7b8 100644 --- a/torchrl/envs/transforms/r3m.py +++ b/torchrl/envs/transforms/r3m.py @@ -149,7 +149,7 @@ class R3MTransform(Compose): The R3MTransform is created in a lazy manner: the object will be initialized only when an attribute (a spec or the forward method) will be queried. - The reason for this is that the `_init()` method requires some attributes of + The reason for this is that the :obj:`_init()` method requires some attributes of the parent environment (if any) to be accessed: by making the class lazy we can ensure that the following code snippet works as expected: diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 76e24f1748f..350a5bc175b 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -92,18 +92,18 @@ class Transform(nn.Module): the same or another) tensordict as output, where a series of values have been modified or created with a new key. When instantiating a new transform, the keys that are to be read from are passed to the - constructor via the `keys` argument. + constructor via the :obj:`keys` argument. Transforms are to be combined with their target environments with the - TransformedEnv class, which takes as arguments an `EnvBase` instance + TransformedEnv class, which takes as arguments an :obj:`EnvBase` instance and a transform. If multiple transforms are to be used, they can be - concatenated using the `Compose` class. + concatenated using the :obj:`Compose` class. A transform can be stateless or stateful (e.g. CatTransform). Because of - this, Transforms support the `reset` operation, which should reset the + this, Transforms support the :obj:`reset` operation, which should reset the transform to its initial state (such that successive trajectories are kept independent). - Notably, `Transform` subclasses take care of transforming the affected + Notably, :obj:`Transform` subclasses take care of transforming the affected specs from an environment: when querying `transformed_env.observation_spec`, the resulting objects will describe the specs of the transformed_in tensors. @@ -148,6 +148,7 @@ def init(self, tensordict) -> None: def _apply_transform(self, obs: torch.Tensor) -> None: """Applies the transform to a tensor. + This operation can be called multiple times (if multiples keys of the tensordict match the keys of the transform). @@ -155,10 +156,7 @@ def _apply_transform(self, obs: torch.Tensor) -> None: raise NotImplementedError def _call(self, tensordict: TensorDictBase) -> TensorDictBase: - """Reads the input tensordict, and for the selected keys, applies the - transform. - - """ + """Reads the input tensordict, and for the selected keys, applies the transform.""" self._check_inplace() for key_in, key_out in zip(self.keys_in, self.keys_out): if key_in in tensordict.keys(): @@ -189,8 +187,7 @@ def inv(self, tensordict: TensorDictBase) -> TensorDictBase: return tensordict def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: - """Transforms the input spec such that the resulting spec matches - transform mapping. + """Transforms the input spec such that the resulting spec matches transform mapping. Args: input_spec (TensorSpec): spec before the transform @@ -202,8 +199,7 @@ def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: return input_spec def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: - """Transforms the observation spec such that the resulting spec - matches transform mapping. + """Transforms the observation spec such that the resulting spec matches transform mapping. Args: observation_spec (TensorSpec): spec before the transform @@ -215,8 +211,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec return observation_spec def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: - """Transforms the reward spec such that the resulting spec matches - transform mapping. + """Transforms the reward spec such that the resulting spec matches transform mapping. Args: reward_spec (TensorSpec): spec before the transform @@ -225,7 +220,6 @@ def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: expected spec after the transform """ - return reward_spec def dump(self, **kwargs) -> None: @@ -269,13 +263,12 @@ def empty_cache(self): class TransformedEnv(EnvBase): - """ - A transformed environment. + """A transformed_in environment. Args: env (EnvBase): original environment to be transformed_in. transform (Transform, optional): transform to apply to the tensordict resulting - from `env.step(td)`. If none is provided, an empty Compose + from :obj:`env.step(td)`. If none is provided, an empty Compose placeholder in an eval mode is used. cache_specs (bool, optional): if True, the specs will be cached once and for all after the first call (i.e. the specs will be @@ -354,7 +347,7 @@ def _inplace_update(self): @property def observation_spec(self) -> TensorSpec: - """Observation spec of the transformed environment""" + """Observation spec of the transformed environment.""" if self._observation_spec is None or not self.cache_specs: observation_spec = self.transform.transform_observation_spec( deepcopy(self.base_env.observation_spec) @@ -367,13 +360,12 @@ def observation_spec(self) -> TensorSpec: @property def action_spec(self) -> TensorSpec: - """Action spec of the transformed environment""" + """Action spec of the transformed environment.""" return self.input_spec["action"] @property def input_spec(self) -> TensorSpec: - """Action spec of the transformed environment""" - + """Action spec of the transformed environment.""" if self._input_spec is None or not self.cache_specs: input_spec = self.transform.transform_input_spec( deepcopy(self.base_env.input_spec) @@ -386,8 +378,7 @@ def input_spec(self) -> TensorSpec: @property def reward_spec(self) -> TensorSpec: - """Reward spec of the transformed environment""" - + """Reward spec of the transformed environment.""" if self._reward_spec is None or not self.cache_specs: reward_spec = self.transform.transform_reward_spec( deepcopy(self.base_env.reward_spec) @@ -409,7 +400,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: return tensordict_out def set_seed(self, seed: int, static_seed: bool = False) -> int: - """Set the seeds of the environment""" + """Set the seeds of the environment.""" return self.base_env.set_seed(seed, static_seed=static_seed) def _reset(self, tensordict: Optional[TensorDictBase] = None, **kwargs): @@ -557,10 +548,7 @@ def __del__(self): class ObservationTransform(Transform): - """ - Abstract class for transformations of the observations. - - """ + """Abstract class for transformations of the observations.""" inplace = False @@ -579,8 +567,7 @@ def __init__( class Compose(Transform): - """ - Composes a chain of transforms. + """Composes a chain of transforms. Examples: >>> env = GymEnv("Pendulum-v0") @@ -694,8 +681,7 @@ def __repr__(self) -> str: class ToTensorImage(ObservationTransform): - """Transforms a numpy-like image (3 x W x H) to a pytorch image - (3 x W x H). + """Transforms a numpy-like image (3 x W x H) to a pytorch image (3 x W x H). Transforms an observation image from a (... x W x H x 3) 0..255 uint8 tensor to a single/double precision floating point (3 x W x H) tensor @@ -765,8 +751,7 @@ def _pixel_observation(self, spec: TensorSpec) -> None: class RewardClipping(Transform): - """ - Clips the reward between `clamp_min` and `clamp_max`. + """Clips the reward between `clamp_min` and `clamp_max`. Args: clip_min (scalar): minimum value of the resulting reward. @@ -828,11 +813,7 @@ def __repr__(self) -> str: class BinarizeReward(Transform): - """ - Maps the reward to a binary value (0 or 1) if the reward is null or - non-null, respectively. - - """ + """Maps the reward to a binary value (0 or 1) if the reward is null or non-null, respectively.""" inplace = True @@ -857,8 +838,7 @@ def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: class Resize(ObservationTransform): - """ - Resizes an pixel observation. + """Resizes an pixel observation. Args: w (int): resulting width @@ -929,7 +909,7 @@ def __repr__(self) -> str: class CenterCrop(ObservationTransform): - """Crops the center of an image + """Crops the center of an image. Args: w (int): resulting width @@ -1192,10 +1172,7 @@ def inv(self, tensordict: TensorDictBase) -> TensorDictBase: class GrayScale(ObservationTransform): - """ - Turns a pixel observation to grayscale. - - """ + """Turns a pixel observation to grayscale.""" inplace = False @@ -1224,7 +1201,8 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec class ObservationNorm(ObservationTransform): - """ + """Observation affine transformation layer. + Normalizes an observation according to .. math:: @@ -1388,8 +1366,9 @@ def __repr__(self) -> str: class RewardScaling(Transform): - """ - Affine transform of the reward according to + """Affine transform of the reward. + + The reward is transformed according to: .. math:: reward = reward * scale + loc @@ -1441,11 +1420,7 @@ def __repr__(self) -> str: class FiniteTensorDictCheck(Transform): - """ - This transform will check that all the items of the tensordict are - finite, and raise an exception if they are not. - - """ + """This transform will check that all the items of the tensordict are finite, and raise an exception if they are not.""" inplace = False @@ -1468,8 +1443,7 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase: class DoubleToFloat(Transform): - """ - Maps actions float to double before they are called on the environment. + """Maps actions float to double before they are called on the environment. Examples: >>> td = TensorDict( @@ -1539,8 +1513,8 @@ def __repr__(self) -> str: class CatTensors(Transform): - """ - Concatenates several keys in a single tensor. + """Concatenates several keys in a single tensor. + This is especially useful if multiple keys describe a single state (e.g. "observation_position" and "observation_velocity") @@ -1604,15 +1578,6 @@ def __init__( self.unsqueeze_if_oor = unsqueeze_if_oor def _check_keys_in(self, keys_in, out_key): - # if ( - # ("reward" in keys_in) - # or ("action" in keys_in) - # or ("reward" in keys_in) - # ): - # raise RuntimeError( - # "Concatenating observations and reward / action / done state " - # "is not allowed." - # ) if not out_key.startswith("next_") and all( key.startswith("next_") for key in keys_in ): @@ -1716,8 +1681,7 @@ def __repr__(self) -> str: class DiscreteActionProjection(Transform): - """Projects discrete actions from a high dimensional space to a low - dimensional space. + """Projects discrete actions from a high dimensional space to a low dimensional space. Given a discrete action (from 1 to N) encoded as a one-hot vector and a maximum action index M (with M < N), transforms the action such that @@ -1778,8 +1742,7 @@ def __repr__(self) -> str: class NoopResetEnv(Transform): - """ - Runs a series of random actions when an environment is reset. + """Runs a series of random actions when an environment is reset. Args: env (EnvBase): env on which the random actions have to be @@ -1798,6 +1761,7 @@ class NoopResetEnv(Transform): def __init__(self, noops: int = 30, random: bool = True): """Sample initial states by taking random number of no-ops on reset. + No-op is assumed to be action 0. """ super().__init__([]) @@ -1965,10 +1929,7 @@ def __repr__(self) -> str: class PinMemoryTransform(Transform): - """ - Calls pin_memory on the tensordict to facilitate writing on CUDA devices. - - """ + """Calls pin_memory on the tensordict to facilitate writing on CUDA devices.""" def __init__(self): super().__init__([]) @@ -1984,6 +1945,11 @@ def _sum_left(val, dest): class gSDENoise(Transform): + """A gSDE noise initializer. + + See the :func:`~torchrl.modules.models.exploration.gSDEModule' for more info. + """ + inplace = False def __init__( @@ -2021,8 +1987,8 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase: class VecNorm(Transform): - """ - Moving average normalization layer for torchrl environments. + """Moving average normalization layer for torchrl environments. + VecNorm keeps track of the summary statistics of a dataset to standardize it on-the-fly. If the transform is in 'eval' mode, the running statistics are not updated. @@ -2170,10 +2136,7 @@ def _update(self, key, value, N) -> torch.Tensor: return (value - mean) / std.clamp_min(self.eps) def to_observation_norm(self) -> Union[Compose, ObservationNorm]: - """Converts VecNorm into an ObservationNorm class that can be used - at inference time. - - """ + """Converts VecNorm into an ObservationNorm class that can be used at inference time.""" out = [] for key in self.keys_in: _sum = self._td.get(key + "_sum") @@ -2200,8 +2163,7 @@ def build_td_for_shared_vecnorm( keys_prefix: Optional[Sequence[str]] = None, memmap: bool = False, ) -> TensorDictBase: - """Creates a shared tensordict that can be sent to different processes - for normalization across processes. + """Creates a shared tensordict for normalization across processes. Args: env (EnvBase): example environment to be used to create the diff --git a/torchrl/envs/transforms/utils.py b/torchrl/envs/transforms/utils.py index dba5ab1622a..e0ff27ec9b3 100644 --- a/torchrl/envs/transforms/utils.py +++ b/torchrl/envs/transforms/utils.py @@ -3,23 +3,20 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import contextlib from typing import Callable, Optional, Tuple import torch from torch.utils._pytree import tree_map -@contextlib.contextmanager -def no_dispatch(): - guard = torch._C._DisableTorchDispatch() - try: - yield - finally: - del guard +class FiniteTensor(torch.Tensor): + """A finite tensor. + If the data contained in this tensor contain non-finite values (nan or inf) + a :obj:`RuntimeError` will be thrown. + + """ -class FiniteTensor(torch.Tensor): @staticmethod def __new__(cls, elem: torch.Tensor, *args, **kwargs): if not torch.isfinite(elem).all(): @@ -40,8 +37,7 @@ def __torch_dispatch__( # TODO: also explicitly recheck invariants on inplace/out mutation if kwargs: raise Exception("Expected empty kwargs") - with no_dispatch(): - rs = func(*args) + rs = func(*args) return tree_map( lambda e: FiniteTensor(e) if isinstance(e, torch.Tensor) else e, rs ) diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 7e2f713797e..d58922eb225 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -13,7 +13,7 @@ AVAILABLE_LIBRARIES = {pkg.key for pkg in pkg_resources.working_set} -class classproperty(property): +class _classproperty(property): def __get__(self, cls, owner): return classmethod(self.fget).__get__(None, owner)() @@ -26,23 +26,24 @@ def step_mdp( exclude_done: bool = True, exclude_action: bool = True, ) -> TensorDictBase: - """ - Given a tensordict retrieved after a step, returns another tensordict with all the 'next_' prefixes are removed, - i.e. all the `'next_some_other_string'` keys will be renamed onto `'some_other_string'` keys. + """Creates a new tensordict that reflects a step in time of the input tensordict. + + Given a tensordict retrieved after a step, returns another tensordict with all the :obj:`'next_'` prefixes are removed, + i.e. all the :obj:`'next_some_other_string'` keys will be renamed onto :obj:`'some_other_string'` keys. Args: tensordict (TensorDictBase): tensordict with keys to be renamed next_tensordict (TensorDictBase, optional): destination tensordict - keep_other (bool, optional): if True, all keys that do not start with `'next_'` will be kept. + keep_other (bool, optional): if True, all keys that do not start with :obj:`'next_'` will be kept. Default is True. - exclude_reward (bool, optional): if True, the `"reward"` key will be discarded + exclude_reward (bool, optional): if True, the :obj:`"reward"` key will be discarded from the resulting tensordict. Default is True. - exclude_done (bool, optional): if True, the `"done"` key will be discarded + exclude_done (bool, optional): if True, the :obj:`"done"` key will be discarded from the resulting tensordict. Default is True. - exclude_action (bool, optional): if True, the `"action"` key will be discarded + exclude_action (bool, optional): if True, the :obj:`"action"` key will be discarded from the resulting tensordict. Default is True. @@ -99,67 +100,34 @@ def step_mdp( def get_available_libraries(): - """ - - Returns: - all the supported libraries - - """ + """Returns all the supported libraries.""" return SUPPORTED_LIBRARIES def _check_gym(): - """ - - Returns: - True if the gym library is installed - - """ + """Returns True if the gym library is installed.""" return "gym" in AVAILABLE_LIBRARIES def _check_gym_atari(): - """ - - Returns: - True if the gym library is installed and atari envs can be found. - - """ + """Returns True if the gym library is installed and atari envs can be found.""" if not _check_gym(): return False return "atari-py" in AVAILABLE_LIBRARIES def _check_mario(): - """ - - Returns: - True if the "gym-super-mario-bros" library is installed. - - """ - + """Returns True if the "gym-super-mario-bros" library is installed.""" return "gym-super-mario-bros" in AVAILABLE_LIBRARIES def _check_dmcontrol(): - """ - - Returns: - True if the "dm-control" library is installed. - - """ - + """Returns True if the "dm-control" library is installed.""" return "dm-control" in AVAILABLE_LIBRARIES def _check_dmlab(): - """ - - Returns: - True if the "deepmind-lab" library is installed. - - """ - + """Returns True if the "deepmind-lab" library is installed.""" return "deepmind-lab" in AVAILABLE_LIBRARIES @@ -196,8 +164,7 @@ def _check_dmlab(): class set_exploration_mode(_DecoratorContextManager): - """ - Sets the exploration mode of all ProbabilisticTDModules to the desired mode. + """Sets the exploration mode of all ProbabilisticTDModules to the desired mode. Args: mode (str): mode to use when the policy is being called. @@ -207,6 +174,7 @@ class set_exploration_mode(_DecoratorContextManager): >>> env.rollout(policy=policy, max_steps=100) # rollout with the "mode" interaction mode >>> with set_exploration_mode("random"): >>> env.rollout(policy=policy, max_steps=100) # rollout with the "random" interaction mode + """ def __init__(self, mode: str = "mode"): diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index a32687ecd1a..32fbcc96c39 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -72,9 +72,8 @@ def __call__(self, *args, **kwargs): class _BatchedEnv(EnvBase): - """ + """Batched environments allow the user to query an arbitrary method / attribute of the environment running remotely. - Batched environments allow the user to query an arbitrary method / attribute of the environment running remotely. Those queries will return a list of length equal to the number of workers containing the values resulting from those queries. >>> env = ParallelEnv(3, my_env_fun) @@ -95,9 +94,9 @@ class _BatchedEnv(EnvBase): drastically decrease the IO burden when the tensordict is placed in shared memory / memory map. env_input_keys will typically contain "action" and if this list is not provided this object will look for corresponding keys. When working with stateless models, it is important to include the - state to be read by the environment. If none is provided, _BatchedEnv will use the `EnvBase.input_spec` + state to be read by the environment. If none is provided, _BatchedEnv will use the :obj:`EnvBase.input_spec` keys as indicators of the keys to be sent to the env. - pin_memory (bool): if True and device is "cpu", calls `pin_memory` on the tensordicts when created. + pin_memory (bool): if True and device is "cpu", calls :obj:`pin_memory` on the tensordicts when created. selected_keys (list of str, optional): keys that have to be returned by the environment. When creating a batch of environment, it might be the case that only some of the keys are to be returned. For instance, if the environment returns 'next_pixels' and 'next_vector', the user might only @@ -112,15 +111,15 @@ class _BatchedEnv(EnvBase): shared_memory (bool): whether or not the returned tensordict will be placed in shared memory; memmap (bool): whether or not the returned tensordict will be placed in memory map. policy_proof (callable, optional): if provided, it'll be used to get the list of - tensors to return through the `step()` and `reset()` methods, such as `"hidden"` etc. + tensors to return through the :obj:`step()` and :obj:`reset()` methods, such as :obj:`"hidden"` etc. device (str, int, torch.device): for consistency, this argument is kept. However this argument should not be passed, as the device will be inferred from the environments. It is assumed that all environments will run on the same device as a common shared tensordict will be used to pass data from process to process. The device can be - changed after instantiation using `env.to(device)`. + changed after instantiation using :obj:`env.to(device)`. allow_step_when_done (bool, optional): if True, batched environments can execute steps after a done state is encountered. - Defaults to `False`. + Defaults to :obj:`False`. """ @@ -535,10 +534,7 @@ def to(self, device: DEVICE_TYPING): class SerialEnv(_BatchedEnv): - """ - Creates a series of environments in the same process. - - """ + """Creates a series of environments in the same process.""" __doc__ += _BatchedEnv.__doc__ @@ -663,8 +659,8 @@ def to(self, device: DEVICE_TYPING): class ParallelEnv(_BatchedEnv): - """ - Creates one environment per process. + """Creates one environment per process. + TensorDicts are passed via shared memory or memory map. """ diff --git a/torchrl/modules/distributions/continuous.py b/torchrl/modules/distributions/continuous.py index f4f37424d70..6fa54783c45 100644 --- a/torchrl/modules/distributions/continuous.py +++ b/torchrl/modules/distributions/continuous.py @@ -56,7 +56,7 @@ class IndependentNormal(D.Independent): tanh_loc (bool, optional): if True, the above formula is used for the location scaling, otherwise the raw value is kept. - Default is `True`; + Default is :obj:`True`; """ num_params: int = 2 @@ -87,10 +87,7 @@ def mode(self): class SafeTanhTransform(D.TanhTransform): - """ - TanhTransform subclass that ensured that the transformation is numerically invertible. - - """ + """TanhTransform subclass that ensured that the transformation is numerically invertible.""" def _call(self, x: torch.Tensor) -> torch.Tensor: y = safetanh(x) @@ -182,7 +179,7 @@ class TruncatedNormal(D.Independent): max (torch.Tensor or number, optional): maximum value of the distribution. Default = 1.0; tanh_loc (bool, optional): if True, the above formula is used for the location scaling, otherwise the raw value is kept. - Default is `True`; + Default is :obj:`True`; """ num_params: int = 2 @@ -299,7 +296,7 @@ class TanhNormal(D.TransformedDistribution): event_dims (int, optional): number of dimensions describing the action. Default is 1; tanh_loc (bool, optional): if True, the above formula is used for the location scaling, otherwise the raw - value is kept. Default is `True`; + value is kept. Default is :obj:`True`; """ arg_constraints = { @@ -397,8 +394,7 @@ def mode(self): def uniform_sample_tanhnormal(dist: TanhNormal, size=None) -> torch.Tensor: - """ - Defines what uniform sampling looks like for a TanhNormal distribution. + """Defines what uniform sampling looks like for a TanhNormal distribution. Args: dist (TanhNormal): distribution defining the space where the sampling should occur. @@ -414,8 +410,7 @@ def uniform_sample_tanhnormal(dist: TanhNormal, size=None) -> torch.Tensor: class Delta(D.Distribution): - """ - Delta distribution. + """Delta distribution. Args: param (torch.Tensor): parameter of the delta distribution; @@ -488,8 +483,7 @@ def mean(self) -> torch.Tensor: class TanhDelta(D.TransformedDistribution): - """ - Implements a Tanh transformed_in Delta distribution. + """Implements a Tanh transformed_in Delta distribution. Args: param (torch.Tensor): parameter of the delta distribution; diff --git a/torchrl/modules/distributions/truncated_normal.py b/torchrl/modules/distributions/truncated_normal.py index a742b5fd23b..0cccf120bcf 100644 --- a/torchrl/modules/distributions/truncated_normal.py +++ b/torchrl/modules/distributions/truncated_normal.py @@ -21,9 +21,9 @@ class TruncatedStandardNormal(Distribution): - """ - Truncated Standard Normal distribution - https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + """Truncated Standard Normal distribution. + + Source: https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf """ arg_constraints = { @@ -134,8 +134,8 @@ def rsample(self, sample_shape=None): class TruncatedNormal(TruncatedStandardNormal): - """ - Truncated Normal distribution + """Truncated Normal distribution. + https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf """ diff --git a/torchrl/modules/functional_modules.py b/torchrl/modules/functional_modules.py index df07a68f776..6e08a29d741 100644 --- a/torchrl/modules/functional_modules.py +++ b/torchrl/modules/functional_modules.py @@ -37,49 +37,49 @@ def _process_batched_inputs(in_dims, args, func): if not isinstance(in_dims, int) and not isinstance(in_dims, tuple): raise ValueError( - f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(): " - f"expected `in_dims` to be int or a (potentially nested) tuple " - f"matching the structure of inputs, got: {type(in_dims)}." + f"""vmap({_get_name(func)}, in_dims={in_dims}, ...)(): +expected `in_dims` to be int or a (potentially nested) tuple +matching the structure of inputs, got: {type(in_dims)}.""" ) if len(args) == 0: raise ValueError( - f"vmap({_get_name(func)})(): got no inputs. Maybe you forgot to add " - f"inputs, or you are trying to vmap over a function with no inputs. " - f"The latter is unsupported." + f"""vmap({_get_name(func)})(): got no inputs. Maybe you forgot to add +inputs, or you are trying to vmap over a function with no inputs. +The latter is unsupported.""" ) flat_args, args_spec = tree_flatten(args) flat_in_dims = _broadcast_to_and_flatten(in_dims, args_spec) if flat_in_dims is None: raise ValueError( - f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(): " - f"in_dims is not compatible with the structure of `inputs`. " - f"in_dims has structure {tree_flatten(in_dims)[1]} but inputs " - f"has structure {args_spec}." + f"""vmap({_get_name(func)}, in_dims={in_dims}, ...)(): +in_dims is not compatible with the structure of `inputs`. +in_dims has structure {tree_flatten(in_dims)[1]} but inputs +has structure {args_spec}.""" ) for i, (arg, in_dim) in enumerate(zip(flat_args, flat_in_dims)): if not isinstance(in_dim, int) and in_dim is not None: raise ValueError( - f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(): " - f"Got in_dim={in_dim} for an input but in_dim must be either " - f"an integer dimension or None." + f"""vmap({_get_name(func)}, in_dims={in_dims}, ...)(): +Got in_dim={in_dim} for an input but in_dim must be either +an integer dimension or None.""" ) if isinstance(in_dim, int) and not isinstance( arg, (Tensor, TensorDictBase) ): raise ValueError( - f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(): " - f"Got in_dim={in_dim} for an input but the input is of type " - f"{type(arg)}. We cannot vmap over non-Tensor arguments, " - f"please use None as the respective in_dim" + f"""vmap({_get_name(func)}, in_dims={in_dims}, ...)(): +Got in_dim={in_dim} for an input but the input is of type +{type(arg)}. We cannot vmap over non-Tensor arguments, +please use None as the respective in_dim""" ) if in_dim is not None and (in_dim < -arg.dim() or in_dim >= arg.dim()): raise ValueError( - f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(): " - f"Got in_dim={in_dim} for some input, but that input is a Tensor " - f"of dimensionality {arg.dim()} so expected in_dim to satisfy " - f"-{arg.dim()} <= in_dim < {arg.dim()}." + f"""vmap({_get_name(func)}, in_dims={in_dims}, ...)(): +Got in_dim={in_dim} for some input, but that input is a Tensor +of dimensionality {arg.dim()} so expected in_dim to satisfy +-{arg.dim()} <= in_dim < {arg.dim()}.""" ) if in_dim is not None and in_dim < 0: flat_in_dims[i] = in_dim % arg.dim() @@ -168,9 +168,7 @@ def incompatible_error(): class FunctionalModule(nn.Module): - """ - This is the callable object returned by :func:`make_functional`. - """ + """This is the callable object returned by :func:`make_functional`.""" def __init__(self, stateless_model): super(FunctionalModule, self).__init__() @@ -199,9 +197,7 @@ def forward(self, params, *args, **kwargs): class FunctionalModuleWithBuffers(nn.Module): - """ - This is the callable object returned by :func:`make_functional`. - """ + """This is the callable object returned by :func:`make_functional`.""" def __init__(self, stateless_model): super(FunctionalModuleWithBuffers, self).__init__() @@ -242,7 +238,8 @@ def forward(self, params, buffers, *args, **kwargs): # Some utils for these -def extract_weights(model): +def extract_weights(model: nn.Module): + """Extracts the weights of a model in a tensordict.""" tensordict = TensorDict({}, []) for name, param in list(model.named_parameters(recurse=False)): setattr(model, name, None) @@ -257,7 +254,8 @@ def extract_weights(model): return None -def extract_buffers(model): +def extract_buffers(model: nn.Module): + """Extracts the buffers of a model in a tensordict.""" tensordict = TensorDict({}, []) for name, param in list(model.named_buffers(recurse=False)): setattr(model, name, None) diff --git a/torchrl/modules/models/exploration.py b/torchrl/modules/models/exploration.py index 570b6126215..0d94be22bf6 100644 --- a/torchrl/modules/models/exploration.py +++ b/torchrl/modules/models/exploration.py @@ -21,8 +21,9 @@ class NoisyLinear(nn.Linear): - """ - Noisy Linear Layer, as presented in "Noisy Networks for Exploration", https://arxiv.org/abs/1706.10295v3 + """Noisy Linear Layer. + + Presented in "Noisy Networks for Exploration", https://arxiv.org/abs/1706.10295v3 A Noisy Linear Layer is a linear layer with parametric noise added to the weights. This induced stochasticity can be used in RL networks for the agent's policy to aid efficient exploration. The parameters of the noise are learned @@ -41,6 +42,7 @@ class NoisyLinear(nn.Linear): default: None std_init (scalar): initial value of the Gaussian standard deviation before optimization. default: 1.0 + """ def __init__( @@ -145,8 +147,7 @@ def bias(self) -> Optional[torch.Tensor]: class NoisyLazyLinear(LazyModuleMixin, NoisyLinear): - """ - Noisy Lazy Linear Layer. + """Noisy Lazy Linear Layer. This class makes the Noisy Linear layer lazy, in that the in_feature argument does not need to be passed at initialization (but is inferred after the first call to the layer). @@ -162,6 +163,7 @@ class NoisyLazyLinear(LazyModuleMixin, NoisyLinear): default: None std_init (scalar): initial value of the Gaussian standard deviation before optimization. default: 1.0 + """ def __init__( @@ -225,20 +227,21 @@ def bias(self) -> torch.Tensor: def reset_noise(layer: nn.Module) -> None: + """Resets the noise of noisy layers.""" if hasattr(layer, "reset_noise"): layer.reset_noise() class gSDEModule(nn.Module): - """A gSDE exploration module as presented in "Smooth Exploration for - Robotic Reinforcement Learning" by Antonin Raffin, Jens Kober, - Freek Stulp (https://arxiv.org/abs/2005.05719) + """A gSDE exploration module. + + Presented in "Smooth Exploration for Robotic Reinforcement Learning" by Antonin Raffin, Jens Kober, Freek Stulp (https://arxiv.org/abs/2005.05719) gSDEModule adds a state-dependent exploration noise to an input action. It also outputs the mean, scale (standard deviation) of the normal distribution, as well as the Gaussian noise used. - The noise input should be reset through a `torchrl.envs.transforms.gSDENoise` + The noise input should be reset through a :obj:`torchrl.envs.transforms.gSDENoise` instance: each time the environment is reset, the input noise will be set to zero by the environment transform, indicating to gSDEModule that it has to be resampled. This scheme allows us to have the environemt tell the module to resample a @@ -297,6 +300,7 @@ class gSDEModule(nn.Module): >>> action_second_call = tensordict.get("action").clone() >>> assert (action_second_call == action_first_call).all() # actions are the same >>> assert (action_first_call != dist.base_dist.base_dist.loc).all() # actions are truly stochastic + """ def __init__( @@ -402,11 +406,12 @@ def to(self, device_or_dtype: Union[torch.dtype, DEVICE_TYPING]): class LazygSDEModule(LazyModuleMixin, gSDEModule): """Lazy gSDE Module. + This module behaves exactly as gSDEModule except that it does not require the user to specify the action and state dimension. If the input state is multi-dimensional (i.e. more than one state is provided), the - sigma value is initialized such that the resulting variance will match `sigma_init` - (or 1 if no `sigma_init` value is provided). + sigma value is initialized such that the resulting variance will match :obj:`sigma_init` + (or 1 if no :obj:`sigma_init` value is provided). """ diff --git a/torchrl/modules/models/models.py b/torchrl/modules/models/models.py index b30630a9db5..a359f115f99 100644 --- a/torchrl/modules/models/models.py +++ b/torchrl/modules/models/models.py @@ -34,9 +34,8 @@ class MLP(nn.Sequential): - """ + """A multi-layer perceptron. - A multi-layer perceptron. If MLP receives more than one input, it concatenates them all along the last dimension before passing the resulting tensor through the network. This is aimed at allowing for a seamless interface with calls of the type of @@ -270,8 +269,7 @@ def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor: class ConvNet(nn.Sequential): - """ - A convolutional neural network. + """A convolutional neural network. Args: in_features (int, optional): number of input features; @@ -479,9 +477,9 @@ def _make_net(self, device: Optional[DEVICE_TYPING]) -> nn.Module: class DuelingMlpDQNet(nn.Module): - """ - Creates a Dueling MLP Q-network, as presented in - https://arxiv.org/abs/1511.06581 + """Creates a Dueling MLP Q-network. + + Presented in https://arxiv.org/abs/1511.06581 Args: out_features (int): number of features for the advantage network @@ -564,8 +562,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class DuelingCnnDQNet(nn.Module): - """ - Creates a Dueling CNN Q-network, as presented in https://arxiv.org/abs/1511.06581 + """Dueling CNN Q-network. + + Presented in https://arxiv.org/abs/1511.06581 Args: out_features (int): number of features for the advantage network @@ -637,8 +636,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class DistributionalDQNnet(nn.Module): - """ - Distributional Deep Q-Network. + """Distributional Deep Q-Network. Args: DQNet (nn.Module): Q-Network with output length equal to the number of atoms: @@ -684,8 +682,9 @@ def ddpg_init_last_layer( class DdpgCnnActor(nn.Module): - """ - DDPG Convolutional Actor class, as presented in "CONTINUOUS CONTROL WITH DEEP REINFORCEMENT LEARNING", + """DDPG Convolutional Actor class. + + Presented in "CONTINUOUS CONTROL WITH DEEP REINFORCEMENT LEARNING", https://arxiv.org/pdf/1509.02971.pdf The DDPG Convolutional Actor takes as input an observation (some simple transformation of the observed pixels) and @@ -717,7 +716,7 @@ class DdpgCnnActor(nn.Module): 'bias_last_layer': True, } use_avg_pooling (bool, optional): if True, a nn.AvgPooling layer is - used to aggregate the output. Default is `False`. + used to aggregate the output. Default is :obj:`False`. device (Optional[DEVICE_TYPING]): device to create the module on. """ @@ -769,8 +768,9 @@ def forward(self, observation: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor class DdpgMlpActor(nn.Module): - """ - DDPG Actor class, as presented in "CONTINUOUS CONTROL WITH DEEP REINFORCEMENT LEARNING", + """DDPG Actor class. + + Presented in "CONTINUOUS CONTROL WITH DEEP REINFORCEMENT LEARNING", https://arxiv.org/pdf/1509.02971.pdf The DDPG Actor takes as input an observation vector and returns an action from it. @@ -816,8 +816,9 @@ def forward(self, observation: torch.Tensor) -> torch.Tensor: class DdpgCnnQNet(nn.Module): - """ - DDPG Convolutional Q-value class, as presented in "CONTINUOUS CONTROL WITH DEEP REINFORCEMENT LEARNING", + """DDPG Convolutional Q-value class. + + Presented in "CONTINUOUS CONTROL WITH DEEP REINFORCEMENT LEARNING", https://arxiv.org/pdf/1509.02971.pdf The DDPG Q-value network takes as input an observation and an action, and returns a scalar from it. @@ -846,7 +847,7 @@ class DdpgCnnQNet(nn.Module): 'bias_last_layer': True, } use_avg_pooling (bool, optional): if True, a nn.AvgPooling layer is - used to aggregate the output. Default is `True`. + used to aggregate the output. Default is :obj:`True`. device (Optional[DEVICE_TYPING]): device to create the module on. """ @@ -897,8 +898,9 @@ def forward(self, observation: torch.Tensor, action: torch.Tensor) -> torch.Tens class DdpgMlpQNet(nn.Module): - """ - DDPG Q-value MLP class, as presented in "CONTINUOUS CONTROL WITH DEEP REINFORCEMENT LEARNING", + """DDPG Q-value MLP class. + + Presented in "CONTINUOUS CONTROL WITH DEEP REINFORCEMENT LEARNING", https://arxiv.org/pdf/1509.02971.pdf The DDPG Q-value network takes as input an observation and an action, and returns a scalar from it. diff --git a/torchrl/modules/models/recipes/impala.py b/torchrl/modules/models/recipes/impala.py index 88a8c50362f..01e4ddc200f 100644 --- a/torchrl/modules/models/recipes/impala.py +++ b/torchrl/modules/models/recipes/impala.py @@ -67,7 +67,7 @@ def forward(self, x): return x -class ImpalaNet(nn.Module): +class ImpalaNet(nn.Module): # noqa: D101 def __init__( self, num_actions, @@ -105,7 +105,7 @@ def __init__( self.policy = nn.Linear(core_output_size, self.num_actions) self.baseline = nn.Linear(core_output_size, 1) - def forward(self, x, reward, done, core_state=None, mask=None): + def forward(self, x, reward, done, core_state=None, mask=None): # noqa: D102 if self.batch_first: B, T, *x_shape = x.shape batch_shape = torch.Size([B, T]) @@ -170,10 +170,10 @@ def _allocate_masked_x(self, x, mask): return x_empty -class ImpalaNetTensorDict(ImpalaNet): +class ImpalaNetTensorDict(ImpalaNet): # noqa: D101 observation_key = "pixels" - def forward(self, tensordict: TensorDictBase): + def forward(self, tensordict: TensorDictBase): # noqa: D102 x = tensordict.get(self.observation_key) done = tensordict.get("done").squeeze(-1) reward = tensordict.get("reward").squeeze(-1) diff --git a/torchrl/modules/models/utils.py b/torchrl/modules/models/utils.py index a7fe2b40cad..cc1c4f6057a 100644 --- a/torchrl/modules/models/utils.py +++ b/torchrl/modules/models/utils.py @@ -24,8 +24,8 @@ class SqueezeLayer(nn.Module): - """ - Squeezing layer. + """Squeezing layer. + Squeezes some given singleton dimensions of an input tensor. Args: @@ -41,7 +41,7 @@ def __init__(self, dims: Sequence[int] = (-1,)): raise RuntimeError("dims must all be < 0") self.dims = dims - def forward(self, input: torch.Tensor) -> torch.Tensor: + def forward(self, input: torch.Tensor) -> torch.Tensor: # noqa: D102 for dim in self.dims: if input.shape[dim] != 1: raise RuntimeError( @@ -52,8 +52,8 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: class Squeeze2dLayer(SqueezeLayer): - """ - Squeezing layer for convolutional neural networks. + """Squeezing layer for convolutional neural networks. + Squeezes the last two singleton dimensions of an input tensor. """ @@ -63,8 +63,8 @@ def __init__(self): class SquashDims(nn.Module): - """ - A squashing layer. + """A squashing layer. + Flattens the N last dimensions of an input tensor. Args: @@ -82,8 +82,8 @@ def forward(self, value: torch.Tensor) -> torch.Tensor: def _find_depth(depth: Optional[int], *list_or_ints: Sequence): - """ - Find depth based on a sequence of inputs and a depth indicator. + """Find depth based on a sequence of inputs and a depth indicator. + If the depth is None, it is inferred by the length of one (or more) matching lists of integers. Raises an exception if depth does not match the list lengths or if lists lengths @@ -110,8 +110,7 @@ def _find_depth(depth: Optional[int], *list_or_ints: Sequence): def create_on_device( module_class: Type[nn.Module], device: Optional[DEVICE_TYPING], *args, **kwargs ) -> nn.Module: - """ - Create a new instance of `module_class` on `device`. + """Create a new instance of :obj:`module_class` on :obj:`device`. The new instance is created directly on the device if its constructor supports this. @@ -120,6 +119,7 @@ def create_on_device( device (DEVICE_TYPING): device to create the module on. *args: positional arguments to be passed to the module constructor. **kwargs: keyword arguments to be passed to the module constructor. + """ fullargspec = inspect.getfullargspec(module_class.__init__) if "device" in fullargspec.args or "device" in fullargspec.kwonlyargs: diff --git a/torchrl/modules/planners/cem.py b/torchrl/modules/planners/cem.py index 893d9cf2dd7..0cb87f9012c 100644 --- a/torchrl/modules/planners/cem.py +++ b/torchrl/modules/planners/cem.py @@ -32,7 +32,7 @@ class CEMPlanner(MPCPlannerBase): Args: env (EnvBase): The environment to perform the planning step on (can be - `ModelBasedEnv` or `EnvBase`). + `ModelBasedEnv` or :obj:`EnvBase`). planning_horizon (int): The length of the simulated trajectories optim_steps (int): The number of optimization steps used by the MPC planner diff --git a/torchrl/modules/planners/common.py b/torchrl/modules/planners/common.py index 21775d7f038..1cd3c48a0e2 100644 --- a/torchrl/modules/planners/common.py +++ b/torchrl/modules/planners/common.py @@ -17,11 +17,11 @@ class MPCPlannerBase(TensorDictModule, metaclass=abc.ABCMeta): """MPCPlannerBase abstract Module. - This class inherits from `TensorDictModule`. Provided a `TensorDict`, this module will perform a Model Predictive Control (MPC) planning step. - At the end of the planning step, the `MPCPlanner` will return a proposed action. + This class inherits from :obj:`TensorDictModule`. Provided a :obj:`TensorDict`, this module will perform a Model Predictive Control (MPC) planning step. + At the end of the planning step, the :obj:`MPCPlanner` will return a proposed action. Args: - env (EnvBase): The environment to perform the planning step on (Can be `ModelBasedEnvBase` or `EnvBase`). + env (EnvBase): The environment to perform the planning step on (Can be :obj:`ModelBasedEnvBase` or :obj:`EnvBase`). action_key (str, optional): The key that will point to the computed action. """ diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index 355b58253b1..cf87317cdfd 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -37,7 +37,7 @@ class Actor(TensorDictModule): The Actor class comes with default values for the out_keys (["action"]) and if the spec is provided but not as a CompositeSpec object, it will be - automatically translated into `spec = CompositeSpec(action=spec)` + automatically translated into :obj:`spec = CompositeSpec(action=spec)` Examples: >>> from torchrl.data import TensorDict, @@ -85,11 +85,11 @@ def __init__( class ProbabilisticActor(ProbabilisticTensorDictModule): - """ - General class for probabilistic actors in RL. + """General class for probabilistic actors in RL. + The Actor class comes with default values for the out_keys (["action"]) and if the spec is provided but not as a CompositeSpec object, it will be - automatically translated into `spec = CompositeSpec(action=spec)` + automatically translated into :obj:`spec = CompositeSpec(action=spec)` Examples: >>> from torchrl.data import TensorDict, NdBoundedTensorSpec @@ -149,8 +149,7 @@ def __init__( class ValueOperator(TensorDictModule): - """ - General class for value functions in RL. + """General class for value functions in RL. The ValueOperator class comes with default values for the in_keys and out_keys arguments (["observation"] and ["state_value"] or @@ -211,8 +210,8 @@ def __init__( class QValueHook: - """ - Q-Value hook for Q-value policies. + """Q-Value hook for Q-value policies. + Given a the output of a regular nn.Module, representing the values of the different discrete actions available, a QValueHook will transform these values into their argmax component (i.e. the resulting greedy action). Currently, this is returned as a one-hot encoding. @@ -408,8 +407,8 @@ def _binary(value: torch.Tensor, support: torch.Tensor) -> torch.Tensor: class QValueActor(Actor): - """ - DQN Actor subclass. + """DQN Actor subclass. + This class hooks the module such that it returns a one-hot encoding of the argmax value. Examples: @@ -448,8 +447,8 @@ def __init__(self, *args, action_space: int = "one_hot", **kwargs): class DistributionalQValueActor(QValueActor): - """ - Distributional DQN Actor subclass. + """Distributional DQN Actor subclass. + This class hooks the module such that it returns a one-hot encoding of the argmax value on its support. Examples: @@ -499,8 +498,7 @@ def __init__( class ActorValueOperator(TensorDictSequential): - """ - Actor-value operator. + """Actor-value operator. This class wraps together an actor and a value model that share a common observation embedding network: @@ -616,25 +614,16 @@ def __init__( ) def get_policy_operator(self) -> TensorDictSequential: - """ - - Returns a stand-alone policy operator that maps an observation to an action. - - """ + """Returns a stand-alone policy operator that maps an observation to an action.""" return TensorDictSequential(self.module[0], self.module[1]) def get_value_operator(self) -> TensorDictSequential: - """ - - Returns a stand-alone value network operator that maps an observation to a value estimate. - - """ + """Returns a stand-alone value network operator that maps an observation to a value estimate.""" return TensorDictSequential(self.module[0], self.module[2]) class ActorCriticOperator(ActorValueOperator): - """ - Actor-critic operator. + """Actor-critic operator. This class wraps together an actor and a value model that share a common observation embedding network: @@ -757,11 +746,7 @@ def __init__(self, *args, **kwargs): ) def get_critic_operator(self) -> TensorDictModuleWrapper: - """ - - Returns a stand-alone critic network operator that maps a state-action pair to a critic estimate. - - """ + """Returns a stand-alone critic network operator that maps a state-action pair to a critic estimate.""" return self def get_value_operator(self) -> TensorDictModuleWrapper: @@ -773,8 +758,7 @@ def get_value_operator(self) -> TensorDictModuleWrapper: class ActorCriticWrapper(TensorDictSequential): - """ - Actor-value operator without common module. + """Actor-value operator without common module. This class wraps together an actor and a value model that do not share a common observation embedding network: @@ -864,17 +848,9 @@ def __init__( ) def get_policy_operator(self) -> TensorDictSequential: - """ - - Returns a stand-alone policy operator that maps an observation to an action. - - """ + """Returns a stand-alone policy operator that maps an observation to an action.""" return self.module[0] def get_value_operator(self) -> TensorDictSequential: - """ - - Returns a stand-alone value network operator that maps an observation to a value estimate. - - """ + """Returns a stand-alone value network operator that maps an observation to a value estimate.""" return self.module[1] diff --git a/torchrl/modules/tensordict_module/common.py b/torchrl/modules/tensordict_module/common.py index 2213e5d7a26..a69e86e46cf 100644 --- a/torchrl/modules/tensordict_module/common.py +++ b/torchrl/modules/tensordict_module/common.py @@ -85,12 +85,11 @@ def _forward_hook_safe_action(module, tensordict_in, tensordict_out): class TensorDictModule(nn.Module): - """A TensorDictModule, for TensorDict module, is a python wrapper around a `nn.Module` that reads and writes to a - TensorDict, instead of reading and returning tensors. + """A TensorDictModule, is a python wrapper around a :obj:`nn.Module` that reads and writes to a TensorDict. Args: module (nn.Module): a nn.Module used to map the input to the output parameter space. Can be a functional - module (FunctionalModule or FunctionalModuleWithBuffers), in which case the `forward` method will expect + module (FunctionalModule or FunctionalModuleWithBuffers), in which case the :obj:`forward` method will expect the params (and possibly) buffers keyword arguments. in_keys (iterable of str): keys to be read from input tensordict and passed to the module. If it contains more than one element, the values will be passed in the order given by the in_keys iterable. @@ -100,11 +99,11 @@ class TensorDictModule(nn.Module): spec characterize the space of the first output tensor. safe (bool): if True, the value of the output is checked against the input spec. Out-of-domain sampling can occur because of exploration policies or numerical under/overflow issues. - If this value is out of bounds, it is projected back onto the desired space using the `TensorSpec.project` - method. Default is `False`. + If this value is out of bounds, it is projected back onto the desired space using the :obj:`TensorSpec.project` + method. Default is :obj:`False`. Embedding a neural network in a TensorDictModule only requires to specify the input and output keys. The domain spec can - be passed along if needed. TensorDictModule support functional and regular `nn.Module` objects. In the functional + be passed along if needed. TensorDictModule support functional and regular :obj:`nn.Module` objects. In the functional case, the 'params' (and 'buffers') keyword argument must be specified: Examples: @@ -383,8 +382,9 @@ def forward( return tensordict_out def random(self, tensordict: TensorDictBase) -> TensorDictBase: - """Samples a random element in the target space, irrespective of any input. If multiple output keys are present, - only the first will be written in the input `tensordict`. + """Samples a random element in the target space, irrespective of any input. + + If multiple output keys are present, only the first will be written in the input :obj:`tensordict`. Args: tensordict (TensorDictBase): tensordict where the output value should be written. @@ -398,7 +398,7 @@ def random(self, tensordict: TensorDictBase) -> TensorDictBase: return tensordict def random_sample(self, tensordict: TensorDictBase) -> TensorDictBase: - """see TensorDictModule.random(...)""" + """See :obj:`TensorDictModule.random(...)`.""" return self.random(tensordict) @property @@ -425,8 +425,8 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}(\n{fields})" def make_functional_with_buffers(self, clone: bool = True, native: bool = False): - """ - Transforms a stateful module in a functional module and returns its parameters and buffers. + """Transforms a stateful module in a functional module and returns its parameters and buffers. + Unlike functorch.make_functional_with_buffers, this method supports lazy modules. Args: @@ -533,8 +533,8 @@ def num_buffers(self): class TensorDictModuleWrapper(nn.Module): - """ - Wrapper calss for TensorDictModule objects. + """Wrapper calss for TensorDictModule objects. + Once created, a TensorDictModuleWrapper will behave exactly as the TensorDictModule it contains except for the methods that are overwritten. @@ -571,6 +571,7 @@ class TensorDictModuleWrapper(nn.Module): >>> tensordict_module_wrapped = EpsilonGreedyExploration(tensordict_module) >>> tensordict_module_wrapped(td, params=params, buffers=buffers) >>> print(td.get("output")) + """ def __init__(self, td_module: TensorDictModule): diff --git a/torchrl/modules/tensordict_module/deprec.py b/torchrl/modules/tensordict_module/deprec.py deleted file mode 100644 index b0c2b58c271..00000000000 --- a/torchrl/modules/tensordict_module/deprec.py +++ /dev/null @@ -1,384 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -from __future__ import annotations - -from _warnings import warn -from copy import deepcopy -from typing import Union, Callable, Sequence, Type, Optional, Tuple - -import torch -from torch import Tensor, nn, distributions as d - -from torchrl.data import TensorSpec, DEVICE_TYPING -from torchrl.data.tensordict.tensordict import TensorDictBase -from torchrl.envs.utils import exploration_mode -from torchrl.modules import TensorDictModule, Delta, distributions_maps - - -class ProbabilisticTDModule(TensorDictModule): - """ - DEPRECATED - - A probabilistic TD Module. - ProbabilisticTDModule is a special case of a TensorDictModule where the output is sampled given some rule, specified by - the input `default_interaction_mode` argument and the `exploration_mode()` global function. - - A ProbabilisticTDModule instance has two main features: - - It reads and writes TensorDict objects - - It uses a real mapping R^n -> R^m to create a distribution in R^d from which values can be sampled or computed. - When the __call__ / forward method is called, a distribution is created, and a value computed (using the 'mean', - 'mode', 'median' attribute or the 'rsample', 'sample' method). - - By default, ProbabilisticTDModule distribution class is a Delta distribution, making ProbabilisticTDModule a - simple wrapper around a deterministic mapping function (i.e. it can be used interchangeably with its parent - TensorDictModule). - - Args: - module (nn.Module): a nn.Module used to map the input to the output parameter space. Can be a functional - module (FunctionalModule or FunctionalModuleWithBuffers), in which case the `forward` method will expect - the params (and possibly) buffers keyword arguments. - spec (TensorSpec): specs of the first output tensor. Used when calling td_module.random() to generate random - values in the target space. - in_keys (iterable of str): keys to be read from input tensordict and passed to the module. If it - contains more than one element, the values will be passed in the order given by the in_keys iterable. - out_keys (iterable of str): keys to be written to the input tensordict. The length of out_keys must match the - number of tensors returned by the distribution sampling method plus the extra tensors returned by the - module. - distribution_class (Type, optional): a torch.distributions.Distribution class to be used for sampling. - Default is Delta. - distribution_kwargs (dict, optional): kwargs to be passed to the distribution. - default_interaction_mode (str, optional): default method to be used to retrieve the output value. Should be one of: - 'mode', 'median', 'mean' or 'random' (in which case the value is sampled randomly from the distribution). - Default is 'mode'. - Note: When a sample is drawn, the `ProbabilisticTDModule` instance will fist look for the interaction mode - dictated by the `exploration_mode()` global function. If this returns `None` (its default value), - then the `default_interaction_mode` of the `ProbabilisticTDModule` instance will be used. - Note that DataCollector instances will use `set_exploration_mode` to `"random"` by default. - return_log_prob (bool, optional): if True, the log-probability of the distribution sample will be written in the - tensordict with the key `f'{in_keys[0]}_log_prob'`. Default is `False`. - safe (bool, optional): if True, the value of the sample is checked against the input spec. Out-of-domain sampling can - occur because of exploration policies or numerical under/overflow issues. As for the `spec` argument, - this check will only occur for the distribution sample, but not the other tensors returned by the input - module. If the sample is out of bounds, it is projected back onto the desired space using the - `TensorSpec.project` - method. - Default is `False`. - save_dist_params (bool, optional): if True, the parameters of the distribution (i.e. the output of the module) - will be written to the tensordict along with the sample. Those parameters can be used to - re-compute the original distribution later on (e.g. to compute the divergence between the distribution - used to sample the action and the updated distribution in PPO). - Default is `False`. - cache_dist (bool, optional): if True, the parameters of the distribution (i.e. the output of the module) - will be written to the tensordict along with the sample. Those parameters can be used to - re-compute the original distribution later on (e.g. to compute the divergence between the distribution - used to sample the action and the updated distribution in PPO). - Default is `False`. - - Examples: - >>> from torchrl.data import TensorDict, NdUnboundedContinuousTensorSpec - >>> from torchrl.modules import TanhNormal - >>> import functorch, torch - >>> td = TensorDict({"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3,]) - >>> spec = NdUnboundedContinuousTensorSpec(4) - >>> module = torch.nn.GRUCell(4, 8) - >>> module_func, params, buffers = functorch.make_functional_with_buffers(module) - >>> td_module = ProbabilisticTDModule( - ... module=module_func, - ... spec=spec, - ... in_keys=["input"], - ... out_keys=["output"], - ... distribution_class=TanhNormal, - ... return_log_prob=True, - ... ) - >>> _ = td_module(td, params=params, buffers=buffers) - >>> print(td) - TensorDict( - fields={ - input: Tensor(torch.Size([3, 4]), dtype=torch.float32), - hidden: Tensor(torch.Size([3, 8]), dtype=torch.float32), - output: Tensor(torch.Size([3, 4]), dtype=torch.float32), - output_log_prob: Tensor(torch.Size([3, 1]), dtype=torch.float32)}, - batch_size=torch.Size([3]), - device=cpu, - is_shared=False) - - >>> # In the vmap case, the tensordict is again expended to match the batch: - >>> params = tuple(p.expand(4, *p.shape).contiguous().normal_() for p in params) - >>> buffers = tuple(b.expand(4, *b.shape).contiguous().normal_() for p in buffers) - >>> td_vmap = td_module(td, params=params, buffers=buffers, vmap=True) - >>> print(td_vmap) - TensorDict( - fields={ - input: Tensor(torch.Size([3, 4]), dtype=torch.float32), - hidden: Tensor(torch.Size([3, 8]), dtype=torch.float32), - output: Tensor(torch.Size([3, 4]), dtype=torch.float32), - output_log_prob: Tensor(torch.Size([3, 1]), dtype=torch.float32)}, - batch_size=torch.Size([3]), - device=cpu, - is_shared=False) - - """ - - def __init__( - self, - module: Union[Callable[[Tensor], Tensor], nn.Module], - spec: TensorSpec, - in_keys: Sequence[str], - out_keys: Sequence[str], - distribution_class: Type = Delta, - distribution_kwargs: Optional[dict] = None, - default_interaction_mode: str = "mode", - _n_empirical_est: int = 1000, - return_log_prob: bool = False, - safe: bool = False, - save_dist_params: bool = False, - cache_dist: bool = False, - ): - warn( - "ProbabilisticTDModule will be deprecated soon, consider using ProbabilisticTensorDictModule instead." - ) - super().__init__( - spec=spec, - module=module, - out_keys=out_keys, - in_keys=in_keys, - safe=safe, - ) - - self.save_dist_params = save_dist_params - self._n_empirical_est = _n_empirical_est - self.cache_dist = cache_dist if hasattr(distribution_class, "update") else False - self._dist = None - - if isinstance(distribution_class, str): - distribution_class = distributions_maps.get(distribution_class.lower()) - self.distribution_class = distribution_class - self.distribution_kwargs = ( - distribution_kwargs if distribution_kwargs is not None else dict() - ) - self.return_log_prob = return_log_prob - - self.default_interaction_mode = default_interaction_mode - self.interact = False - - def get_dist( - self, - tensordict: TensorDictBase, - **kwargs, - ) -> Tuple[torch.distributions.Distribution, ...]: - """Calls the module using the tensors retrieved from the 'in_keys' attribute and returns a distribution - using its output. - - Args: - tensordict (TensorDictBase): tensordict with the input values for the creation of the distribution. - - Returns: - a distribution along with other tensors returned by the module. - - """ - tensors = [tensordict.get(key, None) for key in self.in_keys] - out_tensors = self._call_module(tensors, **kwargs) - if isinstance(out_tensors, Tensor): - out_tensors = (out_tensors,) - if self.save_dist_params: - for i, _tensor in enumerate(out_tensors): - tensordict.set(f"{self.out_keys[0]}_dist_param_{i}", _tensor) - dist, num_params = self.build_dist_from_params(out_tensors) - tensors = out_tensors[num_params:] - - return (dist, *tensors) - - def build_dist_from_params( - self, params: Tuple[Tensor, ...] - ) -> Tuple[d.Distribution, int]: - """Given a tuple of temsors, returns a distribution object and the number of parameters used for it. - - Args: - params (Tuple[Tensor, ...]): tensors to be used for the distribution construction. - - Returns: - a distribution object and the number of parameters used for its construction. - - """ - num_params = ( - self.distribution_class.num_params - if hasattr(self.distribution_class, "num_params") - else 1 - ) - if self.cache_dist and self._dist is not None: - self._dist.update(*params[:num_params]) - dist = self._dist - else: - dist = self.distribution_class( - *params[:num_params], **self.distribution_kwargs - ) - if self.cache_dist: - self._dist = dist - return dist, num_params - - def forward( - self, - tensordict: TensorDictBase, - tensordict_out: Optional[TensorDictBase] = None, - **kwargs, - ) -> TensorDictBase: - - dist, *tensors = self.get_dist(tensordict, **kwargs) - out_tensor = self._dist_sample( - dist, *tensors, interaction_mode=exploration_mode() - ) - tensordict_out = self._write_to_tensordict( - tensordict, - [out_tensor] + list(tensors), - tensordict_out, - vmap=kwargs.get("vmap", 0), - ) - if self.return_log_prob: - log_prob = dist.log_prob(out_tensor) - tensordict_out.set("_".join([self.out_keys[0], "log_prob"]), log_prob) - return tensordict_out - - def log_prob(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: - """ - Samples/computes an action using the module and writes this value onto the input tensordict along - with its log-probability. - - Args: - tensordict (TensorDictBase): tensordict containing the in_keys specified in the initializer. - - Returns: - the same tensordict with the out_keys values added/updated as well as a - f"{out_keys[0]}_log_prob" key containing the log-probability of the first output. - - """ - dist, *_ = self.get_dist(tensordict, **kwargs) - lp = dist.log_prob(tensordict.get(self.out_keys[0])) - tensordict.set(self.out_keys[0] + "_log_prob", lp) - return tensordict - - def _dist_sample( - self, - dist: d.Distribution, - *tensors: Tensor, - interaction_mode: bool = None, - eps: float = None, - ) -> Tensor: - if interaction_mode is None: - interaction_mode = self.default_interaction_mode - if not isinstance(dist, d.Distribution): - raise TypeError(f"type {type(dist)} not recognised by _dist_sample") - - if interaction_mode == "mode": - if hasattr(dist, "mode"): - return dist.mode - else: - raise NotImplementedError( - f"method {type(dist)}.mode is not implemented" - ) - - elif interaction_mode == "median": - if hasattr(dist, "median"): - return dist.median - else: - raise NotImplementedError( - f"method {type(dist)}.median is not implemented" - ) - - elif interaction_mode == "mean": - try: - return dist.mean - except AttributeError or NotImplementedError: - if dist.has_rsample: - return dist.rsample((self._n_empirical_est,)).mean(0) - else: - return dist.sample((self._n_empirical_est,)).mean(0) - - elif interaction_mode == "random": - if dist.has_rsample: - return dist.rsample() - else: - return dist.sample() - elif interaction_mode == "net_output": - if len(tensors) > 1: - raise RuntimeError( - "Multiple values passed to _dist_sample when trying to return a single action " - "tensor." - ) - return tensors[0] - else: - raise NotImplementedError(f"unknown interaction_mode {interaction_mode}") - - @property - def device(self): - for p in self.parameters(): - return p.device - return torch.device("cpu") - - def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> ProbabilisticTDModule: - if self.spec is not None: - self.spec = self.spec.to(dest) - out = super().to(dest) - return out - - def __deepcopy__(self, memodict=None): - if memodict is None: - memodict = dict() - self._dist = None - cls = self.__class__ - result = cls.__new__(cls) - memodict[id(self)] = result - for k, v in self.__dict__.items(): - setattr(result, k, deepcopy(v, memodict)) - return result - - def __repr__(self) -> str: - return f"{self.__class__.__name__}(module={self.module}, distribution_class={self.distribution_class}, device={self.device})" - - -class ProbabilisticActor_deprecated(ProbabilisticTDModule): - """ - General class for probabilistic actors in RL. - The Actor class comes with default values for the in_keys and out_keys - arguments (["observation"] and ["action"], respectively). - - Examples: - >>> from torchrl.data import TensorDict, NdBoundedTensorSpec - >>> from torchrl.modules import Actor, TanhNormal - >>> import torch, functorch - >>> td = TensorDict({"observation": torch.randn(3, 4)}, [3,]) - >>> action_spec = NdBoundedTensorSpec(shape=torch.Size([4]), - ... minimum=-1, maximum=1) - >>> module = torch.nn.Linear(4, 8) - >>> fmodule, params, buffers = functorch.make_functional_with_buffers( - ... module) - >>> td_module = ProbabilisticActor_deprecated( - ... module=fmodule, - ... spec=action_spec, - ... distribution_class=TanhNormal, - ... ) - >>> td_module(td, params=params, buffers=buffers) - >>> print(td.get("action")) - - """ - - def __init__( - self, - *args, - in_keys: Optional[Sequence[str]] = None, - out_keys: Optional[Sequence[str]] = None, - **kwargs, - ): - if in_keys is None: - in_keys = ["observation"] - if out_keys is None: - out_keys = ["action"] - - super().__init__( - *args, - in_keys=in_keys, - out_keys=out_keys, - **kwargs, - ) diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index 29d05a2f484..927566537cc 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -27,8 +27,7 @@ class EGreedyWrapper(TensorDictModuleWrapper): - """ - Epsilon-Greedy PO wrapper. + """Epsilon-Greedy PO wrapper. Args: policy (TensorDictModule): a deterministic policy. @@ -96,6 +95,7 @@ def __init__( def step(self, frames: int = 1) -> None: """A step of epsilon decay. + After self.annealing_num_steps, this function is a no-op. Args: @@ -135,8 +135,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: class AdditiveGaussianWrapper(TensorDictModuleWrapper): - """ - Additive Gaussian PO wrapper. + """Additive Gaussian PO wrapper. Args: policy (TensorDictModule): a policy. @@ -145,7 +144,7 @@ class AdditiveGaussianWrapper(TensorDictModuleWrapper): sigma_end (scalar, optional): final epsilon value. default: 0.1 annealing_num_steps (int, optional): number of steps it will take for - sigma to reach the `sigma_end` value. + sigma to reach the :obj:`sigma_end` value. action_key (str, optional): if the policy module has more than one output key, its output spec will be of type CompositeSpec. One needs to know where to find the action spec. @@ -189,6 +188,7 @@ def __init__( def step(self, frames: int = 1) -> None: """A step of sigma decay. + After self.annealing_num_steps, this function is a no-op. Args: @@ -230,9 +230,9 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: class OrnsteinUhlenbeckProcessWrapper(TensorDictModuleWrapper): - """ - Ornstein-Uhlenbeck exploration policy wrapper as presented in "CONTINUOUS CONTROL WITH DEEP REINFORCEMENT LEARNING", - https://arxiv.org/pdf/1509.02971.pdf. + """Ornstein-Uhlenbeck exploration policy wrapper. + + Presented in "CONTINUOUS CONTROL WITH DEEP REINFORCEMENT LEARNING", https://arxiv.org/pdf/1509.02971.pdf. The OU exploration is to be used with continuous control policies and introduces a auto-correlated exploration noise. This enables a sort of 'structured' exploration. @@ -242,7 +242,7 @@ class OrnsteinUhlenbeckProcessWrapper(TensorDictModuleWrapper): Sigma equation: current_sigma = (-(sigma - sigma_min) / (n_steps_annealing) * n_steps + sigma).clamp_min(sigma_min) - To keep track of the steps and noise from sample to sample, an `"ou_prev_noise{id}"` and `"ou_steps{id}"` keys + To keep track of the steps and noise from sample to sample, an :obj:`"ou_prev_noise{id}"` and :obj:`"ou_steps{id}"` keys will be written in the input/output tensordict. It is expected that the tensordict will be zeroed at reset, indicating that a new trajectory is being collected. If not, and is the same tensordict is used for consecutive trajectories, the step count will keep on increasing across rollouts. Note that the collector classes take care of @@ -273,7 +273,7 @@ class OrnsteinUhlenbeckProcessWrapper(TensorDictModuleWrapper): key (str): key of the action to be modified. default: "action" safe (bool): if True, actions that are out of bounds given the action specs will be projected in the space - given the `TensorSpec.project` heuristic. + given the :obj:`TensorSpec.project` heuristic. default: True Examples: diff --git a/torchrl/modules/tensordict_module/probabilistic.py b/torchrl/modules/tensordict_module/probabilistic.py index 58c0dbae980..99f4ef7f3c2 100644 --- a/torchrl/modules/tensordict_module/probabilistic.py +++ b/torchrl/modules/tensordict_module/probabilistic.py @@ -21,23 +21,23 @@ class ProbabilisticTensorDictModule(TensorDictModule): - """ - A probabilistic TD Module. + """A probabilistic TD Module. + `ProbabilisticTDModule` is a special case of a TDModule where the output is - sampled given some rule, specified by the input `default_interaction_mode` - argument and the `exploration_mode()` global function. + sampled given some rule, specified by the input :obj:`default_interaction_mode` + argument and the :obj:`exploration_mode()` global function. It consists in a wrapper around another TDModule that returns a tensordict - updated with the distribution parameters. `ProbabilisticTensorDictModule` is - responsible for constructing the distribution (through the `get_dist()` method) - and/or sampling from this distribution (through a regular `__call__()` to the + updated with the distribution parameters. :obj:`ProbabilisticTensorDictModule` is + responsible for constructing the distribution (through the :obj:`get_dist()` method) + and/or sampling from this distribution (through a regular :obj:`__call__()` to the module). - A `ProbabilisticTensorDictModule` instance has two main features: + A :obj:`ProbabilisticTensorDictModule` instance has two main features: - It reads and writes TensorDict objects - It uses a real mapping R^n -> R^m to create a distribution in R^d from which values can be sampled or computed. - When the `__call__` / `forward` method is called, a distribution is created, + When the :obj:`__call__` / :obj:`forward` method is called, a distribution is created, and a value computed (using the 'mean', 'mode', 'median' attribute or the 'rsample', 'sample' method). The sampling step is skipped if the inner TDModule has already created the desired key-value pair. @@ -48,30 +48,30 @@ class ProbabilisticTensorDictModule(TensorDictModule): Args: module (nn.Module): a nn.Module used to map the input to the output parameter space. Can be a functional - module (FunctionalModule or FunctionalModuleWithBuffers), in which case the `forward` method will expect + module (FunctionalModule or FunctionalModuleWithBuffers), in which case the :obj:`forward` method will expect the params (and possibly) buffers keyword arguments. dist_param_keys (str or iterable of str or dict): key(s) that will be produced by the inner TDModule and that will be used to build the distribution. Importantly, if it's an iterable of string or a string, those keys must match the keywords used by the distribution - class of interest, e.g. `"loc"` and `"scale"` for the Normal distribution + class of interest, e.g. :obj:`"loc"` and :obj:`"scale"` for the Normal distribution and similar. If dist_param_keys is a dictionary,, the keys are the keys of the distribution and the values are the keys in the tensordict that will get match to the corresponding distribution keys. out_key_sample (str or iterable of str): keys where the sampled values will be - written. Importantly, if this key is part of the `out_keys` of the inner model, + written. Importantly, if this key is part of the :obj:`out_keys` of the inner model, the sampling step will be skipped. spec (TensorSpec): specs of the first output tensor. Used when calling td_module.random() to generate random values in the target space. safe (bool, optional): if True, the value of the sample is checked against the input spec. Out-of-domain sampling can - occur because of exploration policies or numerical under/overflow issues. As for the `spec` argument, + occur because of exploration policies or numerical under/overflow issues. As for the :obj:`spec` argument, this check will only occur for the distribution sample, but not the other tensors returned by the input module. If the sample is out of bounds, it is projected back onto the desired space using the `TensorSpec.project` method. - Default is `False`. + Default is :obj:`False`. default_interaction_mode (str, optional): default method to be used to retrieve the output value. Should be one of: 'mode', 'median', 'mean' or 'random' (in which case the value is sampled randomly from the distribution). Default is 'mode'. - Note: When a sample is drawn, the `ProbabilisticTDModule` instance will fist look for the interaction mode + Note: When a sample is drawn, the :obj:`ProbabilisticTDModule` instance will fist look for the interaction mode dictated by the `exploration_mode()` global function. If this returns `None` (its default value), then the `default_interaction_mode` of the `ProbabilisticTDModule` instance will be used. Note that DataCollector instances will use `set_exploration_mode` to `"random"` by default. diff --git a/torchrl/modules/tensordict_module/sequence.py b/torchrl/modules/tensordict_module/sequence.py index ec39a7421ce..ae85f5e6279 100644 --- a/torchrl/modules/tensordict_module/sequence.py +++ b/torchrl/modules/tensordict_module/sequence.py @@ -39,23 +39,23 @@ class TensorDictSequential(TensorDictModule): - """ - A sequence of TDModules. - Similarly to `nn.Sequence` which passes a tensor through a chain of mappings that read and write a single tensor + """A sequence of TensorDictModules. + + Similarly to :obj:`nn.Sequence` which passes a tensor through a chain of mappings that read and write a single tensor each, this module will read and write over a tensordict by querying each of the input modules. - When calling a `TDSequence` instance with a functional module, it is expected that the parameter lists (and + When calling a :obj:`TensorDictSequencial` instance with a functional module, it is expected that the parameter lists (and buffers) will be concatenated in a single list. Args: - modules (iterable of TDModules): ordered sequence of TDModule instances to be run sequentially. + modules (iterable of TensorDictModules): ordered sequence of TensorDictModule instances to be run sequentially. partial_tolerant (bool, optional): if True, the input tensordict can miss some of the input keys. If so, the only module that will be executed are those who can be executed given the keys that are present. - Also, if the input tensordict is a lazy stack of tensordicts AND if partial_tolerant is `True` AND if the + Also, if the input tensordict is a lazy stack of tensordicts AND if partial_tolerant is :obj:`True` AND if the stack does not have the required keys, then TensorDictSequential will scan through the sub-tensordicts looking for those that have the required keys, if any. - TDSequence supports functional, modular and vmap coding: + TensorDictSequence supports functional, modular and vmap coding: Examples: >>> from torchrl.modules.tensordict_module import ProbabilisticTensorDictModule >>> from torchrl.data import TensorDict, NdUnboundedContinuousTensorSpec @@ -230,9 +230,7 @@ def _split_param( def select_subsequence( self, in_keys: Iterable[str] = None, out_keys: Iterable[str] = None ) -> "TensorDictSequential": - """ - Returns a new TensorDictSequential with only the modules that are necessary to compute - the given output keys with the given input keys. + """Returns a new TensorDictSequential with only the modules that are necessary to compute the given output keys with the given input keys. Args: in_keys: input keys of the subsequence we want to select @@ -377,8 +375,8 @@ def __delitem__(self, index: Union[int, slice]) -> None: self.module.__delitem__(idx=index) def make_functional_with_buffers(self, clone: bool = True, native: bool = False): - """ - Transforms a stateful module in a functional module and returns its parameters and buffers. + """Transforms a stateful module in a functional module and returns its parameters and buffers. + Unlike functorch.make_functional_with_buffers, this method supports lazy modules. Args: diff --git a/torchrl/modules/tensordict_module/world_models.py b/torchrl/modules/tensordict_module/world_models.py index 304359a1e60..64ee3bc59e7 100644 --- a/torchrl/modules/tensordict_module/world_models.py +++ b/torchrl/modules/tensordict_module/world_models.py @@ -33,13 +33,9 @@ def __init__( ) def get_transition_model_operator(self) -> TensorDictSequential: - """ - Returns a transition operator that maps either an observation to a world state or a world state to the next world state. - """ + """Returns a transition operator that maps either an observation to a world state or a world state to the next world state.""" return self.module[0] def get_reward_operator(self) -> TensorDictSequential: - """ - Returns a reward operator that maps a world state to a reward. - """ + """Returns a reward operator that maps a world state to a reward.""" return self.module[1] diff --git a/torchrl/modules/utils/mappings.py b/torchrl/modules/utils/mappings.py index 3f34fe37bf2..fb279594c61 100644 --- a/torchrl/modules/utils/mappings.py +++ b/torchrl/modules/utils/mappings.py @@ -12,9 +12,10 @@ def inv_softplus(bias: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]: - """ - inverse softplus function. + """Inverse softplus function. + Args: + bias (float or tensor): the value to be softplus-inverted. """ is_tensor = True if not isinstance(bias, torch.Tensor): @@ -27,8 +28,11 @@ def inv_softplus(bias: Union[float, torch.Tensor]) -> Union[float, torch.Tensor] class biased_softplus(nn.Module): - """ - A biased softplus layer. + """A biased softplus module. + + The bias indicates the value that is to be returned when a zero-tensor is + passed through the transform. + Args: bias (scalar): 'bias' of the softplus transform. If bias=1.0, then a _bias shift will be computed such that softplus(0.0 + _bias) = bias. @@ -46,9 +50,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def expln(x): - """ - A smooth, continuous positive mapping presented in "State-Dependent - Exploration for Policy Gradient Methods" + """A smooth, continuous positive mapping presented in "State-Dependent Exploration for Policy Gradient Methods". + https://people.idsia.ch/~juergen/ecml2008rueckstiess.pdf """ @@ -60,8 +63,7 @@ def expln(x): def mappings(key: str) -> Callable: - """ - Given an input string, return a surjective function f(x): R -> R^+ + """Given an input string, returns a surjective function f(x): R -> R^+. Args: key (str): one of "softplus", "exp", "relu", "expln", @@ -71,6 +73,7 @@ def mappings(key: str) -> Callable: Alternatively, the ```"biased_softplus_{bias}_{min_val}"``` syntax can be used. In that case, the additional ```min_val``` term is a floating point number that will be used to encode the minimum value of the softplus transform. In practice, the equation used is softplus(x + bias) + min_val, where bias and min_val are values computed such that the conditions above are met. + Returns: a Callable diff --git a/torchrl/objectives/__init__.py b/torchrl/objectives/__init__.py index 0663f6e5807..d7f6806ae4e 100644 --- a/torchrl/objectives/__init__.py +++ b/torchrl/objectives/__init__.py @@ -17,4 +17,5 @@ next_state_value, hold_out_net, ) -from .value import bellman_max, c_val, dv_val, vtrace, GAE, TDLambdaEstimate, TDEstimate + +# from .value import bellman_max, c_val, dv_val, vtrace, GAE, TDLambdaEstimate, TDEstimate diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index ef7c5fa6daa..cc438724c90 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -18,12 +18,13 @@ class LossModule(nn.Module): - """ - A parent class for RL losses. + """A parent class for RL losses. + LossModule inherits from nn.Module. It is designed to read an input TensorDict and return another tensordict with loss keys named "loss_*". Splitting the loss in its component can then be used by the trainer to log the various loss values throughout training. Other scalars present in the output tensordict will be logged too. + """ def __init__(self): @@ -31,8 +32,8 @@ def __init__(self): self._param_maps = dict() def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - """It is designed to read an input TensorDict and return another tensordict - with loss keys named "loss*". + """It is designed to read an input TensorDict and return another tensordict with loss keys named "loss*". + Splitting the loss in its component can then be used by the trainer to log the various loss values throughout training. Other scalars present in the output tensordict will be logged too. @@ -43,6 +44,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: A new tensordict with no batch dimension containing various loss scalars which will be named "loss*". It is essential that the losses are returned with this name as they will be read by the trainer before backpropagation. + """ raise NotImplementedError diff --git a/torchrl/objectives/ddpg.py b/torchrl/objectives/ddpg.py index 2c61a5bf0a0..8e2feeb0375 100644 --- a/torchrl/objectives/ddpg.py +++ b/torchrl/objectives/ddpg.py @@ -22,8 +22,8 @@ class DDPGLoss(LossModule): - """ - The DDPG Loss class. + """The DDPG Loss class. + Args: actor_network (TensorDictModule): a policy operator. value_network (TensorDictModule): a Q value operator. @@ -32,9 +32,9 @@ class DDPGLoss(LossModule): via the value operator. loss_function (str): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1". delay_actor (bool, optional): whether to separate the target actor networks from the actor networks used for - data collection. Default is `False`. + data collection. Default is :obj:`False`. delay_value (bool, optional): whether to separate the target value networks from the value networks used for - data collection. Default is `False`. + data collection. Default is :obj:`False`. """ def __init__( @@ -68,6 +68,7 @@ def __init__( def forward(self, input_tensordict: TensorDictBase) -> TensorDict: """Computes the DDPG losses given a tensordict sampled from the replay buffer. + This function will also write a "td_error" key that can be used by prioritized replay buffers to assign a priority to items in the tensordict. diff --git a/torchrl/objectives/deprecated.py b/torchrl/objectives/deprecated.py index 1228c3c649e..13df1c3b616 100644 --- a/torchrl/objectives/deprecated.py +++ b/torchrl/objectives/deprecated.py @@ -19,8 +19,8 @@ class REDQLoss_deprecated(LossModule): - """ - REDQ Loss module. + """REDQ Loss module. + REDQ (RANDOMIZED ENSEMBLED DOUBLE Q-LEARNING: LEARNING FAST WITHOUT A MODEL https://openreview.net/pdf?id=AY8zfZm0tDd) generalizes the idea of using an ensemble of Q-value functions to train a SAC-like algorithm. @@ -42,7 +42,7 @@ class REDQLoss_deprecated(LossModule): Default is 0.1. max_alpha (float, optional): max value of alpha. Default is 10.0. - fixed_alpha (bool, optional): whether alpha should be trained to match a target entropy. Default is `False`. + fixed_alpha (bool, optional): whether alpha should be trained to match a target entropy. Default is :obj:`False`. target_entropy (Union[str, Number], optional): Target entropy for the stochastic policy. Default is "auto". """ @@ -266,4 +266,6 @@ def _loss_alpha(self, log_pi: Tensor) -> Tensor: class DoubleREDQLoss_deprecated(REDQLoss_deprecated): + """[Deprecated] Class for delayed target-REDQ (which should be the default behaviour).""" + delay_qvalue: bool = True diff --git a/torchrl/objectives/dqn.py b/torchrl/objectives/dqn.py index e6e2b8c89cb..882a4056c7c 100644 --- a/torchrl/objectives/dqn.py +++ b/torchrl/objectives/dqn.py @@ -17,14 +17,15 @@ class DQNLoss(LossModule): - """ - The DQN Loss class. + """The DQN Loss class. + Args: value_network (ProbabilisticTDModule): a Q value operator. gamma (scalar): a discount factor for return computation. loss_function (str): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1". delay_value (bool, optional): whether to duplicate the value network into a new target value network to - create a double DQN. Default is `False`. + create a double DQN. Default is :obj:`False`. + """ def __init__( @@ -54,8 +55,8 @@ def __init__( self.priority_key = priority_key def forward(self, input_tensordict: TensorDictBase) -> TensorDict: - """ - Computes the DQN loss given a tensordict sampled from the replay buffer. + """Computes the DQN loss given a tensordict sampled from the replay buffer. + This function will also write a "td_error" key that can be used by prioritized replay buffers to assign a priority to items in the tensordict. @@ -67,7 +68,6 @@ def forward(self, input_tensordict: TensorDictBase) -> TensorDict: a tensor containing the DQN loss. """ - device = self.device if self.device is not None else input_tensordict.device tensordict = input_tensordict.to(device) if tensordict.device != device: @@ -122,8 +122,8 @@ def forward(self, input_tensordict: TensorDictBase) -> TensorDict: class DistributionalDQNLoss(LossModule): - """ - A distributional DQN loss class. + """A distributional DQN loss class. + Distributional DQN uses a value network that outputs a distribution of values over a discrete support of discounted returns (unlike regular DQN where the value network outputs a single point prediction of the diff --git a/torchrl/objectives/functional.py b/torchrl/objectives/functional.py index 99f9a439647..2d0554d1c78 100644 --- a/torchrl/objectives/functional.py +++ b/torchrl/objectives/functional.py @@ -9,8 +9,8 @@ def cross_entropy_loss( log_policy: torch.Tensor, action: torch.Tensor, inplace: bool = False ) -> torch.Tensor: - """ - Returns the cross entropy loss defined as the log-softmax value indexed by the action index. + """Returns the cross entropy loss defined as the log-softmax value indexed by the action index. + Supports discrete (integer) actions or one-hot encodings. Args: @@ -21,8 +21,6 @@ def cross_entropy_loss( This is usually faster but it will change the value of log-policy in place, which may lead to unwanted behaviours. - Returns: - """ if action.shape == log_policy.shape: if action.dtype not in (torch.bool, torch.long, torch.uint8): diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 9542365bbd5..7b9be16d3b9 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -17,8 +17,7 @@ class PPOLoss(LossModule): - """ - A parent PPO loss class. + """A parent PPO loss class. PPO (Proximal Policy Optimisation) is a model-free, online RL algorithm that makes use of a recorded (batch of) trajectories to perform several optimization steps, while actively preventing the updated policy to deviate too @@ -159,8 +158,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: class ClipPPOLoss(PPOLoss): - """ - Clipped PPO loss. + """Clipped PPO loss. The clipped importance weighted loss is computed as follows: loss = -min( weight * advantage, min(max(weight, 1-eps), 1+eps) * advantage) @@ -264,8 +262,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: class KLPENPPOLoss(PPOLoss): - """ - KL Penalty PPO loss. + """KL Penalty PPO loss. The KL penalty loss has the following formula: loss = loss - beta * KL(old_policy, new_policy) diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index 6ea013d40ea..2ddf72d9657 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -23,8 +23,7 @@ class REDQLoss(LossModule): - """ - REDQ Loss module. + """REDQ Loss module. REDQ (RANDOMIZED ENSEMBLED DOUBLE Q-LEARNING: LEARNING FAST WITHOUT A MODEL https://openreview.net/pdf?id=AY8zfZm0tDd) generalizes the idea of using an ensemble of Q-value functions to @@ -47,10 +46,10 @@ class REDQLoss(LossModule): Default is 0.1. max_alpha (float, optional): max value of alpha. Default is 10.0. - fixed_alpha (bool, optional): whether alpha should be trained to match a target entropy. Default is `False`. + fixed_alpha (bool, optional): whether alpha should be trained to match a target entropy. Default is :obj:`False`. target_entropy (Union[str, Number], optional): Target entropy for the stochastic policy. Default is "auto". delay_qvalue (bool, optional): Whether to separate the target Q value networks from the Q value networks used - for data collection. Default is `False`. + for data collection. Default is :obj:`False`. gSDE (bool, optional): Knowing if gSDE is used is necessary to create random noise variables. Default is False diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py index f3473dc29b8..2c593f978f1 100644 --- a/torchrl/objectives/reinforce.py +++ b/torchrl/objectives/reinforce.py @@ -10,8 +10,9 @@ class ReinforceLoss(LossModule): - """Reinforce loss module, as presented in - "Simple statistical gradient-following algorithms for connectionist reinforcement learning", Williams, 1992 + """Reinforce loss module. + + Presented in "Simple statistical gradient-following algorithms for connectionist reinforcement learning", Williams, 1992 https://doi.org/10.1007/BF00992696 """ diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index d49b80095a4..396b02af861 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -23,8 +23,9 @@ class SACLoss(LossModule): - """ - TorchRL implementation of the SAC loss, as presented in "Soft Actor-Critic: Off-Policy Maximum Entropy Deep + """TorchRL implementation of the SAC loss. + + Presented in "Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor" https://arxiv.org/pdf/1801.01290.pdf Args: @@ -51,19 +52,19 @@ class SACLoss(LossModule): fixed_alpha (bool, optional): if True, alpha will be fixed to its initial value. Otherwise, alpha will be optimized to match the 'target_entropy' value. - Default is `False`. + Default is :obj:`False`. target_entropy (float or str, optional): Target entropy for the stochastic policy. Default is "auto", where target entropy is - computed as `-prod(n_actions)`. + computed as :obj:`-prod(n_actions)`. delay_actor (bool, optional): Whether to separate the target actor networks from the actor networks used for data collection. - Default is `False`. + Default is :obj:`False`. delay_qvalue (bool, optional): Whether to separate the target Q value networks from the Q value networks used for data collection. - Default is `False`. + Default is :obj:`False`. delay_value (bool, optional): Whether to separate the target value networks from the value networks used for data collection. - Default is `False`. + Default is :obj:`False`. """ def __init__( diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index e5b88813988..19e85d847fe 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -44,15 +44,14 @@ def distance_loss( loss_function: str, strict_shape: bool = True, ) -> torch.Tensor: - """ - Computes a distance loss between two tensors. + """Computes a distance loss between two tensors. Args: v1 (Tensor): a tensor with a shape compatible with v2 v2 (Tensor): a tensor with a shape compatible with v1 loss_function (str): One of "l2", "l1" or "smooth_l1" representing which loss function is to be used. strict_shape (bool): if False, v1 and v2 are allowed to have a different shape. - Default is `True`. + Default is :obj:`True`. Returns: A tensor of the shape v1.view_as(v2) or v2.view_as(v1) with values equal to the distance loss between the @@ -95,8 +94,7 @@ class ValueLoss: class TargetNetUpdater: - """ - An abstract class for target network update in Double DQN/DDPG. + """An abstract class for target network update in Double DQN/DDPG. Args: loss_module (DQNLoss or DDPGLoss): loss module where the target network should be updated. @@ -201,8 +199,8 @@ def __repr__(self) -> str: class SoftUpdate(TargetNetUpdater): - """ - A soft-update class for target network update in Double DQN/DDPG. + """A soft-update class for target network update in Double DQN/DDPG. + This was proposed in "CONTINUOUS CONTROL WITH DEEP REINFORCEMENT LEARNING", https://arxiv.org/pdf/1509.02971.pdf Args: @@ -229,8 +227,8 @@ def _step(self, p_source: Tensor, p_target: Tensor) -> None: class HardUpdate(TargetNetUpdater): - """ - A hard-update class for target network update in Double DQN/DDPG (by contrast with soft updates). + """A hard-update class for target network update in Double DQN/DDPG (by contrast with soft updates). + This was proposed in the original Double DQN paper: "Deep Reinforcement Learning with Double Q-learning", https://arxiv.org/abs/1509.06461. @@ -302,8 +300,9 @@ def next_state_value( pred_next_val: Optional[Tensor] = None, **kwargs, ) -> torch.Tensor: - """ - Computes the next state value (without gradient) to compute a target for the MSE loss + """Computes the next state value (without gradient) to compute a target value. + + The target value is ususally used to compute a distance loss (e.g. MSE): L = Sum[ (q_value - target_value)^2 ] The target value is computed as r + gamma ** n_steps_to_next * value_next_state @@ -323,6 +322,7 @@ def next_state_value( Returns: a Tensor of the size of the input tensordict containing the predicted value state. + """ if "steps_to_next_obs" in tensordict.keys(): steps_to_next_obs = tensordict.get("steps_to_next_obs").squeeze(-1) diff --git a/torchrl/objectives/value/__init__.py b/torchrl/objectives/value/__init__.py index 87e1edab3d6..d226058da97 100644 --- a/torchrl/objectives/value/__init__.py +++ b/torchrl/objectives/value/__init__.py @@ -4,5 +4,3 @@ # LICENSE file in the root directory of this source tree. from .advantages import GAE, TDLambdaEstimate, TDEstimate -from .returns import bellman_max -from .vtrace import c_val, dv_val, vtrace diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index 24dc2b7e654..27169de0647 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -38,7 +38,7 @@ class TDEstimate(nn.Module): average_rewards (bool, optional): if True, rewards will be standardized before the TD is computed. gradient_mode (bool, optional): if True, gradients are propagated throught - the computation of the value function. Default is `False`. + the computation of the value function. Default is :obj:`False`. value_key (str, optional): key pointing to the state value. Default is `"state_value"`. """ @@ -251,8 +251,8 @@ def forward( class GAE(nn.Module): - """ - A class wrapper around the generalized advantage estimate functional. + """A class wrapper around the generalized advantage estimate functional. + Refer to "HIGH-DIMENSIONAL CONTINUOUS CONTROL USING GENERALIZED ADVANTAGE ESTIMATION" https://arxiv.org/pdf/1506.02438.pdf for more context. @@ -263,6 +263,7 @@ class GAE(nn.Module): average_rewards (bool): if True, rewards will be standardized before the GAE is computed. gradient_mode (bool): if True, gradients are propagated throught the computation of the value function. Default is `False`. + """ def __init__( diff --git a/torchrl/objectives/value/functional.py b/torchrl/objectives/value/functional.py index 0e45a1fbc3b..e276e0d7670 100644 --- a/torchrl/objectives/value/functional.py +++ b/torchrl/objectives/value/functional.py @@ -28,8 +28,8 @@ def generalized_advantage_estimate( reward: torch.Tensor, done: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Get generalized advantage estimate of a trajectory + """Get generalized advantage estimate of a trajectory. + Refer to "HIGH-DIMENSIONAL CONTINUOUS CONTROL USING GENERALIZED ADVANTAGE ESTIMATION" https://arxiv.org/pdf/1506.02438.pdf for more context. @@ -43,6 +43,7 @@ def generalized_advantage_estimate( reward (Tensor): reward of taking actions in the environment. must be a [Batch x TimeSteps x 1] or [Batch x TimeSteps] tensor done (Tensor): boolean flag for end of episode. + """ for tensor in (next_state_value, state_value, reward, done): if tensor.shape[-1] != 1: @@ -79,8 +80,8 @@ def vec_generalized_advantage_estimate( reward: torch.Tensor, done: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Get generalized advantage estimate of a trajectory + """Get generalized advantage estimate of a trajectory. + Refer to "HIGH-DIMENSIONAL CONTINUOUS CONTROL USING GENERALIZED ADVANTAGE ESTIMATION" https://arxiv.org/pdf/1506.02438.pdf for more context. @@ -94,6 +95,7 @@ def vec_generalized_advantage_estimate( reward (Tensor): reward of taking actions in the environment. must be a [Batch x TimeSteps x 1] or [Batch x TimeSteps] tensor done (Tensor): boolean flag for end of episode. + """ for tensor in (next_state_value, state_value, reward, done): if tensor.shape[-1] != 1: @@ -140,8 +142,8 @@ def td_advantage_estimate( reward: torch.Tensor, done: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Get generalized advantage estimate of a trajectory + """Get generalized advantage estimate of a trajectory. + Refer to "HIGH-DIMENSIONAL CONTINUOUS CONTROL USING GENERALIZED ADVANTAGE ESTIMATION" https://arxiv.org/pdf/1506.02438.pdf for more context. @@ -154,6 +156,7 @@ def td_advantage_estimate( reward (Tensor): reward of taking actions in the environment. must be a [Batch x TimeSteps x 1] or [Batch x TimeSteps] tensor done (Tensor): boolean flag for end of episode. + """ for tensor in (next_state_value, state_value, reward, done): if tensor.shape[-1] != 1: @@ -204,6 +207,7 @@ def td_lambda_return_estimate( v4, ] Default is True. + """ for tensor in (next_state_value, reward, done): if tensor.shape[-1] != 1: @@ -298,6 +302,7 @@ def td_lambda_advantage_estimate( v4, ] Default is True. + """ if not state_value.shape == next_state_value.shape: raise RuntimeError("shape of state_value and next_state_value must match") @@ -351,6 +356,7 @@ def vec_td_lambda_advantage_estimate( v4, ] Default is True. + """ return ( vec_td_lambda_return_estimate( @@ -397,7 +403,6 @@ def vec_td_lambda_return_estimate( Default is True. """ - shape = next_state_value.shape if not shape[-1] == 1: raise RuntimeError("last dimension of inputs shape must be singleton") diff --git a/torchrl/objectives/value/returns.py b/torchrl/objectives/value/returns.py deleted file mode 100644 index af8997f969f..00000000000 --- a/torchrl/objectives/value/returns.py +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Union - -import torch -from torch import nn - - -def bellman_max( - next_observation: torch.Tensor, - reward: torch.Tensor, - done: torch.Tensor, - gamma: Union[float, torch.Tensor], - value_model: nn.Module, -): - qmax = value_model(next_observation).max(dim=-1)[0] - nonterminal_target = reward + gamma * qmax - terminal_target = reward - target = done * terminal_target + (~done) * nonterminal_target - return target diff --git a/torchrl/objectives/value/utils.py b/torchrl/objectives/value/utils.py index 531828c20ab..13fe75b8a8c 100644 --- a/torchrl/objectives/value/utils.py +++ b/torchrl/objectives/value/utils.py @@ -8,6 +8,7 @@ def _custom_conv1d(tensor: torch.Tensor, filter: torch.Tensor): """Computes a conv1d filter over a value. + This is usually used to compute a discounted return: Tensor: Filter Result (discounted return) @@ -20,7 +21,7 @@ def _custom_conv1d(tensor: torch.Tensor, filter: torch.Tensor): 0 ] | v This function takes care of applying the one-sided zero padding. In this example, - `Filter_dim` = `Time` = 4, but in practice Filter_dim can be <= to `Time`. + `Filter_dim` = :obj:`Time` = 4, but in practice Filter_dim can be <= to :obj:`Time`. Args: tensor (torch.Tensor): a [ Batch x 1 x Time ] floating-point tensor @@ -29,7 +30,6 @@ def _custom_conv1d(tensor: torch.Tensor, filter: torch.Tensor): Returns: a filtered tensor of the same shape as the input tensor. """ - if filter.ndimension() > 2: # filter will have shape batch_dims x timesteps x filter_dim x 1 # reshape to batch_dims x timesteps x 1 x filter_dim ready for convolving @@ -132,7 +132,8 @@ def roll_by_gather(mat: torch.Tensor, dim: int, shifts: torch.LongTensor): def _make_gammas_tensor(gamma: torch.Tensor, T: int, rolling_gamma: bool): - """prepares a decay tensor for a matrix multiplication: + """Prepares a decay tensor for a matrix multiplication. + Given a tensor gamma of size [*batch, T, 1], it will return a new tensor with size [*batch, T, T+1, 1]. In the rolling_gamma case, a rolling of the gamma values will be performed @@ -142,11 +143,13 @@ def _make_gammas_tensor(gamma: torch.Tensor, T: int, rolling_gamma: bool): [ 1, g3, 0, 0]] Args: - gamma: - T: - rolling_gamma: + gamma (torch.tensor): the gamma tensor to be prepared. + T (int): the time length + rolling_gamma (bool): if True, the gamma value is set for each step + indepndently. If False, the gamma value at (i, t) will be used for the + trajectory following (i, t). - Returns: + Returns: the prepared gamma decay tensor """ # some reshaping code vendored from vec_td_lambda_return_estimate diff --git a/torchrl/objectives/value/vtrace.py b/torchrl/objectives/value/vtrace.py index 11f92ac1b18..43f5246502f 100644 --- a/torchrl/objectives/value/vtrace.py +++ b/torchrl/objectives/value/vtrace.py @@ -9,7 +9,7 @@ import torch -def c_val( +def _c_val( log_pi: torch.Tensor, log_mu: torch.Tensor, c: Union[float, torch.Tensor] = 1, @@ -17,7 +17,7 @@ def c_val( return (log_pi - log_mu).clamp_max(math.log(c)).exp() -def dv_val( +def _dv_val( rewards: torch.Tensor, vals: torch.Tensor, gamma: Union[float, torch.Tensor], @@ -25,13 +25,13 @@ def dv_val( log_pi: torch.Tensor, log_mu: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: - rho = c_val(log_pi, log_mu, rho_bar) + rho = _c_val(log_pi, log_mu, rho_bar) next_vals = torch.cat([vals[:, 1:], torch.zeros_like(vals[:, :1])], 1) dv = rho * (rewards + gamma * next_vals - vals) return dv, rho -def vtrace( +def _vtrace( rewards: torch.Tensor, vals: torch.Tensor, log_pi: torch.Tensor, @@ -44,8 +44,8 @@ def vtrace( if not isinstance(gamma, torch.Tensor): gamma = torch.full_like(vals, gamma) - dv, rho = dv_val(rewards, vals, gamma, rho_bar, log_pi, log_mu) - c = c_val(log_pi, log_mu, c_bar) + dv, rho = _dv_val(rewards, vals, gamma, rho_bar, log_pi, log_mu) + c = _c_val(log_pi, log_mu, c_bar) v_out = [] v_out.append(vals[:, -1] + dv[:, -1]) diff --git a/torchrl/record/recorder.py b/torchrl/record/recorder.py index 206bdb5d384..0cbbc42ca86 100644 --- a/torchrl/record/recorder.py +++ b/torchrl/record/recorder.py @@ -21,8 +21,8 @@ class VideoRecorder(ObservationTransform): - """ - Video Recorder transform. + """Video Recorder transform. + Will record a series of observations from an environment and write them to a Logger object when needed. @@ -31,13 +31,14 @@ class VideoRecorder(ObservationTransform): should be written. tag (str): the video tag in the logger. keys_in (Sequence[str], optional): keys to be read to produce the video. - Default is `"next_pixels"`. + Default is :obj:`"next_pixels"`. skip (int): frame interval in the output video. Default is 2. center_crop (int, optional): value of square center crop. make_grid (bool, optional): if True, a grid is created assuming that a tensor of shape [B x W x H x 3] is provided, with B being the batch size. Default is True. + """ def __init__( @@ -131,15 +132,15 @@ def dump(self, suffix: Optional[str] = None) -> None: class TensorDictRecorder(Transform): - """ - TensorDict recorder. - When the 'dump' method is called, this class will save a stack of the tensordict resulting from `env.step(td)` in a + """TensorDict recorder. + + When the 'dump' method is called, this class will save a stack of the tensordict resulting from :obj:`env.step(td)` in a file with a prefix defined by the out_file_base argument. Args: out_file_base (str): a string defining the prefix of the file where the tensordict will be written. skip_reset (bool): if True, the first TensorDict of the list will be discarded (usually the tensordict - resulting from the call to `env.reset()`) + resulting from the call to :obj:`env.reset()`) default: True skip (int): frame interval for the saved tensordict. default: 4 diff --git a/torchrl/trainers/helpers/collectors.py b/torchrl/trainers/helpers/collectors.py index 6e70fb218f5..0b3e1edbe23 100644 --- a/torchrl/trainers/helpers/collectors.py +++ b/torchrl/trainers/helpers/collectors.py @@ -34,8 +34,7 @@ def sync_async_collector( num_collectors: Optional[int] = None, **kwargs, ) -> MultiaSyncDataCollector: - """ - Runs asynchronous collectors, each running synchronous environments. + """Runs asynchronous collectors, each running synchronous environments. .. aafig:: @@ -78,7 +77,6 @@ def sync_async_collector( **kwargs: Other kwargs passed to the data collectors """ - return _make_collector( MultiaSyncDataCollector, env_fns=env_fns, @@ -96,8 +94,7 @@ def sync_sync_collector( num_collectors: Optional[int] = None, **kwargs, ) -> Union[SyncDataCollector, MultiSyncDataCollector]: - """ - Runs synchronous collectors, each running synchronous environments. + """Runs synchronous collectors, each running synchronous environments. E.g. @@ -262,8 +259,7 @@ def make_collector_offpolicy( cfg: "DictConfig", # noqa: F821 make_env_kwargs: Optional[Dict] = None, ) -> _DataCollector: - """ - Returns a data collector for off-policy algorithms. + """Returns a data collector for off-policy algorithms. Args: make_env (Callable): environment creator @@ -327,6 +323,15 @@ def make_collector_onpolicy( cfg: "DictConfig", # noqa: F821 make_env_kwargs: Optional[Dict] = None, ) -> _DataCollector: + """Makes a collector in on-policy settings. + + Args: + make_env (Callable): environment creator + actor_model_explore (TensorDictModule): Model instance used for evaluation and exploration update + cfg (DictConfig): config for creating collector object + make_env_kwargs (dict): kwargs for the env creator + + """ collector_helper = sync_sync_collector ms = None diff --git a/torchrl/trainers/helpers/envs.py b/torchrl/trainers/helpers/envs.py index f5349a4766a..5968f1c7921 100644 --- a/torchrl/trainers/helpers/envs.py +++ b/torchrl/trainers/helpers/envs.py @@ -47,9 +47,8 @@ def correct_for_frame_skip(cfg: "DictConfig") -> "DictConfig": # noqa: F821 - """ - Correct the arguments for the input frame_skip, by dividing all the arguments that reflect a count of frames by the - frame_skip. + """Correct the arguments for the input frame_skip, by dividing all the arguments that reflect a count of frames by the frame_skip. + This is aimed at avoiding unknowingly over-sampling from the environment, i.e. targetting a total number of frames of 1M but actually collecting frame_skip * 1M frames. @@ -226,14 +225,13 @@ def transformed_env_constructor( state_dim_gsde: Optional[int] = None, batch_dims: Optional[int] = 0, ) -> Union[Callable, EnvCreator]: - """ - Returns an environment creator from an argparse.Namespace built with the appropriate parser constructor. + """Returns an environment creator from an argparse.Namespace built with the appropriate parser constructor. Args: cfg (DictConfig): a DictConfig containing the arguments of the script. video_tag (str, optional): video tag to be passed to the Logger object logger (Logger, optional): logger associated with the script - stats (dict, optional): a dictionary containing the `loc` and `scale` for the `ObservationNorm` transform + stats (dict, optional): a dictionary containing the :obj:`loc` and :obj:`scale` for the `ObservationNorm` transform norm_obs_only (bool, optional): If `True` and `VecNorm` is used, the reward won't be normalized online. Default is `False`. use_env_creator (bool, optional): wheter the `EnvCreator` class should be used. By using `EnvCreator`, @@ -357,6 +355,18 @@ def get_stats_random_rollout( proof_environment: EnvBase = None, key: Optional[str] = None, ): + """Gathers stas (loc and scale) from an environment using random rollouts. + + Args: + cfg (DictConfig): a config object with `init_env_steps` field, indicating + the total number of frames to be collected to compute the stats. + proof_environment (EnvBase instance, optional): if provided, this env will + be used ot execute the rollouts. If not, it will be created using + the cfg object. + key (str, optional): if provided, the stats of this key will be gathered. + If not, it is expected that only one key exists in `env.observation_spec`. + + """ proof_env_is_none = proof_environment is None if proof_env_is_none: proof_environment = transformed_env_constructor( diff --git a/torchrl/trainers/helpers/logger.py b/torchrl/trainers/helpers/logger.py index b02278ae31e..f6e506dd605 100644 --- a/torchrl/trainers/helpers/logger.py +++ b/torchrl/trainers/helpers/logger.py @@ -9,6 +9,8 @@ @dataclass class LoggerConfig: + """Logger config data-class.""" + logger: str = "csv" # recorder type to be used. One of 'tensorboard', 'wandb' or 'csv' record_video: bool = False diff --git a/torchrl/trainers/helpers/models.py b/torchrl/trainers/helpers/models.py index 0df2b5565ce..490288cbbe7 100644 --- a/torchrl/trainers/helpers/models.py +++ b/torchrl/trainers/helpers/models.py @@ -77,8 +77,7 @@ def make_dqn_actor( proof_environment: EnvBase, cfg: "DictConfig", device: torch.device # noqa: F821 ) -> Actor: - """ - DQN constructor helper function. + """DQN constructor helper function. Args: proof_environment (EnvBase): a dummy environment to retrieve the observation and action spec. @@ -204,8 +203,7 @@ def make_ddpg_actor( value_net_kwargs: Optional[dict] = None, device: DEVICE_TYPING = "cpu", ) -> torch.nn.ModuleList: - """ - DDPG constructor helper function. + """DDPG constructor helper function. Args: proof_environment (EnvBase): a dummy environment to retrieve the observation and action spec @@ -223,8 +221,7 @@ def make_ddpg_actor( https://arxiv.org/pdf/1509.02971.pdf. Examples: - >>> from torchrl.trainers.helpers.envs import parser_env_args - >>> from torchrl.trainers.helpers.models import make_ddpg_actor, parser_model_args_continuous + >>> from torchrl.trainers.helpers.models import make_ddpg_actor >>> from torchrl.envs.libs.gym import GymEnv >>> from torchrl.envs.transforms import CatTensors, TransformedEnv, DoubleToFloat, Compose >>> import hydra @@ -268,7 +265,6 @@ def make_ddpg_actor( device=cpu, is_shared=False) """ - # TODO: https://arxiv.org/pdf/1804.08617.pdf from_pixels = cfg.from_pixels @@ -407,8 +403,8 @@ def make_ppo_model( observation_key=None, **kwargs, ) -> ActorValueOperator: - """ - Actor-value model constructor helper function. + """Actor-value model constructor helper function. + Currently constructs MLP networks with immutable default arguments as described in "Proximal Policy Optimization Algorithms", https://arxiv.org/abs/1707.06347 Other configurations can easily be implemented by modifying this function at will. @@ -699,8 +695,7 @@ def make_sac_model( observation_key=None, **kwargs, ) -> nn.ModuleList: - """ - Actor, Q-value and value model constructor helper function for SAC. + """Actor, Q-value and value model constructor helper function for SAC. Follows default parameters proposed in SAC original paper: https://arxiv.org/pdf/1801.01290.pdf. Other configurations can easily be implemented by modifying this function at will. @@ -720,7 +715,6 @@ def make_sac_model( A nn.ModuleList containing the actor, qvalue operator(s) and the value operator. Examples: - >>> from torchrl.trainers.helpers.envs import parser_env_args >>> from torchrl.trainers.helpers.models import make_sac_model, parser_model_args_continuous >>> from torchrl.envs.libs.gym import GymEnv >>> from torchrl.envs.transforms import CatTensors, TransformedEnv, DoubleToFloat, Compose @@ -920,8 +914,8 @@ def make_redq_model( observation_key=None, **kwargs, ) -> nn.ModuleList: - """ - Actor and Q-value model constructor helper function for REDQ. + """Actor and Q-value model constructor helper function for REDQ. + Follows default parameters proposed in REDQ original paper: https://openreview.net/pdf?id=AY8zfZm0tDd. Other configurations can easily be implemented by modifying this function at will. A single instance of the Q-value model is returned. It will be multiplicated by the loss function. @@ -992,7 +986,6 @@ def make_redq_model( is_shared=False) """ - tanh_loc = cfg.tanh_loc default_policy_scale = cfg.default_policy_scale gSDE = cfg.gSDE diff --git a/torchrl/trainers/loggers/common.py b/torchrl/trainers/loggers/common.py index d377cf7a6b5..e9429172cdf 100644 --- a/torchrl/trainers/loggers/common.py +++ b/torchrl/trainers/loggers/common.py @@ -12,10 +12,7 @@ class Logger: - """ - A template for loggers - - """ + """A template for loggers.""" def __init__(self, exp_name: str, log_dir: str) -> None: self.exp_name = exp_name diff --git a/torchrl/trainers/loggers/csv.py b/torchrl/trainers/loggers/csv.py index e193c3aa43e..93126f6de22 100644 --- a/torchrl/trainers/loggers/csv.py +++ b/torchrl/trainers/loggers/csv.py @@ -13,6 +13,8 @@ class CSVExperiment: + """A CSV logger experiment class.""" + def __init__(self, log_dir: str): self.scalars = defaultdict(lambda: []) self.videos_counter = defaultdict(lambda: 0) @@ -56,8 +58,7 @@ def __repr__(self) -> str: class CSVLogger(Logger): - """ - A minimal-dependecy CSV-logger. + """A minimal-dependecy CSV-logger. Args: exp_name (str): The name of the experiment. @@ -72,16 +73,12 @@ def __init__(self, exp_name: str, log_dir: Optional[str] = None) -> None: self._has_imported_moviepy = False def _create_experiment(self) -> "CSVExperiment": - """ - Creates a CSV experiment. - - """ + """Creates a CSV experiment.""" log_dir = str(os.path.join(self.log_dir, self.exp_name)) return CSVExperiment(log_dir) def log_scalar(self, name: str, value: float, step: int = None) -> None: - """ - Logs a scalar value to the tensorboard. + """Logs a scalar value to the tensorboard. Args: name (str): The name of the scalar. @@ -91,8 +88,7 @@ def log_scalar(self, name: str, value: float, step: int = None) -> None: self.experiment.add_scalar(name, value, global_step=step) def log_video(self, name: str, video: Tensor, step: int = None, **kwargs) -> None: - """ - Log videos inputs to the tensorboard. + """Log videos inputs to the tensorboard. Args: name (str): The name of the video. @@ -113,8 +109,7 @@ def log_video(self, name: str, video: Tensor, step: int = None, **kwargs) -> Non ) def log_hparams(self, cfg: "DictConfig") -> None: # noqa: F821 - """ - Logs the hyperparameters of the experiment. + """Logs the hyperparameters of the experiment. Args: cfg (DictConfig): The configuration of the experiment. diff --git a/torchrl/trainers/loggers/mlflow.py b/torchrl/trainers/loggers/mlflow.py index ba4a8ab390e..dbc6929ad8d 100644 --- a/torchrl/trainers/loggers/mlflow.py +++ b/torchrl/trainers/loggers/mlflow.py @@ -115,8 +115,7 @@ def log_video(self, name: str, video: Tensor, **kwargs) -> None: mlflow.log_artifact(f.name, "videos") def log_hparams(self, cfg: "DictConfig") -> None: # noqa: F821 - """ - Logs the hyperparameters of the experiment. + """Logs the hyperparameters of the experiment. Args: cfg (DictConfig): The configuration of the experiment. diff --git a/torchrl/trainers/loggers/tensorboard.py b/torchrl/trainers/loggers/tensorboard.py index aa46735fd69..c506837ff41 100644 --- a/torchrl/trainers/loggers/tensorboard.py +++ b/torchrl/trainers/loggers/tensorboard.py @@ -18,8 +18,7 @@ class TensorboardLogger(Logger): - """ - Wrapper for the Tensoarboard logger. + """Wrapper for the Tensoarboard logger. Args: exp_name (str): The name of the experiment. @@ -35,11 +34,11 @@ def __init__(self, exp_name: str, log_dir: str = "tb_logs") -> None: self._has_imported_moviepy = False def _create_experiment(self) -> "SummaryWriter": - """ - Creates a tensorboard experiment. + """Creates a tensorboard experiment. Args: exp_name (str): The name of the experiment. + Returns: SummaryWriter: The tensorboard experiment. @@ -51,24 +50,24 @@ def _create_experiment(self) -> "SummaryWriter": return SummaryWriter(log_dir=log_dir) def log_scalar(self, name: str, value: float, step: int = None) -> None: - """ - Logs a scalar value to the tensorboard. + """Logs a scalar value to the tensorboard. Args: name (str): The name of the scalar. value (float): The value of the scalar. step (int, optional): The step at which the scalar is logged. Defaults to None. + """ self.experiment.add_scalar(name, value, global_step=step) def log_video(self, name: str, video: Tensor, step: int = None, **kwargs) -> None: - """ - Log videos inputs to the tensorboard. + """Log videos inputs to the tensorboard. Args: name (str): The name of the video. video (Tensor): The video to be logged. step (int, optional): The step at which the video is logged. Defaults to None. + """ # check for correct format of the video tensor ((N), T, C, H, W) # check that the color channel (C) is either 1 or 3 @@ -93,11 +92,11 @@ def log_video(self, name: str, video: Tensor, step: int = None, **kwargs) -> Non ) def log_hparams(self, cfg: "DictConfig") -> None: # noqa: F821 - """ - Logs the hyperparameters of the experiment. + """Logs the hyperparameters of the experiment. Args: cfg (DictConfig): The configuration of the experiment. + """ txt = "\n\t".join([f"{k}: {val}" for k, val in sorted(vars(cfg).items())]) self.experiment.add_text("hparams", txt) diff --git a/torchrl/trainers/loggers/wandb.py b/torchrl/trainers/loggers/wandb.py index 738049feb95..d259bb6b916 100644 --- a/torchrl/trainers/loggers/wandb.py +++ b/torchrl/trainers/loggers/wandb.py @@ -29,8 +29,7 @@ class WandbLogger(Logger): - """ - Wrapper for the wandb logger. + """Wrapper for the wandb logger. Args: exp_name (str): The name of the experiment. @@ -86,15 +85,14 @@ def __init__( self.video_log_counter = 0 def _create_experiment(self) -> "WandbLogger": - """ - Creates a wandb experiment. + """Creates a wandb experiment. Args: exp_name (str): The name of the experiment. + Returns: WandbLogger: The wandb experiment logger. """ - if self.offline: os.environ["WANDB_MODE"] = "dryrun" @@ -103,8 +101,7 @@ def _create_experiment(self) -> "WandbLogger": return wandb.init(**self._wandb_kwargs) def log_scalar(self, name: str, value: float, step: Optional[int] = None) -> None: - """ - Logs a scalar value to wandb. + """Logs a scalar value to wandb. Args: name (str): The name of the scalar. @@ -118,8 +115,7 @@ def log_scalar(self, name: str, value: float, step: Optional[int] = None) -> Non self.experiment.log({name: value}) def log_video(self, name: str, video: Tensor, **kwargs) -> None: - """ - Log videos inputs to wandb. + """Log videos inputs to wandb. Args: name (str): The name of the video. @@ -127,7 +123,7 @@ def log_video(self, name: str, video: Tensor, **kwargs) -> None: **kwargs: Other keyword arguments. By construction, log_video supports 'step' (integer indicating the step index), 'format' (default is 'mp4') and 'fps' (default: 6). Other kwargs are - passed as-is to the `experiment.log` method. + passed as-is to the :obj:`experiment.log` method. """ # check for correct format of the video tensor ((N), T, C, H, W) # check that the color channel (C) is either 1 or 3 @@ -165,13 +161,12 @@ def log_video(self, name: str, video: Tensor, **kwargs) -> None: ) def log_hparams(self, cfg: "DictConfig") -> None: # noqa: F821 - """ - Logs the hyperparameters of the experiment. + """Logs the hyperparameters of the experiment. Args: cfg (DictConfig): The configuration of the experiment. - """ + """ if type(cfg) is not dict and _has_omgaconf: if not _has_omgaconf: raise ImportError( diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index 60300cdb62c..48c52decd0a 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -72,7 +72,7 @@ class Trainer: of its specific operations: they all must be hooked at specific points in the training loop. - To build a Trainer, one needs an iterable data source (a `collector`), a + To build a Trainer, one needs an iterable data source (a :obj:`collector`), a loss module and an optimizer. Args: @@ -94,12 +94,12 @@ class Trainer: clip_grad_norm (bool, optional): If True, the gradients will be clipped based on the total norm of the model parameters. If False, all the partial derivatives will be clamped to - (-clip_norm, clip_norm). Default is `True`. + (-clip_norm, clip_norm). Default is :obj:`True`. clip_norm (Number, optional): value to be used for clipping gradients. Default is 100.0. progress_bar (bool, optional): If True, a progress bar will be displayed using tqdm. If tqdm is not installed, this option - won't have any effect. Default is `True` + won't have any effect. Default is :obj:`True` seed (int, optional): Seed to be used for the collector, pytorch and numpy. Default is 42. save_trainer_interval (int, optional): How often the trainer should be @@ -534,9 +534,9 @@ class ReplayBufferTrainer: flatten_tensordicts (bool, optional): if True, the tensordicts will be flattened (or equivalently masked with the valid mask obtained from the collector) before being passed to the replay buffer. Otherwise, - no transform will be achieved other than padding (see `max_dims` arg below). + no transform will be achieved other than padding (see :obj:`max_dims` arg below). Defaults to True - max_dims (sequence of int, optional): if `flatten_tensordicts` is set to False, + max_dims (sequence of int, optional): if :obj:`flatten_tensordicts` is set to False, this will be a list of the length of the batch_size of the provided tensordicts that represent the maximum size of each. If provided, this list of sizes will be used to pad the tensordict and make their shape @@ -624,9 +624,9 @@ class LogReward: """Reward logger hook. Args: - logname (str, optional): name of the rewards to be logged. Default is `"r_training"`. + logname (str, optional): name of the rewards to be logged. Default is :obj:`"r_training"`. log_pbar (bool, optional): if True, the reward value will be logged on - the progression bar. Default is `False`. + the progression bar. Default is :obj:`False`. Examples: >>> log_reward = LogReward("reward") @@ -762,7 +762,7 @@ class BatchSubSampler: sub_traj_len (int, optional): length of the trajectories that sub-samples must have in online settings. Default is -1 (i.e. takes the full length of the trajectory) - min_sub_traj_len (int, optional): minimum value of `sub_traj_len`, in + min_sub_traj_len (int, optional): minimum value of :obj:`sub_traj_len`, in case some elements of the batch contain few steps. Default is -1 (i.e. no minimum value) @@ -798,8 +798,8 @@ def __call__(self, batch: TensorDictBase) -> TensorDictBase: dimensions, it is assumed that the first dimension represents the batch, and the second the time. If so, the resulting subsample will contain consecutive samples across time. - """ + """ if batch.ndimension() == 1: return batch[torch.randperm(batch.shape[0])[: self.batch_size]]