Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Algorithm] Online Decision transformer #1149

Merged
merged 136 commits into from
Aug 30, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
136 commits
Select commit Hold shift + click to select a range
7d004a0
set struc
BY571 Mar 31, 2023
520b8fb
architecture test
BY571 Apr 4, 2023
4859735
Merge branch 'main' into decision_transformer
BY571 Apr 14, 2023
d19a3e0
Merge branch 'main' into decision_transformer
BY571 Apr 21, 2023
18d3035
update dt transforms
BY571 Apr 21, 2023
d521fa2
update padding
BY571 Apr 21, 2023
c123fe0
take off outputhead
BY571 Apr 24, 2023
cfcc073
update target and testscript
BY571 Apr 26, 2023
2c314c5
Merge branch 'main' into decision_transformer
BY571 Apr 26, 2023
e377ae8
add r2g
BY571 Apr 26, 2023
8b69d6a
update context mask
BY571 Apr 28, 2023
72dc7c8
Merge branch 'main' into decision_transformer
BY571 May 2, 2023
7b9d029
add offline example script first tests
BY571 May 2, 2023
0672e1c
Merge branch 'main' into decision_transformer
BY571 May 4, 2023
e2fb927
Update objective loss
BY571 May 4, 2023
2c657cc
Merge branch 'main' into decision_transformer
BY571 May 5, 2023
69b0974
updates
BY571 May 5, 2023
a5e5da7
add objective
BY571 May 11, 2023
34fc6e8
fix
BY571 May 11, 2023
0200e29
small fixes
BY571 May 12, 2023
001413c
Merge branch 'main' into decision_transformer
BY571 May 12, 2023
9470797
update DT loss docstring
BY571 May 12, 2023
6b8185d
update dt inference wrapper docstring with example
BY571 May 12, 2023
76e3a27
add odt cost tests
BY571 May 12, 2023
247cfd6
Merge branch 'main' into decision_transformer
BY571 May 18, 2023
082a75e
try to add inverse catframes
BY571 May 18, 2023
2b636a6
as_inverse add to catframes
BY571 May 19, 2023
b1788f5
make dt / odt split
BY571 May 22, 2023
c6e3229
add dt odt script
BY571 May 22, 2023
aaa09dd
add dt config
BY571 May 23, 2023
45cbd61
split config
BY571 Jun 1, 2023
54e2b98
merge main and update
BY571 Jun 2, 2023
86ddc44
fix
BY571 Jun 2, 2023
112e800
Merge branch 'main' into decision_transformer
BY571 Jun 2, 2023
170ab13
fix
BY571 Jun 2, 2023
d5177cd
Merge branch 'decision_transformer' of https://github.com/BY571/rl in…
BY571 Jun 2, 2023
1fcbf0e
description catframes
BY571 Jun 2, 2023
165459d
add dt test
BY571 Jun 2, 2023
50f0aa8
add cfg to logger
BY571 Jun 2, 2023
3cc456e
take off detach
BY571 Jun 8, 2023
e890264
Merge branch 'main' into decision_transformer
BY571 Jun 12, 2023
0497449
add loss to docs
BY571 Jun 12, 2023
b24a7f8
update proof_env creation
BY571 Jun 12, 2023
8e04add
move batch to device
BY571 Jun 12, 2023
2ad7af5
remove gpt2model and import directly from hf
BY571 Jun 12, 2023
aeccb22
update docstring
BY571 Jun 12, 2023
2414e9b
update actor docstring
BY571 Jun 12, 2023
b03f3fe
add dispach, in-out-keys
BY571 Jun 12, 2023
e5c4575
update inference actor inputs
BY571 Jun 12, 2023
a5213ce
add inference wrapper to docs
BY571 Jun 12, 2023
6a6b18e
Merge branch 'main' into decision_transformer
BY571 Jun 26, 2023
1f9f885
fix _data
BY571 Jun 27, 2023
792d35c
extract lamb opti
BY571 Jun 27, 2023
0d9fa42
add DT args and example in docstring
BY571 Jun 27, 2023
83642c7
update constant target return and reduction
BY571 Jun 27, 2023
39dda00
fixes for target return transform
BY571 Jun 27, 2023
c5c71e6
update add transformers installed check
BY571 Jun 27, 2023
ca36a0f
update docstring actor DT
BY571 Jun 27, 2023
9c0dfbb
add docstring for modules and examples
BY571 Jun 27, 2023
ddb284e
udpate config
BY571 Jun 27, 2023
cf5de9a
take off unsqueeze in models
BY571 Jun 28, 2023
c3d0ffa
add loss function to config
BY571 Jun 29, 2023
e4ea278
add loss function to config
BY571 Jun 29, 2023
d2c1b08
update loss module
BY571 Jun 29, 2023
a62a647
udpate DT actor docstring
BY571 Jun 29, 2023
2009060
add default transformer config
BY571 Jun 29, 2023
40522c1
merge main
BY571 Jun 29, 2023
623d79a
Merge branch 'main' into decision_transformer
vmoens Jul 3, 2023
77630bd
amend
vmoens Jul 3, 2023
f2defcb
doc
vmoens Jul 3, 2023
f891bd2
tests
vmoens Jul 3, 2023
cf5bc01
lint
vmoens Jul 3, 2023
6d4b591
fix
vmoens Jul 3, 2023
7c0df55
amend
vmoens Jul 3, 2023
d3a3d77
amend
vmoens Jul 4, 2023
8730244
Merge remote-tracking branch 'origin/main' into decision_transformer
vmoens Jul 6, 2023
b1c73da
amend
vmoens Jul 6, 2023
2ec7b0f
fix tests
vmoens Jul 6, 2023
8b8f7b1
fix tests
vmoens Jul 6, 2023
a00aae4
Merge remote-tracking branch 'origin/main' into decision_transformer
vmoens Jul 6, 2023
f49d07d
mesalib glew glfw libosmesa6-dev
vmoens Jul 6, 2023
ff4c34a
libosmesa6-dev
vmoens Jul 6, 2023
40024ed
patchelf
vmoens Jul 6, 2023
dffe5fc
temp hiding
vmoens Jul 6, 2023
2657662
Merge branch 'main' into decision_transformer
vmoens Jul 7, 2023
81d9b34
amend
vmoens Jul 7, 2023
c75eb39
amend
vmoens Jul 7, 2023
540d82b
amend
vmoens Jul 7, 2023
091a119
amend
vmoens Jul 7, 2023
87866a7
amend
vmoens Jul 7, 2023
f0606de
Merge branch 'main' into decision_transformer
vmoens Jul 7, 2023
24c129f
amend
vmoens Jul 7, 2023
ad8d412
empty
vmoens Jul 7, 2023
4a64716
fix wandb
vmoens Jul 7, 2023
f00359d
Merge remote-tracking branch 'origin/main' into decision_transformer_ssh
vmoens Jul 7, 2023
edaa7b5
lint
vmoens Jul 7, 2023
8abb8f3
amend
vmoens Jul 7, 2023
dfcff63
amend
vmoens Jul 7, 2023
395456c
amend
vmoens Jul 7, 2023
d58675e
amend
vmoens Jul 7, 2023
244e429
amend
vmoens Jul 7, 2023
c9338b3
amend
vmoens Jul 7, 2023
cdadf46
Added list of D4RL datasets
MateuszGuzek Jul 10, 2023
4bada6f
Merge remote-tracking branch 'origin/main' into decision_transformer_ssh
vmoens Jul 10, 2023
311d00d
minor
vmoens Jul 10, 2023
587cff6
amend
vmoens Jul 10, 2023
98688b6
Merge branch 'd4rl_direct_download' into decision_transformer_ssh
vmoens Jul 10, 2023
7342c83
amend
vmoens Jul 10, 2023
18c6b00
amend
vmoens Jul 10, 2023
c4c02e6
revert d4rl
vmoens Jul 10, 2023
d67a822
amend
vmoens Jul 11, 2023
29a1067
amend
vmoens Jul 11, 2023
0b8d564
amend
vmoens Jul 11, 2023
3988ebf
Merge remote-tracking branch 'origin/main' into decision_transformer
vmoens Jul 11, 2023
b08d3d4
fix
vmoens Jul 11, 2023
4e57244
Merge remote-tracking branch 'origin/main' into decision_transformer
vmoens Jul 11, 2023
a522db0
fix reward scale, reduce target return config
BY571 Jul 12, 2023
17a86d7
Merge branch 'decision_transformer' of https://github.com/BY571/rl in…
BY571 Jul 12, 2023
aefbf61
amend
vmoens Jul 13, 2023
9afb0a7
Merge branch 'decision_transformer' of https://github.com/BY571/rl in…
vmoens Jul 13, 2023
1c7cbbf
amend
vmoens Jul 13, 2023
11d8779
zero padding, fix obs loc, std for normalization
BY571 Jul 26, 2023
094808a
Merge branch 'decision_transformer' of https://github.com/BY571/rl in…
BY571 Jul 26, 2023
b383339
Merge branch 'main' into decision_transformer
vmoens Jul 28, 2023
3ff2fc6
temp - SerialEnv
vmoens Jul 30, 2023
c43a02d
merge main into branch
BY571 Aug 2, 2023
9135fa7
fix obs norm, fix action context
BY571 Aug 4, 2023
c3a67c8
update buffer transforms to not use catframes
BY571 Aug 7, 2023
ca505eb
test dist, small fixes
BY571 Aug 14, 2023
b260785
update utils
BY571 Aug 22, 2023
a820015
update and fixes
BY571 Aug 23, 2023
a29c3b4
Merge branch 'main' into decision_transformer
BY571 Aug 23, 2023
6220c04
pull changes
BY571 Aug 23, 2023
17093b7
running examples
vmoens Aug 26, 2023
a717c8e
update header, docs and delete dtwrapper
BY571 Aug 28, 2023
3846a21
Merge branch 'decision_transformer' of https://github.com/BY571/rl in…
BY571 Aug 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update header, docs and delete dtwrapper
  • Loading branch information
