Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Make advantages compatible with Terminated, Truncated, Done #1581

Merged
merged 199 commits into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
199 commits
Select commit Hold shift + click to select a range
9e5f303
init
vmoens Sep 15, 2023
03b6089
fix
vmoens Sep 15, 2023
512b596
amend
vmoens Sep 17, 2023
3fa8ac0
amend
vmoens Sep 17, 2023
162aa6e
amend
vmoens Sep 17, 2023
2c5cddc
amend
vmoens Sep 17, 2023
c703a02
amend
vmoens Sep 17, 2023
2b78f49
Merge remote-tracking branch 'origin/main' into terminal_truncated
vmoens Sep 17, 2023
822f42a
amend
vmoens Sep 18, 2023
db8e0c1
amend
vmoens Sep 18, 2023
9c7e1a5
amend
vmoens Sep 18, 2023
99bcc4a
amend
vmoens Sep 18, 2023
9b0069e
lint
vmoens Sep 18, 2023
ac43a7e
fix step counter
vmoens Sep 18, 2023
a822407
amend
vmoens Sep 18, 2023
6cec6e1
amend
vmoens Sep 18, 2023
0612e09
amend
vmoens Sep 18, 2023
02902db
Update torchrl/envs/gym_like.py
vmoens Sep 19, 2023
7e22c55
Update torchrl/envs/gym_like.py
vmoens Sep 19, 2023
53927bc
Update torchrl/envs/gym_like.py
vmoens Sep 19, 2023
43ff66a
amend
vmoens Sep 19, 2023
77dcd09
rollout
vmoens Sep 19, 2023
f404245
fix
vmoens Sep 19, 2023
a79b57c
fix
vmoens Sep 19, 2023
149781f
fix
vmoens Sep 19, 2023
71a05b1
fix
vmoens Sep 19, 2023
87bad08
remove prints
vmoens Sep 19, 2023
92c3814
amend
vmoens Sep 19, 2023
42d4c40
amend
vmoens Sep 19, 2023
aec627c
amend
vmoens Sep 19, 2023
b2303bd
amend
vmoens Sep 19, 2023
f6a497b
amend
vmoens Sep 19, 2023
aac630f
amend
vmoens Sep 19, 2023
606ee3a
lint and fixes
vmoens Sep 19, 2023
6ba0d38
amend
vmoens Sep 19, 2023
76b3f0c
amend
vmoens Sep 19, 2023
035c274
amend
vmoens Sep 19, 2023
8bd932f
amend
vmoens Sep 19, 2023
c789e50
amend
vmoens Sep 19, 2023
3e93f13
amend
vmoens Sep 19, 2023
cba97b1
amend
vmoens Sep 19, 2023
7ec7c78
amend
vmoens Sep 19, 2023
dd4c45e
amend
vmoens Sep 19, 2023
0ea0716
fix robohive
vmoens Sep 19, 2023
16d688e
amend
vmoens Sep 20, 2023
268dbd7
amend
vmoens Sep 20, 2023
d77d1cd
amend
vmoens Sep 20, 2023
15bd9fa
amend
vmoens Sep 20, 2023
1b656f7
amend
vmoens Sep 20, 2023
2f13c95
amend
vmoens Sep 20, 2023
aa5de06
amend
vmoens Sep 20, 2023
284262f
Merge remote-tracking branch 'origin/main' into terminal_truncated
vmoens Sep 20, 2023
afcc527
amend
vmoens Sep 20, 2023
9eae41d
amend
vmoens Sep 20, 2023
9210f3b
amend
vmoens Sep 20, 2023
b31b2f0
amend
vmoens Sep 21, 2023
4e8acc0
init
vmoens Sep 21, 2023
c8579f9
init
vmoens Sep 21, 2023
d95989c
Merge branch 'fix_dreamer_tests' into threads_mp
vmoens Sep 21, 2023
e22c318
Merge branch 'terminal_truncated' into myo_threaded
vmoens Sep 21, 2023
a5e9ce3
prints
vmoens Sep 21, 2023
b6e83d5
amend
vmoens Sep 21, 2023
acf89e6
amend
vmoens Sep 21, 2023
16fba2e
amend
vmoens Sep 21, 2023
697c523
amend
vmoens Sep 21, 2023
369492d
fix
vmoens Sep 21, 2023
c50263c
Merge branch 'main' into terminal_truncated
vmoens Sep 21, 2023
de82499
Merge remote-tracking branch 'origin/main' into terminal_truncated
vmoens Sep 21, 2023
e14dcd1
fix
vmoens Sep 21, 2023
4a8424c
amend
vmoens Sep 21, 2023
5056b9e
amend
vmoens Sep 21, 2023
ce26e13
amend
vmoens Sep 21, 2023
5a95850
amend
vmoens Sep 21, 2023
5e38d70
Update torchrl/collectors/collectors.py
vmoens Sep 21, 2023
9eb1c98
amend
vmoens Sep 21, 2023
bccbf67
Merge remote-tracking branch 'origin/main' into terminal_truncated
vmoens Sep 21, 2023
a285780
amend
vmoens Sep 21, 2023
40a8e83
Merge branch 'terminal_truncated' into myo_threaded
vmoens Sep 21, 2023
d8f9505
Merge remote-tracking branch 'origin/main' into terminal_truncated
vmoens Sep 21, 2023
e1eba40
Merge remote-tracking branch 'origin/main' into threads_mp
vmoens Sep 21, 2023
9f58a8d
lint
vmoens Sep 21, 2023
1c4f35f
amend
vmoens Sep 21, 2023
93bd2e6
Merge branch 'threads_mp' into terminal_truncated
vmoens Sep 21, 2023
f6e09e3
amend
vmoens Sep 21, 2023
2cd07c1
amend
vmoens Sep 22, 2023
acf6118
amend
vmoens Sep 22, 2023
bb52ce1
tests
vmoens Sep 22, 2023
0d0bc3c
Merge branch 'main' into terminal_truncated
vmoens Sep 22, 2023
3ef139b
amend
vmoens Sep 22, 2023
0d3ba02
amend
vmoens Sep 22, 2023
0b32209
Merge remote-tracking branch 'origin/main' into terminal_truncated
vmoens Sep 22, 2023
0638e13
amend
vmoens Sep 22, 2023
a3bd3e1
tmp
vmoens Sep 22, 2023
a8ed5e5
amend
vmoens Sep 22, 2023
54db14e
amend
vmoens Sep 22, 2023
0122fac
amend
vmoens Sep 23, 2023
b18e6e7
amend
vmoens Sep 23, 2023
2f99f16
amend
vmoens Sep 23, 2023
6aa0c9e
amend
vmoens Sep 23, 2023
2b42070
amend
vmoens Sep 23, 2023
7279b12
amend
vmoens Sep 23, 2023
f5ab14d
amend
vmoens Sep 23, 2023
4a9b6b9
amend
vmoens Sep 23, 2023
696324b
amend
vmoens Sep 23, 2023
8890911
Update docs/source/reference/envs.rst
vmoens Sep 24, 2023
e24c2f3
add doc
vmoens Sep 24, 2023
c029f12
amend
vmoens Sep 24, 2023
b37129d
amend
vmoens Sep 24, 2023
9afc783
amend
vmoens Sep 24, 2023
f65622a
amend
vmoens Sep 24, 2023
989eecf
fix VIP
vmoens Sep 24, 2023
77559e0
lint
vmoens Sep 24, 2023
117e41e
osx_skips
vmoens Sep 24, 2023
19fdc33
Merge remote-tracking branch 'origin/main' into terminal_truncated
vmoens Sep 24, 2023
8326674
amend
vmoens Sep 24, 2023
90da3fd
refactor tests
vmoens Sep 25, 2023
f4ddb92
Refactoring: terminated, truncated, done
vmoens Sep 25, 2023
7988118
amend
vmoens Sep 25, 2023
85d5035
let _step return partial done in batched envs
vmoens Sep 25, 2023
b4fa338
fix mocking classes
vmoens Sep 25, 2023
648792d
more fixes
vmoens Sep 25, 2023
9777332
fix step count in equivalence test
vmoens Sep 25, 2023
9063b5b
fix transforms
vmoens Sep 25, 2023
f47e49d
fix transforms
vmoens Sep 25, 2023
c238e9b
fix transformed env
vmoens Sep 25, 2023
502a59b
amend
vmoens Sep 25, 2023
73af141
amend
vmoens Sep 25, 2023
983e246
amend
vmoens Sep 26, 2023
555bca9
amend
vmoens Sep 26, 2023
85f58de
remove calls to done_key
vmoens Sep 26, 2023
6fc39d0
fix step counter
vmoens Sep 26, 2023
175b701
vec envs
vmoens Sep 26, 2023
92e7826
vec envs
vmoens Sep 26, 2023
cd6eaea
amend
vmoens Sep 26, 2023
23e139b
amend
vmoens Sep 26, 2023
677c408
amend
vmoens Sep 26, 2023
3c79081
d4rl
vmoens Sep 26, 2023
54e75b0
d4rl unsqueeze
vmoens Sep 26, 2023
244e3d3
amend
vmoens Sep 26, 2023
09abb71
amend
vmoens Sep 26, 2023
384db24
minor
vmoens Sep 26, 2023
3693cac
amend
vmoens Sep 26, 2023
ca54133
amend
vmoens Sep 26, 2023
0fdc522
amend
vmoens Sep 26, 2023
fdad78f
test_terminated_or_truncated_spec
vmoens Sep 26, 2023
f454e11
more fixes
vmoens Sep 26, 2023
88cee59
--capture no
vmoens Sep 26, 2023
57ccb63
attempt to limit collector idle time
vmoens Sep 26, 2023
e5b0d23
lint
vmoens Sep 26, 2023
dfe726f
amend
vmoens Sep 26, 2023
4f6ce90
amend
vmoens Sep 26, 2023
b16b939
amend
vmoens Sep 26, 2023
daaaddd
amend
vmoens Sep 26, 2023
cd4811f
amend
vmoens Sep 26, 2023
c0f3137
fixes
vmoens Sep 26, 2023
8f9d8fe
fix r3m, vip and vc1
vmoens Sep 27, 2023
4fdf437
fix robohive, d4rl
vmoens Sep 27, 2023
4f44579
amend
vmoens Sep 27, 2023
7787b28
amend
vmoens Sep 27, 2023
5c964a9
amend
vmoens Sep 27, 2023
03001bc
amend
vmoens Sep 27, 2023
04bbee9
lint
vmoens Sep 27, 2023
6ac830e
adapt tests
vmoens Sep 27, 2023
6fbe8c0
amend
vmoens Sep 27, 2023
28fefb6
lint
vmoens Sep 27, 2023
8256e2f
fix gym 0.19
vmoens Sep 27, 2023
1619b09
missing deps
vmoens Sep 27, 2023
23926b2
fix gym truncated
vmoens Sep 27, 2023
e3b8253
fix gym truncated (bis)
vmoens Sep 27, 2023
2a4a1b6
amend
vmoens Sep 27, 2023
7f4c38b
amend
vmoens Sep 27, 2023
2e54626
Merge branch 'main' into terminal_truncated
vmoens Sep 27, 2023
977488e
Merge remote-tracking branch 'origin/main' into terminal_truncated
vmoens Sep 27, 2023
4f932e0
amend
vmoens Sep 27, 2023
5e6a0f0
amend
vmoens Sep 27, 2023
411f097
lint
vmoens Sep 27, 2023
39ed4c3
final (?)
vmoens Sep 28, 2023
e56f9b7
Merge branch 'terminal_truncated' into truncated_vals
vmoens Sep 28, 2023
aa3707d
amend
vmoens Sep 28, 2023
85bc738
amend
vmoens Sep 28, 2023
21ea856
addressing review
vmoens Sep 28, 2023
71f2d86
Merge branch 'terminal_truncated' into truncated_vals
vmoens Sep 28, 2023
448728b
amend
vmoens Sep 28, 2023
f0ee4dd
more fixes
vmoens Sep 28, 2023
7906387
amend
vmoens Sep 28, 2023
72c1240
cloning dones
vmoens Sep 28, 2023
2dde687
Merge branch 'terminal_truncated' into truncated_vals
vmoens Sep 28, 2023
0d65bfa
amend
vmoens Sep 28, 2023
234b680
amend
vmoens Sep 28, 2023
2c7ffb0
amend
vmoens Sep 29, 2023
1bdccfc
Merge branch 'terminal_truncated' into truncated_vals
vmoens Sep 29, 2023
2f5b598
amend
vmoens Sep 29, 2023
b21f0ae
amend
vmoens Sep 29, 2023
1aeed2a
Merge remote-tracking branch 'origin/main' into truncated_vals
vmoens Oct 2, 2023
e635ba2
amend
vmoens Oct 2, 2023
fef45e6
fix marl examples
vmoens Oct 2, 2023
5d5b740
amend
vmoens Oct 2, 2023
42a25da
amend
vmoens Oct 2, 2023
4a8cc84
amend
vmoens Oct 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix marl examples
  • Loading branch information
