Skip to content

Commit

Permalink
[Doc] Fix tanh_loc in docstrings (pytorch#1203)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored May 29, 2023
1 parent f3e9a1d commit 3528ffc
Show file tree
Hide file tree
Showing 13 changed files with 20 additions and 20 deletions.
2 changes: 1 addition & 1 deletion torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class SamplerWithoutReplacement(Sampler):
When the sampler reaches the end of the list of available indices, a new sample order
will be generated and the resulting indices will be completed with this new draw, which
can lead to duplicated indices, unless the :obj:`drop_last` argument is set to :obj:`True`.
can lead to duplicated indices, unless the :obj:`drop_last` argument is set to ``True``.
"""

Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2389,7 +2389,7 @@ class CompositeSpec(TensorSpec):
This is useful to build nested CompositeSpecs with tuple indices.
**kwargs (key (str): value (TensorSpec)): dictionary of tensorspecs
to be stored. Values can be None, in which case is_in will be assumed
to be :obj:`True` for the corresponding tensors, and :obj:`project()` will have no
to be ``True`` for the corresponding tensors, and :obj:`project()` will have no
effect. `spec.encode` cannot be used with missing values.
Examples:
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1106,7 +1106,7 @@ def rollout(
callback (callable, optional): function to be called at each iteration with the given TensorDict.
auto_reset (bool, optional): if ``True``, resets automatically the environment
if it is in a done state when the rollout is initiated.
Default is :obj:`True`.
Default is ``True``.
auto_cast_to_device (bool, optional): if ``True``, the device of the tensordict is automatically cast to the
policy device before the policy is used. Default is ``False``.
break_when_any_done (bool): breaks if any of the done state is True. If False, a reset() is
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/libs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class GymPixelObservationWrapper(ObservationWrapper):
Args:
env: The environment to wrap.
pixels_only: If :obj:`True` (default), the original observation returned
pixels_only: If ``True`` (default), the original observation returned
by the wrapped environment will be discarded, and a dictionary
observation will only include pixels. If ``False``, the
observation dictionary will contain both the original
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/transforms/r3m.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ class R3MTransform(Compose):
Defaults to 244.
stack_images (bool, optional): if False, the images given in the :obj:`in_keys`
argument will be treaded separetely and each will be given a single,
separated entry in the output tensordict. Defaults to :obj:`True`.
separated entry in the output tensordict. Defaults to ``True``.
download (bool, torchvision Weights config or corresponding string):
if ``True``, the weights will be downloaded using the torch.hub download
API (i.e. weights will be cached for future use).
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/transforms/vip.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ class VIPTransform(Compose):
Defaults to 244.
stack_images (bool, optional): if False, the images given in the :obj:`in_keys`
argument will be treaded separetely and each will be given a single,
separated entry in the output tensordict. Defaults to :obj:`True`.
separated entry in the output tensordict. Defaults to ``True``.
download (bool, torchvision Weights config or corresponding string):
if ``True``, the weights will be downloaded using the torch.hub download
API (i.e. weights will be cached for future use).
Expand Down
14 changes: 7 additions & 7 deletions torchrl/modules/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ class IndependentNormal(D.Independent):
Default is 5.0
tanh_loc (bool, optional): if ``True``, the above formula is used for the location scaling, otherwise the raw value
is kept.
Default is :obj:`True`;
tanh_loc (bool, optional): if ``False``, the above formula is used for
the location scaling, otherwise the raw value
is kept. Default is ``False``;
"""

num_params: int = 2
Expand All @@ -69,7 +69,7 @@ def __init__(
loc: torch.Tensor,
scale: torch.Tensor,
upscale: float = 5.0,
tanh_loc: bool = True,
tanh_loc: bool = False,
event_dim: int = 1,
**kwargs,
):
Expand Down Expand Up @@ -187,7 +187,7 @@ class TruncatedNormal(D.Independent):
max (torch.Tensor or number, optional): maximum value of the distribution. Default = 1.0;
tanh_loc (bool, optional): if ``True``, the above formula is used for
the location scaling, otherwise the raw value is kept.
Default is :obj:`True`;
Default is ``False``;
"""

num_params: int = 2
Expand All @@ -204,7 +204,7 @@ def __init__(
upscale: Union[torch.Tensor, float] = 5.0,
min: Union[torch.Tensor, float] = -1.0,
max: Union[torch.Tensor, float] = 1.0,
tanh_loc: bool = True,
tanh_loc: bool = False,
):
err_msg = "TanhNormal max values must be strictly greater than min values"
if isinstance(max, torch.Tensor) or isinstance(min, torch.Tensor):
Expand Down Expand Up @@ -304,7 +304,7 @@ class TanhNormal(FasterTransformedDistribution):
event_dims (int, optional): number of dimensions describing the action.
Default is 1;
tanh_loc (bool, optional): if ``True``, the above formula is used for the location scaling, otherwise the raw
value is kept. Default is :obj:`True`;
value is kept. Default is ``False``;
"""

arg_constraints = {
Expand Down
2 changes: 1 addition & 1 deletion torchrl/modules/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,7 +884,7 @@ class DdpgCnnQNet(nn.Module):
'bias_last_layer': True,
}
use_avg_pooling (bool, optional): if ``True``, a nn.AvgPooling layer is
used to aggregate the output. Default is :obj:`True`.
used to aggregate the output. Default is ``True``.
device (Optional[DEVICE_TYPING]): device to create the module on.
"""

Expand Down
2 changes: 1 addition & 1 deletion torchrl/modules/tensordict_module/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ class SafeProbabilisticTensorDictSequential(
of the input keys. If so, the only module that will be executed are those
who can be executed given the keys that are present. Also, if the input
tensordict is a lazy stack of tensordicts AND if partial_tolerant is
:obj:`True` AND if the stack does not have the required keys, then
``True`` AND if the stack does not have the required keys, then
TensorDictSequential will scan through the sub-tensordicts looking for those
that have the required keys, if any.
Expand Down
2 changes: 1 addition & 1 deletion torchrl/modules/tensordict_module/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class SafeSequential(TensorDictSequential, SafeModule):
partial_tolerant (bool, optional): if ``True``, the input tensordict can miss some of the input keys.
If so, the only module that will be executed are those who can be executed given the keys that
are present.
Also, if the input tensordict is a lazy stack of tensordicts AND if partial_tolerant is :obj:`True` AND if the
Also, if the input tensordict is a lazy stack of tensordicts AND if partial_tolerant is ``True`` AND if the
stack does not have the required keys, then SafeSequential will scan through the sub-tensordicts
looking for those that have the required keys, if any.
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def distance_loss(
v2 (Tensor): a tensor with a shape compatible with v1
loss_function (str): One of "l2", "l1" or "smooth_l1" representing which loss function is to be used.
strict_shape (bool): if False, v1 and v2 are allowed to have a different shape.
Default is :obj:`True`.
Default is ``True``.
Returns:
A tensor of the shape v1.view_as(v2) or v2.view_as(v1) with values equal to the distance loss between the
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/value/advantages.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,7 +750,7 @@ class GAE(ValueEstimatorBase):
GAE will return an :obj:`"advantage"` entry containing the advange value. It will also
return a :obj:`"value_target"` entry with the return value that is to be used
to train the value network. Finally, if :obj:`gradient_mode` is :obj:`True`,
to train the value network. Finally, if :obj:`gradient_mode` is ``True``,
an additional and differentiable :obj:`"value_error"` entry will be returned,
which simple represents the difference between the return and the value network
output (i.e. an additional distance loss should be applied to that signed value).
Expand Down
4 changes: 2 additions & 2 deletions torchrl/trainers/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,12 @@ class Trainer:
clip_grad_norm (bool, optional): If True, the gradients will be clipped
based on the total norm of the model parameters. If False,
all the partial derivatives will be clamped to
(-clip_norm, clip_norm). Default is :obj:`True`.
(-clip_norm, clip_norm). Default is ``True``.
clip_norm (Number, optional): value to be used for clipping gradients.
Default is None (no clip norm).
progress_bar (bool, optional): If True, a progress bar will be
displayed using tqdm. If tqdm is not installed, this option
won't have any effect. Default is :obj:`True`
won't have any effect. Default is ``True``
seed (int, optional): Seed to be used for the collector, pytorch and
numpy. Default is ``None``.
save_trainer_interval (int, optional): How often the trainer should be
Expand Down

0 comments on commit 3528ffc

Please sign in to comment.