BY571 committed Aug 28, 2023
commit a717c8e1411ca3c5b14ef735b089eed51d26f685
51 changes: 4 additions & 47 deletions examples/decision_transformer/lamb.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,8 @@
""" PyTorch Lamb optimizer w/ behaviour similar to NVIDIA FusedLamb
This optimizer code was adapted from the following (starting with latest)
* https://github.com/HabanaAI/Model-References/blob/2b435114fe8e31f159b1d3063b8280ae37af7423/PyTorch/nlp/bert/pretraining/lamb.py
* https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py
* https://github.com/cybertronai/pytorch-lamb
Use FusedLamb if you can (GPU). The reason for including this variant of Lamb is to have a version that is
similar in behaviour to APEX FusedLamb if you aren't using NVIDIA GPUs or cannot install/use APEX.
In addition to some cleanup, this Lamb impl has been modified to support PyTorch XLA and has been tested on TPU.
Original copyrights for above sources are below.
Modifications Copyright 2021 Ross Wightman
"""
# Copyright (c) 2021, Habana Labs Ltd. All rights reserved.

# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# MIT License
#
# Copyright (c) 2019 cybertronai
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# Lamb optimizer directly copied from https://github.com/facebookresearch/online-dt
import math

import torch
Expand Down
4 changes: 2 additions & 2 deletions examples/decision_transformer/odt_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ replay_buffer:
buffer_prefetch: 64
capacity: 1_000_000
buffer_scratch_dir: "/tmp/"
device: cpu
device: cuda:0
prefetch: 3