vmoens committed Oct 2, 2023
commit fef45e6ebc5ee6bc5a10b880ae863a44261514ac
13 changes: 6 additions & 7 deletions examples/multiagent/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torchrl.modules.models.multiagent import MultiAgentMLP
from torchrl.objectives import DQNLoss, SoftUpdate, ValueEstimators
from utils.logging import init_logging, log_evaluation, log_training
from utils.utils import DoneTransform


def rendering_callback(env, td):
Expand Down Expand Up @@ -118,13 +119,18 @@ def train(cfg: "DictConfig"): # noqa: F821
sampler=SamplerWithoutReplacement(),
batch_size=cfg.train.minibatch_size,
)
replay_buffer.append_transform(
DoneTransform(reward_key=env.reward_key, done_keys=env.done_keys)
)

loss_module = DQNLoss(qnet, delay_value=True)
loss_module.set_keys(
action_value=("agents", "action_value"),
action=env.action_key,
value=("agents", "chosen_action_value"),
reward=env.reward_key,
done="done_expand",
terminated="terminated_expand",
)
loss_module.make_value_estimator(ValueEstimators.TD0, gamma=cfg.loss.gamma)
target_net_updater = SoftUpdate(loss_module, eps=1 - cfg.loss.tau)
Expand All @@ -144,13 +150,6 @@ def train(cfg: "DictConfig"): # noqa: F821

sampling_time = time.time() - sampling_start

tensordict_data.set(
("next", "done"),
tensordict_data.get(("next", "done"))
.unsqueeze(-1)
.expand(tensordict_data.get(("next", env.reward_key)).shape),
) # We need to expand the done to match the reward shape

current_frames = tensordict_data.numel()
total_frames += current_frames
data_view = tensordict_data.reshape(-1)
Expand Down
13 changes: 6 additions & 7 deletions examples/multiagent/maddpg_iddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from torchrl.modules.models.multiagent import MultiAgentMLP
from torchrl.objectives import DDPGLoss, SoftUpdate, ValueEstimators
from utils.logging import init_logging, log_evaluation, log_training
from utils.utils import DoneTransform


def rendering_callback(env, td):
Expand Down Expand Up @@ -140,10 +141,15 @@ def train(cfg: "DictConfig"): # noqa: F821
sampler=SamplerWithoutReplacement(),
batch_size=cfg.train.minibatch_size,
)
replay_buffer.append_transform(
DoneTransform(reward_key=env.reward_key, done_keys=env.done_keys)
)

loss_module = DDPGLoss(
actor_network=policy, value_network=value_module, delay_value=True
)
loss_module.set_keys(done="done_expand", terminated="terminated_expand")

loss_module.set_keys(
state_action_value=("agents", "state_action_value"),
reward=env.reward_key,
Expand All @@ -170,13 +176,6 @@ def train(cfg: "DictConfig"): # noqa: F821

sampling_time = time.time() - sampling_start

tensordict_data.set(
("next", "done"),
tensordict_data.get(("next", "done"))
.unsqueeze(-1)
.expand(tensordict_data.get(("next", env.reward_key)).shape),
) # We need to expand the done to match the reward shape

vmoens marked this conversation as resolved.
Show resolved Hide resolved
current_frames = tensordict_data.numel()
total_frames += current_frames
data_view = tensordict_data.reshape(-1)
Expand Down
17 changes: 8 additions & 9 deletions examples/multiagent/mappo_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from torchrl.modules.models.multiagent import MultiAgentMLP
from torchrl.objectives import ClipPPOLoss, ValueEstimators
from utils.logging import init_logging, log_evaluation, log_training

from utils.utils import DoneTransform

def rendering_callback(env, td):
env.frames.append(env.render(mode="rgb_array", agent_index_focus=None))
Expand Down Expand Up @@ -126,6 +126,7 @@ def train(cfg: "DictConfig"): # noqa: F821
storing_device=cfg.train.device,
frames_per_batch=cfg.collector.frames_per_batch,
total_frames=cfg.collector.total_frames,
postproc=DoneTransform(reward_key=env.reward_key, done_keys=env.done_keys),
)

