Skip to content

Commit

Permalink
[Feature] More restrictive tests on docstrings (pytorch#457)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Oct 19, 2022
1 parent b584d78 commit 041dd53
Show file tree
Hide file tree
Showing 80 changed files with 876 additions and 1,519 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@ repos:
rev: 6.1.1
hooks:
- id: pydocstyle
files: ^torchrl/
22 changes: 12 additions & 10 deletions docs/source/reference/objectives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,25 +53,27 @@ PPO

Returns
-------
.. currentmodule:: torchrl.objectives.value

.. autosummary::
:toctree: generated/
:template: rl_template_noinherit.rst

value.GAE
value.TDLambdaEstimate
value.TDEstimate
value.functional.generalized_advantage_estimate
value.functional.vec_generalized_advantage_estimate
value.functional.vec_td_lambda_return_estimate
value.functional.vec_td_lambda_advantage_estimate
value.functional.td_lambda_return_estimate
value.functional.td_lambda_advantage_estimate
value.functional.td_advantage_estimate
GAE
TDLambdaEstimate
TDEstimate
functional.generalized_advantage_estimate
functional.vec_generalized_advantage_estimate
functional.vec_td_lambda_return_estimate
functional.vec_td_lambda_advantage_estimate
functional.td_lambda_return_estimate
functional.td_lambda_advantage_estimate
functional.td_advantage_estimate


Utils
-----
.. currentmodule:: torchrl.objectives

.. autosummary::
:toctree: generated/
Expand Down
11 changes: 10 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,13 @@ per-file-ignores =
exclude = venv

[pydocstyle]
select = D417 # Missing argument descriptions in the docstring
;select = D417 # Missing argument descriptions in the docstring
;inherit = false
match = .*\.py
;match_dir = ^(?!(.circlecli|test)).*
convention = google
add-ignore = D100, D104, D105, D107, D102
ignore-decorators =
test_*
; test/*.py
; .circleci/*
2 changes: 2 additions & 0 deletions test/test_postprocs.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ def test_multistep(n, key, device, T=11):


class TestSplits:
"""Tests the splitting of collected tensordicts in trajectories."""

@staticmethod
def create_fake_trajs(
num_workers=32,
Expand Down
2 changes: 2 additions & 0 deletions test/test_tensor_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,8 @@ def test_nested_composite_spec(self, is_complete, device, dtype):


class TestEquality:
"""Tests spec comparison."""

@staticmethod
def _ts_make_all_fields_equal(ts_to, ts_from):
ts_to.shape = ts_from.shape
Expand Down
4 changes: 2 additions & 2 deletions torchrl/_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@


def is_module_available(*modules: str) -> bool:
r"""Returns if a top-level module with :attr:`name` exists *without**
importing it.
r"""Returns if a top-level module with :attr:`name` exists *without** importing it.
This is generally safer than try-catch block around a
`import X`. It avoids third party libraries breaking assumptions of some of
our tests, e.g., setting multiprocessing start method when imported
(see librosa/#747, torchvision/#544).
"""
return all(importlib.util.find_spec(m) is not None for m in modules)

Expand Down
28 changes: 25 additions & 3 deletions torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@


class timeit:
"""
A dirty but easy to use decorator for profiling code
"""
"""A dirty but easy to use decorator for profiling code."""

_REG = {}

Expand Down Expand Up @@ -71,6 +69,17 @@ def _check_for_faulty_process(processes):


def seed_generator(seed):
"""A seed generator function.
Given a seeding integer, generates a deterministic next seed to be used in a
seeding sequence.
Args:
seed (int): initial seed.
Returns: Next seed of the chain.
"""
max_seed_val = (
2 ** 32 - 1
) # https://discuss.pytorch.org/t/what-is-the-max-seed-you-can-set-up/145688
Expand All @@ -80,6 +89,14 @@ def seed_generator(seed):


class KeyDependentDefaultDict(collections.defaultdict):
"""A key-dependent default dict.
Examples:
>>> my_dict = KeyDependentDefaultDict(lambda key: "foo_" + key)
>>> print(my_dict["bar"])
foo_bar
"""

def __init__(self, fun):
self.fun = fun
super().__init__()
Expand All @@ -91,6 +108,11 @@ def __missing__(self, key):


def prod(sequence):
"""General prod function, that generalised usage across math and np.
Created for multiple python versions compatibility).
"""
if hasattr(math, "prod"):
return math.prod(sequence)
else:
Expand Down
35 changes: 17 additions & 18 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
class RandomPolicy:
def __init__(self, action_spec: TensorSpec):
"""Random policy for a given action_spec.
This is a wrapper around the action_spec.rand method.
Expand All @@ -63,6 +64,7 @@ def __init__(self, action_spec: TensorSpec):
>>> action_spec = NdBoundedTensorSpec(-torch.ones(3), torch.ones(3))
>>> actor = RandomPolicy(spec=action_spec)
>>> td = actor(TensorDict(batch_size=[])) # selects a random action in the cube [-1; 1]
"""
self.action_spec = action_spec

Expand Down Expand Up @@ -127,7 +129,9 @@ def _get_policy_and_device(
) -> Tuple[
ProbabilisticTensorDictModule, torch.device, Union[None, Callable[[], dict]]
]:
"""From a policy and a device, assigns the self.device attribute to
"""Util method to get a policy and its device given the collector __init__ inputs.
From a policy and a device, assigns the self.device attribute to
the desired device and maps the policy onto it or (if the device is
ommitted) assigns the self.device attribute to the policy device.
Expand Down Expand Up @@ -247,8 +251,7 @@ def __repr__(self) -> str:


class SyncDataCollector(_DataCollector):
"""
Generic data collector for RL problems. Requires and environment constructor and a policy.
"""Generic data collector for RL problems. Requires and environment constructor and a policy.
Args:
create_env_fn (Callable), returns an instance of EnvBase class.
Expand Down Expand Up @@ -684,15 +687,13 @@ def __del__(self):
self.shutdown() # make sure env is closed

def state_dict(self) -> OrderedDict:
"""Returns the local state_dict of the data collector (environment
and policy).
"""Returns the local state_dict of the data collector (environment and policy).
Returns:
an ordered dictionary with fields `"policy_state_dict"` and
an ordered dictionary with fields :obj:`"policy_state_dict"` and
`"env_state_dict"`.
"""

if isinstance(self.env, TransformedEnv):
env_state_dict = self.env.transform.state_dict()
elif isinstance(self.env, _BatchedEnv):
Expand All @@ -716,7 +717,7 @@ def load_state_dict(self, state_dict: OrderedDict, **kwargs) -> None:
Args:
state_dict (OrderedDict): ordered dictionary containing the fields
`"policy_state_dict"` and `"env_state_dict"`.
`"policy_state_dict"` and :obj:`"env_state_dict"`.
"""
strict = kwargs.get("strict", True)
Expand Down Expand Up @@ -791,7 +792,7 @@ class _MultiDataCollector(_DataCollector):
reset_when_done (bool, optional): if True, the contained environment will be reset
every time it hits a done. If the env contains multiple independent envs, a
reset index will be passed to it to reset only thos environments that need to
be reset. In practice, this will happen through a call to `env.reset(tensordict)`,
be reset. In practice, this will happen through a call to :obj:`env.reset(tensordict)`,
in other words, if the env is a multi-agent env, all agents will be
reset once one of them is done.
Defaults to `True`.
Expand Down Expand Up @@ -1081,8 +1082,8 @@ def reset(self, reset_idx: Optional[Sequence[bool]] = None) -> None:
raise RuntimeError(f"Expected msg='reset', got {msg}")

def state_dict(self) -> OrderedDict:
"""
Returns the state_dict of the data collector.
"""Returns the state_dict of the data collector.
Each field represents a worker containing its own state_dict.
"""
Expand All @@ -1098,15 +1099,13 @@ def state_dict(self) -> OrderedDict:
return state_dict

def load_state_dict(self, state_dict: OrderedDict) -> None:
"""
Loads the state_dict on the workers.
"""Loads the state_dict on the workers.
Args:
state_dict (OrderedDict): state_dict of the form
``{"worker0": state_dict0, "worker1": state_dict1}``.
"""

for idx in range(self.num_workers):
self.pipes[idx].send((state_dict[f"worker{idx}"], "load_state_dict"))
for idx in range(self.num_workers):
Expand All @@ -1116,13 +1115,13 @@ def load_state_dict(self, state_dict: OrderedDict) -> None:


class MultiSyncDataCollector(_MultiDataCollector):
"""Runs a given number of DataCollectors on separate processes
synchronously.
"""Runs a given number of DataCollectors on separate processes synchronously.
The collection starts when the next item of the collector is queried,
and no environment step is computed in between the reception of a batch of
trajectory and the start of the next collection.
This class can be safely used with online RL algorithms.
"""

__doc__ += _MultiDataCollector.__doc__
Expand Down Expand Up @@ -1212,12 +1211,12 @@ def iterator(self) -> Iterator[TensorDictBase]:


class MultiaSyncDataCollector(_MultiDataCollector):
"""Runs a given number of DataCollectors on separate processes
asynchronously.
"""Runs a given number of DataCollectors on separate processes asynchronously.
The collection keeps on occuring on all processes even between the time
the batch of rollouts is collected and the next call to the iterator.
This class can be safely used with offline RL algorithms.
"""

__doc__ += _MultiDataCollector.__doc__
Expand Down
5 changes: 4 additions & 1 deletion torchrl/collectors/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ def stacked_output_fun(*args, **kwargs):


def split_trajectories(rollout_tensordict: TensorDictBase) -> TensorDictBase:
"""Takes a tensordict with a key traj_ids that indicates the id of each trajectory.
"""A util function for trajectory separation.
Takes a tensordict with a key traj_ids that indicates the id of each trajectory.
From there, builds a B x T x ... zero-padded tensordict with B batches on max duration T
"""
traj_ids = rollout_tensordict.get("traj_ids")
Expand Down
10 changes: 7 additions & 3 deletions torchrl/data/postprocs/postprocs.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,9 @@ def select_and_repeat(


class MultiStep(nn.Module):
"""
Multistep reward, as presented in 'Sutton, R. S. 1988. Learning to
"""Multistep reward transform.
Presented in 'Sutton, R. S. 1988. Learning to
predict by the methods of temporal differences. Machine learning 3(
1):9–44.'
Expand Down Expand Up @@ -140,7 +141,9 @@ def __init__(
)

def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
"""Args:
"""Re-writes a tensordict following the multi-step transform.
Args:
tensordict: TennsorDict instance with Batch x Time-steps x ...
dimensions.
The TensorDict must contain a "reward" and "done" key. All
Expand All @@ -160,6 +163,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
- The "reward" values will be replaced by the newly computed
rewards.
Returns:
in-place transformation of the input tensordict.
Expand Down
19 changes: 9 additions & 10 deletions torchrl/data/replay_buffers/rb_prototype.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
from .replay_buffers import pin_memory_output, stack_tensors, stack_td
from .samplers import Sampler, RandomSampler
from .storages import Storage, ListStorage
from .utils import INT_CLASSES, to_numpy
from .utils import INT_CLASSES, _to_numpy
from .writers import Writer, RoundRobinWriter


class ReplayBuffer:
"""
#TODO: Description of the ReplayBuffer class needed.
"""A generic, composable replay buffer class.
Args:
storage (Storage, optional): the storage to be used. If none is provided
a default ListStorage with max_size of 1_000 will be created.
Expand Down Expand Up @@ -73,7 +73,7 @@ def __repr__(self) -> str:

@pin_memory_output
def __getitem__(self, index: Union[int, torch.Tensor]) -> Any:
index = to_numpy(index)
index = _to_numpy(index)
with self._replay_lock:
data = self._storage[index]

Expand All @@ -97,8 +97,7 @@ def add(self, data: Any) -> int:
return index

def extend(self, data: Sequence) -> torch.Tensor:
"""Extends the replay buffer with one or more elements contained in
an iterable.
"""Extends the replay buffer with one or more elements contained in an iterable.
Args:
data (iterable): collection of data to be added to the replay
Expand Down Expand Up @@ -130,8 +129,8 @@ def _sample(self, batch_size: int) -> Tuple[Any, dict]:
return data, info

def sample(self, batch_size: int) -> Tuple[Any, dict]:
"""
Samples a batch of data from the replay buffer.
"""Samples a batch of data from the replay buffer.
Uses Sampler to sample indices, and retrieves them from Storage.
Args:
Expand All @@ -158,8 +157,8 @@ def sample(self, batch_size: int) -> Tuple[Any, dict]:


class TensorDictReplayBuffer(ReplayBuffer):
"""
TensorDict-specific wrapper around the ReplayBuffer class.
"""TensorDict-specific wrapper around the ReplayBuffer class.
Args:
priority_key (str): the key at which priority is assumed to be stored
within TensorDicts added to this ReplayBuffer.
Expand Down
Loading

0 comments on commit 041dd53

Please sign in to comment.