-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 05eace2
Showing
8 changed files
with
1,461 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# Flash_Attention_V100 | ||
|
||
flash attention只支持Ampere架构以上的显卡,对于V100这个Volta架构的显卡并不支持,所以出于兴趣,我按照cutlass教程以及flash attention2的论文,写了这个适用于V100的版本,不过由于工作繁忙以及硬件条件限制,不能细致地进行性能调试,本Repo的性能并不能比得上pytorch的attention计算。当前forward的耗时相比于pytorch大约降低了40%,但是backward的耗时大约比pytorch多20%,两者相消。另外,该实现没有考虑边界条件,因此句子的长度要用right padding的方式,pad到32的倍数。这对正常训练并不会有影响,只需在计算loss时,将padding的地方忽略即可。 | ||
|
||
## 安装 | ||
在安装前,你需要确保: | ||
|
||
- PyTorch >= 2.0.1 | ||
- CUDA >= 11.6 | ||
- Linux OS | ||
- Cutlass源码 | ||
|
||
修改setup.py的146行,将这一行改为你下载的cutlass源码的位置 | ||
|
||
```py | ||
include_dirs=[ | ||
Path(this_dir) / "include", | ||
"/home/user/cutlass/include", | ||
], | ||
``` | ||
|
||
修改完毕后,执行命令进行源码安装 | ||
```bash | ||
python setup.py install --user | ||
``` | ||
|
||
## 用法 | ||
|
||
```python | ||
from flash_attn_v100 import flash_attn_func | ||
q = torch.empty((Z, N_CTX, H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=1).requires_grad_() | ||
k = torch.empty((Z, N_CTX, H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=1).requires_grad_() | ||
v = torch.empty((Z, N_CTX, H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=1).requires_grad_() | ||
cuda_out = flash_attn_func(q, k, v, sm_scale, causal) | ||
``` | ||
|
||
## 参考 | ||
- [Flash-Attention](https://github.com/Dao-AILab/flash-attention) | ||
- [CUTLASS](https://github.com/NVIDIA/cutlass) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
__version__ = "0.0.1" | ||
|
||
from flash_attn_v100.flash_attn_interface import ( | ||
flash_attn_func, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# Copyright (c) 2023, Tri Dao. | ||
|
||
import os | ||
from typing import Optional, Union | ||
|
||
import torch | ||
import torch.nn as nn | ||
# We need to import the CUDA kernels after importing torch | ||
import flash_attn_v100_cuda as flash_attn_cuda | ||
|
||
def _flash_attn_forward( | ||
q, k, v, softmax_scale, causal | ||
): | ||
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x | ||
q, k, v = [maybe_contiguous(x) for x in (q, k, v)] | ||
out, softmax_max, softmax_sum = flash_attn_cuda.fwd( | ||
q, | ||
k, | ||
v, | ||
None, | ||
softmax_scale, | ||
causal | ||
) | ||
return out, q, k, v, softmax_max, softmax_sum | ||
|
||
def _flash_attn_backward( | ||
dout, | ||
q, | ||
k, | ||
v, | ||
out, | ||
softmax_max, | ||
softmax_sum, | ||
dq, | ||
dk, | ||
dv, | ||
softmax_scale, | ||
causal, | ||
): | ||
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x | ||
# dq, dk, dv are allocated by us so they should already be contiguous | ||
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] | ||
D = out * dout | ||
D = D.sum(-1) | ||
dq, dk, dv, = flash_attn_cuda.bwd( | ||
dout, | ||
q, | ||
k, | ||
v, | ||
out, | ||
D, | ||
softmax_sum, | ||
softmax_max, | ||
dq, | ||
dk, | ||
dv, | ||
softmax_scale, | ||
causal, | ||
) | ||
return dq, dk, dv | ||
|
||
class FlashAttnFunc(torch.autograd.Function): | ||
@staticmethod | ||
def forward( | ||
ctx, q, k, v, softmax_scale, causal | ||
): | ||
if softmax_scale is None: | ||
softmax_scale = q.shape[-1] ** (-0.5) | ||
out, q, k, v, softmax_max, softmax_sum = _flash_attn_forward( | ||
q, | ||
k, | ||
v, | ||
softmax_scale, | ||
causal=causal, | ||
) | ||
ctx.save_for_backward(q, k, v, out, softmax_max, softmax_sum) | ||
ctx.softmax_scale = softmax_scale | ||
ctx.causal = causal | ||
return out | ||
|
||
@staticmethod | ||
def backward(ctx, dout, *args): | ||
q, k, v, out, softmax_max, softmax_sum = ctx.saved_tensors | ||
dq, dk, dv = torch.zeros_like(q), torch.zeros_like(k), torch.zeros_like(v) | ||
_flash_attn_backward( | ||
dout, | ||
q, | ||
k, | ||
v, | ||
out, | ||
softmax_max, | ||
softmax_sum, | ||
dq, | ||
dk, | ||
dv, | ||
ctx.softmax_scale, | ||
ctx.causal, | ||
) | ||
return dq, dk, dv, None, None | ||
|
||
def flash_attn_func( | ||
q, | ||
k, | ||
v, | ||
softmax_scale=None, | ||
causal=False, | ||
): | ||
return FlashAttnFunc.apply( | ||
q, k, v, softmax_scale, causal | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
void fused_mha_forward(const void *query_ptr, const void *key_ptr, const void *value_ptr, void *output_ptr, void *max_ptr, void *sum_ptr, | ||
int batch, int head, int m, int n, int k, float scale, bool causal, cudaStream_t stream); | ||
|
||
void fused_mha_backward(const void *query_ptr, const void *key_ptr, const void *value_ptr, | ||
void *output_ptr, void *d_output_ptr, void *d_ptr, void *max_ptr, void *sum_ptr, | ||
void *d_query_ptr, void *d_key_ptr, void *d_value_ptr, int batch, int head, int m, int n, int k, float scale, bool causal, cudaStream_t stream); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
#include <torch/python.h> | ||
#include <torch/nn/functional.h> | ||
#include <ATen/cuda/CUDAContext.h> | ||
#include <c10/cuda/CUDAGuard.h> | ||
|
||
#include <cutlass/numeric_types.h> | ||
#include "fused_mha.h" | ||
|
||
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") | ||
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") | ||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") | ||
|
||
std::vector<at::Tensor> | ||
mha_fwd(at::Tensor &q, // batch_size x num_heads x seqlen x head_size | ||
const at::Tensor &k, // batch_size x num_heads x seqlen x head_size | ||
const at::Tensor &v, // batch_size x num_heads x seqlen x head_size | ||
c10::optional<at::Tensor> &out_, // batch_size x num_heads x seqlen x head_size | ||
const float softmax_scale, | ||
bool is_causal) | ||
{ | ||
auto dprops = at::cuda::getCurrentDeviceProperties(); | ||
bool is_sm70 = dprops->major == 7 && dprops->minor == 0; | ||
TORCH_CHECK(is_sm70, "This repo only supports Volta GPUs."); | ||
|
||
auto q_dtype = q.dtype(); | ||
TORCH_CHECK(q_dtype == torch::kFloat16, | ||
"This repo only supports fp16 data type"); | ||
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); | ||
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); | ||
|
||
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); | ||
|
||
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); | ||
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); | ||
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); | ||
|
||
const auto sizes = q.sizes(); | ||
|
||
const int batch_size = sizes[0]; | ||
const int num_heads = sizes[2]; | ||
const int seqlen_q = sizes[1]; | ||
const int head_size = sizes[3]; | ||
TORCH_CHECK(batch_size > 0, "batch size must be postive"); | ||
TORCH_CHECK(head_size == 128, "current repo only supports head dimension 128, we will support more in the fulture"); | ||
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); | ||
CHECK_SHAPE(k, batch_size, seqlen_q, num_heads, head_size); | ||
CHECK_SHAPE(v, batch_size, seqlen_q, num_heads, head_size); | ||
|
||
at::Tensor out; | ||
if (out_.has_value()) { | ||
out = out_.value(); | ||
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); | ||
CHECK_DEVICE(out); | ||
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); | ||
CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size); | ||
} else { | ||
out = torch::empty_like(q); | ||
} | ||
// Otherwise the kernel will be launched from cuda:0 device | ||
// Cast to char to avoid compiler warning about narrowing | ||
at::cuda::CUDAGuard device_guard{(char)q.get_device()}; | ||
|
||
auto opts = q.options(); | ||
|
||
auto softmax_sum = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); | ||
auto softmax_max = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); | ||
auto stream = at::cuda::getCurrentCUDAStream().stream(); | ||
fused_mha_forward(q.data_ptr(), k.data_ptr(), v.data_ptr(), out.data_ptr(), softmax_max.data_ptr(), softmax_sum.data_ptr(), | ||
batch_size, num_heads, seqlen_q, seqlen_q, head_size, softmax_scale, is_causal, stream); | ||
|
||
return {out, softmax_max, softmax_sum}; | ||
} | ||
|
||
std::vector<at::Tensor> | ||
mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og | ||
const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size | ||
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size | ||
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size | ||
const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size | ||
const at::Tensor &D, // batch_size x seqlen_q x num_heads x head_size | ||
const at::Tensor &softmax_sum, // b x h x seqlen_q | ||
const at::Tensor &softmax_max, // b x h x seqlen_q | ||
c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size | ||
c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size | ||
c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size | ||
const float softmax_scale, | ||
const bool is_causal) | ||
{ | ||
auto dprops = at::cuda::getCurrentDeviceProperties(); | ||
bool is_sm70 = dprops->major == 7 && dprops->minor == 0; | ||
TORCH_CHECK(is_sm70, "This repo only supports Volta GPUs."); | ||
|
||
auto q_dtype = q.dtype(); | ||
TORCH_CHECK(q_dtype == torch::kFloat16, | ||
"This repo only supports fp16 data type"); | ||
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); | ||
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); | ||
|
||
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); | ||
CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_sum); CHECK_DEVICE(softmax_max); | ||
|
||
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); | ||
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); | ||
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); | ||
TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); | ||
TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); | ||
|
||
const auto sizes = q.sizes(); | ||
|
||
const int batch_size = sizes[0]; | ||
const int num_heads = sizes[2]; | ||
const int seqlen_q = sizes[1]; | ||
const int head_size = sizes[3]; | ||
TORCH_CHECK(batch_size > 0, "batch size must be postive"); | ||
TORCH_CHECK(head_size == 128, "current repo only supports head dimension 128, we will support more in the fulture"); | ||
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); | ||
CHECK_SHAPE(k, batch_size, seqlen_q, num_heads, head_size); | ||
CHECK_SHAPE(v, batch_size, seqlen_q, num_heads, head_size); | ||
CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size); | ||
CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size); | ||
CHECK_SHAPE(softmax_sum, batch_size, num_heads, seqlen_q); | ||
CHECK_SHAPE(softmax_max, batch_size, num_heads, seqlen_q); | ||
CHECK_SHAPE(D, batch_size, seqlen_q, num_heads); | ||
auto opts = q.options(); | ||
at::Tensor dq, dk, dv; | ||
if (dq_.has_value()) { | ||
dq = dq_.value(); | ||
TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q"); | ||
CHECK_DEVICE(dq); | ||
TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); | ||
CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size); | ||
} else { | ||
// dq = torch::empty_like(q); | ||
dq = torch::zeros({batch_size, seqlen_q, num_heads, head_size}, opts.dtype(at::kHalf)); | ||
} | ||
if (dk_.has_value()) { | ||
dk = dk_.value(); | ||
TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q"); | ||
CHECK_DEVICE(dk); | ||
TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); | ||
CHECK_SHAPE(dk, batch_size, seqlen_q, num_heads, head_size); | ||
} else { | ||
dk = torch::empty_like(k); | ||
} | ||
if (dv_.has_value()) { | ||
dv = dv_.value(); | ||
TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q"); | ||
CHECK_DEVICE(dv); | ||
TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); | ||
CHECK_SHAPE(dv, batch_size, seqlen_q, num_heads, head_size); | ||
} else { | ||
dv = torch::empty_like(k); | ||
} | ||
|
||
at::cuda::CUDAGuard device_guard{(char)q.get_device()}; | ||
|
||
auto stream = at::cuda::getCurrentCUDAStream().stream(); | ||
fused_mha_backward(q.data_ptr(), k.data_ptr(), v.data_ptr(), | ||
out.data_ptr(), dout.data_ptr(), D.data_ptr(), softmax_max.data_ptr(), softmax_sum.data_ptr(), | ||
dq.data_ptr(), dk.data_ptr(), dv.data_ptr(), batch_size, num_heads, seqlen_q, seqlen_q, head_size, softmax_scale, is_causal, stream); | ||
return {dq, dk, dv}; | ||
} | ||
|
||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | ||
m.doc() = "FlashAttention"; | ||
m.def("fwd", &mha_fwd, "Forward pass"); | ||
m.def("bwd", &mha_bwd, "Backward pass"); | ||
} |
Oops, something went wrong.