-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to compute MACs or FLOPs of mamba #110
Comments
We calc FLOPs based on the ref code, though it is very different from the real speed in practise. def flops_selective_scan_ref(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_Group=True, with_complex=False):
"""
u: r(B D L)
delta: r(B D L)
A: r(D N)
B: r(B N L)
C: r(B N L)
D: r(D)
z: r(B D L)
delta_bias: r(D), fp32
ignores:
[.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu]
"""
import numpy as np
# fvcore.nn.jit_handles
def get_flops_einsum(input_shapes, equation):
np_arrs = [np.zeros(s) for s in input_shapes]
optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1]
for line in optim.split("\n"):
if "optimized flop" in line.lower():
# divided by 2 because we count MAC (multiply-add counted as one flop)
flop = float(np.floor(float(line.split(":")[-1]) / 2))
return flop
assert not with_complex
flops = 0 # below code flops = 0
if False:
...
"""
dtype_in = u.dtype
u = u.float()
delta = delta.float()
if delta_bias is not None:
delta = delta + delta_bias[..., None].float()
if delta_softplus:
delta = F.softplus(delta)
batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
is_variable_B = B.dim() >= 3
is_variable_C = C.dim() >= 3
if A.is_complex():
if is_variable_B:
B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
if is_variable_C:
C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
else:
B = B.float()
C = C.float()
x = A.new_zeros((batch, dim, dstate))
ys = []
"""
flops += get_flops_einsum([[B, D, L], [D, N]], "bdl,dn->bdln")
if with_Group:
flops += get_flops_einsum([[B, D, L], [B, N, L], [B, D, L]], "bdl,bnl,bdl->bdln")
else:
flops += get_flops_einsum([[B, D, L], [B, D, N, L], [B, D, L]], "bdl,bdnl,bdl->bdln")
if False:
...
"""
deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
if not is_variable_B:
deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
else:
if B.dim() == 3:
deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
else:
B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
if is_variable_C and C.dim() == 4:
C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
last_state = None
"""
in_for_flops = B * D * N
if with_Group:
in_for_flops += get_flops_einsum([[B, D, N], [B, D, N]], "bdn,bdn->bd")
else:
in_for_flops += get_flops_einsum([[B, D, N], [B, N]], "bdn,bn->bd")
flops += L * in_for_flops
if False:
...
"""
for i in range(u.shape[2]):
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
if not is_variable_C:
y = torch.einsum('bdn,dn->bd', x, C)
else:
if C.dim() == 3:
y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
else:
y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
if i == u.shape[2] - 1:
last_state = x
if y.is_complex():
y = y.real * 2
ys.append(y)
y = torch.stack(ys, dim=2) # (batch dim L)
"""
if with_D:
flops += B * D * L
if with_Z:
flops += B * D * L
if False:
...
"""
out = y if D is None else y + u * rearrange(D, "d -> d 1")
if z is not None:
out = out * F.silu(z)
out = out.to(dtype=dtype_in)
"""
return flops
def selective_scan_flop_jit(inputs, outputs):
# xs, dts, As, Bs, Cs, Ds (skip), z (skip), dt_projs_bias (skip)
assert inputs[0].debugName().startswith("xs") # (B, D, L)
assert inputs[2].debugName().startswith("As") # (D, N)
assert inputs[3].debugName().startswith("Bs") # (D, N)
with_Group = len(inputs[3].type().sizes()) == 4
with_D = inputs[5].debugName().startswith("Ds")
if not with_D:
with_z = inputs[5].debugName().startswith("z")
else:
with_z = inputs[6].debugName().startswith("z")
B, D, L = inputs[0].type().sizes()
N = inputs[2].type().sizes()[1]
flops = flops_selective_scan_ref(B=B, L=L, D=D, N=N, with_D=with_D, with_Z=with_z, with_Group=with_Group)
return flops |
The formula we used is This is a brief explanation: Note that the cost of computing the input-dependent dt/B/C is baked into the linear layer FLOP counts above We ignore The Remaining flops are associative scan on
Summing these gives the |
Thank you for your quick reply. Can you explain that why is there 2*L associative operations, but not L? |
If you look at the algorithm for associative scan that's how it works. See https://en.wikipedia.org/wiki/Prefix_sum for example Also note that the above is not accounting for the expansion factor of the Mamba block. In other words the number of channels of the selective SSM scan is |
Many thanks. I think I've got the answer. |
Hi @MzeroMiko , did you able to figure out how to calculate FLOPs for selective scan? I used your script, and as you noted it is larger than what I expected? |
@llmexperiment For full script: def flops_selective_scan_fn(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_Group=True, with_complex=False):
"""
u: r(B D L)
delta: r(B D L)
A: r(D N)
B: r(B N L)
C: r(B N L)
D: r(D)
z: r(B D L)
delta_bias: r(D), fp32
ignores:
[.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu]
"""
assert not with_complex
# https://github.com/state-spaces/mamba/issues/110
flops = 9 * B * L * D * N
if with_D:
flops += B * D * L
if with_Z:
flops += B * D * L
return flops
def selective_scan_flop_jit(inputs, outputs):
print_jit_input_names(inputs)
B, D, L = inputs[0].type().sizes()
N = inputs[2].type().sizes()[1]
flops = flops_selective_scan_fn(B=B, L=L, D=D, N=N, with_D=True, with_Z=False, with_Group=True)
return flops |
I have a naive follow up question: If we use associative scan algorithm, in the wiki (https://en.wikipedia.org/wiki/Prefix_sum) for prefix sum it shows that the work-efficient version only takes O(T) while the faster span version takes O(T log T). May I ask whether the mamba kernel is more similar to the work-efficient version or the fast version. Because it seems to me the fast and slow version all takes forward/backward latency of scale O(\log T). But they require different number of cores to compute and have very different asymptotic growth with respect to sequence length T. |
We use the work-efficient version (Blelloch's scan). |
In a world with infinite parallelism the lower-span version may be faster by a constant. But GPUs have a lot of different constraints; we actually already max out its parallelism and the bottleneck is compute, so the work-efficient version is much faster. |
Hi @albertfgu , I find your response very informative, and I am trying to understand deeper. I have two quick questions.
|
|
How should I calculate the FLOPS for a standard Mamba Layer, and what would be an approximate value? Thank you very much |
thanks for your work, can you explain how to use this coumpute FLOPs of mamba with a input[B,L,D]? Thank you |
@lth456321
|
I had the same error and solved it by changing this:
at "mamba/mamba_ssm/ops/triton/layer_norm.py" line 365
During training, |
How about flops for mamba2 ? does any one know how to calculate it manually ? |
Dear Author: Thanks for your response, I notice you only consider |
How to compute MACs or FLOPs of mamba?
The text was updated successfully, but these errors were encountered: