From 7d004a0a7dac34d794b4e281c9dc30ae5932760d Mon Sep 17 00:00:00 2001 From: BY571 Date: Fri, 31 Mar 2023 09:38:09 +0200 Subject: [PATCH 001/104] set struc --- examples/decision_transformer/dt_online.py | 0 examples/decision_transformer/utils.py | 18 ++++++ torchrl/modules/__init__.py | 1 + torchrl/modules/models/__init__.py | 1 + torchrl/modules/models/gpt2_transformer.py | 0 torchrl/modules/models/models.py | 64 ++++++++++++++++++++++ 6 files changed, 84 insertions(+) create mode 100644 examples/decision_transformer/dt_online.py create mode 100644 examples/decision_transformer/utils.py create mode 100644 torchrl/modules/models/gpt2_transformer.py diff --git a/examples/decision_transformer/dt_online.py b/examples/decision_transformer/dt_online.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/examples/decision_transformer/utils.py b/examples/decision_transformer/utils.py new file mode 100644 index 00000000000..507d3f70815 --- /dev/null +++ b/examples/decision_transformer/utils.py @@ -0,0 +1,18 @@ +# from torchrl.modules import DTActor + + +# def make_decision_transformer(cfg): + +# transformer = transformer() +# actor_head = DTActor() + +# # actor = ProbabilisticActor( +# # spec=action_spec, +# # in_keys=["loc", "scale"], +# # module=actor_module, +# # distribution_class=dist_class, +# # distribution_kwargs=dist_kwargs, +# # default_interaction_mode="random", +# # cash +# # return_log_prob=False, +# # ) diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index 5a3f4fdbb2b..152316dcb63 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -21,6 +21,7 @@ DdpgMlpQNet, DistributionalDQNnet, DreamerActor, + DTActor, DuelingCnnDQNet, LSTMNet, MLP, diff --git a/torchrl/modules/models/__init__.py b/torchrl/modules/models/__init__.py index 8654d338c18..257eb9628ad 100644 --- a/torchrl/modules/models/__init__.py +++ b/torchrl/modules/models/__init__.py @@ -13,6 +13,7 @@ DdpgMlpActor, DdpgMlpQNet, DistributionalDQNnet, + DTActor, DuelingCnnDQNet, LSTMNet, MLP, diff --git a/torchrl/modules/models/gpt2_transformer.py b/torchrl/modules/models/gpt2_transformer.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/torchrl/modules/models/models.py b/torchrl/modules/models/models.py index 575c12daa74..777aaf28888 100644 --- a/torchrl/modules/models/models.py +++ b/torchrl/modules/models/models.py @@ -1105,3 +1105,67 @@ def forward( input = self.mlp(input) return self._lstm(input, hidden0_in, hidden1_in) + + +class DTActor(nn.Module): + """Decision Transformer Actor class. + + Presented in "Online Decision Transformer", + https://arxiv.org/abs/2202.05607.pdf + + The DDPG Actor takes as input an observation vector and returns an action from it. + It is trained to maximise the value returned by the DDPG Q Value network. + + Args: + action_dim (int): length of the action vector + mlp_net_kwargs (dict, optional): kwargs for MLP. + Default: { + 'in_features': None, + 'out_features': action_dim, + 'depth': 2, + 'num_cells': [400, 300], + 'activation_class': nn.ELU, + 'bias_last_layer': True, + } + device (Optional[DEVICE_TYPING]): device to create the module on. + """ + + def __init__( + self, + action_dim: int, + mlp_net_kwargs: Optional[dict] = None, + device: Optional[DEVICE_TYPING] = None, + ): + super().__init__() + mlp_net_default_kwargs = { + "out_features": action_dim, + "depth": 1, + "num_cells": [512], + "activation_class": nn.ReLU, + "bias_last_layer": True, + } + # log_std_bounds: Tuple[float, float] = [-5.0, 2.0], + log_std_bounds = [-5.0, 2.0] + self.log_std_bounds = log_std_bounds + mlp_net_kwargs = mlp_net_kwargs if mlp_net_kwargs is not None else {} + mlp_net_default_kwargs.update(mlp_net_kwargs) + self.mlp = MLP(device=device, **mlp_net_default_kwargs) + self.apply(dt_actor_weight_init) + + def forward(self, observation: torch.Tensor) -> torch.Tensor: + out = self.mlp(observation) + mu, log_std = out.chunk(2, -1) + log_std = torch.tanh(log_std) + log_std = min(self.log_std_bounds) + 0.5 * ( + max(self.log_std_bounds) - min(self.log_std_bounds) + ) * (log_std + 1.0) + std = torch.exp(log_std) + return (mu, std) + + +def dt_actor_weight_init(m): + """Weight init used in the Decision Transformer for the actor layers.""" + if isinstance(m, torch.nn.Linear): + nn.init.orthogonal_(m.weight.data) + if hasattr(m.bias, "data"): + m.bias.data.fill_(0.0) From 520b8fbf2b692d133748d52eb84a4c893ba3b74d Mon Sep 17 00:00:00 2001 From: BY571 Date: Tue, 4 Apr 2023 08:29:47 +0200 Subject: [PATCH 002/104] architecture test --- examples/decision_transformer/dt_online.py | 84 +++ .../modules/models/decision_transformer.py | 99 ++++ torchrl/modules/models/gpt2_transformer.py | 504 ++++++++++++++++++ torchrl/modules/models/models.py | 9 +- 4 files changed, 691 insertions(+), 5 deletions(-) create mode 100644 torchrl/modules/models/decision_transformer.py diff --git a/examples/decision_transformer/dt_online.py b/examples/decision_transformer/dt_online.py index e69de29bb2d..2f728c35d3a 100644 --- a/examples/decision_transformer/dt_online.py +++ b/examples/decision_transformer/dt_online.py @@ -0,0 +1,84 @@ +import torch +from tensordict import TensorDict +from tensordict.nn import TensorDictModule +from torchrl.envs import EnvCreator, ParallelEnv +from torchrl.envs.libs.gym import GymEnv +from torchrl.modules import ProbabilisticActor +from torchrl.modules.distributions import TanhNormal +from torchrl.modules.models import DTActor +from torchrl.modules.models.decision_transformer import DecisionTransformer + + +def env_maker(env_name, frame_skip=1, device="cpu", from_pixels=False): + return GymEnv( + env_name, "run", device=device, frame_skip=frame_skip, from_pixels=from_pixels + ) + + +def env_factory(num_workers): + """Creates an instance of the environment.""" + + # 1.2 Create env vector + vec_env = ParallelEnv( + create_env_fn=EnvCreator(lambda: env_maker(env_name="Pendulum-v1")), + num_workers=num_workers, + ) + + return vec_env + + +# Sanity check +test_env = env_factory(num_workers=1) +action_spec = test_env.action_spec + +in_keys = ["observation", "action", "returns_to_go", "timesteps", "padding_mask"] +transformer = DecisionTransformer( + state_dim=5, action_dim=2, hidden_size=512, max_ep_len=1000, ordering=False +) + +actor_head = DTActor(action_dim=2) + +actor_net = torch.nn.ModuleList([transformer, actor_head]) + +dist_class = TanhNormal +dist_kwargs = { + "min": -1.0, + "max": 1.0, + "tanh_loc": False, +} + +actor_module = TensorDictModule( + actor_net, + in_keys=in_keys, + out_keys=["loc", "scale", "hidden_state"], +) +actor = ProbabilisticActor( + spec=action_spec, + in_keys=["loc", "scale", "hidden_state"], + out_keys=["action", "log_prob", "hidden_state"], + module=actor_module, + distribution_class=dist_class, + distribution_kwargs=dist_kwargs, + default_interaction_mode="random", + cache_dist=True, + return_log_prob=False, +) + +print(transformer) + +observation = torch.rand(1, 20, 5) +action = torch.rand(1, 20, 2) +reward_to_go = torch.rand(1, 20, 1) +padding_mask = torch.ones(1, 20, 1) +timesteps = torch.arange(1, 21).unsqueeze(0).unsqueeze(-1) + +td = TensorDict( + { + "observation": observation, + "action": action, + "returns_to_go": reward_to_go, + "padding_mask": padding_mask, + "timesteps": timesteps, + }, + batch_size=1, +) diff --git a/torchrl/modules/models/decision_transformer.py b/torchrl/modules/models/decision_transformer.py new file mode 100644 index 00000000000..6219dca8406 --- /dev/null +++ b/torchrl/modules/models/decision_transformer.py @@ -0,0 +1,99 @@ +from typing import Optional + +import torch +import torch.nn as nn +import transformers +from torchrl.modules.models.gpt2_transformer import GPT2Model + + +class DecisionTransformer(nn.Module): + """Decion Transformer as described in https://arxiv.org/abs/2202.05607 .""" + + def __init__( + self, state_dim, action_dim, hidden_size=512, max_ep_len=1000, ordering=False + ): + super(DecisionTransformer, self).__init__() + assert hidden_size == 512, "Only hidden_size=512 is supported" + gpt_config = transformers.GPT2Config( + n_embd=512, + n_layer=4, + n_head=4, + n_inner=4 * 512, + activation_function="relu", + n_positions=1024, + resid_pdrop=0.1, + attn_pdrop=0.1, + vocab_size=1, + ) + self.state_dim = state_dim + self.action_dim = action_dim + self.hidden_size = hidden_size + self.ordering = ordering + + self.transformer = GPT2Model(config=gpt_config) + if ordering: + self.embed_ordering = nn.Embedding(max_ep_len, hidden_size) + self.embed_return = torch.nn.Linear(1, hidden_size) + self.embed_state = torch.nn.Linear(self.state_dim, hidden_size) + self.embed_action = torch.nn.Linear(self.action_dim, hidden_size) + + self.embed_ln = nn.LayerNorm(hidden_size) + + def forward( + self, + observations: torch.Tensor, + actions: torch.Tensor, + returns_to_go: torch.Tensor, + timesteps: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + ): + batch_size, seq_length = observations.shape[0], observations.shape[1] + + if padding_mask is None: + # attention mask for GPT: 1 if can be attended to, 0 if not + padding_mask = torch.ones((batch_size, seq_length), dtype=torch.long) + + # embed each modality with a different head + state_embeddings = self.embed_state(observations) + action_embeddings = self.embed_action(actions) + returns_embeddings = self.embed_return(returns_to_go) + + if self.ordering: + order_embeddings = self.embed_ordering(timesteps) + else: + order_embeddings = 0.0 + + state_embeddings = state_embeddings + order_embeddings + action_embeddings = action_embeddings + order_embeddings + returns_embeddings = returns_embeddings + order_embeddings + + # this makes the sequence look like (R_1, s_1, a_1, R_2, s_2, a_2, ...) + # which works nice in an autoregressive sense since states predict actions + stacked_inputs = ( + torch.stack( + (returns_embeddings, state_embeddings, action_embeddings), dim=1 + ) + .permute(0, 2, 1, 3) + .reshape(batch_size, 3 * seq_length, self.hidden_size) + ) + stacked_inputs = self.embed_ln(stacked_inputs) + + # to make the attention mask fit the stacked inputs, have to stack it as well + stacked_padding_mask = ( + torch.stack((padding_mask, padding_mask, padding_mask), dim=1) + .permute(0, 2, 1) + .reshape(batch_size, 3 * seq_length) + ) + + # we feed in the input embeddings (not word indices as in NLP) to the model + transformer_outputs = self.transformer( + inputs_embeds=stacked_inputs, + attention_mask=stacked_padding_mask, + ) + x = transformer_outputs["last_hidden_state"] + + # reshape x so that the second dimension corresponds to the original + # returns (0), states (1), or actions (2); i.e. x[:,1,t] is the token for s_t + x = x.reshape(batch_size, seq_length, 3, self.hidden_size).permute(0, 2, 1, 3) + + return x[:, 1] diff --git a/torchrl/modules/models/gpt2_transformer.py b/torchrl/modules/models/gpt2_transformer.py index e69de29bb2d..b7b6c60aeb5 100644 --- a/torchrl/modules/models/gpt2_transformer.py +++ b/torchrl/modules/models/gpt2_transformer.py @@ -0,0 +1,504 @@ +# coding=utf-8 +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch OpenAI GPT-2 model.""" + +import warnings + +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions + +from transformers.utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) +from transformers.utils.model_parallel_utils import assert_device_map, get_device_map + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "gpt2" +_CONFIG_FOR_DOC = "GPT2Config" + +from transformers import GPT2PreTrainedModel +from transformers.models.gpt2.modeling_gpt2 import GPT2Block + + +GPT2_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`GPT2Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +GPT2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else + `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + sequence tokens in the vocabulary. + + If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as + `input_ids`. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`): + Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see + `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have + their past given to this model should not be passed as `input_ids` as they have already been computed. + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for + `past_key_values`. In other words, the `attention_mask` always has to have the length: + `len(past_key_values) + len(input_ids)` + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + + If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see + `past_key_values`). + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" +PARALLELIZE_DOCSTRING = r""" + This is an experimental feature and is a subject to change at a moment's notice. + + Uses a device map to distribute attention modules of the model across several devices. If no device map is given, + it will evenly distribute blocks across all devices. + + Args: + device_map (`Dict[int, list]`, optional, defaults to None): + A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always + automatically mapped to the first device (for esoteric reasons). That means that the first device should + have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the + following number of attention modules: + + - gpt2: 12 + - gpt2-medium: 24 + - gpt2-large: 36 + - gpt2-xl: 48 + + Example: + + ```python + # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules: + model = GPT2LMHeadModel.from_pretrained("gpt2-xl") + device_map = { + 0: [0, 1, 2, 3, 4, 5, 6, 7, 8], + 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21], + 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34], + 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + } + model.parallelize(device_map) + ``` +""" +DEPARALLELIZE_DOCSTRING = r""" + Moves the model to cpu from a model parallel state. + + Example: + + ```python + # On a 4 GPU machine with gpt2-large: + model = GPT2LMHeadModel.from_pretrained("gpt2-large") + device_map = { + 0: [0, 1, 2, 3, 4, 5, 6, 7], + 1: [8, 9, 10, 11, 12, 13, 14, 15], + 2: [16, 17, 18, 19, 20, 21, 22, 23], + 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35], + } + model.parallelize(device_map) # Splits the model across several devices + model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() + ``` +""" + + +@add_start_docstrings( + "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.", + GPT2_START_DOCSTRING, +) +class GPT2Model(GPT2PreTrainedModel): + """GPT2 Model transformer.""" + + _keys_to_ignore_on_load_missing = ["attn.masked_bias"] + + def __init__(self, config): + super().__init__(config) + + self.embed_dim = config.hidden_size + + self.wte = nn.Embedding(config.vocab_size, self.embed_dim) + # self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + + self.drop = nn.Dropout(config.embd_pdrop) + self.h = nn.ModuleList( + [GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)] + ) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + + # Model parallel + self.model_parallel = False + self.device_map = None + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + # Check validity of device_map + warnings.warn( + "`GPT2Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your" + " model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" + " `device_map` but it needs to be a dictionary module_name to device, so for instance {'h.0': 0, 'h.1': 1," + " ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.h), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.h)) + self.model_parallel = True + self.first_device = ( + "cpu" + if "cpu" in self.device_map.keys() + else "cuda:" + str(min(self.device_map.keys())) + ) + self.last_device = "cuda:" + str(max(self.device_map.keys())) + self.wte = self.wte.to(self.first_device) + self.wpe = self.wpe.to(self.first_device) + # Load onto devices + for k, v in self.device_map.items(): + for block in v: + cuda_device = "cuda:" + str(k) + self.h[block] = self.h[block].to(cuda_device) + # ln_f to last + self.ln_f = self.ln_f.to(self.last_device) + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.model_parallel = False + self.device_map = None + self.first_device = "cpu" + self.last_device = "cpu" + self.wte = self.wte.to("cpu") + self.wpe = self.wpe.to("cpu") + for index in range(len(self.h)): + self.h[index] = self.h[index].to("cpu") + self.ln_f = self.ln_f.to("cpu") + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.wte + + def set_input_embeddings(self, new_embeddings): + self.wte = new_embeddings + + def _prune_heads(self, heads_to_prune): + """Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}.""" + for layer, heads in heads_to_prune.items(): + self.h[layer].attn.prune_heads(heads) + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPastAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + if position_ids is not None: + position_ids = position_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].size(-2) + if position_ids is None: + position_ids = torch.arange( + past_length, + input_shape[-1] + past_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + # GPT2Attention mask. + if attention_mask is not None: + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") + attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.add_cross_attention and encoder_hidden_states is not None: + ( + encoder_batch_size, + encoder_sequence_length, + _, + ) = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + # position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds # + position_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = input_shape + (hidden_states.size(-1),) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = ( + () if output_attentions and self.config.add_cross_attention else None + ) + all_hidden_states = () if output_hidden_states else None + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = tuple( + past_state.to(hidden_states.device) for past_state in layer_past + ) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache, output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + ( + outputs[2 if use_cache else 1], + ) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + ( + outputs[3 if use_cache else 2], + ) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + presents, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) diff --git a/torchrl/modules/models/models.py b/torchrl/modules/models/models.py index 777aaf28888..26cd8c44b8e 100644 --- a/torchrl/modules/models/models.py +++ b/torchrl/modules/models/models.py @@ -1102,7 +1102,6 @@ def forward( hidden0_in: Optional[torch.Tensor] = None, hidden1_in: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - input = self.mlp(input) return self._lstm(input, hidden0_in, hidden1_in) @@ -1150,17 +1149,17 @@ def __init__( mlp_net_kwargs = mlp_net_kwargs if mlp_net_kwargs is not None else {} mlp_net_default_kwargs.update(mlp_net_kwargs) self.mlp = MLP(device=device, **mlp_net_default_kwargs) - self.apply(dt_actor_weight_init) + # self.apply(dt_actor_weight_init) - def forward(self, observation: torch.Tensor) -> torch.Tensor: - out = self.mlp(observation) + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + out = self.mlp(hidden_state) mu, log_std = out.chunk(2, -1) log_std = torch.tanh(log_std) log_std = min(self.log_std_bounds) + 0.5 * ( max(self.log_std_bounds) - min(self.log_std_bounds) ) * (log_std + 1.0) std = torch.exp(log_std) - return (mu, std) + return (mu, std, hidden_state) def dt_actor_weight_init(m): From 18d303505a91df7514fb36dd1cc6ede7955337a6 Mon Sep 17 00:00:00 2001 From: BY571 Date: Fri, 21 Apr 2023 16:06:13 +0200 Subject: [PATCH 003/104] update dt transforms --- examples/decision_transformer/config.yaml | 60 +++ examples/decision_transformer/dt_online.py | 118 ++--- examples/decision_transformer/utils.py | 438 +++++++++++++++++- .../modules/models/decision_transformer.py | 14 +- torchrl/modules/models/models.py | 23 +- 5 files changed, 562 insertions(+), 91 deletions(-) create mode 100644 examples/decision_transformer/config.yaml diff --git a/examples/decision_transformer/config.yaml b/examples/decision_transformer/config.yaml new file mode 100644 index 00000000000..56781ea0175 --- /dev/null +++ b/examples/decision_transformer/config.yaml @@ -0,0 +1,60 @@ +# Task and env +env: + env_name: Pendulum-v1 + env_task: "" + env_library: gym + record_video: 0 + stacked_frames: 5 + n_samples_stats: 1000 + frame_skip: 1 + from_pixels: False + num_envs: 1 + reward_scaling: + noop: 1 + seed: 0 + +# Collector +collector: + async_collection: 1 + frames_per_batch: 1000 + total_frames: 1000000 + multi_step: 0 + init_random_frames: 25000 + collector_devices: cpu # ,cpu,cpu,cpu] + num_collectors: 1 + max_frames_per_traj: 1000 + +# logger +logger: + backend: tensorboard + exp_name: td3_cheetah_gym + log_interval: 10000 # record interval in frames + eval_steps: 1000 + +# Buffer +replay_buffer: + prb: 0 + buffer_prefetch: 64 + capacity: 1_000_000 + +# Optimization +optim: + device: cpu + lr: 3e-4 + weight_decay: 0.0 + batch_size: 256 + lr_scheduler: "" + optim_steps_per_batch: 1000 + policy_update_delay: 2 + +# Policy and model +model: + ou_exploration: 0 + noisy: False + activation: relu + +# loss +loss: + loss_function: smooth_l1 + gamma: 0.99 + tau: 0.05 diff --git a/examples/decision_transformer/dt_online.py b/examples/decision_transformer/dt_online.py index 2f728c35d3a..2a6204630db 100644 --- a/examples/decision_transformer/dt_online.py +++ b/examples/decision_transformer/dt_online.py @@ -1,84 +1,68 @@ +import hydra import torch -from tensordict import TensorDict + from tensordict.nn import TensorDictModule -from torchrl.envs import EnvCreator, ParallelEnv -from torchrl.envs.libs.gym import GymEnv + from torchrl.modules import ProbabilisticActor from torchrl.modules.distributions import TanhNormal from torchrl.modules.models import DTActor -from torchrl.modules.models.decision_transformer import DecisionTransformer - - -def env_maker(env_name, frame_skip=1, device="cpu", from_pixels=False): - return GymEnv( - env_name, "run", device=device, frame_skip=frame_skip, from_pixels=from_pixels - ) +from utils import make_test_env -def env_factory(num_workers): - """Creates an instance of the environment.""" - - # 1.2 Create env vector - vec_env = ParallelEnv( - create_env_fn=EnvCreator(lambda: env_maker(env_name="Pendulum-v1")), - num_workers=num_workers, - ) - - return vec_env +@hydra.main(config_path=".", config_name="config") +def main(cfg: "DictConfig"): # noqa: F821 + # Sanity check + test_env = make_test_env(cfg.env) + # test_env = make_transformed_env(test_env) + action_spec = test_env.action_spec -# Sanity check -test_env = env_factory(num_workers=1) -action_spec = test_env.action_spec + in_keys = ["observation", "action", "return_to_go", "timesteps"] + # transformer = DecisionTransformer( + # state_dim=5, action_dim=2, hidden_size=512, max_ep_len=1000, ordering=False + # ) -in_keys = ["observation", "action", "returns_to_go", "timesteps", "padding_mask"] -transformer = DecisionTransformer( - state_dim=5, action_dim=2, hidden_size=512, max_ep_len=1000, ordering=False -) + actor_net = DTActor(action_dim=1) -actor_head = DTActor(action_dim=2) + # actor_net = torch.nn.ModuleList([transformer, actor_head]) -actor_net = torch.nn.ModuleList([transformer, actor_head]) + dist_class = TanhNormal + dist_kwargs = { + "min": -1.0, + "max": 1.0, + "tanh_loc": False, + } -dist_class = TanhNormal -dist_kwargs = { - "min": -1.0, - "max": 1.0, - "tanh_loc": False, -} + actor_module = TensorDictModule( + actor_net, in_keys=in_keys, out_keys=["loc", "scale"] # , "hidden_state"], + ) + actor = ProbabilisticActor( + spec=action_spec, + in_keys=["loc", "scale"], # , "hidden_state"], + out_keys=["action", "log_prob"], # , "hidden_state"], + module=actor_module, + distribution_class=dist_class, + distribution_kwargs=dist_kwargs, + default_interaction_mode="random", + cache_dist=True, + return_log_prob=False, + ) -actor_module = TensorDictModule( - actor_net, - in_keys=in_keys, - out_keys=["loc", "scale", "hidden_state"], -) -actor = ProbabilisticActor( - spec=action_spec, - in_keys=["loc", "scale", "hidden_state"], - out_keys=["action", "log_prob", "hidden_state"], - module=actor_module, - distribution_class=dist_class, - distribution_kwargs=dist_kwargs, - default_interaction_mode="random", - cache_dist=True, - return_log_prob=False, -) + print(actor) -print(transformer) + with torch.no_grad(): + test_env.eval() + actor.eval() + # Generate a complete episode + td_test = test_env.rollout( + policy=actor, + max_steps=100, + auto_reset=True, + auto_cast_to_device=True, + break_when_any_done=True, + ).clone() + print(td_test) -observation = torch.rand(1, 20, 5) -action = torch.rand(1, 20, 2) -reward_to_go = torch.rand(1, 20, 1) -padding_mask = torch.ones(1, 20, 1) -timesteps = torch.arange(1, 21).unsqueeze(0).unsqueeze(-1) -td = TensorDict( - { - "observation": observation, - "action": action, - "returns_to_go": reward_to_go, - "padding_mask": padding_mask, - "timesteps": timesteps, - }, - batch_size=1, -) +if __name__ == "__main__": + main() diff --git a/examples/decision_transformer/utils.py b/examples/decision_transformer/utils.py index 507d3f70815..ffe2dc84ec3 100644 --- a/examples/decision_transformer/utils.py +++ b/examples/decision_transformer/utils.py @@ -1,18 +1,428 @@ -# from torchrl.modules import DTActor +import torch.nn +import torch.optim +from tensordict.nn import TensorDictModule +from torchrl.collectors import MultiaSyncDataCollector, MultiSyncDataCollector +from torchrl.data import ( + CompositeSpec, + LazyMemmapStorage, + MultiStep, + TensorDictReplayBuffer, +) +from torchrl.data.replay_buffers.samplers import PrioritizedSampler, RandomSampler +from torchrl.envs import ( + CatFrames, + DoubleToFloat, + EnvCreator, + GrayScale, + NoopResetEnv, + ObservationNorm, + ParallelEnv, + RenameTransform, + Resize, + RewardScaling, + StepCounter, + TargetReturn, + TensorDictPrimer, + ToTensorImage, + TransformedEnv, + UnsqueezeTransform, +) +from torchrl.envs.libs.dm_control import DMControlEnv +from torchrl.envs.utils import set_exploration_mode +from torchrl.modules import ( + AdditiveGaussianWrapper, + ConvNet, + MLP, + OrnsteinUhlenbeckProcessWrapper, + ProbabilisticActor, + TanhDelta, + ValueOperator, +) +from torchrl.objectives import SoftUpdate, TD3Loss +from torchrl.record.loggers import generate_exp_name, get_logger +from torchrl.trainers.helpers.envs import LIBS +from torchrl.trainers.helpers.models import ACTIVATIONS -# def make_decision_transformer(cfg): -# transformer = transformer() -# actor_head = DTActor() +DEFAULT_REWARD_SCALING = { + "Hopper-v1": 5, + "Walker2d-v1": 5, + "HalfCheetah-v1": 5, + "cheetah": 5, + "Ant-v2": 5, + "Humanoid-v2": 20, + "humanoid": 100, +} -# # actor = ProbabilisticActor( -# # spec=action_spec, -# # in_keys=["loc", "scale"], -# # module=actor_module, -# # distribution_class=dist_class, -# # distribution_kwargs=dist_kwargs, -# # default_interaction_mode="random", -# # cash -# # return_log_prob=False, -# # ) +# ==================================================================== +# Environment utils +# ----------------- + + +def make_base_env(env_cfg, from_pixels=None): + env_library = LIBS[env_cfg.env_library] + env_name = env_cfg.env_name + frame_skip = env_cfg.frame_skip + if from_pixels is None: + from_pixels = env_cfg.from_pixels + + env_kwargs = { + "env_name": env_name, + "frame_skip": frame_skip, + "from_pixels": from_pixels, # for rendering + "pixels_only": False, + } + if env_library is DMControlEnv: + env_task = env_cfg.env_task + env_kwargs.update({"task_name": env_task}) + env = env_library(**env_kwargs) + if env_cfg.noop > 1: + env = TransformedEnv(env, NoopResetEnv(env_cfg.noop)) + return env + + +def make_transformed_env(base_env, env_cfg): + from_pixels = env_cfg.from_pixels + if from_pixels: + return make_transformed_env_pixels(base_env, env_cfg) + else: + return make_transformed_env_states(base_env, env_cfg) + + +def make_transformed_env_pixels(base_env, env_cfg): + if not isinstance(env_cfg.reward_scaling, float): + env_cfg.reward_scaling = DEFAULT_REWARD_SCALING.get(env_cfg.env_name, 5.0) + + env_library = LIBS[env_cfg.env_library] + env = TransformedEnv(base_env) + + reward_scaling = env_cfg.reward_scaling + + env.append_transform(RewardScaling(0.0, reward_scaling)) + + double_to_float_list = [] + double_to_float_inv_list = [] + + # + env.append_transform(ToTensorImage()) + env.append_transform(GrayScale()) + env.append_transform(Resize(84, 84)) + env.append_transform(CatFrames(N=4, dim=-3)) + + obs_norm = ObservationNorm(in_keys=["pixels"]) + env.append_transform(obs_norm) + + if env_library is DMControlEnv: + double_to_float_list += [ + "reward", + ] + double_to_float_list += [ + "action", + ] + double_to_float_inv_list += ["action"] # DMControl requires double-precision + double_to_float_list += ["observation_vector"] + else: + double_to_float_list += ["observation_vector"] + env.append_transform( + DoubleToFloat( + in_keys=double_to_float_list, in_keys_inv=double_to_float_inv_list + ) + ) + return env + + +def make_transformed_env_states(base_env, env_cfg): + transformed_env = TransformedEnv(base_env) + + transformed_env.append_transform(StepCounter()) + transformed_env.append_transform(RenameTransform(["step_count"], ["timesteps"])) + transformed_env.append_transform( + TargetReturn(200 * 0.01, out_keys=["return_to_go"]) + ) + # transformed_env.append_transform(SCALE) + transformed_env.append_transform(TensorDictPrimer(action=base_env.action_spec)) + # transformed_env.append_transform(TensorDictPrimer(padding_mask=env.action_spec)) + + transformed_env.append_transform(UnsqueezeTransform(-2, in_keys=["observation"])) + transformed_env.append_transform( + CatFrames(in_keys=["observation"], N=env_cfg.stacked_frames, dim=-2) + ) + + transformed_env.append_transform(UnsqueezeTransform(-2, in_keys=["action"])) + transformed_env.append_transform( + CatFrames(in_keys=["action"], N=env_cfg.stacked_frames, dim=-2) + ) + + transformed_env.append_transform(UnsqueezeTransform(-2, in_keys=["return_to_go"])) + transformed_env.append_transform( + CatFrames(in_keys=["return_to_go"], N=env_cfg.stacked_frames, dim=-2) + ) + + transformed_env.append_transform(UnsqueezeTransform(-2, in_keys=["done"])) + transformed_env.append_transform( + CatFrames(in_keys=["done"], N=env_cfg.stacked_frames, dim=-2) + ) + + transformed_env.append_transform(UnsqueezeTransform(-2, in_keys=["timesteps"])) + transformed_env.append_transform( + CatFrames(in_keys=["timesteps"], N=env_cfg.stacked_frames, dim=-2) + ) + return transformed_env + + +def make_parallel_env(env_cfg, state_dict): + num_envs = env_cfg.num_envs + env = make_transformed_env( + ParallelEnv(num_envs, EnvCreator(lambda: make_base_env(env_cfg))), env_cfg + ) + for t in env.transform: + if isinstance(t, ObservationNorm): + t.init_stats(3, cat_dim=1, reduce_dim=[0, 1]) + env.load_state_dict(state_dict) + return env + + +def make_test_env(env_cfg): + env_cfg.num_envs = 1 + state_dict = get_stats(env_cfg) + env = make_parallel_env(env_cfg, state_dict=state_dict) + return env + + +def get_stats(env_cfg): + from_pixels = env_cfg.from_pixels + env = make_transformed_env(make_base_env(env_cfg), env_cfg) + init_stats(env, env_cfg.n_samples_stats, from_pixels) + return env.state_dict() + + +def init_stats(env, n_samples_stats, from_pixels): + for t in env.transform: + if isinstance(t, ObservationNorm): + if from_pixels: + t.init_stats( + n_samples_stats, + cat_dim=-3, + reduce_dim=(-1, -2, -3), + keep_dims=(-1, -2, -3), + ) + else: + t.init_stats(n_samples_stats) + + +# ==================================================================== +# Collector and replay buffer +# --------------------------- + + +def make_collector(cfg, state_dict, policy): + env_cfg = cfg.env + loss_cfg = cfg.loss + collector_cfg = cfg.collector + if collector_cfg.async_collection: + collector_class = MultiaSyncDataCollector + else: + collector_class = MultiSyncDataCollector + if collector_cfg.multi_step: + ms = MultiStep(gamma=loss_cfg.gamma, n_steps=collector_cfg.multi_step) + else: + ms = None + collector = collector_class( + [make_parallel_env(env_cfg, state_dict=state_dict)] + * collector_cfg.num_collectors, + policy, + frames_per_batch=collector_cfg.frames_per_batch, + total_frames=collector_cfg.total_frames, + postproc=ms, + device=collector_cfg.collector_devices, + init_random_frames=collector_cfg.init_random_frames, + max_frames_per_traj=collector_cfg.max_frames_per_traj, + ) + return collector + + +def make_replay_buffer(rb_cfg): + if rb_cfg.prb: + sampler = PrioritizedSampler(max_capacity=rb_cfg.capacity, alpha=0.7, beta=0.5) + else: + sampler = RandomSampler() + return TensorDictReplayBuffer( + storage=LazyMemmapStorage(rb_cfg.capacity), sampler=sampler + ) + + +# ==================================================================== +# Model +# ----- +# +# We give one version of the model for learning from pixels, and one for state. +# TorchRL comes in handy at this point, as the high-level interactions with +# these models is unchanged, regardless of the modality. +# + + +def make_td3_model(cfg): + env_cfg = cfg.env + model_cfg = cfg.model + proof_environment = make_transformed_env(make_base_env(env_cfg), env_cfg) + # we must initialize the observation norm transform + init_stats(proof_environment, n_samples_stats=3, from_pixels=env_cfg.from_pixels) + + env_specs = proof_environment.specs + from_pixels = env_cfg.from_pixels + + if not from_pixels: + actor_net, q_net = make_td3_modules_state(model_cfg, proof_environment) + in_keys = ["observation_vector"] + out_keys = ["param"] + else: + actor_net, q_net = make_td3_modules_pixels(model_cfg, proof_environment) + in_keys = ["pixels"] + out_keys = ["param", "hidden"] + + actor_module = TensorDictModule(actor_net, in_keys=in_keys, out_keys=out_keys) + + # We use a ProbabilisticActor to make sure that we map the + # network output to the right space using a TanhDelta + # distribution. + actor = ProbabilisticActor( + module=actor_module, + in_keys=["param"], + spec=CompositeSpec(action=env_specs["input_spec"]["action"]), + safe=False, + distribution_class=TanhDelta, + distribution_kwargs={ + "min": env_specs["input_spec"]["action"].space.minimum, + "max": env_specs["input_spec"]["action"].space.maximum, + }, + ) + + if not from_pixels: + in_keys = ["observation_vector", "action"] + else: + in_keys = ["pixels", "action"] + + out_keys = ["state_action_value"] + qvalue = ValueOperator( + in_keys=in_keys, + out_keys=out_keys, + module=q_net, + ) + + # init the lazy layers + with torch.no_grad(), set_exploration_mode("random"): + # for t in proof_environment.transform: + # if isinstance(t, ObservationNorm): + # t.init_stats(2) + td = proof_environment.rollout(max_steps=1000) + print(td) + actor(td) + qvalue(td) + + return actor, qvalue + + +def make_td3_modules_state(model_cfg, proof_environment): + env_specs = proof_environment.specs + out_features = env_specs["input_spec"]["action"].shape[0] + + actor_net_kwargs = { + "num_cells": [256, 256], + "out_features": out_features, + "activation_class": ACTIVATIONS[model_cfg.activation], + } + actor_net = MLP(**actor_net_kwargs) + + qvalue_net_kwargs = { + "num_cells": [256, 256], + "out_features": 1, + "activation_class": ACTIVATIONS[model_cfg.activation], + } + + q_net = MLP(**qvalue_net_kwargs) + return actor_net, q_net + + +def make_td3_modules_pixels(model_cfg, proof_environment): + env_specs = proof_environment.specs + out_features = env_specs["input_spec"]["action"].shape[0] + + actor_net = torch.nn.ModuleList() + + actor_convnet_kwargs = {"activation_class": ACTIVATIONS[model_cfg.activation]} + actor_net.append(ConvNet(**actor_convnet_kwargs)) + + actor_net_kwargs = { + "num_cells": [256, 256], + "out_features": out_features, + "activation_class": ACTIVATIONS[model_cfg.activation], + } + actor_net.append(MLP(**actor_net_kwargs)) + + q_net = torch.nn.ModuleList() + + q_net_convnet_kwargs = {"activation_class": ACTIVATIONS[model_cfg.activation]} + q_net.append(ConvNet(**q_net_convnet_kwargs)) + + qvalue_net_kwargs = { + "num_cells": [256, 256], + "out_features": 1, + "activation_class": ACTIVATIONS[model_cfg.activation], + } + + q_net.append(MLP(**qvalue_net_kwargs)) + + return actor_net, q_net + + +def make_policy(model_cfg, actor): + if model_cfg.ou_exploration: + return OrnsteinUhlenbeckProcessWrapper(actor) + else: + return AdditiveGaussianWrapper(actor) + + +# ==================================================================== +# TD3 Loss +# --------- + + +def make_loss(loss_cfg, actor_network, qvalue_network): + loss = TD3Loss( + actor_network, + qvalue_network, + gamma=loss_cfg.gamma, + loss_function=loss_cfg.loss_function, + policy_noise=0.2, + noise_clip=0.5, + ) + target_net_updater = SoftUpdate(loss, 1 - loss_cfg.tau) + target_net_updater.init_() + return loss, target_net_updater + + +def make_td3_optimizer(optim_cfg, actor_network, qvalue_network): + actor_optim = torch.optim.Adam( + actor_network.parameters(), + lr=optim_cfg.lr, + weight_decay=optim_cfg.weight_decay, + ) + critic_optim = torch.optim.Adam( + qvalue_network.parameters(), + lr=optim_cfg.lr, + weight_decay=optim_cfg.weight_decay, + ) + return actor_optim, critic_optim + + +# ==================================================================== +# Logging and recording +# --------------------- + + +def make_logger(logger_cfg): + exp_name = generate_exp_name("TD3", logger_cfg.exp_name) + logger_cfg.exp_name = exp_name + logger = get_logger(logger_cfg.backend, logger_name="td3", experiment_name=exp_name) + return logger diff --git a/torchrl/modules/models/decision_transformer.py b/torchrl/modules/models/decision_transformer.py index 6219dca8406..50ac04177ef 100644 --- a/torchrl/modules/models/decision_transformer.py +++ b/torchrl/modules/models/decision_transformer.py @@ -41,22 +41,22 @@ def __init__( def forward( self, - observations: torch.Tensor, - actions: torch.Tensor, - returns_to_go: torch.Tensor, + observation: torch.Tensor, + action: torch.Tensor, + return_to_go: torch.Tensor, timesteps: torch.Tensor, padding_mask: Optional[torch.Tensor] = None, ): - batch_size, seq_length = observations.shape[0], observations.shape[1] + batch_size, seq_length = observation.shape[0], observation.shape[1] if padding_mask is None: # attention mask for GPT: 1 if can be attended to, 0 if not padding_mask = torch.ones((batch_size, seq_length), dtype=torch.long) # embed each modality with a different head - state_embeddings = self.embed_state(observations) - action_embeddings = self.embed_action(actions) - returns_embeddings = self.embed_return(returns_to_go) + state_embeddings = self.embed_state(observation) + action_embeddings = self.embed_action(action) + returns_embeddings = self.embed_return(return_to_go) if self.ordering: order_embeddings = self.embed_ordering(timesteps) diff --git a/torchrl/modules/models/models.py b/torchrl/modules/models/models.py index 12b72d5ad80..57f6b746ed0 100644 --- a/torchrl/modules/models/models.py +++ b/torchrl/modules/models/models.py @@ -1135,6 +1135,9 @@ def forward( return self._lstm(input, hidden0_in, hidden1_in) +from torchrl.modules.models.decision_transformer import DecisionTransformer + + class DTActor(nn.Module): """Decision Transformer Actor class. @@ -1166,12 +1169,19 @@ def __init__( ): super().__init__() mlp_net_default_kwargs = { - "out_features": action_dim, + "out_features": action_dim * 2, "depth": 1, "num_cells": [512], "activation_class": nn.ReLU, "bias_last_layer": True, } + self.transformer = DecisionTransformer( + state_dim=3, + action_dim=action_dim, + hidden_size=512, + max_ep_len=1000, + ordering=False, + ) # log_std_bounds: Tuple[float, float] = [-5.0, 2.0], log_std_bounds = [-5.0, 2.0] self.log_std_bounds = log_std_bounds @@ -1180,7 +1190,14 @@ def __init__( self.mlp = MLP(device=device, **mlp_net_default_kwargs) # self.apply(dt_actor_weight_init) - def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + def forward( + self, + observation: torch.Tensor, + action: torch.Tensor, + return_to_go: torch.Tensor, + timesteps: torch.Tensor, + ) -> torch.Tensor: + hidden_state = self.transformer(observation, action, return_to_go, timesteps) out = self.mlp(hidden_state) mu, log_std = out.chunk(2, -1) log_std = torch.tanh(log_std) @@ -1188,7 +1205,7 @@ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: max(self.log_std_bounds) - min(self.log_std_bounds) ) * (log_std + 1.0) std = torch.exp(log_std) - return (mu, std, hidden_state) + return (mu, std) def dt_actor_weight_init(m): From d521fa26e176096becd81973e6b809396f9edd74 Mon Sep 17 00:00:00 2001 From: BY571 Date: Fri, 21 Apr 2023 17:58:46 +0200 Subject: [PATCH 004/104] update padding --- examples/decision_transformer/utils.py | 4 +- .../modules/models/decision_transformer.py | 42 ++++++++++++++++++- torchrl/modules/models/models.py | 2 +- 3 files changed, 45 insertions(+), 3 deletions(-) diff --git a/examples/decision_transformer/utils.py b/examples/decision_transformer/utils.py index ffe2dc84ec3..6b913cbbd27 100644 --- a/examples/decision_transformer/utils.py +++ b/examples/decision_transformer/utils.py @@ -136,7 +136,9 @@ def make_transformed_env_states(base_env, env_cfg): transformed_env = TransformedEnv(base_env) transformed_env.append_transform(StepCounter()) - transformed_env.append_transform(RenameTransform(["step_count"], ["timesteps"])) + transformed_env.append_transform( + RenameTransform(["step_count"], ["timesteps"], create_copy=True) + ) transformed_env.append_transform( TargetReturn(200 * 0.01, out_keys=["return_to_go"]) ) diff --git a/torchrl/modules/models/decision_transformer.py b/torchrl/modules/models/decision_transformer.py index 50ac04177ef..9943bb5506b 100644 --- a/torchrl/modules/models/decision_transformer.py +++ b/torchrl/modules/models/decision_transformer.py @@ -29,6 +29,8 @@ def __init__( self.action_dim = action_dim self.hidden_size = hidden_size self.ordering = ordering + self.train_context = 20 + self.inference_context = 5 self.transformer = GPT2Model(config=gpt_config) if ordering: @@ -49,6 +51,11 @@ def forward( ): batch_size, seq_length = observation.shape[0], observation.shape[1] + if seq_length == self.inference_context: + observation, action, return_to_go, timesteps, seq_length = self.pad_context( + observation, action, return_to_go, timesteps + ) + if padding_mask is None: # attention mask for GPT: 1 if can be attended to, 0 if not padding_mask = torch.ones((batch_size, seq_length), dtype=torch.long) @@ -96,4 +103,37 @@ def forward( # returns (0), states (1), or actions (2); i.e. x[:,1,t] is the token for s_t x = x.reshape(batch_size, seq_length, 3, self.hidden_size).permute(0, 2, 1, 3) - return x[:, 1] + return x[:, 1] # only state tokens + + def pad_context( + self, + observation: torch.Tensor, + action: torch.Tensor, + return_to_go: torch.Tensor, + timesteps: torch.Tensor, + ): + observation = torch.nn.functional.pad( + observation, + (0, 0, self.train_context - self.inference_context, 0), + mode="constant", + value=0, + ) + action = torch.nn.functional.pad( + action, + (0, 0, self.train_context - self.inference_context - 1, 1), + mode="constant", + value=0, + ) # pad first action with 0 + return_to_go = torch.nn.functional.pad( + return_to_go, + (0, 0, self.train_context - self.inference_context - 1, 1), + mode="constant", + value=0, + ) + timesteps = torch.nn.functional.pad( + timesteps, + (0, 0, self.train_context - self.inference_context, 0), + mode="constant", + value=0, + ) + return observation, action, return_to_go, timesteps, self.train_context diff --git a/torchrl/modules/models/models.py b/torchrl/modules/models/models.py index 57f6b746ed0..be6c038bc28 100644 --- a/torchrl/modules/models/models.py +++ b/torchrl/modules/models/models.py @@ -1198,7 +1198,7 @@ def forward( timesteps: torch.Tensor, ) -> torch.Tensor: hidden_state = self.transformer(observation, action, return_to_go, timesteps) - out = self.mlp(hidden_state) + out = self.mlp(hidden_state)[:, -1] mu, log_std = out.chunk(2, -1) log_std = torch.tanh(log_std) log_std = min(self.log_std_bounds) + 0.5 * ( From c123fe0ae6877cea900ed656cc8b394f8be7b272 Mon Sep 17 00:00:00 2001 From: BY571 Date: Mon, 24 Apr 2023 09:19:17 +0200 Subject: [PATCH 005/104] take off outputhead --- examples/decision_transformer/dt_online.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/decision_transformer/dt_online.py b/examples/decision_transformer/dt_online.py index 2a6204630db..ebedd22fd84 100644 --- a/examples/decision_transformer/dt_online.py +++ b/examples/decision_transformer/dt_online.py @@ -24,8 +24,6 @@ def main(cfg: "DictConfig"): # noqa: F821 actor_net = DTActor(action_dim=1) - # actor_net = torch.nn.ModuleList([transformer, actor_head]) - dist_class = TanhNormal dist_kwargs = { "min": -1.0, From cfcc073b51810ff2916f33877e45757d9aa33aa4 Mon Sep 17 00:00:00 2001 From: BY571 Date: Wed, 26 Apr 2023 08:48:28 +0200 Subject: [PATCH 006/104] update target and testscript --- examples/decision_transformer/dt_online.py | 21 +++-- examples/decision_transformer/utils.py | 105 +++------------------ torchrl/envs/transforms/transforms.py | 23 ++--- 3 files changed, 40 insertions(+), 109 deletions(-) diff --git a/examples/decision_transformer/dt_online.py b/examples/decision_transformer/dt_online.py index ebedd22fd84..aad153b709d 100644 --- a/examples/decision_transformer/dt_online.py +++ b/examples/decision_transformer/dt_online.py @@ -6,21 +6,22 @@ from torchrl.modules import ProbabilisticActor from torchrl.modules.distributions import TanhNormal from torchrl.modules.models import DTActor -from utils import make_test_env +from utils import make_collector, make_replay_buffer, make_test_env @hydra.main(config_path=".", config_name="config") def main(cfg: "DictConfig"): # noqa: F821 - # Sanity check test_env = make_test_env(cfg.env) # test_env = make_transformed_env(test_env) action_spec = test_env.action_spec - in_keys = ["observation", "action", "return_to_go", "timesteps"] - # transformer = DecisionTransformer( - # state_dim=5, action_dim=2, hidden_size=512, max_ep_len=1000, ordering=False - # ) + in_keys = [ + "observation", + "action", + "return_to_go", + "timesteps", + ] # return_to_go, timesteps actor_net = DTActor(action_dim=1) @@ -54,13 +55,19 @@ def main(cfg: "DictConfig"): # noqa: F821 # Generate a complete episode td_test = test_env.rollout( policy=actor, - max_steps=100, + max_steps=30, auto_reset=True, auto_cast_to_device=True, break_when_any_done=True, ).clone() print(td_test) + collector = make_collector(cfg, policy=actor) + replay_buffer = make_replay_buffer(cfg.replay_buffer) + for data in collector: + data_view = data.reshape(-1) + replay_buffer.extend(data_view) + if __name__ == "__main__": main() diff --git a/examples/decision_transformer/utils.py b/examples/decision_transformer/utils.py index 6b913cbbd27..281a59d9721 100644 --- a/examples/decision_transformer/utils.py +++ b/examples/decision_transformer/utils.py @@ -2,29 +2,19 @@ import torch.optim from tensordict.nn import TensorDictModule -from torchrl.collectors import MultiaSyncDataCollector, MultiSyncDataCollector -from torchrl.data import ( - CompositeSpec, - LazyMemmapStorage, - MultiStep, - TensorDictReplayBuffer, -) +from torchrl.collectors import SyncDataCollector +from torchrl.data import CompositeSpec, LazyMemmapStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import PrioritizedSampler, RandomSampler from torchrl.envs import ( CatFrames, - DoubleToFloat, EnvCreator, - GrayScale, NoopResetEnv, ObservationNorm, ParallelEnv, RenameTransform, - Resize, - RewardScaling, StepCounter, TargetReturn, TensorDictPrimer, - ToTensorImage, TransformedEnv, UnsqueezeTransform, ) @@ -83,53 +73,7 @@ def make_base_env(env_cfg, from_pixels=None): def make_transformed_env(base_env, env_cfg): - from_pixels = env_cfg.from_pixels - if from_pixels: - return make_transformed_env_pixels(base_env, env_cfg) - else: - return make_transformed_env_states(base_env, env_cfg) - - -def make_transformed_env_pixels(base_env, env_cfg): - if not isinstance(env_cfg.reward_scaling, float): - env_cfg.reward_scaling = DEFAULT_REWARD_SCALING.get(env_cfg.env_name, 5.0) - - env_library = LIBS[env_cfg.env_library] - env = TransformedEnv(base_env) - - reward_scaling = env_cfg.reward_scaling - - env.append_transform(RewardScaling(0.0, reward_scaling)) - - double_to_float_list = [] - double_to_float_inv_list = [] - - # - env.append_transform(ToTensorImage()) - env.append_transform(GrayScale()) - env.append_transform(Resize(84, 84)) - env.append_transform(CatFrames(N=4, dim=-3)) - - obs_norm = ObservationNorm(in_keys=["pixels"]) - env.append_transform(obs_norm) - - if env_library is DMControlEnv: - double_to_float_list += [ - "reward", - ] - double_to_float_list += [ - "action", - ] - double_to_float_inv_list += ["action"] # DMControl requires double-precision - double_to_float_list += ["observation_vector"] - else: - double_to_float_list += ["observation_vector"] - env.append_transform( - DoubleToFloat( - in_keys=double_to_float_list, in_keys_inv=double_to_float_inv_list - ) - ) - return env + return make_transformed_env_states(base_env, env_cfg) def make_transformed_env_states(base_env, env_cfg): @@ -140,7 +84,9 @@ def make_transformed_env_states(base_env, env_cfg): RenameTransform(["step_count"], ["timesteps"], create_copy=True) ) transformed_env.append_transform( - TargetReturn(200 * 0.01, out_keys=["return_to_go"]) + TargetReturn( + 200 * 0.01, out_keys=["return_to_go"] + ) # WATCH OUT FOR THE SCALING! ) # transformed_env.append_transform(SCALE) transformed_env.append_transform(TensorDictPrimer(action=base_env.action_spec)) @@ -161,15 +107,11 @@ def make_transformed_env_states(base_env, env_cfg): CatFrames(in_keys=["return_to_go"], N=env_cfg.stacked_frames, dim=-2) ) - transformed_env.append_transform(UnsqueezeTransform(-2, in_keys=["done"])) - transformed_env.append_transform( - CatFrames(in_keys=["done"], N=env_cfg.stacked_frames, dim=-2) - ) - transformed_env.append_transform(UnsqueezeTransform(-2, in_keys=["timesteps"])) transformed_env.append_transform( CatFrames(in_keys=["timesteps"], N=env_cfg.stacked_frames, dim=-2) ) + return transformed_env @@ -193,24 +135,15 @@ def make_test_env(env_cfg): def get_stats(env_cfg): - from_pixels = env_cfg.from_pixels env = make_transformed_env(make_base_env(env_cfg), env_cfg) - init_stats(env, env_cfg.n_samples_stats, from_pixels) + init_stats(env, env_cfg.n_samples_stats) return env.state_dict() -def init_stats(env, n_samples_stats, from_pixels): +def init_stats(env, n_samples_stats): for t in env.transform: if isinstance(t, ObservationNorm): - if from_pixels: - t.init_stats( - n_samples_stats, - cat_dim=-3, - reduce_dim=(-1, -2, -3), - keep_dims=(-1, -2, -3), - ) - else: - t.init_stats(n_samples_stats) + t.init_stats(n_samples_stats) # ==================================================================== @@ -218,27 +151,17 @@ def init_stats(env, n_samples_stats, from_pixels): # --------------------------- -def make_collector(cfg, state_dict, policy): +def make_collector(cfg, policy): env_cfg = cfg.env - loss_cfg = cfg.loss collector_cfg = cfg.collector - if collector_cfg.async_collection: - collector_class = MultiaSyncDataCollector - else: - collector_class = MultiSyncDataCollector - if collector_cfg.multi_step: - ms = MultiStep(gamma=loss_cfg.gamma, n_steps=collector_cfg.multi_step) - else: - ms = None + collector_class = SyncDataCollector + state_dict = get_stats(env_cfg) collector = collector_class( - [make_parallel_env(env_cfg, state_dict=state_dict)] - * collector_cfg.num_collectors, + make_parallel_env(env_cfg, state_dict=state_dict), policy, frames_per_batch=collector_cfg.frames_per_batch, total_frames=collector_cfg.total_frames, - postproc=ms, device=collector_cfg.collector_devices, - init_random_frames=collector_cfg.init_random_frames, max_frames_per_traj=collector_cfg.max_frames_per_traj, ) return collector diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 9c7c2a04159..00b19f7dc55 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -1056,6 +1056,7 @@ def _call(self, tensordict: TensorDict) -> TensorDict: return tensordict def _step(self, tensordict: TensorDictBase) -> TensorDictBase: + # sets the current root target as next target return for out_key in self.out_keys: if isinstance(out_key, str): out_key = (out_key,) @@ -1066,10 +1067,10 @@ def _apply_transform( self, reward: torch.Tensor, target_return: torch.Tensor ) -> torch.Tensor: if self.mode == "reduce": - target_return = target_return - reward + target_return = target_return[:, -1] - reward return target_return elif self.mode == "constant": - return target_return + return target_return[:, -1] else: raise ValueError("Unknown mode: {}".format(self.mode)) @@ -1085,15 +1086,15 @@ def transform_observation_spec( raise ValueError( f"observation_spec was expected to be of type CompositeSpec. Got {type(observation_spec)} instead." ) - - target_return_spec = BoundedTensorSpec( - minimum=-float("inf"), - maximum=self.target_return, - shape=self.parent.reward_spec.shape, - dtype=self.parent.reward_spec.dtype, - device=self.parent.reward_spec.device, - ) - observation_spec["target_return"] = target_return_spec + for key in self.out_keys: + target_return_spec = BoundedTensorSpec( + minimum=-float("inf"), + maximum=self.target_return, + shape=self.parent.reward_spec.shape, + dtype=self.parent.reward_spec.dtype, + device=self.parent.reward_spec.device, + ) + observation_spec[key] = target_return_spec return observation_spec From e377ae8360028f2233dc86fd852f11c6001782fa Mon Sep 17 00:00:00 2001 From: BY571 Date: Wed, 26 Apr 2023 13:57:54 +0200 Subject: [PATCH 007/104] add r2g --- examples/decision_transformer/dt_online.py | 1 + examples/decision_transformer/utils.py | 11 +++++------ 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/decision_transformer/dt_online.py b/examples/decision_transformer/dt_online.py index aad153b709d..0373f4683d3 100644 --- a/examples/decision_transformer/dt_online.py +++ b/examples/decision_transformer/dt_online.py @@ -63,6 +63,7 @@ def main(cfg: "DictConfig"): # noqa: F821 print(td_test) collector = make_collector(cfg, policy=actor) + replay_buffer = make_replay_buffer(cfg.replay_buffer) for data in collector: data_view = data.reshape(-1) diff --git a/examples/decision_transformer/utils.py b/examples/decision_transformer/utils.py index 281a59d9721..4879c59ad7d 100644 --- a/examples/decision_transformer/utils.py +++ b/examples/decision_transformer/utils.py @@ -4,7 +4,7 @@ from torchrl.collectors import SyncDataCollector from torchrl.data import CompositeSpec, LazyMemmapStorage, TensorDictReplayBuffer -from torchrl.data.replay_buffers.samplers import PrioritizedSampler, RandomSampler +from torchrl.data.replay_buffers.samplers import RandomSampler from torchrl.envs import ( CatFrames, EnvCreator, @@ -12,6 +12,7 @@ ObservationNorm, ParallelEnv, RenameTransform, + Reward2GoTransform, StepCounter, TargetReturn, TensorDictPrimer, @@ -168,12 +169,10 @@ def make_collector(cfg, policy): def make_replay_buffer(rb_cfg): - if rb_cfg.prb: - sampler = PrioritizedSampler(max_capacity=rb_cfg.capacity, alpha=0.7, beta=0.5) - else: - sampler = RandomSampler() + r2g = Reward2GoTransform(gamma=1.0, out_keys=["return_to_go"]) + sampler = RandomSampler() return TensorDictReplayBuffer( - storage=LazyMemmapStorage(rb_cfg.capacity), sampler=sampler + storage=LazyMemmapStorage(rb_cfg.capacity), sampler=sampler, transform=r2g ) From 8b69d6a01d769a1f5ed4a45b901640d9e85119c5 Mon Sep 17 00:00:00 2001 From: BY571 Date: Fri, 28 Apr 2023 17:38:55 +0200 Subject: [PATCH 008/104] update context mask --- examples/decision_transformer/config.yaml | 2 +- examples/decision_transformer/utils.py | 10 +++- torchrl/envs/transforms/transforms.py | 5 +- .../modules/models/decision_transformer.py | 60 ++++++------------- torchrl/modules/models/models.py | 5 +- 5 files changed, 32 insertions(+), 50 deletions(-) diff --git a/examples/decision_transformer/config.yaml b/examples/decision_transformer/config.yaml index 56781ea0175..6793d55628a 100644 --- a/examples/decision_transformer/config.yaml +++ b/examples/decision_transformer/config.yaml @@ -4,7 +4,7 @@ env: env_task: "" env_library: gym record_video: 0 - stacked_frames: 5 + stacked_frames: 20 n_samples_stats: 1000 frame_skip: 1 from_pixels: False diff --git a/examples/decision_transformer/utils.py b/examples/decision_transformer/utils.py index 4879c59ad7d..0a278343443 100644 --- a/examples/decision_transformer/utils.py +++ b/examples/decision_transformer/utils.py @@ -7,7 +7,9 @@ from torchrl.data.replay_buffers.samplers import RandomSampler from torchrl.envs import ( CatFrames, + Compose, EnvCreator, + ExcludeTransform, NoopResetEnv, ObservationNorm, ParallelEnv, @@ -157,6 +159,8 @@ def make_collector(cfg, policy): collector_cfg = cfg.collector collector_class = SyncDataCollector state_dict = get_stats(env_cfg) + # to exclude inference target returns + exclude = ExcludeTransform("return_to_go") # next return to go collector = collector_class( make_parallel_env(env_cfg, state_dict=state_dict), policy, @@ -164,15 +168,19 @@ def make_collector(cfg, policy): total_frames=collector_cfg.total_frames, device=collector_cfg.collector_devices, max_frames_per_traj=collector_cfg.max_frames_per_traj, + postproc=exclude, ) return collector def make_replay_buffer(rb_cfg): r2g = Reward2GoTransform(gamma=1.0, out_keys=["return_to_go"]) + transforms = [r2g] sampler = RandomSampler() return TensorDictReplayBuffer( - storage=LazyMemmapStorage(rb_cfg.capacity), sampler=sampler, transform=r2g + storage=LazyMemmapStorage(rb_cfg.capacity), + sampler=sampler, + transform=Compose(*transforms), ) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index f6df146092c..7bddc90a78b 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -3956,10 +3956,7 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: item = self._inv_apply_transform( tensordict.get(in_key), done_or_truncated ) - tensordict.set( - out_key, - item, - ) + tensordict.set(out_key, item, inplace=True) if not found: raise KeyError(f"Could not find any of the input keys {self.in_keys}.") return tensordict diff --git a/torchrl/modules/models/decision_transformer.py b/torchrl/modules/models/decision_transformer.py index 9943bb5506b..39407348baa 100644 --- a/torchrl/modules/models/decision_transformer.py +++ b/torchrl/modules/models/decision_transformer.py @@ -1,5 +1,3 @@ -from typing import Optional - import torch import torch.nn as nn import transformers @@ -47,18 +45,18 @@ def forward( action: torch.Tensor, return_to_go: torch.Tensor, timesteps: torch.Tensor, - padding_mask: Optional[torch.Tensor] = None, + mask_context: bool = True, ): batch_size, seq_length = observation.shape[0], observation.shape[1] - if seq_length == self.inference_context: - observation, action, return_to_go, timesteps, seq_length = self.pad_context( - observation, action, return_to_go, timesteps - ) - - if padding_mask is None: - # attention mask for GPT: 1 if can be attended to, 0 if not - padding_mask = torch.ones((batch_size, seq_length), dtype=torch.long) + if mask_context: + ( + observation, + action, + return_to_go, + timesteps, + seq_length, + ) = self.mask_context(observation, action, return_to_go, timesteps) # embed each modality with a different head state_embeddings = self.embed_state(observation) @@ -85,17 +83,9 @@ def forward( ) stacked_inputs = self.embed_ln(stacked_inputs) - # to make the attention mask fit the stacked inputs, have to stack it as well - stacked_padding_mask = ( - torch.stack((padding_mask, padding_mask, padding_mask), dim=1) - .permute(0, 2, 1) - .reshape(batch_size, 3 * seq_length) - ) - # we feed in the input embeddings (not word indices as in NLP) to the model transformer_outputs = self.transformer( inputs_embeds=stacked_inputs, - attention_mask=stacked_padding_mask, ) x = transformer_outputs["last_hidden_state"] @@ -105,35 +95,19 @@ def forward( return x[:, 1] # only state tokens - def pad_context( + def mask_context( self, observation: torch.Tensor, action: torch.Tensor, return_to_go: torch.Tensor, timesteps: torch.Tensor, ): - observation = torch.nn.functional.pad( - observation, - (0, 0, self.train_context - self.inference_context, 0), - mode="constant", - value=0, - ) - action = torch.nn.functional.pad( - action, - (0, 0, self.train_context - self.inference_context - 1, 1), - mode="constant", - value=0, - ) # pad first action with 0 - return_to_go = torch.nn.functional.pad( - return_to_go, - (0, 0, self.train_context - self.inference_context - 1, 1), - mode="constant", - value=0, - ) - timesteps = torch.nn.functional.pad( - timesteps, - (0, 0, self.train_context - self.inference_context, 0), - mode="constant", - value=0, + """Mask the context of the input sequences.""" + observation[:, : -self.inference_context, :] = 0 + action[:, : -self.inference_context, :] = 0 + action = torch.cat( + [action[:, 1:], torch.zeros(action.shape[0], 1, self.action_dim)], dim=-2 ) + return_to_go[:, : -self.inference_context, :] = 0 + timesteps[:, : -self.inference_context] = 0 return observation, action, return_to_go, timesteps, self.train_context diff --git a/torchrl/modules/models/models.py b/torchrl/modules/models/models.py index 202723865db..3e0f5f285bf 100644 --- a/torchrl/modules/models/models.py +++ b/torchrl/modules/models/models.py @@ -1200,8 +1200,11 @@ def forward( action: torch.Tensor, return_to_go: torch.Tensor, timesteps: torch.Tensor, + mask_context: bool = True, ) -> torch.Tensor: - hidden_state = self.transformer(observation, action, return_to_go, timesteps) + hidden_state = self.transformer( + observation, action, return_to_go, timesteps, mask_context + ) out = self.mlp(hidden_state)[:, -1] mu, log_std = out.chunk(2, -1) log_std = torch.tanh(log_std) From 7b9d029cbe0f29368e4022d7361acdbdee716c34 Mon Sep 17 00:00:00 2001 From: BY571 Date: Tue, 2 May 2023 11:20:47 +0200 Subject: [PATCH 009/104] add offline example script first tests --- examples/decision_transformer/config.yaml | 2 + examples/decision_transformer/dt_offline.py | 86 +++++++++ examples/decision_transformer/utils.py | 203 +++++++------------- 3 files changed, 157 insertions(+), 134 deletions(-) create mode 100644 examples/decision_transformer/dt_offline.py diff --git a/examples/decision_transformer/config.yaml b/examples/decision_transformer/config.yaml index 6793d55628a..279715b46b5 100644 --- a/examples/decision_transformer/config.yaml +++ b/examples/decision_transformer/config.yaml @@ -33,6 +33,8 @@ logger: # Buffer replay_buffer: + dataset: hopper-medium-v2 + batch_size: 256 prb: 0 buffer_prefetch: 64 capacity: 1_000_000 diff --git a/examples/decision_transformer/dt_offline.py b/examples/decision_transformer/dt_offline.py new file mode 100644 index 00000000000..70af977dfce --- /dev/null +++ b/examples/decision_transformer/dt_offline.py @@ -0,0 +1,86 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Decision Transformer Example. +This is a self-contained example of an offline Decision Transformer training script. +The helper functions are coded in the utils.py associated with this script. +""" + +import hydra +import torch +import tqdm +from torchrl.envs.utils import set_exploration_mode + +from utils import ( + # get_stats, + make_decision_transformer_model, + make_dt_optimizer, + # make_logger, + make_loss, + make_offline_replay_buffer, + # make_parallel_env, + make_test_env, +) + + +@hydra.main(config_path=".", config_name="config") +def main(cfg: "DictConfig"): # noqa: F821 + model_device = cfg.optim.device + + # state_dict = get_stats(cfg.env) + evaluation_env = make_test_env(cfg.env) + # logger = make_logger(cfg.logger) + replay_buffer = make_offline_replay_buffer(cfg.replay_buffer) + + actor = make_decision_transformer_model(cfg) + policy = actor.to(model_device) + + loss, target_net_updater = make_loss(cfg.loss, policy) + optim = make_dt_optimizer(cfg.optim, policy) + + pbar = tqdm.tqdm(total=cfg.optim.gradient_steps) + + r0 = None + l0 = None + + for i in range(cfg.optim.gradient_steps): + pbar.update(i) + data = replay_buffer.sample() + # loss + loss_vals = loss(data) + # backprop + actor_loss = loss_vals["loss_actor"] + q_loss = loss_vals["loss_qvalue"] + value_loss = loss_vals["loss_value"] + loss_val = actor_loss + q_loss + value_loss + + optim.zero_grad() + loss_val.backward() + optim.step() + target_net_updater.step() + + # evaluation + if i % cfg.env.evaluation_interval == 0: + with set_exploration_mode("random"), torch.no_grad(): + eval_td = evaluation_env.rollout( + max_steps=1000, policy=policy, auto_cast_to_device=True + ) + + if r0 is None: + r0 = eval_td["next", "reward"].sum(1).mean().item() + if l0 is None: + l0 = loss_val.item() + + # for key, value in loss_vals.items(): + # logger.log_scalar(key, value.item(), i) + # eval_reward = eval_td["next", "reward"].sum(1).mean().item() + # logger.log_scalar("evaluation reward", eval_reward, i) + + # pbar.set_description( + # f"loss: {loss_val.item(): 4.4f} (init: {l0: 4.4f}), evaluation reward: {eval_reward: 4.4f} (init={r0: 4.4f})" + # ) + + +if __name__ == "__main__": + main() diff --git a/examples/decision_transformer/utils.py b/examples/decision_transformer/utils.py index 0a278343443..787ecbebf78 100644 --- a/examples/decision_transformer/utils.py +++ b/examples/decision_transformer/utils.py @@ -3,7 +3,9 @@ from tensordict.nn import TensorDictModule from torchrl.collectors import SyncDataCollector -from torchrl.data import CompositeSpec, LazyMemmapStorage, TensorDictReplayBuffer +from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer +from torchrl.data.datasets.d4rl import D4RLExperienceReplay +from torchrl.data.replay_buffers import SamplerWithoutReplacement from torchrl.data.replay_buffers.samplers import RandomSampler from torchrl.envs import ( CatFrames, @@ -13,7 +15,6 @@ NoopResetEnv, ObservationNorm, ParallelEnv, - RenameTransform, Reward2GoTransform, StepCounter, TargetReturn, @@ -23,19 +24,10 @@ ) from torchrl.envs.libs.dm_control import DMControlEnv from torchrl.envs.utils import set_exploration_mode -from torchrl.modules import ( - AdditiveGaussianWrapper, - ConvNet, - MLP, - OrnsteinUhlenbeckProcessWrapper, - ProbabilisticActor, - TanhDelta, - ValueOperator, -) +from torchrl.modules import DTActor, ProbabilisticActor, TanhNormal from torchrl.objectives import SoftUpdate, TD3Loss from torchrl.record.loggers import generate_exp_name, get_logger from torchrl.trainers.helpers.envs import LIBS -from torchrl.trainers.helpers.models import ACTIVATIONS DEFAULT_REWARD_SCALING = { @@ -83,9 +75,10 @@ def make_transformed_env_states(base_env, env_cfg): transformed_env = TransformedEnv(base_env) transformed_env.append_transform(StepCounter()) - transformed_env.append_transform( - RenameTransform(["step_count"], ["timesteps"], create_copy=True) - ) + # Only needed if ordering True -> Default is False + # transformed_env.append_transform( + # RenameTransform(["step_count"], ["timesteps"], create_copy=True) + # ) transformed_env.append_transform( TargetReturn( 200 * 0.01, out_keys=["return_to_go"] @@ -109,11 +102,11 @@ def make_transformed_env_states(base_env, env_cfg): transformed_env.append_transform( CatFrames(in_keys=["return_to_go"], N=env_cfg.stacked_frames, dim=-2) ) - - transformed_env.append_transform(UnsqueezeTransform(-2, in_keys=["timesteps"])) - transformed_env.append_transform( - CatFrames(in_keys=["timesteps"], N=env_cfg.stacked_frames, dim=-2) - ) + # Only needed if ordering True -> Default is False + # transformed_env.append_transform(UnsqueezeTransform(-2, in_keys=["timesteps"])) + # transformed_env.append_transform( + # CatFrames(in_keys=["timesteps"], N=env_cfg.stacked_frames, dim=-2) + # ) return transformed_env @@ -184,6 +177,28 @@ def make_replay_buffer(rb_cfg): ) +def make_offline_replay_buffer(rb_cfg): + r2g = Reward2GoTransform(gamma=1.0, out_keys=["return_to_go"]) + data = D4RLExperienceReplay( + rb_cfg.dataset, + split_trajs=False, + batch_size=rb_cfg.batch_size, + sampler=SamplerWithoutReplacement(drop_last=False), + transform=r2g, + ) + # data.append_transform( + # Reward2GoTransform(gamma=1.0, out_keys=["return_to_go"]) + # ) + # data.append_transform( + + # ) + # data.append_transform( + + # ) + + return data + + # ==================================================================== # Model # ----- @@ -194,158 +209,78 @@ def make_replay_buffer(rb_cfg): # -def make_td3_model(cfg): +def make_decision_transformer_model(cfg): env_cfg = cfg.env - model_cfg = cfg.model + # model_cfg = cfg.model proof_environment = make_transformed_env(make_base_env(env_cfg), env_cfg) # we must initialize the observation norm transform init_stats(proof_environment, n_samples_stats=3, from_pixels=env_cfg.from_pixels) - env_specs = proof_environment.specs - from_pixels = env_cfg.from_pixels + action_spec = proof_environment.action_spec + + in_keys = [ + "observation", + "action", + "return_to_go", + # "timesteps", + ] # return_to_go, timesteps - if not from_pixels: - actor_net, q_net = make_td3_modules_state(model_cfg, proof_environment) - in_keys = ["observation_vector"] - out_keys = ["param"] - else: - actor_net, q_net = make_td3_modules_pixels(model_cfg, proof_environment) - in_keys = ["pixels"] - out_keys = ["param", "hidden"] + actor_net = DTActor(action_dim=1) - actor_module = TensorDictModule(actor_net, in_keys=in_keys, out_keys=out_keys) + dist_class = TanhNormal + dist_kwargs = { + "min": -1.0, + "max": 1.0, + "tanh_loc": False, + } - # We use a ProbabilisticActor to make sure that we map the - # network output to the right space using a TanhDelta - # distribution. + actor_module = TensorDictModule( + actor_net, in_keys=in_keys, out_keys=["loc", "scale"] # , "hidden_state"], + ) actor = ProbabilisticActor( + spec=action_spec, + in_keys=["loc", "scale"], # , "hidden_state"], + out_keys=["action", "log_prob"], # , "hidden_state"], module=actor_module, - in_keys=["param"], - spec=CompositeSpec(action=env_specs["input_spec"]["action"]), - safe=False, - distribution_class=TanhDelta, - distribution_kwargs={ - "min": env_specs["input_spec"]["action"].space.minimum, - "max": env_specs["input_spec"]["action"].space.maximum, - }, - ) - - if not from_pixels: - in_keys = ["observation_vector", "action"] - else: - in_keys = ["pixels", "action"] - - out_keys = ["state_action_value"] - qvalue = ValueOperator( - in_keys=in_keys, - out_keys=out_keys, - module=q_net, + distribution_class=dist_class, + distribution_kwargs=dist_kwargs, + default_interaction_mode="random", + cache_dist=True, + return_log_prob=False, ) # init the lazy layers with torch.no_grad(), set_exploration_mode("random"): - # for t in proof_environment.transform: - # if isinstance(t, ObservationNorm): - # t.init_stats(2) td = proof_environment.rollout(max_steps=1000) print(td) actor(td) - qvalue(td) - - return actor, qvalue - -def make_td3_modules_state(model_cfg, proof_environment): - env_specs = proof_environment.specs - out_features = env_specs["input_spec"]["action"].shape[0] - - actor_net_kwargs = { - "num_cells": [256, 256], - "out_features": out_features, - "activation_class": ACTIVATIONS[model_cfg.activation], - } - actor_net = MLP(**actor_net_kwargs) - - qvalue_net_kwargs = { - "num_cells": [256, 256], - "out_features": 1, - "activation_class": ACTIVATIONS[model_cfg.activation], - } - - q_net = MLP(**qvalue_net_kwargs) - return actor_net, q_net - - -def make_td3_modules_pixels(model_cfg, proof_environment): - env_specs = proof_environment.specs - out_features = env_specs["input_spec"]["action"].shape[0] - - actor_net = torch.nn.ModuleList() - - actor_convnet_kwargs = {"activation_class": ACTIVATIONS[model_cfg.activation]} - actor_net.append(ConvNet(**actor_convnet_kwargs)) - - actor_net_kwargs = { - "num_cells": [256, 256], - "out_features": out_features, - "activation_class": ACTIVATIONS[model_cfg.activation], - } - actor_net.append(MLP(**actor_net_kwargs)) - - q_net = torch.nn.ModuleList() - - q_net_convnet_kwargs = {"activation_class": ACTIVATIONS[model_cfg.activation]} - q_net.append(ConvNet(**q_net_convnet_kwargs)) - - qvalue_net_kwargs = { - "num_cells": [256, 256], - "out_features": 1, - "activation_class": ACTIVATIONS[model_cfg.activation], - } - - q_net.append(MLP(**qvalue_net_kwargs)) - - return actor_net, q_net - - -def make_policy(model_cfg, actor): - if model_cfg.ou_exploration: - return OrnsteinUhlenbeckProcessWrapper(actor) - else: - return AdditiveGaussianWrapper(actor) + return actor # ==================================================================== -# TD3 Loss +# Decision Transformer Loss # --------- -def make_loss(loss_cfg, actor_network, qvalue_network): +def make_loss(loss_cfg, actor_network): loss = TD3Loss( actor_network, - qvalue_network, gamma=loss_cfg.gamma, loss_function=loss_cfg.loss_function, - policy_noise=0.2, - noise_clip=0.5, ) target_net_updater = SoftUpdate(loss, 1 - loss_cfg.tau) target_net_updater.init_() return loss, target_net_updater -def make_td3_optimizer(optim_cfg, actor_network, qvalue_network): - actor_optim = torch.optim.Adam( +def make_dt_optimizer(optim_cfg, actor_network): + optimizer = torch.optim.Adam( actor_network.parameters(), lr=optim_cfg.lr, weight_decay=optim_cfg.weight_decay, ) - critic_optim = torch.optim.Adam( - qvalue_network.parameters(), - lr=optim_cfg.lr, - weight_decay=optim_cfg.weight_decay, - ) - return actor_optim, critic_optim + return optimizer # ==================================================================== From e2fb927115e01dff571c60d1477b4a3cec7e9243 Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 4 May 2023 11:21:09 +0200 Subject: [PATCH 010/104] Update objective loss --- examples/decision_transformer/utils.py | 15 +-- torchrl/objectives/decision_transformer.py | 146 +++++++++++++++++++++ 2 files changed, 153 insertions(+), 8 deletions(-) create mode 100644 torchrl/objectives/decision_transformer.py diff --git a/examples/decision_transformer/utils.py b/examples/decision_transformer/utils.py index 787ecbebf78..9c8a53e1f4d 100644 --- a/examples/decision_transformer/utils.py +++ b/examples/decision_transformer/utils.py @@ -25,7 +25,7 @@ from torchrl.envs.libs.dm_control import DMControlEnv from torchrl.envs.utils import set_exploration_mode from torchrl.modules import DTActor, ProbabilisticActor, TanhNormal -from torchrl.objectives import SoftUpdate, TD3Loss +from torchrl.objectives import OnlineDTLoss from torchrl.record.loggers import generate_exp_name, get_logger from torchrl.trainers.helpers.envs import LIBS @@ -259,22 +259,21 @@ def make_decision_transformer_model(cfg): # ==================================================================== -# Decision Transformer Loss +# Online Decision Transformer Loss # --------- def make_loss(loss_cfg, actor_network): - loss = TD3Loss( + loss = OnlineDTLoss( actor_network, gamma=loss_cfg.gamma, loss_function=loss_cfg.loss_function, ) - target_net_updater = SoftUpdate(loss, 1 - loss_cfg.tau) - target_net_updater.init_() - return loss, target_net_updater + return loss def make_dt_optimizer(optim_cfg, actor_network): + # Should be Lambda Optimizer optimizer = torch.optim.Adam( actor_network.parameters(), lr=optim_cfg.lr, @@ -289,7 +288,7 @@ def make_dt_optimizer(optim_cfg, actor_network): def make_logger(logger_cfg): - exp_name = generate_exp_name("TD3", logger_cfg.exp_name) + exp_name = generate_exp_name("OnlineDecisionTransformer", logger_cfg.exp_name) logger_cfg.exp_name = exp_name - logger = get_logger(logger_cfg.backend, logger_name="td3", experiment_name=exp_name) + logger = get_logger(logger_cfg.backend, logger_name="oDT", experiment_name=exp_name) return logger diff --git a/torchrl/objectives/decision_transformer.py b/torchrl/objectives/decision_transformer.py new file mode 100644 index 00000000000..28aeb611e37 --- /dev/null +++ b/torchrl/objectives/decision_transformer.py @@ -0,0 +1,146 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import numpy as np + +import torch +from tensordict.tensordict import TensorDict, TensorDictBase + +from torchrl.modules import ProbabilisticActor + +from .common import LossModule + + +class OnlineDTLoss(LossModule): + r"""TorchRL implementation of the Online Decision Transformer loss. + + Presented in "Online Decision Transformer" https://arxiv.org/abs/2202.05607 + Args: + actor_network (ProbabilisticActor): stochastic actor + qvalue_network (SafeModule): Q(s, a) parametric model + value_network (SafeModule, optional): V(s) parametric model. If not + provided, the second version of SAC is assumed. + qvalue_network_bis (ProbabilisticTDModule, optional): if required, the + Q-value can be computed twice independently using two separate + networks. The minimum predicted value will then be used for + inference. + gamma (number, optional): discount for return computation + Default is 0.99 + priority_key (str, optional): tensordict key where to write the + priority (for prioritized replay buffer usage). Default is + `"td_error"`. + loss_function (str, optional): loss function to be used with + the value function loss. Default is `"smooth_l1"`. + temperature (float, optional): Inverse temperature (beta). + For smaller hyperparameter values, the objective behaves similarly to + behavioral cloning, while for larger values, it attempts to recover the + maximum of the Q-function. + expectile (float, optional): expectile :math:`\tau`. A larger value of :math:`\tau` is crucial + for antmaze tasks that require dynamical programming ("stichting"). + + """ + + def __init__( + self, + actor_network: ProbabilisticActor, + alpha_init: float = 1.0, + min_alpha: float = 0.1, + max_alpha: float = 10.0, + ) -> None: + super().__init__() + + # Actor Network + self.convert_to_functional( + actor_network, + "actor_network", + create_target_params=False, + funs_to_decorate=["forward", "get_dist"], + ) + try: + device = next(self.parameters()).device + except AttributeError: + device = torch.device("cpu") + self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device)) + self.register_buffer( + "min_log_alpha", torch.tensor(min_alpha, device=device).log() + ) + self.register_buffer( + "max_log_alpha", torch.tensor(max_alpha, device=device).log() + ) + + self.register_parameter( + "log_alpha", + torch.nn.Parameter(torch.tensor(math.log(alpha_init), device=device)), + ) + + target_entropy = -float(np.prod(actor_network.spec["action"].shape)) + self.register_buffer( + "target_entropy", torch.tensor(target_entropy, device=device) + ) + + @property + def device(self) -> torch.device: + for p in self.parameters(): + return p.device + raise RuntimeError( + "At least one of the networks of SACLoss must have trainable " "parameters." + ) + + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + """Compute the loss for the Online Decision Transformer. + + # a_hat is a SquashedNormal Distribution + log_likelihood = a_hat_dist.log_likelihood(a)[attention_mask > 0].mean() + + entropy = a_hat_dist.entropy().mean() + loss = -(log_likelihood + entropy_reg * entropy) + + return ( + loss, + -log_likelihood, + entropy, + ) + dist.log_prob(x).sum(axis=2) + """ + shape = None + if tensordict.ndimension() > 1: + shape = tensordict.shape + tensordict_reshape = tensordict.reshape(-1) + else: + tensordict_reshape = tensordict + + # device = self.device + # td_device = tensordict_reshape.to(device) + + out_td = self.actor_network(tensordict) + + target_actions = tensordict["action"] + + # log_prob = out_td["log_prob"] + action_dist = out_td["distribution"] + loss_log_likelihood = action_dist.log_prob(target_actions).sum(axis=2) + entropy = action_dist.entropy().mean() + loss = -(loss_log_likelihood + self.target_entropy.detach() * entropy) + + loss_alpha = self.log_alpha.exp() * (entropy - self.target_entropy).detach() + if shape: + tensordict.update(tensordict_reshape.view(shape)) + out = { + "loss": loss.mean(), + "loss_log_likelihood": loss_log_likelihood.mean(), + "entropy": entropy.mean(), + "loss_alpha": loss_alpha.mean(), + "alpha": self._alpha, + } + return TensorDict(out, []) + + @property + def _alpha(self): + self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha) + with torch.no_grad(): + alpha = self.log_alpha.exp() + return alpha From 69b0974ad4f204a32834f4b21865b44fe90e9a06 Mon Sep 17 00:00:00 2001 From: BY571 Date: Fri, 5 May 2023 13:50:04 +0200 Subject: [PATCH 011/104] updates --- examples/decision_transformer/config.yaml | 10 ++-- examples/decision_transformer/dt_offline.py | 37 ++++++------- examples/decision_transformer/utils.py | 52 ++++++++----------- torchrl/envs/transforms/transforms.py | 14 ++++- .../modules/models/decision_transformer.py | 31 +++++------ torchrl/modules/models/models.py | 14 ++--- torchrl/objectives/__init__.py | 1 + torchrl/objectives/decision_transformer.py | 24 ++------- 8 files changed, 84 insertions(+), 99 deletions(-) diff --git a/examples/decision_transformer/config.yaml b/examples/decision_transformer/config.yaml index 279715b46b5..80d723354c6 100644 --- a/examples/decision_transformer/config.yaml +++ b/examples/decision_transformer/config.yaml @@ -1,6 +1,6 @@ # Task and env env: - env_name: Pendulum-v1 + env_name: Hopper-v2 env_task: "" env_library: gym record_video: 0 @@ -36,6 +36,7 @@ replay_buffer: dataset: hopper-medium-v2 batch_size: 256 prb: 0 + stacked_frames: 20 buffer_prefetch: 64 capacity: 1_000_000 @@ -48,6 +49,7 @@ optim: lr_scheduler: "" optim_steps_per_batch: 1000 policy_update_delay: 2 + gradient_steps: 1000 # Policy and model model: @@ -57,6 +59,6 @@ model: # loss loss: - loss_function: smooth_l1 - gamma: 0.99 - tau: 0.05 + alpha_init: 1.0 + min_alpha: 0.1 + max_alpha: 10.0 diff --git a/examples/decision_transformer/dt_offline.py b/examples/decision_transformer/dt_offline.py index 70af977dfce..c7d6d911515 100644 --- a/examples/decision_transformer/dt_offline.py +++ b/examples/decision_transformer/dt_offline.py @@ -8,9 +8,7 @@ """ import hydra -import torch import tqdm -from torchrl.envs.utils import set_exploration_mode from utils import ( # get_stats, @@ -20,7 +18,7 @@ make_loss, make_offline_replay_buffer, # make_parallel_env, - make_test_env, + # make_test_env, ) @@ -29,20 +27,20 @@ def main(cfg: "DictConfig"): # noqa: F821 model_device = cfg.optim.device # state_dict = get_stats(cfg.env) - evaluation_env = make_test_env(cfg.env) + # evaluation_env = make_test_env(cfg.env) # logger = make_logger(cfg.logger) replay_buffer = make_offline_replay_buffer(cfg.replay_buffer) actor = make_decision_transformer_model(cfg) policy = actor.to(model_device) - loss, target_net_updater = make_loss(cfg.loss, policy) + loss = make_loss(cfg.loss, policy) optim = make_dt_optimizer(cfg.optim, policy) pbar = tqdm.tqdm(total=cfg.optim.gradient_steps) - r0 = None - l0 = None + # r0 = None + # l0 = None for i in range(cfg.optim.gradient_steps): pbar.update(i) @@ -58,19 +56,18 @@ def main(cfg: "DictConfig"): # noqa: F821 optim.zero_grad() loss_val.backward() optim.step() - target_net_updater.step() - - # evaluation - if i % cfg.env.evaluation_interval == 0: - with set_exploration_mode("random"), torch.no_grad(): - eval_td = evaluation_env.rollout( - max_steps=1000, policy=policy, auto_cast_to_device=True - ) - - if r0 is None: - r0 = eval_td["next", "reward"].sum(1).mean().item() - if l0 is None: - l0 = loss_val.item() + + # # evaluation + # if i % cfg.env.evaluation_interval == 0: + # with set_exploration_mode("random"), torch.no_grad(): + # eval_td = evaluation_env.rollout( + # max_steps=1000, policy=policy, auto_cast_to_device=True + # ) + + # if r0 is None: + # r0 = eval_td["next", "reward"].sum(1).mean().item() + # if l0 is None: + # l0 = loss_val.item() # for key, value in loss_vals.items(): # logger.log_scalar(key, value.item(), i) diff --git a/examples/decision_transformer/utils.py b/examples/decision_transformer/utils.py index 9c8a53e1f4d..a4f2f67d209 100644 --- a/examples/decision_transformer/utils.py +++ b/examples/decision_transformer/utils.py @@ -16,7 +16,6 @@ ObservationNorm, ParallelEnv, Reward2GoTransform, - StepCounter, TargetReturn, TensorDictPrimer, TransformedEnv, @@ -30,16 +29,6 @@ from torchrl.trainers.helpers.envs import LIBS -DEFAULT_REWARD_SCALING = { - "Hopper-v1": 5, - "Walker2d-v1": 5, - "HalfCheetah-v1": 5, - "cheetah": 5, - "Ant-v2": 5, - "Humanoid-v2": 20, - "humanoid": 100, -} - # ==================================================================== # Environment utils # ----------------- @@ -74,11 +63,6 @@ def make_transformed_env(base_env, env_cfg): def make_transformed_env_states(base_env, env_cfg): transformed_env = TransformedEnv(base_env) - transformed_env.append_transform(StepCounter()) - # Only needed if ordering True -> Default is False - # transformed_env.append_transform( - # RenameTransform(["step_count"], ["timesteps"], create_copy=True) - # ) transformed_env.append_transform( TargetReturn( 200 * 0.01, out_keys=["return_to_go"] @@ -102,11 +86,10 @@ def make_transformed_env_states(base_env, env_cfg): transformed_env.append_transform( CatFrames(in_keys=["return_to_go"], N=env_cfg.stacked_frames, dim=-2) ) - # Only needed if ordering True -> Default is False - # transformed_env.append_transform(UnsqueezeTransform(-2, in_keys=["timesteps"])) - # transformed_env.append_transform( - # CatFrames(in_keys=["timesteps"], N=env_cfg.stacked_frames, dim=-2) - # ) + + # transformed_env.append_transform(UnsqueezeTransform(0, in_keys=["return_to_go"], allow_positive_dim=True)) + # transformed_env.append_transform(UnsqueezeTransform(0, in_keys=["observation"], allow_positive_dim=True)) + # transformed_env.append_transform(UnsqueezeTransform(0, in_keys=["action"], allow_positive_dim=True)) return transformed_env @@ -179,12 +162,17 @@ def make_replay_buffer(rb_cfg): def make_offline_replay_buffer(rb_cfg): r2g = Reward2GoTransform(gamma=1.0, out_keys=["return_to_go"]) + cat_r2g = CatFrames(in_keys=["return_to_go"], N=rb_cfg.stacked_frames, dim=-2) + cat_obs = CatFrames(in_keys=["observation"], N=rb_cfg.stacked_frames, dim=-2) + cat_actions = CatFrames(in_keys=["action"], N=rb_cfg.stacked_frames, dim=-2) + exclude_next_obs = ExcludeTransform("next_observations") + transforms = Compose(r2g, cat_r2g, cat_obs, cat_actions, exclude_next_obs) data = D4RLExperienceReplay( rb_cfg.dataset, split_trajs=False, batch_size=rb_cfg.batch_size, sampler=SamplerWithoutReplacement(drop_last=False), - transform=r2g, + transform=transforms, ) # data.append_transform( # Reward2GoTransform(gamma=1.0, out_keys=["return_to_go"]) @@ -213,11 +201,11 @@ def make_decision_transformer_model(cfg): env_cfg = cfg.env # model_cfg = cfg.model proof_environment = make_transformed_env(make_base_env(env_cfg), env_cfg) - # we must initialize the observation norm transform - init_stats(proof_environment, n_samples_stats=3, from_pixels=env_cfg.from_pixels) action_spec = proof_environment.action_spec - + for key, value in proof_environment.observation_spec.items(): + if key == "observation": + state_dim = value.shape[-1] in_keys = [ "observation", "action", @@ -225,7 +213,7 @@ def make_decision_transformer_model(cfg): # "timesteps", ] # return_to_go, timesteps - actor_net = DTActor(action_dim=1) + actor_net = DTActor(state_dim=state_dim, action_dim=action_spec.shape[-1]) dist_class = TanhNormal dist_kwargs = { @@ -245,13 +233,14 @@ def make_decision_transformer_model(cfg): distribution_class=dist_class, distribution_kwargs=dist_kwargs, default_interaction_mode="random", - cache_dist=True, - return_log_prob=False, + cache_dist=False, + return_log_prob=True, ) # init the lazy layers with torch.no_grad(), set_exploration_mode("random"): - td = proof_environment.rollout(max_steps=1000) + td = proof_environment.rollout(max_steps=100) + td["action"] = td["next", "action"] print(td) actor(td) @@ -266,8 +255,9 @@ def make_decision_transformer_model(cfg): def make_loss(loss_cfg, actor_network): loss = OnlineDTLoss( actor_network, - gamma=loss_cfg.gamma, - loss_function=loss_cfg.loss_function, + loss_cfg.alpha_init, + loss_cfg.min_alpha, + loss_cfg.max_alpha, ) return loss diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index fb3fce9e0ef..35508d298ab 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -1103,10 +1103,20 @@ def _apply_transform( self, reward: torch.Tensor, target_return: torch.Tensor ) -> torch.Tensor: if self.mode == "reduce": - target_return = target_return[:, -1] - reward + if reward.ndim == 1 and target_return.ndim == 2: + # if target is stacked + target_return = target_return[-1] - reward + else: + # reward.ndim == 2 and target_return.ndim == 2: + target_return = target_return - reward return target_return elif self.mode == "constant": - return target_return[:, -1] + if reward.ndim == 1 and target_return.ndim == 2: + # if target is stacked + target_return = target_return[-1] - reward + else: + # reward.ndim == target_return.ndim + target_return = target_return - reward else: raise ValueError("Unknown mode: {}".format(self.mode)) diff --git a/torchrl/modules/models/decision_transformer.py b/torchrl/modules/models/decision_transformer.py index 39407348baa..593f6319789 100644 --- a/torchrl/modules/models/decision_transformer.py +++ b/torchrl/modules/models/decision_transformer.py @@ -44,29 +44,26 @@ def forward( observation: torch.Tensor, action: torch.Tensor, return_to_go: torch.Tensor, - timesteps: torch.Tensor, - mask_context: bool = True, ): batch_size, seq_length = observation.shape[0], observation.shape[1] - if mask_context: - ( - observation, - action, - return_to_go, - timesteps, - seq_length, - ) = self.mask_context(observation, action, return_to_go, timesteps) + # if mask_context: + # ( observation, + # action, + # return_to_go, + # # timesteps, + # seq_length, + # ) = self.mask_context(observation, action, return_to_go)# timesteps # embed each modality with a different head state_embeddings = self.embed_state(observation) action_embeddings = self.embed_action(action) returns_embeddings = self.embed_return(return_to_go) - if self.ordering: - order_embeddings = self.embed_ordering(timesteps) - else: - order_embeddings = 0.0 + # if self.ordering: + # order_embeddings = self.embed_ordering(timesteps) + # else: + order_embeddings = 0.0 state_embeddings = state_embeddings + order_embeddings action_embeddings = action_embeddings + order_embeddings @@ -100,7 +97,7 @@ def mask_context( observation: torch.Tensor, action: torch.Tensor, return_to_go: torch.Tensor, - timesteps: torch.Tensor, + # timesteps: torch.Tensor, ): """Mask the context of the input sequences.""" observation[:, : -self.inference_context, :] = 0 @@ -109,5 +106,5 @@ def mask_context( [action[:, 1:], torch.zeros(action.shape[0], 1, self.action_dim)], dim=-2 ) return_to_go[:, : -self.inference_context, :] = 0 - timesteps[:, : -self.inference_context] = 0 - return observation, action, return_to_go, timesteps, self.train_context + # timesteps[:, : -self.inference_context] = 0 + return observation, action, return_to_go, self.train_context # timesteps diff --git a/torchrl/modules/models/models.py b/torchrl/modules/models/models.py index 7f6f3876d24..431de7a39b9 100644 --- a/torchrl/modules/models/models.py +++ b/torchrl/modules/models/models.py @@ -1167,6 +1167,7 @@ class DTActor(nn.Module): def __init__( self, + state_dim: int, action_dim: int, mlp_net_kwargs: Optional[dict] = None, device: Optional[DEVICE_TYPING] = None, @@ -1180,7 +1181,7 @@ def __init__( "bias_last_layer": True, } self.transformer = DecisionTransformer( - state_dim=3, + state_dim=state_dim, action_dim=action_dim, hidden_size=512, max_ep_len=1000, @@ -1199,12 +1200,13 @@ def forward( observation: torch.Tensor, action: torch.Tensor, return_to_go: torch.Tensor, - timesteps: torch.Tensor, - mask_context: bool = True, ) -> torch.Tensor: - hidden_state = self.transformer( - observation, action, return_to_go, timesteps, mask_context - ) + + if observation.ndim == 2: + observation = observation.unsqueeze(0).float() + action = action.unsqueeze(0) + return_to_go = return_to_go.unsqueeze(0) + hidden_state = self.transformer(observation, action, return_to_go) # timesteps out = self.mlp(hidden_state)[:, -1] mu, log_std = out.chunk(2, -1) log_std = torch.tanh(log_std) diff --git a/torchrl/objectives/__init__.py b/torchrl/objectives/__init__.py index 70d794e6495..16e3cb73b5c 100644 --- a/torchrl/objectives/__init__.py +++ b/torchrl/objectives/__init__.py @@ -6,6 +6,7 @@ from .a2c import A2CLoss from .common import LossModule from .ddpg import DDPGLoss +from .decision_transformer import OnlineDTLoss from .dqn import DistributionalDQNLoss, DQNLoss from .dreamer import DreamerActorLoss, DreamerModelLoss, DreamerValueLoss from .iql import IQLLoss diff --git a/torchrl/objectives/decision_transformer.py b/torchrl/objectives/decision_transformer.py index 28aeb611e37..96b40d74c3a 100644 --- a/torchrl/objectives/decision_transformer.py +++ b/torchrl/objectives/decision_transformer.py @@ -9,7 +9,6 @@ import torch from tensordict.tensordict import TensorDict, TensorDictBase - from torchrl.modules import ProbabilisticActor from .common import LossModule @@ -102,33 +101,20 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: return ( loss, -log_likelihood, - entropy, + entrop y, ) dist.log_prob(x).sum(axis=2) """ - shape = None - if tensordict.ndimension() > 1: - shape = tensordict.shape - tensordict_reshape = tensordict.reshape(-1) - else: - tensordict_reshape = tensordict - - # device = self.device - # td_device = tensordict_reshape.to(device) - out_td = self.actor_network(tensordict) - target_actions = tensordict["action"] + target_actions = tensordict["target_actions"] - # log_prob = out_td["log_prob"] - action_dist = out_td["distribution"] - loss_log_likelihood = action_dist.log_prob(target_actions).sum(axis=2) - entropy = action_dist.entropy().mean() + loss_log_likelihood = out_td["action_log_prob"](target_actions).sum(axis=2) + entropy = 0 # action_dist.entropy().mean() loss = -(loss_log_likelihood + self.target_entropy.detach() * entropy) loss_alpha = self.log_alpha.exp() * (entropy - self.target_entropy).detach() - if shape: - tensordict.update(tensordict_reshape.view(shape)) + out = { "loss": loss.mean(), "loss_log_likelihood": loss_log_likelihood.mean(), From a5e5da7734b4b985f5fdd31f5dc513068b59532b Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 11 May 2023 16:39:04 +0200 Subject: [PATCH 012/104] add objective --- examples/decision_transformer/config.yaml | 56 +++-- examples/decision_transformer/dt_offline.py | 201 +++++++++++---- examples/decision_transformer/dt_online.py | 74 ------ examples/decision_transformer/utils.py | 234 ++++++++++++------ torchrl/envs/transforms/transforms.py | 5 +- .../modules/models/decision_transformer.py | 75 ++---- torchrl/modules/models/models.py | 80 +++--- torchrl/modules/tensordict_module/__init__.py | 1 + torchrl/modules/tensordict_module/actors.py | 102 +++++++- .../modules/tensordict_module/exploration.py | 1 - torchrl/objectives/decision_transformer.py | 89 ++----- 11 files changed, 521 insertions(+), 397 deletions(-) delete mode 100644 examples/decision_transformer/dt_online.py diff --git a/examples/decision_transformer/config.yaml b/examples/decision_transformer/config.yaml index 80d723354c6..3d0e885ad5f 100644 --- a/examples/decision_transformer/config.yaml +++ b/examples/decision_transformer/config.yaml @@ -5,30 +5,35 @@ env: env_library: gym record_video: 0 stacked_frames: 20 - n_samples_stats: 1000 + n_samples_stats: 2000 frame_skip: 1 from_pixels: False - num_envs: 1 - reward_scaling: + num_train_envs: 1 + num_eval_envs: 1 + reward_scaling: 0.001 noop: 1 seed: 0 + eval_target_return: 3600 + collect_target_return: 7200 + total_online_frames: 1000000 + # Collector collector: async_collection: 1 frames_per_batch: 1000 total_frames: 1000000 - multi_step: 0 - init_random_frames: 25000 + init_random_frames: 0 collector_devices: cpu # ,cpu,cpu,cpu] num_collectors: 1 max_frames_per_traj: 1000 # logger logger: - backend: tensorboard - exp_name: td3_cheetah_gym - log_interval: 10000 # record interval in frames + backend: wandb + exp_name: oDT-Hopper-medium-v2 + pretrain_log_interval: 500 # record interval in frames + fintune_log_interval: 1 eval_steps: 1000 # Buffer @@ -39,26 +44,31 @@ replay_buffer: stacked_frames: 20 buffer_prefetch: 64 capacity: 1_000_000 + buffer_scratch_dir: "/tmp/" + device: cpu + prefetch: 3 # Optimization optim: - device: cpu - lr: 3e-4 - weight_decay: 0.0 + device: cuda:0 + lr: 1.0e-4 + weight_decay: 5.0e-4 batch_size: 256 lr_scheduler: "" - optim_steps_per_batch: 1000 - policy_update_delay: 2 - gradient_steps: 1000 - -# Policy and model -model: - ou_exploration: 0 - noisy: False - activation: relu + pretrain_gradient_steps: 3000 + updates_per_episode: 300 + warmup_steps: 10000 # loss loss: - alpha_init: 1.0 - min_alpha: 0.1 - max_alpha: 10.0 + alpha_init: 0.1 + +transformer: + n_embd: 512 + n_layer: 4 + n_head: 4 + n_inner: 2048 # 4*512 + activation: relu + n_positions: 1024 + resid_pdrop: 0.1 + attn_pdrop: 0.1 diff --git a/examples/decision_transformer/dt_offline.py b/examples/decision_transformer/dt_offline.py index c7d6d911515..aa973e60b96 100644 --- a/examples/decision_transformer/dt_offline.py +++ b/examples/decision_transformer/dt_offline.py @@ -8,17 +8,19 @@ """ import hydra +import torch import tqdm +from torchrl.envs.utils import ExplorationType, set_exploration_type from utils import ( - # get_stats, + make_collector, make_decision_transformer_model, make_dt_optimizer, - # make_logger, + make_env, + make_logger, make_loss, make_offline_replay_buffer, - # make_parallel_env, - # make_test_env, + make_online_replay_buffer, ) @@ -26,57 +28,162 @@ def main(cfg: "DictConfig"): # noqa: F821 model_device = cfg.optim.device - # state_dict = get_stats(cfg.env) - # evaluation_env = make_test_env(cfg.env) - # logger = make_logger(cfg.logger) - replay_buffer = make_offline_replay_buffer(cfg.replay_buffer) + test_env = make_env(cfg.env) + logger = make_logger(cfg.logger) + offline_buffer = make_offline_replay_buffer( + cfg.replay_buffer, cfg.env.reward_scaling + ) - actor = make_decision_transformer_model(cfg) + inference_actor, actor = make_decision_transformer_model(cfg) policy = actor.to(model_device) + inference_policy = inference_actor.to(model_device) - loss = make_loss(cfg.loss, policy) - optim = make_dt_optimizer(cfg.optim, policy) + loss_module = make_loss(cfg.loss, actor) + transformer_optim, temperature_optim, scheduler = make_dt_optimizer( + cfg.optim, policy, loss_module + ) - pbar = tqdm.tqdm(total=cfg.optim.gradient_steps) + pbar = tqdm.tqdm(total=cfg.optim.pretrain_gradient_steps) - # r0 = None - # l0 = None - - for i in range(cfg.optim.gradient_steps): + r0 = None + l0 = None + print(" ***Pretraining*** ") + # Pretraining + for i in range(cfg.optim.pretrain_gradient_steps): pbar.update(i) - data = replay_buffer.sample() + data = offline_buffer.sample() # loss - loss_vals = loss(data) + loss_vals = loss_module(data) # backprop - actor_loss = loss_vals["loss_actor"] - q_loss = loss_vals["loss_qvalue"] - value_loss = loss_vals["loss_value"] - loss_val = actor_loss + q_loss + value_loss - - optim.zero_grad() - loss_val.backward() - optim.step() - - # # evaluation - # if i % cfg.env.evaluation_interval == 0: - # with set_exploration_mode("random"), torch.no_grad(): - # eval_td = evaluation_env.rollout( - # max_steps=1000, policy=policy, auto_cast_to_device=True - # ) - - # if r0 is None: - # r0 = eval_td["next", "reward"].sum(1).mean().item() - # if l0 is None: - # l0 = loss_val.item() - - # for key, value in loss_vals.items(): - # logger.log_scalar(key, value.item(), i) - # eval_reward = eval_td["next", "reward"].sum(1).mean().item() - # logger.log_scalar("evaluation reward", eval_reward, i) - - # pbar.set_description( - # f"loss: {loss_val.item(): 4.4f} (init: {l0: 4.4f}), evaluation reward: {eval_reward: 4.4f} (init={r0: 4.4f})" - # ) + transformer_loss = loss_vals["loss"] + temperature_loss = loss_vals["loss_alpha"] + + transformer_optim.zero_grad() + torch.nn.utils.clip_grad_norm_(policy.parameters(), 0.25) + transformer_loss.backward() + transformer_optim.step() + + temperature_optim.zero_grad() + temperature_loss.backward() + temperature_optim.step() + + scheduler.step() + + # evaluation + with set_exploration_type(ExplorationType.MEAN), torch.no_grad(): + if i % cfg.logger.pretrain_log_interval == 0: + eval_td = test_env.rollout( + max_steps=cfg.logger.eval_steps, + policy=inference_policy, + auto_cast_to_device=True, + ) + if r0 is None: + r0 = eval_td["next", "reward"].sum(1).mean().item() / cfg.env.reward_scaling + if l0 is None: + l0 = transformer_loss.item() + + for key, value in loss_vals.items(): + logger.log_scalar(key, value.item(), i) + eval_reward = ( + eval_td["next", "reward"].sum(1).mean().item() / cfg.env.reward_scaling + ) + logger.log_scalar("evaluation reward", eval_reward, i) + + pbar.set_description( + f"[Pre-Training] loss: {transformer_loss.item(): 4.4f} (init: {l0: 4.4f}), evaluation reward: {eval_reward: 4.4f} (init={r0: 4.4f})" + ) + print("\n ***Online Finetuning*** ") + collector = make_collector(cfg, inference_policy) + online_buffer = make_online_replay_buffer( + offline_buffer, cfg.replay_buffer, cfg.env.reward_scaling + ) + # online_buffer = offline_buffer + collected_frames = 0 + + pbar = tqdm.tqdm(total=cfg.env.total_online_frames) + r0 = None + + for j, tensordict in enumerate(collector): + # update weights of the inference policy + collector.update_policy_weights_() + + episode_reward = ( + tensordict["next", "episode_reward"][tensordict["next", "done"]] + .mean() + .item() + / cfg.env.reward_scaling + ) + if r0 is None: + r0 = episode_reward + + current_frames = tensordict.numel() + pbar.update(current_frames) + + tensordict = tensordict.reshape(-1) + tensordict.del_("episode_reward") + + online_buffer.extend(tensordict.cpu().clone().detach()) + collected_frames += current_frames + + # optimization steps + for _ in range(int(cfg.optim.updates_per_episode)): + sampled_tensordict = online_buffer.sample().clone() + + loss_vals = loss_module(sampled_tensordict) + + # backprop + transformer_loss = loss_vals["loss"] + temperature_loss = loss_vals["loss_alpha"] + + transformer_optim.zero_grad() + torch.nn.utils.clip_grad_norm_(policy.parameters(), 0.25) + transformer_loss.backward() + transformer_optim.step() + + temperature_optim.zero_grad() + temperature_loss.backward() + temperature_optim.step() + + scheduler.step() + + train_target_return = ( + tensordict["return_to_go"][:, 0].mean() / cfg.env.reward_scaling + ) + train_log = { + "collect reward": episode_reward, + "collected_frames": collected_frames, + "collect target_return": train_target_return.item() + / cfg.env.reward_scaling, + } + + for key, value in train_log.items(): + logger.log_scalar(key, value, step=j) + + with set_exploration_type(ExplorationType.MEAN), torch.no_grad(): + if j % cfg.logger.fintune_log_interval == 0: + eval_td = test_env.rollout( + max_steps=cfg.logger.eval_steps * cfg.env.num_eval_envs, + policy=inference_policy, + auto_cast_to_device=True, + ) + eval_reward = ( + eval_td["next", "reward"].sum(1).mean().item() / cfg.env.reward_scaling + ) + eval_target_return = ( + eval_td["return_to_go"][:, 0].mean() / cfg.env.reward_scaling + ) + eval_log = { + "fine-tune evaluation reward": eval_reward, + "evaluation target_return": eval_target_return.item() + / cfg.env.reward_scaling, + } + for key, value in eval_log.items(): + logger.log_scalar(key, value, step=j) + pbar.set_description( + f"[Fine-Tuning] loss: {transformer_loss.item(): 4.4f} (init: {l0: 4.4f}), evaluation reward: {eval_reward: 4.4f} (init={r0: 4.4f})" + ) + + collector.shutdown() if __name__ == "__main__": diff --git a/examples/decision_transformer/dt_online.py b/examples/decision_transformer/dt_online.py deleted file mode 100644 index 0373f4683d3..00000000000 --- a/examples/decision_transformer/dt_online.py +++ /dev/null @@ -1,74 +0,0 @@ -import hydra -import torch - -from tensordict.nn import TensorDictModule - -from torchrl.modules import ProbabilisticActor -from torchrl.modules.distributions import TanhNormal -from torchrl.modules.models import DTActor -from utils import make_collector, make_replay_buffer, make_test_env - - -@hydra.main(config_path=".", config_name="config") -def main(cfg: "DictConfig"): # noqa: F821 - # Sanity check - test_env = make_test_env(cfg.env) - # test_env = make_transformed_env(test_env) - action_spec = test_env.action_spec - - in_keys = [ - "observation", - "action", - "return_to_go", - "timesteps", - ] # return_to_go, timesteps - - actor_net = DTActor(action_dim=1) - - dist_class = TanhNormal - dist_kwargs = { - "min": -1.0, - "max": 1.0, - "tanh_loc": False, - } - - actor_module = TensorDictModule( - actor_net, in_keys=in_keys, out_keys=["loc", "scale"] # , "hidden_state"], - ) - actor = ProbabilisticActor( - spec=action_spec, - in_keys=["loc", "scale"], # , "hidden_state"], - out_keys=["action", "log_prob"], # , "hidden_state"], - module=actor_module, - distribution_class=dist_class, - distribution_kwargs=dist_kwargs, - default_interaction_mode="random", - cache_dist=True, - return_log_prob=False, - ) - - print(actor) - - with torch.no_grad(): - test_env.eval() - actor.eval() - # Generate a complete episode - td_test = test_env.rollout( - policy=actor, - max_steps=30, - auto_reset=True, - auto_cast_to_device=True, - break_when_any_done=True, - ).clone() - print(td_test) - - collector = make_collector(cfg, policy=actor) - - replay_buffer = make_replay_buffer(cfg.replay_buffer) - for data in collector: - data_view = data.reshape(-1) - replay_buffer.extend(data_view) - - -if __name__ == "__main__": - main() diff --git a/examples/decision_transformer/utils.py b/examples/decision_transformer/utils.py index a4f2f67d209..2cae0b15b7b 100644 --- a/examples/decision_transformer/utils.py +++ b/examples/decision_transformer/utils.py @@ -6,16 +6,18 @@ from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer from torchrl.data.datasets.d4rl import D4RLExperienceReplay from torchrl.data.replay_buffers import SamplerWithoutReplacement -from torchrl.data.replay_buffers.samplers import RandomSampler from torchrl.envs import ( CatFrames, Compose, + DoubleToFloat, EnvCreator, ExcludeTransform, NoopResetEnv, ObservationNorm, ParallelEnv, Reward2GoTransform, + RewardScaling, + RewardSum, TargetReturn, TensorDictPrimer, TransformedEnv, @@ -24,6 +26,7 @@ from torchrl.envs.libs.dm_control import DMControlEnv from torchrl.envs.utils import set_exploration_mode from torchrl.modules import DTActor, ProbabilisticActor, TanhNormal +from torchrl.modules.tensordict_module import DecisionTransformerInferenceWrapper from torchrl.objectives import OnlineDTLoss from torchrl.record.loggers import generate_exp_name, get_logger from torchrl.trainers.helpers.envs import LIBS @@ -34,18 +37,14 @@ # ----------------- -def make_base_env(env_cfg, from_pixels=None): +def make_base_env(env_cfg): env_library = LIBS[env_cfg.env_library] env_name = env_cfg.env_name frame_skip = env_cfg.frame_skip - if from_pixels is None: - from_pixels = env_cfg.from_pixels env_kwargs = { "env_name": env_name, "frame_skip": frame_skip, - "from_pixels": from_pixels, # for rendering - "pixels_only": False, } if env_library is DMControlEnv: env_task = env_cfg.env_task @@ -56,22 +55,37 @@ def make_base_env(env_cfg, from_pixels=None): return env -def make_transformed_env(base_env, env_cfg): - return make_transformed_env_states(base_env, env_cfg) - - -def make_transformed_env_states(base_env, env_cfg): +def make_transformed_env(base_env, env_cfg, train=False): transformed_env = TransformedEnv(base_env) - + if train: + transformed_env.append_transform( + TargetReturn(env_cfg.collect_target_return, out_keys=["return_to_go"]) + ) + else: + transformed_env.append_transform( + TargetReturn(env_cfg.eval_target_return, out_keys=["return_to_go"]) + ) transformed_env.append_transform( - TargetReturn( - 200 * 0.01, out_keys=["return_to_go"] - ) # WATCH OUT FOR THE SCALING! + RewardScaling( + loc=0, + scale=env_cfg.reward_scaling, + in_keys="return_to_go", + standard_normal=False, + ) + ) + transformed_env.append_transform( + RewardScaling( + loc=0, scale=env_cfg.reward_scaling, in_keys="reward", standard_normal=False + ) ) - # transformed_env.append_transform(SCALE) transformed_env.append_transform(TensorDictPrimer(action=base_env.action_spec)) - # transformed_env.append_transform(TensorDictPrimer(padding_mask=env.action_spec)) + transformed_env.append_transform( + DoubleToFloat( + in_keys=["observation"], + in_keys_inv=[], + ) + ) transformed_env.append_transform(UnsqueezeTransform(-2, in_keys=["observation"])) transformed_env.append_transform( CatFrames(in_keys=["observation"], N=env_cfg.stacked_frames, dim=-2) @@ -86,18 +100,21 @@ def make_transformed_env_states(base_env, env_cfg): transformed_env.append_transform( CatFrames(in_keys=["return_to_go"], N=env_cfg.stacked_frames, dim=-2) ) - - # transformed_env.append_transform(UnsqueezeTransform(0, in_keys=["return_to_go"], allow_positive_dim=True)) - # transformed_env.append_transform(UnsqueezeTransform(0, in_keys=["observation"], allow_positive_dim=True)) - # transformed_env.append_transform(UnsqueezeTransform(0, in_keys=["action"], allow_positive_dim=True)) + if train: + transformed_env.append_transform(RewardSum()) return transformed_env -def make_parallel_env(env_cfg, state_dict): - num_envs = env_cfg.num_envs +def make_parallel_env(env_cfg, state_dict, train=False): + if train: + num_envs = env_cfg.num_train_envs + else: + num_envs = env_cfg.num_eval_envs env = make_transformed_env( - ParallelEnv(num_envs, EnvCreator(lambda: make_base_env(env_cfg))), env_cfg + ParallelEnv(num_envs, EnvCreator(lambda: make_base_env(env_cfg))), + env_cfg, + train, ) for t in env.transform: if isinstance(t, ObservationNorm): @@ -106,10 +123,9 @@ def make_parallel_env(env_cfg, state_dict): return env -def make_test_env(env_cfg): - env_cfg.num_envs = 1 +def make_env(env_cfg, train=False): state_dict = get_stats(env_cfg) - env = make_parallel_env(env_cfg, state_dict=state_dict) + env = make_parallel_env(env_cfg, state_dict=state_dict, train=train) return env @@ -131,42 +147,67 @@ def init_stats(env, n_samples_stats): def make_collector(cfg, policy): - env_cfg = cfg.env + exclude_target_return = ExcludeTransform( + "return_to_go", + ("next", "return_to_go"), + ("next", "action"), + ("next", "observation"), + "scale", + "loc", + ) + cat = CatFrames(in_keys=["action"], N=20, dim=-2, padding="zeros") + transforms = Compose( + exclude_target_return, + cat, + ) collector_cfg = cfg.collector collector_class = SyncDataCollector - state_dict = get_stats(env_cfg) - # to exclude inference target returns - exclude = ExcludeTransform("return_to_go") # next return to go collector = collector_class( - make_parallel_env(env_cfg, state_dict=state_dict), + make_env(cfg.env, train=True), policy, frames_per_batch=collector_cfg.frames_per_batch, total_frames=collector_cfg.total_frames, device=collector_cfg.collector_devices, max_frames_per_traj=collector_cfg.max_frames_per_traj, - postproc=exclude, + postproc=transforms, ) return collector -def make_replay_buffer(rb_cfg): +def make_offline_replay_buffer(rb_cfg, reward_scaling): r2g = Reward2GoTransform(gamma=1.0, out_keys=["return_to_go"]) - transforms = [r2g] - sampler = RandomSampler() - return TensorDictReplayBuffer( - storage=LazyMemmapStorage(rb_cfg.capacity), - sampler=sampler, - transform=Compose(*transforms), + reward_scale = RewardScaling( + loc=0, scale=reward_scaling, in_keys="return_to_go", standard_normal=False + ) + catframes = CatFrames( + in_keys=["action", "observation", "return_to_go"], + N=rb_cfg.stacked_frames, + dim=-2, + padding="zeros", ) + d2f = DoubleToFloat( + in_keys=["observation", ("next", "observation")], + in_keys_inv=[], + ) + exclude = ExcludeTransform( + "next_observations", + "timeout", + "terminal", + "info", + ("next", "timeout"), + ("next", "terminal"), + ("next", "observation"), + ("next", "info"), + ) -def make_offline_replay_buffer(rb_cfg): - r2g = Reward2GoTransform(gamma=1.0, out_keys=["return_to_go"]) - cat_r2g = CatFrames(in_keys=["return_to_go"], N=rb_cfg.stacked_frames, dim=-2) - cat_obs = CatFrames(in_keys=["observation"], N=rb_cfg.stacked_frames, dim=-2) - cat_actions = CatFrames(in_keys=["action"], N=rb_cfg.stacked_frames, dim=-2) - exclude_next_obs = ExcludeTransform("next_observations") - transforms = Compose(r2g, cat_r2g, cat_obs, cat_actions, exclude_next_obs) + transforms = Compose( + d2f, + r2g, + reward_scale, + catframes, + exclude, + ) data = D4RLExperienceReplay( rb_cfg.dataset, split_trajs=False, @@ -174,32 +215,52 @@ def make_offline_replay_buffer(rb_cfg): sampler=SamplerWithoutReplacement(drop_last=False), transform=transforms, ) - # data.append_transform( - # Reward2GoTransform(gamma=1.0, out_keys=["return_to_go"]) - # ) - # data.append_transform( + # TODO: add obsnorm here - # ) - # data.append_transform( + return data - # ) - return data +def make_online_replay_buffer(offline_buffer, rb_cfg, reward_scaling=0.001): + offline_data = offline_buffer.sample(100000) + offline_data.del_("return_to_go") + offline_data.del_("index") # delete + + r2g = Reward2GoTransform(gamma=1.0, out_keys=["return_to_go"]) + reward_scale = RewardScaling( + loc=0, scale=reward_scaling, in_keys="return_to_go", standard_normal=False + ) + catframes = CatFrames( + in_keys=["return_to_go"], N=rb_cfg.stacked_frames, dim=-2, padding="zeros" + ) + transforms = Compose( + r2g, + reward_scale, + catframes, + ) + storage = LazyMemmapStorage( + rb_cfg.capacity, rb_cfg.buffer_scratch_dir, device=rb_cfg.device + ) + + replay_buffer = TensorDictReplayBuffer( + pin_memory=False, + prefetch=rb_cfg.prefetch, + transform=transforms, + storage=storage, + batch_size=rb_cfg.batch_size, + ) + # init buffer with offline data + # replay_buffer.extend(offline_data.clone().detach().to_tensordict()) + + return replay_buffer # ==================================================================== # Model # ----- -# -# We give one version of the model for learning from pixels, and one for state. -# TorchRL comes in handy at this point, as the high-level interactions with -# these models is unchanged, regardless of the modality. -# def make_decision_transformer_model(cfg): env_cfg = cfg.env - # model_cfg = cfg.model proof_environment = make_transformed_env(make_base_env(env_cfg), env_cfg) action_spec = proof_environment.action_spec @@ -210,11 +271,22 @@ def make_decision_transformer_model(cfg): "observation", "action", "return_to_go", - # "timesteps", - ] # return_to_go, timesteps + ] - actor_net = DTActor(state_dim=state_dim, action_dim=action_spec.shape[-1]) + actor_net = DTActor( + state_dim=state_dim, + action_dim=action_spec.shape[-1], + transformer_config=cfg.transformer, + ) + actor_module = TensorDictModule( + actor_net, + in_keys=in_keys, + out_keys=[ + "loc", + "scale", + ], + ) dist_class = TanhNormal dist_kwargs = { "min": -1.0, @@ -222,29 +294,28 @@ def make_decision_transformer_model(cfg): "tanh_loc": False, } - actor_module = TensorDictModule( - actor_net, in_keys=in_keys, out_keys=["loc", "scale"] # , "hidden_state"], - ) actor = ProbabilisticActor( spec=action_spec, - in_keys=["loc", "scale"], # , "hidden_state"], - out_keys=["action", "log_prob"], # , "hidden_state"], + in_keys=["loc", "scale"], + out_keys=["action", "log_prob"], module=actor_module, distribution_class=dist_class, distribution_kwargs=dist_kwargs, default_interaction_mode="random", cache_dist=False, - return_log_prob=True, + return_log_prob=False, ) # init the lazy layers with torch.no_grad(), set_exploration_mode("random"): td = proof_environment.rollout(max_steps=100) td["action"] = td["next", "action"] - print(td) actor(td) - return actor + inference_actor = DecisionTransformerInferenceWrapper( + actor, + ) + return inference_actor, actor # ==================================================================== @@ -256,20 +327,29 @@ def make_loss(loss_cfg, actor_network): loss = OnlineDTLoss( actor_network, loss_cfg.alpha_init, - loss_cfg.min_alpha, - loss_cfg.max_alpha, ) return loss -def make_dt_optimizer(optim_cfg, actor_network): +def make_dt_optimizer(optim_cfg, actor_network, loss): # Should be Lambda Optimizer - optimizer = torch.optim.Adam( + dt_optimizer = torch.optim.Adam( actor_network.parameters(), lr=optim_cfg.lr, weight_decay=optim_cfg.weight_decay, + eps=1.0e-8, + ) + scheduler = torch.optim.lr_scheduler.LambdaLR( + dt_optimizer, lambda steps: min((steps + 1) / optim_cfg.warmup_steps, 1) ) - return optimizer + + log_temp_optimizer = torch.optim.Adam( + [loss.log_alpha], + lr=1e-4, + betas=[0.9, 0.999], + ) + + return dt_optimizer, log_temp_optimizer, scheduler # ==================================================================== diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 35508d298ab..4b086a73c70 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -1107,7 +1107,6 @@ def _apply_transform( # if target is stacked target_return = target_return[-1] - reward else: - # reward.ndim == 2 and target_return.ndim == 2: target_return = target_return - reward return target_return elif self.mode == "constant": @@ -1115,7 +1114,6 @@ def _apply_transform( # if target is stacked target_return = target_return[-1] - reward else: - # reward.ndim == target_return.ndim target_return = target_return - reward else: raise ValueError("Unknown mode: {}".format(self.mode)) @@ -2109,7 +2107,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: # If so, we must add an offset data = tensordict.get(in_key) if isinstance(in_key, tuple) and in_key[0] == "next": - # let's get the out_key we have already processed prev_out_key = dict(zip(self.in_keys, self.out_keys))[in_key[1]] prev_val = tensordict.get(prev_out_key) @@ -4166,7 +4163,7 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: item = self._inv_apply_transform( tensordict.get(in_key), done_or_truncated ) - tensordict.set(out_key, item, inplace=True) + tensordict.set(out_key, item) if not found: raise KeyError(f"Could not find any of the input keys {self.in_keys}.") return tensordict diff --git a/torchrl/modules/models/decision_transformer.py b/torchrl/modules/models/decision_transformer.py index 593f6319789..0528d01d993 100644 --- a/torchrl/modules/models/decision_transformer.py +++ b/torchrl/modules/models/decision_transformer.py @@ -5,39 +5,38 @@ class DecisionTransformer(nn.Module): - """Decion Transformer as described in https://arxiv.org/abs/2202.05607 .""" + """online Decion Transformer as described in https://arxiv.org/abs/2202.05607 .""" def __init__( - self, state_dim, action_dim, hidden_size=512, max_ep_len=1000, ordering=False + self, + state_dim, + action_dim, + config, ): super(DecisionTransformer, self).__init__() - assert hidden_size == 512, "Only hidden_size=512 is supported" + gpt_config = transformers.GPT2Config( - n_embd=512, - n_layer=4, - n_head=4, - n_inner=4 * 512, - activation_function="relu", - n_positions=1024, - resid_pdrop=0.1, - attn_pdrop=0.1, + n_embd=config.n_embd, + n_layer=config.n_layer, + n_head=config.n_head, + n_inner=config.n_inner, + activation_function=config.activation, + n_positions=config.n_positions, + resid_pdrop=config.resid_pdrop, + attn_pdrop=config.attn_pdrop, vocab_size=1, ) self.state_dim = state_dim self.action_dim = action_dim - self.hidden_size = hidden_size - self.ordering = ordering - self.train_context = 20 - self.inference_context = 5 + self.hidden_size = config.n_embd self.transformer = GPT2Model(config=gpt_config) - if ordering: - self.embed_ordering = nn.Embedding(max_ep_len, hidden_size) - self.embed_return = torch.nn.Linear(1, hidden_size) - self.embed_state = torch.nn.Linear(self.state_dim, hidden_size) - self.embed_action = torch.nn.Linear(self.action_dim, hidden_size) - self.embed_ln = nn.LayerNorm(hidden_size) + self.embed_return = torch.nn.Linear(1, config.n_embd) + self.embed_state = torch.nn.Linear(self.state_dim, config.n_embd) + self.embed_action = torch.nn.Linear(self.action_dim, config.n_embd) + + self.embed_ln = nn.LayerNorm(config.n_embd) def forward( self, @@ -47,28 +46,11 @@ def forward( ): batch_size, seq_length = observation.shape[0], observation.shape[1] - # if mask_context: - # ( observation, - # action, - # return_to_go, - # # timesteps, - # seq_length, - # ) = self.mask_context(observation, action, return_to_go)# timesteps - # embed each modality with a different head state_embeddings = self.embed_state(observation) action_embeddings = self.embed_action(action) returns_embeddings = self.embed_return(return_to_go) - # if self.ordering: - # order_embeddings = self.embed_ordering(timesteps) - # else: - order_embeddings = 0.0 - - state_embeddings = state_embeddings + order_embeddings - action_embeddings = action_embeddings + order_embeddings - returns_embeddings = returns_embeddings + order_embeddings - # this makes the sequence look like (R_1, s_1, a_1, R_2, s_2, a_2, ...) # which works nice in an autoregressive sense since states predict actions stacked_inputs = ( @@ -91,20 +73,3 @@ def forward( x = x.reshape(batch_size, seq_length, 3, self.hidden_size).permute(0, 2, 1, 3) return x[:, 1] # only state tokens - - def mask_context( - self, - observation: torch.Tensor, - action: torch.Tensor, - return_to_go: torch.Tensor, - # timesteps: torch.Tensor, - ): - """Mask the context of the input sequences.""" - observation[:, : -self.inference_context, :] = 0 - action[:, : -self.inference_context, :] = 0 - action = torch.cat( - [action[:, 1:], torch.zeros(action.shape[0], 1, self.action_dim)], dim=-2 - ) - return_to_go[:, : -self.inference_context, :] = 0 - # timesteps[:, : -self.inference_context] = 0 - return observation, action, return_to_go, self.train_context # timesteps diff --git a/torchrl/modules/models/models.py b/torchrl/modules/models/models.py index 431de7a39b9..502efdc4872 100644 --- a/torchrl/modules/models/models.py +++ b/torchrl/modules/models/models.py @@ -13,6 +13,7 @@ from torchrl._utils import prod from torchrl.data.utils import DEVICE_TYPING +from torchrl.modules.models.decision_transformer import DecisionTransformer from torchrl.modules.models.utils import ( _find_depth, create_on_device, @@ -1139,61 +1140,42 @@ def forward( return self._lstm(input, hidden0_in, hidden1_in) -from torchrl.modules.models.decision_transformer import DecisionTransformer - - class DTActor(nn.Module): """Decision Transformer Actor class. Presented in "Online Decision Transformer", https://arxiv.org/abs/2202.05607.pdf - The DDPG Actor takes as input an observation vector and returns an action from it. - It is trained to maximise the value returned by the DDPG Q Value network. - Args: - action_dim (int): length of the action vector - mlp_net_kwargs (dict, optional): kwargs for MLP. - Default: { - 'in_features': None, - 'out_features': action_dim, - 'depth': 2, - 'num_cells': [400, 300], - 'activation_class': nn.ELU, - 'bias_last_layer': True, - } - device (Optional[DEVICE_TYPING]): device to create the module on. """ def __init__( self, state_dim: int, action_dim: int, - mlp_net_kwargs: Optional[dict] = None, + transformer_config: Dict, device: Optional[DEVICE_TYPING] = None, ): super().__init__() - mlp_net_default_kwargs = { - "out_features": action_dim * 2, - "depth": 1, - "num_cells": [512], - "activation_class": nn.ReLU, - "bias_last_layer": True, - } self.transformer = DecisionTransformer( state_dim=state_dim, action_dim=action_dim, - hidden_size=512, - max_ep_len=1000, - ordering=False, + config=transformer_config, + ) + self.action_layer = nn.Linear( + transformer_config.n_embd, action_dim * 2, device=device ) - # log_std_bounds: Tuple[float, float] = [-5.0, 2.0], - log_std_bounds = [-5.0, 2.0] - self.log_std_bounds = log_std_bounds - mlp_net_kwargs = mlp_net_kwargs if mlp_net_kwargs is not None else {} - mlp_net_default_kwargs.update(mlp_net_kwargs) - self.mlp = MLP(device=device, **mlp_net_default_kwargs) - # self.apply(dt_actor_weight_init) + + self.log_std_min, self.log_std_max = -5.0, 2.0 + + def weight_init(m): + """Custom weight init for Conv2D and Linear layers.""" + if isinstance(m, torch.nn.Linear): + nn.init.orthogonal_(m.weight.data) + if hasattr(m.bias, "data"): + m.bias.data.fill_(0.0) + + self.apply(weight_init) def forward( self, @@ -1201,25 +1183,19 @@ def forward( action: torch.Tensor, return_to_go: torch.Tensor, ) -> torch.Tensor: - if observation.ndim == 2: - observation = observation.unsqueeze(0).float() + observation = observation.unsqueeze(0) action = action.unsqueeze(0) return_to_go = return_to_go.unsqueeze(0) - hidden_state = self.transformer(observation, action, return_to_go) # timesteps - out = self.mlp(hidden_state)[:, -1] - mu, log_std = out.chunk(2, -1) + hidden_state = self.transformer(observation, action, return_to_go) + out = self.action_layer(hidden_state) + mu, log_std = torch.chunk(out, 2, -1) log_std = torch.tanh(log_std) - log_std = min(self.log_std_bounds) + 0.5 * ( - max(self.log_std_bounds) - min(self.log_std_bounds) - ) * (log_std + 1.0) - std = torch.exp(log_std) - return (mu, std) - + # log_std is the output of tanh so it will be between [-1, 1] + # map it to be between [log_std_min, log_std_max] + log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * ( + log_std + 1.0 + ) + std = log_std.exp() -def dt_actor_weight_init(m): - """Weight init used in the Decision Transformer for the actor layers.""" - if isinstance(m, torch.nn.Linear): - nn.init.orthogonal_(m.weight.data) - if hasattr(m.bias, "data"): - m.bias.data.fill_(0.0) + return (mu, std) diff --git a/torchrl/modules/tensordict_module/__init__.py b/torchrl/modules/tensordict_module/__init__.py index 36ee2045950..484abbc71c5 100644 --- a/torchrl/modules/tensordict_module/__init__.py +++ b/torchrl/modules/tensordict_module/__init__.py @@ -8,6 +8,7 @@ ActorCriticOperator, ActorCriticWrapper, ActorValueOperator, + DecisionTransformerInferenceWrapper, DistributionalQValueActor, DistributionalQValueHook, DistributionalQValueModule, diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index fdbd47f3f94..5adda6470b1 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -295,7 +295,6 @@ def __init__( in_keys: Optional[Sequence[str]] = None, out_keys: Optional[Sequence[str]] = None, ) -> None: - if in_keys is None: in_keys = ["observation"] if out_keys is None: @@ -1103,7 +1102,6 @@ def __init__( action_value_key: str = "action_value", make_log_softmax: bool = True, ): - action_space, spec = _process_action_space_spec(action_space, spec) self.action_space = action_space self.action_value_key = action_value_key @@ -1579,3 +1577,103 @@ def get_value_operator(self) -> SafeSequential: get_policy_head = get_policy_operator get_value_head = get_value_operator + + +class DecisionTransformerInferenceWrapper(TensorDictModuleWrapper): + """Inference Action Wrapper for the Decision Transformer. + + A wrapper specifically designed for the Decision Transformer, which will mask the context of the + input tensordict to the inferece context. The output will be a TensorDict with be the input TensorDict + and the last predicted action of the predicted output sequence. + + Args: + policy (TensorDictModule): The policy module that takes in + observations and produces an action value + inference_context (int): The number of previous actions that will not be masked in the context. + For example for an observation input of shape [batch_size, context, obs_dim] with context=20 and inference_context=5, the first 15 entries + of the context will be masked. + observation_key (str): The key of the observation in the input TensorDict + action_key (str): The key of the action in the input TensorDict + return_to_go_key (str): The key of the return to go in the input TensorDict + spec (Optional[TensorSpec]): The spec of the input TensorDict. If None, it will be inferred from the policy module. + """ + + def __init__( + self, + policy: TensorDictModule, + *, + inference_context: int = 5, + observation_key: str = "observation", + action_key: str = "action", + return_to_go_key: str = "return_to_go", + spec: Optional[TensorSpec] = None, + ): + super().__init__(policy) + self.observation_key = observation_key + self.action_key = action_key + self.return_to_go_key = return_to_go_key + self.inference_context = inference_context + if spec is not None: + if not isinstance(spec, CompositeSpec) and len(self.out_keys) >= 1: + spec = CompositeSpec({action_key: spec}, shape=spec.shape[:-1]) + self._spec = spec + elif hasattr(self.td_module, "_spec"): + self._spec = self.td_module._spec.clone() + if action_key not in self._spec.keys(): + self._spec[action_key] = None + elif hasattr(self.td_module, "spec"): + self._spec = self.td_module.spec.clone() + if action_key not in self._spec.keys(): + self._spec[action_key] = None + else: + self._spec = CompositeSpec({key: None for key in policy.out_keys}) + + def step(self, frames: int = 1) -> None: + pass + + @staticmethod + def _check_tensor_dims(reward, obs, action): + if not (reward.shape[:-1] == obs.shape[:-1] == action.shape[:-1]): + raise ValueError( + "Mismatched tensor dimensions. This is not supported yet, file an issue on torchrl" + ) + + def mask_context(self, tensordict: TensorDictBase) -> TensorDictBase: + """Mask the context of the input sequences.""" + observation = tensordict.get(self.observation_key) + action = tensordict.get(self.action_key) + return_to_go = tensordict.get(self.return_to_go_key) + self._check_tensor_dims(return_to_go, observation, action) + + observation[..., : -self.inference_context, :] = 0 + action[..., : -self.inference_context, :] = 0 + action = torch.cat( + [ + action[:, 1:], + torch.zeros(action.shape[0], 1, action.shape[-1], device=action.device), + ], + dim=-2, + ) + return_to_go[..., : -self.inference_context, :] = 0 + + tensordict.set(self.observation_key, observation) + tensordict.set(self.action_key, action) + tensordict.set(self.return_to_go_key, return_to_go) + return tensordict + + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + """Forward pass of the inference wrapper.""" + unmasked_tensordict = tensordict.clone() + # Mask the context of the input sequences + tensordict = self.mask_context(tensordict) + # forward pass + tensordict = self.td_module.forward(tensordict) + # get last action prediciton + out_action = tensordict.get(self.action_key)[:, -1] + tensordict.set(self.action_key, out_action) + out_rtg = tensordict.get(self.return_to_go_key)[:, -1] + tensordict.set(self.return_to_go_key, out_rtg) + tensordict.set( + self.observation_key, unmasked_tensordict.get(self.observation_key) + ) + return tensordict diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index 86cf8b8bbc6..3ac9ef8ba01 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -561,7 +561,6 @@ def _make_noise_pair(self, tensordict: TensorDictBase, is_init=None) -> None: def add_sample( self, tensordict: TensorDictBase, eps: float = 1.0 ) -> TensorDictBase: - if self.noise_key not in tensordict.keys(): self._make_noise_pair(tensordict) is_init = tensordict.get("is_init", None) diff --git a/torchrl/objectives/decision_transformer.py b/torchrl/objectives/decision_transformer.py index 96b40d74c3a..d50ede5d4b8 100644 --- a/torchrl/objectives/decision_transformer.py +++ b/torchrl/objectives/decision_transformer.py @@ -9,6 +9,7 @@ import torch from tensordict.tensordict import TensorDict, TensorDictBase +from torch import distributions as d from torchrl.modules import ProbabilisticActor from .common import LossModule @@ -20,26 +21,8 @@ class OnlineDTLoss(LossModule): Presented in "Online Decision Transformer" https://arxiv.org/abs/2202.05607 Args: actor_network (ProbabilisticActor): stochastic actor - qvalue_network (SafeModule): Q(s, a) parametric model - value_network (SafeModule, optional): V(s) parametric model. If not - provided, the second version of SAC is assumed. - qvalue_network_bis (ProbabilisticTDModule, optional): if required, the - Q-value can be computed twice independently using two separate - networks. The minimum predicted value will then be used for - inference. - gamma (number, optional): discount for return computation - Default is 0.99 - priority_key (str, optional): tensordict key where to write the - priority (for prioritized replay buffer usage). Default is - `"td_error"`. - loss_function (str, optional): loss function to be used with - the value function loss. Default is `"smooth_l1"`. - temperature (float, optional): Inverse temperature (beta). - For smaller hyperparameter values, the objective behaves similarly to - behavioral cloning, while for larger values, it attempts to recover the - maximum of the Q-function. - expectile (float, optional): expectile :math:`\tau`. A larger value of :math:`\tau` is crucial - for antmaze tasks that require dynamical programming ("stichting"). + alpha_init: + samples_mc_entropy: """ @@ -47,8 +30,7 @@ def __init__( self, actor_network: ProbabilisticActor, alpha_init: float = 1.0, - min_alpha: float = 0.1, - max_alpha: float = 10.0, + samples_mc_entropy: int = 1, ) -> None: super().__init__() @@ -64,13 +46,6 @@ def __init__( except AttributeError: device = torch.device("cpu") self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device)) - self.register_buffer( - "min_log_alpha", torch.tensor(min_alpha, device=device).log() - ) - self.register_buffer( - "max_log_alpha", torch.tensor(max_alpha, device=device).log() - ) - self.register_parameter( "log_alpha", torch.nn.Parameter(torch.tensor(math.log(alpha_init), device=device)), @@ -80,53 +55,43 @@ def __init__( self.register_buffer( "target_entropy", torch.tensor(target_entropy, device=device) ) + self.samples_mc_entropy = samples_mc_entropy @property def device(self) -> torch.device: for p in self.parameters(): return p.device raise RuntimeError( - "At least one of the networks of SACLoss must have trainable " "parameters." + "At least one of the networks of OnlineDTLoss must have trainable " + "parameters." ) - def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - """Compute the loss for the Online Decision Transformer. - - # a_hat is a SquashedNormal Distribution - log_likelihood = a_hat_dist.log_likelihood(a)[attention_mask > 0].mean() - - entropy = a_hat_dist.entropy().mean() - loss = -(log_likelihood + entropy_reg * entropy) + def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor: + x = dist.rsample((self.samples_mc_entropy,)) + log_p = dist.log_prob(x) + # log_p: (batch_size, context_len, + return -log_p.mean(axis=0) - return ( - loss, - -log_likelihood, - entrop y, - ) - dist.log_prob(x).sum(axis=2) - """ - out_td = self.actor_network(tensordict) + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + """Compute the loss for the Online Decision Transformer.""" + # extract action targets + target_actions = torch.clone(tensordict["action"].detach()).to(self.device) - target_actions = tensordict["target_actions"] + action_dist = self.actor_network.get_dist( + tensordict.to(self.device), params=self.actor_network_params + ) - loss_log_likelihood = out_td["action_log_prob"](target_actions).sum(axis=2) - entropy = 0 # action_dist.entropy().mean() - loss = -(loss_log_likelihood + self.target_entropy.detach() * entropy) + loss_log_likelihood = action_dist.log_prob(target_actions).mean() + entropy = self.get_entropy_bonus(action_dist).mean() + loss = -(loss_log_likelihood + self.log_alpha.exp().detach() * entropy) loss_alpha = self.log_alpha.exp() * (entropy - self.target_entropy).detach() out = { - "loss": loss.mean(), - "loss_log_likelihood": loss_log_likelihood.mean(), - "entropy": entropy.mean(), - "loss_alpha": loss_alpha.mean(), - "alpha": self._alpha, + "loss": loss, + "loss_log_likelihood": -loss_log_likelihood, + "entropy": entropy, + "loss_alpha": loss_alpha, + "alpha": self.log_alpha.exp(), } return TensorDict(out, []) - - @property - def _alpha(self): - self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha) - with torch.no_grad(): - alpha = self.log_alpha.exp() - return alpha From 34fc6e884eb952ea46051fc948f560bf904bf9e6 Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 11 May 2023 16:39:55 +0200 Subject: [PATCH 013/104] fix --- examples/decision_transformer/dt_offline.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/decision_transformer/dt_offline.py b/examples/decision_transformer/dt_offline.py index aa973e60b96..3ac22f06c84 100644 --- a/examples/decision_transformer/dt_offline.py +++ b/examples/decision_transformer/dt_offline.py @@ -97,7 +97,6 @@ def main(cfg: "DictConfig"): # noqa: F821 online_buffer = make_online_replay_buffer( offline_buffer, cfg.replay_buffer, cfg.env.reward_scaling ) - # online_buffer = offline_buffer collected_frames = 0 pbar = tqdm.tqdm(total=cfg.env.total_online_frames) From 0200e29403705e4681d6b6dd465ee630c7a94cc6 Mon Sep 17 00:00:00 2001 From: BY571 Date: Fri, 12 May 2023 09:58:24 +0200 Subject: [PATCH 014/104] small fixes --- examples/decision_transformer/dt_offline.py | 1 + examples/decision_transformer/utils.py | 15 +++++++-------- torchrl/modules/tensordict_module/actors.py | 1 + 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/examples/decision_transformer/dt_offline.py b/examples/decision_transformer/dt_offline.py index 3ac22f06c84..cb7529ea71a 100644 --- a/examples/decision_transformer/dt_offline.py +++ b/examples/decision_transformer/dt_offline.py @@ -119,6 +119,7 @@ def main(cfg: "DictConfig"): # noqa: F821 pbar.update(current_frames) tensordict = tensordict.reshape(-1) + # only used for logging tensordict.del_("episode_reward") online_buffer.extend(tensordict.cpu().clone().detach()) diff --git a/examples/decision_transformer/utils.py b/examples/decision_transformer/utils.py index 2cae0b15b7b..c1946395379 100644 --- a/examples/decision_transformer/utils.py +++ b/examples/decision_transformer/utils.py @@ -202,10 +202,10 @@ def make_offline_replay_buffer(rb_cfg, reward_scaling): ) transforms = Compose( - d2f, r2g, reward_scale, catframes, + d2f, exclude, ) data = D4RLExperienceReplay( @@ -221,10 +221,6 @@ def make_offline_replay_buffer(rb_cfg, reward_scaling): def make_online_replay_buffer(offline_buffer, rb_cfg, reward_scaling=0.001): - offline_data = offline_buffer.sample(100000) - offline_data.del_("return_to_go") - offline_data.del_("index") # delete - r2g = Reward2GoTransform(gamma=1.0, out_keys=["return_to_go"]) reward_scale = RewardScaling( loc=0, scale=reward_scaling, in_keys="return_to_go", standard_normal=False @@ -235,7 +231,7 @@ def make_online_replay_buffer(offline_buffer, rb_cfg, reward_scaling=0.001): transforms = Compose( r2g, reward_scale, - catframes, + catframes, # TODO: cat frames is not an inverse transform doesnt get triggered! ) storage = LazyMemmapStorage( rb_cfg.capacity, rb_cfg.buffer_scratch_dir, device=rb_cfg.device @@ -244,12 +240,15 @@ def make_online_replay_buffer(offline_buffer, rb_cfg, reward_scaling=0.001): replay_buffer = TensorDictReplayBuffer( pin_memory=False, prefetch=rb_cfg.prefetch, - transform=transforms, storage=storage, batch_size=rb_cfg.batch_size, ) # init buffer with offline data - # replay_buffer.extend(offline_data.clone().detach().to_tensordict()) + offline_data = offline_buffer.sample(100000) + offline_data.del_("index") + replay_buffer.extend(offline_data.clone().detach().to_tensordict()) + # add transforms after offline data extension to not trigger reward-to-go calculation + replay_buffer.append_transform(transforms) return replay_buffer diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index 5adda6470b1..c94192e1265 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -1673,6 +1673,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict.set(self.action_key, out_action) out_rtg = tensordict.get(self.return_to_go_key)[:, -1] tensordict.set(self.return_to_go_key, out_rtg) + # set unmasked observation tensordict.set( self.observation_key, unmasked_tensordict.get(self.observation_key) ) From 94707978545ce9ab36a4f34f6b9f4bf80b77d084 Mon Sep 17 00:00:00 2001 From: BY571 Date: Fri, 12 May 2023 12:19:15 +0200 Subject: [PATCH 015/104] update DT loss docstring --- torchrl/objectives/decision_transformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchrl/objectives/decision_transformer.py b/torchrl/objectives/decision_transformer.py index d50ede5d4b8..4d246267a6a 100644 --- a/torchrl/objectives/decision_transformer.py +++ b/torchrl/objectives/decision_transformer.py @@ -21,8 +21,8 @@ class OnlineDTLoss(LossModule): Presented in "Online Decision Transformer" https://arxiv.org/abs/2202.05607 Args: actor_network (ProbabilisticActor): stochastic actor - alpha_init: - samples_mc_entropy: + alpha_init (float): initial value of the temperature parameter + samples_mc_entropy (int): number of samples to estimate the entropy """ From 6b8185db9e79d933c414d3f524be57c1cc0dea51 Mon Sep 17 00:00:00 2001 From: BY571 Date: Fri, 12 May 2023 13:23:18 +0200 Subject: [PATCH 016/104] update dt inference wrapper docstring with example --- torchrl/modules/tensordict_module/actors.py | 57 +++++++++++++++++++-- 1 file changed, 54 insertions(+), 3 deletions(-) diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index c94192e1265..1f1b0bfd811 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -1582,9 +1582,10 @@ def get_value_operator(self) -> SafeSequential: class DecisionTransformerInferenceWrapper(TensorDictModuleWrapper): """Inference Action Wrapper for the Decision Transformer. - A wrapper specifically designed for the Decision Transformer, which will mask the context of the - input tensordict to the inferece context. The output will be a TensorDict with be the input TensorDict - and the last predicted action of the predicted output sequence. + A wrapper specifically designed for the Decision Transformer, which will mask the + input tensordict sequences to the inferece context. + The output will be a TensorDict with the same keys as the input, but with only the last + action of the predicted action sequence and the last return to go. Args: policy (TensorDictModule): The policy module that takes in @@ -1596,6 +1597,56 @@ class DecisionTransformerInferenceWrapper(TensorDictModuleWrapper): action_key (str): The key of the action in the input TensorDict return_to_go_key (str): The key of the return to go in the input TensorDict spec (Optional[TensorSpec]): The spec of the input TensorDict. If None, it will be inferred from the policy module. + + Examples: + >>> import torch + >>> from tensordict import TensorDict + >>> from tensordict.nn import TensorDictModule + >>> from torchrl.modules import ( + ... ProbabilisticActor, + ... DTActor, + ... TanhNormal, + ... DecisionTransformerInferenceWrapper, + ... ) + + >>> actor_module = TensorDictModule( + DTActor(state_dim=4, action_dim=2), + in_keys=in_keys, + out_keys=[ + "loc", + "scale",]) + >>> dist_class = TanhNormal + >>> dist_kwargs = { + "min": -1.0, + "max": 1.0, + "tanh_loc": False, + } + >>> actor = ProbabilisticActor( + in_keys=["loc", "scale"], + out_keys=["action", "log_prob"], + module=actor_module, + distribution_class=dist_class, + distribution_kwargs=dist_kwargs) + + >>> inference_actor = DecisionTransformerInferenceWrapper(actor) + >>> print(inference_actor) + >>> sequence_length = 20 + >>> td = TensorDict({"observation": torch.randn(1, sequence_length, 4), + "action": torch.randn(1, sequence_length, 2), + "return_to_go": torch.randn(1, sequence_length, 1)}, [1,]) + + >>> print(inference_actor(td.clone())) + TensorDict( + fields={ + action: Tensor(shape=torch.Size([1, 2]), device=cpu, dtype=torch.float32, is_shared=False), + loc: Tensor(shape=torch.Size([1, 2]), device=cpu, dtype=torch.float32, is_shared=False), + observation: Tensor(shape=torch.Size([1, 20, 4]), device=cpu, dtype=torch.float32, is_shared=False), + sample_log_prob: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), + scale: Tensor(shape=torch.Size([1, 2]), device=cpu, dtype=torch.float32, is_shared=False), + return_to_go: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([1]), + device=None, + is_shared=False) """ def __init__( From 76e3a274754e50d46a6837d9c3d445dfdcc8a79d Mon Sep 17 00:00:00 2001 From: BY571 Date: Fri, 12 May 2023 14:55:25 +0200 Subject: [PATCH 017/104] add odt cost tests --- test/test_cost.py | 156 +++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 146 insertions(+), 10 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index b0acc66af30..aee439acfb1 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -81,6 +81,7 @@ DreamerValueLoss, IQLLoss, KLPENPPOLoss, + OnlineDTLoss, PPOLoss, SACLoss, TD3Loss, @@ -1508,7 +1509,6 @@ def test_discrete_sac( target_entropy, td_est, ): - torch.manual_seed(self.seed) td = self._create_mock_data_sac(device=device) @@ -1856,7 +1856,6 @@ def _create_seq_mock_data_redq( @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) def test_redq(self, delay_qvalue, num_qvalue, device, td_est): - torch.manual_seed(self.seed) td = self._create_mock_data_redq(device=device) @@ -1945,7 +1944,6 @@ def test_redq(self, delay_qvalue, num_qvalue, device, td_est): @pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8]) @pytest.mark.parametrize("device", get_available_devices()) def test_redq_shared(self, delay_qvalue, num_qvalue, device): - torch.manual_seed(self.seed) td = self._create_mock_data_redq(device=device) @@ -2050,7 +2048,6 @@ def test_redq_shared(self, delay_qvalue, num_qvalue, device): @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) def test_redq_batched(self, delay_qvalue, num_qvalue, device, td_est): - torch.manual_seed(self.seed) td = self._create_mock_data_redq(device=device) @@ -3317,6 +3314,146 @@ def test_dreamer_value(self, device, discount_loss): loss_module.zero_grad() +class TestOnlineDT: + seed = 0 + + def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): + # Actor + action_spec = BoundedTensorSpec( + -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) + ) + net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) + module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) + actor = ProbabilisticActor( + module=module, + distribution_class=TanhNormal, + in_keys=["loc", "scale"], + spec=action_spec, + ) + return actor.to(device) + + def _create_mock_data_odt(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): + # create a tensordict + obs = torch.randn(batch, obs_dim, device=device) + action = torch.randn(batch, action_dim, device=device).clamp(-1, 1) + reward2go = torch.randn(batch, 1, device=device) + td = TensorDict( + batch_size=(batch,), + source={ + "observation": obs, + "action": action, + "reward2go": reward2go, + }, + device=device, + ) + return td + + def _create_seq_mock_data_odt( + self, batch=2, T=4, obs_dim=3, action_dim=4, device="cpu" + ): + # create a tensordict + obs = torch.randn(batch, T, obs_dim, device=device) + action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) + reward2go = torch.randn(batch, T, 1, device=device) + + td = TensorDict( + batch_size=(batch, T), + source={ + "observation": obs, + "reward": reward2go, + "action": action, + }, + device=device, + ) + return td + + @pytest.mark.parametrize("device", get_available_devices()) + def test_odt(self, device): + torch.manual_seed(self.seed) + td = self._create_mock_data_odt(device=device) + + actor = self._create_mock_actor(device=device) + + loss_fn = OnlineDTLoss(actor) + loss = loss_fn(td) + loss_transformer = loss["loss"] + loss_alpha = loss["loss_alpha"] + loss_transformer.backward(retain_graph=True) + named_parameters = loss_fn.named_parameters() + + for name, p in named_parameters: + if p.grad is not None and p.grad.norm() > 0.0: + assert "actor" in name + assert "alpha" not in name + if p.grad is None: + assert "actor" not in name + assert "alpha" in name + loss_fn.zero_grad() + loss_alpha.backward(retain_graph=True) + named_parameters = loss_fn.named_parameters() + for name, p in named_parameters: + if p.grad is not None and p.grad.norm() > 0.0: + assert "actor" not in name + assert "alpha" in name + if p.grad is None: + assert "actor" in name + assert "alpha" not in name + loss_fn.zero_grad() + + sum([loss_transformer, loss_alpha]).backward() + named_parameters = list(loss_fn.named_parameters()) + named_buffers = list(loss_fn.named_buffers()) + + assert len({p for n, p in named_parameters}) == len(list(named_parameters)) + assert len({p for n, p in named_buffers}) == len(list(named_buffers)) + + for name, p in named_parameters: + assert p.grad.norm() > 0.0, f"parameter {name} has a null gradient" + + @pytest.mark.parametrize("device", get_available_devices()) + def test_seq_odt(self, device): + torch.manual_seed(self.seed) + td = self._create_seq_mock_data_odt(device=device) + + actor = self._create_mock_actor(device=device) + + loss_fn = OnlineDTLoss(actor) + loss = loss_fn(td) + loss_transformer = loss["loss"] + loss_alpha = loss["loss_alpha"] + loss_transformer.backward(retain_graph=True) + named_parameters = loss_fn.named_parameters() + + for name, p in named_parameters: + if p.grad is not None and p.grad.norm() > 0.0: + assert "actor" in name + assert "alpha" not in name + if p.grad is None: + assert "actor" not in name + assert "alpha" in name + loss_fn.zero_grad() + loss_alpha.backward(retain_graph=True) + named_parameters = loss_fn.named_parameters() + for name, p in named_parameters: + if p.grad is not None and p.grad.norm() > 0.0: + assert "actor" not in name + assert "alpha" in name + if p.grad is None: + assert "actor" in name + assert "alpha" not in name + loss_fn.zero_grad() + + sum([loss_transformer, loss_alpha]).backward() + named_parameters = list(loss_fn.named_parameters()) + named_buffers = list(loss_fn.named_buffers()) + + assert len({p for n, p in named_parameters}) == len(list(named_parameters)) + assert len({p for n, p in named_buffers}) == len(list(named_buffers)) + + for name, p in named_parameters: + assert p.grad.norm() > 0.0, f"parameter {name} has a null gradient" + + class TestIQL: seed = 0 @@ -3439,7 +3576,6 @@ def test_iql( expectile, td_est, ): - torch.manual_seed(self.seed) td = self._create_mock_data_iql(device=device) @@ -3736,7 +3872,7 @@ def __init__(self): # total dist d0 = 0.0 - for (key, source_val) in upd._sources.items(True, True): + for key, source_val in upd._sources.items(True, True): if not isinstance(key, tuple): key = (key,) key = ("target_" + key[0], *key[1:]) @@ -3752,7 +3888,7 @@ def __init__(self): for i in range(value_network_update_interval + 1): # test that no update is occuring until value_network_update_interval d1 = 0.0 - for (key, source_val) in upd._sources.items(True, True): + for key, source_val in upd._sources.items(True, True): if not isinstance(key, tuple): key = (key,) key = ("target_" + key[0], *key[1:]) @@ -3767,7 +3903,7 @@ def __init__(self): assert upd.counter == 0 # test that a new update has occured d1 = 0.0 - for (key, source_val) in upd._sources.items(True, True): + for key, source_val in upd._sources.items(True, True): if not isinstance(key, tuple): key = (key,) key = ("target_" + key[0], *key[1:]) @@ -3780,7 +3916,7 @@ def __init__(self): elif mode == "soft": upd.step() d1 = 0.0 - for (key, source_val) in upd._sources.items(True, True): + for key, source_val in upd._sources.items(True, True): if not isinstance(key, tuple): key = (key,) key = ("target_" + key[0], *key[1:]) @@ -3793,7 +3929,7 @@ def __init__(self): upd.init_() upd.step() d2 = 0.0 - for (key, source_val) in upd._sources.items(True, True): + for key, source_val in upd._sources.items(True, True): if not isinstance(key, tuple): key = (key,) key = ("target_" + key[0], *key[1:]) From 082a75ec3cdeb9d0989f890c133ff342bfd98d3a Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 18 May 2023 14:21:55 +0200 Subject: [PATCH 018/104] try to add inverse catframes --- torchrl/envs/transforms/transforms.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index b0b1b57eb80..65cf29876c5 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -1974,6 +1974,7 @@ def __init__( in_keys: Optional[Sequence[str]] = None, out_keys: Optional[Sequence[str]] = None, padding="same", + as_inverse=False, ): if in_keys is None: in_keys = IMAGE_KEYS @@ -1996,6 +1997,7 @@ def __init__( ) # keeps track of calls to _reset since it's only _call that will populate the buffer self._just_reset = False + self.as_inverse = as_inverse def reset(self, tensordict: TensorDictBase) -> TensorDictBase: """Resets _buffers.""" @@ -2089,6 +2091,12 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec observation_spec.shape = torch.Size(shape) return observation_spec + def _inv_apply_transform(self, tensordict: TensorDictBase) -> torch.Tensor: + if self.as_inverse: + return self.forward(tensordict) + else: + raise KeyError("Inverse transform not implemented for this transform.") + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: # it is assumed that the last dimension of the tensordict is the time dimension if not tensordict.ndim or ( From 2b636a6c1d1f2e34df3a0943dbd5c54a04929f6e Mon Sep 17 00:00:00 2001 From: BY571 Date: Fri, 19 May 2023 09:48:11 +0200 Subject: [PATCH 019/104] as_inverse add to catframes --- examples/decision_transformer/utils.py | 22 ++++++++++--------- torchrl/data/replay_buffers/replay_buffers.py | 2 +- torchrl/envs/transforms/transforms.py | 2 +- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/examples/decision_transformer/utils.py b/examples/decision_transformer/utils.py index c1946395379..f135a144827 100644 --- a/examples/decision_transformer/utils.py +++ b/examples/decision_transformer/utils.py @@ -86,20 +86,17 @@ def make_transformed_env(base_env, env_cfg, train=False): in_keys_inv=[], ) ) - transformed_env.append_transform(UnsqueezeTransform(-2, in_keys=["observation"])) transformed_env.append_transform( - CatFrames(in_keys=["observation"], N=env_cfg.stacked_frames, dim=-2) + UnsqueezeTransform(-2, in_keys=["observation", "action", "return_to_go"]) ) - - transformed_env.append_transform(UnsqueezeTransform(-2, in_keys=["action"])) transformed_env.append_transform( - CatFrames(in_keys=["action"], N=env_cfg.stacked_frames, dim=-2) + CatFrames( + in_keys=["observation", "action", "return_to_go"], + N=env_cfg.stacked_frames, + dim=-2, + ) ) - transformed_env.append_transform(UnsqueezeTransform(-2, in_keys=["return_to_go"])) - transformed_env.append_transform( - CatFrames(in_keys=["return_to_go"], N=env_cfg.stacked_frames, dim=-2) - ) if train: transformed_env.append_transform(RewardSum()) @@ -184,6 +181,7 @@ def make_offline_replay_buffer(rb_cfg, reward_scaling): N=rb_cfg.stacked_frames, dim=-2, padding="zeros", + as_inverse=True, ) d2f = DoubleToFloat( @@ -226,7 +224,11 @@ def make_online_replay_buffer(offline_buffer, rb_cfg, reward_scaling=0.001): loc=0, scale=reward_scaling, in_keys="return_to_go", standard_normal=False ) catframes = CatFrames( - in_keys=["return_to_go"], N=rb_cfg.stacked_frames, dim=-2, padding="zeros" + in_keys=["return_to_go"], + N=rb_cfg.stacked_frames, + dim=-2, + padding="zeros", + as_inverse=True, ) transforms = Compose( r2g, diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index cfd84af832e..d73ee7e8678 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -311,7 +311,7 @@ def extend(self, data: Sequence) -> torch.Tensor: Indices of the data added to the replay buffer. """ if self._transform is not None and is_tensor_collection(data): - data = self._transform.inv(data) + data = self._transform.inv(data.get("_data")) elif self._transform is not None and len(self._transform): # Accepts transforms that act on "data" key data = self._transform.inv(TensorDict({"data": data}, [])).get("data") diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 65cf29876c5..6667c4d57be 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -2091,7 +2091,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec observation_spec.shape = torch.Size(shape) return observation_spec - def _inv_apply_transform(self, tensordict: TensorDictBase) -> torch.Tensor: + def _inv_call(self, tensordict: TensorDictBase) -> torch.Tensor: if self.as_inverse: return self.forward(tensordict) else: From b1788f501b03ff284f9cb6a8883d1103fe4b2fa3 Mon Sep 17 00:00:00 2001 From: BY571 Date: Mon, 22 May 2023 09:53:02 +0200 Subject: [PATCH 020/104] make dt / odt split --- examples/decision_transformer/config.yaml | 5 +- examples/decision_transformer/dt_offline.py | 190 ------------ examples/decision_transformer/utils.py | 309 ++++++++++++++++++-- torchrl/envs/transforms/transforms.py | 11 +- torchrl/modules/__init__.py | 1 + torchrl/modules/models/__init__.py | 1 + torchrl/modules/models/models.py | 54 +++- torchrl/objectives/__init__.py | 2 +- torchrl/objectives/decision_transformer.py | 52 ++++ 9 files changed, 400 insertions(+), 225 deletions(-) delete mode 100644 examples/decision_transformer/dt_offline.py diff --git a/examples/decision_transformer/config.yaml b/examples/decision_transformer/config.yaml index 3d0e885ad5f..503565909bf 100644 --- a/examples/decision_transformer/config.yaml +++ b/examples/decision_transformer/config.yaml @@ -5,11 +5,12 @@ env: env_library: gym record_video: 0 stacked_frames: 20 + inference_context: 20 # 5 n_samples_stats: 2000 frame_skip: 1 from_pixels: False num_train_envs: 1 - num_eval_envs: 1 + num_eval_envs: 10 reward_scaling: 0.001 noop: 1 seed: 0 @@ -55,7 +56,7 @@ optim: weight_decay: 5.0e-4 batch_size: 256 lr_scheduler: "" - pretrain_gradient_steps: 3000 + pretrain_gradient_steps: 5000 updates_per_episode: 300 warmup_steps: 10000 diff --git a/examples/decision_transformer/dt_offline.py b/examples/decision_transformer/dt_offline.py deleted file mode 100644 index cb7529ea71a..00000000000 --- a/examples/decision_transformer/dt_offline.py +++ /dev/null @@ -1,190 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. -"""Decision Transformer Example. -This is a self-contained example of an offline Decision Transformer training script. -The helper functions are coded in the utils.py associated with this script. -""" - -import hydra -import torch -import tqdm -from torchrl.envs.utils import ExplorationType, set_exploration_type - -from utils import ( - make_collector, - make_decision_transformer_model, - make_dt_optimizer, - make_env, - make_logger, - make_loss, - make_offline_replay_buffer, - make_online_replay_buffer, -) - - -@hydra.main(config_path=".", config_name="config") -def main(cfg: "DictConfig"): # noqa: F821 - model_device = cfg.optim.device - - test_env = make_env(cfg.env) - logger = make_logger(cfg.logger) - offline_buffer = make_offline_replay_buffer( - cfg.replay_buffer, cfg.env.reward_scaling - ) - - inference_actor, actor = make_decision_transformer_model(cfg) - policy = actor.to(model_device) - inference_policy = inference_actor.to(model_device) - - loss_module = make_loss(cfg.loss, actor) - transformer_optim, temperature_optim, scheduler = make_dt_optimizer( - cfg.optim, policy, loss_module - ) - - pbar = tqdm.tqdm(total=cfg.optim.pretrain_gradient_steps) - - r0 = None - l0 = None - print(" ***Pretraining*** ") - # Pretraining - for i in range(cfg.optim.pretrain_gradient_steps): - pbar.update(i) - data = offline_buffer.sample() - # loss - loss_vals = loss_module(data) - # backprop - transformer_loss = loss_vals["loss"] - temperature_loss = loss_vals["loss_alpha"] - - transformer_optim.zero_grad() - torch.nn.utils.clip_grad_norm_(policy.parameters(), 0.25) - transformer_loss.backward() - transformer_optim.step() - - temperature_optim.zero_grad() - temperature_loss.backward() - temperature_optim.step() - - scheduler.step() - - # evaluation - with set_exploration_type(ExplorationType.MEAN), torch.no_grad(): - if i % cfg.logger.pretrain_log_interval == 0: - eval_td = test_env.rollout( - max_steps=cfg.logger.eval_steps, - policy=inference_policy, - auto_cast_to_device=True, - ) - if r0 is None: - r0 = eval_td["next", "reward"].sum(1).mean().item() / cfg.env.reward_scaling - if l0 is None: - l0 = transformer_loss.item() - - for key, value in loss_vals.items(): - logger.log_scalar(key, value.item(), i) - eval_reward = ( - eval_td["next", "reward"].sum(1).mean().item() / cfg.env.reward_scaling - ) - logger.log_scalar("evaluation reward", eval_reward, i) - - pbar.set_description( - f"[Pre-Training] loss: {transformer_loss.item(): 4.4f} (init: {l0: 4.4f}), evaluation reward: {eval_reward: 4.4f} (init={r0: 4.4f})" - ) - print("\n ***Online Finetuning*** ") - collector = make_collector(cfg, inference_policy) - online_buffer = make_online_replay_buffer( - offline_buffer, cfg.replay_buffer, cfg.env.reward_scaling - ) - collected_frames = 0 - - pbar = tqdm.tqdm(total=cfg.env.total_online_frames) - r0 = None - - for j, tensordict in enumerate(collector): - # update weights of the inference policy - collector.update_policy_weights_() - - episode_reward = ( - tensordict["next", "episode_reward"][tensordict["next", "done"]] - .mean() - .item() - / cfg.env.reward_scaling - ) - if r0 is None: - r0 = episode_reward - - current_frames = tensordict.numel() - pbar.update(current_frames) - - tensordict = tensordict.reshape(-1) - # only used for logging - tensordict.del_("episode_reward") - - online_buffer.extend(tensordict.cpu().clone().detach()) - collected_frames += current_frames - - # optimization steps - for _ in range(int(cfg.optim.updates_per_episode)): - sampled_tensordict = online_buffer.sample().clone() - - loss_vals = loss_module(sampled_tensordict) - - # backprop - transformer_loss = loss_vals["loss"] - temperature_loss = loss_vals["loss_alpha"] - - transformer_optim.zero_grad() - torch.nn.utils.clip_grad_norm_(policy.parameters(), 0.25) - transformer_loss.backward() - transformer_optim.step() - - temperature_optim.zero_grad() - temperature_loss.backward() - temperature_optim.step() - - scheduler.step() - - train_target_return = ( - tensordict["return_to_go"][:, 0].mean() / cfg.env.reward_scaling - ) - train_log = { - "collect reward": episode_reward, - "collected_frames": collected_frames, - "collect target_return": train_target_return.item() - / cfg.env.reward_scaling, - } - - for key, value in train_log.items(): - logger.log_scalar(key, value, step=j) - - with set_exploration_type(ExplorationType.MEAN), torch.no_grad(): - if j % cfg.logger.fintune_log_interval == 0: - eval_td = test_env.rollout( - max_steps=cfg.logger.eval_steps * cfg.env.num_eval_envs, - policy=inference_policy, - auto_cast_to_device=True, - ) - eval_reward = ( - eval_td["next", "reward"].sum(1).mean().item() / cfg.env.reward_scaling - ) - eval_target_return = ( - eval_td["return_to_go"][:, 0].mean() / cfg.env.reward_scaling - ) - eval_log = { - "fine-tune evaluation reward": eval_reward, - "evaluation target_return": eval_target_return.item() - / cfg.env.reward_scaling, - } - for key, value in eval_log.items(): - logger.log_scalar(key, value, step=j) - pbar.set_description( - f"[Fine-Tuning] loss: {transformer_loss.item(): 4.4f} (init: {l0: 4.4f}), evaluation reward: {eval_reward: 4.4f} (init={r0: 4.4f})" - ) - - collector.shutdown() - - -if __name__ == "__main__": - main() diff --git a/examples/decision_transformer/utils.py b/examples/decision_transformer/utils.py index f135a144827..7061ecedfff 100644 --- a/examples/decision_transformer/utils.py +++ b/examples/decision_transformer/utils.py @@ -25,9 +25,15 @@ ) from torchrl.envs.libs.dm_control import DMControlEnv from torchrl.envs.utils import set_exploration_mode -from torchrl.modules import DTActor, ProbabilisticActor, TanhNormal +from torchrl.modules import ( + DTActor, + OnlineDTActor, + ProbabilisticActor, + TanhDelta, + TanhNormal, +) from torchrl.modules.tensordict_module import DecisionTransformerInferenceWrapper -from torchrl.objectives import OnlineDTLoss +from torchrl.objectives import DTLoss, OnlineDTLoss from torchrl.record.loggers import generate_exp_name, get_logger from torchrl.trainers.helpers.envs import LIBS @@ -96,6 +102,11 @@ def make_transformed_env(base_env, env_cfg, train=False): dim=-2, ) ) + loc, std = get_loc_std("hopper-medium-v2") + obsnorm = ObservationNorm( + loc=loc, scale=std, in_keys="observation", standard_normal=True + ) + transformed_env.append_transform(obsnorm) if train: transformed_env.append_transform(RewardSum()) @@ -103,7 +114,7 @@ def make_transformed_env(base_env, env_cfg, train=False): return transformed_env -def make_parallel_env(env_cfg, state_dict, train=False): +def make_parallel_env(env_cfg, train=False): if train: num_envs = env_cfg.num_train_envs else: @@ -113,31 +124,14 @@ def make_parallel_env(env_cfg, state_dict, train=False): env_cfg, train, ) - for t in env.transform: - if isinstance(t, ObservationNorm): - t.init_stats(3, cat_dim=1, reduce_dim=[0, 1]) - env.load_state_dict(state_dict) return env def make_env(env_cfg, train=False): - state_dict = get_stats(env_cfg) - env = make_parallel_env(env_cfg, state_dict=state_dict, train=train) + env = make_parallel_env(env_cfg, train=train) return env -def get_stats(env_cfg): - env = make_transformed_env(make_base_env(env_cfg), env_cfg) - init_stats(env, env_cfg.n_samples_stats) - return env.state_dict() - - -def init_stats(env, n_samples_stats): - for t in env.transform: - if isinstance(t, ObservationNorm): - t.init_stats(n_samples_stats) - - # ==================================================================== # Collector and replay buffer # --------------------------- @@ -171,8 +165,19 @@ def make_collector(cfg, policy): return collector +def get_loc_std(env_name): + import d4rl # noqa + import gym + + env = gym.make(env_name) + data = env.get_dataset() + loc = torch.from_numpy(data["observations"].mean(axis=0)).float() + std = torch.from_numpy(data["observations"].std(axis=0)).float() + return loc, std + + def make_offline_replay_buffer(rb_cfg, reward_scaling): - r2g = Reward2GoTransform(gamma=1.0, out_keys=["return_to_go"]) + r2g = Reward2GoTransform(gamma=1.0, in_keys=["reward"], out_keys=["return_to_go"]) reward_scale = RewardScaling( loc=0, scale=reward_scaling, in_keys="return_to_go", standard_normal=False ) @@ -188,6 +193,10 @@ def make_offline_replay_buffer(rb_cfg, reward_scaling): in_keys=["observation", ("next", "observation")], in_keys_inv=[], ) + loc, std = get_loc_std(rb_cfg.dataset) + obsnorm = ObservationNorm( + loc=loc, scale=std, in_keys="observation", standard_normal=True + ) exclude = ExcludeTransform( "next_observations", "timeout", @@ -200,11 +209,14 @@ def make_offline_replay_buffer(rb_cfg, reward_scaling): ) transforms = Compose( + # inverse transforms are called reversed + # therefore catframes before r2g + catframes, r2g, reward_scale, - catframes, d2f, exclude, + obsnorm, ) data = D4RLExperienceReplay( rb_cfg.dataset, @@ -260,7 +272,7 @@ def make_online_replay_buffer(offline_buffer, rb_cfg, reward_scaling=0.001): # ----- -def make_decision_transformer_model(cfg): +def make_odt_model(cfg): env_cfg = cfg.env proof_environment = make_transformed_env(make_base_env(env_cfg), env_cfg) @@ -274,7 +286,7 @@ def make_decision_transformer_model(cfg): "return_to_go", ] - actor_net = DTActor( + actor_net = OnlineDTActor( state_dim=state_dim, action_dim=action_spec.shape[-1], transformer_config=cfg.transformer, @@ -307,6 +319,62 @@ def make_decision_transformer_model(cfg): return_log_prob=False, ) + # init the lazy layers + with torch.no_grad(), set_exploration_mode("random"): + td = proof_environment.rollout(max_steps=100) + td["action"] = td["next", "action"] + actor(td) + + inference_actor = DecisionTransformerInferenceWrapper( + actor, + inference_context=cfg.env.inference_context, + ) + return inference_actor, actor + + +def make_dt_model(cfg): + env_cfg = cfg.env + proof_environment = make_transformed_env(make_base_env(env_cfg), env_cfg) + + action_spec = proof_environment.action_spec + for key, value in proof_environment.observation_spec.items(): + if key == "observation": + state_dim = value.shape[-1] + in_keys = [ + "observation", + "action", + "return_to_go", + ] + + actor_net = DTActor( + state_dim=state_dim, + action_dim=action_spec.shape[-1], + transformer_config=cfg.transformer, + ) + + actor_module = TensorDictModule( + actor_net, + in_keys=in_keys, + out_keys=["param"], + ) + dist_class = TanhDelta + dist_kwargs = { + "min": action_spec.space.minimum, + "max": action_spec.space.maximum, + } + + actor = ProbabilisticActor( + spec=action_spec, + in_keys=["param"], + out_keys=["action"], + module=actor_module, + distribution_class=dist_class, + distribution_kwargs=dist_kwargs, + default_interaction_mode="random", + cache_dist=False, + return_log_prob=False, + ) + # init the lazy layers with torch.no_grad(), set_exploration_mode("random"): td = proof_environment.rollout(max_steps=100) @@ -324,7 +392,7 @@ def make_decision_transformer_model(cfg): # --------- -def make_loss(loss_cfg, actor_network): +def make_odt_loss(loss_cfg, actor_network): loss = OnlineDTLoss( actor_network, loss_cfg.alpha_init, @@ -332,9 +400,17 @@ def make_loss(loss_cfg, actor_network): return loss -def make_dt_optimizer(optim_cfg, actor_network, loss): +def make_dt_loss(actor_network): + loss = DTLoss( + actor_network, + ) + return loss + + +def make_odt_optimizer(optim_cfg, actor_network, loss): # Should be Lambda Optimizer - dt_optimizer = torch.optim.Adam( + + dt_optimizer = Lamb( actor_network.parameters(), lr=optim_cfg.lr, weight_decay=optim_cfg.weight_decay, @@ -353,6 +429,21 @@ def make_dt_optimizer(optim_cfg, actor_network, loss): return dt_optimizer, log_temp_optimizer, scheduler +def make_dt_optimizer(optim_cfg, actor_network, loss): + # Should be Lambda Optimizer + dt_optimizer = torch.optim.Adam( + actor_network.parameters(), + lr=optim_cfg.lr, + weight_decay=optim_cfg.weight_decay, + eps=1.0e-8, + ) + scheduler = torch.optim.lr_scheduler.LambdaLR( + dt_optimizer, lambda steps: min((steps + 1) / optim_cfg.warmup_steps, 1) + ) + + return dt_optimizer, scheduler + + # ==================================================================== # Logging and recording # --------------------- @@ -363,3 +454,165 @@ def make_logger(logger_cfg): logger_cfg.exp_name = exp_name logger = get_logger(logger_cfg.backend, logger_name="oDT", experiment_name=exp_name) return logger + + +import math + +import torch +from torch.optim import Optimizer + + +class Lamb(Optimizer): + """Implements a pure pytorch variant of FuseLAMB (NvLamb variant) optimizer from apex.optimizers.FusedLAMB + reference: https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py + LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its norm. (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + grad_averaging (bool, optional): whether apply (1-beta2) to grad when + calculating running averages of gradient. (default: True) + max_grad_norm (float, optional): value used to clip global grad norm (default: 1.0) + trust_clip (bool): enable LAMBC trust ratio clipping (default: False) + always_adapt (boolean, optional): Apply adaptive learning rate to 0.0 + weight decay parameter (default: False) + .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes: + https://arxiv.org/abs/1904.00962 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__( + self, + params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-6, + weight_decay=0.01, + grad_averaging=True, + max_grad_norm=1.0, + trust_clip=False, + always_adapt=False, + ): + defaults = { + "lr": lr, + "bias_correction": bias_correction, + "betas": betas, + "eps": eps, + "weight_decay": weight_decay, + "grad_averaging": grad_averaging, + "max_grad_norm": max_grad_norm, + "trust_clip": trust_clip, + "always_adapt": always_adapt, + } + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + device = self.param_groups[0]["params"][0].device + one_tensor = torch.tensor( + 1.0, device=device + ) # because torch.where doesn't handle scalars correctly + global_grad_norm = torch.zeros(1, device=device) + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad + if grad.is_sparse: + raise RuntimeError( + "Lamb does not support sparse gradients, consider SparseAdam instad." + ) + global_grad_norm.add_(grad.pow(2).sum()) + + global_grad_norm = torch.sqrt(global_grad_norm) + # FIXME it'd be nice to remove explicit tensor conversion of scalars when torch.where promotes + # scalar types properly https://github.com/pytorch/pytorch/issues/9190 + max_grad_norm = torch.tensor(self.defaults["max_grad_norm"], device=device) + clip_global_grad_norm = torch.where( + global_grad_norm > max_grad_norm, + global_grad_norm / max_grad_norm, + one_tensor, + ) + + for group in self.param_groups: + bias_correction = 1 if group["bias_correction"] else 0 + beta1, beta2 = group["betas"] + grad_averaging = 1 if group["grad_averaging"] else 0 + beta3 = 1 - beta1 if grad_averaging else 1.0 + + # assume same step across group now to simplify things + # per parameter step can be easily support by making it tensor, or pass list into kernel + if "step" in group: + group["step"] += 1 + else: + group["step"] = 1 + + if bias_correction: + bias_correction1 = 1 - beta1 ** group["step"] + bias_correction2 = 1 - beta2 ** group["step"] + else: + bias_correction1, bias_correction2 = 1.0, 1.0 + + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.div_(clip_global_grad_norm) + state = self.state[p] + + # State initialization + if len(state) == 0: + # Exponential moving average of gradient valuesa + state["exp_avg"] = torch.zeros_like(p) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(p) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=beta3) # m_t + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # v_t + + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_( + group["eps"] + ) + update = (exp_avg / bias_correction1).div_(denom) + + weight_decay = group["weight_decay"] + if weight_decay != 0: + update.add_(p, alpha=weight_decay) + + if weight_decay != 0 or group["always_adapt"]: + # Layer-wise LR adaptation. By default, skip adaptation on parameters that are + # excluded from weight decay, unless always_adapt == True, then always enabled. + w_norm = p.norm(2.0) + g_norm = update.norm(2.0) + # FIXME nested where required since logical and/or not working in PT XLA + trust_ratio = torch.where( + w_norm > 0, + torch.where(g_norm > 0, w_norm / g_norm, one_tensor), + one_tensor, + ) + if group["trust_clip"]: + # LAMBC trust clipping, upper bound fixed at one + trust_ratio = torch.minimum(trust_ratio, one_tensor) + update.mul_(trust_ratio) + + p.add_(update, alpha=-group["lr"]) + + return loss diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 6667c4d57be..f792eddf8f6 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -2093,11 +2093,18 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec def _inv_call(self, tensordict: TensorDictBase) -> torch.Tensor: if self.as_inverse: - return self.forward(tensordict) + return self.unfolding(tensordict) else: - raise KeyError("Inverse transform not implemented for this transform.") + return tensordict + # raise KeyError("Inverse transform not implemented for this transform.") def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + if self.as_inverse: + return tensordict + else: + return self.unfolding(tensordict) + + def unfolding(self, tensordict: TensorDictBase) -> TensorDictBase: # it is assumed that the last dimension of the tensordict is the time dimension if not tensordict.ndim or ( tensordict.names[-1] is not None and tensordict.names[-1] != "time" diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index 4ab677a65b7..3dd66d64ef8 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -30,6 +30,7 @@ NoisyLinear, ObsDecoder, ObsEncoder, + OnlineDTActor, reset_noise, RSSMPosterior, RSSMPrior, diff --git a/torchrl/modules/models/__init__.py b/torchrl/modules/models/__init__.py index 257eb9628ad..f2972cab4cd 100644 --- a/torchrl/modules/models/__init__.py +++ b/torchrl/modules/models/__init__.py @@ -17,5 +17,6 @@ DuelingCnnDQNet, LSTMNet, MLP, + OnlineDTActor, ) from .utils import Squeeze2dLayer, SqueezeLayer diff --git a/torchrl/modules/models/models.py b/torchrl/modules/models/models.py index 502efdc4872..057c596fee7 100644 --- a/torchrl/modules/models/models.py +++ b/torchrl/modules/models/models.py @@ -1140,8 +1140,8 @@ def forward( return self._lstm(input, hidden0_in, hidden1_in) -class DTActor(nn.Module): - """Decision Transformer Actor class. +class OnlineDTActor(nn.Module): + """Online Decision Transformer Actor class. Presented in "Online Decision Transformer", https://arxiv.org/abs/2202.05607.pdf @@ -1199,3 +1199,53 @@ def forward( std = log_std.exp() return (mu, std) + + +class DTActor(nn.Module): + """Decision Transformer Actor class. + + Presented in "Decision Transformer", + https://arxiv.org/abs/2202.05607.pdf + + + """ + + def __init__( + self, + state_dim: int, + action_dim: int, + transformer_config: Dict, + device: Optional[DEVICE_TYPING] = None, + ): + super().__init__() + self.transformer = DecisionTransformer( + state_dim=state_dim, + action_dim=action_dim, + config=transformer_config, + ) + self.action_layer = nn.Linear( + transformer_config.n_embd, action_dim, device=device + ) + + def weight_init(m): + """Custom weight init for Conv2D and Linear layers.""" + if isinstance(m, torch.nn.Linear): + nn.init.orthogonal_(m.weight.data) + if hasattr(m.bias, "data"): + m.bias.data.fill_(0.0) + + self.apply(weight_init) + + def forward( + self, + observation: torch.Tensor, + action: torch.Tensor, + return_to_go: torch.Tensor, + ) -> torch.Tensor: + if observation.ndim == 2: + observation = observation.unsqueeze(0) + action = action.unsqueeze(0) + return_to_go = return_to_go.unsqueeze(0) + hidden_state = self.transformer(observation, action, return_to_go) + out = self.action_layer(hidden_state) + return out diff --git a/torchrl/objectives/__init__.py b/torchrl/objectives/__init__.py index 16e3cb73b5c..0df17762ef6 100644 --- a/torchrl/objectives/__init__.py +++ b/torchrl/objectives/__init__.py @@ -6,7 +6,7 @@ from .a2c import A2CLoss from .common import LossModule from .ddpg import DDPGLoss -from .decision_transformer import OnlineDTLoss +from .decision_transformer import DTLoss, OnlineDTLoss from .dqn import DistributionalDQNLoss, DQNLoss from .dreamer import DreamerActorLoss, DreamerModelLoss, DreamerValueLoss from .iql import IQLLoss diff --git a/torchrl/objectives/decision_transformer.py b/torchrl/objectives/decision_transformer.py index 4d246267a6a..3afa435813c 100644 --- a/torchrl/objectives/decision_transformer.py +++ b/torchrl/objectives/decision_transformer.py @@ -11,6 +11,7 @@ from tensordict.tensordict import TensorDict, TensorDictBase from torch import distributions as d from torchrl.modules import ProbabilisticActor +from torchrl.objectives.utils import distance_loss from .common import LossModule @@ -95,3 +96,54 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: "alpha": self.log_alpha.exp(), } return TensorDict(out, []) + + +class DTLoss(LossModule): + r"""TorchRL implementation of the Online Decision Transformer loss. + + Presented in "Decision Transformer" https://arxiv.org/abs/2202.05607 + Args: + actor_network (ProbabilisticActor): stochastic actor + + """ + + def __init__( + self, + actor_network: ProbabilisticActor, + ) -> None: + super().__init__() + + # Actor Network + self.convert_to_functional( + actor_network, + "actor_network", + create_target_params=False, + funs_to_decorate=["forward"], + ) + + @property + def device(self) -> torch.device: + for p in self.parameters(): + return p.device + raise RuntimeError( + "At least one of the networks of OnlineDTLoss must have trainable " + "parameters." + ) + + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + """Compute the loss for the Online Decision Transformer.""" + # extract action targets + target_actions = torch.clone(tensordict["action"].detach()).to(self.device) + + pred_actions = self.actor_network( + tensordict.to(self.device), params=self.actor_network_params + ).get("action") + loss = distance_loss( + pred_actions, + target_actions, + loss_function="l2", + ).mean() + out = { + "loss": loss, + } + return TensorDict(out, []) From c6e3229620b038c7695c078f1b6d0ae56614584f Mon Sep 17 00:00:00 2001 From: BY571 Date: Mon, 22 May 2023 09:58:11 +0200 Subject: [PATCH 021/104] add dt odt script --- examples/decision_transformer/dt.py | 89 ++++++++++ examples/decision_transformer/online_dt.py | 194 +++++++++++++++++++++ 2 files changed, 283 insertions(+) create mode 100644 examples/decision_transformer/dt.py create mode 100644 examples/decision_transformer/online_dt.py diff --git a/examples/decision_transformer/dt.py b/examples/decision_transformer/dt.py new file mode 100644 index 00000000000..d4e1a44e38f --- /dev/null +++ b/examples/decision_transformer/dt.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Decision Transformer Example. +This is a self-contained example of an offline Decision Transformer training script. +The helper functions are coded in the utils.py associated with this script. +""" + +import hydra +import torch +import tqdm +from torchrl.envs.utils import ExplorationType, set_exploration_type + +from utils import ( + make_dt_loss, + make_dt_model, + make_dt_optimizer, + make_env, + make_logger, + make_offline_replay_buffer, +) + + +@hydra.main(config_path=".", config_name="config") +def main(cfg: "DictConfig"): # noqa: F821 + model_device = cfg.optim.device + + test_env = make_env(cfg.env) + logger = make_logger(cfg.logger) + offline_buffer = make_offline_replay_buffer( + cfg.replay_buffer, cfg.env.reward_scaling + ) + + inference_actor, actor = make_dt_model(cfg) + policy = actor.to(model_device) + inference_policy = inference_actor.to(model_device) + + loss_module = make_dt_loss(actor) + transformer_optim, scheduler = make_dt_optimizer(cfg.optim, policy, loss_module) + + pbar = tqdm.tqdm(total=cfg.optim.pretrain_gradient_steps) + + r0 = None + l0 = None + print(" ***Pretraining*** ") + # Pretraining + for i in range(cfg.optim.pretrain_gradient_steps): + pbar.update(i) + data = offline_buffer.sample() + # loss + loss_vals = loss_module(data) + # backprop + transformer_loss = loss_vals["loss"] + + transformer_optim.zero_grad() + torch.nn.utils.clip_grad_norm_(policy.parameters(), 0.25) + transformer_loss.backward() + transformer_optim.step() + + scheduler.step() + + # evaluation + with set_exploration_type(ExplorationType.MEAN), torch.no_grad(): + if i % cfg.logger.pretrain_log_interval == 0: + eval_td = test_env.rollout( + max_steps=cfg.logger.eval_steps, + policy=inference_policy, + auto_cast_to_device=True, + ) + if r0 is None: + r0 = eval_td["next", "reward"].sum(1).mean().item() / cfg.env.reward_scaling + if l0 is None: + l0 = transformer_loss.item() + + for key, value in loss_vals.items(): + logger.log_scalar(key, value.item(), i) + eval_reward = ( + eval_td["next", "reward"].sum(1).mean().item() / cfg.env.reward_scaling + ) + logger.log_scalar("evaluation reward", eval_reward, i) + + pbar.set_description( + f"[Pre-Training] loss: {transformer_loss.item(): 4.4f} (init: {l0: 4.4f}), evaluation reward: {eval_reward: 4.4f} (init={r0: 4.4f})" + ) + + +if __name__ == "__main__": + main() diff --git a/examples/decision_transformer/online_dt.py b/examples/decision_transformer/online_dt.py new file mode 100644 index 00000000000..68cc7076cf0 --- /dev/null +++ b/examples/decision_transformer/online_dt.py @@ -0,0 +1,194 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Decision Transformer Example. +This is a self-contained example of an offline Decision Transformer training script. +The helper functions are coded in the utils.py associated with this script. +""" + +import hydra +import torch +import tqdm +from torchrl.envs.utils import ExplorationType, set_exploration_type + +from utils import ( + # get_loc_std, + # make_collector, + make_env, + make_logger, + make_odt_loss, + make_odt_model, + make_odt_optimizer, + make_offline_replay_buffer, + # make_online_replay_buffer, +) + + +@hydra.main(config_path=".", config_name="config") +def main(cfg: "DictConfig"): # noqa: F821 + model_device = cfg.optim.device + + # loc, std = get_loc_std(cfg.replay_buffer.dataset) + test_env = make_env(cfg.env) # , loc, std) + logger = make_logger(cfg.logger) + offline_buffer = make_offline_replay_buffer( + cfg.replay_buffer, cfg.env.reward_scaling + ) # , loc, std + + inference_actor, actor = make_odt_model(cfg) + policy = actor.to(model_device) + inference_policy = inference_actor.to(model_device) + + loss_module = make_odt_loss(cfg.loss, actor) + transformer_optim, temperature_optim, scheduler = make_odt_optimizer( + cfg.optim, policy, loss_module + ) + + pbar = tqdm.tqdm(total=cfg.optim.pretrain_gradient_steps) + + r0 = None + l0 = None + print(" ***Pretraining*** ") + # Pretraining + for i in range(cfg.optim.pretrain_gradient_steps): + pbar.update(i) + data = offline_buffer.sample() + # loss + loss_vals = loss_module(data) + # backprop + transformer_loss = loss_vals["loss"] + temperature_loss = loss_vals["loss_alpha"] + + transformer_optim.zero_grad() + torch.nn.utils.clip_grad_norm_(policy.parameters(), 0.25) + transformer_loss.backward() + transformer_optim.step() + + temperature_optim.zero_grad() + temperature_loss.backward() + temperature_optim.step() + + scheduler.step() + + # evaluation + with set_exploration_type(ExplorationType.MEAN), torch.no_grad(): + inference_policy.eval() + if i % cfg.logger.pretrain_log_interval == 0: + eval_td = test_env.rollout( + max_steps=cfg.logger.eval_steps, + policy=inference_policy, + auto_cast_to_device=True, + ) + inference_policy.train() + if r0 is None: + r0 = eval_td["next", "reward"].sum(1).mean().item() / cfg.env.reward_scaling + if l0 is None: + l0 = transformer_loss.item() + + for key, value in loss_vals.items(): + logger.log_scalar(key, value.item(), i) + eval_reward = ( + eval_td["next", "reward"].sum(1).mean().item() / cfg.env.reward_scaling + ) + logger.log_scalar("evaluation reward", eval_reward, i) + + pbar.set_description( + f"[Pre-Training] loss: {transformer_loss.item(): 4.4f} (init: {l0: 4.4f}), evaluation reward: {eval_reward: 4.4f} (init={r0: 4.4f})" + ) + # print("\n ***Online Finetuning*** ") + # collector = make_collector(cfg, inference_policy) + # online_buffer = make_online_replay_buffer( + # offline_buffer, cfg.replay_buffer, cfg.env.reward_scaling + # ) + # collected_frames = 0 + + # pbar = tqdm.tqdm(total=cfg.env.total_online_frames) + # r0 = None + + # for j, tensordict in enumerate(collector): + # # update weights of the inference policy + # collector.update_policy_weights_() + + # episode_reward = ( + # tensordict["next", "episode_reward"][tensordict["next", "done"]] + # .mean() + # .item() + # / cfg.env.reward_scaling + # ) + # if r0 is None: + # r0 = episode_reward + + # current_frames = tensordict.numel() + # pbar.update(current_frames) + + # tensordict = tensordict.reshape(-1) + # # only used for logging + # tensordict.del_("episode_reward") + + # online_buffer.extend(tensordict.cpu().clone().detach()) + # collected_frames += current_frames + + # # optimization steps + # for _ in range(int(cfg.optim.updates_per_episode)): + # sampled_tensordict = online_buffer.sample().clone() + + # loss_vals = loss_module(sampled_tensordict) + + # # backprop + # transformer_loss = loss_vals["loss"] + # temperature_loss = loss_vals["loss_alpha"] + + # transformer_optim.zero_grad() + # torch.nn.utils.clip_grad_norm_(policy.parameters(), 0.25) + # transformer_loss.backward() + # transformer_optim.step() + + # temperature_optim.zero_grad() + # temperature_loss.backward() + # temperature_optim.step() + + # scheduler.step() + + # train_target_return = ( + # tensordict["return_to_go"][:, 0].mean() / cfg.env.reward_scaling + # ) + # train_log = { + # "collect reward": episode_reward, + # "collected_frames": collected_frames, + # "collect target_return": train_target_return.item() + # / cfg.env.reward_scaling, + # } + + # for key, value in train_log.items(): + # logger.log_scalar(key, value, step=j) + + # with set_exploration_type(ExplorationType.MEAN), torch.no_grad(): + # if j % cfg.logger.fintune_log_interval == 0: + # eval_td = test_env.rollout( + # max_steps=cfg.logger.eval_steps * cfg.env.num_eval_envs, + # policy=inference_policy, + # auto_cast_to_device=True, + # ) + # eval_reward = ( + # eval_td["next", "reward"].sum(1).mean().item() / cfg.env.reward_scaling + # ) + # eval_target_return = ( + # eval_td["return_to_go"][:, 0].mean() / cfg.env.reward_scaling + # ) + # eval_log = { + # "fine-tune evaluation reward": eval_reward, + # "evaluation target_return": eval_target_return.item() + # / cfg.env.reward_scaling, + # } + # for key, value in eval_log.items(): + # logger.log_scalar(key, value, step=j) + # pbar.set_description( + # f"[Fine-Tuning] loss: {transformer_loss.item(): 4.4f} (init: {l0: 4.4f}), evaluation reward: {eval_reward: 4.4f} (init={r0: 4.4f})" + # ) + + # collector.shutdown() + + +if __name__ == "__main__": + main() From aaa09dd3ecf003034f0b13c9c06f8bbf75c9f385 Mon Sep 17 00:00:00 2001 From: BY571 Date: Tue, 23 May 2023 09:31:16 +0200 Subject: [PATCH 022/104] add dt config --- examples/decision_transformer/config.yaml | 30 +++++++++++------------ 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/examples/decision_transformer/config.yaml b/examples/decision_transformer/config.yaml index 503565909bf..92af468c36b 100644 --- a/examples/decision_transformer/config.yaml +++ b/examples/decision_transformer/config.yaml @@ -1,21 +1,21 @@ # Task and env env: - env_name: Hopper-v2 + env_name: HalfCheetah-v3 env_task: "" env_library: gym record_video: 0 stacked_frames: 20 - inference_context: 20 # 5 + inference_context: 5 # for Hopper n_samples_stats: 2000 frame_skip: 1 from_pixels: False num_train_envs: 1 num_eval_envs: 10 - reward_scaling: 0.001 + reward_scaling: 0.001 # for r2g noop: 1 - seed: 0 - eval_target_return: 3600 - collect_target_return: 7200 + seed: 2 + eval_target_return: 6000 # 3600 + collect_target_return: 12000 # 7200 total_online_frames: 1000000 @@ -32,15 +32,15 @@ collector: # logger logger: backend: wandb - exp_name: oDT-Hopper-medium-v2 + exp_name: DT-HalfCheetah-medium-v2 pretrain_log_interval: 500 # record interval in frames fintune_log_interval: 1 eval_steps: 1000 # Buffer replay_buffer: - dataset: hopper-medium-v2 - batch_size: 256 + dataset: halfcheetah-medium-v2 + batch_size: 64 # odt 256 prb: 0 stacked_frames: 20 buffer_prefetch: 64 @@ -54,9 +54,9 @@ optim: device: cuda:0 lr: 1.0e-4 weight_decay: 5.0e-4 - batch_size: 256 + batch_size: 64 # odt 256 lr_scheduler: "" - pretrain_gradient_steps: 5000 + pretrain_gradient_steps: 55000 updates_per_episode: 300 warmup_steps: 10000 @@ -65,10 +65,10 @@ loss: alpha_init: 0.1 transformer: - n_embd: 512 - n_layer: 4 - n_head: 4 - n_inner: 2048 # 4*512 + n_embd: 128 # odt 512 + n_layer: 3 # odt 4 + n_head: 1 # odt 4 + n_inner: 512 # odt 2048 # 4*512 activation: relu n_positions: 1024 resid_pdrop: 0.1 From 45cbd61d3ae00a364f14b284ca30db2f55a5ad90 Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 1 Jun 2023 10:48:43 +0200 Subject: [PATCH 023/104] split config --- examples/decision_transformer/config.yaml | 75 -------------- examples/decision_transformer/dt.py | 2 +- examples/decision_transformer/online_dt.py | 109 ++------------------- examples/decision_transformer/utils.py | 35 +++---- torchrl/objectives/decision_transformer.py | 2 +- 5 files changed, 26 insertions(+), 197 deletions(-) delete mode 100644 examples/decision_transformer/config.yaml diff --git a/examples/decision_transformer/config.yaml b/examples/decision_transformer/config.yaml deleted file mode 100644 index 92af468c36b..00000000000 --- a/examples/decision_transformer/config.yaml +++ /dev/null @@ -1,75 +0,0 @@ -# Task and env -env: - env_name: HalfCheetah-v3 - env_task: "" - env_library: gym - record_video: 0 - stacked_frames: 20 - inference_context: 5 # for Hopper - n_samples_stats: 2000 - frame_skip: 1 - from_pixels: False - num_train_envs: 1 - num_eval_envs: 10 - reward_scaling: 0.001 # for r2g - noop: 1 - seed: 2 - eval_target_return: 6000 # 3600 - collect_target_return: 12000 # 7200 - total_online_frames: 1000000 - - -# Collector -collector: - async_collection: 1 - frames_per_batch: 1000 - total_frames: 1000000 - init_random_frames: 0 - collector_devices: cpu # ,cpu,cpu,cpu] - num_collectors: 1 - max_frames_per_traj: 1000 - -# logger -logger: - backend: wandb - exp_name: DT-HalfCheetah-medium-v2 - pretrain_log_interval: 500 # record interval in frames - fintune_log_interval: 1 - eval_steps: 1000 - -# Buffer -replay_buffer: - dataset: halfcheetah-medium-v2 - batch_size: 64 # odt 256 - prb: 0 - stacked_frames: 20 - buffer_prefetch: 64 - capacity: 1_000_000 - buffer_scratch_dir: "/tmp/" - device: cpu - prefetch: 3 - -# Optimization -optim: - device: cuda:0 - lr: 1.0e-4 - weight_decay: 5.0e-4 - batch_size: 64 # odt 256 - lr_scheduler: "" - pretrain_gradient_steps: 55000 - updates_per_episode: 300 - warmup_steps: 10000 - -# loss -loss: - alpha_init: 0.1 - -transformer: - n_embd: 128 # odt 512 - n_layer: 3 # odt 4 - n_head: 1 # odt 4 - n_inner: 512 # odt 2048 # 4*512 - activation: relu - n_positions: 1024 - resid_pdrop: 0.1 - attn_pdrop: 0.1 diff --git a/examples/decision_transformer/dt.py b/examples/decision_transformer/dt.py index d4e1a44e38f..745920bc376 100644 --- a/examples/decision_transformer/dt.py +++ b/examples/decision_transformer/dt.py @@ -22,7 +22,7 @@ ) -@hydra.main(config_path=".", config_name="config") +@hydra.main(config_path=".", config_name="dt_config") def main(cfg: "DictConfig"): # noqa: F821 model_device = cfg.optim.device diff --git a/examples/decision_transformer/online_dt.py b/examples/decision_transformer/online_dt.py index 68cc7076cf0..add21a4cfdc 100644 --- a/examples/decision_transformer/online_dt.py +++ b/examples/decision_transformer/online_dt.py @@ -2,8 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -"""Decision Transformer Example. -This is a self-contained example of an offline Decision Transformer training script. +"""Online Decision Transformer Example. +This is a self-contained example of an Online Decision Transformer training script. The helper functions are coded in the utils.py associated with this script. """ @@ -13,28 +13,24 @@ from torchrl.envs.utils import ExplorationType, set_exploration_type from utils import ( - # get_loc_std, - # make_collector, make_env, make_logger, make_odt_loss, make_odt_model, make_odt_optimizer, make_offline_replay_buffer, - # make_online_replay_buffer, ) -@hydra.main(config_path=".", config_name="config") +@hydra.main(config_path=".", config_name="odt_config") def main(cfg: "DictConfig"): # noqa: F821 model_device = cfg.optim.device - # loc, std = get_loc_std(cfg.replay_buffer.dataset) - test_env = make_env(cfg.env) # , loc, std) logger = make_logger(cfg.logger) - offline_buffer = make_offline_replay_buffer( + offline_buffer, obs_loc, obs_std = make_offline_replay_buffer( cfg.replay_buffer, cfg.env.reward_scaling - ) # , loc, std + ) + test_env = make_env(cfg.env, obs_loc, obs_std) inference_actor, actor = make_odt_model(cfg) policy = actor.to(model_device) @@ -56,7 +52,6 @@ def main(cfg: "DictConfig"): # noqa: F821 data = offline_buffer.sample() # loss loss_vals = loss_module(data) - # backprop transformer_loss = loss_vals["loss"] temperature_loss = loss_vals["loss_alpha"] @@ -96,98 +91,6 @@ def main(cfg: "DictConfig"): # noqa: F821 pbar.set_description( f"[Pre-Training] loss: {transformer_loss.item(): 4.4f} (init: {l0: 4.4f}), evaluation reward: {eval_reward: 4.4f} (init={r0: 4.4f})" ) - # print("\n ***Online Finetuning*** ") - # collector = make_collector(cfg, inference_policy) - # online_buffer = make_online_replay_buffer( - # offline_buffer, cfg.replay_buffer, cfg.env.reward_scaling - # ) - # collected_frames = 0 - - # pbar = tqdm.tqdm(total=cfg.env.total_online_frames) - # r0 = None - - # for j, tensordict in enumerate(collector): - # # update weights of the inference policy - # collector.update_policy_weights_() - - # episode_reward = ( - # tensordict["next", "episode_reward"][tensordict["next", "done"]] - # .mean() - # .item() - # / cfg.env.reward_scaling - # ) - # if r0 is None: - # r0 = episode_reward - - # current_frames = tensordict.numel() - # pbar.update(current_frames) - - # tensordict = tensordict.reshape(-1) - # # only used for logging - # tensordict.del_("episode_reward") - - # online_buffer.extend(tensordict.cpu().clone().detach()) - # collected_frames += current_frames - - # # optimization steps - # for _ in range(int(cfg.optim.updates_per_episode)): - # sampled_tensordict = online_buffer.sample().clone() - - # loss_vals = loss_module(sampled_tensordict) - - # # backprop - # transformer_loss = loss_vals["loss"] - # temperature_loss = loss_vals["loss_alpha"] - - # transformer_optim.zero_grad() - # torch.nn.utils.clip_grad_norm_(policy.parameters(), 0.25) - # transformer_loss.backward() - # transformer_optim.step() - - # temperature_optim.zero_grad() - # temperature_loss.backward() - # temperature_optim.step() - - # scheduler.step() - - # train_target_return = ( - # tensordict["return_to_go"][:, 0].mean() / cfg.env.reward_scaling - # ) - # train_log = { - # "collect reward": episode_reward, - # "collected_frames": collected_frames, - # "collect target_return": train_target_return.item() - # / cfg.env.reward_scaling, - # } - - # for key, value in train_log.items(): - # logger.log_scalar(key, value, step=j) - - # with set_exploration_type(ExplorationType.MEAN), torch.no_grad(): - # if j % cfg.logger.fintune_log_interval == 0: - # eval_td = test_env.rollout( - # max_steps=cfg.logger.eval_steps * cfg.env.num_eval_envs, - # policy=inference_policy, - # auto_cast_to_device=True, - # ) - # eval_reward = ( - # eval_td["next", "reward"].sum(1).mean().item() / cfg.env.reward_scaling - # ) - # eval_target_return = ( - # eval_td["return_to_go"][:, 0].mean() / cfg.env.reward_scaling - # ) - # eval_log = { - # "fine-tune evaluation reward": eval_reward, - # "evaluation target_return": eval_target_return.item() - # / cfg.env.reward_scaling, - # } - # for key, value in eval_log.items(): - # logger.log_scalar(key, value, step=j) - # pbar.set_description( - # f"[Fine-Tuning] loss: {transformer_loss.item(): 4.4f} (init: {l0: 4.4f}), evaluation reward: {eval_reward: 4.4f} (init={r0: 4.4f})" - # ) - - # collector.shutdown() if __name__ == "__main__": diff --git a/examples/decision_transformer/utils.py b/examples/decision_transformer/utils.py index 7061ecedfff..b7e9a5fa7a7 100644 --- a/examples/decision_transformer/utils.py +++ b/examples/decision_transformer/utils.py @@ -44,8 +44,8 @@ def make_base_env(env_cfg): - env_library = LIBS[env_cfg.env_library] - env_name = env_cfg.env_name + env_library = LIBS[env_cfg.library] + env_name = env_cfg.name frame_skip = env_cfg.frame_skip env_kwargs = { @@ -53,7 +53,7 @@ def make_base_env(env_cfg): "frame_skip": frame_skip, } if env_library is DMControlEnv: - env_task = env_cfg.env_task + env_task = env_cfg.task env_kwargs.update({"task_name": env_task}) env = env_library(**env_kwargs) if env_cfg.noop > 1: @@ -61,7 +61,7 @@ def make_base_env(env_cfg): return env -def make_transformed_env(base_env, env_cfg, train=False): +def make_transformed_env(base_env, env_cfg, obs_loc, obs_std, train=False): transformed_env = TransformedEnv(base_env) if train: transformed_env.append_transform( @@ -102,9 +102,8 @@ def make_transformed_env(base_env, env_cfg, train=False): dim=-2, ) ) - loc, std = get_loc_std("hopper-medium-v2") obsnorm = ObservationNorm( - loc=loc, scale=std, in_keys="observation", standard_normal=True + loc=obs_loc, scale=obs_std, in_keys="observation", standard_normal=True ) transformed_env.append_transform(obsnorm) @@ -114,7 +113,7 @@ def make_transformed_env(base_env, env_cfg, train=False): return transformed_env -def make_parallel_env(env_cfg, train=False): +def make_parallel_env(env_cfg, obs_loc, obs_std, train=False): if train: num_envs = env_cfg.num_train_envs else: @@ -122,13 +121,15 @@ def make_parallel_env(env_cfg, train=False): env = make_transformed_env( ParallelEnv(num_envs, EnvCreator(lambda: make_base_env(env_cfg))), env_cfg, + obs_loc, + obs_std, train, ) return env -def make_env(env_cfg, train=False): - env = make_parallel_env(env_cfg, train=train) +def make_env(env_cfg, obs_loc, obs_std, train=False): + env = make_parallel_env(env_cfg, obs_loc, obs_std, train=train) return env @@ -227,7 +228,7 @@ def make_offline_replay_buffer(rb_cfg, reward_scaling): ) # TODO: add obsnorm here - return data + return data, loc, std def make_online_replay_buffer(offline_buffer, rb_cfg, reward_scaling=0.001): @@ -407,8 +408,7 @@ def make_dt_loss(actor_network): return loss -def make_odt_optimizer(optim_cfg, actor_network, loss): - # Should be Lambda Optimizer +def make_odt_optimizer(optim_cfg, actor_network, loss_module): dt_optimizer = Lamb( actor_network.parameters(), @@ -421,7 +421,7 @@ def make_odt_optimizer(optim_cfg, actor_network, loss): ) log_temp_optimizer = torch.optim.Adam( - [loss.log_alpha], + [loss_module.log_alpha], lr=1e-4, betas=[0.9, 0.999], ) @@ -429,8 +429,7 @@ def make_odt_optimizer(optim_cfg, actor_network, loss): return dt_optimizer, log_temp_optimizer, scheduler -def make_dt_optimizer(optim_cfg, actor_network, loss): - # Should be Lambda Optimizer +def make_dt_optimizer(optim_cfg, actor_network): dt_optimizer = torch.optim.Adam( actor_network.parameters(), lr=optim_cfg.lr, @@ -450,9 +449,11 @@ def make_dt_optimizer(optim_cfg, actor_network, loss): def make_logger(logger_cfg): - exp_name = generate_exp_name("OnlineDecisionTransformer", logger_cfg.exp_name) + exp_name = generate_exp_name(logger_cfg.model_name, logger_cfg.exp_name) logger_cfg.exp_name = exp_name - logger = get_logger(logger_cfg.backend, logger_name="oDT", experiment_name=exp_name) + logger = get_logger( + logger_cfg.backend, logger_name=logger_cfg.model_name, experiment_name=exp_name + ) return logger diff --git a/torchrl/objectives/decision_transformer.py b/torchrl/objectives/decision_transformer.py index 3afa435813c..00d09012c8b 100644 --- a/torchrl/objectives/decision_transformer.py +++ b/torchrl/objectives/decision_transformer.py @@ -101,7 +101,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: class DTLoss(LossModule): r"""TorchRL implementation of the Online Decision Transformer loss. - Presented in "Decision Transformer" https://arxiv.org/abs/2202.05607 + Presented in "Decision Transformer: Reinforcement Learning via Sequence Modeling" https://arxiv.org/abs/2106.01345 Args: actor_network (ProbabilisticActor): stochastic actor From 86ddc44228191e0f8d885c6d69847a0f66dfee4c Mon Sep 17 00:00:00 2001 From: BY571 Date: Fri, 2 Jun 2023 14:10:15 +0200 Subject: [PATCH 024/104] fix --- test/test_cost.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/test_cost.py b/test/test_cost.py index dee2bbbb1a9..eb55edaa24e 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -4249,6 +4249,9 @@ def test_seq_odt(self, device): assert p.grad.norm() > 0.0, f"parameter {name} has a null gradient" +@pytest.mark.skipif( + not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}" +) class TestIQL: seed = 0 From 170ab13100b48441666b6c2cbd155e89640a0cc6 Mon Sep 17 00:00:00 2001 From: BY571 Date: Fri, 2 Jun 2023 14:16:41 +0200 Subject: [PATCH 025/104] fix --- test/test_cost.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_cost.py b/test/test_cost.py index eb55edaa24e..ebee3e811ee 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -4252,7 +4252,7 @@ def test_seq_odt(self, device): @pytest.mark.skipif( not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}" ) -class TestIQL: +class TestIQL(LossModuleTestBase): seed = 0 def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): From 1fcbf0eaee7f76f82dcd8d1b8c4dd8268363da0f Mon Sep 17 00:00:00 2001 From: BY571 Date: Fri, 2 Jun 2023 14:23:40 +0200 Subject: [PATCH 026/104] description catframes --- torchrl/envs/transforms/transforms.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 61482644fe3..ac876d7d91c 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -1960,6 +1960,7 @@ class CatFrames(ObservationTransform): has to be written. Defaults to the value of `in_keys`. padding (str, optional): the padding method. One of ``"same"`` or ``"zeros"``. Defaults to ``"same"``, ie. the first value is uesd for padding. + as_inverse (bool, optional): if ``True``, the transform is applied as an inverse transform. Examples: >>> from torchrl.envs.libs.gym import GymEnv From 165459d26bd788755a7af078d2b91f19361cc863 Mon Sep 17 00:00:00 2001 From: BY571 Date: Fri, 2 Jun 2023 14:40:42 +0200 Subject: [PATCH 027/104] add dt test --- test/test_cost.py | 126 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 124 insertions(+), 2 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 716606dee98..2ac803f989c 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -54,7 +54,11 @@ SafeSequential, WorldModelWrapper, ) -from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal +from torchrl.modules.distributions.continuous import ( + NormalParamWrapper, + TanhDelta, + TanhNormal, +) from torchrl.modules.models.model_based import ( DreamerActor, ObsDecoder, @@ -82,6 +86,7 @@ DreamerActorLoss, DreamerModelLoss, DreamerValueLoss, + DTLoss, IQLLoss, KLPENPPOLoss, OnlineDTLoss, @@ -4097,7 +4102,6 @@ def test_dreamer_value_tensordict_keys(self, device): self.tensordict_keys_test(loss_fn, default_keys=default_keys) - class TestOnlineDT: seed = 0 @@ -4238,6 +4242,124 @@ def test_seq_odt(self, device): assert p.grad.norm() > 0.0, f"parameter {name} has a null gradient" +class TestDT: + seed = 0 + + def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): + # Actor + action_spec = BoundedTensorSpec( + -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) + ) + net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) + module = SafeModule(net, in_keys=["observation"], out_keys=["param"]) + actor = ProbabilisticActor( + module=module, + distribution_class=TanhDelta, + in_keys=["param"], + spec=action_spec, + ) + return actor.to(device) + + def _create_mock_data_dt(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): + # create a tensordict + obs = torch.randn(batch, obs_dim, device=device) + action = torch.randn(batch, action_dim, device=device).clamp(-1, 1) + reward2go = torch.randn(batch, 1, device=device) + td = TensorDict( + batch_size=(batch,), + source={ + "observation": obs, + "action": action, + "reward2go": reward2go, + }, + device=device, + ) + return td + + def _create_seq_mock_data_dt( + self, batch=2, T=4, obs_dim=3, action_dim=4, device="cpu" + ): + # create a tensordict + obs = torch.randn(batch, T, obs_dim, device=device) + action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) + reward2go = torch.randn(batch, T, 1, device=device) + + td = TensorDict( + batch_size=(batch, T), + source={ + "observation": obs, + "reward": reward2go, + "action": action, + }, + device=device, + ) + return td + + @pytest.mark.parametrize("device", get_available_devices()) + def test_dt(self, device): + torch.manual_seed(self.seed) + td = self._create_mock_data_dt(device=device) + + actor = self._create_mock_actor(device=device) + + loss_fn = DTLoss(actor) + loss = loss_fn(td) + loss_transformer = loss["loss"] + loss_transformer.backward(retain_graph=True) + named_parameters = loss_fn.named_parameters() + + for name, p in named_parameters: + if p.grad is not None and p.grad.norm() > 0.0: + assert "actor" in name + assert "alpha" not in name + if p.grad is None: + assert "actor" not in name + assert "alpha" in name + loss_fn.zero_grad() + + sum([loss_transformer]).backward() + named_parameters = list(loss_fn.named_parameters()) + named_buffers = list(loss_fn.named_buffers()) + + assert len({p for n, p in named_parameters}) == len(list(named_parameters)) + assert len({p for n, p in named_buffers}) == len(list(named_buffers)) + + for name, p in named_parameters: + assert p.grad.norm() > 0.0, f"parameter {name} has a null gradient" + + @pytest.mark.parametrize("device", get_available_devices()) + def test_seq_dt(self, device): + torch.manual_seed(self.seed) + td = self._create_seq_mock_data_dt(device=device) + + actor = self._create_mock_actor(device=device) + + loss_fn = DTLoss(actor) + loss = loss_fn(td) + loss_transformer = loss["loss"] + loss_transformer.backward(retain_graph=True) + named_parameters = loss_fn.named_parameters() + + for name, p in named_parameters: + if p.grad is not None and p.grad.norm() > 0.0: + assert "actor" in name + assert "alpha" not in name + if p.grad is None: + assert "actor" not in name + assert "alpha" in name + loss_fn.zero_grad() + + sum([loss_transformer]).backward() + named_parameters = list(loss_fn.named_parameters()) + named_buffers = list(loss_fn.named_buffers()) + + assert len({p for n, p in named_parameters}) == len(list(named_parameters)) + assert len({p for n, p in named_buffers}) == len(list(named_buffers)) + + for name, p in named_parameters: + assert p.grad.norm() > 0.0, f"parameter {name} has a null gradient" + + @pytest.mark.skipif( not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}" ) From 50f0aa80dfe1659c6b6f6d1c0d1f30a61307c2cb Mon Sep 17 00:00:00 2001 From: BY571 Date: Fri, 2 Jun 2023 14:54:19 +0200 Subject: [PATCH 028/104] add cfg to logger --- examples/decision_transformer/dt.py | 2 +- examples/decision_transformer/online_dt.py | 2 +- examples/decision_transformer/utils.py | 12 +++++++----- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/examples/decision_transformer/dt.py b/examples/decision_transformer/dt.py index 3d49ba3d3da..1f8a8570a46 100644 --- a/examples/decision_transformer/dt.py +++ b/examples/decision_transformer/dt.py @@ -27,7 +27,7 @@ def main(cfg: "DictConfig"): # noqa: F821 model_device = cfg.optim.device test_env = make_env(cfg.env) - logger = make_logger(cfg.logger) + logger = make_logger(cfg) offline_buffer = make_offline_replay_buffer( cfg.replay_buffer, cfg.env.reward_scaling ) diff --git a/examples/decision_transformer/online_dt.py b/examples/decision_transformer/online_dt.py index 3b6aea3b8fb..522c26c1064 100644 --- a/examples/decision_transformer/online_dt.py +++ b/examples/decision_transformer/online_dt.py @@ -26,7 +26,7 @@ def main(cfg: "DictConfig"): # noqa: F821 model_device = cfg.optim.device - logger = make_logger(cfg.logger) + logger = make_logger(cfg) offline_buffer, obs_loc, obs_std = make_offline_replay_buffer( cfg.replay_buffer, cfg.env.reward_scaling ) diff --git a/examples/decision_transformer/utils.py b/examples/decision_transformer/utils.py index b7e9a5fa7a7..151d9cc0f06 100644 --- a/examples/decision_transformer/utils.py +++ b/examples/decision_transformer/utils.py @@ -409,7 +409,6 @@ def make_dt_loss(actor_network): def make_odt_optimizer(optim_cfg, actor_network, loss_module): - dt_optimizer = Lamb( actor_network.parameters(), lr=optim_cfg.lr, @@ -448,11 +447,14 @@ def make_dt_optimizer(optim_cfg, actor_network): # --------------------- -def make_logger(logger_cfg): - exp_name = generate_exp_name(logger_cfg.model_name, logger_cfg.exp_name) - logger_cfg.exp_name = exp_name +def make_logger(cfg): + exp_name = generate_exp_name(cfg.logger.model_name, cfg.logger.exp_name) + cfg.logger.exp_name = exp_name logger = get_logger( - logger_cfg.backend, logger_name=logger_cfg.model_name, experiment_name=exp_name + cfg.logger.backend, + logger_name=cfg.logger.model_name, + experiment_name=exp_name, + wandb_kwargs={"config": cfg}, ) return logger From 3cc456e9ed20466deadba70d3485ba7543520265 Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 8 Jun 2023 14:13:19 +0200 Subject: [PATCH 029/104] take off detach --- torchrl/objectives/decision_transformer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/torchrl/objectives/decision_transformer.py b/torchrl/objectives/decision_transformer.py index 00d09012c8b..6907ea3e99f 100644 --- a/torchrl/objectives/decision_transformer.py +++ b/torchrl/objectives/decision_transformer.py @@ -76,10 +76,10 @@ def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor: def forward(self, tensordict: TensorDictBase) -> TensorDictBase: """Compute the loss for the Online Decision Transformer.""" # extract action targets - target_actions = torch.clone(tensordict["action"].detach()).to(self.device) + target_actions = tensordict["action"].detach() action_dist = self.actor_network.get_dist( - tensordict.to(self.device), params=self.actor_network_params + tensordict, params=self.actor_network_params ) loss_log_likelihood = action_dist.log_prob(target_actions).mean() @@ -91,9 +91,9 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: out = { "loss": loss, "loss_log_likelihood": -loss_log_likelihood, - "entropy": entropy, + "entropy": entropy.detach(), "loss_alpha": loss_alpha, - "alpha": self.log_alpha.exp(), + "alpha": self.log_alpha.exp().detach(), } return TensorDict(out, []) @@ -133,10 +133,10 @@ def device(self) -> torch.device: def forward(self, tensordict: TensorDictBase) -> TensorDictBase: """Compute the loss for the Online Decision Transformer.""" # extract action targets - target_actions = torch.clone(tensordict["action"].detach()).to(self.device) + target_actions = tensordict["action"].detach() pred_actions = self.actor_network( - tensordict.to(self.device), params=self.actor_network_params + tensordict, params=self.actor_network_params ).get("action") loss = distance_loss( pred_actions, From 04974495c8287623716c79a534f264f5d79ea922 Mon Sep 17 00:00:00 2001 From: BY571 Date: Mon, 12 Jun 2023 10:33:33 +0200 Subject: [PATCH 030/104] add loss to docs --- docs/source/reference/objectives.rst | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/docs/source/reference/objectives.rst b/docs/source/reference/objectives.rst index ed2d5c3cff7..ad653d001d6 100644 --- a/docs/source/reference/objectives.rst +++ b/docs/source/reference/objectives.rst @@ -136,6 +136,24 @@ CQL CQLLoss +DT +---- + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + DTLoss + +OnlineDT +---- + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + OnlineDTLoss + TD3 ---- From b24a7f8639d33f07652b41a326a0e2cc0f723b01 Mon Sep 17 00:00:00 2001 From: BY571 Date: Mon, 12 Jun 2023 10:35:29 +0200 Subject: [PATCH 031/104] update proof_env creation --- examples/decision_transformer/utils.py | 8 +++++-- torchrl/objectives/decision_transformer.py | 28 +++++----------------- 2 files changed, 12 insertions(+), 24 deletions(-) diff --git a/examples/decision_transformer/utils.py b/examples/decision_transformer/utils.py index 151d9cc0f06..af63d7a6612 100644 --- a/examples/decision_transformer/utils.py +++ b/examples/decision_transformer/utils.py @@ -275,7 +275,9 @@ def make_online_replay_buffer(offline_buffer, rb_cfg, reward_scaling=0.001): def make_odt_model(cfg): env_cfg = cfg.env - proof_environment = make_transformed_env(make_base_env(env_cfg), env_cfg) + proof_environment = make_transformed_env( + make_base_env(env_cfg), env_cfg, obs_loc=0, obs_std=1 + ) action_spec = proof_environment.action_spec for key, value in proof_environment.observation_spec.items(): @@ -335,7 +337,9 @@ def make_odt_model(cfg): def make_dt_model(cfg): env_cfg = cfg.env - proof_environment = make_transformed_env(make_base_env(env_cfg), env_cfg) + proof_environment = make_transformed_env( + make_base_env(env_cfg), env_cfg, obs_loc=0, obs_std=1 + ) action_spec = proof_environment.action_spec for key, value in proof_environment.observation_spec.items(): diff --git a/torchrl/objectives/decision_transformer.py b/torchrl/objectives/decision_transformer.py index 6907ea3e99f..fc372fd2e27 100644 --- a/torchrl/objectives/decision_transformer.py +++ b/torchrl/objectives/decision_transformer.py @@ -11,15 +11,16 @@ from tensordict.tensordict import TensorDict, TensorDictBase from torch import distributions as d from torchrl.modules import ProbabilisticActor -from torchrl.objectives.utils import distance_loss -from .common import LossModule +from torchrl.objectives.common import LossModule +from torchrl.objectives.utils import distance_loss class OnlineDTLoss(LossModule): r"""TorchRL implementation of the Online Decision Transformer loss. - Presented in "Online Decision Transformer" https://arxiv.org/abs/2202.05607 + Presented in `"Online Decision Transformer" ` + Args: actor_network (ProbabilisticActor): stochastic actor alpha_init (float): initial value of the temperature parameter @@ -58,15 +59,6 @@ def __init__( ) self.samples_mc_entropy = samples_mc_entropy - @property - def device(self) -> torch.device: - for p in self.parameters(): - return p.device - raise RuntimeError( - "At least one of the networks of OnlineDTLoss must have trainable " - "parameters." - ) - def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor: x = dist.rsample((self.samples_mc_entropy,)) log_p = dist.log_prob(x) @@ -101,7 +93,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: class DTLoss(LossModule): r"""TorchRL implementation of the Online Decision Transformer loss. - Presented in "Decision Transformer: Reinforcement Learning via Sequence Modeling" https://arxiv.org/abs/2106.01345 + Presented in `"Decision Transformer: Reinforcement Learning via Sequence Modeling" ` + Args: actor_network (ProbabilisticActor): stochastic actor @@ -121,15 +114,6 @@ def __init__( funs_to_decorate=["forward"], ) - @property - def device(self) -> torch.device: - for p in self.parameters(): - return p.device - raise RuntimeError( - "At least one of the networks of OnlineDTLoss must have trainable " - "parameters." - ) - def forward(self, tensordict: TensorDictBase) -> TensorDictBase: """Compute the loss for the Online Decision Transformer.""" # extract action targets From 8e04add8bf0d6aed2865c8d7f9011f5b8d321c00 Mon Sep 17 00:00:00 2001 From: BY571 Date: Mon, 12 Jun 2023 10:36:13 +0200 Subject: [PATCH 032/104] move batch to device --- examples/decision_transformer/online_dt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/decision_transformer/online_dt.py b/examples/decision_transformer/online_dt.py index 522c26c1064..e495df3ba8e 100644 --- a/examples/decision_transformer/online_dt.py +++ b/examples/decision_transformer/online_dt.py @@ -57,7 +57,7 @@ def main(cfg: "DictConfig"): # noqa: F821 pbar.update(i) data = offline_buffer.sample() # loss - loss_vals = loss_module(data) + loss_vals = loss_module(data.to(model_device)) transformer_loss = loss_vals["loss"] temperature_loss = loss_vals["loss_alpha"] From 2ad7af549295ad4ff1a24c69ce61d1a64deb26d5 Mon Sep 17 00:00:00 2001 From: BY571 Date: Mon, 12 Jun 2023 11:27:23 +0200 Subject: [PATCH 033/104] remove gpt2model and import directly from hf --- .../modules/models/decision_transformer.py | 2 +- torchrl/modules/models/gpt2_transformer.py | 504 ------------------ 2 files changed, 1 insertion(+), 505 deletions(-) delete mode 100644 torchrl/modules/models/gpt2_transformer.py diff --git a/torchrl/modules/models/decision_transformer.py b/torchrl/modules/models/decision_transformer.py index 0528d01d993..10d3c56c06c 100644 --- a/torchrl/modules/models/decision_transformer.py +++ b/torchrl/modules/models/decision_transformer.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn import transformers -from torchrl.modules.models.gpt2_transformer import GPT2Model +from transformers.models.gpt2.modeling_gpt2 import GPT2Model class DecisionTransformer(nn.Module): diff --git a/torchrl/modules/models/gpt2_transformer.py b/torchrl/modules/models/gpt2_transformer.py deleted file mode 100644 index b7b6c60aeb5..00000000000 --- a/torchrl/modules/models/gpt2_transformer.py +++ /dev/null @@ -1,504 +0,0 @@ -# coding=utf-8 -# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch OpenAI GPT-2 model.""" - -import warnings - -from typing import Optional, Tuple, Union - -import torch -import torch.utils.checkpoint -from torch import nn - -from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions - -from transformers.utils import ( - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, -) -from transformers.utils.model_parallel_utils import assert_device_map, get_device_map - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "gpt2" -_CONFIG_FOR_DOC = "GPT2Config" - -from transformers import GPT2PreTrainedModel -from transformers.models.gpt2.modeling_gpt2 import GPT2Block - - -GPT2_START_DOCSTRING = r""" - - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`GPT2Config`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -GPT2_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): - `input_ids_length` = `sequence_length` if `past_key_values` is `None` else - `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input - sequence tokens in the vocabulary. - - If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as - `input_ids`. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`): - Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see - `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have - their past given to this model should not be passed as `input_ids` as they have already been computed. - attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for - `past_key_values`. In other words, the `attention_mask` always has to have the length: - `len(past_key_values) + len(input_ids)` - - [What are attention masks?](../glossary#attention-mask) - token_type_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*): - Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, - 1]`: - - - 0 corresponds to a *sentence A* token, - - 1 corresponds to a *sentence B* token. - - [What are token type IDs?](../glossary#token-type-ids) - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - - If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see - `past_key_values`). - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" -PARALLELIZE_DOCSTRING = r""" - This is an experimental feature and is a subject to change at a moment's notice. - - Uses a device map to distribute attention modules of the model across several devices. If no device map is given, - it will evenly distribute blocks across all devices. - - Args: - device_map (`Dict[int, list]`, optional, defaults to None): - A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always - automatically mapped to the first device (for esoteric reasons). That means that the first device should - have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the - following number of attention modules: - - - gpt2: 12 - - gpt2-medium: 24 - - gpt2-large: 36 - - gpt2-xl: 48 - - Example: - - ```python - # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules: - model = GPT2LMHeadModel.from_pretrained("gpt2-xl") - device_map = { - 0: [0, 1, 2, 3, 4, 5, 6, 7, 8], - 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21], - 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34], - 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], - } - model.parallelize(device_map) - ``` -""" -DEPARALLELIZE_DOCSTRING = r""" - Moves the model to cpu from a model parallel state. - - Example: - - ```python - # On a 4 GPU machine with gpt2-large: - model = GPT2LMHeadModel.from_pretrained("gpt2-large") - device_map = { - 0: [0, 1, 2, 3, 4, 5, 6, 7], - 1: [8, 9, 10, 11, 12, 13, 14, 15], - 2: [16, 17, 18, 19, 20, 21, 22, 23], - 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35], - } - model.parallelize(device_map) # Splits the model across several devices - model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() - ``` -""" - - -@add_start_docstrings( - "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.", - GPT2_START_DOCSTRING, -) -class GPT2Model(GPT2PreTrainedModel): - """GPT2 Model transformer.""" - - _keys_to_ignore_on_load_missing = ["attn.masked_bias"] - - def __init__(self, config): - super().__init__(config) - - self.embed_dim = config.hidden_size - - self.wte = nn.Embedding(config.vocab_size, self.embed_dim) - # self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) - - self.drop = nn.Dropout(config.embd_pdrop) - self.h = nn.ModuleList( - [GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)] - ) - self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) - - # Model parallel - self.model_parallel = False - self.device_map = None - self.gradient_checkpointing = False - - # Initialize weights and apply final processing - self.post_init() - - @add_start_docstrings(PARALLELIZE_DOCSTRING) - def parallelize(self, device_map=None): - # Check validity of device_map - warnings.warn( - "`GPT2Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your" - " model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" - " `device_map` but it needs to be a dictionary module_name to device, so for instance {'h.0': 0, 'h.1': 1," - " ...}", - FutureWarning, - ) - self.device_map = ( - get_device_map(len(self.h), range(torch.cuda.device_count())) - if device_map is None - else device_map - ) - assert_device_map(self.device_map, len(self.h)) - self.model_parallel = True - self.first_device = ( - "cpu" - if "cpu" in self.device_map.keys() - else "cuda:" + str(min(self.device_map.keys())) - ) - self.last_device = "cuda:" + str(max(self.device_map.keys())) - self.wte = self.wte.to(self.first_device) - self.wpe = self.wpe.to(self.first_device) - # Load onto devices - for k, v in self.device_map.items(): - for block in v: - cuda_device = "cuda:" + str(k) - self.h[block] = self.h[block].to(cuda_device) - # ln_f to last - self.ln_f = self.ln_f.to(self.last_device) - - @add_start_docstrings(DEPARALLELIZE_DOCSTRING) - def deparallelize(self): - warnings.warn( - "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", - FutureWarning, - ) - self.model_parallel = False - self.device_map = None - self.first_device = "cpu" - self.last_device = "cpu" - self.wte = self.wte.to("cpu") - self.wpe = self.wpe.to("cpu") - for index in range(len(self.h)): - self.h[index] = self.h[index].to("cpu") - self.ln_f = self.ln_f.to("cpu") - torch.cuda.empty_cache() - - def get_input_embeddings(self): - return self.wte - - def set_input_embeddings(self, new_embeddings): - self.wte = new_embeddings - - def _prune_heads(self, heads_to_prune): - """Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}.""" - for layer, heads in heads_to_prune.items(): - self.h[layer].attn.prune_heads(heads) - - @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=BaseModelOutputWithPastAndCrossAttentions, - config_class=_CONFIG_FOR_DOC, - ) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both input_ids and inputs_embeds at the same time" - ) - elif input_ids is not None: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - batch_size = input_ids.shape[0] - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - batch_size = inputs_embeds.shape[0] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - device = input_ids.device if input_ids is not None else inputs_embeds.device - - if token_type_ids is not None: - token_type_ids = token_type_ids.view(-1, input_shape[-1]) - if position_ids is not None: - position_ids = position_ids.view(-1, input_shape[-1]) - - if past_key_values is None: - past_length = 0 - past_key_values = tuple([None] * len(self.h)) - else: - past_length = past_key_values[0][0].size(-2) - if position_ids is None: - position_ids = torch.arange( - past_length, - input_shape[-1] + past_length, - dtype=torch.long, - device=device, - ) - position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) - - # GPT2Attention mask. - if attention_mask is not None: - if batch_size <= 0: - raise ValueError("batch_size has to be defined and > 0") - attention_mask = attention_mask.view(batch_size, -1) - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask = attention_mask[:, None, None, :] - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and the dtype's smallest value for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility - attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min - - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - if self.config.add_cross_attention and encoder_hidden_states is not None: - ( - encoder_batch_size, - encoder_sequence_length, - _, - ) = encoder_hidden_states.size() - encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) - if encoder_attention_mask is None: - encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) - else: - encoder_attention_mask = None - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # head_mask has shape n_layer x batch x n_heads x N x N - head_mask = self.get_head_mask(head_mask, self.config.n_layer) - - if inputs_embeds is None: - inputs_embeds = self.wte(input_ids) - # position_embeds = self.wpe(position_ids) - hidden_states = inputs_embeds # + position_embeds - - if token_type_ids is not None: - token_type_embeds = self.wte(token_type_ids) - hidden_states = hidden_states + token_type_embeds - - hidden_states = self.drop(hidden_states) - - output_shape = input_shape + (hidden_states.size(-1),) - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - presents = () if use_cache else None - all_self_attentions = () if output_attentions else None - all_cross_attentions = ( - () if output_attentions and self.config.add_cross_attention else None - ) - all_hidden_states = () if output_hidden_states else None - for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): - # Model parallel - if self.model_parallel: - torch.cuda.set_device(hidden_states.device) - # Ensure layer_past is on same device as hidden_states (might not be correct) - if layer_past is not None: - layer_past = tuple( - past_state.to(hidden_states.device) for past_state in layer_past - ) - # Ensure that attention_mask is always on the same device as hidden_states - if attention_mask is not None: - attention_mask = attention_mask.to(hidden_states.device) - if isinstance(head_mask, torch.Tensor): - head_mask = head_mask.to(hidden_states.device) - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - None, - attention_mask, - head_mask[i], - encoder_hidden_states, - encoder_attention_mask, - ) - else: - outputs = block( - hidden_states, - layer_past=layer_past, - attention_mask=attention_mask, - head_mask=head_mask[i], - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - ) - - hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[1],) - - if output_attentions: - all_self_attentions = all_self_attentions + ( - outputs[2 if use_cache else 1], - ) - if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + ( - outputs[3 if use_cache else 2], - ) - - # Model Parallel: If it's the last layer for that device, put things on the next device - if self.model_parallel: - for k, v in self.device_map.items(): - if i == v[-1] and "cuda:" + str(k) != self.last_device: - hidden_states = hidden_states.to("cuda:" + str(k + 1)) - - hidden_states = self.ln_f(hidden_states) - - hidden_states = hidden_states.view(output_shape) - # Add last hidden state - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple( - v - for v in [ - hidden_states, - presents, - all_hidden_states, - all_self_attentions, - all_cross_attentions, - ] - if v is not None - ) - - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - cross_attentions=all_cross_attentions, - ) From aeccb22b862972c6682c7a775191f8089490839a Mon Sep 17 00:00:00 2001 From: BY571 Date: Mon, 12 Jun 2023 11:27:53 +0200 Subject: [PATCH 034/104] update docstring --- torchrl/modules/tensordict_module/actors.py | 103 ++++++++++---------- 1 file changed, 51 insertions(+), 52 deletions(-) diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index 8b1c4ce0d88..3271d9ad219 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -1590,63 +1590,62 @@ class DecisionTransformerInferenceWrapper(TensorDictModuleWrapper): Args: policy (TensorDictModule): The policy module that takes in observations and produces an action value + + Keyword Args: inference_context (int): The number of previous actions that will not be masked in the context. For example for an observation input of shape [batch_size, context, obs_dim] with context=20 and inference_context=5, the first 15 entries - of the context will be masked. - observation_key (str): The key of the observation in the input TensorDict - action_key (str): The key of the action in the input TensorDict - return_to_go_key (str): The key of the return to go in the input TensorDict + of the context will be masked. Defaults to 5. + observation_key (str): The key of the observation in the input TensorDict, defaults to "observation". + action_key (str): The key of the action in the input TensorDict, defaults to "action". + return_to_go_key (str): The key of the return to go in the input TensorDict, defaults to "return_to_go". spec (Optional[TensorSpec]): The spec of the input TensorDict. If None, it will be inferred from the policy module. Examples: - >>> import torch - >>> from tensordict import TensorDict - >>> from tensordict.nn import TensorDictModule - >>> from torchrl.modules import ( - ... ProbabilisticActor, - ... DTActor, - ... TanhNormal, - ... DecisionTransformerInferenceWrapper, - ... ) - - >>> actor_module = TensorDictModule( - DTActor(state_dim=4, action_dim=2), - in_keys=in_keys, - out_keys=[ - "loc", - "scale",]) - >>> dist_class = TanhNormal - >>> dist_kwargs = { - "min": -1.0, - "max": 1.0, - "tanh_loc": False, - } - >>> actor = ProbabilisticActor( - in_keys=["loc", "scale"], - out_keys=["action", "log_prob"], - module=actor_module, - distribution_class=dist_class, - distribution_kwargs=dist_kwargs) - - >>> inference_actor = DecisionTransformerInferenceWrapper(actor) - >>> print(inference_actor) - >>> sequence_length = 20 - >>> td = TensorDict({"observation": torch.randn(1, sequence_length, 4), - "action": torch.randn(1, sequence_length, 2), - "return_to_go": torch.randn(1, sequence_length, 1)}, [1,]) - - >>> print(inference_actor(td.clone())) - TensorDict( - fields={ - action: Tensor(shape=torch.Size([1, 2]), device=cpu, dtype=torch.float32, is_shared=False), - loc: Tensor(shape=torch.Size([1, 2]), device=cpu, dtype=torch.float32, is_shared=False), - observation: Tensor(shape=torch.Size([1, 20, 4]), device=cpu, dtype=torch.float32, is_shared=False), - sample_log_prob: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), - scale: Tensor(shape=torch.Size([1, 2]), device=cpu, dtype=torch.float32, is_shared=False), - return_to_go: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, - batch_size=torch.Size([1]), - device=None, - is_shared=False) + >>> import torch + >>> from tensordict import TensorDict + >>> from tensordict.nn import TensorDictModule + >>> from torchrl.modules import ( + ... ProbabilisticActor, + ... DTActor, + ... TanhNormal, + ... DecisionTransformerInferenceWrapper, + ... ) + >>> actor_module = TensorDictModule( + DTActor(state_dim=4, action_dim=2), + in_keys=in_keys, + out_keys=[ + "loc", + "scale",]) + >>> dist_class = TanhNormal + >>> dist_kwargs = { + "min": -1.0, + "max": 1.0, + "tanh_loc": False, + } + >>> actor = ProbabilisticActor( + in_keys=["loc", "scale"], + out_keys=["action", "log_prob"], + module=actor_module, + distribution_class=dist_class, + distribution_kwargs=dist_kwargs) + >>> inference_actor = DecisionTransformerInferenceWrapper(actor) + >>> print(inference_actor) + >>> sequence_length = 20 + >>> td = TensorDict({"observation": torch.randn(1, sequence_length, 4), + "action": torch.randn(1, sequence_length, 2), + "return_to_go": torch.randn(1, sequence_length, 1)}, [1,]) + >>> print(inference_actor(td.clone())) + TensorDict( + fields={ + action: Tensor(shape=torch.Size([1, 2]), device=cpu, dtype=torch.float32, is_shared=False), + loc: Tensor(shape=torch.Size([1, 2]), device=cpu, dtype=torch.float32, is_shared=False), + observation: Tensor(shape=torch.Size([1, 20, 4]), device=cpu, dtype=torch.float32, is_shared=False), + sample_log_prob: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), + scale: Tensor(shape=torch.Size([1, 2]), device=cpu, dtype=torch.float32, is_shared=False), + return_to_go: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([1]), + device=None, + is_shared=False) """ def __init__( From 2414e9b8f18c592d2b3f9cdc5d9ce0169f09fb4e Mon Sep 17 00:00:00 2001 From: BY571 Date: Mon, 12 Jun 2023 11:43:32 +0200 Subject: [PATCH 035/104] update actor docstring --- torchrl/modules/models/models.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/torchrl/modules/models/models.py b/torchrl/modules/models/models.py index 26619554017..24914cebbaf 100644 --- a/torchrl/modules/models/models.py +++ b/torchrl/modules/models/models.py @@ -1143,9 +1143,13 @@ def forward( class OnlineDTActor(nn.Module): """Online Decision Transformer Actor class. - Presented in "Online Decision Transformer", - https://arxiv.org/abs/2202.05607.pdf + Actor class for the Online Decision Transformer to sample actions from gaussian distribution as presented inresented in `"Online Decision Transformer" `. + Args: + state_dim (int): state dimension. + action_dim (int): action dimension. + transformer_config (Dict): config for the GPT2 transformer. + device (Optional[DEVICE_TYPING], optional): device to use. Defaults to None. """ @@ -1204,10 +1208,14 @@ def forward( class DTActor(nn.Module): """Decision Transformer Actor class. - Presented in "Decision Transformer", - https://arxiv.org/abs/2202.05607.pdf + Actor class for the Decision Transformer to output deterministic action as presented in `"Decision Transformer" `. + Args: + state_dim (int): state dimension. + action_dim (int): action dimension. + transformer_config (Dict): config for the GPT2 transformer. + device (Optional[DEVICE_TYPING], optional): device to use. Defaults to None. """ def __init__( From b03f3fe10c93c8071262cf88441ea101920985ca Mon Sep 17 00:00:00 2001 From: BY571 Date: Mon, 12 Jun 2023 13:37:08 +0200 Subject: [PATCH 036/104] add dispach, in-out-keys --- examples/decision_transformer/odt_config.yaml | 1 + examples/decision_transformer/utils.py | 3 +- torchrl/modules/tensordict_module/actors.py | 9 +- torchrl/objectives/decision_transformer.py | 209 +++++++++++++++++- 4 files changed, 208 insertions(+), 14 deletions(-) diff --git a/examples/decision_transformer/odt_config.yaml b/examples/decision_transformer/odt_config.yaml index 76329210a0d..c55b7c87e49 100644 --- a/examples/decision_transformer/odt_config.yaml +++ b/examples/decision_transformer/odt_config.yaml @@ -52,6 +52,7 @@ optim: # loss loss: alpha_init: 0.1 + target_entropy: auto transformer: n_embd: 512 diff --git a/examples/decision_transformer/utils.py b/examples/decision_transformer/utils.py index af63d7a6612..bbb26fff751 100644 --- a/examples/decision_transformer/utils.py +++ b/examples/decision_transformer/utils.py @@ -400,7 +400,8 @@ def make_dt_model(cfg): def make_odt_loss(loss_cfg, actor_network): loss = OnlineDTLoss( actor_network, - loss_cfg.alpha_init, + alpha_init=loss_cfg.alpha_init, + target_entropy=loss_cfg.target_entropy, ) return loss diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index 3271d9ad219..d6e33b4e4c0 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -1719,9 +1719,14 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: # forward pass tensordict = self.td_module.forward(tensordict) # get last action prediciton - out_action = tensordict.get(self.action_key)[:, -1] + out_action = tensordict.get(self.action_key) + idx = (slice(None),) * tensordict.ndim + (-1,) + out_action = out_action[idx] tensordict.set(self.action_key, out_action) - out_rtg = tensordict.get(self.return_to_go_key)[:, -1] + # out_rtg = tensordict.get(self.return_to_go_key)[:, -1] + out_rtg = tensordict.get(self.return_to_go_key) + idx = (slice(None),) * tensordict.ndim + (-1,) + out_rtg = out_rtg[idx] tensordict.set(self.return_to_go_key, out_rtg) # set unmasked observation tensordict.set( diff --git a/torchrl/objectives/decision_transformer.py b/torchrl/objectives/decision_transformer.py index fc372fd2e27..a6775b90e36 100644 --- a/torchrl/objectives/decision_transformer.py +++ b/torchrl/objectives/decision_transformer.py @@ -4,11 +4,16 @@ # LICENSE file in the root directory of this source tree. import math +from dataclasses import dataclass +from typing import Union import numpy as np import torch +from tensordict.nn import dispatch from tensordict.tensordict import TensorDict, TensorDictBase +from tensordict.utils import NestedKey + from torch import distributions as d from torchrl.modules import ProbabilisticActor @@ -23,17 +28,62 @@ class OnlineDTLoss(LossModule): Args: actor_network (ProbabilisticActor): stochastic actor - alpha_init (float): initial value of the temperature parameter + alpha_init (float, optional): initial entropy multiplier. + Default is 1.0. + min_alpha (float, optional): min value of alpha. + Default is None (no minimum value). + max_alpha (float, optional): max value of alpha. + Default is None (no maximum value). + fixed_alpha (bool, optional): if ``True``, alpha will be fixed to its + initial value. Otherwise, alpha will be optimized to + match the 'target_entropy' value. + Default is ``False``. + target_entropy (float or str, optional): Target entropy for the + stochastic policy. Default is "auto", where target entropy is + computed as :obj:`-prod(n_actions)`. samples_mc_entropy (int): number of samples to estimate the entropy """ + @dataclass + class _AcceptedKeys: + """Maintains default values for all configurable tensordict keys. + + This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their + default values. + + Attributes: + action (NestedKey): The input tensordict key where the action is expected. + Defaults to ``"action"``. + observation (NestedKey): The input tensordict key where the observation is expected. + Defaults to ``"observation"``. + return_to_go (NestedKey): The input tensordict key where the return_to_go is expected. + Defaults to ``"return_to_go"``. + done (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is done. Will be used for the underlying value estimator. + Defaults to ``"done"``. + """ + + action: NestedKey = "action" + observation: NestedKey = "observation" + return_to_go: NestedKey = "return_to_go" + done: NestedKey = "done" + + default_keys = _AcceptedKeys() + def __init__( self, actor_network: ProbabilisticActor, + *, alpha_init: float = 1.0, + min_alpha: float = None, + max_alpha: float = None, + fixed_alpha: bool = False, + target_entropy: Union[str, float] = "auto", samples_mc_entropy: int = 1, ) -> None: + self._in_keys = None + self._out_keys = None super().__init__() # Actor Network @@ -47,17 +97,91 @@ def __init__( device = next(self.parameters()).device except AttributeError: device = torch.device("cpu") + self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device)) - self.register_parameter( - "log_alpha", - torch.nn.Parameter(torch.tensor(math.log(alpha_init), device=device)), - ) + if bool(min_alpha) ^ bool(max_alpha): + min_alpha = min_alpha if min_alpha else 0.0 + if max_alpha == 0: + raise ValueError("max_alpha must be either None or greater than 0.") + max_alpha = max_alpha if max_alpha else 1e9 + if min_alpha: + self.register_buffer( + "min_log_alpha", torch.tensor(min_alpha, device=device).log() + ) + else: + self.min_log_alpha = None + if max_alpha: + self.register_buffer( + "max_log_alpha", torch.tensor(max_alpha, device=device).log() + ) + else: + self.max_log_alpha = None + self.fixed_alpha = fixed_alpha + if fixed_alpha: + self.register_buffer( + "log_alpha", torch.tensor(math.log(alpha_init), device=device) + ) + else: + self.register_parameter( + "log_alpha", + torch.nn.Parameter(torch.tensor(math.log(alpha_init), device=device)), + ) - target_entropy = -float(np.prod(actor_network.spec["action"].shape)) + if target_entropy == "auto": + if actor_network.spec is None: + raise RuntimeError( + "Cannot infer the dimensionality of the action. Consider providing " + "the target entropy explicitely or provide the spec of the " + "action tensor in the actor network." + ) + target_entropy = -float(np.prod(actor_network.spec["action"].shape)) self.register_buffer( "target_entropy", torch.tensor(target_entropy, device=device) ) + self.samples_mc_entropy = samples_mc_entropy + self._set_in_keys() + + def _set_in_keys(self): + keys = [ + self.tensor_keys.action, + ("next", self.tensor_keys.return_to_go), + ("next", self.tensor_keys.done), + *self.tensor_keys.action, + *[("next", key) for key in self.tensor_keys.action], + *self.tensor_keys.observation, + ] + + self._in_keys = list(set(keys)) + + @property + def alpha(self): + if self.min_log_alpha is not None: + self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha) + with torch.no_grad(): + alpha = self.log_alpha.exp() + return alpha + + @property + def in_keys(self): + if self._in_keys is None: + self._set_in_keys() + return self._in_keys + + @in_keys.setter + def in_keys(self, values): + self._in_keys = values + + @property + def out_keys(self): + if self._out_keys is None: + keys = ["loss", "loss_log_likelihood", "loss_alpha", "alpha", "entropy"] + self._out_keys = keys + return self._out_keys + + @out_keys.setter + def out_keys(self, values): + self._out_keys = values def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor: x = dist.rsample((self.samples_mc_entropy,)) @@ -65,10 +189,11 @@ def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor: # log_p: (batch_size, context_len, return -log_p.mean(axis=0) + @dispatch def forward(self, tensordict: TensorDictBase) -> TensorDictBase: """Compute the loss for the Online Decision Transformer.""" # extract action targets - target_actions = tensordict["action"].detach() + target_actions = tensordict.get(self.tensor_keys.action).detach() action_dist = self.actor_network.get_dist( tensordict, params=self.actor_network_params @@ -76,7 +201,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: loss_log_likelihood = action_dist.log_prob(target_actions).mean() entropy = self.get_entropy_bonus(action_dist).mean() - loss = -(loss_log_likelihood + self.log_alpha.exp().detach() * entropy) + loss = -(loss_log_likelihood + self.alpha.detach() * entropy) loss_alpha = self.log_alpha.exp() * (entropy - self.target_entropy).detach() @@ -85,7 +210,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: "loss_log_likelihood": -loss_log_likelihood, "entropy": entropy.detach(), "loss_alpha": loss_alpha, - "alpha": self.log_alpha.exp().detach(), + "alpha": self.alpha.detach(), } return TensorDict(out, []) @@ -100,10 +225,38 @@ class DTLoss(LossModule): """ + @dataclass + class _AcceptedKeys: + """Maintains default values for all configurable tensordict keys. + + This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their + default values. + + Attributes: + action (NestedKey): The input tensordict key where the action is expected. + Defaults to ``"action"``. + observation (NestedKey): The input tensordict key where the observation is expected. + Defaults to ``"observation"``. + return_to_go (NestedKey): The input tensordict key where the return_to_go is expected. + Defaults to ``"return_to_go"``. + done (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is done. Will be used for the underlying value estimator. + Defaults to ``"done"``. + """ + + action: NestedKey = "action" + observation: NestedKey = "observation" + return_to_go: NestedKey = "return_to_go" + done: NestedKey = "done" + + default_keys = _AcceptedKeys() + def __init__( self, actor_network: ProbabilisticActor, ) -> None: + self._in_keys = None + self._out_keys = None super().__init__() # Actor Network @@ -114,14 +267,48 @@ def __init__( funs_to_decorate=["forward"], ) + def _set_in_keys(self): + keys = [ + self.tensor_keys.action, + ("next", self.tensor_keys.return_to_go), + ("next", self.tensor_keys.done), + *self.tensor_keys.action, + *[("next", key) for key in self.tensor_keys.action], + *self.tensor_keys.observation, + ] + + self._in_keys = list(set(keys)) + + @property + def in_keys(self): + if self._in_keys is None: + self._set_in_keys() + return self._in_keys + + @in_keys.setter + def in_keys(self, values): + self._in_keys = values + + @property + def out_keys(self): + if self._out_keys is None: + keys = ["loss"] + self._out_keys = keys + return self._out_keys + + @out_keys.setter + def out_keys(self, values): + self._out_keys = values + + @dispatch def forward(self, tensordict: TensorDictBase) -> TensorDictBase: """Compute the loss for the Online Decision Transformer.""" # extract action targets - target_actions = tensordict["action"].detach() + target_actions = tensordict.get(self.tensor_keys.action).detach() pred_actions = self.actor_network( tensordict, params=self.actor_network_params - ).get("action") + ).get(self.tensor_keys.action) loss = distance_loss( pred_actions, target_actions, From e5c4575e216f61f583060c46d8ac89beca1ca434 Mon Sep 17 00:00:00 2001 From: BY571 Date: Mon, 12 Jun 2023 16:03:09 +0200 Subject: [PATCH 037/104] update inference actor inputs --- examples/decision_transformer/dt.py | 18 +++++++++++------- examples/decision_transformer/online_dt.py | 9 +++++++-- examples/decision_transformer/utils.py | 13 +++---------- torchrl/modules/tensordict_module/actors.py | 20 +++++++++----------- 4 files changed, 30 insertions(+), 30 deletions(-) diff --git a/examples/decision_transformer/dt.py b/examples/decision_transformer/dt.py index 1f8a8570a46..5816022a42f 100644 --- a/examples/decision_transformer/dt.py +++ b/examples/decision_transformer/dt.py @@ -11,6 +11,7 @@ import torch import tqdm from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.modules.tensordict_module import DecisionTransformerInferenceWrapper from utils import ( make_dt_loss, @@ -26,18 +27,21 @@ def main(cfg: "DictConfig"): # noqa: F821 model_device = cfg.optim.device - test_env = make_env(cfg.env) logger = make_logger(cfg) - offline_buffer = make_offline_replay_buffer( + offline_buffer, obs_loc, obs_std = make_offline_replay_buffer( cfg.replay_buffer, cfg.env.reward_scaling ) - - inference_actor, actor = make_dt_model(cfg) + test_env = make_env(cfg.env, obs_loc, obs_std) + actor = make_dt_model(cfg) policy = actor.to(model_device) - inference_policy = inference_actor.to(model_device) loss_module = make_dt_loss(actor) - transformer_optim, scheduler = make_dt_optimizer(cfg.optim, policy, loss_module) + transformer_optim, scheduler = make_dt_optimizer(cfg.optim, policy) + inference_policy = DecisionTransformerInferenceWrapper( + policy=policy, + loss_module=loss_module, + inference_context=cfg.env.inference_context, + ).to(model_device) pbar = tqdm.tqdm(total=cfg.optim.pretrain_gradient_steps) @@ -56,7 +60,7 @@ def main(cfg: "DictConfig"): # noqa: F821 pbar.update(i) data = offline_buffer.sample() # loss - loss_vals = loss_module(data) + loss_vals = loss_module(data.to(model_device)) # backprop transformer_loss = loss_vals["loss"] diff --git a/examples/decision_transformer/online_dt.py b/examples/decision_transformer/online_dt.py index e495df3ba8e..d61e4535a8b 100644 --- a/examples/decision_transformer/online_dt.py +++ b/examples/decision_transformer/online_dt.py @@ -11,6 +11,7 @@ import torch import tqdm from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.modules.tensordict_module import DecisionTransformerInferenceWrapper from utils import ( make_env, @@ -32,14 +33,18 @@ def main(cfg: "DictConfig"): # noqa: F821 ) test_env = make_env(cfg.env, obs_loc, obs_std) - inference_actor, actor = make_odt_model(cfg) + actor = make_odt_model(cfg) policy = actor.to(model_device) - inference_policy = inference_actor.to(model_device) loss_module = make_odt_loss(cfg.loss, actor) transformer_optim, temperature_optim, scheduler = make_odt_optimizer( cfg.optim, policy, loss_module ) + inference_policy = DecisionTransformerInferenceWrapper( + policy=policy, + loss_module=loss_module, + inference_context=cfg.env.inference_context, + ).to(model_device) pbar = tqdm.tqdm(total=cfg.optim.pretrain_gradient_steps) diff --git a/examples/decision_transformer/utils.py b/examples/decision_transformer/utils.py index bbb26fff751..7fda0dfce5d 100644 --- a/examples/decision_transformer/utils.py +++ b/examples/decision_transformer/utils.py @@ -32,7 +32,7 @@ TanhDelta, TanhNormal, ) -from torchrl.modules.tensordict_module import DecisionTransformerInferenceWrapper + from torchrl.objectives import DTLoss, OnlineDTLoss from torchrl.record.loggers import generate_exp_name, get_logger from torchrl.trainers.helpers.envs import LIBS @@ -328,11 +328,7 @@ def make_odt_model(cfg): td["action"] = td["next", "action"] actor(td) - inference_actor = DecisionTransformerInferenceWrapper( - actor, - inference_context=cfg.env.inference_context, - ) - return inference_actor, actor + return actor def make_dt_model(cfg): @@ -386,10 +382,7 @@ def make_dt_model(cfg): td["action"] = td["next", "action"] actor(td) - inference_actor = DecisionTransformerInferenceWrapper( - actor, - ) - return inference_actor, actor + return actor # ==================================================================== diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index d6e33b4e4c0..25ded51cdc7 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -1652,29 +1652,27 @@ def __init__( self, policy: TensorDictModule, *, + loss_module: TensorDictModule, inference_context: int = 5, - observation_key: str = "observation", - action_key: str = "action", - return_to_go_key: str = "return_to_go", spec: Optional[TensorSpec] = None, ): super().__init__(policy) - self.observation_key = observation_key - self.action_key = action_key - self.return_to_go_key = return_to_go_key + self.observation_key = loss_module.tensor_keys.observation + self.action_key = loss_module.tensor_keys.action + self.return_to_go_key = loss_module.tensor_keys.return_to_go self.inference_context = inference_context if spec is not None: if not isinstance(spec, CompositeSpec) and len(self.out_keys) >= 1: - spec = CompositeSpec({action_key: spec}, shape=spec.shape[:-1]) + spec = CompositeSpec({self.action_key: spec}, shape=spec.shape[:-1]) self._spec = spec elif hasattr(self.td_module, "_spec"): self._spec = self.td_module._spec.clone() - if action_key not in self._spec.keys(): - self._spec[action_key] = None + if self.action_key not in self._spec.keys(): + self._spec[self.action_key] = None elif hasattr(self.td_module, "spec"): self._spec = self.td_module.spec.clone() - if action_key not in self._spec.keys(): - self._spec[action_key] = None + if self.action_key not in self._spec.keys(): + self._spec[self.action_key] = None else: self._spec = CompositeSpec({key: None for key in policy.out_keys}) From a5213ceb24f965ddcc861bb6468d7d110a766828 Mon Sep 17 00:00:00 2001 From: BY571 Date: Mon, 12 Jun 2023 16:27:35 +0200 Subject: [PATCH 038/104] add inference wrapper to docs --- docs/source/reference/modules.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst index 0325624b6fc..79224bc0658 100644 --- a/docs/source/reference/modules.rst +++ b/docs/source/reference/modules.rst @@ -261,7 +261,7 @@ without shared parameters. It is mainly intended as a replacement for ActorCriticWrapper ActorValueOperator ValueOperator - + DecisionTransformerInferenceWrapper Other modules ~~~~~~~~~~~~~ From 1f9f885314343a63ba5d832c2acf48a70f6063ad Mon Sep 17 00:00:00 2001 From: BY571 Date: Tue, 27 Jun 2023 11:51:48 +0200 Subject: [PATCH 039/104] fix _data --- torchrl/data/replay_buffers/replay_buffers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index e6a4c66432f..0035769efc7 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -272,7 +272,7 @@ def extend(self, data: Sequence) -> torch.Tensor: Indices of the data added to the replay buffer. """ if self._transform is not None and is_tensor_collection(data): - data = self._transform.inv(data.get("_data")) + data = self._transform.inv(data) # test elif self._transform is not None and len(self._transform): data = self._transform.inv(data) return self._extend(data) From 792d35ced3dc9e4b95c19dab973ed52c7eb737aa Mon Sep 17 00:00:00 2001 From: BY571 Date: Tue, 27 Jun 2023 13:10:49 +0200 Subject: [PATCH 040/104] extract lamb opti --- examples/decision_transformer/lamb.py | 208 +++++++++++++++++++++++++ examples/decision_transformer/utils.py | 163 +------------------ 2 files changed, 209 insertions(+), 162 deletions(-) create mode 100644 examples/decision_transformer/lamb.py diff --git a/examples/decision_transformer/lamb.py b/examples/decision_transformer/lamb.py new file mode 100644 index 00000000000..a3324614051 --- /dev/null +++ b/examples/decision_transformer/lamb.py @@ -0,0 +1,208 @@ +""" PyTorch Lamb optimizer w/ behaviour similar to NVIDIA FusedLamb +This optimizer code was adapted from the following (starting with latest) +* https://github.com/HabanaAI/Model-References/blob/2b435114fe8e31f159b1d3063b8280ae37af7423/PyTorch/nlp/bert/pretraining/lamb.py +* https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py +* https://github.com/cybertronai/pytorch-lamb +Use FusedLamb if you can (GPU). The reason for including this variant of Lamb is to have a version that is +similar in behaviour to APEX FusedLamb if you aren't using NVIDIA GPUs or cannot install/use APEX. +In addition to some cleanup, this Lamb impl has been modified to support PyTorch XLA and has been tested on TPU. +Original copyrights for above sources are below. +Modifications Copyright 2021 Ross Wightman +""" +# Copyright (c) 2021, Habana Labs Ltd. All rights reserved. + +# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# MIT License +# +# Copyright (c) 2019 cybertronai +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +import math + +import torch +from torch.optim import Optimizer + + +class Lamb(Optimizer): + """Implements a pure pytorch variant of FuseLAMB (NvLamb variant) optimizer from apex.optimizers.FusedLAMB + reference: https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py + LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its norm. (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + grad_averaging (bool, optional): whether apply (1-beta2) to grad when + calculating running averages of gradient. (default: True) + max_grad_norm (float, optional): value used to clip global grad norm (default: 1.0) + trust_clip (bool): enable LAMBC trust ratio clipping (default: False) + always_adapt (boolean, optional): Apply adaptive learning rate to 0.0 + weight decay parameter (default: False) + .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes: + https://arxiv.org/abs/1904.00962 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__( + self, + params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-6, + weight_decay=0.01, + grad_averaging=True, + max_grad_norm=1.0, + trust_clip=False, + always_adapt=False, + ): + defaults = { + "lr": lr, + "bias_correction": bias_correction, + "betas": betas, + "eps": eps, + "weight_decay": weight_decay, + "grad_averaging": grad_averaging, + "max_grad_norm": max_grad_norm, + "trust_clip": trust_clip, + "always_adapt": always_adapt, + } + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + device = self.param_groups[0]["params"][0].device + one_tensor = torch.tensor( + 1.0, device=device + ) # because torch.where doesn't handle scalars correctly + global_grad_norm = torch.zeros(1, device=device) + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad + if grad.is_sparse: + raise RuntimeError( + "Lamb does not support sparse gradients, consider SparseAdam instad." + ) + global_grad_norm.add_(grad.pow(2).sum()) + + global_grad_norm = torch.sqrt(global_grad_norm) + # FIXME it'd be nice to remove explicit tensor conversion of scalars when torch.where promotes + # scalar types properly https://github.com/pytorch/pytorch/issues/9190 + max_grad_norm = torch.tensor(self.defaults["max_grad_norm"], device=device) + clip_global_grad_norm = torch.where( + global_grad_norm > max_grad_norm, + global_grad_norm / max_grad_norm, + one_tensor, + ) + + for group in self.param_groups: + bias_correction = 1 if group["bias_correction"] else 0 + beta1, beta2 = group["betas"] + grad_averaging = 1 if group["grad_averaging"] else 0 + beta3 = 1 - beta1 if grad_averaging else 1.0 + + # assume same step across group now to simplify things + # per parameter step can be easily support by making it tensor, or pass list into kernel + if "step" in group: + group["step"] += 1 + else: + group["step"] = 1 + + if bias_correction: + bias_correction1 = 1 - beta1 ** group["step"] + bias_correction2 = 1 - beta2 ** group["step"] + else: + bias_correction1, bias_correction2 = 1.0, 1.0 + + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.div_(clip_global_grad_norm) + state = self.state[p] + + # State initialization + if len(state) == 0: + # Exponential moving average of gradient valuesa + state["exp_avg"] = torch.zeros_like(p) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(p) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=beta3) # m_t + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # v_t + + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_( + group["eps"] + ) + update = (exp_avg / bias_correction1).div_(denom) + + weight_decay = group["weight_decay"] + if weight_decay != 0: + update.add_(p, alpha=weight_decay) + + if weight_decay != 0 or group["always_adapt"]: + # Layer-wise LR adaptation. By default, skip adaptation on parameters that are + # excluded from weight decay, unless always_adapt == True, then always enabled. + w_norm = p.norm(2.0) + g_norm = update.norm(2.0) + # FIXME nested where required since logical and/or not working in PT XLA + trust_ratio = torch.where( + w_norm > 0, + torch.where(g_norm > 0, w_norm / g_norm, one_tensor), + one_tensor, + ) + if group["trust_clip"]: + # LAMBC trust clipping, upper bound fixed at one + trust_ratio = torch.minimum(trust_ratio, one_tensor) + update.mul_(trust_ratio) + + p.add_(update, alpha=-group["lr"]) + + return loss diff --git a/examples/decision_transformer/utils.py b/examples/decision_transformer/utils.py index 7fda0dfce5d..e5ed3137e9d 100644 --- a/examples/decision_transformer/utils.py +++ b/examples/decision_transformer/utils.py @@ -1,5 +1,6 @@ import torch.nn import torch.optim +from lamb import Lamb from tensordict.nn import TensorDictModule from torchrl.collectors import SyncDataCollector @@ -455,165 +456,3 @@ def make_logger(cfg): wandb_kwargs={"config": cfg}, ) return logger - - -import math - -import torch -from torch.optim import Optimizer - - -class Lamb(Optimizer): - """Implements a pure pytorch variant of FuseLAMB (NvLamb variant) optimizer from apex.optimizers.FusedLAMB - reference: https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py - LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. - Arguments: - params (iterable): iterable of parameters to optimize or dicts defining parameter groups. - lr (float, optional): learning rate. (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its norm. (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability. (default: 1e-8) - weight_decay (float, optional): weight decay (L2 penalty) (default: 0) - grad_averaging (bool, optional): whether apply (1-beta2) to grad when - calculating running averages of gradient. (default: True) - max_grad_norm (float, optional): value used to clip global grad norm (default: 1.0) - trust_clip (bool): enable LAMBC trust ratio clipping (default: False) - always_adapt (boolean, optional): Apply adaptive learning rate to 0.0 - weight decay parameter (default: False) - .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes: - https://arxiv.org/abs/1904.00962 - .. _On the Convergence of Adam and Beyond: - https://openreview.net/forum?id=ryQu7f-RZ - """ - - def __init__( - self, - params, - lr=1e-3, - bias_correction=True, - betas=(0.9, 0.999), - eps=1e-6, - weight_decay=0.01, - grad_averaging=True, - max_grad_norm=1.0, - trust_clip=False, - always_adapt=False, - ): - defaults = { - "lr": lr, - "bias_correction": bias_correction, - "betas": betas, - "eps": eps, - "weight_decay": weight_decay, - "grad_averaging": grad_averaging, - "max_grad_norm": max_grad_norm, - "trust_clip": trust_clip, - "always_adapt": always_adapt, - } - super().__init__(params, defaults) - - @torch.no_grad() - def step(self, closure=None): - """Performs a single optimization step. - Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - device = self.param_groups[0]["params"][0].device - one_tensor = torch.tensor( - 1.0, device=device - ) # because torch.where doesn't handle scalars correctly - global_grad_norm = torch.zeros(1, device=device) - for group in self.param_groups: - for p in group["params"]: - if p.grad is None: - continue - grad = p.grad - if grad.is_sparse: - raise RuntimeError( - "Lamb does not support sparse gradients, consider SparseAdam instad." - ) - global_grad_norm.add_(grad.pow(2).sum()) - - global_grad_norm = torch.sqrt(global_grad_norm) - # FIXME it'd be nice to remove explicit tensor conversion of scalars when torch.where promotes - # scalar types properly https://github.com/pytorch/pytorch/issues/9190 - max_grad_norm = torch.tensor(self.defaults["max_grad_norm"], device=device) - clip_global_grad_norm = torch.where( - global_grad_norm > max_grad_norm, - global_grad_norm / max_grad_norm, - one_tensor, - ) - - for group in self.param_groups: - bias_correction = 1 if group["bias_correction"] else 0 - beta1, beta2 = group["betas"] - grad_averaging = 1 if group["grad_averaging"] else 0 - beta3 = 1 - beta1 if grad_averaging else 1.0 - - # assume same step across group now to simplify things - # per parameter step can be easily support by making it tensor, or pass list into kernel - if "step" in group: - group["step"] += 1 - else: - group["step"] = 1 - - if bias_correction: - bias_correction1 = 1 - beta1 ** group["step"] - bias_correction2 = 1 - beta2 ** group["step"] - else: - bias_correction1, bias_correction2 = 1.0, 1.0 - - for p in group["params"]: - if p.grad is None: - continue - grad = p.grad.div_(clip_global_grad_norm) - state = self.state[p] - - # State initialization - if len(state) == 0: - # Exponential moving average of gradient valuesa - state["exp_avg"] = torch.zeros_like(p) - # Exponential moving average of squared gradient values - state["exp_avg_sq"] = torch.zeros_like(p) - - exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] - - # Decay the first and second moment running average coefficient - exp_avg.mul_(beta1).add_(grad, alpha=beta3) # m_t - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # v_t - - denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_( - group["eps"] - ) - update = (exp_avg / bias_correction1).div_(denom) - - weight_decay = group["weight_decay"] - if weight_decay != 0: - update.add_(p, alpha=weight_decay) - - if weight_decay != 0 or group["always_adapt"]: - # Layer-wise LR adaptation. By default, skip adaptation on parameters that are - # excluded from weight decay, unless always_adapt == True, then always enabled. - w_norm = p.norm(2.0) - g_norm = update.norm(2.0) - # FIXME nested where required since logical and/or not working in PT XLA - trust_ratio = torch.where( - w_norm > 0, - torch.where(g_norm > 0, w_norm / g_norm, one_tensor), - one_tensor, - ) - if group["trust_clip"]: - # LAMBC trust clipping, upper bound fixed at one - trust_ratio = torch.minimum(trust_ratio, one_tensor) - update.mul_(trust_ratio) - - p.add_(update, alpha=-group["lr"]) - - return loss From 0d9fa42c7bd3c460533f7aedc21fed57e1e41ce7 Mon Sep 17 00:00:00 2001 From: BY571 Date: Tue, 27 Jun 2023 15:25:11 +0200 Subject: [PATCH 041/104] add DT args and example in docstring --- .../modules/models/decision_transformer.py | 64 +++++++++++++++---- 1 file changed, 50 insertions(+), 14 deletions(-) diff --git a/torchrl/modules/models/decision_transformer.py b/torchrl/modules/models/decision_transformer.py index 10d3c56c06c..96d3fa9921b 100644 --- a/torchrl/modules/models/decision_transformer.py +++ b/torchrl/modules/models/decision_transformer.py @@ -1,3 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + import torch import torch.nn as nn import transformers @@ -5,7 +11,37 @@ class DecisionTransformer(nn.Module): - """online Decion Transformer as described in https://arxiv.org/abs/2202.05607 .""" + """Online Decion Transformer. + + Desdescribed in https://arxiv.org/abs/2202.05607 . + + Args: + state_dim (int): dimension of the state space + action_dim (int): dimension of the action space + config (dict): transformer architecture configuration, used to create the GPT2Config from transformers. + + + Example: + >>> config = { + >>> "n_embd": 256, + >>> "n_layer": 4, + >>> "n_head": 4, + >>> "n_inner": 1024, + >>> "activation": "relu", + >>> "n_positions": 1024, + >>> "resid_pdrop": 0.1, + >>> "attn_pdrop": 0.1, + >>> } + >>> model = DecisionTransformer(state_dim=4, action_dim=2, config=config) + >>> observation = torch.randn(32, 10, 4) + >>> action = torch.randn(32, 10, 2) + >>> return_to_go = torch.randn(32, 10, 1) + >>> output = model(observation, action, return_to_go) + >>> output.shape + torch.Size([32, 10, 256]) + + + """ def __init__( self, @@ -16,27 +52,27 @@ def __init__( super(DecisionTransformer, self).__init__() gpt_config = transformers.GPT2Config( - n_embd=config.n_embd, - n_layer=config.n_layer, - n_head=config.n_head, - n_inner=config.n_inner, - activation_function=config.activation, - n_positions=config.n_positions, - resid_pdrop=config.resid_pdrop, - attn_pdrop=config.attn_pdrop, + n_embd=config["n_embd"], + n_layer=config["n_layer"], + n_head=config["n_head"], + n_inner=["config.n_inner"], + activation_function=config["activation"], + n_positions=config["n_positions"], + resid_pdrop=config["resid_pdrop"], + attn_pdrop=config["attn_pdrop"], vocab_size=1, ) self.state_dim = state_dim self.action_dim = action_dim - self.hidden_size = config.n_embd + self.hidden_size = config["n_embd"] self.transformer = GPT2Model(config=gpt_config) - self.embed_return = torch.nn.Linear(1, config.n_embd) - self.embed_state = torch.nn.Linear(self.state_dim, config.n_embd) - self.embed_action = torch.nn.Linear(self.action_dim, config.n_embd) + self.embed_return = torch.nn.Linear(1, self.hidden_size) + self.embed_state = torch.nn.Linear(self.state_dim, self.hidden_size) + self.embed_action = torch.nn.Linear(self.action_dim, self.hidden_size) - self.embed_ln = nn.LayerNorm(config.n_embd) + self.embed_ln = nn.LayerNorm(self.hidden_size) def forward( self, From 83642c7661a119c3db6ce3da9e3a48df88fdc7b7 Mon Sep 17 00:00:00 2001 From: BY571 Date: Tue, 27 Jun 2023 15:46:45 +0200 Subject: [PATCH 042/104] update constant target return and reduction --- torchrl/envs/transforms/transforms.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 34c24221262..0b5f1da4eb3 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -1180,11 +1180,7 @@ def _apply_transform( target_return = target_return - reward return target_return elif self.mode == "constant": - if reward.ndim == 1 and target_return.ndim == 2: - # if target is stacked - target_return = target_return[-1] - reward - else: - target_return = target_return - reward + return target_return else: raise ValueError("Unknown mode: {}".format(self.mode)) @@ -2176,10 +2172,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: def unfolding(self, tensordict: TensorDictBase) -> TensorDictBase: # it is assumed that the last dimension of the tensordict is the time dimension - if not tensordict.ndim or tensordict.names[-1] != "time": - raise ValueError( - "The last dimension of the tensordict must be marked as 'time'." - ) + # if not tensordict.ndim or tensordict.names[-1] != "time": + # raise ValueError( + # "The last dimension of the tensordict must be marked as 'time'." + # ) # first sort the in_keys with strings and non-strings in_keys = list( zip( From 39dda0096b39d435936590d354cd4d4c1de25105 Mon Sep 17 00:00:00 2001 From: BY571 Date: Tue, 27 Jun 2023 15:59:20 +0200 Subject: [PATCH 043/104] fixes for target return transform --- examples/decision_transformer/dt_config.yaml | 1 + examples/decision_transformer/odt_config.yaml | 1 + examples/decision_transformer/utils.py | 14 +++++++++++--- torchrl/envs/transforms/transforms.py | 5 +++++ torchrl/modules/models/decision_transformer.py | 3 +-- 5 files changed, 19 insertions(+), 5 deletions(-) diff --git a/examples/decision_transformer/dt_config.yaml b/examples/decision_transformer/dt_config.yaml index c0f7d75706f..1ff21af8d3e 100644 --- a/examples/decision_transformer/dt_config.yaml +++ b/examples/decision_transformer/dt_config.yaml @@ -12,6 +12,7 @@ env: reward_scaling: 0.001 # for r2g noop: 1 seed: 1 + target_return_mode: constant eval_target_return: 6000 collect_target_return: 12000 total_online_frames: 1000000 diff --git a/examples/decision_transformer/odt_config.yaml b/examples/decision_transformer/odt_config.yaml index c55b7c87e49..912a29b5cbf 100644 --- a/examples/decision_transformer/odt_config.yaml +++ b/examples/decision_transformer/odt_config.yaml @@ -12,6 +12,7 @@ env: reward_scaling: 0.001 # for r2g noop: 1 seed: 2 + target_return_mode: constant eval_target_return: 6000 collect_target_return: 12000 diff --git a/examples/decision_transformer/utils.py b/examples/decision_transformer/utils.py index e5ed3137e9d..be7a4c8fb6d 100644 --- a/examples/decision_transformer/utils.py +++ b/examples/decision_transformer/utils.py @@ -66,11 +66,19 @@ def make_transformed_env(base_env, env_cfg, obs_loc, obs_std, train=False): transformed_env = TransformedEnv(base_env) if train: transformed_env.append_transform( - TargetReturn(env_cfg.collect_target_return, out_keys=["return_to_go"]) + TargetReturn( + env_cfg.collect_target_return, + out_keys=["return_to_go"], + mode=env_cfg.target_return_mode, + ) ) else: transformed_env.append_transform( - TargetReturn(env_cfg.eval_target_return, out_keys=["return_to_go"]) + TargetReturn( + env_cfg.eval_target_return, + out_keys=["return_to_go"], + mode=env_cfg.target_return_mode, + ) ) transformed_env.append_transform( RewardScaling( @@ -314,7 +322,7 @@ def make_odt_model(cfg): actor = ProbabilisticActor( spec=action_spec, in_keys=["loc", "scale"], - out_keys=["action", "log_prob"], + out_keys=["action"], module=actor_module, distribution_class=dist_class, distribution_kwargs=dist_kwargs, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 0b5f1da4eb3..69a65e0de86 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -1180,6 +1180,11 @@ def _apply_transform( target_return = target_return - reward return target_return elif self.mode == "constant": + if reward.ndim == 1 and target_return.ndim == 2: + # if target is stacked + target_return = target_return[-1] + else: + target_return = target_return return target_return else: raise ValueError("Unknown mode: {}".format(self.mode)) diff --git a/torchrl/modules/models/decision_transformer.py b/torchrl/modules/models/decision_transformer.py index 96d3fa9921b..f42ffef2271 100644 --- a/torchrl/modules/models/decision_transformer.py +++ b/torchrl/modules/models/decision_transformer.py @@ -40,7 +40,6 @@ class DecisionTransformer(nn.Module): >>> output.shape torch.Size([32, 10, 256]) - """ def __init__( @@ -55,7 +54,7 @@ def __init__( n_embd=config["n_embd"], n_layer=config["n_layer"], n_head=config["n_head"], - n_inner=["config.n_inner"], + n_inner=config["n_inner"], activation_function=config["activation"], n_positions=config["n_positions"], resid_pdrop=config["resid_pdrop"], From c5c71e6e56edf923b54ae113fe8fbccabba937e0 Mon Sep 17 00:00:00 2001 From: BY571 Date: Tue, 27 Jun 2023 16:07:39 +0200 Subject: [PATCH 044/104] update add transformers installed check --- torchrl/modules/models/decision_transformer.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/torchrl/modules/models/decision_transformer.py b/torchrl/modules/models/decision_transformer.py index f42ffef2271..9c970e2097c 100644 --- a/torchrl/modules/models/decision_transformer.py +++ b/torchrl/modules/models/decision_transformer.py @@ -3,12 +3,15 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import importlib import torch import torch.nn as nn import transformers from transformers.models.gpt2.modeling_gpt2 import GPT2Model +_has_transformers = importlib.util.find_spec("transformers") is not None + class DecisionTransformer(nn.Module): """Online Decion Transformer. @@ -48,6 +51,10 @@ def __init__( action_dim, config, ): + if not _has_transformers: + raise ImportError( + "transformers is not installed. Please install it with `pip install transformers`." + ) super(DecisionTransformer, self).__init__() gpt_config = transformers.GPT2Config( From ca36a0f90a15952db6bf6a8a1b5fcbe8f7d9ba78 Mon Sep 17 00:00:00 2001 From: BY571 Date: Tue, 27 Jun 2023 16:10:51 +0200 Subject: [PATCH 045/104] update docstring actor DT --- torchrl/modules/tensordict_module/actors.py | 22 ++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index 3ba591dcf4b..02bfba95f0d 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -1611,17 +1611,17 @@ class DecisionTransformerInferenceWrapper(TensorDictModuleWrapper): ... DecisionTransformerInferenceWrapper, ... ) >>> actor_module = TensorDictModule( - DTActor(state_dim=4, action_dim=2), - in_keys=in_keys, - out_keys=[ - "loc", - "scale",]) + ... DTActor(state_dim=4, action_dim=2), + ... in_keys=in_keys, + ... out_keys=[ + ... "loc", + ... "scale",]) >>> dist_class = TanhNormal >>> dist_kwargs = { - "min": -1.0, - "max": 1.0, - "tanh_loc": False, - } + ... "min": -1.0, + ... "max": 1.0, + ... "tanh_loc": False, + ... } >>> actor = ProbabilisticActor( in_keys=["loc", "scale"], out_keys=["action", "log_prob"], @@ -1632,8 +1632,8 @@ class DecisionTransformerInferenceWrapper(TensorDictModuleWrapper): >>> print(inference_actor) >>> sequence_length = 20 >>> td = TensorDict({"observation": torch.randn(1, sequence_length, 4), - "action": torch.randn(1, sequence_length, 2), - "return_to_go": torch.randn(1, sequence_length, 1)}, [1,]) + ... "action": torch.randn(1, sequence_length, 2), + ... "return_to_go": torch.randn(1, sequence_length, 1)}, [1,]) >>> print(inference_actor(td.clone())) TensorDict( fields={ From 9c0dfbbd5dc03f2e9c260ded63bc897de0e0ba70 Mon Sep 17 00:00:00 2001 From: BY571 Date: Tue, 27 Jun 2023 16:20:54 +0200 Subject: [PATCH 046/104] add docstring for modules and examples --- .../modules/models/decision_transformer.py | 18 ++++---- torchrl/modules/models/models.py | 43 ++++++++++++++++++- 2 files changed, 51 insertions(+), 10 deletions(-) diff --git a/torchrl/modules/models/decision_transformer.py b/torchrl/modules/models/decision_transformer.py index 9c970e2097c..c3db7cb71f2 100644 --- a/torchrl/modules/models/decision_transformer.py +++ b/torchrl/modules/models/decision_transformer.py @@ -26,15 +26,15 @@ class DecisionTransformer(nn.Module): Example: >>> config = { - >>> "n_embd": 256, - >>> "n_layer": 4, - >>> "n_head": 4, - >>> "n_inner": 1024, - >>> "activation": "relu", - >>> "n_positions": 1024, - >>> "resid_pdrop": 0.1, - >>> "attn_pdrop": 0.1, - >>> } + ... "n_embd": 256, + ... "n_layer": 4, + ... "n_head": 4, + ... "n_inner": 1024, + ... "activation": "relu", + ... "n_positions": 1024, + ... "resid_pdrop": 0.1, + ... "attn_pdrop": 0.1, + ... } >>> model = DecisionTransformer(state_dim=4, action_dim=2, config=config) >>> observation = torch.randn(32, 10, 4) >>> action = torch.randn(32, 10, 2) diff --git a/torchrl/modules/models/models.py b/torchrl/modules/models/models.py index 24914cebbaf..4e274e43f8e 100644 --- a/torchrl/modules/models/models.py +++ b/torchrl/modules/models/models.py @@ -1144,6 +1144,7 @@ class OnlineDTActor(nn.Module): """Online Decision Transformer Actor class. Actor class for the Online Decision Transformer to sample actions from gaussian distribution as presented inresented in `"Online Decision Transformer" `. + Returns mu and sigma for the gaussian distribution to sample actions from. Args: state_dim (int): state dimension. @@ -1151,6 +1152,26 @@ class OnlineDTActor(nn.Module): transformer_config (Dict): config for the GPT2 transformer. device (Optional[DEVICE_TYPING], optional): device to use. Defaults to None. + Examples: + >>> config = { + ... "n_embd": 256, + ... "n_layer": 4, + ... "n_head": 4, + ... "n_inner": 1024, + ... "activation": "relu", + ... "n_positions": 1024, + ... "resid_pdrop": 0.1, + ... "attn_pdrop": 0.1, + ... } + >>> model = OnlineDTActor(state_dim=4, action_dim=2, config=config) + >>> observation = torch.randn(32, 10, 4) + >>> action = torch.randn(32, 10, 2) + >>> return_to_go = torch.randn(32, 10, 1) + >>> (mu, std) = model(observation, action, return_to_go) + >>> mu.shape + torch.Size([32, 10, 2]) + >>> std.shape + torch.Size([32, 10, 2]) """ def __init__( @@ -1209,13 +1230,33 @@ class DTActor(nn.Module): """Decision Transformer Actor class. Actor class for the Decision Transformer to output deterministic action as presented in `"Decision Transformer" `. - + Returns the deterministic actions. Args: state_dim (int): state dimension. action_dim (int): action dimension. transformer_config (Dict): config for the GPT2 transformer. device (Optional[DEVICE_TYPING], optional): device to use. Defaults to None. + + Examples: + >>> config = { + ... "n_embd": 256, + ... "n_layer": 4, + ... "n_head": 4, + ... "n_inner": 1024, + ... "activation": "relu", + ... "n_positions": 1024, + ... "resid_pdrop": 0.1, + ... "attn_pdrop": 0.1, + ... } + >>> model = DTActor(state_dim=4, action_dim=2, config=config) + >>> observation = torch.randn(32, 10, 4) + >>> action = torch.randn(32, 10, 2) + >>> return_to_go = torch.randn(32, 10, 1) + >>> output = model(observation, action, return_to_go) + >>> output.shape + torch.Size([32, 10, 2]) + """ def __init__( From ddb284ed299eb50ea049763b7d59582c37f0cfe4 Mon Sep 17 00:00:00 2001 From: BY571 Date: Tue, 27 Jun 2023 16:56:24 +0200 Subject: [PATCH 047/104] udpate config --- examples/decision_transformer/dt_config.yaml | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/decision_transformer/dt_config.yaml b/examples/decision_transformer/dt_config.yaml index 1ff21af8d3e..3b4033f13da 100644 --- a/examples/decision_transformer/dt_config.yaml +++ b/examples/decision_transformer/dt_config.yaml @@ -15,8 +15,6 @@ env: target_return_mode: constant eval_target_return: 6000 collect_target_return: 12000 - total_online_frames: 1000000 - # logger logger: From cf5de9ad6f4e79802cda4e0419bb4937effb8f18 Mon Sep 17 00:00:00 2001 From: BY571 Date: Wed, 28 Jun 2023 12:10:32 +0200 Subject: [PATCH 048/104] take off unsqueeze in models --- torchrl/modules/models/models.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/torchrl/modules/models/models.py b/torchrl/modules/models/models.py index 4e274e43f8e..6ad68afb831 100644 --- a/torchrl/modules/models/models.py +++ b/torchrl/modules/models/models.py @@ -1208,10 +1208,6 @@ def forward( action: torch.Tensor, return_to_go: torch.Tensor, ) -> torch.Tensor: - if observation.ndim == 2: - observation = observation.unsqueeze(0) - action = action.unsqueeze(0) - return_to_go = return_to_go.unsqueeze(0) hidden_state = self.transformer(observation, action, return_to_go) out = self.action_layer(hidden_state) mu, log_std = torch.chunk(out, 2, -1) @@ -1291,10 +1287,6 @@ def forward( action: torch.Tensor, return_to_go: torch.Tensor, ) -> torch.Tensor: - if observation.ndim == 2: - observation = observation.unsqueeze(0) - action = action.unsqueeze(0) - return_to_go = return_to_go.unsqueeze(0) hidden_state = self.transformer(observation, action, return_to_go) out = self.action_layer(hidden_state) return out From c3d0ffa15b34bf918c8838fa200b940e27027904 Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 29 Jun 2023 12:30:06 +0200 Subject: [PATCH 049/104] add loss function to config --- examples/decision_transformer/dt.py | 2 +- examples/decision_transformer/dt_config.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/decision_transformer/dt.py b/examples/decision_transformer/dt.py index 5816022a42f..9004925afde 100644 --- a/examples/decision_transformer/dt.py +++ b/examples/decision_transformer/dt.py @@ -35,7 +35,7 @@ def main(cfg: "DictConfig"): # noqa: F821 actor = make_dt_model(cfg) policy = actor.to(model_device) - loss_module = make_dt_loss(actor) + loss_module = make_dt_loss(cfg.loss, actor) transformer_optim, scheduler = make_dt_optimizer(cfg.optim, policy) inference_policy = DecisionTransformerInferenceWrapper( policy=policy, diff --git a/examples/decision_transformer/dt_config.yaml b/examples/decision_transformer/dt_config.yaml index 3b4033f13da..65ba26e4664 100644 --- a/examples/decision_transformer/dt_config.yaml +++ b/examples/decision_transformer/dt_config.yaml @@ -51,7 +51,7 @@ optim: # loss loss: - alpha_init: 0.1 + loss_function: "l2" transformer: n_embd: 128 From e4ea2785f8218e2872ed885bf8e357d2686ba20b Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 29 Jun 2023 12:30:46 +0200 Subject: [PATCH 050/104] add loss function to config --- examples/decision_transformer/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/decision_transformer/utils.py b/examples/decision_transformer/utils.py index be7a4c8fb6d..4e3d3573f82 100644 --- a/examples/decision_transformer/utils.py +++ b/examples/decision_transformer/utils.py @@ -408,9 +408,10 @@ def make_odt_loss(loss_cfg, actor_network): return loss -def make_dt_loss(actor_network): +def make_dt_loss(loss_cfg, actor_network): loss = DTLoss( actor_network, + loss_function=loss_cfg.loss_function, ) return loss From d2c1b08edad20a7fd48aa030d27228d09f959c51 Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 29 Jun 2023 12:31:31 +0200 Subject: [PATCH 051/104] update loss module --- torchrl/objectives/decision_transformer.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/torchrl/objectives/decision_transformer.py b/torchrl/objectives/decision_transformer.py index a6775b90e36..a7e132a477f 100644 --- a/torchrl/objectives/decision_transformer.py +++ b/torchrl/objectives/decision_transformer.py @@ -28,6 +28,8 @@ class OnlineDTLoss(LossModule): Args: actor_network (ProbabilisticActor): stochastic actor + + Keyword Args: alpha_init (float, optional): initial entropy multiplier. Default is 1.0. min_alpha (float, optional): min value of alpha. @@ -147,9 +149,8 @@ def _set_in_keys(self): self.tensor_keys.action, ("next", self.tensor_keys.return_to_go), ("next", self.tensor_keys.done), - *self.tensor_keys.action, - *[("next", key) for key in self.tensor_keys.action], - *self.tensor_keys.observation, + self.tensor_keys.action, + self.tensor_keys.observation, ] self._in_keys = list(set(keys)) @@ -223,6 +224,9 @@ class DTLoss(LossModule): Args: actor_network (ProbabilisticActor): stochastic actor + Keyword Args: + loss_function (str): loss function to use. Defaults to ``"l2"``. + """ @dataclass @@ -254,6 +258,8 @@ class _AcceptedKeys: def __init__( self, actor_network: ProbabilisticActor, + *, + loss_function: str = "l2", ) -> None: self._in_keys = None self._out_keys = None @@ -266,15 +272,15 @@ def __init__( create_target_params=False, funs_to_decorate=["forward"], ) + self.loss_function = loss_function def _set_in_keys(self): keys = [ self.tensor_keys.action, ("next", self.tensor_keys.return_to_go), ("next", self.tensor_keys.done), - *self.tensor_keys.action, - *[("next", key) for key in self.tensor_keys.action], - *self.tensor_keys.observation, + self.tensor_keys.action, + self.tensor_keys.observation, ] self._in_keys = list(set(keys)) @@ -312,7 +318,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: loss = distance_loss( pred_actions, target_actions, - loss_function="l2", + loss_function=self.loss_function, ).mean() out = { "loss": loss, From a62a64738960399b6f0d862457ec789f9cd7c07b Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 29 Jun 2023 12:33:17 +0200 Subject: [PATCH 052/104] udpate DT actor docstring --- torchrl/modules/tensordict_module/actors.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index 02bfba95f0d..5ce0a6b622b 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -1592,12 +1592,10 @@ class DecisionTransformerInferenceWrapper(TensorDictModuleWrapper): observations and produces an action value Keyword Args: + loss_module (TensorDictModule): The loss module that computes the DT loss to receive the input keys of the model. inference_context (int): The number of previous actions that will not be masked in the context. For example for an observation input of shape [batch_size, context, obs_dim] with context=20 and inference_context=5, the first 15 entries of the context will be masked. Defaults to 5. - observation_key (str): The key of the observation in the input TensorDict, defaults to "observation". - action_key (str): The key of the action in the input TensorDict, defaults to "action". - return_to_go_key (str): The key of the return to go in the input TensorDict, defaults to "return_to_go". spec (Optional[TensorSpec]): The spec of the input TensorDict. If None, it will be inferred from the policy module. Examples: @@ -1623,11 +1621,11 @@ class DecisionTransformerInferenceWrapper(TensorDictModuleWrapper): ... "tanh_loc": False, ... } >>> actor = ProbabilisticActor( - in_keys=["loc", "scale"], - out_keys=["action", "log_prob"], - module=actor_module, - distribution_class=dist_class, - distribution_kwargs=dist_kwargs) + ... in_keys=["loc", "scale"], + ... out_keys=["action", "log_prob"], + ... module=actor_module, + ... distribution_class=dist_class, + ... distribution_kwargs=dist_kwargs) >>> inference_actor = DecisionTransformerInferenceWrapper(actor) >>> print(inference_actor) >>> sequence_length = 20 From 200906085b8988bc6e4c52cec879cdb792d67604 Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 29 Jun 2023 14:44:07 +0200 Subject: [PATCH 053/104] add default transformer config --- .../modules/models/decision_transformer.py | 47 ++++++++++++++----- 1 file changed, 36 insertions(+), 11 deletions(-) diff --git a/torchrl/modules/models/decision_transformer.py b/torchrl/modules/models/decision_transformer.py index c3db7cb71f2..4934d76e9c4 100644 --- a/torchrl/modules/models/decision_transformer.py +++ b/torchrl/modules/models/decision_transformer.py @@ -18,6 +18,18 @@ class DecisionTransformer(nn.Module): Desdescribed in https://arxiv.org/abs/2202.05607 . + The transformer utilizes a default config to create the GPT2 model if the user does not provide a specific config. + default_config = { + "n_embd": 256, + "n_layer": 4, + "n_head": 4, + "n_inner": 1024, + "activation": "relu", + "n_positions": 1024, + "resid_pdrop": 0.1, + "attn_pdrop": 0.1, + } + Args: state_dim (int): dimension of the state space action_dim (int): dimension of the action space @@ -31,7 +43,7 @@ class DecisionTransformer(nn.Module): ... "n_head": 4, ... "n_inner": 1024, ... "activation": "relu", - ... "n_positions": 1024, + ... "n_positions": 1024,clear ... "resid_pdrop": 0.1, ... "attn_pdrop": 0.1, ... } @@ -45,11 +57,22 @@ class DecisionTransformer(nn.Module): """ + default_config = { + "n_embd": 256, + "n_layer": 4, + "n_head": 4, + "n_inner": 1024, + "activation": "relu", + "n_positions": 1024, + "resid_pdrop": 0.1, + "attn_pdrop": 0.1, + } + def __init__( self, state_dim, action_dim, - config, + config: dict = default_config, ): if not _has_transformers: raise ImportError( @@ -57,20 +80,22 @@ def __init__( ) super(DecisionTransformer, self).__init__() + self.default_config.update(config) + gpt_config = transformers.GPT2Config( - n_embd=config["n_embd"], - n_layer=config["n_layer"], - n_head=config["n_head"], - n_inner=config["n_inner"], - activation_function=config["activation"], - n_positions=config["n_positions"], - resid_pdrop=config["resid_pdrop"], - attn_pdrop=config["attn_pdrop"], + n_embd=self.default_config["n_embd"], + n_layer=self.default_config["n_layer"], + n_head=self.default_config["n_head"], + n_inner=self.default_config["n_inner"], + activation_function=self.default_config["activation"], + n_positions=self.default_config["n_positions"], + resid_pdrop=self.default_config["resid_pdrop"], + attn_pdrop=self.default_config["attn_pdrop"], vocab_size=1, ) self.state_dim = state_dim self.action_dim = action_dim - self.hidden_size = config["n_embd"] + self.hidden_size = self.default_config["n_embd"] self.transformer = GPT2Model(config=gpt_config) From 77630bd488b83e1f519760a4e5d9830594838775 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 3 Jul 2023 14:45:59 +0100 Subject: [PATCH 054/104] amend --- examples/decision_transformer/online_dt.py | 2 +- test/test_cost.py | 114 +++++++++++++++++++-- torchrl/objectives/decision_transformer.py | 65 +++++------- 3 files changed, 132 insertions(+), 49 deletions(-) diff --git a/examples/decision_transformer/online_dt.py b/examples/decision_transformer/online_dt.py index d61e4535a8b..5dff3390371 100644 --- a/examples/decision_transformer/online_dt.py +++ b/examples/decision_transformer/online_dt.py @@ -63,7 +63,7 @@ def main(cfg: "DictConfig"): # noqa: F821 data = offline_buffer.sample() # loss loss_vals = loss_module(data.to(model_device)) - transformer_loss = loss_vals["loss"] + transformer_loss = loss_vals["loss_log_likelihood"] + loss_vals["loss_entropy"] temperature_loss = loss_vals["loss_alpha"] transformer_optim.zero_grad() diff --git a/test/test_cost.py b/test/test_cost.py index 4d9215b430d..4ac0081ed5d 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -12,6 +12,8 @@ from dataclasses import asdict, dataclass from packaging import version as pack_version +from tensordict._tensordict import unravel_keys + from tensordict.nn import ( InteractionType, ProbabilisticTensorDictModule as ProbMod, @@ -164,6 +166,12 @@ def get_devices(): class LossModuleTestBase: + def _flatten_in_keys(self, in_keys): + return [ + in_key if isinstance(in_key, str) else "_".join(list(unravel_keys(in_key))) + for in_key in in_keys + ] + def tensordict_keys_test(self, loss_fn, default_keys, td_est=None): self.tensordict_keys_unknown_key_test(loss_fn) self.tensordict_keys_default_values_test(loss_fn, default_keys) @@ -6074,7 +6082,7 @@ def test_dreamer_value_tensordict_keys(self, device): self.tensordict_keys_test(loss_fn, default_keys=default_keys) -class TestOnlineDT: +class TestOnlineDT(LossModuleTestBase): seed = 0 def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): @@ -6136,7 +6144,11 @@ def test_odt(self, device): loss_fn = OnlineDTLoss(actor) loss = loss_fn(td) - loss_transformer = loss["loss"] + loss_transformer = sum( + loss[key] + for key in loss.keys() + if key.startswith("loss") and key != "loss_alpha" + ) loss_alpha = loss["loss_alpha"] loss_transformer.backward(retain_graph=True) named_parameters = loss_fn.named_parameters() @@ -6179,7 +6191,11 @@ def test_seq_odt(self, device): loss_fn = OnlineDTLoss(actor) loss = loss_fn(td) - loss_transformer = loss["loss"] + loss_transformer = sum( + loss[key] + for key in loss.keys() + if key.startswith("loss") and key != "loss_alpha" + ) loss_alpha = loss["loss_alpha"] loss_transformer.backward(retain_graph=True) named_parameters = loss_fn.named_parameters() @@ -6213,8 +6229,56 @@ def test_seq_odt(self, device): for name, p in named_parameters: assert p.grad.norm() > 0.0, f"parameter {name} has a null gradient" + def test_onlinedt_tensordict_keys(self): + actor = self._create_mock_actor() + loss_fn = OnlineDTLoss(actor) -class TestDT: + default_keys = { + "action": "action", + } + + self.tensordict_keys_test( + loss_fn, + default_keys=default_keys, + ) + + @pytest.mark.parametrize("device", get_default_devices()) + def test_onlinedt_notensordict(self, device): + torch.manual_seed(self.seed) + actor = self._create_mock_actor(device=device) + td = self._create_mock_data_odt(device=device) + loss_fn = OnlineDTLoss(actor) + + in_keys = self._flatten_in_keys(loss_fn.in_keys) + kwargs = dict(td.flatten_keys("_").select(*in_keys)) + + torch.manual_seed(0) + loss_val_td = loss_fn(td) + torch.manual_seed(0) + loss_log_likelihood, loss_entropy, loss_alpha, alpha, entropy = loss_fn( + **kwargs + ) + torch.testing.assert_close( + loss_val_td.get("loss_log_likelihood"), loss_log_likelihood + ) + torch.testing.assert_close(loss_val_td.get("loss_entropy"), loss_entropy) + torch.testing.assert_close(loss_val_td.get("loss_alpha"), loss_alpha) + # test select + torch.manual_seed(0) + loss_fn.select_out_keys("loss_entropy") + if torch.__version__ >= "2.0.0": + loss_entropy = loss_fn(**kwargs) + else: + with pytest.raises( + RuntimeError, + match="You are likely using tensordict.nn.dispatch with keyword arguments", + ): + loss_entropy = loss_fn(**kwargs) + return + assert loss_entropy == loss_val_td["loss_entropy"] + + +class TestDT(LossModuleTestBase): seed = 0 def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): @@ -6242,7 +6306,6 @@ def _create_mock_data_dt(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): source={ "observation": obs, "action": action, - "reward2go": reward2go, }, device=device, ) @@ -6254,19 +6317,56 @@ def _create_seq_mock_data_dt( # create a tensordict obs = torch.randn(batch, T, obs_dim, device=device) action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) - reward2go = torch.randn(batch, T, 1, device=device) td = TensorDict( batch_size=(batch, T), source={ "observation": obs, - "reward": reward2go, "action": action, }, device=device, ) return td + def test_dt_tensordict_keys(self): + actor = self._create_mock_actor() + loss_fn = DTLoss(actor) + + default_keys = { + "action": "action", + } + + self.tensordict_keys_test( + loss_fn, + default_keys=default_keys, + ) + + @pytest.mark.parametrize("device", get_default_devices()) + def test_dt_notensordict(self, device): + torch.manual_seed(self.seed) + actor = self._create_mock_actor(device=device) + td = self._create_mock_data_dt(device=device) + loss_fn = DTLoss(actor) + + in_keys = self._flatten_in_keys(loss_fn.in_keys) + kwargs = dict(td.flatten_keys("_").select(*in_keys)) + + loss_val_td = loss_fn(td) + loss_val = loss_fn(**kwargs) + torch.testing.assert_close(loss_val_td.get("loss"), loss_val) + # test select + loss_fn.select_out_keys("loss") + if torch.__version__ >= "2.0.0": + loss_actor = loss_fn(**kwargs) + else: + with pytest.raises( + RuntimeError, + match="You are likely using tensordict.nn.dispatch with keyword arguments", + ): + loss_actor = loss_fn(**kwargs) + return + assert loss_actor == loss_val_td["loss"] + @pytest.mark.parametrize("device", get_available_devices()) def test_dt(self, device): torch.manual_seed(self.seed) diff --git a/torchrl/objectives/decision_transformer.py b/torchrl/objectives/decision_transformer.py index a7e132a477f..9a1693c74a0 100644 --- a/torchrl/objectives/decision_transformer.py +++ b/torchrl/objectives/decision_transformer.py @@ -57,19 +57,10 @@ class _AcceptedKeys: Attributes: action (NestedKey): The input tensordict key where the action is expected. Defaults to ``"action"``. - observation (NestedKey): The input tensordict key where the observation is expected. - Defaults to ``"observation"``. - return_to_go (NestedKey): The input tensordict key where the return_to_go is expected. - Defaults to ``"return_to_go"``. - done (NestedKey): The key in the input TensorDict that indicates - whether a trajectory is done. Will be used for the underlying value estimator. - Defaults to ``"done"``. + """ action: NestedKey = "action" - observation: NestedKey = "observation" - return_to_go: NestedKey = "return_to_go" - done: NestedKey = "done" default_keys = _AcceptedKeys() @@ -145,15 +136,13 @@ def __init__( self._set_in_keys() def _set_in_keys(self): - keys = [ - self.tensor_keys.action, - ("next", self.tensor_keys.return_to_go), - ("next", self.tensor_keys.done), - self.tensor_keys.action, - self.tensor_keys.observation, - ] + keys = self.actor_network.in_keys + keys = set(keys) + keys.add(self.tensor_keys.action) + self._in_keys = sorted(keys, key=str) - self._in_keys = list(set(keys)) + def _forward_value_estimator_keys(self, **kwargs): + pass @property def alpha(self): @@ -176,7 +165,13 @@ def in_keys(self, values): @property def out_keys(self): if self._out_keys is None: - keys = ["loss", "loss_log_likelihood", "loss_alpha", "alpha", "entropy"] + keys = [ + "loss_log_likelihood", + "loss_entropy", + "loss_alpha", + "alpha", + "entropy", + ] self._out_keys = keys return self._out_keys @@ -202,15 +197,15 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: loss_log_likelihood = action_dist.log_prob(target_actions).mean() entropy = self.get_entropy_bonus(action_dist).mean() - loss = -(loss_log_likelihood + self.alpha.detach() * entropy) + loss_entropy = self.alpha.detach() * entropy loss_alpha = self.log_alpha.exp() * (entropy - self.target_entropy).detach() out = { - "loss": loss, "loss_log_likelihood": -loss_log_likelihood, - "entropy": entropy.detach(), + "loss_entropy": loss_entropy, "loss_alpha": loss_alpha, + "entropy": entropy.detach(), "alpha": self.alpha.detach(), } return TensorDict(out, []) @@ -239,19 +234,9 @@ class _AcceptedKeys: Attributes: action (NestedKey): The input tensordict key where the action is expected. Defaults to ``"action"``. - observation (NestedKey): The input tensordict key where the observation is expected. - Defaults to ``"observation"``. - return_to_go (NestedKey): The input tensordict key where the return_to_go is expected. - Defaults to ``"return_to_go"``. - done (NestedKey): The key in the input TensorDict that indicates - whether a trajectory is done. Will be used for the underlying value estimator. - Defaults to ``"done"``. """ action: NestedKey = "action" - observation: NestedKey = "observation" - return_to_go: NestedKey = "return_to_go" - done: NestedKey = "done" default_keys = _AcceptedKeys() @@ -275,15 +260,13 @@ def __init__( self.loss_function = loss_function def _set_in_keys(self): - keys = [ - self.tensor_keys.action, - ("next", self.tensor_keys.return_to_go), - ("next", self.tensor_keys.done), - self.tensor_keys.action, - self.tensor_keys.observation, - ] - - self._in_keys = list(set(keys)) + keys = self.actor_network.in_keys + keys = set(keys) + keys.add(self.tensor_keys.action) + self._in_keys = sorted(keys, key=str) + + def _forward_value_estimator_keys(self, **kwargs) -> None: + pass @property def in_keys(self): From f2defcbbe5df04f1a70a62e49944f5f67aa4603b Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 3 Jul 2023 14:46:32 +0100 Subject: [PATCH 055/104] doc --- docs/source/reference/objectives.rst | 8 -------- 1 file changed, 8 deletions(-) diff --git a/docs/source/reference/objectives.rst b/docs/source/reference/objectives.rst index ad653d001d6..5fddc9655ef 100644 --- a/docs/source/reference/objectives.rst +++ b/docs/source/reference/objectives.rst @@ -144,14 +144,6 @@ DT :template: rl_template_noinherit.rst DTLoss - -OnlineDT ----- - -.. autosummary:: - :toctree: generated/ - :template: rl_template_noinherit.rst - OnlineDTLoss TD3 From f891bd20dda784312f317b6734b68b9076f233bd Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 3 Jul 2023 15:10:15 +0100 Subject: [PATCH 056/104] tests --- .circleci/unittest/linux_examples/scripts/environment.yml | 1 + .circleci/unittest/linux_examples/scripts/run_test.sh | 8 ++++++++ .circleci/unittest/linux_examples/scripts/setup_env.sh | 5 ++--- torchrl/modules/models/decision_transformer.py | 4 ++-- 4 files changed, 13 insertions(+), 5 deletions(-) diff --git a/.circleci/unittest/linux_examples/scripts/environment.yml b/.circleci/unittest/linux_examples/scripts/environment.yml index 7a91696ca46..adfc5436625 100644 --- a/.circleci/unittest/linux_examples/scripts/environment.yml +++ b/.circleci/unittest/linux_examples/scripts/environment.yml @@ -27,3 +27,4 @@ dependencies: - mlflow - av - coverage + - mujoco-py diff --git a/.circleci/unittest/linux_examples/scripts/run_test.sh b/.circleci/unittest/linux_examples/scripts/run_test.sh index 3e555d41df9..e3338d66a7f 100755 --- a/.circleci/unittest/linux_examples/scripts/run_test.sh +++ b/.circleci/unittest/linux_examples/scripts/run_test.sh @@ -29,6 +29,14 @@ python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_ python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test_deps.py -v --durations 200 # With batched environments +python .circleci/unittest/helpers/coverage_run_parallel.py examples/decision_transformer/dt.py \ + optim.pretrain_gradient_steps=55 \ + optim.updates_per_episode=3 \ + optim.warmup_steps=10 +python .circleci/unittest/helpers/coverage_run_parallel.py examples/decision_transformer/online_td.py \ + optim.pretrain_gradient_steps=55 \ + optim.updates_per_episode=3 \ + optim.warmup_steps=10 python .circleci/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo.py \ env.num_envs=1 \ env.device=cuda:0 \ diff --git a/.circleci/unittest/linux_examples/scripts/setup_env.sh b/.circleci/unittest/linux_examples/scripts/setup_env.sh index c79f25a6979..55bd61a5e13 100755 --- a/.circleci/unittest/linux_examples/scripts/setup_env.sh +++ b/.circleci/unittest/linux_examples/scripts/setup_env.sh @@ -112,7 +112,6 @@ if [[ $OSTYPE != 'darwin'* ]]; then pip install ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl rm ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl fi - pip install "gymnasium[atari,accept-rom-license]" -else - pip install "gymnasium[atari,accept-rom-license]" fi +pip install "gymnasium[atari,accept-rom-license]" +pip intall gym==0.23 # for D4RL's sake... diff --git a/torchrl/modules/models/decision_transformer.py b/torchrl/modules/models/decision_transformer.py index 4934d76e9c4..c137be9b8f2 100644 --- a/torchrl/modules/models/decision_transformer.py +++ b/torchrl/modules/models/decision_transformer.py @@ -7,8 +7,6 @@ import torch import torch.nn as nn -import transformers -from transformers.models.gpt2.modeling_gpt2 import GPT2Model _has_transformers = importlib.util.find_spec("transformers") is not None @@ -78,6 +76,8 @@ def __init__( raise ImportError( "transformers is not installed. Please install it with `pip install transformers`." ) + import transformers + from transformers.models.gpt2.modeling_gpt2 import GPT2Model super(DecisionTransformer, self).__init__() self.default_config.update(config) From cf5bc01df8eb0e8f27d78416682abb07d584f203 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 3 Jul 2023 15:11:28 +0100 Subject: [PATCH 057/104] lint --- torchrl/modules/models/decision_transformer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchrl/modules/models/decision_transformer.py b/torchrl/modules/models/decision_transformer.py index c137be9b8f2..ab5095b3956 100644 --- a/torchrl/modules/models/decision_transformer.py +++ b/torchrl/modules/models/decision_transformer.py @@ -78,6 +78,7 @@ def __init__( ) import transformers from transformers.models.gpt2.modeling_gpt2 import GPT2Model + super(DecisionTransformer, self).__init__() self.default_config.update(config) From 6d4b591bcf7198c8ea69fc8467ac17cb7af5551b Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 3 Jul 2023 15:35:31 +0100 Subject: [PATCH 058/104] fix --- .circleci/unittest/linux_examples/scripts/setup_env.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.circleci/unittest/linux_examples/scripts/setup_env.sh b/.circleci/unittest/linux_examples/scripts/setup_env.sh index 55bd61a5e13..db7a54fd424 100755 --- a/.circleci/unittest/linux_examples/scripts/setup_env.sh +++ b/.circleci/unittest/linux_examples/scripts/setup_env.sh @@ -114,4 +114,4 @@ if [[ $OSTYPE != 'darwin'* ]]; then fi fi pip install "gymnasium[atari,accept-rom-license]" -pip intall gym==0.23 # for D4RL's sake... +pip install gym==0.23 # for D4RL's sake... From 7c0df55a99dcdefe26fc8ac7f59077d4963e86bc Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 3 Jul 2023 15:52:29 +0100 Subject: [PATCH 059/104] amend --- .../unittest/linux_examples/scripts/environment.yml | 1 - .../unittest/linux_examples/scripts/setup_env.sh | 11 ++++++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/.circleci/unittest/linux_examples/scripts/environment.yml b/.circleci/unittest/linux_examples/scripts/environment.yml index adfc5436625..7a91696ca46 100644 --- a/.circleci/unittest/linux_examples/scripts/environment.yml +++ b/.circleci/unittest/linux_examples/scripts/environment.yml @@ -27,4 +27,3 @@ dependencies: - mlflow - av - coverage - - mujoco-py diff --git a/.circleci/unittest/linux_examples/scripts/setup_env.sh b/.circleci/unittest/linux_examples/scripts/setup_env.sh index db7a54fd424..ed38977d529 100755 --- a/.circleci/unittest/linux_examples/scripts/setup_env.sh +++ b/.circleci/unittest/linux_examples/scripts/setup_env.sh @@ -38,7 +38,7 @@ if [ ! -d "${env_dir}" ]; then fi conda activate "${env_dir}" -# 3. Install mujoco +# 3a. Install mujoco printf "* Installing mujoco and related\n" mkdir -p $root_dir/.mujoco cd $root_dir/.mujoco/ @@ -48,6 +48,14 @@ wget https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz tar -xf mujoco210-linux-x86_64.tar.gz cd $this_dir +# 3b. install mujoco-py +mkdir third_party +cd third_party +git clone https://github.com/openai/mujoco-py +cd mujoco-py +pip install -e . +cd ../.. + # 4. Install Conda dependencies printf "* Installing dependencies (except PyTorch)\n" echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml" @@ -115,3 +123,4 @@ if [[ $OSTYPE != 'darwin'* ]]; then fi pip install "gymnasium[atari,accept-rom-license]" pip install gym==0.23 # for D4RL's sake... +python -c """import gym""" From d3a3d77ea250908cc3af72aa0fb3f419d8c9f182 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 4 Jul 2023 15:58:42 +0100 Subject: [PATCH 060/104] amend --- .circleci/unittest/linux_examples/scripts/setup_env.sh | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.circleci/unittest/linux_examples/scripts/setup_env.sh b/.circleci/unittest/linux_examples/scripts/setup_env.sh index ed38977d529..f5605409820 100755 --- a/.circleci/unittest/linux_examples/scripts/setup_env.sh +++ b/.circleci/unittest/linux_examples/scripts/setup_env.sh @@ -6,6 +6,7 @@ # Do not install PyTorch and torchvision here, otherwise they also get cached. set -e +set -v this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" # Avoid error: "fatal: unsafe repository" @@ -123,4 +124,5 @@ if [[ $OSTYPE != 'darwin'* ]]; then fi pip install "gymnasium[atari,accept-rom-license]" pip install gym==0.23 # for D4RL's sake... -python -c """import gym""" +pip install git+https://github.com/Farama-Foundation/d4rl.git +python -c """import gym;import d4rl""" From b1c73dab80a553267d3491a27f17efd532e808b7 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 6 Jul 2023 16:12:50 +0100 Subject: [PATCH 061/104] amend --- .circleci/unittest/linux_examples/scripts/run_all.sh | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/.circleci/unittest/linux_examples/scripts/run_all.sh b/.circleci/unittest/linux_examples/scripts/run_all.sh index 0f77fbba72d..d131a371142 100755 --- a/.circleci/unittest/linux_examples/scripts/run_all.sh +++ b/.circleci/unittest/linux_examples/scripts/run_all.sh @@ -118,6 +118,9 @@ elif [[ $PY_VERSION == *"3.10"* ]]; then fi pip install "gymnasium[atari,accept-rom-license]" +# install d4rl +pip install git+https://github.com/Farama-Foundation/d4rl@master#egg=d4rl + # ============================================================================================ # # ================================ PyTorch & TorchRL ========================================= # @@ -170,6 +173,15 @@ python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_ python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test_deps.py -v --durations 200 # With batched environments +python .circleci/unittest/helpers/coverage_run_parallel.py examples/decision_transformer/dt.py \ + optim.pretrain_gradient_steps=55 \ + optim.updates_per_episode=3 \ + optim.warmup_steps=10 +python .circleci/unittest/helpers/coverage_run_parallel.py examples/decision_transformer/online_td.py \ + optim.pretrain_gradient_steps=55 \ + optim.updates_per_episode=3 \ + optim.warmup_steps=10 + python .circleci/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo.py \ env.num_envs=1 \ env.device=cuda:0 \ From 2ec7b0fe92c40ced65385870a43b89577166fdf6 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 6 Jul 2023 16:57:12 +0100 Subject: [PATCH 062/104] fix tests --- test/test_cost.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 412901cf29f..c1aba9fbac7 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -5517,7 +5517,6 @@ def test_reinforce_value_net(self, advantage, gradient_mode, delay_value, td_est retain_graph=True, allow_unused=False, ) - print(advantage, gradient_mode, delay_value, td_est) @pytest.mark.parametrize( "td_est", @@ -6243,7 +6242,7 @@ def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) - module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) + module = TensorDictModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) actor = ProbabilisticActor( module=module, distribution_class=TanhNormal, @@ -6439,7 +6438,7 @@ def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) - module = SafeModule(net, in_keys=["observation"], out_keys=["param"]) + module = TensorDictModule(net, in_keys=["observation"], out_keys=["param"]) actor = ProbabilisticActor( module=module, distribution_class=TanhDelta, From 8b8f7b1d4ebd87e1c2a08241e8faea7ef92eaaa6 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 6 Jul 2023 17:01:13 +0100 Subject: [PATCH 063/104] fix tests --- .../linux_examples/scripts/run_all.sh | 36 +++++++++---------- test/test_cost.py | 4 ++- 2 files changed, 19 insertions(+), 21 deletions(-) diff --git a/.circleci/unittest/linux_examples/scripts/run_all.sh b/.circleci/unittest/linux_examples/scripts/run_all.sh index d131a371142..a418df1e0f5 100755 --- a/.circleci/unittest/linux_examples/scripts/run_all.sh +++ b/.circleci/unittest/linux_examples/scripts/run_all.sh @@ -7,29 +7,17 @@ set -v # ================================ Init ============================================== # -if [[ $OSTYPE != 'darwin'* ]]; then - apt-get update && apt-get upgrade -y - apt-get install -y vim git wget - - apt-get install -y libglfw3 libgl1-mesa-glx libosmesa6 libglew-dev - apt-get install -y libglvnd0 libgl1 libglx0 libegl1 libgles2 - - if [ "${CU_VERSION:-}" == cpu ] ; then - # solves version `GLIBCXX_3.4.29' not found for tensorboard -# apt-get install -y gcc-4.9 - apt-get upgrade -y libstdc++6 - apt-get dist-upgrade -y - else - apt-get install -y g++ gcc - fi +apt-get update && apt-get upgrade -y +apt-get install -y vim git wget -fi +apt-get install -y libglfw3 libgl1-mesa-glx libosmesa6 libglew-dev +apt-get install -y libglvnd0 libgl1 libglx0 libegl1 libgles2 + +apt-get install -y g++ gcc this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" -if [[ $OSTYPE != 'darwin'* ]]; then - # from cudagl docker image - cp $this_dir/10_nvidia.json /usr/share/glvnd/egl_vendor.d/10_nvidia.json -fi +# from cudagl docker image +cp $this_dir/10_nvidia.json /usr/share/glvnd/egl_vendor.d/10_nvidia.json # ==================================================================================== # @@ -80,6 +68,14 @@ printf "* Installing dependencies (except PyTorch)\n" echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml" cat "${this_dir}/environment.yml" +export MUJOCO_PY_MUJOCO_PATH=$root_dir/.mujoco/mujoco210 +export DISPLAY=unix:0.0 +export MJLIB_PATH=$root_dir/.mujoco/mujoco-2.1.1/lib/libmujoco.so.2.1.1 +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$root_dir/.mujoco/mujoco210/bin +export SDL_VIDEODRIVER=dummy +export MUJOCO_GL=egl +export PYOPENGL_PLATFORM=egl + conda env config vars set MUJOCO_PY_MUJOCO_PATH=$root_dir/.mujoco/mujoco210 \ DISPLAY=unix:0.0 \ MJLIB_PATH=$root_dir/.mujoco/mujoco-2.1.1/lib/libmujoco.so.2.1.1 \ diff --git a/test/test_cost.py b/test/test_cost.py index c1aba9fbac7..8f210a483c9 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -6242,7 +6242,9 @@ def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) - module = TensorDictModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) + module = TensorDictModule( + net, in_keys=["observation"], out_keys=["loc", "scale"] + ) actor = ProbabilisticActor( module=module, distribution_class=TanhNormal, From f49d07d14c255c7667b8ab04f75c4f2d74a998f6 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 6 Jul 2023 17:32:50 +0100 Subject: [PATCH 064/104] mesalib glew glfw libosmesa6-dev --- .circleci/unittest/linux_examples/scripts/run_all.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.circleci/unittest/linux_examples/scripts/run_all.sh b/.circleci/unittest/linux_examples/scripts/run_all.sh index a418df1e0f5..1bc0c2077ce 100755 --- a/.circleci/unittest/linux_examples/scripts/run_all.sh +++ b/.circleci/unittest/linux_examples/scripts/run_all.sh @@ -10,8 +10,9 @@ set -v apt-get update && apt-get upgrade -y apt-get install -y vim git wget -apt-get install -y libglfw3 libgl1-mesa-glx libosmesa6 libglew-dev +apt-get install -y libglfw3 libgl1-mesa-glx libosmesa6 libglew-dev libosmesa6-dev apt-get install -y libglvnd0 libgl1 libglx0 libegl1 libgles2 +apt-get install -y mesalib glew glfw apt-get install -y g++ gcc From ff4c34af8a9acb987ab6d3de4f1e24f9c55e9dde Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 6 Jul 2023 17:43:12 +0100 Subject: [PATCH 065/104] libosmesa6-dev --- .circleci/unittest/linux_examples/scripts/run_all.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.circleci/unittest/linux_examples/scripts/run_all.sh b/.circleci/unittest/linux_examples/scripts/run_all.sh index 1bc0c2077ce..fdac4eac21a 100755 --- a/.circleci/unittest/linux_examples/scripts/run_all.sh +++ b/.circleci/unittest/linux_examples/scripts/run_all.sh @@ -12,7 +12,7 @@ apt-get install -y vim git wget apt-get install -y libglfw3 libgl1-mesa-glx libosmesa6 libglew-dev libosmesa6-dev apt-get install -y libglvnd0 libgl1 libglx0 libegl1 libgles2 -apt-get install -y mesalib glew glfw +#apt-get install -y mesalib glew glfw apt-get install -y g++ gcc From 40024edb9ae747c22db69a869cea70ad357dbb64 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 6 Jul 2023 19:13:56 +0100 Subject: [PATCH 066/104] patchelf --- .circleci/unittest/linux_examples/scripts/run_all.sh | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/.circleci/unittest/linux_examples/scripts/run_all.sh b/.circleci/unittest/linux_examples/scripts/run_all.sh index fdac4eac21a..4730bf9cc39 100755 --- a/.circleci/unittest/linux_examples/scripts/run_all.sh +++ b/.circleci/unittest/linux_examples/scripts/run_all.sh @@ -12,9 +12,7 @@ apt-get install -y vim git wget apt-get install -y libglfw3 libgl1-mesa-glx libosmesa6 libglew-dev libosmesa6-dev apt-get install -y libglvnd0 libgl1 libglx0 libegl1 libgles2 -#apt-get install -y mesalib glew glfw - -apt-get install -y g++ gcc +apt-get install -y g++ gcc patchelf this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" # from cudagl docker image From dffe5fc812c7d00c43eb17ffceb181a6e088b755 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 6 Jul 2023 20:31:34 +0100 Subject: [PATCH 067/104] temp hiding --- .../unittest/linux_examples/scripts/run_all.sh | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/.circleci/unittest/linux_examples/scripts/run_all.sh b/.circleci/unittest/linux_examples/scripts/run_all.sh index 4730bf9cc39..07b1eaca6ce 100755 --- a/.circleci/unittest/linux_examples/scripts/run_all.sh +++ b/.circleci/unittest/linux_examples/scripts/run_all.sh @@ -167,15 +167,15 @@ export MKL_THREADING_LAYER=GNU python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test.py -v --durations 200 python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test_deps.py -v --durations 200 -# With batched environments -python .circleci/unittest/helpers/coverage_run_parallel.py examples/decision_transformer/dt.py \ - optim.pretrain_gradient_steps=55 \ - optim.updates_per_episode=3 \ - optim.warmup_steps=10 -python .circleci/unittest/helpers/coverage_run_parallel.py examples/decision_transformer/online_td.py \ - optim.pretrain_gradient_steps=55 \ - optim.updates_per_episode=3 \ - optim.warmup_steps=10 +## With batched environments +#python .circleci/unittest/helpers/coverage_run_parallel.py examples/decision_transformer/dt.py \ +# optim.pretrain_gradient_steps=55 \ +# optim.updates_per_episode=3 \ +# optim.warmup_steps=10 +#python .circleci/unittest/helpers/coverage_run_parallel.py examples/decision_transformer/online_td.py \ +# optim.pretrain_gradient_steps=55 \ +# optim.updates_per_episode=3 \ +# optim.warmup_steps=10 python .circleci/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo.py \ env.num_envs=1 \ From 81d9b3422e6ab25588a54b2da3dd5516aa835630 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 7 Jul 2023 11:20:47 +0100 Subject: [PATCH 068/104] amend --- .../unittest/linux_examples/scripts/run_all.sh | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/.circleci/unittest/linux_examples/scripts/run_all.sh b/.circleci/unittest/linux_examples/scripts/run_all.sh index 07b1eaca6ce..4730bf9cc39 100755 --- a/.circleci/unittest/linux_examples/scripts/run_all.sh +++ b/.circleci/unittest/linux_examples/scripts/run_all.sh @@ -167,15 +167,15 @@ export MKL_THREADING_LAYER=GNU python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test.py -v --durations 200 python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test_deps.py -v --durations 200 -## With batched environments -#python .circleci/unittest/helpers/coverage_run_parallel.py examples/decision_transformer/dt.py \ -# optim.pretrain_gradient_steps=55 \ -# optim.updates_per_episode=3 \ -# optim.warmup_steps=10 -#python .circleci/unittest/helpers/coverage_run_parallel.py examples/decision_transformer/online_td.py \ -# optim.pretrain_gradient_steps=55 \ -# optim.updates_per_episode=3 \ -# optim.warmup_steps=10 +# With batched environments +python .circleci/unittest/helpers/coverage_run_parallel.py examples/decision_transformer/dt.py \ + optim.pretrain_gradient_steps=55 \ + optim.updates_per_episode=3 \ + optim.warmup_steps=10 +python .circleci/unittest/helpers/coverage_run_parallel.py examples/decision_transformer/online_td.py \ + optim.pretrain_gradient_steps=55 \ + optim.updates_per_episode=3 \ + optim.warmup_steps=10 python .circleci/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo.py \ env.num_envs=1 \ From c75eb39ba74db76d8adf75496fab02d04af1d464 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 7 Jul 2023 12:16:10 +0100 Subject: [PATCH 069/104] amend --- .../unittest/linux_examples/scripts/run_all.sh | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/.circleci/unittest/linux_examples/scripts/run_all.sh b/.circleci/unittest/linux_examples/scripts/run_all.sh index 4730bf9cc39..d70e78e61a7 100755 --- a/.circleci/unittest/linux_examples/scripts/run_all.sh +++ b/.circleci/unittest/linux_examples/scripts/run_all.sh @@ -56,8 +56,8 @@ conda activate "${env_dir}" printf "* Installing mujoco and related\n" mkdir -p $root_dir/.mujoco cd $root_dir/.mujoco/ -wget https://github.com/deepmind/mujoco/releases/download/2.1.1/mujoco-2.1.1-linux-x86_64.tar.gz -tar -xf mujoco-2.1.1-linux-x86_64.tar.gz +#wget https://github.com/deepmind/mujoco/releases/download/2.1.1/mujoco-2.1.1-linux-x86_64.tar.gz +#tar -xf mujoco-2.1.1-linux-x86_64.tar.gz wget https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz tar -xf mujoco210-linux-x86_64.tar.gz cd "${root_dir}" @@ -69,7 +69,7 @@ cat "${this_dir}/environment.yml" export MUJOCO_PY_MUJOCO_PATH=$root_dir/.mujoco/mujoco210 export DISPLAY=unix:0.0 -export MJLIB_PATH=$root_dir/.mujoco/mujoco-2.1.1/lib/libmujoco.so.2.1.1 +#export MJLIB_PATH=$root_dir/.mujoco/mujoco-2.1.1/lib/libmujoco.so.2.1.1 export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$root_dir/.mujoco/mujoco210/bin export SDL_VIDEODRIVER=dummy export MUJOCO_GL=egl @@ -77,7 +77,7 @@ export PYOPENGL_PLATFORM=egl conda env config vars set MUJOCO_PY_MUJOCO_PATH=$root_dir/.mujoco/mujoco210 \ DISPLAY=unix:0.0 \ - MJLIB_PATH=$root_dir/.mujoco/mujoco-2.1.1/lib/libmujoco.so.2.1.1 \ +# MJLIB_PATH=$root_dir/.mujoco/mujoco-2.1.1/lib/libmujoco.so.2.1.1 \ LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$root_dir/.mujoco/mujoco210/bin \ SDL_VIDEODRIVER=dummy \ MUJOCO_GL=egl \ @@ -90,6 +90,12 @@ conda env update --file "${this_dir}/environment.yml" --prune conda deactivate conda activate "${env_dir}" +# install d4rl +pip install free-mujoco-py +pip install git+https://github.com/Farama-Foundation/d4rl@master#egg=d4rl + +python -c """import gym;import d4rl""" + # install ale-py: manylinux names are broken for CentOS so we need to manually download and # rename them PY_VERSION=$(python --version) @@ -113,9 +119,6 @@ elif [[ $PY_VERSION == *"3.10"* ]]; then fi pip install "gymnasium[atari,accept-rom-license]" -# install d4rl -pip install git+https://github.com/Farama-Foundation/d4rl@master#egg=d4rl - # ============================================================================================ # # ================================ PyTorch & TorchRL ========================================= # From 540d82b629dfe605e451497aadb894488e8adade Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 7 Jul 2023 12:27:12 +0100 Subject: [PATCH 070/104] amend --- .circleci/unittest/linux_examples/scripts/run_all.sh | 6 ++++++ .circleci/unittest/linux_libs/scripts_habitat/run_test.sh | 3 +-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/.circleci/unittest/linux_examples/scripts/run_all.sh b/.circleci/unittest/linux_examples/scripts/run_all.sh index d70e78e61a7..af94d1d234e 100755 --- a/.circleci/unittest/linux_examples/scripts/run_all.sh +++ b/.circleci/unittest/linux_examples/scripts/run_all.sh @@ -94,6 +94,12 @@ conda activate "${env_dir}" pip install free-mujoco-py pip install git+https://github.com/Farama-Foundation/d4rl@master#egg=d4rl +# TODO: move this down -- will break torchrl installation +conda install -y -c conda-forge libstdcxx-ng=12 +## find libstdc +STDC_LOC=$(find conda/ -name "libstdc++.so.6" | head -1) +conda env config vars set LD_PRELOAD=$LD_PRELOAD:$STDC_LOC + python -c """import gym;import d4rl""" # install ale-py: manylinux names are broken for CentOS so we need to manually download and diff --git a/.circleci/unittest/linux_libs/scripts_habitat/run_test.sh b/.circleci/unittest/linux_libs/scripts_habitat/run_test.sh index 1f916fea9c1..1c2f7e19cb0 100755 --- a/.circleci/unittest/linux_libs/scripts_habitat/run_test.sh +++ b/.circleci/unittest/linux_libs/scripts_habitat/run_test.sh @@ -10,10 +10,9 @@ conda activate ./env # https://stackoverflow.com/questions/72540359/glibcxx-3-4-30-not-found-for-librosa-in-conda-virtual-environment-after-tryin #conda install -y -c conda-forge gcc=12.1.0 conda install -y -c conda-forge libstdcxx-ng=12 -conda env config vars set LD_PRELOAD=$LD_PRELOAD:$STDC_LOC - ## find libstdc STDC_LOC=$(find conda/ -name "libstdc++.so.6" | head -1) +conda env config vars set LD_PRELOAD=$LD_PRELOAD:$STDC_LOC export PYTORCH_TEST_WITH_SLOW='1' python -m torch.utils.collect_env From 091a119e47b86b52f98368f4e3d91d0803a16461 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 7 Jul 2023 13:32:37 +0100 Subject: [PATCH 071/104] amend --- .circleci/unittest/linux_examples/scripts/run_all.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.circleci/unittest/linux_examples/scripts/run_all.sh b/.circleci/unittest/linux_examples/scripts/run_all.sh index af94d1d234e..4d34d3144c5 100755 --- a/.circleci/unittest/linux_examples/scripts/run_all.sh +++ b/.circleci/unittest/linux_examples/scripts/run_all.sh @@ -98,7 +98,7 @@ pip install git+https://github.com/Farama-Foundation/d4rl@master#egg=d4rl conda install -y -c conda-forge libstdcxx-ng=12 ## find libstdc STDC_LOC=$(find conda/ -name "libstdc++.so.6" | head -1) -conda env config vars set LD_PRELOAD=$LD_PRELOAD:$STDC_LOC +conda env config vars set LD_PRELOAD=$STDC_LOC python -c """import gym;import d4rl""" From 87866a718693d3a206c8287d84108c3325f48c3d Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 7 Jul 2023 14:08:30 +0100 Subject: [PATCH 072/104] amend --- .circleci/unittest/linux_examples/scripts/run_all.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/.circleci/unittest/linux_examples/scripts/run_all.sh b/.circleci/unittest/linux_examples/scripts/run_all.sh index 4d34d3144c5..1e17c02ccc5 100755 --- a/.circleci/unittest/linux_examples/scripts/run_all.sh +++ b/.circleci/unittest/linux_examples/scripts/run_all.sh @@ -100,6 +100,7 @@ conda install -y -c conda-forge libstdcxx-ng=12 STDC_LOC=$(find conda/ -name "libstdc++.so.6" | head -1) conda env config vars set LD_PRELOAD=$STDC_LOC +# compile mujoco-py (bc it's done at runtime for whatever reason someone thought it was a good idea) python -c """import gym;import d4rl""" # install ale-py: manylinux names are broken for CentOS so we need to manually download and From 24c129fa70801ea9cdd93113c9c8a652c5e923df Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 7 Jul 2023 14:53:44 +0100 Subject: [PATCH 073/104] amend --- .circleci/unittest/linux_examples/scripts/run_test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.circleci/unittest/linux_examples/scripts/run_test.sh b/.circleci/unittest/linux_examples/scripts/run_test.sh index 4c84b220552..218c26ae6a1 100755 --- a/.circleci/unittest/linux_examples/scripts/run_test.sh +++ b/.circleci/unittest/linux_examples/scripts/run_test.sh @@ -26,7 +26,7 @@ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir export MKL_THREADING_LAYER=GNU python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test.py -v --durations 200 -python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test_deps.py -v --durations 200 +#python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test_deps.py -v --durations 200 # With batched environments python .circleci/unittest/helpers/coverage_run_parallel.py examples/decision_transformer/dt.py \ From ad8d4125777d6689cac96e22d8193eabeac5ed88 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 7 Jul 2023 16:22:47 +0100 Subject: [PATCH 074/104] empty From 4a6471691a7a92f766623dcfa434970ab33f7bee Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 7 Jul 2023 17:23:40 +0100 Subject: [PATCH 075/104] fix wandb --- .circleci/unittest/linux_examples/scripts/environment.yml | 2 -- .circleci/unittest/linux_examples/scripts/run_all.sh | 3 +-- .circleci/unittest/linux_examples/scripts/run_test.sh | 6 ++++-- examples/decision_transformer/dt.py | 8 ++++---- examples/decision_transformer/online_dt.py | 7 ++++--- examples/decision_transformer/utils.py | 2 ++ 6 files changed, 15 insertions(+), 13 deletions(-) diff --git a/.circleci/unittest/linux_examples/scripts/environment.yml b/.circleci/unittest/linux_examples/scripts/environment.yml index 7a91696ca46..0b58952e0e8 100644 --- a/.circleci/unittest/linux_examples/scripts/environment.yml +++ b/.circleci/unittest/linux_examples/scripts/environment.yml @@ -20,9 +20,7 @@ dependencies: - pyyaml - scipy - hydra-core - - tensorboard - imageio==2.26.0 - - wandb - dm_control - mlflow - av diff --git a/.circleci/unittest/linux_examples/scripts/run_all.sh b/.circleci/unittest/linux_examples/scripts/run_all.sh index 6e76424a986..8d47bef9998 100755 --- a/.circleci/unittest/linux_examples/scripts/run_all.sh +++ b/.circleci/unittest/linux_examples/scripts/run_all.sh @@ -77,7 +77,6 @@ export PYOPENGL_PLATFORM=egl conda env config vars set MUJOCO_PY_MUJOCO_PATH=$root_dir/.mujoco/mujoco210 \ DISPLAY=unix:0.0 \ -# MJLIB_PATH=$root_dir/.mujoco/mujoco-2.1.1/lib/libmujoco.so.2.1.1 \ LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$root_dir/.mujoco/mujoco210/bin \ SDL_VIDEODRIVER=dummy \ MUJOCO_GL=egl \ @@ -98,7 +97,7 @@ pip install git+https://github.com/Farama-Foundation/d4rl@master#egg=d4rl conda install -y -c conda-forge libstdcxx-ng=12 ## find libstdc STDC_LOC=$(find conda/ -name "libstdc++.so.6" | head -1) -conda env config vars set LD_PRELOAD=$STDC_LOC +conda env config vars set LD_PRELOAD=${root_dir}/$STDC_LOC # compile mujoco-py (bc it's done at runtime for whatever reason someone thought it was a good idea) python -c """import gym;import d4rl""" diff --git a/.circleci/unittest/linux_examples/scripts/run_test.sh b/.circleci/unittest/linux_examples/scripts/run_test.sh index 218c26ae6a1..df82875e520 100755 --- a/.circleci/unittest/linux_examples/scripts/run_test.sh +++ b/.circleci/unittest/linux_examples/scripts/run_test.sh @@ -32,11 +32,13 @@ python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_ python .circleci/unittest/helpers/coverage_run_parallel.py examples/decision_transformer/dt.py \ optim.pretrain_gradient_steps=55 \ optim.updates_per_episode=3 \ - optim.warmup_steps=10 + optim.warmup_steps=10 \ + logger.backend= python .circleci/unittest/helpers/coverage_run_parallel.py examples/decision_transformer/online_td.py \ optim.pretrain_gradient_steps=55 \ optim.updates_per_episode=3 \ - optim.warmup_steps=10 + optim.warmup_steps=10 \ + logger.backend= python .circleci/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo.py \ env.num_envs=1 \ env.device=cuda:0 \ diff --git a/examples/decision_transformer/dt.py b/examples/decision_transformer/dt.py index 9004925afde..674ad4ffc50 100644 --- a/examples/decision_transformer/dt.py +++ b/examples/decision_transformer/dt.py @@ -26,7 +26,6 @@ @hydra.main(config_path=".", config_name="dt_config") def main(cfg: "DictConfig"): # noqa: F821 model_device = cfg.optim.device - logger = make_logger(cfg) offline_buffer, obs_loc, obs_std = make_offline_replay_buffer( cfg.replay_buffer, cfg.env.reward_scaling @@ -84,10 +83,11 @@ def main(cfg: "DictConfig"): # noqa: F821 if l0 is None: l0 = transformer_loss.item() - for key, value in loss_vals.items(): - logger.log_scalar(key, value.item(), i) eval_reward = eval_td["next", "reward"].sum(1).mean().item() / reward_scaling - logger.log_scalar("evaluation reward", eval_reward, i) + if logger is not None: + for key, value in loss_vals.items(): + logger.log_scalar(key, value.item(), i) + logger.log_scalar("evaluation reward", eval_reward, i) pbar.set_description( f"[Pre-Training] loss: {transformer_loss.item(): 4.4f} (init: {l0: 4.4f}), evaluation reward: {eval_reward: 4.4f} (init={r0: 4.4f})" diff --git a/examples/decision_transformer/online_dt.py b/examples/decision_transformer/online_dt.py index 5dff3390371..d67ef5bb166 100644 --- a/examples/decision_transformer/online_dt.py +++ b/examples/decision_transformer/online_dt.py @@ -92,10 +92,11 @@ def main(cfg: "DictConfig"): # noqa: F821 if l0 is None: l0 = transformer_loss.item() - for key, value in loss_vals.items(): - logger.log_scalar(key, value.item(), i) eval_reward = eval_td["next", "reward"].sum(1).mean().item() / reward_scaling - logger.log_scalar("evaluation reward", eval_reward, i) + if logger is not None: + for key, value in loss_vals.items(): + logger.log_scalar(key, value.item(), i) + logger.log_scalar("evaluation reward", eval_reward, i) pbar.set_description( f"[Pre-Training] loss: {transformer_loss.item(): 4.4f} (init: {l0: 4.4f}), evaluation reward: {eval_reward: 4.4f} (init={r0: 4.4f})" diff --git a/examples/decision_transformer/utils.py b/examples/decision_transformer/utils.py index 4e3d3573f82..179275716cb 100644 --- a/examples/decision_transformer/utils.py +++ b/examples/decision_transformer/utils.py @@ -456,6 +456,8 @@ def make_dt_optimizer(optim_cfg, actor_network): def make_logger(cfg): + if not cfg.logger.backend: + return None exp_name = generate_exp_name(cfg.logger.model_name, cfg.logger.exp_name) cfg.logger.exp_name = exp_name logger = get_logger( From edaa7b539d322a57c93a0d895849451e02b7de93 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 7 Jul 2023 17:58:48 +0100 Subject: [PATCH 076/104] lint --- examples/decision_transformer/dt.py | 3 +++ examples/decision_transformer/online_dt.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/examples/decision_transformer/dt.py b/examples/decision_transformer/dt.py index 674ad4ffc50..90b8faf5592 100644 --- a/examples/decision_transformer/dt.py +++ b/examples/decision_transformer/dt.py @@ -10,6 +10,8 @@ import hydra import torch import tqdm + +from torchrl.envs.libs.gym import set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules.tensordict_module import DecisionTransformerInferenceWrapper @@ -23,6 +25,7 @@ ) +@set_gym_backend("gym") # D4RL uses gym so we make sure gymnasium is hidden @hydra.main(config_path=".", config_name="dt_config") def main(cfg: "DictConfig"): # noqa: F821 model_device = cfg.optim.device diff --git a/examples/decision_transformer/online_dt.py b/examples/decision_transformer/online_dt.py index d67ef5bb166..6dc4e83809c 100644 --- a/examples/decision_transformer/online_dt.py +++ b/examples/decision_transformer/online_dt.py @@ -10,6 +10,8 @@ import hydra import torch import tqdm + +from torchrl.envs.libs.gym import set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules.tensordict_module import DecisionTransformerInferenceWrapper @@ -23,6 +25,7 @@ ) +@set_gym_backend("gym") # D4RL uses gym so we make sure gymnasium is hidden @hydra.main(config_path=".", config_name="odt_config") def main(cfg: "DictConfig"): # noqa: F821 model_device = cfg.optim.device From 8abb8f34287e2f8d59f3a3b940344866ed05bb49 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 7 Jul 2023 18:31:59 +0100 Subject: [PATCH 077/104] amend --- examples/decision_transformer/utils.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/examples/decision_transformer/utils.py b/examples/decision_transformer/utils.py index 179275716cb..c27bb7e299c 100644 --- a/examples/decision_transformer/utils.py +++ b/examples/decision_transformer/utils.py @@ -25,6 +25,7 @@ UnsqueezeTransform, ) from torchrl.envs.libs.dm_control import DMControlEnv +from torchrl.envs.libs.gym import set_gym_backend from torchrl.envs.utils import set_exploration_mode from torchrl.modules import ( DTActor, @@ -127,8 +128,13 @@ def make_parallel_env(env_cfg, obs_loc, obs_std, train=False): num_envs = env_cfg.num_train_envs else: num_envs = env_cfg.num_eval_envs + + def make_env(): + with set_gym_backend("gym"): + return make_base_env(env_cfg) + env = make_transformed_env( - ParallelEnv(num_envs, EnvCreator(lambda: make_base_env(env_cfg))), + ParallelEnv(num_envs, EnvCreator(make_env)), env_cfg, obs_loc, obs_std, From dfcff63d8b6a589606a2c686ec32b408689959eb Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 7 Jul 2023 21:53:52 +0100 Subject: [PATCH 078/104] amend --- .../linux_examples/scripts/run_all.sh | 23 -------------- .../linux_examples/scripts/run_test.sh | 30 +++++++++++++++++++ 2 files changed, 30 insertions(+), 23 deletions(-) diff --git a/.circleci/unittest/linux_examples/scripts/run_all.sh b/.circleci/unittest/linux_examples/scripts/run_all.sh index 8d47bef9998..5efa1976509 100755 --- a/.circleci/unittest/linux_examples/scripts/run_all.sh +++ b/.circleci/unittest/linux_examples/scripts/run_all.sh @@ -102,29 +102,6 @@ conda env config vars set LD_PRELOAD=${root_dir}/$STDC_LOC # compile mujoco-py (bc it's done at runtime for whatever reason someone thought it was a good idea) python -c """import gym;import d4rl""" -# install ale-py: manylinux names are broken for CentOS so we need to manually download and -# rename them -PY_VERSION=$(python --version) -if [[ $PY_VERSION == *"3.7"* ]]; then - wget https://files.pythonhosted.org/packages/ab/fd/6615982d9460df7f476cad265af1378057eee9daaa8e0026de4cedbaffbd/ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pip install ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - rm ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl -elif [[ $PY_VERSION == *"3.8"* ]]; then - wget https://files.pythonhosted.org/packages/0f/8a/feed20571a697588bc4bfef05d6a487429c84f31406a52f8af295a0346a2/ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pip install ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - rm ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl -elif [[ $PY_VERSION == *"3.9"* ]]; then - wget https://files.pythonhosted.org/packages/a0/98/4316c1cedd9934f9a91b6e27a9be126043b4445594b40cfa391c8de2e5e8/ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pip install ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - rm ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl -elif [[ $PY_VERSION == *"3.10"* ]]; then - wget https://files.pythonhosted.org/packages/60/1b/3adde7f44f79fcc50d0a00a0643255e48024c4c3977359747d149dc43500/ale_py-0.8.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl - mv ale_py-0.8.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pip install ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - rm ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl -fi -pip install "gymnasium[atari,accept-rom-license]" - # ============================================================================================ # # ================================ PyTorch & TorchRL ========================================= # diff --git a/.circleci/unittest/linux_examples/scripts/run_test.sh b/.circleci/unittest/linux_examples/scripts/run_test.sh index df82875e520..11a14bb9a61 100755 --- a/.circleci/unittest/linux_examples/scripts/run_test.sh +++ b/.circleci/unittest/linux_examples/scripts/run_test.sh @@ -28,6 +28,9 @@ export MKL_THREADING_LAYER=GNU python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test.py -v --durations 200 #python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test_deps.py -v --durations 200 +# ==================================================================================== # +# ================================ gym 0.23 ========================================== # + # With batched environments python .circleci/unittest/helpers/coverage_run_parallel.py examples/decision_transformer/dt.py \ optim.pretrain_gradient_steps=55 \ @@ -39,6 +42,33 @@ python .circleci/unittest/helpers/coverage_run_parallel.py examples/decision_tra optim.updates_per_episode=3 \ optim.warmup_steps=10 \ logger.backend= + +# ==================================================================================== # +# ================================ Gymnasium ========================================= # + +# install ale-py: manylinux names are broken for CentOS so we need to manually download and +# rename them +PY_VERSION=$(python --version) +if [[ $PY_VERSION == *"3.7"* ]]; then + wget https://files.pythonhosted.org/packages/ab/fd/6615982d9460df7f476cad265af1378057eee9daaa8e0026de4cedbaffbd/ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + pip install ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + rm ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl +elif [[ $PY_VERSION == *"3.8"* ]]; then + wget https://files.pythonhosted.org/packages/0f/8a/feed20571a697588bc4bfef05d6a487429c84f31406a52f8af295a0346a2/ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + pip install ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + rm ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl +elif [[ $PY_VERSION == *"3.9"* ]]; then + wget https://files.pythonhosted.org/packages/a0/98/4316c1cedd9934f9a91b6e27a9be126043b4445594b40cfa391c8de2e5e8/ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + pip install ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + rm ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl +elif [[ $PY_VERSION == *"3.10"* ]]; then + wget https://files.pythonhosted.org/packages/60/1b/3adde7f44f79fcc50d0a00a0643255e48024c4c3977359747d149dc43500/ale_py-0.8.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl + mv ale_py-0.8.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + pip install ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + rm ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl +fi +pip install "gymnasium[atari,accept-rom-license]" + python .circleci/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo.py \ env.num_envs=1 \ env.device=cuda:0 \ From 395456c04c41d0977697386480df07d31f61fe1a Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 7 Jul 2023 22:05:27 +0100 Subject: [PATCH 079/104] amend --- .circleci/unittest/linux_examples/scripts/environment.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.circleci/unittest/linux_examples/scripts/environment.yml b/.circleci/unittest/linux_examples/scripts/environment.yml index 0b58952e0e8..fde5980bad1 100644 --- a/.circleci/unittest/linux_examples/scripts/environment.yml +++ b/.circleci/unittest/linux_examples/scripts/environment.yml @@ -25,3 +25,4 @@ dependencies: - mlflow - av - coverage + - transformers From d58675eb88f75f331ae45c7cff0c58e97e1652d1 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 7 Jul 2023 22:25:42 +0100 Subject: [PATCH 080/104] amend --- examples/decision_transformer/dt.py | 1 - examples/decision_transformer/online_dt.py | 1 - torchrl/modules/tensordict_module/actors.py | 28 +++++++++++++++++---- 3 files changed, 23 insertions(+), 7 deletions(-) diff --git a/examples/decision_transformer/dt.py b/examples/decision_transformer/dt.py index 90b8faf5592..3f9ca9c02e3 100644 --- a/examples/decision_transformer/dt.py +++ b/examples/decision_transformer/dt.py @@ -41,7 +41,6 @@ def main(cfg: "DictConfig"): # noqa: F821 transformer_optim, scheduler = make_dt_optimizer(cfg.optim, policy) inference_policy = DecisionTransformerInferenceWrapper( policy=policy, - loss_module=loss_module, inference_context=cfg.env.inference_context, ).to(model_device) diff --git a/examples/decision_transformer/online_dt.py b/examples/decision_transformer/online_dt.py index 6dc4e83809c..07b5e298fe5 100644 --- a/examples/decision_transformer/online_dt.py +++ b/examples/decision_transformer/online_dt.py @@ -45,7 +45,6 @@ def main(cfg: "DictConfig"): # noqa: F821 ) inference_policy = DecisionTransformerInferenceWrapper( policy=policy, - loss_module=loss_module, inference_context=cfg.env.inference_context, ).to(model_device) diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index c67ffac7843..e2b29f88251 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -1624,7 +1624,6 @@ class DecisionTransformerInferenceWrapper(TensorDictModuleWrapper): observations and produces an action value Keyword Args: - loss_module (TensorDictModule): The loss module that computes the DT loss to receive the input keys of the model. inference_context (int): The number of previous actions that will not be masked in the context. For example for an observation input of shape [batch_size, context, obs_dim] with context=20 and inference_context=5, the first 15 entries of the context will be masked. Defaults to 5. @@ -1682,14 +1681,13 @@ def __init__( self, policy: TensorDictModule, *, - loss_module: TensorDictModule, inference_context: int = 5, spec: Optional[TensorSpec] = None, ): super().__init__(policy) - self.observation_key = loss_module.tensor_keys.observation - self.action_key = loss_module.tensor_keys.action - self.return_to_go_key = loss_module.tensor_keys.return_to_go + self.observation_key = "observation" + self.action_key = "action" + self.return_to_go_key = "return_to_go" self.inference_context = inference_context if spec is not None: if not isinstance(spec, CompositeSpec) and len(self.out_keys) >= 1: @@ -1706,6 +1704,26 @@ def __init__( else: self._spec = CompositeSpec({key: None for key in policy.out_keys}) + def set_tensor_keys(self, **kwargs): + """Sets the input keys of the module. + + Keyword Args: + observation (NestedKey, optional): The observation key. + action (NestedKey, optional): The action key. + return_to_go (NestedKey, optional): The return_to_go key. + + """ + observation_key = kwargs.pop("observation", None) + action_key = kwargs.pop("action", None) + return_to_go_key = kwargs.pop("return_to_go", None) + if kwargs: + raise TypeError( + f"Got unknown input(s) {kwargs.keys()}. Accepted keys are 'action', 'return_to_go' and 'observation'." + ) + self.observation_key = observation_key + self.action_key = action_key + self.return_to_go_key = return_to_go_key + def step(self, frames: int = 1) -> None: pass From 244e42973402f67fdb71ba232484b33df1a954e2 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 7 Jul 2023 22:42:48 +0100 Subject: [PATCH 081/104] amend --- torchrl/modules/__init__.py | 1 + torchrl/modules/models/models.py | 62 ++++++++++++--------- torchrl/modules/tensordict_module/actors.py | 30 ++++++---- 3 files changed, 56 insertions(+), 37 deletions(-) diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index 346a65348ad..0f16362cb1a 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -43,6 +43,7 @@ ActorCriticWrapper, ActorValueOperator, AdditiveGaussianWrapper, + DecisionTransformerInferenceWrapper, DistributionalQValueActor, DistributionalQValueHook, DistributionalQValueModule, diff --git a/torchrl/modules/models/models.py b/torchrl/modules/models/models.py index 6ad68afb831..f3ec1d247d6 100644 --- a/torchrl/modules/models/models.py +++ b/torchrl/modules/models/models.py @@ -1153,21 +1153,12 @@ class OnlineDTActor(nn.Module): device (Optional[DEVICE_TYPING], optional): device to use. Defaults to None. Examples: - >>> config = { - ... "n_embd": 256, - ... "n_layer": 4, - ... "n_head": 4, - ... "n_inner": 1024, - ... "activation": "relu", - ... "n_positions": 1024, - ... "resid_pdrop": 0.1, - ... "attn_pdrop": 0.1, - ... } - >>> model = OnlineDTActor(state_dim=4, action_dim=2, config=config) + >>> model = OnlineDTActor(state_dim=4, action_dim=2, + ... transformer_config=OnlineDTActor.get_default_config()) >>> observation = torch.randn(32, 10, 4) >>> action = torch.randn(32, 10, 2) >>> return_to_go = torch.randn(32, 10, 1) - >>> (mu, std) = model(observation, action, return_to_go) + >>> mu, std = model(observation, action, return_to_go) >>> mu.shape torch.Size([32, 10, 2]) >>> std.shape @@ -1188,7 +1179,7 @@ def __init__( config=transformer_config, ) self.action_layer = nn.Linear( - transformer_config.n_embd, action_dim * 2, device=device + transformer_config['n_embd'], action_dim * 2, device=device ) self.log_std_min, self.log_std_max = -5.0, 2.0 @@ -1219,7 +1210,21 @@ def forward( ) std = log_std.exp() - return (mu, std) + return mu, std + + @classmethod + def get_default_config(cls): + """Default configuration for :class:`~.OnlineDTActor`""" + return { + "n_embd": 256, + "n_layer": 4, + "n_head": 4, + "n_inner": 1024, + "activation": "relu", + "n_positions": 1024, + "resid_pdrop": 0.1, + "attn_pdrop": 0.1, + } class DTActor(nn.Module): @@ -1235,17 +1240,8 @@ class DTActor(nn.Module): device (Optional[DEVICE_TYPING], optional): device to use. Defaults to None. Examples: - >>> config = { - ... "n_embd": 256, - ... "n_layer": 4, - ... "n_head": 4, - ... "n_inner": 1024, - ... "activation": "relu", - ... "n_positions": 1024, - ... "resid_pdrop": 0.1, - ... "attn_pdrop": 0.1, - ... } - >>> model = DTActor(state_dim=4, action_dim=2, config=config) + >>> model = DTActor(state_dim=4, action_dim=2, + ... transformer_config=DTActor.get_default_config()) >>> observation = torch.randn(32, 10, 4) >>> action = torch.randn(32, 10, 2) >>> return_to_go = torch.randn(32, 10, 1) @@ -1269,7 +1265,7 @@ def __init__( config=transformer_config, ) self.action_layer = nn.Linear( - transformer_config.n_embd, action_dim, device=device + transformer_config['n_embd'], action_dim, device=device ) def weight_init(m): @@ -1290,3 +1286,17 @@ def forward( hidden_state = self.transformer(observation, action, return_to_go) out = self.action_layer(hidden_state) return out + + @classmethod + def get_default_config(cls): + """Default configuration for :class:`~.DTActor`""" + return { + "n_embd": 256, + "n_layer": 4, + "n_head": 4, + "n_inner": 1024, + "activation": "relu", + "n_positions": 1024, + "resid_pdrop": 0.1, + "attn_pdrop": 0.1, + } diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index e2b29f88251..8225bdeec07 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -1635,35 +1635,36 @@ class DecisionTransformerInferenceWrapper(TensorDictModuleWrapper): >>> from tensordict.nn import TensorDictModule >>> from torchrl.modules import ( ... ProbabilisticActor, + ... TanhDelta, ... DTActor, - ... TanhNormal, ... DecisionTransformerInferenceWrapper, ... ) + >>> dtactor = DTActor(state_dim=4, action_dim=2, + ... transformer_config=DTActor.get_default_config() + ... ) >>> actor_module = TensorDictModule( - ... DTActor(state_dim=4, action_dim=2), - ... in_keys=in_keys, - ... out_keys=[ - ... "loc", - ... "scale",]) - >>> dist_class = TanhNormal + ... nn.dtactor, + ... in_keys=["observation", "action", "return_to_go"], + ... out_keys=["param"]) + >>> dist_class = TanhDelta >>> dist_kwargs = { ... "min": -1.0, ... "max": 1.0, ... "tanh_loc": False, ... } >>> actor = ProbabilisticActor( - ... in_keys=["loc", "scale"], - ... out_keys=["action", "log_prob"], + ... in_keys=["param"], + ... out_keys=["action"], ... module=actor_module, ... distribution_class=dist_class, ... distribution_kwargs=dist_kwargs) >>> inference_actor = DecisionTransformerInferenceWrapper(actor) - >>> print(inference_actor) >>> sequence_length = 20 >>> td = TensorDict({"observation": torch.randn(1, sequence_length, 4), ... "action": torch.randn(1, sequence_length, 2), ... "return_to_go": torch.randn(1, sequence_length, 1)}, [1,]) - >>> print(inference_actor(td.clone())) + >>> result = inference_actor(td) + >>> print(result) TensorDict( fields={ action: Tensor(shape=torch.Size([1, 2]), device=cpu, dtype=torch.float32, is_shared=False), @@ -1704,6 +1705,13 @@ def __init__( else: self._spec = CompositeSpec({key: None for key in policy.out_keys}) + @property + def in_keys(self): + return [self.observation_key, self.action_key, self.return_to_go_key] + @property + def out_keys(self): + return [self.observation_key, self.action_key, self.return_to_go_key] + def set_tensor_keys(self, **kwargs): """Sets the input keys of the module. From c9338b303c1d9c94f4a37a2c5c542669005cd23c Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 7 Jul 2023 22:43:55 +0100 Subject: [PATCH 082/104] amend --- torchrl/modules/tensordict_module/actors.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index 8225bdeec07..cf87065c0e5 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -1641,16 +1641,15 @@ class DecisionTransformerInferenceWrapper(TensorDictModuleWrapper): ... ) >>> dtactor = DTActor(state_dim=4, action_dim=2, ... transformer_config=DTActor.get_default_config() - ... ) + ... ) >>> actor_module = TensorDictModule( - ... nn.dtactor, + ... dtactor, ... in_keys=["observation", "action", "return_to_go"], ... out_keys=["param"]) >>> dist_class = TanhDelta >>> dist_kwargs = { ... "min": -1.0, ... "max": 1.0, - ... "tanh_loc": False, ... } >>> actor = ProbabilisticActor( ... in_keys=["param"], @@ -1668,10 +1667,8 @@ class DecisionTransformerInferenceWrapper(TensorDictModuleWrapper): TensorDict( fields={ action: Tensor(shape=torch.Size([1, 2]), device=cpu, dtype=torch.float32, is_shared=False), - loc: Tensor(shape=torch.Size([1, 2]), device=cpu, dtype=torch.float32, is_shared=False), observation: Tensor(shape=torch.Size([1, 20, 4]), device=cpu, dtype=torch.float32, is_shared=False), - sample_log_prob: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), - scale: Tensor(shape=torch.Size([1, 2]), device=cpu, dtype=torch.float32, is_shared=False), + param: Tensor(shape=torch.Size([1, 20, 2]), device=cpu, dtype=torch.float32, is_shared=False), return_to_go: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([1]), device=None, From cdadf46cf1f2ffdede1fbd68691eed88698a5150 Mon Sep 17 00:00:00 2001 From: Mateusz Guzek Date: Mon, 10 Jul 2023 15:59:37 +0200 Subject: [PATCH 083/104] Added list of D4RL datasets --- torchrl/data/datasets/d4rl.py | 205 ++++++++++++++++++++++++++++++++-- 1 file changed, 197 insertions(+), 8 deletions(-) diff --git a/torchrl/data/datasets/d4rl.py b/torchrl/data/datasets/d4rl.py index d6c32083e23..017fb31aa5b 100644 --- a/torchrl/data/datasets/d4rl.py +++ b/torchrl/data/datasets/d4rl.py @@ -93,6 +93,188 @@ class D4RLExperienceReplay(TensorDictReplayBuffer): """ + D4RL_DATASETS = { + "maze2d-open-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-open-sparse.hdf5", + "maze2d-umaze-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-umaze-sparse-v1.hdf5", + "maze2d-medium-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-medium-sparse-v1.hdf5", + "maze2d-large-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-large-sparse-v1.hdf5", + "maze2d-eval-umaze-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-umaze-sparse-v1.hdf5", + "maze2d-eval-medium-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-medium-sparse-v1.hdf5", + "maze2d-eval-large-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-large-sparse-v1.hdf5", + "maze2d-open-dense-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-open-dense.hdf5", + "maze2d-umaze-dense-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-umaze-dense-v1.hdf5", + "maze2d-medium-dense-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-medium-dense-v1.hdf5", + "maze2d-large-dense-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-large-dense-v1.hdf5", + "maze2d-eval-umaze-dense-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-umaze-dense-v1.hdf5", + "maze2d-eval-medium-dense-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-medium-dense-v1.hdf5", + "maze2d-eval-large-dense-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-large-dense-v1.hdf5", + "minigrid-fourrooms-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/minigrid/minigrid4rooms.hdf5", + "minigrid-fourrooms-random-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/minigrid/minigrid4rooms_random.hdf5", + "pen-human-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/pen-v0_demos_clipped.hdf5", + "pen-cloned-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/pen-demos-v0-bc-combined.hdf5", + "pen-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/pen-v0_expert_clipped.hdf5", + "hammer-human-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/hammer-v0_demos_clipped.hdf5", + "hammer-cloned-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/hammer-demos-v0-bc-combined.hdf5", + "hammer-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/hammer-v0_expert_clipped.hdf5", + "relocate-human-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/relocate-v0_demos_clipped.hdf5", + "relocate-cloned-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/relocate-demos-v0-bc-combined.hdf5", + "relocate-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/relocate-v0_expert_clipped.hdf5", + "door-human-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/door-v0_demos_clipped.hdf5", + "door-cloned-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/door-demos-v0-bc-combined.hdf5", + "door-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/door-v0_expert_clipped.hdf5", + "halfcheetah-random-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_random.hdf5", + "halfcheetah-medium-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_medium.hdf5", + "halfcheetah-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_expert.hdf5", + "halfcheetah-medium-replay-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_mixed.hdf5", + "halfcheetah-medium-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_medium_expert.hdf5", + "walker2d-random-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_random.hdf5", + "walker2d-medium-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_medium.hdf5", + "walker2d-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_expert.hdf5", + "walker2d-medium-replay-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker_mixed.hdf5", + "walker2d-medium-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_medium_expert.hdf5", + "hopper-random-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_random.hdf5", + "hopper-medium-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_medium.hdf5", + "hopper-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_expert.hdf5", + "hopper-medium-replay-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_mixed.hdf5", + "hopper-medium-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_medium_expert.hdf5", + "ant-random-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_random.hdf5", + "ant-medium-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_medium.hdf5", + "ant-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_expert.hdf5", + "ant-medium-replay-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_mixed.hdf5", + "ant-medium-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_medium_expert.hdf5", + "ant-random-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_random_expert.hdf5", + "antmaze-umaze-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_u-maze_noisy_multistart_False_multigoal_False_sparse.hdf5", + "antmaze-umaze-diverse-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_u-maze_noisy_multistart_True_multigoal_True_sparse.hdf5", + "antmaze-medium-play-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_big-maze_noisy_multistart_True_multigoal_False_sparse.hdf5", + "antmaze-medium-diverse-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_big-maze_noisy_multistart_True_multigoal_True_sparse.hdf5", + "antmaze-large-play-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_False_sparse.hdf5", + "antmaze-large-diverse-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_True_sparse.hdf5", + "antmaze-umaze-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_u-maze_noisy_multistart_False_multigoal_False_sparse_fixed.hdf5", + "antmaze-umaze-diverse-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_u-maze_noisy_multistart_True_multigoal_True_sparse_fixed.hdf5", + "antmaze-medium-play-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_big-maze_noisy_multistart_True_multigoal_False_sparse_fixed.hdf5", + "antmaze-medium-diverse-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_big-maze_noisy_multistart_True_multigoal_True_sparse_fixed.hdf5", + "antmaze-large-play-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_False_sparse_fixed.hdf5", + "antmaze-large-diverse-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_True_sparse_fixed.hdf5", + "flow-ring-random-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-ring-v0-random.hdf5", + "flow-ring-controller-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-ring-v0-idm.hdf5", + "flow-merge-random-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-merge-v0-random.hdf5", + "flow-merge-controller-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-merge-v0-idm.hdf5", + "kitchen-complete-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/kitchen/mini_kitchen_microwave_kettle_light_slider-v0.hdf5", + "kitchen-partial-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/kitchen/kitchen_microwave_kettle_light_slider-v0.hdf5", + "kitchen-mixed-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/kitchen/kitchen_microwave_kettle_bottomburner_light-v0.hdf5", + "carla-lane-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_lane_follow_flat-v0.hdf5", + "carla-town-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_town_subsamp_flat-v0.hdf5", + "carla-town-full-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_town_flat-v0.hdf5", + "bullet-halfcheetah-random-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_random.hdf5", + "bullet-halfcheetah-medium-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_medium.hdf5", + "bullet-halfcheetah-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_expert.hdf5", + "bullet-halfcheetah-medium-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_medium_expert.hdf5", + "bullet-halfcheetah-medium-replay-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_medium_replay.hdf5", + "bullet-hopper-random-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_random.hdf5", + "bullet-hopper-medium-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_medium.hdf5", + "bullet-hopper-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_expert.hdf5", + "bullet-hopper-medium-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_medium_expert.hdf5", + "bullet-hopper-medium-replay-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_medium_replay.hdf5", + "bullet-ant-random-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_random.hdf5", + "bullet-ant-medium-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_medium.hdf5", + "bullet-ant-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_expert.hdf5", + "bullet-ant-medium-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_medium_expert.hdf5", + "bullet-ant-medium-replay-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_medium_replay.hdf5", + "bullet-walker2d-random-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_random.hdf5", + "bullet-walker2d-medium-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_medium.hdf5", + "bullet-walker2d-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_expert.hdf5", + "bullet-walker2d-medium-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_medium_expert.hdf5", + "bullet-walker2d-medium-replay-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_medium_replay.hdf5", + "bullet-maze2d-open-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-open-sparse.hdf5", + "bullet-maze2d-umaze-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-umaze-sparse.hdf5", + "bullet-maze2d-medium-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-medium-sparse.hdf5", + "bullet-maze2d-large-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-large-sparse.hdf5", + "halfcheetah-random-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/halfcheetah_random-v1.hdf5", + "halfcheetah-random-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/halfcheetah_random-v2.hdf5", + "halfcheetah-medium-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/halfcheetah_medium-v1.hdf5", + "halfcheetah-medium-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/halfcheetah_medium-v2.hdf5", + "halfcheetah-expert-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/halfcheetah_expert-v1.hdf5", + "halfcheetah-expert-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/halfcheetah_expert-v2.hdf5", + "halfcheetah-medium-replay-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/halfcheetah_medium_replay-v1.hdf5", + "halfcheetah-medium-replay-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/halfcheetah_medium_replay-v2.hdf5", + "halfcheetah-full-replay-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/halfcheetah_full_replay-v1.hdf5", + "halfcheetah-full-replay-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/halfcheetah_full_replay-v2.hdf5", + "halfcheetah-medium-expert-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/halfcheetah_medium_expert-v1.hdf5", + "halfcheetah-medium-expert-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/halfcheetah_medium_expert-v2.hdf5", + "hopper-random-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/hopper_random-v1.hdf5", + "hopper-random-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/hopper_random-v2.hdf5", + "hopper-medium-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/hopper_medium-v1.hdf5", + "hopper-medium-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/hopper_medium-v2.hdf5", + "hopper-expert-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/hopper_expert-v1.hdf5", + "hopper-expert-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/hopper_expert-v2.hdf5", + "hopper-medium-replay-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/hopper_medium_replay-v1.hdf5", + "hopper-medium-replay-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/hopper_medium_replay-v2.hdf5", + "hopper-full-replay-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/hopper_full_replay-v1.hdf5", + "hopper-full-replay-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/hopper_full_replay-v2.hdf5", + "hopper-medium-expert-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/hopper_medium_expert-v1.hdf5", + "hopper-medium-expert-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/hopper_medium_expert-v2.hdf5", + "walker2d-random-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/walker2d_random-v1.hdf5", + "walker2d-random-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/walker2d_random-v2.hdf5", + "walker2d-medium-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/walker2d_medium-v1.hdf5", + "walker2d-medium-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/walker2d_medium-v2.hdf5", + "walker2d-expert-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/walker2d_expert-v1.hdf5", + "walker2d-expert-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/walker2d_expert-v2.hdf5", + "walker2d-medium-replay-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/walker2d_medium_replay-v1.hdf5", + "walker2d-medium-replay-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/walker2d_medium_replay-v2.hdf5", + "walker2d-full-replay-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/walker2d_full_replay-v1.hdf5", + "walker2d-full-replay-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/walker2d_full_replay-v2.hdf5", + "walker2d-medium-expert-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/walker2d_medium_expert-v1.hdf5", + "walker2d-medium-expert-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/walker2d_medium_expert-v2.hdf5", + "ant-random-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/ant_random-v1.hdf5", + "ant-random-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/ant_random-v2.hdf5", + "ant-medium-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/ant_medium-v1.hdf5", + "ant-medium-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/ant_medium-v2.hdf5", + "ant-expert-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/ant_expert-v1.hdf5", + "ant-expert-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/ant_expert-v2.hdf5", + "ant-medium-replay-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/ant_medium_replay-v1.hdf5", + "ant-medium-replay-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/ant_medium_replay-v2.hdf5", + "ant-full-replay-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/ant_full_replay-v1.hdf5", + "ant-full-replay-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/ant_full_replay-v2.hdf5", + "ant-medium-expert-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/ant_medium_expert-v1.hdf5", + "ant-medium-expert-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/ant_medium_expert-v2.hdf5", + "hammer-human-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg_v1/hammer-human-v1.hdf5", + "hammer-expert-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg_v1/hammer-expert-v1.hdf5", + "hammer-cloned-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg_v1/hammer-cloned-v1.hdf5", + "pen-human-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg_v1/pen-human-v1.hdf5", + "pen-expert-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg_v1/pen-expert-v1.hdf5", + "pen-cloned-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg_v1/pen-cloned-v1.hdf5", + "relocate-human-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg_v1/relocate-human-v1.hdf5", + "relocate-expert-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg_v1/relocate-expert-v1.hdf5", + "relocate-cloned-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg_v1/relocate-cloned-v1.hdf5", + "door-human-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg_v1/door-human-v1.hdf5", + "door-expert-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg_v1/door-expert-v1.hdf5", + "door-cloned-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg_v1/door-cloned-v1.hdf5", + "antmaze-umaze-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v1/Ant_maze_umaze_noisy_multistart_False_multigoal_False_sparse.hdf5", + "antmaze-umaze-diverse-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v1/Ant_maze_umaze_noisy_multistart_True_multigoal_True_sparse.hdf5", + "antmaze-medium-play-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v1/Ant_maze_medium_noisy_multistart_True_multigoal_False_sparse.hdf5", + "antmaze-medium-diverse-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v1/Ant_maze_medium_noisy_multistart_True_multigoal_True_sparse.hdf5", + "antmaze-large-diverse-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v1/Ant_maze_large_noisy_multistart_True_multigoal_True_sparse.hdf5", + "antmaze-large-play-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v1/Ant_maze_large_noisy_multistart_True_multigoal_False_sparse.hdf5", + "antmaze-eval-umaze-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_umaze_eval_noisy_multistart_True_multigoal_False_sparse.hdf5", + "antmaze-eval-umaze-diverse-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_umaze_eval_noisy_multistart_True_multigoal_True_sparse.hdf5", + "antmaze-eval-medium-play-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_medium_eval_noisy_multistart_True_multigoal_True_sparse.hdf5", + "antmaze-eval-medium-diverse-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_medium_eval_noisy_multistart_True_multigoal_False_sparse.hdf5", + "antmaze-eval-large-diverse-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_large_eval_noisy_multistart_True_multigoal_False_sparse.hdf5", + "antmaze-eval-large-play-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_large_eval_noisy_multistart_True_multigoal_True_sparse.hdf5", + "door-human-longhorizon-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/door-v0_demos_clipped.hdf5", + "hammer-human-longhorizon-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/hammer-v0_demos_clipped.hdf5", + "pen-human-longhorizon-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/pen-v0_demos_clipped.hdf5", + "relocate-human-longhorizon-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/relocate-v0_demos_clipped.hdf5", + "maze2d-umaze-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-umaze-sparse.hdf5", + "maze2d-medium-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-medium-sparse.hdf5", + "maze2d-large-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-large-sparse.hdf5", + "maze2d-umaze-dense-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-umaze-dense.hdf5", + "maze2d-medium-dense-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-medium-dense.hdf5", + "maze2d-large-dense-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-large-dense.hdf5", + "carla-lane-render-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_lane_follow-v0.hdf5", + "carla-town-render-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_town_flat-v0.hdf5", + } + D4RL_ERR = None @classmethod @@ -118,19 +300,23 @@ def __init__( split_trajs: bool = False, from_env: bool = True, use_timeout_as_done: bool = True, + direct_download: bool = False, **env_kwargs, ): - - type(self)._import_d4rl() - - if not self._has_d4rl: - raise ImportError("Could not import d4rl") from self.D4RL_ERR self.from_env = from_env self.use_timeout_as_done = use_timeout_as_done - if from_env: - dataset = self._get_dataset_from_env(name, env_kwargs) + if not direct_download: + type(self)._import_d4rl() + + if not self._has_d4rl: + raise ImportError("Could not import d4rl") from self.D4RL_ERR + + if from_env: + dataset = self._get_dataset_from_env(name, env_kwargs) + else: + dataset = self._get_dataset_direct(name, env_kwargs) else: - dataset = self._get_dataset_direct(name, env_kwargs) + dataset = self._get_dataset_direct_download(name, env_kwargs) # Fill unknown next states with 0 dataset["next", "observation"][dataset["next", "done"].squeeze()] = 0 @@ -149,6 +335,9 @@ def __init__( ) self.extend(dataset) + def _get_dataset_direct_download(self, name, env_kwargs): + """Directly download and use a D4RL dataset.""" + def _get_dataset_direct(self, name, env_kwargs): from torchrl.envs.libs.gym import GymWrapper From 311d00d140be88e4c7c0a109801e860e787d395d Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 10 Jul 2023 14:45:27 -0400 Subject: [PATCH 084/104] minor --- torchrl/modules/tensordict_module/actors.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index cf87065c0e5..6cfefb70972 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -1750,8 +1750,8 @@ def mask_context(self, tensordict: TensorDictBase) -> TensorDictBase: action[..., : -self.inference_context, :] = 0 action = torch.cat( [ - action[:, 1:], - torch.zeros(action.shape[0], 1, action.shape[-1], device=action.device), + action[..., 1:, :], + torch.zeros(*action.shape[:-2], 1, action.shape[-1], device=action.device), ], dim=-2, ) @@ -1764,7 +1764,7 @@ def mask_context(self, tensordict: TensorDictBase) -> TensorDictBase: def forward(self, tensordict: TensorDictBase) -> TensorDictBase: """Forward pass of the inference wrapper.""" - unmasked_tensordict = tensordict.clone() + unmasked_tensordict = tensordict.clone(False) # Mask the context of the input sequences tensordict = self.mask_context(tensordict) # forward pass From 587cff66146e894d059042ee4a8b9a84dab2dfa5 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 10 Jul 2023 15:30:57 -0400 Subject: [PATCH 085/104] amend --- torchrl/data/datasets/d4rl.py | 83 ++++++++++++++++++++++++++++++----- 1 file changed, 72 insertions(+), 11 deletions(-) diff --git a/torchrl/data/datasets/d4rl.py b/torchrl/data/datasets/d4rl.py index 017fb31aa5b..c7a0918badc 100644 --- a/torchrl/data/datasets/d4rl.py +++ b/torchrl/data/datasets/d4rl.py @@ -2,13 +2,17 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - +import os +import urllib +from pathlib import Path from typing import Callable, Optional import numpy as np import torch -from tensordict.tensordict import make_tensordict + +from tensordict import PersistentTensorDict +from tensordict.tensordict import make_tensordict, TensorDict from torchrl.collectors.utils import split_trajectories from torchrl.data.replay_buffers import TensorDictReplayBuffer @@ -75,7 +79,8 @@ class D4RLExperienceReplay(TensorDictReplayBuffer): differ. In particular, the ``"timeout"`` key (used to determine the end of an episode) may be absent when ``from_env=False`` but present otherwise, leading to a different slicing when ``traj_splits`` is enabled. - + direct_download (bool): if ``True`` (default), the data will be downloaded without + requiring D4RL. This is not compatible with ``from_env=True``. use_timeout_as_done (bool, optional): if ``True``, ``done = terminal | timeout``. Otherwise, only the ``terminal`` key is used. Defaults to ``True``. **env_kwargs (key-value pairs): additional kwargs for @@ -86,7 +91,7 @@ class D4RLExperienceReplay(TensorDictReplayBuffer): Examples: >>> from torchrl.data.datasets.d4rl import D4RLExperienceReplay >>> from torchrl.envs import ObservationNorm - >>> data = D4RLExperienceReplay("maze2d-umaze-v1") + >>> data = D4RLExperienceReplay("maze2d-umaze-v1", 128) >>> # we can append transforms to the dataset >>> data.append_transform(ObservationNorm(loc=-1, scale=1.0)) >>> data.sample(128) @@ -300,13 +305,13 @@ def __init__( split_trajs: bool = False, from_env: bool = True, use_timeout_as_done: bool = True, - direct_download: bool = False, + direct_download: bool = True, **env_kwargs, ): self.from_env = from_env self.use_timeout_as_done = use_timeout_as_done if not direct_download: - type(self)._import_d4rl() + self._import_d4rl() if not self._has_d4rl: raise ImportError("Could not import d4rl") from self.D4RL_ERR @@ -337,6 +342,18 @@ def __init__( def _get_dataset_direct_download(self, name, env_kwargs): """Directly download and use a D4RL dataset.""" + if env_kwargs: + raise RuntimeError("Cannot pass env_kwargs when `direct_download=True`.") + url = self.D4RL_DATASETS.get(name, None) + if url is None: + raise KeyError(f"Env {name} not found.") + h5path = _download_dataset_from_url(url) + # h5path_parent = Path(h5path).parent + dataset = PersistentTensorDict.from_h5(h5path) + dataset = dataset.to_tensordict() + with dataset.unlock_(): + dataset = self._process_data_from_env(dataset) + return dataset def _get_dataset_direct(self, name, env_kwargs): from torchrl.envs.libs.gym import GymWrapper @@ -428,6 +445,10 @@ def _get_dataset_from_env(self, name, env_kwargs): } ) dataset = dataset.unflatten_keys("/") + dataset = self._process_data_from_env(dataset, env) + return dataset + + def _process_data_from_env(self, dataset, env=None): if "metadata" in dataset.keys(): metadata = dataset.get("metadata") dataset = dataset.exclude("metadata") @@ -458,10 +479,11 @@ def _get_dataset_from_env(self, name, env_kwargs): pass # let's make sure that the dtypes match what's expected - for key, spec in env.observation_spec.items(True, True): - dataset[key] = dataset[key].to(spec.dtype) - dataset["action"] = dataset["action"].to(env.action_spec.dtype) - dataset["reward"] = dataset["reward"].to(env.reward_spec.dtype) + if env is not None: + for key, spec in env.observation_spec.items(True, True): + dataset[key] = dataset[key].to(spec.dtype) + dataset["action"] = dataset["action"].to(env.action_spec.dtype) + dataset["reward"] = dataset["reward"].to(env.reward_spec.dtype) dataset["done"] = dataset["done"].bool() dataset["done"] = dataset["done"].unsqueeze(-1) @@ -478,7 +500,10 @@ def _get_dataset_from_env(self, name, env_kwargs): dataset.clone() ) # make sure that all tensors have a different data_ptr self._shift_reward_done(dataset) - self.specs = env.specs.clone() + if env is not None: + self.specs = env.specs.clone() + else: + self.specs = None return dataset def _shift_reward_done(self, dataset): @@ -488,3 +513,39 @@ def _shift_reward_done(self, dataset): dataset["done"][1:] = dataset["done"][:-1].clone() dataset["reward"][0] = 0 dataset["done"][0] = 0 + + +def _download_dataset_from_url(dataset_url): + dataset_filepath = _filepath_from_url(dataset_url) + if not os.path.exists(dataset_filepath): + print("Downloading dataset:", dataset_url, "to", dataset_filepath) + urllib.request.urlretrieve(dataset_url, dataset_filepath) + if not os.path.exists(dataset_filepath): + raise IOError("Failed to download dataset from %s" % dataset_url) + return dataset_filepath + + +def _filepath_from_url(dataset_url): + _, dataset_name = os.path.split(dataset_url) + dataset_filepath = os.path.join(DATASET_PATH, dataset_name) + return dataset_filepath + + +def _set_dataset_path(path): + global DATASET_PATH + DATASET_PATH = path + os.makedirs(path, exist_ok=True) + + +_set_dataset_path( + os.environ.get( + "D4RL_DATASET_DIR", os.path.expanduser("~/.cache/torchrl/data/d4rl/datasets") + ) +) + +if __name__ == "__main__": + data = D4RLExperienceReplay("kitchen-partial-v0", batch_size=128) + print(data) + for sample in data: + print(sample) + break From 7342c83e822b432749ec1ea4da8208ad62c6dd99 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 10 Jul 2023 16:25:49 -0400 Subject: [PATCH 086/104] amend --- examples/decision_transformer/utils.py | 12 +++++------- torchrl/data/datasets/d4rl.py | 3 +-- torchrl/modules/models/models.py | 8 ++++---- torchrl/modules/tensordict_module/actors.py | 5 ++++- 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/examples/decision_transformer/utils.py b/examples/decision_transformer/utils.py index c27bb7e299c..ab5d8339e76 100644 --- a/examples/decision_transformer/utils.py +++ b/examples/decision_transformer/utils.py @@ -182,13 +182,11 @@ def make_collector(cfg, policy): def get_loc_std(env_name): - import d4rl # noqa - import gym - - env = gym.make(env_name) - data = env.get_dataset() - loc = torch.from_numpy(data["observations"].mean(axis=0)).float() - std = torch.from_numpy(data["observations"].std(axis=0)).float() + data = D4RLExperienceReplay(env_name, 1024) + for sample in data: + loc = sample.get("observation").mean() + std = sample.get("observation").std() + break return loc, std diff --git a/torchrl/data/datasets/d4rl.py b/torchrl/data/datasets/d4rl.py index c7a0918badc..e334261535d 100644 --- a/torchrl/data/datasets/d4rl.py +++ b/torchrl/data/datasets/d4rl.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. import os import urllib -from pathlib import Path from typing import Callable, Optional import numpy as np @@ -12,7 +11,7 @@ import torch from tensordict import PersistentTensorDict -from tensordict.tensordict import make_tensordict, TensorDict +from tensordict.tensordict import make_tensordict from torchrl.collectors.utils import split_trajectories from torchrl.data.replay_buffers import TensorDictReplayBuffer diff --git a/torchrl/modules/models/models.py b/torchrl/modules/models/models.py index f3ec1d247d6..45b661ae604 100644 --- a/torchrl/modules/models/models.py +++ b/torchrl/modules/models/models.py @@ -1179,7 +1179,7 @@ def __init__( config=transformer_config, ) self.action_layer = nn.Linear( - transformer_config['n_embd'], action_dim * 2, device=device + transformer_config["n_embd"], action_dim * 2, device=device ) self.log_std_min, self.log_std_max = -5.0, 2.0 @@ -1214,7 +1214,7 @@ def forward( @classmethod def get_default_config(cls): - """Default configuration for :class:`~.OnlineDTActor`""" + """Default configuration for :class:`~.OnlineDTActor`.""" return { "n_embd": 256, "n_layer": 4, @@ -1265,7 +1265,7 @@ def __init__( config=transformer_config, ) self.action_layer = nn.Linear( - transformer_config['n_embd'], action_dim, device=device + transformer_config["n_embd"], action_dim, device=device ) def weight_init(m): @@ -1289,7 +1289,7 @@ def forward( @classmethod def get_default_config(cls): - """Default configuration for :class:`~.DTActor`""" + """Default configuration for :class:`~.DTActor`.""" return { "n_embd": 256, "n_layer": 4, diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index 6cfefb70972..34c5b2c9e48 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -1705,6 +1705,7 @@ def __init__( @property def in_keys(self): return [self.observation_key, self.action_key, self.return_to_go_key] + @property def out_keys(self): return [self.observation_key, self.action_key, self.return_to_go_key] @@ -1751,7 +1752,9 @@ def mask_context(self, tensordict: TensorDictBase) -> TensorDictBase: action = torch.cat( [ action[..., 1:, :], - torch.zeros(*action.shape[:-2], 1, action.shape[-1], device=action.device), + torch.zeros( + *action.shape[:-2], 1, action.shape[-1], device=action.device + ), ], dim=-2, ) From 18c6b00f35616f2fbf59a8c5c232ea78fe342081 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 10 Jul 2023 17:12:21 -0400 Subject: [PATCH 087/104] amend --- .../linux_examples/scripts/run_all.sh | 23 +++++++++++++++++ .../linux_examples/scripts/run_test.sh | 25 ++----------------- examples/decision_transformer/dt.py | 2 -- examples/decision_transformer/utils.py | 1 + torchrl/envs/common.py | 2 +- 5 files changed, 27 insertions(+), 26 deletions(-) diff --git a/.circleci/unittest/linux_examples/scripts/run_all.sh b/.circleci/unittest/linux_examples/scripts/run_all.sh index 5efa1976509..8d47bef9998 100755 --- a/.circleci/unittest/linux_examples/scripts/run_all.sh +++ b/.circleci/unittest/linux_examples/scripts/run_all.sh @@ -102,6 +102,29 @@ conda env config vars set LD_PRELOAD=${root_dir}/$STDC_LOC # compile mujoco-py (bc it's done at runtime for whatever reason someone thought it was a good idea) python -c """import gym;import d4rl""" +# install ale-py: manylinux names are broken for CentOS so we need to manually download and +# rename them +PY_VERSION=$(python --version) +if [[ $PY_VERSION == *"3.7"* ]]; then + wget https://files.pythonhosted.org/packages/ab/fd/6615982d9460df7f476cad265af1378057eee9daaa8e0026de4cedbaffbd/ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + pip install ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + rm ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl +elif [[ $PY_VERSION == *"3.8"* ]]; then + wget https://files.pythonhosted.org/packages/0f/8a/feed20571a697588bc4bfef05d6a487429c84f31406a52f8af295a0346a2/ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + pip install ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + rm ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl +elif [[ $PY_VERSION == *"3.9"* ]]; then + wget https://files.pythonhosted.org/packages/a0/98/4316c1cedd9934f9a91b6e27a9be126043b4445594b40cfa391c8de2e5e8/ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + pip install ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + rm ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl +elif [[ $PY_VERSION == *"3.10"* ]]; then + wget https://files.pythonhosted.org/packages/60/1b/3adde7f44f79fcc50d0a00a0643255e48024c4c3977359747d149dc43500/ale_py-0.8.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl + mv ale_py-0.8.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + pip install ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + rm ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl +fi +pip install "gymnasium[atari,accept-rom-license]" + # ============================================================================================ # # ================================ PyTorch & TorchRL ========================================= # diff --git a/.circleci/unittest/linux_examples/scripts/run_test.sh b/.circleci/unittest/linux_examples/scripts/run_test.sh index 11a14bb9a61..6d60dadb7c1 100755 --- a/.circleci/unittest/linux_examples/scripts/run_test.sh +++ b/.circleci/unittest/linux_examples/scripts/run_test.sh @@ -36,39 +36,18 @@ python .circleci/unittest/helpers/coverage_run_parallel.py examples/decision_tra optim.pretrain_gradient_steps=55 \ optim.updates_per_episode=3 \ optim.warmup_steps=10 \ + optim.device=cuda:0 \ logger.backend= python .circleci/unittest/helpers/coverage_run_parallel.py examples/decision_transformer/online_td.py \ optim.pretrain_gradient_steps=55 \ optim.updates_per_episode=3 \ optim.warmup_steps=10 \ + optim.device=cuda:0 \ logger.backend= # ==================================================================================== # # ================================ Gymnasium ========================================= # -# install ale-py: manylinux names are broken for CentOS so we need to manually download and -# rename them -PY_VERSION=$(python --version) -if [[ $PY_VERSION == *"3.7"* ]]; then - wget https://files.pythonhosted.org/packages/ab/fd/6615982d9460df7f476cad265af1378057eee9daaa8e0026de4cedbaffbd/ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pip install ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - rm ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl -elif [[ $PY_VERSION == *"3.8"* ]]; then - wget https://files.pythonhosted.org/packages/0f/8a/feed20571a697588bc4bfef05d6a487429c84f31406a52f8af295a0346a2/ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pip install ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - rm ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl -elif [[ $PY_VERSION == *"3.9"* ]]; then - wget https://files.pythonhosted.org/packages/a0/98/4316c1cedd9934f9a91b6e27a9be126043b4445594b40cfa391c8de2e5e8/ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pip install ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - rm ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl -elif [[ $PY_VERSION == *"3.10"* ]]; then - wget https://files.pythonhosted.org/packages/60/1b/3adde7f44f79fcc50d0a00a0643255e48024c4c3977359747d149dc43500/ale_py-0.8.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl - mv ale_py-0.8.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pip install ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - rm ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl -fi -pip install "gymnasium[atari,accept-rom-license]" - python .circleci/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo.py \ env.num_envs=1 \ env.device=cuda:0 \ diff --git a/examples/decision_transformer/dt.py b/examples/decision_transformer/dt.py index 3f9ca9c02e3..046b29b4b76 100644 --- a/examples/decision_transformer/dt.py +++ b/examples/decision_transformer/dt.py @@ -11,7 +11,6 @@ import torch import tqdm -from torchrl.envs.libs.gym import set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules.tensordict_module import DecisionTransformerInferenceWrapper @@ -25,7 +24,6 @@ ) -@set_gym_backend("gym") # D4RL uses gym so we make sure gymnasium is hidden @hydra.main(config_path=".", config_name="dt_config") def main(cfg: "DictConfig"): # noqa: F821 model_device = cfg.optim.device diff --git a/examples/decision_transformer/utils.py b/examples/decision_transformer/utils.py index ab5d8339e76..24d537ac67b 100644 --- a/examples/decision_transformer/utils.py +++ b/examples/decision_transformer/utils.py @@ -45,6 +45,7 @@ # ----------------- +@set_gym_backend("gym") # D4RL uses gym so we make sure gymnasium is hidden def make_base_env(env_cfg): env_library = LIBS[env_cfg.library] env_name = env_cfg.name diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 15168f20411..ebac648e0fc 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -1249,7 +1249,7 @@ def policy(td): tensordict = step_mdp( tensordict, keep_other=True, - exclude_action=True, + exclude_action=False, exclude_reward=True, reward_key=self.reward_key, action_key=self.action_key, From c4c02e675fa9943771ef0b28864b9c28c19b2fa6 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 10 Jul 2023 17:13:28 -0400 Subject: [PATCH 088/104] revert d4rl --- torchrl/data/datasets/d4rl.py | 281 ++-------------------------------- 1 file changed, 16 insertions(+), 265 deletions(-) diff --git a/torchrl/data/datasets/d4rl.py b/torchrl/data/datasets/d4rl.py index e334261535d..d6c32083e23 100644 --- a/torchrl/data/datasets/d4rl.py +++ b/torchrl/data/datasets/d4rl.py @@ -2,15 +2,12 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import os -import urllib + from typing import Callable, Optional import numpy as np import torch - -from tensordict import PersistentTensorDict from tensordict.tensordict import make_tensordict from torchrl.collectors.utils import split_trajectories @@ -78,8 +75,7 @@ class D4RLExperienceReplay(TensorDictReplayBuffer): differ. In particular, the ``"timeout"`` key (used to determine the end of an episode) may be absent when ``from_env=False`` but present otherwise, leading to a different slicing when ``traj_splits`` is enabled. - direct_download (bool): if ``True`` (default), the data will be downloaded without - requiring D4RL. This is not compatible with ``from_env=True``. + use_timeout_as_done (bool, optional): if ``True``, ``done = terminal | timeout``. Otherwise, only the ``terminal`` key is used. Defaults to ``True``. **env_kwargs (key-value pairs): additional kwargs for @@ -90,195 +86,13 @@ class D4RLExperienceReplay(TensorDictReplayBuffer): Examples: >>> from torchrl.data.datasets.d4rl import D4RLExperienceReplay >>> from torchrl.envs import ObservationNorm - >>> data = D4RLExperienceReplay("maze2d-umaze-v1", 128) + >>> data = D4RLExperienceReplay("maze2d-umaze-v1") >>> # we can append transforms to the dataset >>> data.append_transform(ObservationNorm(loc=-1, scale=1.0)) >>> data.sample(128) """ - D4RL_DATASETS = { - "maze2d-open-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-open-sparse.hdf5", - "maze2d-umaze-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-umaze-sparse-v1.hdf5", - "maze2d-medium-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-medium-sparse-v1.hdf5", - "maze2d-large-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-large-sparse-v1.hdf5", - "maze2d-eval-umaze-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-umaze-sparse-v1.hdf5", - "maze2d-eval-medium-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-medium-sparse-v1.hdf5", - "maze2d-eval-large-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-large-sparse-v1.hdf5", - "maze2d-open-dense-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-open-dense.hdf5", - "maze2d-umaze-dense-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-umaze-dense-v1.hdf5", - "maze2d-medium-dense-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-medium-dense-v1.hdf5", - "maze2d-large-dense-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-large-dense-v1.hdf5", - "maze2d-eval-umaze-dense-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-umaze-dense-v1.hdf5", - "maze2d-eval-medium-dense-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-medium-dense-v1.hdf5", - "maze2d-eval-large-dense-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-large-dense-v1.hdf5", - "minigrid-fourrooms-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/minigrid/minigrid4rooms.hdf5", - "minigrid-fourrooms-random-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/minigrid/minigrid4rooms_random.hdf5", - "pen-human-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/pen-v0_demos_clipped.hdf5", - "pen-cloned-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/pen-demos-v0-bc-combined.hdf5", - "pen-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/pen-v0_expert_clipped.hdf5", - "hammer-human-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/hammer-v0_demos_clipped.hdf5", - "hammer-cloned-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/hammer-demos-v0-bc-combined.hdf5", - "hammer-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/hammer-v0_expert_clipped.hdf5", - "relocate-human-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/relocate-v0_demos_clipped.hdf5", - "relocate-cloned-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/relocate-demos-v0-bc-combined.hdf5", - "relocate-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/relocate-v0_expert_clipped.hdf5", - "door-human-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/door-v0_demos_clipped.hdf5", - "door-cloned-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/door-demos-v0-bc-combined.hdf5", - "door-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/door-v0_expert_clipped.hdf5", - "halfcheetah-random-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_random.hdf5", - "halfcheetah-medium-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_medium.hdf5", - "halfcheetah-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_expert.hdf5", - "halfcheetah-medium-replay-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_mixed.hdf5", - "halfcheetah-medium-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_medium_expert.hdf5", - "walker2d-random-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_random.hdf5", - "walker2d-medium-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_medium.hdf5", - "walker2d-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_expert.hdf5", - "walker2d-medium-replay-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker_mixed.hdf5", - "walker2d-medium-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_medium_expert.hdf5", - "hopper-random-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_random.hdf5", - "hopper-medium-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_medium.hdf5", - "hopper-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_expert.hdf5", - "hopper-medium-replay-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_mixed.hdf5", - "hopper-medium-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_medium_expert.hdf5", - "ant-random-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_random.hdf5", - "ant-medium-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_medium.hdf5", - "ant-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_expert.hdf5", - "ant-medium-replay-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_mixed.hdf5", - "ant-medium-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_medium_expert.hdf5", - "ant-random-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_random_expert.hdf5", - "antmaze-umaze-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_u-maze_noisy_multistart_False_multigoal_False_sparse.hdf5", - "antmaze-umaze-diverse-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_u-maze_noisy_multistart_True_multigoal_True_sparse.hdf5", - "antmaze-medium-play-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_big-maze_noisy_multistart_True_multigoal_False_sparse.hdf5", - "antmaze-medium-diverse-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_big-maze_noisy_multistart_True_multigoal_True_sparse.hdf5", - "antmaze-large-play-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_False_sparse.hdf5", - "antmaze-large-diverse-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_True_sparse.hdf5", - "antmaze-umaze-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_u-maze_noisy_multistart_False_multigoal_False_sparse_fixed.hdf5", - "antmaze-umaze-diverse-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_u-maze_noisy_multistart_True_multigoal_True_sparse_fixed.hdf5", - "antmaze-medium-play-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_big-maze_noisy_multistart_True_multigoal_False_sparse_fixed.hdf5", - "antmaze-medium-diverse-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_big-maze_noisy_multistart_True_multigoal_True_sparse_fixed.hdf5", - "antmaze-large-play-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_False_sparse_fixed.hdf5", - "antmaze-large-diverse-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_True_sparse_fixed.hdf5", - "flow-ring-random-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-ring-v0-random.hdf5", - "flow-ring-controller-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-ring-v0-idm.hdf5", - "flow-merge-random-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-merge-v0-random.hdf5", - "flow-merge-controller-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-merge-v0-idm.hdf5", - "kitchen-complete-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/kitchen/mini_kitchen_microwave_kettle_light_slider-v0.hdf5", - "kitchen-partial-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/kitchen/kitchen_microwave_kettle_light_slider-v0.hdf5", - "kitchen-mixed-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/kitchen/kitchen_microwave_kettle_bottomburner_light-v0.hdf5", - "carla-lane-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_lane_follow_flat-v0.hdf5", - "carla-town-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_town_subsamp_flat-v0.hdf5", - "carla-town-full-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_town_flat-v0.hdf5", - "bullet-halfcheetah-random-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_random.hdf5", - "bullet-halfcheetah-medium-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_medium.hdf5", - "bullet-halfcheetah-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_expert.hdf5", - "bullet-halfcheetah-medium-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_medium_expert.hdf5", - "bullet-halfcheetah-medium-replay-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_medium_replay.hdf5", - "bullet-hopper-random-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_random.hdf5", - "bullet-hopper-medium-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_medium.hdf5", - "bullet-hopper-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_expert.hdf5", - "bullet-hopper-medium-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_medium_expert.hdf5", - "bullet-hopper-medium-replay-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_medium_replay.hdf5", - "bullet-ant-random-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_random.hdf5", - "bullet-ant-medium-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_medium.hdf5", - "bullet-ant-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_expert.hdf5", - "bullet-ant-medium-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_medium_expert.hdf5", - "bullet-ant-medium-replay-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_medium_replay.hdf5", - "bullet-walker2d-random-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_random.hdf5", - "bullet-walker2d-medium-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_medium.hdf5", - "bullet-walker2d-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_expert.hdf5", - "bullet-walker2d-medium-expert-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_medium_expert.hdf5", - "bullet-walker2d-medium-replay-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_medium_replay.hdf5", - "bullet-maze2d-open-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-open-sparse.hdf5", - "bullet-maze2d-umaze-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-umaze-sparse.hdf5", - "bullet-maze2d-medium-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-medium-sparse.hdf5", - "bullet-maze2d-large-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-large-sparse.hdf5", - "halfcheetah-random-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/halfcheetah_random-v1.hdf5", - "halfcheetah-random-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/halfcheetah_random-v2.hdf5", - "halfcheetah-medium-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/halfcheetah_medium-v1.hdf5", - "halfcheetah-medium-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/halfcheetah_medium-v2.hdf5", - "halfcheetah-expert-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/halfcheetah_expert-v1.hdf5", - "halfcheetah-expert-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/halfcheetah_expert-v2.hdf5", - "halfcheetah-medium-replay-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/halfcheetah_medium_replay-v1.hdf5", - "halfcheetah-medium-replay-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/halfcheetah_medium_replay-v2.hdf5", - "halfcheetah-full-replay-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/halfcheetah_full_replay-v1.hdf5", - "halfcheetah-full-replay-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/halfcheetah_full_replay-v2.hdf5", - "halfcheetah-medium-expert-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/halfcheetah_medium_expert-v1.hdf5", - "halfcheetah-medium-expert-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/halfcheetah_medium_expert-v2.hdf5", - "hopper-random-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/hopper_random-v1.hdf5", - "hopper-random-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/hopper_random-v2.hdf5", - "hopper-medium-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/hopper_medium-v1.hdf5", - "hopper-medium-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/hopper_medium-v2.hdf5", - "hopper-expert-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/hopper_expert-v1.hdf5", - "hopper-expert-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/hopper_expert-v2.hdf5", - "hopper-medium-replay-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/hopper_medium_replay-v1.hdf5", - "hopper-medium-replay-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/hopper_medium_replay-v2.hdf5", - "hopper-full-replay-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/hopper_full_replay-v1.hdf5", - "hopper-full-replay-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/hopper_full_replay-v2.hdf5", - "hopper-medium-expert-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/hopper_medium_expert-v1.hdf5", - "hopper-medium-expert-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/hopper_medium_expert-v2.hdf5", - "walker2d-random-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/walker2d_random-v1.hdf5", - "walker2d-random-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/walker2d_random-v2.hdf5", - "walker2d-medium-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/walker2d_medium-v1.hdf5", - "walker2d-medium-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/walker2d_medium-v2.hdf5", - "walker2d-expert-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/walker2d_expert-v1.hdf5", - "walker2d-expert-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/walker2d_expert-v2.hdf5", - "walker2d-medium-replay-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/walker2d_medium_replay-v1.hdf5", - "walker2d-medium-replay-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/walker2d_medium_replay-v2.hdf5", - "walker2d-full-replay-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/walker2d_full_replay-v1.hdf5", - "walker2d-full-replay-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/walker2d_full_replay-v2.hdf5", - "walker2d-medium-expert-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/walker2d_medium_expert-v1.hdf5", - "walker2d-medium-expert-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/walker2d_medium_expert-v2.hdf5", - "ant-random-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/ant_random-v1.hdf5", - "ant-random-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/ant_random-v2.hdf5", - "ant-medium-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/ant_medium-v1.hdf5", - "ant-medium-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/ant_medium-v2.hdf5", - "ant-expert-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/ant_expert-v1.hdf5", - "ant-expert-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/ant_expert-v2.hdf5", - "ant-medium-replay-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/ant_medium_replay-v1.hdf5", - "ant-medium-replay-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/ant_medium_replay-v2.hdf5", - "ant-full-replay-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/ant_full_replay-v1.hdf5", - "ant-full-replay-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/ant_full_replay-v2.hdf5", - "ant-medium-expert-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/ant_medium_expert-v1.hdf5", - "ant-medium-expert-v2": "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/ant_medium_expert-v2.hdf5", - "hammer-human-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg_v1/hammer-human-v1.hdf5", - "hammer-expert-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg_v1/hammer-expert-v1.hdf5", - "hammer-cloned-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg_v1/hammer-cloned-v1.hdf5", - "pen-human-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg_v1/pen-human-v1.hdf5", - "pen-expert-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg_v1/pen-expert-v1.hdf5", - "pen-cloned-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg_v1/pen-cloned-v1.hdf5", - "relocate-human-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg_v1/relocate-human-v1.hdf5", - "relocate-expert-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg_v1/relocate-expert-v1.hdf5", - "relocate-cloned-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg_v1/relocate-cloned-v1.hdf5", - "door-human-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg_v1/door-human-v1.hdf5", - "door-expert-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg_v1/door-expert-v1.hdf5", - "door-cloned-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg_v1/door-cloned-v1.hdf5", - "antmaze-umaze-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v1/Ant_maze_umaze_noisy_multistart_False_multigoal_False_sparse.hdf5", - "antmaze-umaze-diverse-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v1/Ant_maze_umaze_noisy_multistart_True_multigoal_True_sparse.hdf5", - "antmaze-medium-play-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v1/Ant_maze_medium_noisy_multistart_True_multigoal_False_sparse.hdf5", - "antmaze-medium-diverse-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v1/Ant_maze_medium_noisy_multistart_True_multigoal_True_sparse.hdf5", - "antmaze-large-diverse-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v1/Ant_maze_large_noisy_multistart_True_multigoal_True_sparse.hdf5", - "antmaze-large-play-v1": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v1/Ant_maze_large_noisy_multistart_True_multigoal_False_sparse.hdf5", - "antmaze-eval-umaze-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_umaze_eval_noisy_multistart_True_multigoal_False_sparse.hdf5", - "antmaze-eval-umaze-diverse-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_umaze_eval_noisy_multistart_True_multigoal_True_sparse.hdf5", - "antmaze-eval-medium-play-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_medium_eval_noisy_multistart_True_multigoal_True_sparse.hdf5", - "antmaze-eval-medium-diverse-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_medium_eval_noisy_multistart_True_multigoal_False_sparse.hdf5", - "antmaze-eval-large-diverse-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_large_eval_noisy_multistart_True_multigoal_False_sparse.hdf5", - "antmaze-eval-large-play-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_large_eval_noisy_multistart_True_multigoal_True_sparse.hdf5", - "door-human-longhorizon-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/door-v0_demos_clipped.hdf5", - "hammer-human-longhorizon-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/hammer-v0_demos_clipped.hdf5", - "pen-human-longhorizon-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/pen-v0_demos_clipped.hdf5", - "relocate-human-longhorizon-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/relocate-v0_demos_clipped.hdf5", - "maze2d-umaze-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-umaze-sparse.hdf5", - "maze2d-medium-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-medium-sparse.hdf5", - "maze2d-large-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-large-sparse.hdf5", - "maze2d-umaze-dense-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-umaze-dense.hdf5", - "maze2d-medium-dense-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-medium-dense.hdf5", - "maze2d-large-dense-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-large-dense.hdf5", - "carla-lane-render-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_lane_follow-v0.hdf5", - "carla-town-render-v0": "http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_town_flat-v0.hdf5", - } - D4RL_ERR = None @classmethod @@ -304,23 +118,19 @@ def __init__( split_trajs: bool = False, from_env: bool = True, use_timeout_as_done: bool = True, - direct_download: bool = True, **env_kwargs, ): - self.from_env = from_env - self.use_timeout_as_done = use_timeout_as_done - if not direct_download: - self._import_d4rl() - if not self._has_d4rl: - raise ImportError("Could not import d4rl") from self.D4RL_ERR + type(self)._import_d4rl() - if from_env: - dataset = self._get_dataset_from_env(name, env_kwargs) - else: - dataset = self._get_dataset_direct(name, env_kwargs) + if not self._has_d4rl: + raise ImportError("Could not import d4rl") from self.D4RL_ERR + self.from_env = from_env + self.use_timeout_as_done = use_timeout_as_done + if from_env: + dataset = self._get_dataset_from_env(name, env_kwargs) else: - dataset = self._get_dataset_direct_download(name, env_kwargs) + dataset = self._get_dataset_direct(name, env_kwargs) # Fill unknown next states with 0 dataset["next", "observation"][dataset["next", "done"].squeeze()] = 0 @@ -339,21 +149,6 @@ def __init__( ) self.extend(dataset) - def _get_dataset_direct_download(self, name, env_kwargs): - """Directly download and use a D4RL dataset.""" - if env_kwargs: - raise RuntimeError("Cannot pass env_kwargs when `direct_download=True`.") - url = self.D4RL_DATASETS.get(name, None) - if url is None: - raise KeyError(f"Env {name} not found.") - h5path = _download_dataset_from_url(url) - # h5path_parent = Path(h5path).parent - dataset = PersistentTensorDict.from_h5(h5path) - dataset = dataset.to_tensordict() - with dataset.unlock_(): - dataset = self._process_data_from_env(dataset) - return dataset - def _get_dataset_direct(self, name, env_kwargs): from torchrl.envs.libs.gym import GymWrapper @@ -444,10 +239,6 @@ def _get_dataset_from_env(self, name, env_kwargs): } ) dataset = dataset.unflatten_keys("/") - dataset = self._process_data_from_env(dataset, env) - return dataset - - def _process_data_from_env(self, dataset, env=None): if "metadata" in dataset.keys(): metadata = dataset.get("metadata") dataset = dataset.exclude("metadata") @@ -478,11 +269,10 @@ def _process_data_from_env(self, dataset, env=None): pass # let's make sure that the dtypes match what's expected - if env is not None: - for key, spec in env.observation_spec.items(True, True): - dataset[key] = dataset[key].to(spec.dtype) - dataset["action"] = dataset["action"].to(env.action_spec.dtype) - dataset["reward"] = dataset["reward"].to(env.reward_spec.dtype) + for key, spec in env.observation_spec.items(True, True): + dataset[key] = dataset[key].to(spec.dtype) + dataset["action"] = dataset["action"].to(env.action_spec.dtype) + dataset["reward"] = dataset["reward"].to(env.reward_spec.dtype) dataset["done"] = dataset["done"].bool() dataset["done"] = dataset["done"].unsqueeze(-1) @@ -499,10 +289,7 @@ def _process_data_from_env(self, dataset, env=None): dataset.clone() ) # make sure that all tensors have a different data_ptr self._shift_reward_done(dataset) - if env is not None: - self.specs = env.specs.clone() - else: - self.specs = None + self.specs = env.specs.clone() return dataset def _shift_reward_done(self, dataset): @@ -512,39 +299,3 @@ def _shift_reward_done(self, dataset): dataset["done"][1:] = dataset["done"][:-1].clone() dataset["reward"][0] = 0 dataset["done"][0] = 0 - - -def _download_dataset_from_url(dataset_url): - dataset_filepath = _filepath_from_url(dataset_url) - if not os.path.exists(dataset_filepath): - print("Downloading dataset:", dataset_url, "to", dataset_filepath) - urllib.request.urlretrieve(dataset_url, dataset_filepath) - if not os.path.exists(dataset_filepath): - raise IOError("Failed to download dataset from %s" % dataset_url) - return dataset_filepath - - -def _filepath_from_url(dataset_url): - _, dataset_name = os.path.split(dataset_url) - dataset_filepath = os.path.join(DATASET_PATH, dataset_name) - return dataset_filepath - - -def _set_dataset_path(path): - global DATASET_PATH - DATASET_PATH = path - os.makedirs(path, exist_ok=True) - - -_set_dataset_path( - os.environ.get( - "D4RL_DATASET_DIR", os.path.expanduser("~/.cache/torchrl/data/d4rl/datasets") - ) -) - -if __name__ == "__main__": - data = D4RLExperienceReplay("kitchen-partial-v0", batch_size=128) - print(data) - for sample in data: - print(sample) - break From d67a8220b2613e1ce641e234ac55dfc3061f0640 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 10 Jul 2023 20:05:16 -0400 Subject: [PATCH 089/104] amend --- .../unittest/linux/scripts/environment.yml | 1 + .../linux_examples/scripts/run_test.sh | 2 +- docs/source/reference/modules.rst | 11 +- test/test_modules.py | 82 +++++++++++- test/test_tensordictmodules.py | 69 ++++++++++ torchrl/modules/__init__.py | 1 + torchrl/modules/models/__init__.py | 1 + .../modules/models/decision_transformer.py | 122 +++++++++++------- torchrl/modules/models/models.py | 74 ++++++----- torchrl/modules/tensordict_module/actors.py | 68 +++++++--- 10 files changed, 334 insertions(+), 97 deletions(-) diff --git a/.circleci/unittest/linux/scripts/environment.yml b/.circleci/unittest/linux/scripts/environment.yml index f27bae7da6c..ed33f11b27f 100644 --- a/.circleci/unittest/linux/scripts/environment.yml +++ b/.circleci/unittest/linux/scripts/environment.yml @@ -28,3 +28,4 @@ dependencies: - av - coverage - ray + - transformers diff --git a/.circleci/unittest/linux_examples/scripts/run_test.sh b/.circleci/unittest/linux_examples/scripts/run_test.sh index 6d60dadb7c1..bcf4688412f 100755 --- a/.circleci/unittest/linux_examples/scripts/run_test.sh +++ b/.circleci/unittest/linux_examples/scripts/run_test.sh @@ -38,7 +38,7 @@ python .circleci/unittest/helpers/coverage_run_parallel.py examples/decision_tra optim.warmup_steps=10 \ optim.device=cuda:0 \ logger.backend= -python .circleci/unittest/helpers/coverage_run_parallel.py examples/decision_transformer/online_td.py \ +python .circleci/unittest/helpers/coverage_run_parallel.py examples/decision_transformer/online_dt.py \ optim.pretrain_gradient_steps=55 \ optim.updates_per_episode=3 \ optim.warmup_steps=10 \ diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst index 7a782ab52eb..56457fc7e5a 100644 --- a/docs/source/reference/modules.rst +++ b/docs/source/reference/modules.rst @@ -322,18 +322,21 @@ algorithms, such as DQN, DDPG or Dreamer. :toctree: generated/ :template: rl_template_noinherit.rst - DuelingCnnDQNet - DistributionalDQNnet + DTActor DdpgCnnActor DdpgCnnQNet DdpgMlpActor DdpgMlpQNet + DecisionTransformer + DistributionalDQNnet DreamerActor + DuelingCnnDQNet LSTMModule - ObsEncoder ObsDecoder - RSSMPrior + ObsEncoder + OnlineDTActor RSSMPosterior + RSSMPrior Exploration diff --git a/test/test_modules.py b/test/test_modules.py index 2481ec09f69..b4c9770e8d9 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -14,9 +14,21 @@ from tensordict import TensorDict from torch import nn from torchrl.data.tensor_specs import BoundedTensorSpec, CompositeSpec -from torchrl.modules import CEMPlanner, LSTMNet, SafeModule, TanhModule, ValueOperator +from torchrl.modules import ( + CEMPlanner, + DTActor, + LSTMNet, + OnlineDTActor, + SafeModule, + TanhModule, + ValueOperator, +) from torchrl.modules.distributions.utils import safeatanh, safetanh from torchrl.modules.models import ConvNet, MLP, NoisyLazyLinear, NoisyLinear +from torchrl.modules.models.decision_transformer import ( + _has_transformers, + DecisionTransformer, +) from torchrl.modules.models.model_based import ( DreamerActor, ObsDecoder, @@ -737,6 +749,74 @@ def test_tanh_atanh(use_vmap, scale): torch.testing.assert_close(x.grad, torch.ones_like(x)) +@pytest.mark.skipif( + not _has_transformers, reason="transformers needed for TestDecisionTransformer" +) +class TestDecisionTransformer: + def test_init(self): + DecisionTransformer( + 3, + 4, + ) + with pytest.raises(TypeError): + DecisionTransformer(3, 4, config="some_str") + DecisionTransformer( + 3, + 4, + config=DecisionTransformer.DTConfig( + n_layer=2, n_embd=16, n_positions=16, n_inner=16, n_head=2 + ), + ) + + @pytest.mark.parametrize("batch_dims", [[], [3], [3, 4]]) + def test_exec(self, batch_dims, T=5): + observations = torch.randn(*batch_dims, T, 3) + actions = torch.randn(*batch_dims, T, 4) + r2go = torch.randn(*batch_dims, T, 1) + model = DecisionTransformer( + 3, + 4, + config=DecisionTransformer.DTConfig( + n_layer=2, n_embd=16, n_positions=16, n_inner=16, n_head=2 + ), + ) + out = model(observations, actions, r2go) + assert out.shape == torch.Size([*batch_dims, T, 16]) + + @pytest.mark.parametrize("batch_dims", [[], [3], [3, 4]]) + def test_dtactor(self, batch_dims, T=5): + dtactor = DTActor( + 3, + 4, + transformer_config=DecisionTransformer.DTConfig( + n_layer=2, n_embd=16, n_positions=16, n_inner=16, n_head=2 + ), + ) + observations = torch.randn(*batch_dims, T, 3) + actions = torch.randn(*batch_dims, T, 4) + r2go = torch.randn(*batch_dims, T, 1) + out = dtactor(observations, actions, r2go) + assert out.shape == torch.Size([*batch_dims, T, 4]) + + @pytest.mark.parametrize("batch_dims", [[], [3], [3, 4]]) + def test_onlinedtactor(self, batch_dims, T=5): + dtactor = OnlineDTActor( + 3, + 4, + transformer_config=DecisionTransformer.DTConfig( + n_layer=2, n_embd=16, n_positions=16, n_inner=16, n_head=2 + ), + ) + observations = torch.randn(*batch_dims, T, 3) + actions = torch.randn(*batch_dims, T, 4) + r2go = torch.randn(*batch_dims, T, 1) + mu, sig = dtactor(observations, actions, r2go) + assert mu.shape == torch.Size([*batch_dims, T, 4]) + assert sig.shape == torch.Size([*batch_dims, T, 4]) + assert (dtactor.log_std_min < sig.log()).all() + assert (dtactor.log_std_max > sig.log()).all() + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index ddcdf0f7535..38d7ac56cc7 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -18,9 +18,14 @@ from torchrl.envs.utils import set_exploration_type, step_mdp from torchrl.modules import ( AdditiveGaussianWrapper, + DecisionTransformerInferenceWrapper, + DTActor, LSTMModule, NormalParamWrapper, + OnlineDTActor, + ProbabilisticActor, SafeModule, + TanhDelta, TanhNormal, ValueOperator, ) @@ -1786,6 +1791,70 @@ def test_vmapmodule(): assert (sample_in_td["x"][:, 0] == sample_in_td["y"]).all() +class TestDecisionTransformerInferenceWrapper: + @pytest.mark.parametrize("online", [True, False]) + def test_dt_inference_wrapper(self, online): + action_key = ("nested", ("action",)) + if online: + dtactor = OnlineDTActor( + state_dim=4, action_dim=2, transformer_config=DTActor.default_config() + ) + in_keys = ["loc", "scale"] + actor_module = TensorDictModule( + dtactor, + in_keys=["observation", action_key, "return_to_go"], + out_keys=in_keys, + ) + dist_class = TanhNormal + else: + dtactor = DTActor( + state_dim=4, action_dim=2, transformer_config=DTActor.default_config() + ) + in_keys = ["param"] + actor_module = TensorDictModule( + dtactor, + in_keys=["observation", action_key, "return_to_go"], + out_keys=in_keys, + ) + dist_class = TanhDelta + dist_kwargs = { + "min": -1.0, + "max": 1.0, + } + actor = ProbabilisticActor( + in_keys=in_keys, + out_keys=[action_key], + module=actor_module, + distribution_class=dist_class, + distribution_kwargs=dist_kwargs, + ) + inference_actor = DecisionTransformerInferenceWrapper(actor) + sequence_length = 20 + td = TensorDict( + { + "observation": torch.randn(1, sequence_length, 4), + action_key: torch.randn(1, sequence_length, 2), + "return_to_go": torch.randn(1, sequence_length, 1), + }, + [1], + ) + with pytest.raises( + ValueError, + match="The action key action was not found in the policy out_keys", + ): + result = inference_actor(td) + inference_actor.set_tensor_keys(action=action_key) + result = inference_actor(td) + # checks that the seq length has disappeared + assert result.get(action_key).shape == torch.Size([1, 2]) + assert inference_actor.out_keys == unravel_key_list( + sorted([action_key, *in_keys, "observation", "return_to_go"], key=str) + ) + assert set(result.keys(True, True)) - set(td.keys(True, True)) == set( + inference_actor.out_keys + ) - set(inference_actor.in_keys) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index 0f16362cb1a..327ef1674ad 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -20,6 +20,7 @@ DdpgCnnQNet, DdpgMlpActor, DdpgMlpQNet, + DecisionTransformer, DistributionalDQNnet, DreamerActor, DTActor, diff --git a/torchrl/modules/models/__init__.py b/torchrl/modules/models/__init__.py index f2972cab4cd..ed3acf94edb 100644 --- a/torchrl/modules/models/__init__.py +++ b/torchrl/modules/models/__init__.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. +from .decision_transformer import DecisionTransformer from .exploration import NoisyLazyLinear, NoisyLinear, reset_noise from .model_based import DreamerActor, ObsDecoder, ObsEncoder, RSSMPosterior, RSSMPrior from .models import ( diff --git a/torchrl/modules/models/decision_transformer.py b/torchrl/modules/models/decision_transformer.py index ab5095b3956..628e7f79b7b 100644 --- a/torchrl/modules/models/decision_transformer.py +++ b/torchrl/modules/models/decision_transformer.py @@ -2,8 +2,13 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import dataclasses import importlib +from dataclasses import dataclass +from typing import Any import torch import torch.nn as nn @@ -31,46 +36,60 @@ class DecisionTransformer(nn.Module): Args: state_dim (int): dimension of the state space action_dim (int): dimension of the action space - config (dict): transformer architecture configuration, used to create the GPT2Config from transformers. + config (:obj:`~.DTConfig` or dict, optional): transformer architecture configuration, + used to create the GPT2Config from transformers. + Defaults to :obj:`~.default_config`. Example: - >>> config = { - ... "n_embd": 256, - ... "n_layer": 4, - ... "n_head": 4, - ... "n_inner": 1024, - ... "activation": "relu", - ... "n_positions": 1024,clear - ... "resid_pdrop": 0.1, - ... "attn_pdrop": 0.1, - ... } + >>> config = DecisionTransformer.default_config() + >>> config.n_embd = 128 + >>> print(config) + DTConfig(n_embd: 128, n_layer: 4, n_head: 4, n_inner: 1024, activation: relu, n_positions: 1024, resid_pdrop: 0.1, attn_pdrop: 0.1) + >>> # alternatively + >>> config = DecisionTransformer.DTConfig(n_embd=128) >>> model = DecisionTransformer(state_dim=4, action_dim=2, config=config) - >>> observation = torch.randn(32, 10, 4) - >>> action = torch.randn(32, 10, 2) - >>> return_to_go = torch.randn(32, 10, 1) + >>> batch_size = [3, 32] + >>> length = 10 + >>> observation = torch.randn(*batch_size, length, 4) + >>> action = torch.randn(*batch_size, length, 2) + >>> return_to_go = torch.randn(*batch_size, length, 1) >>> output = model(observation, action, return_to_go) >>> output.shape - torch.Size([32, 10, 256]) + torch.Size([3, 32, 10, 128]) """ - default_config = { - "n_embd": 256, - "n_layer": 4, - "n_head": 4, - "n_inner": 1024, - "activation": "relu", - "n_positions": 1024, - "resid_pdrop": 0.1, - "attn_pdrop": 0.1, - } + @dataclass + class DTConfig: + """Default configuration for DecisionTransformer.""" + + n_embd: Any = 256 + n_layer: Any = 4 + n_head: Any = 4 + n_inner: Any = 1024 + activation: Any = "relu" + n_positions: Any = 1024 + resid_pdrop: Any = 0.1 + attn_pdrop: Any = 0.1 + + def __repr__(self): + fields = [] + for f in dataclasses.fields(self): + value = getattr(self, f.name) + fields.append(f"{f.name}: {value}") + fields = ", ".join(fields) + return f"{self.__class__.__name__}({fields})" + + @classmethod + def default_config(cls): + return cls.DTConfig() def __init__( self, state_dim, action_dim, - config: dict = default_config, + config: dict | DTConfig = None, ): if not _has_transformers: raise ImportError( @@ -79,24 +98,29 @@ def __init__( import transformers from transformers.models.gpt2.modeling_gpt2 import GPT2Model - super(DecisionTransformer, self).__init__() + if config is None: + config = self.default_config() + if isinstance(config, self.DTConfig): + config = dataclasses.asdict(config) + if not isinstance(config, dict): + raise TypeError(f"Config of type {type(config)} is not supported.") - self.default_config.update(config) + super(DecisionTransformer, self).__init__() gpt_config = transformers.GPT2Config( - n_embd=self.default_config["n_embd"], - n_layer=self.default_config["n_layer"], - n_head=self.default_config["n_head"], - n_inner=self.default_config["n_inner"], - activation_function=self.default_config["activation"], - n_positions=self.default_config["n_positions"], - resid_pdrop=self.default_config["resid_pdrop"], - attn_pdrop=self.default_config["attn_pdrop"], + n_embd=config["n_embd"], + n_layer=config["n_layer"], + n_head=config["n_head"], + n_inner=config["n_inner"], + activation_function=config["activation"], + n_positions=config["n_positions"], + resid_pdrop=config["resid_pdrop"], + attn_pdrop=config["attn_pdrop"], vocab_size=1, ) self.state_dim = state_dim self.action_dim = action_dim - self.hidden_size = self.default_config["n_embd"] + self.hidden_size = config["n_embd"] self.transformer = GPT2Model(config=gpt_config) @@ -112,7 +136,14 @@ def forward( action: torch.Tensor, return_to_go: torch.Tensor, ): - batch_size, seq_length = observation.shape[0], observation.shape[1] + batch_size, seq_length = observation.shape[:-2], observation.shape[-2] + batch_size_orig = batch_size + if len(batch_size) != 1: + # TODO: vmap over transformer once this is possible + observation = observation.view(-1, *observation.shape[-2:]) + action = action.view(-1, *action.shape[-2:]) + return_to_go = return_to_go.view(-1, *return_to_go.shape[-2:]) + batch_size = torch.Size([batch_size.numel()]) # embed each modality with a different head state_embeddings = self.embed_state(observation) @@ -123,10 +154,10 @@ def forward( # which works nice in an autoregressive sense since states predict actions stacked_inputs = ( torch.stack( - (returns_embeddings, state_embeddings, action_embeddings), dim=1 + (returns_embeddings, state_embeddings, action_embeddings), dim=-2 ) - .permute(0, 2, 1, 3) - .reshape(batch_size, 3 * seq_length, self.hidden_size) + .permute(*range(len(batch_size)), -2, -3, -1) + .reshape(*batch_size, 3 * seq_length, self.hidden_size) ) stacked_inputs = self.embed_ln(stacked_inputs) @@ -138,6 +169,9 @@ def forward( # reshape x so that the second dimension corresponds to the original # returns (0), states (1), or actions (2); i.e. x[:,1,t] is the token for s_t - x = x.reshape(batch_size, seq_length, 3, self.hidden_size).permute(0, 2, 1, 3) - - return x[:, 1] # only state tokens + x = x.reshape(*batch_size, seq_length, 3, self.hidden_size).permute( + *range(len(batch_size)), -2, -3, -1 + ) + if batch_size_orig is batch_size: + return x[..., 1, :, :] # only state tokens + return x[..., 1, :, :].view(*batch_size_orig, *x.shape[-2:]) diff --git a/torchrl/modules/models/models.py b/torchrl/modules/models/models.py index 45b661ae604..d8a8bfb606e 100644 --- a/torchrl/modules/models/models.py +++ b/torchrl/modules/models/models.py @@ -2,6 +2,10 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import dataclasses + import warnings from numbers import Number from typing import Dict, List, Optional, Sequence, Tuple, Type, Union @@ -1149,12 +1153,14 @@ class OnlineDTActor(nn.Module): Args: state_dim (int): state dimension. action_dim (int): action dimension. - transformer_config (Dict): config for the GPT2 transformer. + transformer_config (Dict or :class:`DecisionTransformer.DTConfig`): + config for the GPT2 transformer. + Defaults to :meth:`~.default_config`. device (Optional[DEVICE_TYPING], optional): device to use. Defaults to None. Examples: >>> model = OnlineDTActor(state_dim=4, action_dim=2, - ... transformer_config=OnlineDTActor.get_default_config()) + ... transformer_config=OnlineDTActor.default_config()) >>> observation = torch.randn(32, 10, 4) >>> action = torch.randn(32, 10, 2) >>> return_to_go = torch.randn(32, 10, 1) @@ -1169,10 +1175,14 @@ def __init__( self, state_dim: int, action_dim: int, - transformer_config: Dict, + transformer_config: Dict | DecisionTransformer.DTConfig = None, device: Optional[DEVICE_TYPING] = None, ): super().__init__() + if transformer_config is None: + transformer_config = self.default_config() + if isinstance(transformer_config, DecisionTransformer.DTConfig): + transformer_config = dataclasses.asdict(transformer_config) self.transformer = DecisionTransformer( state_dim=state_dim, action_dim=action_dim, @@ -1198,7 +1208,7 @@ def forward( observation: torch.Tensor, action: torch.Tensor, return_to_go: torch.Tensor, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, torch.Tensor]: hidden_state = self.transformer(observation, action, return_to_go) out = self.action_layer(hidden_state) mu, log_std = torch.chunk(out, 2, -1) @@ -1213,18 +1223,18 @@ def forward( return mu, std @classmethod - def get_default_config(cls): + def default_config(cls): """Default configuration for :class:`~.OnlineDTActor`.""" - return { - "n_embd": 256, - "n_layer": 4, - "n_head": 4, - "n_inner": 1024, - "activation": "relu", - "n_positions": 1024, - "resid_pdrop": 0.1, - "attn_pdrop": 0.1, - } + return DecisionTransformer.DTConfig( + n_embd=256, + n_layer=4, + n_head=4, + n_inner=1024, + activation="relu", + n_positions=1024, + resid_pdrop=0.1, + attn_pdrop=0.1, + ) class DTActor(nn.Module): @@ -1236,12 +1246,14 @@ class DTActor(nn.Module): Args: state_dim (int): state dimension. action_dim (int): action dimension. - transformer_config (Dict): config for the GPT2 transformer. + transformer_config (Dict or :class:`DecisionTransformer.DTConfig`, optional): + config for the GPT2 transformer. + Defaults to :meth:`~.default_config`. device (Optional[DEVICE_TYPING], optional): device to use. Defaults to None. Examples: >>> model = DTActor(state_dim=4, action_dim=2, - ... transformer_config=DTActor.get_default_config()) + ... transformer_config=DTActor.default_config()) >>> observation = torch.randn(32, 10, 4) >>> action = torch.randn(32, 10, 2) >>> return_to_go = torch.randn(32, 10, 1) @@ -1255,10 +1267,14 @@ def __init__( self, state_dim: int, action_dim: int, - transformer_config: Dict, + transformer_config: Dict | DecisionTransformer.DTConfig = None, device: Optional[DEVICE_TYPING] = None, ): super().__init__() + if transformer_config is None: + transformer_config = self.default_config() + if isinstance(transformer_config, DecisionTransformer.DTConfig): + transformer_config = dataclasses.asdict(transformer_config) self.transformer = DecisionTransformer( state_dim=state_dim, action_dim=action_dim, @@ -1288,15 +1304,15 @@ def forward( return out @classmethod - def get_default_config(cls): + def default_config(cls): """Default configuration for :class:`~.DTActor`.""" - return { - "n_embd": 256, - "n_layer": 4, - "n_head": 4, - "n_inner": 1024, - "activation": "relu", - "n_positions": 1024, - "resid_pdrop": 0.1, - "attn_pdrop": 0.1, - } + return DecisionTransformer.DTConfig( + n_embd=256, + n_layer=4, + n_head=4, + n_inner=1024, + activation="relu", + n_positions=1024, + resid_pdrop=0.1, + attn_pdrop=0.1, + ) diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index 34c5b2c9e48..e7cc2a020c9 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -7,7 +7,7 @@ import torch -from tensordict import TensorDictBase +from tensordict import TensorDictBase, unravel_key from tensordict.nn import ( dispatch, TensorDictModule, @@ -1619,6 +1619,18 @@ class DecisionTransformerInferenceWrapper(TensorDictModuleWrapper): The output will be a TensorDict with the same keys as the input, but with only the last action of the predicted action sequence and the last return to go. + This module creates returns a modified copy of the tensordict, ie. it does + **not** modify the tensordict in-place. + + .. note:: If the action, observation or reward-to-go key is not standard, + the method :meth:`~.set_tensor_keys` should be used, e.g. + + >>> dt_inference_wrapper.set_tensor_keys(action="foo", observation="bar", return_to_go="baz") + + The in_keys are the observation, action and return-to-go keys. The out-keys + match the in-keys, with the addition of any other out-key from the policy + (eg., parameters of the distribution or hidden values). + Args: policy (TensorDictModule): The policy module that takes in observations and produces an action value @@ -1640,7 +1652,7 @@ class DecisionTransformerInferenceWrapper(TensorDictModuleWrapper): ... DecisionTransformerInferenceWrapper, ... ) >>> dtactor = DTActor(state_dim=4, action_dim=2, - ... transformer_config=DTActor.get_default_config() + ... transformer_config=DTActor.default_config() ... ) >>> actor_module = TensorDictModule( ... dtactor, @@ -1701,6 +1713,7 @@ def __init__( self._spec[self.action_key] = None else: self._spec = CompositeSpec({key: None for key in policy.out_keys}) + self.checked = False @property def in_keys(self): @@ -1708,7 +1721,12 @@ def in_keys(self): @property def out_keys(self): - return [self.observation_key, self.action_key, self.return_to_go_key] + return sorted( + set(self.td_module.out_keys).union( + {self.observation_key, self.action_key, self.return_to_go_key} + ), + key=str, + ) def set_tensor_keys(self, **kwargs): """Sets the input keys of the module. @@ -1719,13 +1737,19 @@ def set_tensor_keys(self, **kwargs): return_to_go (NestedKey, optional): The return_to_go key. """ - observation_key = kwargs.pop("observation", None) - action_key = kwargs.pop("action", None) - return_to_go_key = kwargs.pop("return_to_go", None) + observation_key = unravel_key(kwargs.pop("observation", self.observation_key)) + action_key = unravel_key(kwargs.pop("action", self.action_key)) + return_to_go_key = unravel_key( + kwargs.pop("return_to_go", self.return_to_go_key) + ) if kwargs: raise TypeError( f"Got unknown input(s) {kwargs.keys()}. Accepted keys are 'action', 'return_to_go' and 'observation'." ) + if action_key not in self.td_module.out_keys: + raise ValueError( + f"The action key {action_key} was not found in the policy out_keys {self.td_module.out_keys}." + ) self.observation_key = observation_key self.action_key = action_key self.return_to_go_key = return_to_go_key @@ -1742,9 +1766,9 @@ def _check_tensor_dims(reward, obs, action): def mask_context(self, tensordict: TensorDictBase) -> TensorDictBase: """Mask the context of the input sequences.""" - observation = tensordict.get(self.observation_key) - action = tensordict.get(self.action_key) - return_to_go = tensordict.get(self.return_to_go_key) + observation = tensordict.get(self.observation_key).clone() + action = tensordict.get(self.action_key).clone() + return_to_go = tensordict.get(self.return_to_go_key).clone() self._check_tensor_dims(return_to_go, observation, action) observation[..., : -self.inference_context, :] = 0 @@ -1765,27 +1789,35 @@ def mask_context(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict.set(self.return_to_go_key, return_to_go) return tensordict + def check_keys(self): + # an exception will be raised if the action key mismatch + self.set_tensor_keys() + self.checked = True + + @dispatch def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + if not self.checked: + self.check_keys() """Forward pass of the inference wrapper.""" - unmasked_tensordict = tensordict.clone(False) + tensordict = tensordict.clone(False) + obs = tensordict.get(self.observation_key) # Mask the context of the input sequences tensordict = self.mask_context(tensordict) # forward pass tensordict = self.td_module.forward(tensordict) - # get last action prediciton + # get last action predicton out_action = tensordict.get(self.action_key) - idx = (slice(None),) * tensordict.ndim + (-1,) - out_action = out_action[idx] + if tensordict.ndim == out_action.ndim - 1: + # then time dimension is in the TD's dimensions, and we must get rid of it + tensordict.batch_size = tensordict.batch_size[:-1] + out_action = out_action[..., -1, :] tensordict.set(self.action_key, out_action) # out_rtg = tensordict.get(self.return_to_go_key)[:, -1] out_rtg = tensordict.get(self.return_to_go_key) - idx = (slice(None),) * tensordict.ndim + (-1,) - out_rtg = out_rtg[idx] + out_rtg = out_rtg[..., -1, :] tensordict.set(self.return_to_go_key, out_rtg) # set unmasked observation - tensordict.set( - self.observation_key, unmasked_tensordict.get(self.observation_key) - ) + tensordict.set(self.observation_key, obs) return tensordict From 29a106792d5e3b2b501734e6d0b6b98c17792666 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 11 Jul 2023 08:09:39 -0400 Subject: [PATCH 090/104] amend --- benchmarks/conftest.py | 22 ++++++++++++++++++- test/conftest.py | 22 ++++++++++++++++++- test/test_actors.py | 5 +++-- test/test_tensordictmodules.py | 4 ++++ .../modules/models/decision_transformer.py | 7 +++++- 5 files changed, 55 insertions(+), 5 deletions(-) diff --git a/benchmarks/conftest.py b/benchmarks/conftest.py index d786cc4244d..7f320ff2e8d 100644 --- a/benchmarks/conftest.py +++ b/benchmarks/conftest.py @@ -57,7 +57,7 @@ def pytest_addoption(parser): parser.addoption("--rank", action="store") -@pytest.fixture(autouse=True) +@pytest.fixture(scope="session", autouse=True) def set_warnings() -> None: warnings.filterwarnings( "ignore", @@ -69,3 +69,23 @@ def set_warnings() -> None: category=UserWarning, message=r"Couldn't cast the policy onto the desired device on remote process", ) + warnings.filterwarnings( + "ignore", + category=DeprecationWarning, + message=r"Deprecated call to `pkg_resources.declare_namespace", + ) + warnings.filterwarnings( + "ignore", + category=DeprecationWarning, + message=r"Using or importing the ABCs", + ) + warnings.filterwarnings( + "ignore", + category=DeprecationWarning, + message=r"Please use `coo_matrix` from the `scipy.sparse` namespace", + ) + warnings.filterwarnings( + "ignore", + category=DeprecationWarning, + message=r"jax.tree_util.register_keypaths is deprecated|jax.ShapedArray is deprecated", + ) diff --git a/test/conftest.py b/test/conftest.py index c5cfdd680e7..048b9e6c49e 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -54,7 +54,7 @@ def fin(): request.addfinalizer(fin) -@pytest.fixture(autouse=True) +@pytest.fixture(scope="session", autouse=True) def set_warnings() -> None: warnings.filterwarnings( "ignore", @@ -66,3 +66,23 @@ def set_warnings() -> None: category=UserWarning, message=r"Couldn't cast the policy onto the desired device on remote process", ) + warnings.filterwarnings( + "ignore", + category=DeprecationWarning, + message=r"Deprecated call to `pkg_resources.declare_namespace", + ) + warnings.filterwarnings( + "ignore", + category=DeprecationWarning, + message=r"Using or importing the ABCs", + ) + warnings.filterwarnings( + "ignore", + category=DeprecationWarning, + message=r"Please use `coo_matrix` from the `scipy.sparse` namespace", + ) + warnings.filterwarnings( + "ignore", + category=DeprecationWarning, + message=r"jax.tree_util.register_keypaths is deprecated|jax.ShapedArray is deprecated", + ) diff --git a/test/test_actors.py b/test/test_actors.py index ee358cbe25a..7837ea74e7a 100644 --- a/test/test_actors.py +++ b/test/test_actors.py @@ -759,7 +759,7 @@ def test_lmhead_actorvalueoperator(device): from transformers import AutoModelForCausalLM base_model = AutoModelForCausalLM.from_pretrained("gpt2", return_dict=False) - aco = LMHeadActorValueOperator(base_model) + aco = LMHeadActorValueOperator(base_model).to(device) # check common assert aco.module[0][0].module is base_model.transformer @@ -786,7 +786,8 @@ def test_lmhead_actorvalueoperator(device): batch_size=[ 4, ], - ).to(device) + device=device, + ) td_total = aco(td.clone()) policy_op = aco.get_policy_operator() td_policy = policy_op(td.clone()) diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index 38d7ac56cc7..1edd61ccac8 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -29,6 +29,7 @@ TanhNormal, ValueOperator, ) +from torchrl.modules.models.decision_transformer import _has_transformers from torchrl.modules.tensordict_module.common import ( ensure_tensordict_compatible, is_tensordict_compatible, @@ -1791,6 +1792,9 @@ def test_vmapmodule(): assert (sample_in_td["x"][:, 0] == sample_in_td["y"]).all() +@pytest.mark.skipif( + not _has_transformers, reason="transformers needed to test DT classes" +) class TestDecisionTransformerInferenceWrapper: @pytest.mark.parametrize("online", [True, False]) def test_dt_inference_wrapper(self, online): diff --git a/torchrl/modules/models/decision_transformer.py b/torchrl/modules/models/decision_transformer.py index 628e7f79b7b..befe0861da7 100644 --- a/torchrl/modules/models/decision_transformer.py +++ b/torchrl/modules/models/decision_transformer.py @@ -103,7 +103,12 @@ def __init__( if isinstance(config, self.DTConfig): config = dataclasses.asdict(config) if not isinstance(config, dict): - raise TypeError(f"Config of type {type(config)} is not supported.") + try: + config = dict(config) + except Exception as err: + raise TypeError( + f"Config of type {type(config)} is not supported." + ) from err super(DecisionTransformer, self).__init__() From 0b8d564e973d97b66d6669dd165675179fcc0e0a Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 11 Jul 2023 09:39:41 -0400 Subject: [PATCH 091/104] amend --- .../linux_examples/scripts/run_test.sh | 3 ++- examples/iql/iql_online.py | 24 ++++++++++++------- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/.circleci/unittest/linux_examples/scripts/run_test.sh b/.circleci/unittest/linux_examples/scripts/run_test.sh index bcf4688412f..db39010e48e 100755 --- a/.circleci/unittest/linux_examples/scripts/run_test.sh +++ b/.circleci/unittest/linux_examples/scripts/run_test.sh @@ -155,7 +155,8 @@ python .circleci/unittest/helpers/coverage_run_parallel.py examples/iql/iql_onli env_per_collector=2 \ collector_device=cuda:0 \ device=cuda:0 \ - mode=offline + mode=offline \ + logger= # With single envs python .circleci/unittest/helpers/coverage_run_parallel.py examples/dreamer/dreamer.py \ diff --git a/examples/iql/iql_online.py b/examples/iql/iql_online.py index cbe9f697a65..5d75badcace 100644 --- a/examples/iql/iql_online.py +++ b/examples/iql/iql_online.py @@ -76,12 +76,14 @@ def main(cfg: "DictConfig"): # noqa: F821 device = torch.device(cfg.device) exp_name = generate_exp_name("Online_IQL", cfg.exp_name) - logger = get_logger( - logger_type=cfg.logger, - logger_name="iql_logging", - experiment_name=exp_name, - wandb_kwargs={"mode": cfg.mode}, - ) + logger = None + if cfg.logger is not None: + logger = get_logger( + logger_type=cfg.logger, + logger_name="iql_logging", + experiment_name=exp_name, + wandb_kwargs={"mode": cfg.mode}, + ) torch.manual_seed(cfg.seed) np.random.seed(cfg.seed) @@ -300,8 +302,9 @@ def env_factory(num_workers): "value_loss": np.mean(value_losses), } ) - for key, value in train_log.items(): - logger.log_scalar(key, value, step=collected_frames) + if logger is not None: + for key, value in train_log.items(): + logger.log_scalar(key, value, step=collected_frames) with set_exploration_type(ExplorationType.MEAN), torch.no_grad(): eval_rollout = test_env.rollout( @@ -312,7 +315,10 @@ def env_factory(num_workers): eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() rewards_eval.append((i, eval_reward)) eval_str = f"eval cumulative reward: {rewards_eval[-1][1]: 4.4f} (init: {rewards_eval[0][1]: 4.4f})" - logger.log_scalar("test_reward", rewards_eval[-1][1], step=collected_frames) + if logger is not None: + logger.log_scalar( + "test_reward", rewards_eval[-1][1], step=collected_frames + ) if len(rewards_eval): pbar.set_description( f"reward: {rewards[-1][1]: 4.4f} (r0 = {r0: 4.4f})," + eval_str From b08d3d451a8440dccf2f76f43ab37ea811ddfe46 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 11 Jul 2023 11:18:58 -0400 Subject: [PATCH 092/104] fix --- .circleci/unittest/linux_examples/scripts/run_test.sh | 3 ++- examples/iql/iql_online.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.circleci/unittest/linux_examples/scripts/run_test.sh b/.circleci/unittest/linux_examples/scripts/run_test.sh index db39010e48e..881b38669a7 100755 --- a/.circleci/unittest/linux_examples/scripts/run_test.sh +++ b/.circleci/unittest/linux_examples/scripts/run_test.sh @@ -254,7 +254,8 @@ python .circleci/unittest/helpers/coverage_run_parallel.py examples/iql/iql_onli env_per_collector=1 \ mode=offline \ device=cuda:0 \ - collector_device=cuda:0 + collector_device=cuda:0 \ + logger= python .circleci/unittest/helpers/coverage_run_parallel.py examples/td3/td3.py \ collector.total_frames=48 \ collector.init_random_frames=10 \ diff --git a/examples/iql/iql_online.py b/examples/iql/iql_online.py index 5d75badcace..16014f4f3ec 100644 --- a/examples/iql/iql_online.py +++ b/examples/iql/iql_online.py @@ -77,7 +77,7 @@ def main(cfg: "DictConfig"): # noqa: F821 exp_name = generate_exp_name("Online_IQL", cfg.exp_name) logger = None - if cfg.logger is not None: + if cfg.logger: logger = get_logger( logger_type=cfg.logger, logger_name="iql_logging", From a522db0b624a15b8d45b37105992d8c215f47b45 Mon Sep 17 00:00:00 2001 From: BY571 Date: Wed, 12 Jul 2023 19:33:22 +0200 Subject: [PATCH 093/104] fix reward scale, reduce target return config --- examples/decision_transformer/dt_config.yaml | 2 +- examples/decision_transformer/odt_config.yaml | 2 +- examples/decision_transformer/online_dt.py | 2 +- examples/decision_transformer/utils.py | 23 +++++++------------ torchrl/objectives/decision_transformer.py | 2 +- 5 files changed, 12 insertions(+), 19 deletions(-) diff --git a/examples/decision_transformer/dt_config.yaml b/examples/decision_transformer/dt_config.yaml index 65ba26e4664..69ced6be5d8 100644 --- a/examples/decision_transformer/dt_config.yaml +++ b/examples/decision_transformer/dt_config.yaml @@ -12,7 +12,7 @@ env: reward_scaling: 0.001 # for r2g noop: 1 seed: 1 - target_return_mode: constant + target_return_mode: reduce eval_target_return: 6000 collect_target_return: 12000 diff --git a/examples/decision_transformer/odt_config.yaml b/examples/decision_transformer/odt_config.yaml index 912a29b5cbf..5a6084bc4e1 100644 --- a/examples/decision_transformer/odt_config.yaml +++ b/examples/decision_transformer/odt_config.yaml @@ -12,7 +12,7 @@ env: reward_scaling: 0.001 # for r2g noop: 1 seed: 2 - target_return_mode: constant + target_return_mode: reduce eval_target_return: 6000 collect_target_return: 12000 diff --git a/examples/decision_transformer/online_dt.py b/examples/decision_transformer/online_dt.py index 07b5e298fe5..1e25ea7e4f9 100644 --- a/examples/decision_transformer/online_dt.py +++ b/examples/decision_transformer/online_dt.py @@ -39,7 +39,7 @@ def main(cfg: "DictConfig"): # noqa: F821 actor = make_odt_model(cfg) policy = actor.to(model_device) - loss_module = make_odt_loss(cfg.loss, actor) + loss_module = make_odt_loss(cfg.loss, policy) transformer_optim, temperature_optim, scheduler = make_odt_optimizer( cfg.optim, policy, loss_module ) diff --git a/examples/decision_transformer/utils.py b/examples/decision_transformer/utils.py index 24d537ac67b..fe57a352d1f 100644 --- a/examples/decision_transformer/utils.py +++ b/examples/decision_transformer/utils.py @@ -66,10 +66,15 @@ def make_base_env(env_cfg): def make_transformed_env(base_env, env_cfg, obs_loc, obs_std, train=False): transformed_env = TransformedEnv(base_env) + transformed_env.append_transform( + RewardScaling( + loc=0, scale=env_cfg.reward_scaling, in_keys="reward", standard_normal=False + ) + ) if train: transformed_env.append_transform( TargetReturn( - env_cfg.collect_target_return, + env_cfg.collect_target_return * env_cfg.reward_scaling, out_keys=["return_to_go"], mode=env_cfg.target_return_mode, ) @@ -77,24 +82,12 @@ def make_transformed_env(base_env, env_cfg, obs_loc, obs_std, train=False): else: transformed_env.append_transform( TargetReturn( - env_cfg.eval_target_return, + env_cfg.eval_target_return * env_cfg.reward_scaling, out_keys=["return_to_go"], mode=env_cfg.target_return_mode, ) ) - transformed_env.append_transform( - RewardScaling( - loc=0, - scale=env_cfg.reward_scaling, - in_keys="return_to_go", - standard_normal=False, - ) - ) - transformed_env.append_transform( - RewardScaling( - loc=0, scale=env_cfg.reward_scaling, in_keys="reward", standard_normal=False - ) - ) + transformed_env.append_transform(TensorDictPrimer(action=base_env.action_spec)) transformed_env.append_transform( diff --git a/torchrl/objectives/decision_transformer.py b/torchrl/objectives/decision_transformer.py index 9a1693c74a0..24a975d4ee9 100644 --- a/torchrl/objectives/decision_transformer.py +++ b/torchrl/objectives/decision_transformer.py @@ -203,7 +203,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: out = { "loss_log_likelihood": -loss_log_likelihood, - "loss_entropy": loss_entropy, + "loss_entropy": -loss_entropy, "loss_alpha": loss_alpha, "entropy": entropy.detach(), "alpha": self.alpha.detach(), From aefbf61a9b02439664f7305c42c6e4ec858da662 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 13 Jul 2023 08:59:05 -0400 Subject: [PATCH 094/104] amend --- .circleci/unittest/linux_libs/scripts_habitat/run_test.sh | 3 ++- torchrl/data/replay_buffers/replay_buffers.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.circleci/unittest/linux_libs/scripts_habitat/run_test.sh b/.circleci/unittest/linux_libs/scripts_habitat/run_test.sh index 1c2f7e19cb0..1f916fea9c1 100755 --- a/.circleci/unittest/linux_libs/scripts_habitat/run_test.sh +++ b/.circleci/unittest/linux_libs/scripts_habitat/run_test.sh @@ -10,9 +10,10 @@ conda activate ./env # https://stackoverflow.com/questions/72540359/glibcxx-3-4-30-not-found-for-librosa-in-conda-virtual-environment-after-tryin #conda install -y -c conda-forge gcc=12.1.0 conda install -y -c conda-forge libstdcxx-ng=12 +conda env config vars set LD_PRELOAD=$LD_PRELOAD:$STDC_LOC + ## find libstdc STDC_LOC=$(find conda/ -name "libstdc++.so.6" | head -1) -conda env config vars set LD_PRELOAD=$LD_PRELOAD:$STDC_LOC export PYTORCH_TEST_WITH_SLOW='1' python -m torch.utils.collect_env diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index 0035769efc7..3369e2004bc 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -272,7 +272,7 @@ def extend(self, data: Sequence) -> torch.Tensor: Indices of the data added to the replay buffer. """ if self._transform is not None and is_tensor_collection(data): - data = self._transform.inv(data) # test + data = self._transform.inv(data) elif self._transform is not None and len(self._transform): data = self._transform.inv(data) return self._extend(data) From 1c7cbbf28a2c6a086a8ba054025c7aa6c2730e29 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 13 Jul 2023 08:59:32 -0400 Subject: [PATCH 095/104] amend --- torchrl/envs/transforms/transforms.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 1056914ade5..f7ca64d0923 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -1684,15 +1684,15 @@ class ObservationNorm(ObservationTransform): Args: loc (number or tensor): location of the affine transform scale (number or tensor): scale of the affine transform - in_keys (seuqence of NestedKey, optional): entries to be normalized. Defaults to ["observation", "pixels"]. + in_keys (sequence of NestedKey, optional): entries to be normalized. Defaults to ["observation", "pixels"]. All entries will be normalized with the same values: if a different behaviour is desired (e.g. a different normalization for pixels and states) different :obj:`ObservationNorm` objects should be used. - out_keys (seuqence of NestedKey, optional): output entries. Defaults to the value of `in_keys`. - in_keys_inv (seuqence of NestedKey, optional): ObservationNorm also supports inverse transforms. This will + out_keys (sequence of NestedKey, optional): output entries. Defaults to the value of `in_keys`. + in_keys_inv (sequence of NestedKey, optional): ObservationNorm also supports inverse transforms. This will only occur if a list of keys is provided to :obj:`in_keys_inv`. If none is provided, only the forward transform will be called. - out_keys_inv (seuqence of NestedKey, optional): output entries for the inverse transform. + out_keys_inv (sequence of NestedKey, optional): output entries for the inverse transform. Defaults to the value of `in_keys_inv`. standard_normal (bool, optional): if ``True``, the transform will be @@ -1964,9 +1964,9 @@ class CatFrames(ObservationTransform): dim (int): dimension along which concatenate the observations. Should be negative, to ensure that it is compatible with environments of different batch_size. - in_keys (seuqence of NestedKey, optional): keys pointing to the frames that have + in_keys (sequence of NestedKey, optional): keys pointing to the frames that have to be concatenated. Defaults to ["pixels"]. - out_keys (seuqence of NestedKey, optional): keys pointing to where the output + out_keys (sequence of NestedKey, optional): keys pointing to where the output has to be written. Defaults to the value of `in_keys`. padding (str, optional): the padding method. One of ``"same"`` or ``"zeros"``. Defaults to ``"same"``, ie. the first value is uesd for padding. From 11d87792bf886594758ca43132907639d3880669 Mon Sep 17 00:00:00 2001 From: BY571 Date: Wed, 26 Jul 2023 11:20:03 +0200 Subject: [PATCH 096/104] zero padding, fix obs loc, std for normalization --- examples/decision_transformer/utils.py | 18 +++++++++--------- torchrl/modules/models/models.py | 20 ++++++++++++-------- torchrl/objectives/decision_transformer.py | 2 +- 3 files changed, 22 insertions(+), 18 deletions(-) diff --git a/examples/decision_transformer/utils.py b/examples/decision_transformer/utils.py index fe57a352d1f..46d3506758f 100644 --- a/examples/decision_transformer/utils.py +++ b/examples/decision_transformer/utils.py @@ -96,6 +96,10 @@ def make_transformed_env(base_env, env_cfg, obs_loc, obs_std, train=False): in_keys_inv=[], ) ) + obsnorm = ObservationNorm( + loc=obs_loc, scale=obs_std, in_keys="observation", standard_normal=True + ) + transformed_env.append_transform(obsnorm) transformed_env.append_transform( UnsqueezeTransform(-2, in_keys=["observation", "action", "return_to_go"]) ) @@ -104,12 +108,9 @@ def make_transformed_env(base_env, env_cfg, obs_loc, obs_std, train=False): in_keys=["observation", "action", "return_to_go"], N=env_cfg.stacked_frames, dim=-2, + padding="zeros", ) ) - obsnorm = ObservationNorm( - loc=obs_loc, scale=obs_std, in_keys="observation", standard_normal=True - ) - transformed_env.append_transform(obsnorm) if train: transformed_env.append_transform(RewardSum()) @@ -176,11 +177,10 @@ def make_collector(cfg, policy): def get_loc_std(env_name): - data = D4RLExperienceReplay(env_name, 1024) - for sample in data: - loc = sample.get("observation").mean() - std = sample.get("observation").std() - break + buffer = D4RLExperienceReplay(env_name, 1024) + full_data = buffer._get_dataset_from_env(env_name, {}) + loc = full_data["observation"].mean(axis=0).float() + std = full_data["observation"].std(axis=0).float() return loc, std diff --git a/torchrl/modules/models/models.py b/torchrl/modules/models/models.py index d8a8bfb606e..63b84a097a0 100644 --- a/torchrl/modules/models/models.py +++ b/torchrl/modules/models/models.py @@ -1188,8 +1188,11 @@ def __init__( action_dim=action_dim, config=transformer_config, ) - self.action_layer = nn.Linear( - transformer_config["n_embd"], action_dim * 2, device=device + self.action_layer_mean = nn.Linear( + transformer_config["n_embd"], action_dim, device=device + ) + self.action_layer_logstd = nn.Linear( + transformer_config["n_embd"], action_dim, device=device ) self.log_std_min, self.log_std_max = -5.0, 2.0 @@ -1210,8 +1213,9 @@ def forward( return_to_go: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: hidden_state = self.transformer(observation, action, return_to_go) - out = self.action_layer(hidden_state) - mu, log_std = torch.chunk(out, 2, -1) + mu = self.action_layer_mean(hidden_state) + log_std = self.action_layer_logstd(hidden_state) + # mu, log_std = torch.chunk(out, 2, -1) log_std = torch.tanh(log_std) # log_std is the output of tanh so it will be between [-1, 1] # map it to be between [log_std_min, log_std_max] @@ -1226,10 +1230,10 @@ def forward( def default_config(cls): """Default configuration for :class:`~.OnlineDTActor`.""" return DecisionTransformer.DTConfig( - n_embd=256, + n_embd=512, n_layer=4, n_head=4, - n_inner=1024, + n_inner=2048, activation="relu", n_positions=1024, resid_pdrop=0.1, @@ -1307,10 +1311,10 @@ def forward( def default_config(cls): """Default configuration for :class:`~.DTActor`.""" return DecisionTransformer.DTConfig( - n_embd=256, + n_embd=512, n_layer=4, n_head=4, - n_inner=1024, + n_inner=2048, activation="relu", n_positions=1024, resid_pdrop=0.1, diff --git a/torchrl/objectives/decision_transformer.py b/torchrl/objectives/decision_transformer.py index 24a975d4ee9..6543652338a 100644 --- a/torchrl/objectives/decision_transformer.py +++ b/torchrl/objectives/decision_transformer.py @@ -182,7 +182,7 @@ def out_keys(self, values): def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor: x = dist.rsample((self.samples_mc_entropy,)) log_p = dist.log_prob(x) - # log_p: (batch_size, context_len, + # log_p: (batch_size, context_len) return -log_p.mean(axis=0) @dispatch From 3ff2fc6fac86d928d60374f20f789814aa320fde Mon Sep 17 00:00:00 2001 From: vmoens Date: Sun, 30 Jul 2023 21:08:34 +0100 Subject: [PATCH 097/104] temp - SerialEnv --- examples/decision_transformer/odt_config.yaml | 2 +- examples/decision_transformer/utils.py | 4 +++- torchrl/objectives/decision_transformer.py | 8 ++++---- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/examples/decision_transformer/odt_config.yaml b/examples/decision_transformer/odt_config.yaml index 5a6084bc4e1..72302fdc5cd 100644 --- a/examples/decision_transformer/odt_config.yaml +++ b/examples/decision_transformer/odt_config.yaml @@ -40,7 +40,7 @@ replay_buffer: # Optimization optim: - device: cuda:0 + device: cpu lr: 1.0e-4 weight_decay: 5.0e-4 batch_size: 256 diff --git a/examples/decision_transformer/utils.py b/examples/decision_transformer/utils.py index 46d3506758f..34f67307706 100644 --- a/examples/decision_transformer/utils.py +++ b/examples/decision_transformer/utils.py @@ -16,6 +16,7 @@ NoopResetEnv, ObservationNorm, ParallelEnv, + SerialEnv, Reward2GoTransform, RewardScaling, RewardSum, @@ -129,7 +130,8 @@ def make_env(): return make_base_env(env_cfg) env = make_transformed_env( - ParallelEnv(num_envs, EnvCreator(make_env)), + # ParallelEnv(num_envs, EnvCreator(make_env)), + SerialEnv(num_envs, EnvCreator(make_env)), env_cfg, obs_loc, obs_std, diff --git a/torchrl/objectives/decision_transformer.py b/torchrl/objectives/decision_transformer.py index 6543652338a..768e05de9c6 100644 --- a/torchrl/objectives/decision_transformer.py +++ b/torchrl/objectives/decision_transformer.py @@ -195,15 +195,15 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict, params=self.actor_network_params ) - loss_log_likelihood = action_dist.log_prob(target_actions).mean() + log_likelihood = action_dist.log_prob(target_actions).mean() entropy = self.get_entropy_bonus(action_dist).mean() - loss_entropy = self.alpha.detach() * entropy + entropy_bonus = self.alpha.detach() * entropy loss_alpha = self.log_alpha.exp() * (entropy - self.target_entropy).detach() out = { - "loss_log_likelihood": -loss_log_likelihood, - "loss_entropy": -loss_entropy, + "loss_log_likelihood": -log_likelihood, + "loss_entropy": -entropy_bonus, "loss_alpha": loss_alpha, "entropy": entropy.detach(), "alpha": self.alpha.detach(), From 9135fa7ba13d58dd1163f1dcd474d7a5a3d48d95 Mon Sep 17 00:00:00 2001 From: BY571 Date: Fri, 4 Aug 2023 08:49:00 +0200 Subject: [PATCH 098/104] fix obs norm, fix action context --- examples/decision_transformer/utils.py | 22 +++++++-------------- torchrl/modules/tensordict_module/actors.py | 4 +++- 2 files changed, 10 insertions(+), 16 deletions(-) diff --git a/examples/decision_transformer/utils.py b/examples/decision_transformer/utils.py index 46d3506758f..a4c0b70e424 100644 --- a/examples/decision_transformer/utils.py +++ b/examples/decision_transformer/utils.py @@ -176,14 +176,6 @@ def make_collector(cfg, policy): return collector -def get_loc_std(env_name): - buffer = D4RLExperienceReplay(env_name, 1024) - full_data = buffer._get_dataset_from_env(env_name, {}) - loc = full_data["observation"].mean(axis=0).float() - std = full_data["observation"].std(axis=0).float() - return loc, std - - def make_offline_replay_buffer(rb_cfg, reward_scaling): r2g = Reward2GoTransform(gamma=1.0, in_keys=["reward"], out_keys=["return_to_go"]) reward_scale = RewardScaling( @@ -201,10 +193,6 @@ def make_offline_replay_buffer(rb_cfg, reward_scaling): in_keys=["observation", ("next", "observation")], in_keys_inv=[], ) - loc, std = get_loc_std(rb_cfg.dataset) - obsnorm = ObservationNorm( - loc=loc, scale=std, in_keys="observation", standard_normal=True - ) exclude = ExcludeTransform( "next_observations", "timeout", @@ -224,7 +212,6 @@ def make_offline_replay_buffer(rb_cfg, reward_scaling): reward_scale, d2f, exclude, - obsnorm, ) data = D4RLExperienceReplay( rb_cfg.dataset, @@ -233,8 +220,13 @@ def make_offline_replay_buffer(rb_cfg, reward_scaling): sampler=SamplerWithoutReplacement(drop_last=False), transform=transforms, ) - # TODO: add obsnorm here - + full_data = data._get_dataset_from_env(rb_cfg.dataset, {}) + loc = full_data["observation"].mean(axis=0).float() + std = full_data["observation"].std(axis=0).float() + obsnorm = ObservationNorm( + loc=loc, scale=std, in_keys="observation", standard_normal=True + ) + data.append_transform(obsnorm) return data, loc, std diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index 574ff5ed43c..1e214330ce5 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -1778,7 +1778,9 @@ def mask_context(self, tensordict: TensorDictBase) -> TensorDictBase: self._check_tensor_dims(return_to_go, observation, action) observation[..., : -self.inference_context, :] = 0 - action[..., : -self.inference_context, :] = 0 + action[ + ..., : -(self.inference_context - 1), : + ] = 0 # as we add zeros to the end of the action action = torch.cat( [ action[..., 1:, :], From c3a67c88a5041f92cc3f4a1baafceeab9fd1d74c Mon Sep 17 00:00:00 2001 From: BY571 Date: Mon, 7 Aug 2023 09:56:00 +0200 Subject: [PATCH 099/104] update buffer transforms to not use catframes --- examples/decision_transformer/utils.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/examples/decision_transformer/utils.py b/examples/decision_transformer/utils.py index a4c0b70e424..ac8487a02a1 100644 --- a/examples/decision_transformer/utils.py +++ b/examples/decision_transformer/utils.py @@ -16,6 +16,7 @@ NoopResetEnv, ObservationNorm, ParallelEnv, + RandomCropTensorDict, Reward2GoTransform, RewardScaling, RewardSum, @@ -181,13 +182,7 @@ def make_offline_replay_buffer(rb_cfg, reward_scaling): reward_scale = RewardScaling( loc=0, scale=reward_scaling, in_keys="return_to_go", standard_normal=False ) - catframes = CatFrames( - in_keys=["action", "observation", "return_to_go"], - N=rb_cfg.stacked_frames, - dim=-2, - padding="zeros", - as_inverse=True, - ) + crop_seq = RandomCropTensorDict(sub_seq_len=rb_cfg.stacked_frames, sample_dim=-1) d2f = DoubleToFloat( in_keys=["observation", ("next", "observation")], @@ -195,7 +190,7 @@ def make_offline_replay_buffer(rb_cfg, reward_scaling): ) exclude = ExcludeTransform( "next_observations", - "timeout", + # "timeout", "terminal", "info", ("next", "timeout"), @@ -205,20 +200,19 @@ def make_offline_replay_buffer(rb_cfg, reward_scaling): ) transforms = Compose( - # inverse transforms are called reversed - # therefore catframes before r2g - catframes, r2g, + crop_seq, reward_scale, d2f, exclude, ) data = D4RLExperienceReplay( rb_cfg.dataset, - split_trajs=False, + split_trajs=True, batch_size=rb_cfg.batch_size, sampler=SamplerWithoutReplacement(drop_last=False), transform=transforms, + use_timeout_as_done=True, ) full_data = data._get_dataset_from_env(rb_cfg.dataset, {}) loc = full_data["observation"].mean(axis=0).float() From ca505ebad787f470856c867db83e9baf8e6c4e29 Mon Sep 17 00:00:00 2001 From: BY571 Date: Mon, 14 Aug 2023 10:14:19 +0200 Subject: [PATCH 100/104] test dist, small fixes --- examples/decision_transformer/online_dt.py | 1 + examples/decision_transformer/utils.py | 107 ++++++++++++++++++--- torchrl/data/datasets/d4rl.py | 3 + torchrl/modules/distributions/utils.py | 2 +- torchrl/modules/models/models.py | 19 ++-- torchrl/objectives/decision_transformer.py | 12 ++- 6 files changed, 120 insertions(+), 24 deletions(-) diff --git a/examples/decision_transformer/online_dt.py b/examples/decision_transformer/online_dt.py index 1e25ea7e4f9..e18d7fe2593 100644 --- a/examples/decision_transformer/online_dt.py +++ b/examples/decision_transformer/online_dt.py @@ -87,6 +87,7 @@ def main(cfg: "DictConfig"): # noqa: F821 max_steps=eval_steps, policy=inference_policy, auto_cast_to_device=True, + break_when_any_done=False, ) inference_policy.train() if r0 is None: diff --git a/examples/decision_transformer/utils.py b/examples/decision_transformer/utils.py index ac8487a02a1..de7b1725f24 100644 --- a/examples/decision_transformer/utils.py +++ b/examples/decision_transformer/utils.py @@ -1,7 +1,13 @@ +import math + import torch.nn + +import torch.nn.functional as F import torch.optim from lamb import Lamb from tensordict.nn import TensorDictModule +from torch import distributions as pyd +from torch.distributions import constraints from torchrl.collectors import SyncDataCollector from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer @@ -15,11 +21,12 @@ ExcludeTransform, NoopResetEnv, ObservationNorm, - ParallelEnv, + # ParallelEnv, RandomCropTensorDict, Reward2GoTransform, RewardScaling, RewardSum, + SerialEnv, TargetReturn, TensorDictPrimer, TransformedEnv, @@ -33,14 +40,13 @@ OnlineDTActor, ProbabilisticActor, TanhDelta, - TanhNormal, + # TanhNormal, ) from torchrl.objectives import DTLoss, OnlineDTLoss from torchrl.record.loggers import generate_exp_name, get_logger from torchrl.trainers.helpers.envs import LIBS - # ==================================================================== # Environment utils # ----------------- @@ -130,7 +136,7 @@ def make_env(): return make_base_env(env_cfg) env = make_transformed_env( - ParallelEnv(num_envs, EnvCreator(make_env)), + SerialEnv(num_envs, EnvCreator(make_env)), env_cfg, obs_loc, obs_std, @@ -296,12 +302,12 @@ def make_odt_model(cfg): "scale", ], ) - dist_class = TanhNormal - dist_kwargs = { - "min": -1.0, - "max": 1.0, - "tanh_loc": False, - } + dist_class = SquashedNormal # TanhNormal + # dist_kwargs = { + # "min": -1.0, + # "max": 1.0, + # "tanh_loc": False, + # } actor = ProbabilisticActor( spec=action_spec, @@ -309,7 +315,7 @@ def make_odt_model(cfg): out_keys=["action"], module=actor_module, distribution_class=dist_class, - distribution_kwargs=dist_kwargs, + # distribution_kwargs=dist_kwargs, default_interaction_mode="random", cache_dist=False, return_log_prob=False, @@ -451,3 +457,82 @@ def make_logger(cfg): wandb_kwargs={"config": cfg}, ) return logger + + +class TanhTransform(pyd.transforms.Transform): + domain = pyd.constraints.real + codomain = pyd.constraints.interval(-1.0, 1.0) + bijective = True + sign = +1 + + def __init__(self, cache_size=1): + super().__init__(cache_size=cache_size) + + @staticmethod + def atanh(x): + return 0.5 * (x.log1p() - (-x).log1p()) + + def __eq__(self, other): + return isinstance(other, TanhTransform) + + def _call(self, x): + return x.tanh() + + def _inverse(self, y): + # We do not clamp to the boundary here as it may degrade the performance of certain algorithms. + # one should use `cache_size=1` instead + return self.atanh(y) + + def log_abs_det_jacobian(self, x, y): + # We use a formula that is more numerically stable, see details in the following link + # https://github.com/tensorflow/probability/commit/ef6bb176e0ebd1cf6e25c6b5cecdd2428c22963f#diff-e120f70e92e6741bca649f04fcd907b7 + return 2.0 * (math.log(2.0) - x - F.softplus(-2.0 * x)) + + +class SquashedNormal( + pyd.transformed_distribution.TransformedDistribution +): # FasterTransformedDistribution): # + """ + Squashed Normal Distribution(s) + + If loc/std is of size (batch_size, sequence length, d), + this returns batch_size * sequence length * d + independent squashed univariate normal distributions. + """ + + arg_constraints = { + "loc": constraints.real, + "scale": constraints.greater_than(1e-6), + } + + def __init__(self, loc, scale, **kwargs): + self.loc = loc + self.scale = scale + self.base_dist = pyd.Normal(loc, scale) + + transforms = [TanhTransform()] + super().__init__(self.base_dist, transforms) + + @property + def mode(self): + mu = self.loc + for tr in self.transforms: + mu = tr(mu) + return mu + + def entropy(self, N=1): + # sample from the distribution and then compute + # the empirical entropy: + x = self.rsample((N,)) + x = torch.clamp(x, -0.99999, 0.99999) + log_p = self.log_prob(x) + + # log_p: (batch_size, context_len, action_dim), + return -log_p.mean(axis=0).sum(axis=2) + + def log_likelihood(self, x): + # log_prob(x): (batch_size, context_len, action_dim) + # sum up along the action dimensions + # Return tensor shape: (batch_size, context_len) + x = torch.clamp(x, -0.99999, 0.99999) + return self.log_prob(x).sum(axis=2) diff --git a/torchrl/data/datasets/d4rl.py b/torchrl/data/datasets/d4rl.py index d6c32083e23..ab82c06219f 100644 --- a/torchrl/data/datasets/d4rl.py +++ b/torchrl/data/datasets/d4rl.py @@ -136,6 +136,8 @@ def __init__( if split_trajs: dataset = split_trajectories(dataset) + dataset["next", "done"][:, -1] = True + storage = LazyMemmapStorage(dataset.shape[0]) super().__init__( batch_size=batch_size, @@ -261,6 +263,7 @@ def _get_dataset_from_env(self, name, env_kwargs): ) else: dataset.set("done", dataset.get("terminal")) + dataset.rename_key("rewards", "reward") dataset.rename_key("actions", "action") try: diff --git a/torchrl/modules/distributions/utils.py b/torchrl/modules/distributions/utils.py index 3fc64b4c03f..267632c4fd9 100644 --- a/torchrl/modules/distributions/utils.py +++ b/torchrl/modules/distributions/utils.py @@ -46,7 +46,7 @@ def __init__(self, base_distribution, transforms, validate_args=None): transforms, ] elif isinstance(transforms, list): - raise ValueError("Mae a ComposeTransform first.") + raise ValueError("Make a ComposeTransform first.") else: raise ValueError( "transforms must be a Transform or list, but was {}".format(transforms) diff --git a/torchrl/modules/models/models.py b/torchrl/modules/models/models.py index 63b84a097a0..7e422d856b4 100644 --- a/torchrl/modules/models/models.py +++ b/torchrl/modules/models/models.py @@ -1204,7 +1204,8 @@ def weight_init(m): if hasattr(m.bias, "data"): m.bias.data.fill_(0.0) - self.apply(weight_init) + self.action_layer_mean.apply(weight_init) + self.action_layer_logstd.apply(weight_init) def forward( self, @@ -1215,7 +1216,7 @@ def forward( hidden_state = self.transformer(observation, action, return_to_go) mu = self.action_layer_mean(hidden_state) log_std = self.action_layer_logstd(hidden_state) - # mu, log_std = torch.chunk(out, 2, -1) + log_std = torch.tanh(log_std) # log_std is the output of tanh so it will be between [-1, 1] # map it to be between [log_std_min, log_std_max] @@ -1288,14 +1289,14 @@ def __init__( transformer_config["n_embd"], action_dim, device=device ) - def weight_init(m): - """Custom weight init for Conv2D and Linear layers.""" - if isinstance(m, torch.nn.Linear): - nn.init.orthogonal_(m.weight.data) - if hasattr(m.bias, "data"): - m.bias.data.fill_(0.0) + # def weight_init(m): + # """Custom weight init for Conv2D and Linear layers.""" + # if isinstance(m, torch.nn.Linear): + # nn.init.orthogonal_(m.weight.data) + # if hasattr(m.bias, "data"): + # m.bias.data.fill_(0.0) - self.apply(weight_init) + # self.apply(weight_init) def forward( self, diff --git a/torchrl/objectives/decision_transformer.py b/torchrl/objectives/decision_transformer.py index 6543652338a..781ab166082 100644 --- a/torchrl/objectives/decision_transformer.py +++ b/torchrl/objectives/decision_transformer.py @@ -190,13 +190,16 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: """Compute the loss for the Online Decision Transformer.""" # extract action targets target_actions = tensordict.get(self.tensor_keys.action).detach() - + target_actions = torch.clamp(target_actions, -0.99999, 0.99999) action_dist = self.actor_network.get_dist( tensordict, params=self.actor_network_params ) - loss_log_likelihood = action_dist.log_prob(target_actions).mean() - entropy = self.get_entropy_bonus(action_dist).mean() + # loss_log_likelihood = action_dist.log_prob(target_actions).mean() + # entropy = self.get_entropy_bonus(action_dist).mean() + loss_log_likelihood = action_dist.log_likelihood(target_actions).mean() + entropy = action_dist.entropy().mean() + loss_entropy = self.alpha.detach() * entropy loss_alpha = self.log_alpha.exp() * (entropy - self.target_entropy).detach() @@ -207,6 +210,9 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: "loss_alpha": loss_alpha, "entropy": entropy.detach(), "alpha": self.alpha.detach(), + "return_to_go-mean": tensordict["return_to_go"].mean(), + "action_dist_mean": action_dist.loc.mean(), + "action_dist_std": action_dist.scale.mean(), } return TensorDict(out, []) From b26078515314a7acf0746076ce1ea9f4efdbc7e3 Mon Sep 17 00:00:00 2001 From: BY571 Date: Tue, 22 Aug 2023 19:46:47 +0200 Subject: [PATCH 101/104] update utils --- examples/decision_transformer/dt.py | 2 +- examples/decision_transformer/utils.py | 109 ++------------------- torchrl/objectives/decision_transformer.py | 11 +-- 3 files changed, 14 insertions(+), 108 deletions(-) diff --git a/examples/decision_transformer/dt.py b/examples/decision_transformer/dt.py index 046b29b4b76..30e19608cf7 100644 --- a/examples/decision_transformer/dt.py +++ b/examples/decision_transformer/dt.py @@ -36,7 +36,7 @@ def main(cfg: "DictConfig"): # noqa: F821 policy = actor.to(model_device) loss_module = make_dt_loss(cfg.loss, actor) - transformer_optim, scheduler = make_dt_optimizer(cfg.optim, policy) + transformer_optim, scheduler = make_dt_optimizer(cfg.optim, loss_module) inference_policy = DecisionTransformerInferenceWrapper( policy=policy, inference_context=cfg.env.inference_context, diff --git a/examples/decision_transformer/utils.py b/examples/decision_transformer/utils.py index de7b1725f24..b37fedebd33 100644 --- a/examples/decision_transformer/utils.py +++ b/examples/decision_transformer/utils.py @@ -1,18 +1,13 @@ -import math - import torch.nn -import torch.nn.functional as F import torch.optim from lamb import Lamb from tensordict.nn import TensorDictModule -from torch import distributions as pyd -from torch.distributions import constraints from torchrl.collectors import SyncDataCollector from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer from torchrl.data.datasets.d4rl import D4RLExperienceReplay -from torchrl.data.replay_buffers import SamplerWithoutReplacement +from torchrl.data.replay_buffers import RandomSampler from torchrl.envs import ( CatFrames, Compose, @@ -21,7 +16,6 @@ ExcludeTransform, NoopResetEnv, ObservationNorm, - # ParallelEnv, RandomCropTensorDict, Reward2GoTransform, RewardScaling, @@ -40,7 +34,7 @@ OnlineDTActor, ProbabilisticActor, TanhDelta, - # TanhNormal, + TanhNormal, ) from torchrl.objectives import DTLoss, OnlineDTLoss @@ -216,7 +210,7 @@ def make_offline_replay_buffer(rb_cfg, reward_scaling): rb_cfg.dataset, split_trajs=True, batch_size=rb_cfg.batch_size, - sampler=SamplerWithoutReplacement(drop_last=False), + sampler=RandomSampler(), # SamplerWithoutReplacement(drop_last=False), transform=transforms, use_timeout_as_done=True, ) @@ -302,12 +296,8 @@ def make_odt_model(cfg): "scale", ], ) - dist_class = SquashedNormal # TanhNormal - # dist_kwargs = { - # "min": -1.0, - # "max": 1.0, - # "tanh_loc": False, - # } + dist_class = TanhNormal + dist_kwargs = {"min": -1.0, "max": 1.0, "tanh_loc": False, "upscale": 1.0} actor = ProbabilisticActor( spec=action_spec, @@ -315,7 +305,7 @@ def make_odt_model(cfg): out_keys=["action"], module=actor_module, distribution_class=dist_class, - # distribution_kwargs=dist_kwargs, + distribution_kwargs=dist_kwargs, default_interaction_mode="random", cache_dist=False, return_log_prob=False, @@ -406,9 +396,9 @@ def make_dt_loss(loss_cfg, actor_network): return loss -def make_odt_optimizer(optim_cfg, actor_network, loss_module): +def make_odt_optimizer(optim_cfg, loss_module): dt_optimizer = Lamb( - actor_network.parameters(), + loss_module.actor_network_params.flatten_keys().values(), lr=optim_cfg.lr, weight_decay=optim_cfg.weight_decay, eps=1.0e-8, @@ -426,9 +416,9 @@ def make_odt_optimizer(optim_cfg, actor_network, loss_module): return dt_optimizer, log_temp_optimizer, scheduler -def make_dt_optimizer(optim_cfg, actor_network): +def make_dt_optimizer(optim_cfg, loss_module): dt_optimizer = torch.optim.Adam( - actor_network.parameters(), + loss_module.actor_network_params.flatten_keys().values(), lr=optim_cfg.lr, weight_decay=optim_cfg.weight_decay, eps=1.0e-8, @@ -457,82 +447,3 @@ def make_logger(cfg): wandb_kwargs={"config": cfg}, ) return logger - - -class TanhTransform(pyd.transforms.Transform): - domain = pyd.constraints.real - codomain = pyd.constraints.interval(-1.0, 1.0) - bijective = True - sign = +1 - - def __init__(self, cache_size=1): - super().__init__(cache_size=cache_size) - - @staticmethod - def atanh(x): - return 0.5 * (x.log1p() - (-x).log1p()) - - def __eq__(self, other): - return isinstance(other, TanhTransform) - - def _call(self, x): - return x.tanh() - - def _inverse(self, y): - # We do not clamp to the boundary here as it may degrade the performance of certain algorithms. - # one should use `cache_size=1` instead - return self.atanh(y) - - def log_abs_det_jacobian(self, x, y): - # We use a formula that is more numerically stable, see details in the following link - # https://github.com/tensorflow/probability/commit/ef6bb176e0ebd1cf6e25c6b5cecdd2428c22963f#diff-e120f70e92e6741bca649f04fcd907b7 - return 2.0 * (math.log(2.0) - x - F.softplus(-2.0 * x)) - - -class SquashedNormal( - pyd.transformed_distribution.TransformedDistribution -): # FasterTransformedDistribution): # - """ - Squashed Normal Distribution(s) - - If loc/std is of size (batch_size, sequence length, d), - this returns batch_size * sequence length * d - independent squashed univariate normal distributions. - """ - - arg_constraints = { - "loc": constraints.real, - "scale": constraints.greater_than(1e-6), - } - - def __init__(self, loc, scale, **kwargs): - self.loc = loc - self.scale = scale - self.base_dist = pyd.Normal(loc, scale) - - transforms = [TanhTransform()] - super().__init__(self.base_dist, transforms) - - @property - def mode(self): - mu = self.loc - for tr in self.transforms: - mu = tr(mu) - return mu - - def entropy(self, N=1): - # sample from the distribution and then compute - # the empirical entropy: - x = self.rsample((N,)) - x = torch.clamp(x, -0.99999, 0.99999) - log_p = self.log_prob(x) - - # log_p: (batch_size, context_len, action_dim), - return -log_p.mean(axis=0).sum(axis=2) - - def log_likelihood(self, x): - # log_prob(x): (batch_size, context_len, action_dim) - # sum up along the action dimensions - # Return tensor shape: (batch_size, context_len) - x = torch.clamp(x, -0.99999, 0.99999) - return self.log_prob(x).sum(axis=2) diff --git a/torchrl/objectives/decision_transformer.py b/torchrl/objectives/decision_transformer.py index 781ab166082..46121e75c10 100644 --- a/torchrl/objectives/decision_transformer.py +++ b/torchrl/objectives/decision_transformer.py @@ -190,15 +190,13 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: """Compute the loss for the Online Decision Transformer.""" # extract action targets target_actions = tensordict.get(self.tensor_keys.action).detach() - target_actions = torch.clamp(target_actions, -0.99999, 0.99999) + action_dist = self.actor_network.get_dist( tensordict, params=self.actor_network_params ) - # loss_log_likelihood = action_dist.log_prob(target_actions).mean() - # entropy = self.get_entropy_bonus(action_dist).mean() - loss_log_likelihood = action_dist.log_likelihood(target_actions).mean() - entropy = action_dist.entropy().mean() + loss_log_likelihood = action_dist.log_prob(target_actions).mean() + entropy = self.get_entropy_bonus(action_dist).mean() loss_entropy = self.alpha.detach() * entropy @@ -210,9 +208,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: "loss_alpha": loss_alpha, "entropy": entropy.detach(), "alpha": self.alpha.detach(), - "return_to_go-mean": tensordict["return_to_go"].mean(), - "action_dist_mean": action_dist.loc.mean(), - "action_dist_std": action_dist.scale.mean(), } return TensorDict(out, []) From a820015925fb93fb1fb8310c892889240640befe Mon Sep 17 00:00:00 2001 From: BY571 Date: Wed, 23 Aug 2023 10:35:24 +0200 Subject: [PATCH 102/104] update and fixes --- examples/decision_transformer/odt_config.yaml | 4 +- examples/decision_transformer/online_dt.py | 7 +- examples/decision_transformer/utils.py | 2 +- .../modules/models/decision_transformer.py | 123 +++++++++++++++++- 4 files changed, 125 insertions(+), 11 deletions(-) diff --git a/examples/decision_transformer/odt_config.yaml b/examples/decision_transformer/odt_config.yaml index 5a6084bc4e1..1326ec29684 100644 --- a/examples/decision_transformer/odt_config.yaml +++ b/examples/decision_transformer/odt_config.yaml @@ -11,7 +11,7 @@ env: num_eval_envs: 10 reward_scaling: 0.001 # for r2g noop: 1 - seed: 2 + seed: 42 target_return_mode: reduce eval_target_return: 6000 collect_target_return: 12000 @@ -45,7 +45,7 @@ optim: weight_decay: 5.0e-4 batch_size: 256 lr_scheduler: "" - pretrain_gradient_steps: 55000 + pretrain_gradient_steps: 10000 updates_per_episode: 300 warmup_steps: 10000 clip_grad: 0.25 diff --git a/examples/decision_transformer/online_dt.py b/examples/decision_transformer/online_dt.py index e18d7fe2593..8ca1d85f599 100644 --- a/examples/decision_transformer/online_dt.py +++ b/examples/decision_transformer/online_dt.py @@ -12,7 +12,8 @@ import tqdm from torchrl.envs.libs.gym import set_gym_backend -from torchrl.envs.utils import ExplorationType, set_exploration_type + +# from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules.tensordict_module import DecisionTransformerInferenceWrapper from utils import ( @@ -41,7 +42,7 @@ def main(cfg: "DictConfig"): # noqa: F821 loss_module = make_odt_loss(cfg.loss, policy) transformer_optim, temperature_optim, scheduler = make_odt_optimizer( - cfg.optim, policy, loss_module + cfg.optim, loss_module ) inference_policy = DecisionTransformerInferenceWrapper( policy=policy, @@ -80,7 +81,7 @@ def main(cfg: "DictConfig"): # noqa: F821 scheduler.step() # evaluation - with set_exploration_type(ExplorationType.MEAN), torch.no_grad(): + with torch.no_grad(): inference_policy.eval() if i % pretrain_log_interval == 0: eval_td = test_env.rollout( diff --git a/examples/decision_transformer/utils.py b/examples/decision_transformer/utils.py index b37fedebd33..dbe0f6380c0 100644 --- a/examples/decision_transformer/utils.py +++ b/examples/decision_transformer/utils.py @@ -297,7 +297,7 @@ def make_odt_model(cfg): ], ) dist_class = TanhNormal - dist_kwargs = {"min": -1.0, "max": 1.0, "tanh_loc": False, "upscale": 1.0} + dist_kwargs = {"min": -1.0, "max": 1.0, "tanh_loc": False, "upscale": 5.0} actor = ProbabilisticActor( spec=action_spec, diff --git a/torchrl/modules/models/decision_transformer.py b/torchrl/modules/models/decision_transformer.py index befe0861da7..1a6eaa63235 100644 --- a/torchrl/modules/models/decision_transformer.py +++ b/torchrl/modules/models/decision_transformer.py @@ -8,12 +8,127 @@ import importlib from dataclasses import dataclass -from typing import Any +from typing import Any, Optional import torch import torch.nn as nn _has_transformers = importlib.util.find_spec("transformers") is not None +import transformers +from transformers.models.gpt2.modeling_gpt2 import ( + BaseModelOutputWithPastAndCrossAttentions, + GPT2Model, +) + + +class ModifiedGPT2Model(GPT2Model): + """Wrapper around the GPT2Model from transformers. + + This class is a modified version of the GPT2Model from transformers + as for the Decision Transformer we dont need the wpe layer. + + """ + + def __init__(self, config): + super(ModifiedGPT2Model, self).__init__(config) + + # Remove the wpe layer + del self.wpe + + def forward( + self, + inputs_embeds: Optional[torch.FloatTensor] = None, + ): + input_shape = inputs_embeds.size()[:-1] + + output_attentions = self.config.output_attentions + output_hidden_states = self.config.output_hidden_states + use_cache = self.config.use_cache + return_dict = self.config.use_return_dict + + head_mask = self.get_head_mask(None, self.config.n_layer) + + hidden_states = inputs_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),) + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = ( + () if output_attentions and self.config.add_cross_attention else None + ) + all_hidden_states = () if output_hidden_states else None + past_key_values = tuple([None] * len(self.h)) + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = tuple( + past_state.to(hidden_states.device) for past_state in layer_past + ) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + outputs = block( + hidden_states, + layer_past=layer_past, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + ( + outputs[2 if use_cache else 1], + ) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + ( + outputs[3 if use_cache else 2], + ) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + presents, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) class DecisionTransformer(nn.Module): @@ -95,8 +210,6 @@ def __init__( raise ImportError( "transformers is not installed. Please install it with `pip install transformers`." ) - import transformers - from transformers.models.gpt2.modeling_gpt2 import GPT2Model if config is None: config = self.default_config() @@ -127,7 +240,7 @@ def __init__( self.action_dim = action_dim self.hidden_size = config["n_embd"] - self.transformer = GPT2Model(config=gpt_config) + self.transformer = ModifiedGPT2Model(config=gpt_config) self.embed_return = torch.nn.Linear(1, self.hidden_size) self.embed_state = torch.nn.Linear(self.state_dim, self.hidden_size) @@ -159,7 +272,7 @@ def forward( # which works nice in an autoregressive sense since states predict actions stacked_inputs = ( torch.stack( - (returns_embeddings, state_embeddings, action_embeddings), dim=-2 + (returns_embeddings, state_embeddings, action_embeddings), dim=-3 ) .permute(*range(len(batch_size)), -2, -3, -1) .reshape(*batch_size, 3 * seq_length, self.hidden_size) From 17093b7ed8f8d0245a6dc4e70362e47c3464c5e0 Mon Sep 17 00:00:00 2001 From: vmoens Date: Sat, 26 Aug 2023 15:28:25 -0400 Subject: [PATCH 103/104] running examples --- examples/multiagent/utils/logging.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/multiagent/utils/logging.py b/examples/multiagent/utils/logging.py index 03276329422..352d0addc51 100644 --- a/examples/multiagent/utils/logging.py +++ b/examples/multiagent/utils/logging.py @@ -6,7 +6,6 @@ import numpy as np import torch -import wandb from tensordict import TensorDictBase from torchrl.envs.libs.vmas import VmasEnv from torchrl.record.loggers import generate_exp_name, get_logger, Logger @@ -134,6 +133,8 @@ def log_evaluation( ).unsqueeze(0) if isinstance(logger, WandbLogger): + import wandb + logger.experiment.log(to_log, commit=False) logger.experiment.log( { From a717c8e1411ca3c5b14ef735b089eed51d26f685 Mon Sep 17 00:00:00 2001 From: BY571 Date: Mon, 28 Aug 2023 12:08:05 +0200 Subject: [PATCH 104/104] update header, docs and delete dtwrapper --- examples/decision_transformer/lamb.py | 51 +------ examples/decision_transformer/odt_config.yaml | 4 +- examples/decision_transformer/online_dt.py | 4 +- examples/decision_transformer/utils.py | 5 + .../modules/models/decision_transformer.py | 137 ++---------------- torchrl/modules/models/models.py | 16 +- 6 files changed, 33 insertions(+), 184 deletions(-) diff --git a/examples/decision_transformer/lamb.py b/examples/decision_transformer/lamb.py index a3324614051..7f874b6e049 100644 --- a/examples/decision_transformer/lamb.py +++ b/examples/decision_transformer/lamb.py @@ -1,51 +1,8 @@ -""" PyTorch Lamb optimizer w/ behaviour similar to NVIDIA FusedLamb -This optimizer code was adapted from the following (starting with latest) -* https://github.com/HabanaAI/Model-References/blob/2b435114fe8e31f159b1d3063b8280ae37af7423/PyTorch/nlp/bert/pretraining/lamb.py -* https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py -* https://github.com/cybertronai/pytorch-lamb -Use FusedLamb if you can (GPU). The reason for including this variant of Lamb is to have a version that is -similar in behaviour to APEX FusedLamb if you aren't using NVIDIA GPUs or cannot install/use APEX. -In addition to some cleanup, this Lamb impl has been modified to support PyTorch XLA and has been tested on TPU. -Original copyrights for above sources are below. -Modifications Copyright 2021 Ross Wightman -""" -# Copyright (c) 2021, Habana Labs Ltd. All rights reserved. - -# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) Meta Platforms, Inc. and affiliates. # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# MIT License -# -# Copyright (c) 2019 cybertronai -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# Lamb optimizer directly copied from https://github.com/facebookresearch/online-dt import math import torch diff --git a/examples/decision_transformer/odt_config.yaml b/examples/decision_transformer/odt_config.yaml index bf2f74654d2..de8d5ffb6af 100644 --- a/examples/decision_transformer/odt_config.yaml +++ b/examples/decision_transformer/odt_config.yaml @@ -35,12 +35,12 @@ replay_buffer: buffer_prefetch: 64 capacity: 1_000_000 buffer_scratch_dir: "/tmp/" - device: cpu + device: cuda:0 prefetch: 3 # Optimization optim: - device: cpu + device: cuda:0 lr: 1.0e-4 weight_decay: 5.0e-4 batch_size: 256 diff --git a/examples/decision_transformer/online_dt.py b/examples/decision_transformer/online_dt.py index 8ca1d85f599..01ab12dfabd 100644 --- a/examples/decision_transformer/online_dt.py +++ b/examples/decision_transformer/online_dt.py @@ -13,7 +13,7 @@ from torchrl.envs.libs.gym import set_gym_backend -# from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules.tensordict_module import DecisionTransformerInferenceWrapper from utils import ( @@ -81,7 +81,7 @@ def main(cfg: "DictConfig"): # noqa: F821 scheduler.step() # evaluation - with torch.no_grad(): + with torch.no_grad(), set_exploration_type(ExplorationType.MODE): inference_policy.eval() if i % pretrain_log_interval == 0: eval_td = test_env.rollout( diff --git a/examples/decision_transformer/utils.py b/examples/decision_transformer/utils.py index dbe0f6380c0..c181b32ca5d 100644 --- a/examples/decision_transformer/utils.py +++ b/examples/decision_transformer/utils.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + import torch.nn import torch.optim diff --git a/torchrl/modules/models/decision_transformer.py b/torchrl/modules/models/decision_transformer.py index 1a6eaa63235..8eb72f1f9ea 100644 --- a/torchrl/modules/models/decision_transformer.py +++ b/torchrl/modules/models/decision_transformer.py @@ -8,127 +8,12 @@ import importlib from dataclasses import dataclass -from typing import Any, Optional +from typing import Any import torch import torch.nn as nn _has_transformers = importlib.util.find_spec("transformers") is not None -import transformers -from transformers.models.gpt2.modeling_gpt2 import ( - BaseModelOutputWithPastAndCrossAttentions, - GPT2Model, -) - - -class ModifiedGPT2Model(GPT2Model): - """Wrapper around the GPT2Model from transformers. - - This class is a modified version of the GPT2Model from transformers - as for the Decision Transformer we dont need the wpe layer. - - """ - - def __init__(self, config): - super(ModifiedGPT2Model, self).__init__(config) - - # Remove the wpe layer - del self.wpe - - def forward( - self, - inputs_embeds: Optional[torch.FloatTensor] = None, - ): - input_shape = inputs_embeds.size()[:-1] - - output_attentions = self.config.output_attentions - output_hidden_states = self.config.output_hidden_states - use_cache = self.config.use_cache - return_dict = self.config.use_return_dict - - head_mask = self.get_head_mask(None, self.config.n_layer) - - hidden_states = inputs_embeds - - hidden_states = self.drop(hidden_states) - - output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),) - - presents = () if use_cache else None - all_self_attentions = () if output_attentions else None - all_cross_attentions = ( - () if output_attentions and self.config.add_cross_attention else None - ) - all_hidden_states = () if output_hidden_states else None - past_key_values = tuple([None] * len(self.h)) - for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): - # Model parallel - if self.model_parallel: - torch.cuda.set_device(hidden_states.device) - # Ensure layer_past is on same device as hidden_states (might not be correct) - if layer_past is not None: - layer_past = tuple( - past_state.to(hidden_states.device) for past_state in layer_past - ) - if isinstance(head_mask, torch.Tensor): - head_mask = head_mask.to(hidden_states.device) - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - outputs = block( - hidden_states, - layer_past=layer_past, - head_mask=head_mask[i], - use_cache=use_cache, - output_attentions=output_attentions, - ) - - hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[1],) - - if output_attentions: - all_self_attentions = all_self_attentions + ( - outputs[2 if use_cache else 1], - ) - if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + ( - outputs[3 if use_cache else 2], - ) - - # Model Parallel: If it's the last layer for that device, put things on the next device - if self.model_parallel: - for k, v in self.device_map.items(): - if i == v[-1] and "cuda:" + str(k) != self.last_device: - hidden_states = hidden_states.to("cuda:" + str(k + 1)) - - hidden_states = self.ln_f(hidden_states) - - hidden_states = hidden_states.view(output_shape) - # Add last hidden state - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple( - v - for v in [ - hidden_states, - presents, - all_hidden_states, - all_self_attentions, - all_cross_attentions, - ] - if v is not None - ) - - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - cross_attentions=all_cross_attentions, - ) class DecisionTransformer(nn.Module): @@ -138,14 +23,14 @@ class DecisionTransformer(nn.Module): The transformer utilizes a default config to create the GPT2 model if the user does not provide a specific config. default_config = { - "n_embd": 256, - "n_layer": 4, - "n_head": 4, - "n_inner": 1024, - "activation": "relu", - "n_positions": 1024, - "resid_pdrop": 0.1, - "attn_pdrop": 0.1, + ... "n_embd": 256, + ... "n_layer": 4, + ... "n_head": 4, + ... "n_inner": 1024, + ... "activation": "relu", + ... "n_positions": 1024, + ... "resid_pdrop": 0.1, + ... "attn_pdrop": 0.1, } Args: @@ -210,6 +95,8 @@ def __init__( raise ImportError( "transformers is not installed. Please install it with `pip install transformers`." ) + import transformers + from transformers.models.gpt2.modeling_gpt2 import GPT2Model if config is None: config = self.default_config() @@ -240,7 +127,7 @@ def __init__( self.action_dim = action_dim self.hidden_size = config["n_embd"] - self.transformer = ModifiedGPT2Model(config=gpt_config) + self.transformer = GPT2Model(config=gpt_config) self.embed_return = torch.nn.Linear(1, self.hidden_size) self.embed_state = torch.nn.Linear(self.state_dim, self.hidden_size) diff --git a/torchrl/modules/models/models.py b/torchrl/modules/models/models.py index 7e422d856b4..ba6344f3f03 100644 --- a/torchrl/modules/models/models.py +++ b/torchrl/modules/models/models.py @@ -1289,14 +1289,14 @@ def __init__( transformer_config["n_embd"], action_dim, device=device ) - # def weight_init(m): - # """Custom weight init for Conv2D and Linear layers.""" - # if isinstance(m, torch.nn.Linear): - # nn.init.orthogonal_(m.weight.data) - # if hasattr(m.bias, "data"): - # m.bias.data.fill_(0.0) - - # self.apply(weight_init) + def weight_init(m): + """Custom weight init for Conv2D and Linear layers.""" + if isinstance(m, torch.nn.Linear): + nn.init.orthogonal_(m.weight.data) + if hasattr(m.bias, "data"): + m.bias.data.fill_(0.0) + + self.action_layer.apply(weight_init) def forward( self,