Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Minor] Fix typos in advantages.py #2492

Merged
merged 1 commit into from
Oct 14, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions torchrl/objectives/value/advantages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1177,17 +1177,17 @@ class GAE(ValueEstimatorBase):
device (torch.device, optional): device of the module.
time_dim (int, optional): the dimension corresponding to the time
in the input tensordict. If not provided, defaults to the dimension
markes with the ``"time"`` name if any, and to the last dimension
marked with the ``"time"`` name if any, and to the last dimension
otherwise. Can be overridden during a call to
:meth:`~.value_estimate`.
Negative dimensions are considered with respect to the input
tensordict.

GAE will return an :obj:`"advantage"` entry containing the advange value. It will also
GAE will return an :obj:`"advantage"` entry containing the advantage value. It will also
return a :obj:`"value_target"` entry with the return value that is to be used
to train the value network. Finally, if :obj:`gradient_mode` is ``True``,
an additional and differentiable :obj:`"value_error"` entry will be returned,
which simple represents the difference between the return and the value network
which simply represents the difference between the return and the value network
output (i.e. an additional distance loss should be applied to that signed value).

.. note::
Expand Down Expand Up @@ -1262,7 +1262,7 @@ def forward(
target params to be passed to the functional value network module.
time_dim (int, optional): the dimension corresponding to the time
in the input tensordict. If not provided, defaults to the dimension
markes with the ``"time"`` name if any, and to the last dimension
marked with the ``"time"`` name if any, and to the last dimension
otherwise.
Negative dimensions are considered with respect to the input
tensordict.
Expand Down Expand Up @@ -1310,7 +1310,7 @@ def forward(
"""
if tensordict.batch_dims < 1:
raise RuntimeError(
"Expected input tensordict to have at least one dimensions, got "
"Expected input tensordict to have at least one dimension, got "
f"tensordict.batch_size = {tensordict.batch_size}"
)
reward = tensordict.get(("next", self.tensor_keys.reward))
Expand Down
Loading