Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Algorithm] Online Decision transformer #1149

Merged
merged 136 commits into from
Aug 30, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
136 commits
Select commit Hold shift + click to select a range
7d004a0
set struc
BY571 Mar 31, 2023
520b8fb
architecture test
BY571 Apr 4, 2023
4859735
Merge branch 'main' into decision_transformer
BY571 Apr 14, 2023
d19a3e0
Merge branch 'main' into decision_transformer
BY571 Apr 21, 2023
18d3035
update dt transforms
BY571 Apr 21, 2023
d521fa2
update padding
BY571 Apr 21, 2023
c123fe0
take off outputhead
BY571 Apr 24, 2023
cfcc073
update target and testscript
BY571 Apr 26, 2023
2c314c5
Merge branch 'main' into decision_transformer
BY571 Apr 26, 2023
e377ae8
add r2g
BY571 Apr 26, 2023
8b69d6a
update context mask
BY571 Apr 28, 2023
72dc7c8
Merge branch 'main' into decision_transformer
BY571 May 2, 2023
7b9d029
add offline example script first tests
BY571 May 2, 2023
0672e1c
Merge branch 'main' into decision_transformer
BY571 May 4, 2023
e2fb927
Update objective loss
BY571 May 4, 2023
2c657cc
Merge branch 'main' into decision_transformer
BY571 May 5, 2023
69b0974
updates
BY571 May 5, 2023
a5e5da7
add objective
BY571 May 11, 2023
34fc6e8
fix
BY571 May 11, 2023
0200e29
small fixes
BY571 May 12, 2023
001413c
Merge branch 'main' into decision_transformer
BY571 May 12, 2023
9470797
update DT loss docstring
BY571 May 12, 2023
6b8185d
update dt inference wrapper docstring with example
BY571 May 12, 2023
76e3a27
add odt cost tests
BY571 May 12, 2023
247cfd6
Merge branch 'main' into decision_transformer
BY571 May 18, 2023
082a75e
try to add inverse catframes
BY571 May 18, 2023
2b636a6
as_inverse add to catframes
BY571 May 19, 2023
b1788f5
make dt / odt split
BY571 May 22, 2023
c6e3229
add dt odt script
BY571 May 22, 2023
aaa09dd
add dt config
BY571 May 23, 2023
45cbd61
split config
BY571 Jun 1, 2023
54e2b98
merge main and update
BY571 Jun 2, 2023
86ddc44
fix
BY571 Jun 2, 2023
112e800
Merge branch 'main' into decision_transformer
BY571 Jun 2, 2023
170ab13
fix
BY571 Jun 2, 2023
d5177cd
Merge branch 'decision_transformer' of https://github.com/BY571/rl in…
BY571 Jun 2, 2023
1fcbf0e
description catframes
BY571 Jun 2, 2023
165459d
add dt test
BY571 Jun 2, 2023
50f0aa8
add cfg to logger
BY571 Jun 2, 2023
3cc456e
take off detach
BY571 Jun 8, 2023
e890264
Merge branch 'main' into decision_transformer
BY571 Jun 12, 2023
0497449
add loss to docs
BY571 Jun 12, 2023
b24a7f8
update proof_env creation
BY571 Jun 12, 2023
8e04add
move batch to device
BY571 Jun 12, 2023
2ad7af5
remove gpt2model and import directly from hf
BY571 Jun 12, 2023
aeccb22
update docstring
BY571 Jun 12, 2023
2414e9b
update actor docstring
BY571 Jun 12, 2023
b03f3fe
add dispach, in-out-keys
BY571 Jun 12, 2023
e5c4575
update inference actor inputs
BY571 Jun 12, 2023
a5213ce
add inference wrapper to docs
BY571 Jun 12, 2023
6a6b18e
Merge branch 'main' into decision_transformer
BY571 Jun 26, 2023
1f9f885
fix _data
BY571 Jun 27, 2023
792d35c
extract lamb opti
BY571 Jun 27, 2023
0d9fa42
add DT args and example in docstring
BY571 Jun 27, 2023
83642c7
update constant target return and reduction
BY571 Jun 27, 2023
39dda00
fixes for target return transform
BY571 Jun 27, 2023
c5c71e6
update add transformers installed check
BY571 Jun 27, 2023
ca36a0f
update docstring actor DT
BY571 Jun 27, 2023
9c0dfbb
add docstring for modules and examples
BY571 Jun 27, 2023
ddb284e
udpate config
BY571 Jun 27, 2023
cf5de9a
take off unsqueeze in models
BY571 Jun 28, 2023
c3d0ffa
add loss function to config
BY571 Jun 29, 2023
e4ea278
add loss function to config
BY571 Jun 29, 2023
d2c1b08
update loss module
BY571 Jun 29, 2023
a62a647
udpate DT actor docstring
BY571 Jun 29, 2023
2009060
add default transformer config
BY571 Jun 29, 2023
40522c1
merge main
BY571 Jun 29, 2023
623d79a
Merge branch 'main' into decision_transformer
vmoens Jul 3, 2023
77630bd
amend
vmoens Jul 3, 2023
f2defcb
doc
vmoens Jul 3, 2023
f891bd2
tests
vmoens Jul 3, 2023
cf5bc01
lint
vmoens Jul 3, 2023
6d4b591
fix
vmoens Jul 3, 2023
7c0df55
amend
vmoens Jul 3, 2023
d3a3d77
amend
vmoens Jul 4, 2023
8730244
Merge remote-tracking branch 'origin/main' into decision_transformer
vmoens Jul 6, 2023
b1c73da
amend
vmoens Jul 6, 2023
2ec7b0f
fix tests
vmoens Jul 6, 2023
8b8f7b1
fix tests
vmoens Jul 6, 2023
a00aae4
Merge remote-tracking branch 'origin/main' into decision_transformer
vmoens Jul 6, 2023
f49d07d
mesalib glew glfw libosmesa6-dev
vmoens Jul 6, 2023
ff4c34a
libosmesa6-dev
vmoens Jul 6, 2023
40024ed
patchelf
vmoens Jul 6, 2023
dffe5fc
temp hiding
vmoens Jul 6, 2023
2657662
Merge branch 'main' into decision_transformer
vmoens Jul 7, 2023
81d9b34
amend
vmoens Jul 7, 2023
c75eb39
amend
vmoens Jul 7, 2023
540d82b
amend
vmoens Jul 7, 2023
091a119
amend
vmoens Jul 7, 2023
87866a7
amend
vmoens Jul 7, 2023
f0606de
Merge branch 'main' into decision_transformer
vmoens Jul 7, 2023
24c129f
amend
vmoens Jul 7, 2023
ad8d412
empty
vmoens Jul 7, 2023
4a64716
fix wandb
vmoens Jul 7, 2023
f00359d
Merge remote-tracking branch 'origin/main' into decision_transformer_ssh
vmoens Jul 7, 2023
edaa7b5
lint
vmoens Jul 7, 2023
8abb8f3
amend
vmoens Jul 7, 2023
dfcff63
amend
vmoens Jul 7, 2023
395456c
amend
vmoens Jul 7, 2023
d58675e
amend
vmoens Jul 7, 2023
244e429
amend
vmoens Jul 7, 2023
c9338b3
amend
vmoens Jul 7, 2023
cdadf46
Added list of D4RL datasets
MateuszGuzek Jul 10, 2023
4bada6f
Merge remote-tracking branch 'origin/main' into decision_transformer_ssh
vmoens Jul 10, 2023
311d00d
minor
vmoens Jul 10, 2023
587cff6
amend
vmoens Jul 10, 2023
98688b6
Merge branch 'd4rl_direct_download' into decision_transformer_ssh
vmoens Jul 10, 2023
7342c83
amend
vmoens Jul 10, 2023
18c6b00
amend
vmoens Jul 10, 2023
c4c02e6
revert d4rl
vmoens Jul 10, 2023
d67a822
amend
vmoens Jul 11, 2023
29a1067
amend
vmoens Jul 11, 2023
0b8d564
amend
vmoens Jul 11, 2023
3988ebf
Merge remote-tracking branch 'origin/main' into decision_transformer
vmoens Jul 11, 2023
b08d3d4
fix
vmoens Jul 11, 2023
4e57244
Merge remote-tracking branch 'origin/main' into decision_transformer
vmoens Jul 11, 2023
a522db0
fix reward scale, reduce target return config
BY571 Jul 12, 2023
17a86d7
Merge branch 'decision_transformer' of https://github.com/BY571/rl in…
BY571 Jul 12, 2023
aefbf61
amend
vmoens Jul 13, 2023
9afb0a7
Merge branch 'decision_transformer' of https://github.com/BY571/rl in…
vmoens Jul 13, 2023
1c7cbbf
amend
vmoens Jul 13, 2023
11d8779
zero padding, fix obs loc, std for normalization
BY571 Jul 26, 2023
094808a
Merge branch 'decision_transformer' of https://github.com/BY571/rl in…
BY571 Jul 26, 2023
b383339
Merge branch 'main' into decision_transformer
vmoens Jul 28, 2023
3ff2fc6
temp - SerialEnv
vmoens Jul 30, 2023
c43a02d
merge main into branch
BY571 Aug 2, 2023
9135fa7
fix obs norm, fix action context
BY571 Aug 4, 2023
c3a67c8
update buffer transforms to not use catframes
BY571 Aug 7, 2023
ca505eb
test dist, small fixes
BY571 Aug 14, 2023
b260785
update utils
BY571 Aug 22, 2023
a820015
update and fixes
BY571 Aug 23, 2023
a29c3b4
Merge branch 'main' into decision_transformer
BY571 Aug 23, 2023
6220c04
pull changes
BY571 Aug 23, 2023
17093b7
running examples
vmoens Aug 26, 2023
a717c8e
update header, docs and delete dtwrapper
BY571 Aug 28, 2023
3846a21
Merge branch 'decision_transformer' of https://github.com/BY571/rl in…
BY571 Aug 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add r2g
  • Loading branch information
