You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I tried running the following script and found that S5 is far slower than PyTorch's LSTM. Is this supposed to be the case? Perhaps the scale at which I'm testing it is too small to realize the benefit?
from datetime import datetime
import os
import torch
from s5 import S5
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
L = 1200
B = 256
x_dim = 128
m = S5(x_dim, 512).cuda()
lstm = torch.nn.LSTM(obs_size, 512).cuda()
x = torch.randn(B, L, obs_size).cuda()
t0 = datetime.now()
for i in range(10):
y, _ = lstm(x)
torch.sum(y).backward()
t1 = datetime.now()
print(t1 - t0)
t2 = datetime.now()
for i in range(10):
y = m(x)
torch.sum(y).backward()
t3 = datetime.now()
print(t3 - t2)
I would greatly appreciate any comment on this. Thanks in advance, and thanks for the implementation!
The text was updated successfully, but these errors were encountered:
Yes seems like in it's current state it's ~2.3-6x (forward) or ~3.4-9x (forward+backward) slower on CUDA (S5 gets faster with increasing sequence length; tested 1200-144000), on CPU S5 is 1.07-1.44x faster (forward+backward) or 1.03-1.4x slower (forward).
This is likely due to the associative scan function not being optimized yet for GPU, which means a lot of communication between CPU & GPU while LSTM does have an optimized CUDA kernel. Checking nvtop this seems to be accurate: LSTM=(RX: ~4GiB/s, TX: ~400MiB/s), S5=(RX: ~60MiB/s, TX: ~10MiB/s), lower bandwidth correlates with more individual requests/blocks.
Additionally the 'depth' of the graph for S5 should in theory be shallower which could be the reason backwards is faster on CPU.
There are some new implementations since I did my own (see pytorch thread: pytorch/pytorch#95408), but there are mixed reports on speed; from profiling it does seem like quite a bit of time is spent on stack/interleave functions which don't do any computation. From the paper it seems like an optimized version of S5 would be potentially ~10-60x faster than GRU (which should be similar to LSTM), but the reported figures could be a naive implementation rather than an optimized kernel.
Note that benchmarking on CUDA is not reliable due to async calls so I adapted your example to fix that:
I tried running the following script and found that S5 is far slower than PyTorch's LSTM. Is this supposed to be the case? Perhaps the scale at which I'm testing it is too small to realize the benefit?
I would greatly appreciate any comment on this. Thanks in advance, and thanks for the implementation!
The text was updated successfully, but these errors were encountered: