Skip to content

Commit

Permalink
flash_attention_v100 v0.0.1
Browse files Browse the repository at this point in the history
  • Loading branch information
ZRayZzz committed Feb 19, 2024
0 parents commit 05eace2
Show file tree
Hide file tree
Showing 8 changed files with 1,461 additions and 0 deletions.
39 changes: 39 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Flash_Attention_V100

flash attention只支持Ampere架构以上的显卡,对于V100这个Volta架构的显卡并不支持,所以出于兴趣,我按照cutlass教程以及flash attention2的论文,写了这个适用于V100的版本,不过由于工作繁忙以及硬件条件限制,不能细致地进行性能调试,本Repo的性能并不能比得上pytorch的attention计算。当前forward的耗时相比于pytorch大约降低了40%,但是backward的耗时大约比pytorch多20%,两者相消。另外,该实现没有考虑边界条件,因此句子的长度要用right padding的方式,pad到32的倍数。这对正常训练并不会有影响,只需在计算loss时,将padding的地方忽略即可。

## 安装
在安装前,你需要确保:

- PyTorch >= 2.0.1
- CUDA >= 11.6
- Linux OS
- Cutlass源码

修改setup.py的146行,将这一行改为你下载的cutlass源码的位置

```py
include_dirs=[
Path(this_dir) / "include",
"/home/user/cutlass/include",
],
```

修改完毕后,执行命令进行源码安装
```bash
python setup.py install --user
```

## 用法

```python
from flash_attn_v100 import flash_attn_func
q = torch.empty((Z, N_CTX, H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=1).requires_grad_()
k = torch.empty((Z, N_CTX, H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=1).requires_grad_()
v = torch.empty((Z, N_CTX, H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=1).requires_grad_()
cuda_out = flash_attn_func(q, k, v, sm_scale, causal)
```

## 参考
- [Flash-Attention](https://github.com/Dao-AILab/flash-attention)
- [CUTLASS](https://github.com/NVIDIA/cutlass)
5 changes: 5 additions & 0 deletions flash_attn_v100/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
__version__ = "0.0.1"

from flash_attn_v100.flash_attn_interface import (
flash_attn_func,
)
110 changes: 110 additions & 0 deletions flash_attn_v100/flash_attn_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright (c) 2023, Tri Dao.

import os
from typing import Optional, Union

import torch
import torch.nn as nn
# We need to import the CUDA kernels after importing torch
import flash_attn_v100_cuda as flash_attn_cuda

def _flash_attn_forward(
q, k, v, softmax_scale, causal
):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, softmax_max, softmax_sum = flash_attn_cuda.fwd(
q,
k,
v,
None,
softmax_scale,
causal
)
return out, q, k, v, softmax_max, softmax_sum

def _flash_attn_backward(
dout,
q,
k,
v,
out,
softmax_max,
softmax_sum,
dq,
dk,
dv,
softmax_scale,
causal,
):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
D = out * dout
D = D.sum(-1)
dq, dk, dv, = flash_attn_cuda.bwd(
dout,
q,
k,
v,
out,
D,
softmax_sum,
softmax_max,
dq,
dk,
dv,
softmax_scale,
causal,
)
return dq, dk, dv

class FlashAttnFunc(torch.autograd.Function):
@staticmethod
def forward(
ctx, q, k, v, softmax_scale, causal
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
out, q, k, v, softmax_max, softmax_sum = _flash_attn_forward(
q,
k,
v,
softmax_scale,
causal=causal,
)
ctx.save_for_backward(q, k, v, out, softmax_max, softmax_sum)
ctx.softmax_scale = softmax_scale
ctx.causal = causal
return out

@staticmethod
def backward(ctx, dout, *args):
q, k, v, out, softmax_max, softmax_sum = ctx.saved_tensors
dq, dk, dv = torch.zeros_like(q), torch.zeros_like(k), torch.zeros_like(v)
_flash_attn_backward(
dout,
q,
k,
v,
out,
softmax_max,
softmax_sum,
dq,
dk,
dv,
ctx.softmax_scale,
ctx.causal,
)
return dq, dk, dv, None, None

def flash_attn_func(
q,
k,
v,
softmax_scale=None,
causal=False,
):
return FlashAttnFunc.apply(
q, k, v, softmax_scale, causal
)
6 changes: 6 additions & 0 deletions include/fused_mha.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
void fused_mha_forward(const void *query_ptr, const void *key_ptr, const void *value_ptr, void *output_ptr, void *max_ptr, void *sum_ptr,
int batch, int head, int m, int n, int k, float scale, bool causal, cudaStream_t stream);

void fused_mha_backward(const void *query_ptr, const void *key_ptr, const void *value_ptr,
void *output_ptr, void *d_output_ptr, void *d_ptr, void *max_ptr, void *sum_ptr,
void *d_query_ptr, void *d_key_ptr, void *d_value_ptr, int batch, int head, int m, int n, int k, float scale, bool causal, cudaStream_t stream);
168 changes: 168 additions & 0 deletions kernel/fused_mha_api.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
#include <torch/python.h>
#include <torch/nn/functional.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

#include <cutlass/numeric_types.h>
#include "fused_mha.h"

#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")

std::vector<at::Tensor>
mha_fwd(at::Tensor &q, // batch_size x num_heads x seqlen x head_size
const at::Tensor &k, // batch_size x num_heads x seqlen x head_size
const at::Tensor &v, // batch_size x num_heads x seqlen x head_size
c10::optional<at::Tensor> &out_, // batch_size x num_heads x seqlen x head_size
const float softmax_scale,
bool is_causal)
{
auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm70 = dprops->major == 7 && dprops->minor == 0;
TORCH_CHECK(is_sm70, "This repo only supports Volta GPUs.");

auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kFloat16,
"This repo only supports fp16 data type");
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");

CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);

TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");

const auto sizes = q.sizes();

const int batch_size = sizes[0];
const int num_heads = sizes[2];
const int seqlen_q = sizes[1];
const int head_size = sizes[3];
TORCH_CHECK(batch_size > 0, "batch size must be postive");
TORCH_CHECK(head_size == 128, "current repo only supports head dimension 128, we will support more in the fulture");
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
CHECK_SHAPE(k, batch_size, seqlen_q, num_heads, head_size);
CHECK_SHAPE(v, batch_size, seqlen_q, num_heads, head_size);

at::Tensor out;
if (out_.has_value()) {
out = out_.value();
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
CHECK_DEVICE(out);
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size);
} else {
out = torch::empty_like(q);
}
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};