replay_buffer = TensorDictReplayBuffer(
Expand All @@ -142,7 +143,12 @@ def train(cfg: "DictConfig"): # noqa: F821
entropy_coef=cfg.loss.entropy_eps,
normalize_advantage=False,
)
loss_module.set_keys(reward=env.reward_key, action=env.action_key)
loss_module.set_keys(
reward=env.reward_key,
action=env.action_key,
done="done_expand",
terminated="terminated_expand",
)
loss_module.make_value_estimator(
ValueEstimators.GAE, gamma=cfg.loss.gamma, lmbda=cfg.loss.lmbda
)
Expand All @@ -165,13 +171,6 @@ def train(cfg: "DictConfig"): # noqa: F821

sampling_time = time.time() - sampling_start

tensordict_data.set(
("next", "done"),
tensordict_data.get(("next", "done"))
.unsqueeze(-1)
.expand(tensordict_data.get(("next", env.reward_key)).shape),
) # We need to expand the done to match the reward shape

with torch.no_grad():
loss_module.value_estimator(
tensordict_data,
Expand Down
15 changes: 8 additions & 7 deletions examples/multiagent/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from torchrl.modules.models.multiagent import MultiAgentMLP
from torchrl.objectives import DiscreteSACLoss, SACLoss, SoftUpdate, ValueEstimators
from utils.logging import init_logging, log_evaluation, log_training
from utils.utils import DoneTransform


def rendering_callback(env, td):
Expand Down Expand Up @@ -186,6 +187,9 @@ def train(cfg: "DictConfig"): # noqa: F821
sampler=SamplerWithoutReplacement(),
batch_size=cfg.train.minibatch_size,
)
replay_buffer.append_transform(
DoneTransform(reward_key=env.reward_key, done_keys=env.done_keys)
)