BY571 committed Apr 26, 2023
commit e377ae8360028f2233dc86fd852f11c6001782fa
1 change: 1 addition & 0 deletions examples/decision_transformer/dt_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def main(cfg: "DictConfig"): # noqa: F821
print(td_test)

collector = make_collector(cfg, policy=actor)

replay_buffer = make_replay_buffer(cfg.replay_buffer)
for data in collector:
data_view = data.reshape(-1)
Expand Down
11 changes: 5 additions & 6 deletions examples/decision_transformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@

from torchrl.collectors import SyncDataCollector
from torchrl.data import CompositeSpec, LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import PrioritizedSampler, RandomSampler
from torchrl.data.replay_buffers.samplers import RandomSampler
from torchrl.envs import (
CatFrames,
EnvCreator,
NoopResetEnv,
ObservationNorm,
ParallelEnv,
RenameTransform,
Reward2GoTransform,
StepCounter,
TargetReturn,
TensorDictPrimer,
Expand Down Expand Up @@ -168,12 +169,10 @@ def make_collector(cfg, policy):


def make_replay_buffer(rb_cfg):
if rb_cfg.prb:
sampler = PrioritizedSampler(max_capacity=rb_cfg.capacity, alpha=0.7, beta=0.5)
else:
sampler = RandomSampler()
r2g = Reward2GoTransform(gamma=1.0, out_keys=["return_to_go"])
sampler = RandomSampler()
return TensorDictReplayBuffer(
storage=LazyMemmapStorage(rb_cfg.capacity), sampler=sampler
storage=LazyMemmapStorage(rb_cfg.capacity), sampler=sampler, transform=r2g
)


Expand Down