auto opts = q.options();

auto softmax_sum = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
auto softmax_max = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
auto stream = at::cuda::getCurrentCUDAStream().stream();
fused_mha_forward(q.data_ptr(), k.data_ptr(), v.data_ptr(), out.data_ptr(), softmax_max.data_ptr(), softmax_sum.data_ptr(),
batch_size, num_heads, seqlen_q, seqlen_q, head_size, softmax_scale, is_causal, stream);

return {out, softmax_max, softmax_sum};
}

std::vector<at::Tensor>
mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og
const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &D, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &softmax_sum, // b x h x seqlen_q
const at::Tensor &softmax_max, // b x h x seqlen_q
c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
const float softmax_scale,
const bool is_causal)
{
auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm70 = dprops->major == 7 && dprops->minor == 0;
TORCH_CHECK(is_sm70, "This repo only supports Volta GPUs.");

auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kFloat16,
"This repo only supports fp16 data type");
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");

CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_sum); CHECK_DEVICE(softmax_max);

TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");

const auto sizes = q.sizes();

const int batch_size = sizes[0];
const int num_heads = sizes[2];
const int seqlen_q = sizes[1];
const int head_size = sizes[3];
TORCH_CHECK(batch_size > 0, "batch size must be postive");
TORCH_CHECK(head_size == 128, "current repo only supports head dimension 128, we will support more in the fulture");
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
CHECK_SHAPE(k, batch_size, seqlen_q, num_heads, head_size);
CHECK_SHAPE(v, batch_size, seqlen_q, num_heads, head_size);
CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size);
CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size);
CHECK_SHAPE(softmax_sum, batch_size, num_heads, seqlen_q);
CHECK_SHAPE(softmax_max, batch_size, num_heads, seqlen_q);
CHECK_SHAPE(D, batch_size, seqlen_q, num_heads);
auto opts = q.options();
at::Tensor dq, dk, dv;
if (dq_.has_value()) {
dq = dq_.value();
TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
CHECK_DEVICE(dq);
TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size);
} else {
// dq = torch::empty_like(q);
dq = torch::zeros({batch_size, seqlen_q, num_heads, head_size}, opts.dtype(at::kHalf));
}
if (dk_.has_value()) {
dk = dk_.value();
TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
CHECK_DEVICE(dk);
TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
CHECK_SHAPE(dk, batch_size, seqlen_q, num_heads, head_size);
} else {
dk = torch::empty_like(k);
}
if (dv_.has_value()) {
dv = dv_.value();
TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
CHECK_DEVICE(dv);
TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
CHECK_SHAPE(dv, batch_size, seqlen_q, num_heads, head_size);
} else {
dv = torch::empty_like(k);
}

at::cuda::CUDAGuard device_guard{(char)q.get_device()};

auto stream = at::cuda::getCurrentCUDAStream().stream();
fused_mha_backward(q.data_ptr(), k.data_ptr(), v.data_ptr(),
out.data_ptr(), dout.data_ptr(), D.data_ptr(), softmax_max.data_ptr(), softmax_sum.data_ptr(),
dq.data_ptr(), dk.data_ptr(), dv.data_ptr(), batch_size, num_heads, seqlen_q, seqlen_q, head_size, softmax_scale, is_causal, stream);
return {dq, dk, dv};
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "FlashAttention";
m.def("fwd", &mha_fwd, "Forward pass");
m.def("bwd", &mha_bwd, "Backward pass");
}
Loading

0 comments on commit 05eace2

Please sign in to comment.