if cfg.env.continuous_actions:
loss_module = SACLoss(
Expand All @@ -198,6 +202,8 @@ def train(cfg: "DictConfig"): # noqa: F821
state_action_value=("agents", "state_action_value"),
action=env.action_key,
reward=env.reward_key,
done="done_expand",
terminated="terminated_expand",
)
else:
loss_module = DiscreteSACLoss(
Expand All @@ -211,6 +217,8 @@ def train(cfg: "DictConfig"): # noqa: F821
action_value=("agents", "action_value"),
action=env.action_key,
reward=env.reward_key,
done="done_expand",
terminated="terminated_expand",
)

loss_module.make_value_estimator(ValueEstimators.TD0, gamma=cfg.loss.gamma)
Expand All @@ -235,13 +243,6 @@ def train(cfg: "DictConfig"): # noqa: F821

sampling_time = time.time() - sampling_start

tensordict_data.set(
("next", "done"),
tensordict_data.get(("next", "done"))
.unsqueeze(-1)
.expand(tensordict_data.get(("next", env.reward_key)).shape),
) # We need to expand the done to match the reward shape

current_frames = tensordict_data.numel()
total_frames += current_frames
data_view = tensordict_data.reshape(-1)
Expand Down
32 changes: 32 additions & 0 deletions examples/multiagent/utils/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from torchrl.envs import Transform


def append_suffix(key, suffix):
if isinstance(key, str):
return key + suffix
return key[:-1] + (append_suffix(key[-1], suffix),)


class DoneTransform(Transform):
"""Expands the 'done' entries (incl. terminated) to match the reward shape.

Can be appended to a replay buffer or a collector.
"""
def __init__(self, reward_key, done_keys):
super().__init__()
self.reward_key = reward_key
self.done_keys = done_keys

