Skip to content

Commit

Permalink
[Performance] Faster RNNs (#1732)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Dec 15, 2023
1 parent bc4a72f commit b3d2aa6
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 70 deletions.
1 change: 1 addition & 0 deletions test/test_tensordictmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1756,6 +1756,7 @@ def test_multi_consecutive(self, shape, python_based):
lstm_module_ss = LSTMModule(
input_size=3,
hidden_size=12,
num_layers=4,
batch_first=True,
in_keys=["observation", "hidden0", "hidden1"],
out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")],
Expand Down
141 changes: 74 additions & 67 deletions torchrl/modules/tensordict_module/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from tensordict.nn import TensorDictModuleBase as ModuleBase

from tensordict.tensordict import NO_DEFAULT
from tensordict.utils import prod
from tensordict.utils import expand_as_right, prod

from torch import nn, Tensor
from torch.nn.modules.rnn import RNNCellBase
Expand Down Expand Up @@ -235,52 +235,57 @@ def _lstm_cell(x, hx, cx, weight_ih, bias_ih, weight_hh, bias_hh):

def _lstm(self, x, hx):

if self.batch_first is False:
x = x.permute(
1, 0, 2
) # Change (seq_len, batch, features) to (batch, seq_len, features)

# should check self.batch_first
bs, seq_len, input_size = x.size()
h_t, c_t = [list(h.unbind(0)) for h in hx]
h_t, c_t = hx
h_t, c_t = h_t.unbind(0), c_t.unbind(0)

outputs = []
for t in range(seq_len):

x_t = x[:, t, :]

for layer in range(self.num_layers):
# Retrieve weights
weights = self._all_weights[layer]
weight_ih = getattr(self, weights[0])
weight_hh = getattr(self, weights[1])
if self.bias is True:
bias_ih = getattr(self, weights[2])
bias_hh = getattr(self, weights[3])
else:
bias_ih = None
bias_hh = None

weight_ihs = []
weight_hhs = []
bias_ihs = []
bias_hhs = []
for weights in self._all_weights:
# Retrieve weights
weight_ihs.append(getattr(self, weights[0]))
weight_hhs.append(getattr(self, weights[1]))
if self.bias:
bias_ihs.append(getattr(self, weights[2]))
bias_hhs.append(getattr(self, weights[3]))
else:
bias_ihs.append(None)
bias_hhs.append(None)

for x_t in x.unbind(int(self.batch_first)):
h_t_out = []
c_t_out = []

for layer, (
weight_ih,
bias_ih,
weight_hh,
bias_hh,
_h_t,
_c_t,
) in enumerate(zip(weight_ihs, bias_ihs, weight_hhs, bias_hhs, h_t, c_t)):
# Run cell
h_t[layer], c_t[layer] = self._lstm_cell(
x_t, h_t[layer], c_t[layer], weight_ih, bias_ih, weight_hh, bias_hh
_h_t, _c_t = self._lstm_cell(
x_t, _h_t, _c_t, weight_ih, bias_ih, weight_hh, bias_hh
)
h_t_out.append(_h_t)
c_t_out.append(_c_t)

# Apply dropout if in training mode
if layer < self.num_layers - 1 and self.dropout:
x_t = F.dropout(h_t[layer], p=self.dropout, training=self.training)
x_t = F.dropout(_h_t, p=self.dropout, training=self.training)
else: # No dropout after the last layer
x_t = h_t[layer]

x_t = _h_t
h_t = h_t_out
c_t = c_t_out
outputs.append(x_t)

outputs = torch.stack(outputs, dim=1)
if self.batch_first is False:
outputs = outputs.permute(
1, 0, 2
) # Change back (batch, seq_len, features) to (seq_len, batch, features)
outputs = torch.stack(outputs, dim=int(self.batch_first))

return outputs, (torch.stack(h_t, 0), torch.stack(c_t, 0))
return outputs, (torch.stack(h_t_out, 0), torch.stack(c_t_out, 0))

def forward(self, input, hx=None): # noqa: F811
real_hidden_size = self.proj_size if self.proj_size > 0 else self.hidden_size
Expand All @@ -305,12 +310,7 @@ def forward(self, input, hx=None): # noqa: F811
device=input.device,
)
hx = (h_zeros, c_zeros)
else:
self.check_forward_args(input, hx, batch_sizes=None)
result = self._lstm(input, hx)
output = result[0]
hidden = result[1]
return output, hidden
return self._lstm(input, hx)


class LSTMModule(ModuleBase):
Expand Down Expand Up @@ -457,7 +457,7 @@ def __init__(
raise ValueError("The input lstm must have batch_first=True.")
if bidirectional:
raise ValueError("The input lstm cannot be bidirectional.")
if python_based is True:
if python_based:
lstm = LSTM(
input_size=input_size,
hidden_size=hidden_size,
Expand Down Expand Up @@ -647,8 +647,9 @@ def forward(self, tensordict: TensorDictBase):
# if splits is not None:
# value = torch.nn.utils.rnn.pack_padded_sequence(value, splits, batch_first=True)
if is_init.any() and hidden0 is not None:
hidden0[is_init] = 0
hidden1[is_init] = 0
is_init_expand = expand_as_right(is_init, hidden0)
hidden0 = torch.where(is_init_expand, 0, hidden0)
hidden1 = torch.where(is_init_expand, 0, hidden1)
val, hidden0, hidden1 = self._lstm(
value, batch, steps, device, dtype, hidden0, hidden1
)
Expand Down Expand Up @@ -879,7 +880,7 @@ def __init__(
dtype=None,
) -> None:

if bidirectional is True:
if bidirectional:
raise NotImplementedError(
"Bidirectional LSTMs are not supported yet in this implementation."
)
Expand Down Expand Up @@ -924,31 +925,34 @@ def _gru(self, x, hx):
bs, seq_len, input_size = x.size()
h_t = list(hx.unbind(0))

outputs = []
weight_ih = []
weight_hh = []
bias_ih = []
bias_hh = []
for layer in range(self.num_layers):

# Retrieve weights
weights = self._all_weights[layer]
weight_ih.append(getattr(self, weights[0]))
weight_hh.append(getattr(self, weights[1]))
if self.bias:
bias_ih.append(getattr(self, weights[2]))
bias_hh.append(getattr(self, weights[3]))
else:
bias_ih.append(None)
bias_hh.append(None)

for t in range(seq_len):
x_t = x[:, t, :]
outputs = []

for x_t in x.unbind(1):
for layer in range(self.num_layers):

# Retrieve weights
weights = self._all_weights[layer]
weight_ih = getattr(self, weights[0])
weight_hh = getattr(self, weights[1])
if self.bias is True:
bias_ih = getattr(self, weights[2])
bias_hh = getattr(self, weights[3])
else:
bias_ih = None
bias_hh = None

h_t[layer] = self._gru_cell(
x_t,
h_t[layer],
weight_ih,
bias_ih,
weight_hh,
bias_hh,
weight_ih[layer],
bias_ih[layer],
weight_hh[layer],
bias_hh[layer],
)

# Apply dropout if in training mode and not the last layer
Expand All @@ -960,7 +964,7 @@ def _gru(self, x, hx):
outputs.append(x_t)

outputs = torch.stack(outputs, dim=1)
if self.batch_first is False:
if not self.batch_first:
outputs = outputs.permute(
1, 0, 2
) # Change back (batch, seq_len, features) to (seq_len, batch, features)
Expand Down Expand Up @@ -1160,7 +1164,7 @@ def __init__(
if bidirectional:
raise ValueError("The input gru cannot be bidirectional.")

if python_based is True:
if python_based:
gru = GRU(
input_size=input_size,
hidden_size=hidden_size,
Expand Down Expand Up @@ -1314,6 +1318,8 @@ def forward(self, tensordict: TensorDictBase):
)
else:
tensordict_shaped = tensordict.reshape(-1).unsqueeze(-1)
# TODO: replace by contiguous, or ultimately deprecate the default lazy unsqueeze
tensordict_shaped = tensordict_shaped.to_tensordict()

is_init = tensordict_shaped.get("is_init").squeeze(-1)
splits = None
Expand Down Expand Up @@ -1342,7 +1348,8 @@ def forward(self, tensordict: TensorDictBase):
# if splits is not None:
# value = torch.nn.utils.rnn.pack_padded_sequence(value, splits, batch_first=True)
if is_init.any() and hidden is not None:
hidden[is_init] = 0
is_init_expand = expand_as_right(is_init, hidden)
hidden = torch.where(is_init_expand, 0, hidden)
val, hidden = self._gru(value, batch, steps, device, dtype, hidden)
tensordict_shaped.set(self.out_keys[0], val)
tensordict_shaped.set(self.out_keys[1], hidden)
Expand Down
9 changes: 6 additions & 3 deletions torchrl/objectives/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -1200,11 +1200,12 @@ def _actor_loss(
) -> Tuple[Tensor, Dict[str, Tensor]]:
# get probs and log probs for actions
with self.actor_network_params.to_module(self.actor_network):
dist = self.actor_network.get_dist(tensordict)
dist = self.actor_network.get_dist(tensordict.clone(False))
prob = dist.probs
log_prob = torch.log(torch.where(prob == 0, 1e-8, prob))
log_prob = prob.clamp_min(torch.finfo(prob.dtype).resolution)

td_q = tensordict.select(*self.qvalue_network.in_keys)

td_q = self._vmap_qnetworkN0(
td_q, self._cached_detached_qvalue_params # should we clone?
)
Expand Down Expand Up @@ -1234,7 +1235,9 @@ def _alpha_loss(self, log_prob: Tensor) -> Tensor:
@property
def _alpha(self):
if self.min_log_alpha is not None:
self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha)
self.log_alpha.data = self.log_alpha.data.clamp(
self.min_log_alpha, self.max_log_alpha
)
with torch.no_grad():
alpha = self.log_alpha.exp()
return alpha
Expand Down

0 comments on commit b3d2aa6

Please sign in to comment.