Skip to content

Commit

Permalink
add wkv custom operator on GPU & support RWKV pretrained model (minds…
Browse files Browse the repository at this point in the history
  • Loading branch information
lvyufeng authored May 22, 2023
1 parent cf12344 commit 54072ef
Show file tree
Hide file tree
Showing 19 changed files with 1,045 additions and 109 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci_pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ jobs:
pip install -r download.txt
- name: Test with pytest
run: |
pytest -m 'not download' tests/ut
pytest -m 'not download and not gpu_only' tests/ut
release-test:
needs: pylint-check
Expand All @@ -102,7 +102,7 @@ jobs:
pip install mindspore==${{matrix.ms_version}}
- name: Test with pytest
run: |
pytest -m 'not download' --ignore=tests/ut/models tests/ut
pytest -m 'not download and not gpu_only' --ignore=tests/ut/models tests/ut
st-test:
needs: ut-test
Expand Down
227 changes: 227 additions & 0 deletions mindnlp/_csrc/cuda/wkv.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
#include <stdio.h>
#include <assert.h>

#define MIN_VALUE (-1e38)

template <typename F>
__global__ void kernel_forward(const int B, const int T, const int C,
const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
F *__restrict__ const _y) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
const int _b = idx / C;
const int _c = idx % C;
const int _offset = _b * T * C + _c;

F u = _u[_c];
F w = _w[_c];
const F *__restrict__ const k = _k + _offset;
const F *__restrict__ const v = _v + _offset;
F *__restrict__ const y = _y + _offset;

F p = 0, q = 0, o = MIN_VALUE;
// p and q are running sums divided by exp(o) (to avoid overflows)
for (int i = 0; i < T; i++) {
const int ii = i * C;

F no = max(o, u + k[ii]);
F A = exp(o - no);
F B = exp(u + k[ii] - no);
y[ii] = (A * p + B * v[ii]) / (A * q + B);

no = max(w + o, k[ii]);
A = exp(w + o - no);
B = exp(k[ii] - no);
p = A * p + B * v[ii];
q = A * q + B;
o = no;
}
}

template <typename F>
__global__ void kernel_forward_with_state(
const int B, const int T, const int C, const F *__restrict__ const _w, const F *__restrict__ const _u,
const F *__restrict__ const _k, const F *__restrict__ const _v, F *__restrict__ const _y, F *__restrict__ const _s
) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
const int _b = idx / C;
const int _c = idx % C;
const int _offset_s = _b * C * 3 + _c * 3;
const int _offset = _b * T * C + _c;

F u = _u[_c];
F w = _w[_c];
const F *__restrict__ const k = _k + _offset;
const F *__restrict__ const v = _v + _offset;
F *__restrict__ const y = _y + _offset;
F *__restrict__ const s = _s + _offset_s;

// aa and bb are running sums divided by exp(pp) (to avoid overflow)
F aa = s[0], bb = s[1], pp = s[2];
for (int i = 0; i < T; i++) {
const int ii = i * C;
const F kk = k[ii];
const F vv = v[ii];

F ww = u + kk;
F p = max(pp, ww);
F e1 = exp(pp - p);
F e2 = exp(ww - p);
y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2);

ww = w + pp;
p = max(ww, kk);
e1 = exp(ww - p);
e2 = exp(kk - p);
aa = e1 * aa + e2 * vv;
bb = e1 * bb + e2;
pp = p;
}
s[0] = aa;
s[1] = bb;
s[2] = pp;
}