def forward(self, tensordict):
for done_key in self.done_keys:
tensordict.set(
("next", append_suffix(done_key, "_expand")),
tensordict.get(("next", done_key))
.unsqueeze(-1)
.expand(tensordict.get(("next", self.reward_key)).shape),
)
return tensordict
17 changes: 10 additions & 7 deletions torchrl/objectives/value/advantages.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,9 +242,9 @@ def forward(
and ``"next"`` tensordict state as returned by the environment)
necessary to compute the value estimates and the TDEstimate.
The data passed to this module should be structured as
:obj:`[*B, T, F]` where :obj:`B` are
:obj:`[*B, T, *F]` where :obj:`B` are
the batch size, :obj:`T` the time dimension and :obj:`F` the
feature dimension(s).
feature dimension(s). The tensordict must have shape ``[*B, T]``.
params (TensorDictBase, optional): A nested TensorDict containing the params
to be passed to the functional value network module.
target_params (TensorDictBase, optional): A nested TensorDict containing the
Expand Down Expand Up @@ -500,9 +500,9 @@ def forward(
tensordict state as returned by the environment) necessary to
compute the value estimates and the TDEstimate.
The data passed to this module should be structured as
vmoens marked this conversation as resolved.
Show resolved Hide resolved
:obj:`[*B, T, F]` where :obj:`B` are
:obj:`[*B, T, *F]` where :obj:`B` are
the batch size, :obj:`T` the time dimension and :obj:`F` the
feature dimension(s).
feature dimension(s). The tensordict must have shape ``[*B, T]``.
params (TensorDictBase, optional): A nested TensorDict containing the params
to be passed to the functional value network module.
target_params (TensorDictBase, optional): A nested TensorDict containing the
Expand Down Expand Up @@ -701,8 +701,9 @@ def forward(
``("next", "done")``, ``("next", "terminated")``,
and ``"next"`` tensordict state as returned by the environment)
necessary to compute the value estimates and the TDEstimate.
The data passed to this module should be structured as :obj:`[*B, T, F]` where :obj:`B` are
The data passed to this module should be structured as :obj:`[*B, T, *F]` where :obj:`B` are
the batch size, :obj:`T` the time dimension and :obj:`F` the feature dimension(s).
The tensordict must have shape ``[*B, T]``.
params (TensorDictBase, optional): A nested TensorDict containing the params
to be passed to the functional value network module.
target_params (TensorDictBase, optional): A nested TensorDict containing the
Expand Down Expand Up @@ -910,8 +911,9 @@ def forward(
``("next", "done")``, ``("next", "terminated")``,
and ``"next"`` tensordict state as returned by the environment)
necessary to compute the value estimates and the TDLambdaEstimate.
The data passed to this module should be structured as :obj:`[*B, T, F]` where :obj:`B` are
The data passed to this module should be structured as :obj:`[*B, T, *F]` where :obj:`B` are
the batch size, :obj:`T` the time dimension and :obj:`F` the feature dimension(s).
The tensordict must have shape ``[*B, T]``.
params (TensorDictBase, optional): A nested TensorDict containing the params
to be passed to the functional value network module.
target_params (TensorDictBase, optional): A nested TensorDict containing the
Expand Down Expand Up @@ -1150,8 +1152,9 @@ def forward(
``("next", "done")``, ``("next", "terminated")``,
and ``"next"`` tensordict state as returned by the environment)
necessary to compute the value estimates and the GAE.
The data passed to this module should be structured as :obj:`[*B, T, F]` where :obj:`B` are
The data passed to this module should be structured as :obj:`[*B, T, *F]` where :obj:`B` are
the batch size, :obj:`T` the time dimension and :obj:`F` the feature dimension(s).
The tensordict must have shape ``[*B, T]``.
params (TensorDictBase, optional): A nested TensorDict containing the params
to be passed to the functional value network module.
target_params (TensorDictBase, optional): A nested TensorDict containing the
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/value/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ def td1_return_estimate(
gamma = gamma * not_terminated
g = next_state_value[..., -1, :]
for i in reversed(range(T)):
# if not done and not terminated, get the bootstrapped value
# if not done (and hence not terminated), get the bootstrapped value
# if done but not terminated, get nex_val
# if terminated, take nothing (gamma = 0)
dnt = done_but_not_terminated[..., i, :]
Expand Down