Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Algorithm] IMPALA and VTrace module #1506

Merged
merged 118 commits into from
Nov 23, 2023
Merged
Changes from 1 commit
Commits
Show all changes
118 commits
Select commit Hold shift + click to select a range
550d397
vtrace module
albertbou92 Sep 6, 2023
2a693a7
vtrace module
albertbou92 Sep 6, 2023
7a8ee38
vtrace module
albertbou92 Sep 6, 2023
a6434fd
vtrace module
albertbou92 Sep 6, 2023
3e76450
vtrace module
albertbou92 Sep 6, 2023
7613cb9
vtrace module
albertbou92 Sep 6, 2023
8c70b0a
vtrace module
albertbou92 Sep 6, 2023
e8dd5be
vtrace module
albertbou92 Sep 6, 2023
3ace31f
vtrace module
albertbou92 Sep 7, 2023
3efe601
impala
albertbou92 Sep 7, 2023
2458927
impala example
albertbou92 Sep 7, 2023
c6a60cc
impala example
albertbou92 Sep 7, 2023
f157766
impala example
albertbou92 Sep 7, 2023
f0cf4f5
impala example
albertbou92 Sep 7, 2023
ae6fb62
docs clarifications
albertbou92 Sep 18, 2023
081c2da
docs
albertbou92 Sep 19, 2023
89a5a9e
fixes
albertbou92 Sep 19, 2023
888fbcb
fixes
albertbou92 Sep 19, 2023
f3f9832
config
albertbou92 Sep 19, 2023
5b7d642
fixes
albertbou92 Sep 19, 2023
7dca35e
fixes
albertbou92 Sep 19, 2023
e8c35ef
fixes
albertbou92 Sep 19, 2023
a4af09f
fixes
albertbou92 Sep 20, 2023
8bf4787
move vtrace to adv script
albertbou92 Sep 21, 2023
8648e10
tests
albertbou92 Sep 21, 2023
dfc1c82
tests
albertbou92 Sep 22, 2023
3da748e
tests
albertbou92 Sep 22, 2023
a568378
fix
albertbou92 Sep 22, 2023
596c6cc
format
albertbou92 Sep 22, 2023
ee692f5
working impala script
albertbou92 Sep 25, 2023
a5eb8b6
working impala script
albertbou92 Sep 25, 2023
b9e81d2
test offpolicy losses
albertbou92 Sep 25, 2023
d04d050
minor script fixes
albertbou92 Sep 25, 2023
2a15708
test onpolicy losses
albertbou92 Sep 25, 2023
a5c2046
test fix
albertbou92 Sep 25, 2023
dbde27c
test fix
albertbou92 Sep 25, 2023
7411235
test fix
albertbou92 Sep 26, 2023
fa5f835
test fix
albertbou92 Sep 26, 2023
30e0cc1
test fix
albertbou92 Oct 2, 2023
0b9ed5c
fixes
albertbou92 Oct 2, 2023
15034f8
merge main
albertbou92 Oct 2, 2023
b3c0c9e
multi node
albertbou92 Oct 2, 2023
c634112
multi node
albertbou92 Oct 2, 2023
df16ace
multi node
albertbou92 Oct 2, 2023
3403e29
fix tests
albertbou92 Oct 3, 2023
9da24eb
fix tests
albertbou92 Oct 3, 2023
795620f
fix tests
albertbou92 Oct 3, 2023
bdc2392
Merge branch 'main' into vtrace
albertbou92 Oct 3, 2023
53aceba
merge main
albertbou92 Oct 3, 2023
02cebf6
multinode script
albertbou92 Oct 3, 2023
5c0aec0
call actor func
albertbou92 Oct 3, 2023
c8ef2c7
faster scripts
albertbou92 Oct 3, 2023
e024c09
multinode script
albertbou92 Oct 3, 2023
55b7947
simplify utils
albertbou92 Oct 3, 2023
6d6df00
revert tests
albertbou92 Oct 4, 2023
d4536d1
Merge branch 'main' into vtrace
albertbou92 Oct 4, 2023
5ebcfb8
Merge branch 'main' into vtrace
albertbou92 Oct 4, 2023
9e1d64b
introduce review feedback
albertbou92 Oct 4, 2023
224ae91
torch compile
albertbou92 Oct 4, 2023
e543888
torch compile
albertbou92 Oct 4, 2023
db541c0
fix
albertbou92 Oct 4, 2023
937b819
fix
albertbou92 Oct 4, 2023
9e33035
tests
albertbou92 Oct 4, 2023
199bc3b
adapt ppo tests
albertbou92 Oct 5, 2023
f1b11dd
adapt ppo tests
albertbou92 Oct 5, 2023
1d8d1ef
adapt ppo tests
albertbou92 Oct 5, 2023
ebf74b8
fix tests ppo
albertbou92 Oct 5, 2023
1180993
fix tests a2c
albertbou92 Oct 5, 2023
6e73acd
fix tests a2c
albertbou92 Oct 5, 2023
2ecb103
fix tests a2c
albertbou92 Oct 5, 2023
cd07719
fix tests reinforce
albertbou92 Oct 5, 2023
f1d2770
fix tests values
albertbou92 Oct 5, 2023
676d8f5
fix tests values
albertbou92 Oct 5, 2023
f491e7d
fix tests adv
albertbou92 Oct 5, 2023
7a63dd6
fix tests adv
albertbou92 Oct 5, 2023
d30bb9d
fix tests adv
albertbou92 Oct 5, 2023
e19c671
Merge branch 'main' into vtrace
albertbou92 Oct 5, 2023
a9e1db3
code examples
albertbou92 Oct 5, 2023
32cd518
code examples
albertbou92 Oct 5, 2023
53617bb
fix tests adv
albertbou92 Oct 5, 2023
0bc7b8c
fix tests adv
albertbou92 Oct 5, 2023
40cc02f
code examples tests
albertbou92 Oct 5, 2023
cbd923e
code examples tests
albertbou92 Oct 5, 2023
2235c02
code example with submitit
albertbou92 Oct 5, 2023
1a8efd1
code example with submitit
albertbou92 Oct 5, 2023
3ef4001
code example with submitit
albertbou92 Oct 5, 2023
fcc1121
code example with submitit
albertbou92 Oct 5, 2023
dd2a7f3
code example with submitit
albertbou92 Oct 5, 2023
624b2d6
code example with submitit
albertbou92 Oct 5, 2023
7e30069
code example with submitit
albertbou92 Oct 5, 2023
8d6c064
Merge branch 'main' into vtrace
albertbou92 Oct 5, 2023
157ad9b
Merge branch 'main' into vtrace
albertbou92 Nov 14, 2023
597623b
fix logging
albertbou92 Nov 14, 2023
5c21c1e
fix example
albertbou92 Nov 19, 2023
e47dbc3
fix example
albertbou92 Nov 19, 2023
607ad53
Merge branch 'main' into vtrace
albertbou92 Nov 22, 2023
c23401a
Update examples/impala/impala_multi_node_ray.py
albertbou92 Nov 22, 2023
886b4e0
Update torchrl/objectives/value/advantages.py
albertbou92 Nov 22, 2023
803fc4f
Update torchrl/objectives/value/advantages.py
albertbou92 Nov 22, 2023
72a3c6e
Update torchrl/objectives/value/advantages.py
albertbou92 Nov 22, 2023
4a061b5
Update examples/impala/impala_multi_node_ray.py
albertbou92 Nov 22, 2023
e7069e4
Update examples/impala/impala_multi_node_ray.py
albertbou92 Nov 22, 2023
5399cf1
merge main
albertbou92 Nov 22, 2023
39584ab
fixes
albertbou92 Nov 22, 2023
c68fd40
fixes
albertbou92 Nov 22, 2023
638c0d6
format
albertbou92 Nov 22, 2023
2f8b545
fixes
albertbou92 Nov 22, 2023
6ddcb3a
fixes
albertbou92 Nov 22, 2023
94306bf
fixes
albertbou92 Nov 22, 2023
9cc0284
fixes
albertbou92 Nov 23, 2023
63392f0
fixes
albertbou92 Nov 23, 2023
e61f342
fixes
albertbou92 Nov 23, 2023
6d384d5
fixes
albertbou92 Nov 23, 2023
89770e4
submitit example
albertbou92 Nov 23, 2023
9132a60
submitit example
albertbou92 Nov 23, 2023
89a803b
README
albertbou92 Nov 23, 2023
bd02b30
fix tests
albertbou92 Nov 23, 2023
0a382bb
fix unused_args
albertbou92 Nov 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
vtrace module
  • Loading branch information