template <typename F>
__global__ void kernel_backward(const int B, const int T, const int C,
const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _gy,
F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
const int _b = idx / C;
const int _c = idx % C;
const int _offset = _b * T * C + _c;

F u = _u[_c];
F w = _w[_c];
const F *__restrict__ const k = _k + _offset;
const F *__restrict__ const v = _v + _offset;
const F *__restrict__ const gy = _gy + _offset;

F *__restrict__ const gk = _gk + _offset;
F *__restrict__ const gv = _gv + _offset;

F y[Tmax], z[Tmax], zexp[Tmax];

F gw = 0, gu = 0;
F p = 0, q = 0;
F dpdw = 0, dqdw = 0;
F o = MIN_VALUE;
for (int i = 0; i < T; i++) {
const int ii = i * C;
F no = max(o, k[ii] + u);
F A = exp(o - no);
F B = exp(k[ii] + u - no);

F num = A * p + B * v[ii];
F iden = 1 / (A * q + B);

y[i] = num * iden;
z[i] = iden;
zexp[i] = k[ii] + u - no;

gw += gy[ii] * (dpdw - dqdw * y[i]) * iden * A;
gu += gy[ii] * (v[ii] - y[i]) * B * iden;

no = max(w + o, k[ii]);
A = exp(w + o - no);
B = exp(k[ii] - no);
dpdw = A * (p + dpdw);
dqdw = A * (q + dqdw);
p = A * p + B * v[ii];
q = A * q + B;
o = no;
}

F gp = 0, gq = 0;
o = MIN_VALUE;
for (int i = T - 1; i >= 0; i--) {
const int ii = i * C;
F A = gy[ii] * z[i] * exp(zexp[i]);
F B = exp(k[ii] + o);
gk[ii] = A * (v[ii] - y[i]) + B * (gp * v[ii] + gq);
gv[ii] = A + B * gp;

F no = max(w + o, zexp[i] - k[ii] - u);
A = exp(w + o - no);
B = gy[ii] * z[i] * exp(zexp[i] - k[ii] - u - no);
gp = A * gp + B;
gq = A * gq - B * y[i];
o = no;
}

// Multiply by w because the w -> -exp(w) preprocessing is halfway in the backwards pass, even though it's not in the forward pass
const int _offsetBC = _b * C + _c;
_gw[_offsetBC] += gw * _w[_c];
_gu[_offsetBC] += gu;
}

extern "C" {

int wkv_forward(int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes, void *stream,
void *extra) {
cudaStream_t custream = static_cast<cudaStream_t>(stream);
if (nparam != 5) return 1;
float *w = static_cast<float *>(params[0]);
float *u = static_cast<float *>(params[1]);
float *k = static_cast<float *>(params[2]);
float *v = static_cast<float *>(params[3]);
float *y = static_cast<float *>(params[4]);

int B = static_cast<int>(shapes[2][0]);
int T = static_cast<int>(shapes[2][1]);
int C = static_cast<int>(shapes[2][2]);

dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
assert(B * C % threadsPerBlock.x == 0);
dim3 numBlocks(B * C / threadsPerBlock.x);
kernel_forward<<<numBlocks, threadsPerBlock, 0, custream>>>(B, T, C, w, u, k, v, y);
return 0;
}

int wkv_forward_with_state(int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes, void *stream,
void *extra) {
cudaStream_t custream = static_cast<cudaStream_t>(stream);
if (nparam != 6) return 1;
float *w = static_cast<float *>(params[0]);
float *u = static_cast<float *>(params[1]);
float *k = static_cast<float *>(params[2]);
float *v = static_cast<float *>(params[3]);
float *s = static_cast<float *>(params[4]);
float *y = static_cast<float *>(params[5]);

int B = static_cast<int>(shapes[2][0]);
int T = static_cast<int>(shapes[2][1]);
int C = static_cast<int>(shapes[2][2]);

dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
assert(B * C % threadsPerBlock.x == 0);
dim3 numBlocks(B * C / threadsPerBlock.x);
kernel_forward_with_state<<<numBlocks, threadsPerBlock, 0, custream>>>(B, T, C, w, u, k, v, y, s);
return 0;
}

int wkv_backward(int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes, void *stream,
void *extra) {
cudaStream_t custream = static_cast<cudaStream_t>(stream);
if (nparam != 9) return 1;
float *w = static_cast<float *>(params[0]);
float *u = static_cast<float *>(params[1]);
float *k = static_cast<float *>(params[2]);
float *v = static_cast<float *>(params[3]);
float *gy = static_cast<float *>(params[4]);
float *gw = static_cast<float *>(params[5]);
float *gu = static_cast<float *>(params[6]);
float *gk = static_cast<float *>(params[7]);
float *gv = static_cast<float *>(params[8]);

int B = static_cast<int>(shapes[2][0]);
int T = static_cast<int>(shapes[2][1]);
int C = static_cast<int>(shapes[2][2]);

dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
assert(B * C % threadsPerBlock.x == 0);
dim3 numBlocks(B * C / threadsPerBlock.x);
kernel_backward<<<numBlocks, threadsPerBlock, 0, custream>>>(B, T, C, w, u, k, v, gy, gw, gu, gk, gv);
return 0;
}

}
2 changes: 0 additions & 2 deletions mindnlp/models/codegen/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,10 @@
from mindspore import Tensor, Parameter
from mindspore.common.initializer import Normal, initializer

