Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Algorithm] IMPALA and VTrace module #1506

Merged
merged 118 commits into from
Nov 23, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
118 commits
Select commit Hold shift + click to select a range
550d397
vtrace module
albertbou92 Sep 6, 2023
2a693a7
vtrace module
albertbou92 Sep 6, 2023
7a8ee38
vtrace module
albertbou92 Sep 6, 2023
a6434fd
vtrace module
albertbou92 Sep 6, 2023
3e76450
vtrace module
albertbou92 Sep 6, 2023
7613cb9
vtrace module
albertbou92 Sep 6, 2023
8c70b0a
vtrace module
albertbou92 Sep 6, 2023
e8dd5be
vtrace module
albertbou92 Sep 6, 2023
3ace31f
vtrace module
albertbou92 Sep 7, 2023
3efe601
impala
albertbou92 Sep 7, 2023
2458927
impala example
albertbou92 Sep 7, 2023
c6a60cc
impala example
albertbou92 Sep 7, 2023
f157766
impala example
albertbou92 Sep 7, 2023
f0cf4f5
impala example
albertbou92 Sep 7, 2023
ae6fb62
docs clarifications
albertbou92 Sep 18, 2023
081c2da
docs
albertbou92 Sep 19, 2023
89a5a9e
fixes
albertbou92 Sep 19, 2023
888fbcb
fixes
albertbou92 Sep 19, 2023
f3f9832
config
albertbou92 Sep 19, 2023
5b7d642
fixes
albertbou92 Sep 19, 2023
7dca35e
fixes
albertbou92 Sep 19, 2023
e8c35ef
fixes
albertbou92 Sep 19, 2023
a4af09f
fixes
albertbou92 Sep 20, 2023
8bf4787
move vtrace to adv script
albertbou92 Sep 21, 2023
8648e10
tests
albertbou92 Sep 21, 2023
dfc1c82
tests
albertbou92 Sep 22, 2023
3da748e
tests
albertbou92 Sep 22, 2023
a568378
fix
albertbou92 Sep 22, 2023
596c6cc
format
albertbou92 Sep 22, 2023
ee692f5
working impala script
albertbou92 Sep 25, 2023
a5eb8b6
working impala script
albertbou92 Sep 25, 2023
b9e81d2
test offpolicy losses
albertbou92 Sep 25, 2023
d04d050
minor script fixes
albertbou92 Sep 25, 2023
2a15708
test onpolicy losses
albertbou92 Sep 25, 2023
a5c2046
test fix
albertbou92 Sep 25, 2023
dbde27c
test fix
albertbou92 Sep 25, 2023
7411235
test fix
albertbou92 Sep 26, 2023
fa5f835
test fix
albertbou92 Sep 26, 2023
30e0cc1
test fix
albertbou92 Oct 2, 2023
0b9ed5c
fixes
albertbou92 Oct 2, 2023
15034f8
merge main
albertbou92 Oct 2, 2023
b3c0c9e
multi node
albertbou92 Oct 2, 2023
c634112
multi node
albertbou92 Oct 2, 2023
df16ace
multi node
albertbou92 Oct 2, 2023
3403e29
fix tests
albertbou92 Oct 3, 2023
9da24eb
fix tests
albertbou92 Oct 3, 2023
795620f
fix tests
albertbou92 Oct 3, 2023
bdc2392
Merge branch 'main' into vtrace
albertbou92 Oct 3, 2023
53aceba
merge main
albertbou92 Oct 3, 2023
02cebf6
multinode script
albertbou92 Oct 3, 2023
5c0aec0
call actor func
albertbou92 Oct 3, 2023
c8ef2c7
faster scripts
albertbou92 Oct 3, 2023
e024c09
multinode script
albertbou92 Oct 3, 2023
55b7947
simplify utils
albertbou92 Oct 3, 2023
6d6df00
revert tests
albertbou92 Oct 4, 2023
d4536d1
Merge branch 'main' into vtrace
albertbou92 Oct 4, 2023
5ebcfb8
Merge branch 'main' into vtrace
albertbou92 Oct 4, 2023
9e1d64b
introduce review feedback
albertbou92 Oct 4, 2023
224ae91
torch compile
albertbou92 Oct 4, 2023
e543888
torch compile
albertbou92 Oct 4, 2023
db541c0
fix
albertbou92 Oct 4, 2023
937b819
fix
albertbou92 Oct 4, 2023
9e33035
tests
albertbou92 Oct 4, 2023
199bc3b
adapt ppo tests
albertbou92 Oct 5, 2023
f1b11dd
adapt ppo tests
albertbou92 Oct 5, 2023
1d8d1ef
adapt ppo tests
albertbou92 Oct 5, 2023
ebf74b8
fix tests ppo
albertbou92 Oct 5, 2023
1180993
fix tests a2c
albertbou92 Oct 5, 2023
6e73acd
fix tests a2c
albertbou92 Oct 5, 2023
2ecb103
fix tests a2c
albertbou92 Oct 5, 2023
cd07719
fix tests reinforce
albertbou92 Oct 5, 2023
f1d2770
fix tests values
albertbou92 Oct 5, 2023
676d8f5
fix tests values
albertbou92 Oct 5, 2023
f491e7d
fix tests adv
albertbou92 Oct 5, 2023
7a63dd6
fix tests adv
albertbou92 Oct 5, 2023
d30bb9d
fix tests adv
albertbou92 Oct 5, 2023
e19c671
Merge branch 'main' into vtrace
albertbou92 Oct 5, 2023
a9e1db3
code examples
albertbou92 Oct 5, 2023
32cd518
code examples
albertbou92 Oct 5, 2023
53617bb
fix tests adv
albertbou92 Oct 5, 2023
0bc7b8c
fix tests adv
albertbou92 Oct 5, 2023
40cc02f
code examples tests
albertbou92 Oct 5, 2023
cbd923e
code examples tests
albertbou92 Oct 5, 2023
2235c02
code example with submitit
albertbou92 Oct 5, 2023
1a8efd1
code example with submitit
albertbou92 Oct 5, 2023
3ef4001
code example with submitit
albertbou92 Oct 5, 2023
fcc1121
code example with submitit
albertbou92 Oct 5, 2023
dd2a7f3
code example with submitit
albertbou92 Oct 5, 2023
624b2d6
code example with submitit
albertbou92 Oct 5, 2023
7e30069
code example with submitit
albertbou92 Oct 5, 2023
8d6c064
Merge branch 'main' into vtrace
albertbou92 Oct 5, 2023
157ad9b
Merge branch 'main' into vtrace
albertbou92 Nov 14, 2023
597623b
fix logging
albertbou92 Nov 14, 2023
5c21c1e
fix example
albertbou92 Nov 19, 2023
e47dbc3
fix example
albertbou92 Nov 19, 2023
607ad53
Merge branch 'main' into vtrace
albertbou92 Nov 22, 2023
c23401a
Update examples/impala/impala_multi_node_ray.py
albertbou92 Nov 22, 2023
886b4e0
Update torchrl/objectives/value/advantages.py
albertbou92 Nov 22, 2023
803fc4f
Update torchrl/objectives/value/advantages.py
albertbou92 Nov 22, 2023
72a3c6e
Update torchrl/objectives/value/advantages.py
albertbou92 Nov 22, 2023
4a061b5
Update examples/impala/impala_multi_node_ray.py
albertbou92 Nov 22, 2023
e7069e4
Update examples/impala/impala_multi_node_ray.py
albertbou92 Nov 22, 2023
5399cf1
merge main
albertbou92 Nov 22, 2023
39584ab
fixes
albertbou92 Nov 22, 2023
c68fd40
fixes
albertbou92 Nov 22, 2023
638c0d6
format
albertbou92 Nov 22, 2023
2f8b545
fixes
albertbou92 Nov 22, 2023
6ddcb3a
fixes
albertbou92 Nov 22, 2023
94306bf
fixes
albertbou92 Nov 22, 2023
9cc0284
fixes
albertbou92 Nov 23, 2023
63392f0
fixes
albertbou92 Nov 23, 2023
e61f342
fixes
albertbou92 Nov 23, 2023
6d384d5
fixes
albertbou92 Nov 23, 2023
89770e4
submitit example
albertbou92 Nov 23, 2023
9132a60
submitit example
albertbou92 Nov 23, 2023
89a803b
README
albertbou92 Nov 23, 2023
bd02b30
fix tests
albertbou92 Nov 23, 2023
0a382bb
fix unused_args
albertbou92 Nov 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
vtrace module
  • Loading branch information