albertbou92 committed Sep 6, 2023
commit 3e7645002f5d74ebdfb89293b40a638738ade00f
24 changes: 14 additions & 10 deletions torchrl/objectives/value/vtrace.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,12 @@ def _c_val(
log_mu: torch.Tensor,
c: Union[float, torch.Tensor] = 1,
) -> torch.Tensor:
return (log_pi - log_mu).clamp_max(math.log(c)).exp()
return (log_pi - log_mu).clamp_max(math.log(c)).exp().unsqueeze(-1)


def _dv_val(
rewards: torch.Tensor,
vals: torch.Tensor,
next_vals: torch.Tensor,
gamma: Union[float, torch.Tensor],
rho_bar: Union[float, torch.Tensor],
Expand Down Expand Up @@ -70,7 +71,7 @@ def _vtrace(
v_out = torch.stack(list(reversed(v_out)), 1) # values
return v_out, rho

@_transpose_time
# @_transpose_time
def vtrace_correction(
gamma: float,
log_pi: torch.Tensor,
Expand Down Expand Up @@ -104,23 +105,25 @@ def vtrace_correction(
``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.

"""

if not (next_state_value.shape == state_value.shape == reward.shape == done.shape):
raise RuntimeError(SHAPE_ERR)

dtype = next_state_value.dtype
device = state_value.device

import ipdb; ipdb.set_trace()
delta, rho = _dv_val(reward, next_state_value, gamma, rho_bar, log_pi, log_mu)
c = _c_val(log_pi, log_mu, c_bar)
delta, rho = _dv_val(reward, state_value, next_state_value, gamma, rho_bar, log_pi, log_mu)
cs = _c_val(log_pi, log_mu, c_bar)

not_done = (~done).int()
*batch_size, time_steps, lastdim = not_done.shape
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
acc = 0
v_out = torch.empty(*batch_size, time_steps, lastdim, device=device, dtype=dtype)
gnotdone = gamma * not_done
import ipdb; ipdb.set_trace() # TODO: Review!
for t in reversed(range(time_steps)):
import ipdb; ipdb.set_trace() # TODO: Review!
acc = delta[..., t, :] + (gnotdone[..., t, :] * acc * c)
acc = delta[..., t, :] + (gnotdone[..., t, :] * acc * cs[t])
v_out[..., t, :].copy_(acc + state_value[..., t, :])

advantage = rho * (reward + gamma * v_out - state_value) # TODO: Review!
Expand Down Expand Up @@ -195,7 +198,7 @@ def __init__(
value_network: TensorDictModule,
average_gae: bool = False,
differentiable: bool = False,
vectorized: bool = True,
vectorized: bool = False, # TODO: Review!
skip_existing: Optional[bool] = None,
advantage_key: NestedKey = None,
value_target_key: NestedKey = None,
Expand Down Expand Up @@ -324,9 +327,10 @@ def forward(
import ipdb; ipdb.set_trace()
raise NotImplementedError
else:
import ipdb; ipdb.set_trace()
log_pi = tensordict.get(self.tensor_keys.log_pi) # new / local log prob
log_mu = tensordict.get(self.tensor_keys.log_mu) # old / distributed log prob
# log_pi = tensordict.get(self.tensor_keys.log_pi) # new / local log prob
log_pi = tensordict.get("sample_log_prob")
# log_mu = tensordict.get(self.tensor_keys.log_mu) # old / distributed log prob
log_mu = tensordict.get("sample_log_prob")

done = tensordict.get(("next", self.tensor_keys.done))
if self.vectorized:
Expand Down