diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index 2a726c64c36..f66a560f2de 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -1756,6 +1756,7 @@ def test_multi_consecutive(self, shape, python_based): lstm_module_ss = LSTMModule( input_size=3, hidden_size=12, + num_layers=4, batch_first=True, in_keys=["observation", "hidden0", "hidden1"], out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")], diff --git a/torchrl/modules/tensordict_module/rnn.py b/torchrl/modules/tensordict_module/rnn.py index b705e33474e..e76ee043c4e 100644 --- a/torchrl/modules/tensordict_module/rnn.py +++ b/torchrl/modules/tensordict_module/rnn.py @@ -12,7 +12,7 @@ from tensordict.nn import TensorDictModuleBase as ModuleBase from tensordict.tensordict import NO_DEFAULT -from tensordict.utils import prod +from tensordict.utils import expand_as_right, prod from torch import nn, Tensor from torch.nn.modules.rnn import RNNCellBase @@ -235,52 +235,57 @@ def _lstm_cell(x, hx, cx, weight_ih, bias_ih, weight_hh, bias_hh): def _lstm(self, x, hx): - if self.batch_first is False: - x = x.permute( - 1, 0, 2 - ) # Change (seq_len, batch, features) to (batch, seq_len, features) - - # should check self.batch_first - bs, seq_len, input_size = x.size() - h_t, c_t = [list(h.unbind(0)) for h in hx] + h_t, c_t = hx + h_t, c_t = h_t.unbind(0), c_t.unbind(0) outputs = [] - for t in range(seq_len): - - x_t = x[:, t, :] - - for layer in range(self.num_layers): - # Retrieve weights - weights = self._all_weights[layer] - weight_ih = getattr(self, weights[0]) - weight_hh = getattr(self, weights[1]) - if self.bias is True: - bias_ih = getattr(self, weights[2]) - bias_hh = getattr(self, weights[3]) - else: - bias_ih = None - bias_hh = None + weight_ihs = [] + weight_hhs = [] + bias_ihs = [] + bias_hhs = [] + for weights in self._all_weights: + # Retrieve weights + weight_ihs.append(getattr(self, weights[0])) + weight_hhs.append(getattr(self, weights[1])) + if self.bias: + bias_ihs.append(getattr(self, weights[2])) + bias_hhs.append(getattr(self, weights[3])) + else: + bias_ihs.append(None) + bias_hhs.append(None) + + for x_t in x.unbind(int(self.batch_first)): + h_t_out = [] + c_t_out = [] + + for layer, ( + weight_ih, + bias_ih, + weight_hh, + bias_hh, + _h_t, + _c_t, + ) in enumerate(zip(weight_ihs, bias_ihs, weight_hhs, bias_hhs, h_t, c_t)): # Run cell - h_t[layer], c_t[layer] = self._lstm_cell( - x_t, h_t[layer], c_t[layer], weight_ih, bias_ih, weight_hh, bias_hh + _h_t, _c_t = self._lstm_cell( + x_t, _h_t, _c_t, weight_ih, bias_ih, weight_hh, bias_hh ) + h_t_out.append(_h_t) + c_t_out.append(_c_t) # Apply dropout if in training mode if layer < self.num_layers - 1 and self.dropout: - x_t = F.dropout(h_t[layer], p=self.dropout, training=self.training) + x_t = F.dropout(_h_t, p=self.dropout, training=self.training) else: # No dropout after the last layer - x_t = h_t[layer] - + x_t = _h_t + h_t = h_t_out + c_t = c_t_out outputs.append(x_t) - outputs = torch.stack(outputs, dim=1) - if self.batch_first is False: - outputs = outputs.permute( - 1, 0, 2 - ) # Change back (batch, seq_len, features) to (seq_len, batch, features) + outputs = torch.stack(outputs, dim=int(self.batch_first)) - return outputs, (torch.stack(h_t, 0), torch.stack(c_t, 0)) + return outputs, (torch.stack(h_t_out, 0), torch.stack(c_t_out, 0)) def forward(self, input, hx=None): # noqa: F811 real_hidden_size = self.proj_size if self.proj_size > 0 else self.hidden_size @@ -305,12 +310,7 @@ def forward(self, input, hx=None): # noqa: F811 device=input.device, ) hx = (h_zeros, c_zeros) - else: - self.check_forward_args(input, hx, batch_sizes=None) - result = self._lstm(input, hx) - output = result[0] - hidden = result[1] - return output, hidden + return self._lstm(input, hx) class LSTMModule(ModuleBase): @@ -457,7 +457,7 @@ def __init__( raise ValueError("The input lstm must have batch_first=True.") if bidirectional: raise ValueError("The input lstm cannot be bidirectional.") - if python_based is True: + if python_based: lstm = LSTM( input_size=input_size, hidden_size=hidden_size, @@ -647,8 +647,9 @@ def forward(self, tensordict: TensorDictBase): # if splits is not None: # value = torch.nn.utils.rnn.pack_padded_sequence(value, splits, batch_first=True) if is_init.any() and hidden0 is not None: - hidden0[is_init] = 0 - hidden1[is_init] = 0 + is_init_expand = expand_as_right(is_init, hidden0) + hidden0 = torch.where(is_init_expand, 0, hidden0) + hidden1 = torch.where(is_init_expand, 0, hidden1) val, hidden0, hidden1 = self._lstm( value, batch, steps, device, dtype, hidden0, hidden1 ) @@ -879,7 +880,7 @@ def __init__( dtype=None, ) -> None: - if bidirectional is True: + if bidirectional: raise NotImplementedError( "Bidirectional LSTMs are not supported yet in this implementation." ) @@ -924,31 +925,34 @@ def _gru(self, x, hx): bs, seq_len, input_size = x.size() h_t = list(hx.unbind(0)) - outputs = [] + weight_ih = [] + weight_hh = [] + bias_ih = [] + bias_hh = [] + for layer in range(self.num_layers): + + # Retrieve weights + weights = self._all_weights[layer] + weight_ih.append(getattr(self, weights[0])) + weight_hh.append(getattr(self, weights[1])) + if self.bias: + bias_ih.append(getattr(self, weights[2])) + bias_hh.append(getattr(self, weights[3])) + else: + bias_ih.append(None) + bias_hh.append(None) - for t in range(seq_len): - x_t = x[:, t, :] + outputs = [] + for x_t in x.unbind(1): for layer in range(self.num_layers): - - # Retrieve weights - weights = self._all_weights[layer] - weight_ih = getattr(self, weights[0]) - weight_hh = getattr(self, weights[1]) - if self.bias is True: - bias_ih = getattr(self, weights[2]) - bias_hh = getattr(self, weights[3]) - else: - bias_ih = None - bias_hh = None - h_t[layer] = self._gru_cell( x_t, h_t[layer], - weight_ih, - bias_ih, - weight_hh, - bias_hh, + weight_ih[layer], + bias_ih[layer], + weight_hh[layer], + bias_hh[layer], ) # Apply dropout if in training mode and not the last layer @@ -960,7 +964,7 @@ def _gru(self, x, hx): outputs.append(x_t) outputs = torch.stack(outputs, dim=1) - if self.batch_first is False: + if not self.batch_first: outputs = outputs.permute( 1, 0, 2 ) # Change back (batch, seq_len, features) to (seq_len, batch, features) @@ -1160,7 +1164,7 @@ def __init__( if bidirectional: raise ValueError("The input gru cannot be bidirectional.") - if python_based is True: + if python_based: gru = GRU( input_size=input_size, hidden_size=hidden_size, @@ -1314,6 +1318,8 @@ def forward(self, tensordict: TensorDictBase): ) else: tensordict_shaped = tensordict.reshape(-1).unsqueeze(-1) + # TODO: replace by contiguous, or ultimately deprecate the default lazy unsqueeze + tensordict_shaped = tensordict_shaped.to_tensordict() is_init = tensordict_shaped.get("is_init").squeeze(-1) splits = None @@ -1342,7 +1348,8 @@ def forward(self, tensordict: TensorDictBase): # if splits is not None: # value = torch.nn.utils.rnn.pack_padded_sequence(value, splits, batch_first=True) if is_init.any() and hidden is not None: - hidden[is_init] = 0 + is_init_expand = expand_as_right(is_init, hidden) + hidden = torch.where(is_init_expand, 0, hidden) val, hidden = self._gru(value, batch, steps, device, dtype, hidden) tensordict_shaped.set(self.out_keys[0], val) tensordict_shaped.set(self.out_keys[1], hidden) diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index d0617dedc74..6d97b9fc7be 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -1200,11 +1200,12 @@ def _actor_loss( ) -> Tuple[Tensor, Dict[str, Tensor]]: # get probs and log probs for actions with self.actor_network_params.to_module(self.actor_network): - dist = self.actor_network.get_dist(tensordict) + dist = self.actor_network.get_dist(tensordict.clone(False)) prob = dist.probs - log_prob = torch.log(torch.where(prob == 0, 1e-8, prob)) + log_prob = prob.clamp_min(torch.finfo(prob.dtype).resolution) td_q = tensordict.select(*self.qvalue_network.in_keys) + td_q = self._vmap_qnetworkN0( td_q, self._cached_detached_qvalue_params # should we clone? ) @@ -1234,7 +1235,9 @@ def _alpha_loss(self, log_prob: Tensor) -> Tensor: @property def _alpha(self): if self.min_log_alpha is not None: - self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha) + self.log_alpha.data = self.log_alpha.data.clamp( + self.min_log_alpha, self.max_log_alpha + ) with torch.no_grad(): alpha = self.log_alpha.exp() return alpha