from mindnlp.models.utils import logging
from mindnlp.models.utils.activations import ACT2FN
from mindnlp.models.codegen.codegen_config import CodeGenConfig
from mindnlp.abc import PreTrainedModel

logger = logging.get_logger(__name__)

_CHECKPOINT_FOR_DOC = "Salesforce/codegen-2B-mono"
_CONFIG_FOR_DOC = "CodeGenConfig"
Expand Down
6 changes: 3 additions & 3 deletions mindnlp/models/gpt_neo/gpt_neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@
import os
from typing import Union, Optional, Tuple
from functools import partial
import mindspore
import numpy as np
import mindspore
from mindspore import ops, nn, Parameter, Tensor, dtype_to_nptype
from mindspore.common.initializer import initializer, Normal
from mindnlp.models.utils import logging
from mindspore import log as logger

from mindnlp.models.utils.activations import ACT2FN
from mindnlp.models.gpt_neo.gpt_neo_config import GPTNeoConfig
from ...abc import PreTrainedModel

logger = logging.get_logger(__name__)


class GPTNeoSelfAttention(nn.Cell):
Expand Down
4 changes: 2 additions & 2 deletions mindnlp/models/llama/llama_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@
import numpy as np
from mindspore import nn, ops, Parameter, numpy, Tensor
from mindspore.common.initializer import initializer, Normal
from mindspore import log as logger

from mindnlp.abc import PreTrainedModel
from mindnlp.models.utils import logging
from .llama_hf_config import LlamaConfig
from ..utils.activations import ACT2FN

logger = logging.get_logger(__name__)


def _make_causal_mask(
Expand Down
3 changes: 1 addition & 2 deletions mindnlp/models/longformer/longformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,11 @@
from mindspore import Tensor
from mindspore import ops
from mindspore.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
from mindspore import log as logger

from ..utils.activations import ACT2FN
from .longformer_config import LongformerConfig
from ..utils import logging
from ...abc import PreTrainedModel
logger = logging.get_logger(__name__)

def apply_chunking_to_forward(
forward_fn: Callable[..., mindspore.Tensor], chunk_size: int, chunk_dim: int, *input_tensors
Expand Down
3 changes: 0 additions & 3 deletions mindnlp/models/luke/luke.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,9 @@
from mindspore.common.initializer import Normal, initializer

from mindnlp.models.luke.luke_config import LukeConfig
from ..utils import logging
from ..utils.activations import ACT2FN
from ...abc import PreTrainedModel

logger = logging.get_logger(__name__)


class LukeEmbeddings(nn.Cell):
"""
Expand Down
4 changes: 1 addition & 3 deletions mindnlp/models/nezha/nezha.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,13 @@
from mindspore import nn
from mindspore import ops
from mindspore import Tensor, Parameter
from mindspore import log as logger
from .nezha_config import NezhaConfig
from ..utils import logging
from ...abc import PreTrainedModel
from ..utils.utils import prune_linear_layer, find_pruneable_heads_and_indices, apply_chunking_to_forward
from ..utils.activations import ACT2FN


logger = logging.get_logger(__name__)

class NezhaRelativePositionsEncoding(nn.Cell):
"""Implement the Functional Relative Position Encoding"""

Expand Down
Empty file added mindnlp/models/rwkv/__init__.py
Empty file.
Loading

0 comments on commit 54072ef

Please sign in to comment.