# Optimization
optim:
device: cpu
device: cuda:0
lr: 1.0e-4
weight_decay: 5.0e-4
batch_size: 256
Expand Down
4 changes: 2 additions & 2 deletions examples/decision_transformer/online_dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from torchrl.envs.libs.gym import set_gym_backend

# from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules.tensordict_module import DecisionTransformerInferenceWrapper

from utils import (
Expand Down Expand Up @@ -81,7 +81,7 @@ def main(cfg: "DictConfig"): # noqa: F821
scheduler.step()

# evaluation
with torch.no_grad():
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
inference_policy.eval()
if i % pretrain_log_interval == 0:
eval_td = test_env.rollout(
Expand Down
5 changes: 5 additions & 0 deletions examples/decision_transformer/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch.nn
BY571 marked this conversation as resolved.
Show resolved Hide resolved

import torch.optim
Expand Down
137 changes: 12 additions & 125 deletions torchrl/modules/models/decision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,127 +8,12 @@

import importlib
from dataclasses import dataclass
from typing import Any, Optional
from typing import Any

import torch
BY571 marked this conversation as resolved.
Show resolved Hide resolved
import torch.nn as nn

_has_transformers = importlib.util.find_spec("transformers") is not None
import transformers
from transformers.models.gpt2.modeling_gpt2 import (
BaseModelOutputWithPastAndCrossAttentions,
GPT2Model,
)


class ModifiedGPT2Model(GPT2Model):
"""Wrapper around the GPT2Model from transformers.

This class is a modified version of the GPT2Model from transformers
as for the Decision Transformer we dont need the wpe layer.

"""

def __init__(self, config):
super(ModifiedGPT2Model, self).__init__(config)

# Remove the wpe layer
del self.wpe

def forward(
self,
inputs_embeds: Optional[torch.FloatTensor] = None,
):
input_shape = inputs_embeds.size()[:-1]

output_attentions = self.config.output_attentions
output_hidden_states = self.config.output_hidden_states
use_cache = self.config.use_cache
return_dict = self.config.use_return_dict

head_mask = self.get_head_mask(None, self.config.n_layer)

hidden_states = inputs_embeds

hidden_states = self.drop(hidden_states)

output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)

presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = (
() if output_attentions and self.config.add_cross_attention else None
)
all_hidden_states = () if output_hidden_states else None
past_key_values = tuple([None] * len(self.h))
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
# Model parallel
if self.model_parallel:
torch.cuda.set_device(hidden_states.device)
# Ensure layer_past is on same device as hidden_states (might not be correct)
if layer_past is not None:
layer_past = tuple(
past_state.to(hidden_states.device) for past_state in layer_past
)
if isinstance(head_mask, torch.Tensor):
head_mask = head_mask.to(hidden_states.device)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

outputs = block(
hidden_states,
layer_past=layer_past,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
)

hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)