albertbou92 committed Sep 6, 2023
commit 8c70b0a5931a873f454a35c12a3750902370c759
2 changes: 1 addition & 1 deletion examples/impala/impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def main(cfg: "DictConfig"): # noqa: F821
gamma=cfg.loss.gamma,
value_network=critic,
actor_network=actor,
average_gae=False,
average_adv=False,
)
loss_module = ClipPPOLoss(
actor=actor,
Expand Down
119 changes: 61 additions & 58 deletions torchrl/objectives/value/vtrace.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,20 @@
from tensordict.utils import NestedKey
from torch import nn, Tensor
from torchrl.objectives.utils import hold_out_net
from torchrl.objectives.value.advantages import ValueEstimatorBase, _self_set_skip_existing, _self_set_grad_enabled, _call_value_nets
from torchrl.objectives.value.functional import _transpose_time, SHAPE_ERR
from torchrl.objectives.value.advantages import (
ValueEstimatorBase,
_self_set_skip_existing,
_self_set_grad_enabled,
_call_value_nets)
from torchrl.objectives.value.functional import _transpose_time, SHAPE_ERR, td0_return_estimate


def _c_val(
log_pi: torch.Tensor,
log_mu: torch.Tensor,
c: Union[float, torch.Tensor] = 1,
) -> torch.Tensor:
return (log_pi - log_mu).clamp_max(math.log(c)).exp().unsqueeze(-1)

return (log_pi - log_mu).clamp_max(math.log(c)).exp().unsqueeze(-1) # TODO: Review!

def _dv_val(
rewards: torch.Tensor,
Expand Down Expand Up @@ -80,8 +83,8 @@ def vtrace_correction(
next_state_value: torch.Tensor,
reward: torch.Tensor,
done: torch.Tensor,
rho_bar: Union[float, torch.Tensor] = 1.0,
c_bar: Union[float, torch.Tensor] = 1.0,
rho_thresh: Union[float, torch.Tensor] = 1.0,
c_thresh: Union[float, torch.Tensor] = 1.0,
time_dim: int = -2,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""V-Trace off-policy correction method.
Expand Down Expand Up @@ -112,8 +115,10 @@ def vtrace_correction(
dtype = next_state_value.dtype
device = state_value.device

delta, rho = _dv_val(reward, state_value, next_state_value, gamma, rho_bar, log_pi, log_mu)
cs = _c_val(log_pi, log_mu, c_bar)
delta, clipped_rho = _dv_val(reward, state_value, next_state_value, gamma, rho_thresh, log_pi, log_mu)
torch.clamp(torch.exp(log_rhos), max=clip_c_thres)

clipped_cs = _c_val(log_pi, log_mu, c_thresh)

not_done = (~done).int()
*batch_size, time_steps, lastdim = not_done.shape
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -124,6 +129,8 @@ def vtrace_correction(
for t in reversed(range(time_steps)):
# TODO: Review!
acc = delta[..., t, :] + (gnotdone[..., t, :] * acc * cs[..., t, :])


v_out[..., t, :].copy_(acc + state_value[..., t, :])

advantage = rho * (reward + gamma * v_out - state_value) # TODO: Review!
Expand All @@ -140,7 +147,8 @@ class VTrace(ValueEstimatorBase):
Args:
gamma (scalar): exponential mean discount.
value_network (TensorDictModule): value operator used to retrieve the value estimates.
average_gae (bool): if ``True``, the resulting GAE values will be standardized.
actor_network (TensorDictModule, optional): actor operator used to retrieve the log prob.
average_adv (bool): if ``True``, the resulting advantage values will be standardized.
Default is ``False``.
differentiable (bool, optional): if ``True``, gradients are propagated through
the computation of the value function. Default is ``False``.
Expand All @@ -149,10 +157,6 @@ class VTrace(ValueEstimatorBase):
The proper way to make the function call non-differentiable is to
decorate it in a `torch.no_grad()` context manager/decorator or
pass detached parameters for functional modules.

# vectorized (bool, optional): whether to use the vectorized version of the
# lambda return. Default is `True`.

skip_existing (bool, optional): if ``True``, the value network will skip
modules which outputs are already present in the tensordict.
Defaults to ``None``, ie. the value of :func:`tensordict.nn.skip_existing()`
Expand All @@ -164,27 +168,21 @@ class VTrace(ValueEstimatorBase):
of the advantage entry. Defaults to ``"value_target"``.
value_key (str or tuple of str, optional): [Deprecated] the value key to
read from the input tensordict. Defaults to ``"state_value"``.

# shifted (bool, optional): if ``True``, the value and next value are
# estimated with a single call to the value network. This is faster
# but is only valid whenever (1) the ``"next"`` value is shifted by
# only one time step (which is not the case with multi-step value
# estimation, for instance) and (2) when the parameters used at time
# ``t`` and ``t+1`` are identical (which is not the case when target
# parameters are to be used). Defaults to ``False``.
shifted (bool, optional): if ``True``, the value and next value are
estimated with a single call to the value network. This is faster
but is only valid whenever (1) the ``"next"`` value is shifted by
only one time step (which is not the case with multi-step value
estimation, for instance) and (2) when the parameters used at time
``t`` and ``t+1`` are identical (which is not the case when target
parameters are to be used). Defaults to ``False``.

VTrace will return an :obj:`"advantage"` entry containing the advantage value. It will also
return a :obj:`"value_target"` entry with the V-Trace target value.

# Finally, if :obj:`gradient_mode` is ``True``,
# an additional and differentiable :obj:`"value_error"` entry will be returned,
# which simple represents the difference between the return and the value network
# output (i.e. an additional distance loss should be applied to that signed value).

# .. note::
# As other advantage functions do, if the ``value_key`` is already present
# in the input tensordict, the VTrace module will ignore the calls to the value
# network (if any) and use the provided value instead.
.. note::
As other advantage functions do, if the ``value_key`` is already present
in the input tensordict, the VTrace module will ignore the calls to the value
network (if any) and use the provided value instead.

"""

Expand All @@ -196,10 +194,10 @@ def __init__(
c_bar: Union[float, torch.Tensor] = 1.0,
actor_network: TensorDictModule = None,
value_network: TensorDictModule,
average_gae: bool = False,
average_adv: bool = False,
differentiable: bool = False,
vectorized: bool = False, # TODO: Review!
skip_existing: Optional[bool] = None,
log_prob_key: NestedKey = "sample_log_prob", # TODO: should be added to _AcceptedKeys?
advantage_key: NestedKey = None,
value_target_key: NestedKey = None,
value_key: NestedKey = None,
Expand All @@ -221,9 +219,13 @@ def __init__(
self.register_buffer("gamma", torch.tensor(gamma, device=device))
self.register_buffer("rho_bar", torch.tensor(rho_bar, device=device))
self.register_buffer("c_bar", torch.tensor(c_bar, device=device))
self.average_gae = average_gae
self.vectorized = vectorized
self.average_adv = average_adv
self.actor_network = actor_network
self._log_prob_key = log_prob_key

@property
def log_prob_key(self):
return self._log_prob_key

@_self_set_skip_existing
@_self_set_grad_enabled
Expand Down Expand Up @@ -322,34 +324,35 @@ def forward(
value = tensordict.get(self.tensor_keys.value)
next_value = tensordict.get(("next", self.tensor_keys.value))

# TODO: raise ValueError if log_mu is not present
log_mu = tensordict.get("sample_log_prob")
# Make sure we have the log prob computed at collection time
if self.log_prob_key not in tensordict.keys():
raise ValueError(f"Expected {self.log_prob_key} to be in tensordict")
log_mu = tensordict.get(self.log_prob_key)

# Make sure we have the new log prob and the old log prob
if self.actor_network is not None:
# TODO: review
log_pi = self.actor_network(tensordict.select(self.actor_network.in_keys)).get("sample_log_prob") # old / distributed log prob
else:
log_pi = tensordict.get("sample_log_prob") # new / local log prob # TODO: Review!
# Compute the current log prob
with hold_out_net(self.actor_network):
log_pi = self.actor_network(
tensordict.select(self.actor_network.in_keys)
).get(self.log_prob_key)

# Compute the V-Trace correction
done = tensordict.get(("next", self.tensor_keys.done))
if self.vectorized:
raise NotImplementedError
else:
adv, value_target = vtrace_correction(
gamma,
log_pi,
log_mu,
value,
next_value,
reward,
done,
rho_bar=self.rho_bar,
c_bar=self.c_bar,
time_dim=tensordict.ndim - 1,
)
adv, value_target = vtrace_correction(
gamma,
log_pi,
log_mu,
value,
next_value,
reward,
done,
rho_bar=self.rho_bar,
c_bar=self.c_bar,
time_dim=tensordict.ndim - 1,
)

# TODO: where are returns computed?

if self.average_gae:
if self.average_adv:
loc = adv.mean()
scale = adv.std().clamp_min(1e-4)
adv = adv - loc
Expand Down