if output_attentions:
all_self_attentions = all_self_attentions + (
outputs[2 if use_cache else 1],
)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (
outputs[3 if use_cache else 2],
)

# Model Parallel: If it's the last layer for that device, put things on the next device
if self.model_parallel:
for k, v in self.device_map.items():
if i == v[-1] and "cuda:" + str(k) != self.last_device:
hidden_states = hidden_states.to("cuda:" + str(k + 1))

hidden_states = self.ln_f(hidden_states)

hidden_states = hidden_states.view(output_shape)
# Add last hidden state
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

if not return_dict:
return tuple(
v
for v in [
hidden_states,
presents,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)

return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)


class DecisionTransformer(nn.Module):
Expand All @@ -138,14 +23,14 @@ class DecisionTransformer(nn.Module):

The transformer utilizes a default config to create the GPT2 model if the user does not provide a specific config.
default_config = {
"n_embd": 256,
"n_layer": 4,
"n_head": 4,
"n_inner": 1024,
"activation": "relu",
"n_positions": 1024,
"resid_pdrop": 0.1,
"attn_pdrop": 0.1,
... "n_embd": 256,
... "n_layer": 4,
... "n_head": 4,
... "n_inner": 1024,
... "activation": "relu",
... "n_positions": 1024,
... "resid_pdrop": 0.1,
... "attn_pdrop": 0.1,
}

Args:
Expand Down Expand Up @@ -210,6 +95,8 @@ def __init__(
raise ImportError(
"transformers is not installed. Please install it with `pip install transformers`."
)
import transformers
from transformers.models.gpt2.modeling_gpt2 import GPT2Model

if config is None:
config = self.default_config()
Expand Down Expand Up @@ -240,7 +127,7 @@ def __init__(
self.action_dim = action_dim
self.hidden_size = config["n_embd"]

self.transformer = ModifiedGPT2Model(config=gpt_config)
self.transformer = GPT2Model(config=gpt_config)

self.embed_return = torch.nn.Linear(1, self.hidden_size)
self.embed_state = torch.nn.Linear(self.state_dim, self.hidden_size)
Expand Down
16 changes: 8 additions & 8 deletions torchrl/modules/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1289,14 +1289,14 @@ def __init__(
transformer_config["n_embd"], action_dim, device=device
)

# def weight_init(m):
# """Custom weight init for Conv2D and Linear layers."""
# if isinstance(m, torch.nn.Linear):
# nn.init.orthogonal_(m.weight.data)
# if hasattr(m.bias, "data"):
# m.bias.data.fill_(0.0)

# self.apply(weight_init)
def weight_init(m):
"""Custom weight init for Conv2D and Linear layers."""
if isinstance(m, torch.nn.Linear):
nn.init.orthogonal_(m.weight.data)
if hasattr(m.bias, "data"):
m.bias.data.fill_(0.0)

self.action_layer.apply(weight_init)

def forward(
self,
Expand Down