Skip to content

Commit

Permalink
1. accelerated Gemm applied in the end of classification models
Browse files Browse the repository at this point in the history
2. slightly accelerated depthwise convolution
  • Loading branch information
vpisarev committed Aug 6, 2022
1 parent d0a9d42 commit 05581bf
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 19 deletions.
28 changes: 21 additions & 7 deletions lib/NN/ConstFold.fx
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,42 @@
// the result of operation is represented as the constant.

import Dynvec
import Ast, InferShapes, RunOp
import Ast, InferShapes, RunOp, OpPermute

fun cfold(model: Ast.nnmodel_t)
{
val graph = cfold_graph(model, model.graph)
val usecounts = model.use_counts()
val graph = cfold_graph(model, model.graph, usecounts)
model.{graph = graph}
}

fun cfold_graph(model: Ast.nnmodel_t, graph: Ast.nngraph_t)
fun cfold_graph(model: Ast.nnmodel_t, graph: Ast.nngraph_t, usecounts: int [])
{
val new_prog = Dynvec.create(0, Ast.NN_Nop)
var have_changes = false
for op <- graph.prog {
val opt_op = match op {
| Ast.NN_If {name, then_branch, else_branch, t_inp, t_out} =>
val then_branch = cfold_graph(model, then_branch)
val else_branch = cfold_graph(model, else_branch)
val then_branch = cfold_graph(model, then_branch, usecounts)
val else_branch = cfold_graph(model, else_branch, usecounts)
Some(Ast.NN_If {name=name, then_branch=then_branch,
else_branch=else_branch, t_inp=t_inp, t_out=t_out})
| Ast.NN_Loop {name, body, t_trip_count, t_cond_in, t_v_in, t_v_out} =>
val body = cfold_graph(model, body)
val body = cfold_graph(model, body, usecounts)
Some(Ast.NN_Loop {name=name, body=body, t_trip_count=t_trip_count,
t_cond_in=t_cond_in, t_v_in=t_v_in, t_v_out=t_v_out})
| Ast.NN_Gemm {name, alpha, beta, transA, transB=true, t_A, t_B, t_bias, t_out}
when model.isconst(t_B) && usecounts[t_B] == 1 =>
val B = model.tensors[t_B]
val Bt_shape = Ast.nnshape_t {shape=[B.shape.shape[1], B.shape.shape[0]], layout=B.shape.layout}
val Bt = Ast.mktensor(Bt_shape, B.elemtype())
OpPermute.run_transpose(B, [1, 0], Bt)
println(f"B_shape={B.shape}, Bt_shape={Bt.shape}")
model.tensors[t_B] = Bt
model.args[t_B].shape = Bt.shape
have_changes = true
Some(Ast.NN_Gemm {name=name, alpha=alpha, beta=beta, transA=transA,
transB=false, t_A=t_A, t_B=t_B, t_bias=t_bias, t_out=t_out})
| _ =>
val (inps, outs) = op.get_inputs_outputs()
if !all(for t_inp <- inps {model.isconst(t_inp)}) ||
Expand All @@ -57,7 +71,7 @@ fun cfold_graph(model: Ast.nnmodel_t, graph: Ast.nngraph_t)
| _ => {}
}
}
if new_prog.count == graph.prog.size() {graph}
if new_prog.count == graph.prog.size() && !have_changes {graph}
else {
val {name, inpargs, outargs} = graph
Ast.NN_Graph {name=name, inpargs=inpargs, outargs=outargs,
Expand Down
44 changes: 37 additions & 7 deletions lib/NN/OpConv_Depthwise.fx
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,12 @@ static int _fx_depthwise_conv2d_f32(int N, int C, int Hi, int Wi, int H0, int W0
if (useSIMD) {
if (is3x3) {
if (dy0 == 3) {
for (; x0 <= x1 - FX_VEC_NLANES; x0 += FX_VEC_NLANES) {
for (; x0 < x1; x0 += FX_VEC_NLANES) {
if (x0 + FX_VEC_NLANES > x1) {
if (x0 <= inner_xleft)
break;
x0 = x1 - FX_VEC_NLANES;
}
int xi_ = x0*stride_x - pad_left;
const float* inptr_xi = inptr + Wi*yi_ + xi_;
float32x4_t s0, s1, s2;
Expand Down Expand Up @@ -173,7 +178,12 @@ static int _fx_depthwise_conv2d_f32(int N, int C, int Hi, int Wi, int H0, int W0
vst1q_f32(outptr + W0*2 + x0, s2);
}
} else {
for (; x0 <= x1 - FX_VEC_NLANES; x0 += FX_VEC_NLANES) {
for (; x0 < x1; x0 += FX_VEC_NLANES) {
if (x0 + FX_VEC_NLANES > x1) {
if (x0 <= inner_xleft)
break;
x0 = x1 - FX_VEC_NLANES;
}
int xi_ = x0*stride_x - pad_left;
const float* inptr_xi = inptr + Wi*yi_ + xi_;
float32x4_t s0 = vfmaq_f32(vbias, vld1q_f32(inptr_xi + ofstab[0]), w0);
Expand All @@ -194,7 +204,12 @@ static int _fx_depthwise_conv2d_f32(int N, int C, int Hi, int Wi, int H0, int W0
}
}
} else {
for (; x0 <= x1 - FX_VEC_NLANES; x0 += FX_VEC_NLANES) {
for (; x0 < x1; x0 += FX_VEC_NLANES) {
if (x0 + FX_VEC_NLANES > x1) {
if (x0 <= inner_xleft)
break;
x0 = x1 - FX_VEC_NLANES;
}
int xi_ = x0*stride_x - pad_left, k = 0;
const float* inptr_xi = inptr + Wi*yi_ + xi_;
float32x4_t s0 = vbias;
Expand Down Expand Up @@ -359,8 +374,13 @@ static int _fx_depthwise_conv2d_f16(int N, int C, int Hi, int Wi, int H0, int W0
if (useSIMD) {
if (is3x3) {
if (dy0 == 3) {
for (; x0 <= x1 - FX_VEC_F16_NLANES; x0 += FX_VEC_F16_NLANES) {
int xi_ = x0*stride_x - pad_left;
for (; x0 < x1; x0 += FX_VEC_F16_NLANES) {
if (x0 + FX_VEC_F16_NLANES > x1) {
if (x0 <= inner_xleft)
break;
x0 = x1 - FX_VEC_F16_NLANES;
}
int xi_ = x0 - pad_left;
const fx_f16* inptr_xi = inptr + Wi*yi_ + xi_;
float16x8_t s0, s1, s2;
float16x8_t x00 = vld1q_f16(inptr_xi);
Expand Down Expand Up @@ -427,7 +447,12 @@ static int _fx_depthwise_conv2d_f16(int N, int C, int Hi, int Wi, int H0, int W0
vst1q_f16(outptr + W0*2 + x0, s2);
}
} else {
for (; x0 <= x1 - FX_VEC_F16_NLANES; x0 += FX_VEC_F16_NLANES) {
for (; x0 < x1; x0 += FX_VEC_F16_NLANES) {
if (x0 + FX_VEC_F16_NLANES > x1) {
if (x0 <= inner_xleft)
break;
x0 = x1 - FX_VEC_F16_NLANES;
}
int xi_ = x0*stride_x - pad_left;
const fx_f16* inptr_xi = inptr + Wi*yi_ + xi_;
float16x8_t s0 = vfmaq_f16(vbias, vld1q_f16(inptr_xi + ofstab[0]), w0);
Expand All @@ -448,7 +473,12 @@ static int _fx_depthwise_conv2d_f16(int N, int C, int Hi, int Wi, int H0, int W0
}
}
} else {
for (; x0 <= x1 - FX_VEC_F16_NLANES; x0 += FX_VEC_F16_NLANES) {
for (; x0 < x1; x0 += FX_VEC_F16_NLANES) {
if (x0 + FX_VEC_F16_NLANES > x1) {
if (x0 <= inner_xleft)
break;
x0 = x1 - FX_VEC_F16_NLANES;
}
int xi_ = x0*stride_x - pad_left, k = 0;
const fx_f16* inptr_xi = inptr + Wi*yi_ + xi_;
float16x8_t s0 = vbias;
Expand Down
30 changes: 25 additions & 5 deletions lib/NN/OpPooling.fx
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,12 @@ static void _fx_maxpool_2d_f32(int nc, const char* inptr_, char* outptr_,
#ifdef __ARM_NEON
if (useSIMD) {
if (is3x3) {
for (; x0 <= x1 - vec_nlanes; x0 += vec_nlanes) {
for (; x0 < x1; x0 += vec_nlanes) {
if (x0 + vec_nlanes > x1) {
if (x0 <= inner_x0)
break;
x0 = x1 - vec_nlanes;
}
int xi_ = x0*stride_x - pad_left;
const float* inptr_xi = inptr + Wi*yi_ + xi_;
float32x4_t s0 = vld1q_f32(inptr_xi + ofstab[0]);
Expand All @@ -185,7 +190,12 @@ static void _fx_maxpool_2d_f32(int nc, const char* inptr_, char* outptr_,
vst1q_f32(outptr + x0, s0);
}
} else {
for (; x0 <= x1 - vec_nlanes; x0 += vec_nlanes) {
for (; x0 < x1; x0 += vec_nlanes) {
if (x0 + vec_nlanes > x1) {
if (x0 <= inner_x0)
break;
x0 = x1 - vec_nlanes;
}
int xi_ = x0*stride_x - pad_left, k = 0;
const float* inptr_xi = inptr + Wi*yi_ + xi_;
float32x4_t s0 = vld1q_f32(inptr_xi + ofstab[0]);
Expand Down Expand Up @@ -226,7 +236,7 @@ static void _fx_maxpool_2d_f16(int nc, const char* inptr_, char* outptr_,
int pad_top = pool->pad_top, pad_left = pool->pad_left;
const int* yxtab = pool->yxtab;
const int* ofstab = pool->ofstab;
const int vec_nlanes = FX_VEC_NLANES*2;
const int vec_nlanes = FX_VEC_F16_NLANES;

bool useSIMD = stride_x == 1 && inner_x0 < W0;
bool is3x3 = pool->Hk == 3 && pool->Wk == 3;
Expand Down Expand Up @@ -256,7 +266,12 @@ static void _fx_maxpool_2d_f16(int nc, const char* inptr_, char* outptr_,
x1 = inner_x1;
if (useSIMD) {
if (is3x3) {
for (; x0 <= x1 - vec_nlanes; x0 += vec_nlanes) {
for (; x0 < x1; x0 += vec_nlanes) {
if (x0 + vec_nlanes > x1) {
if (x0 <= inner_x0)
break;
x0 = x1 - vec_nlanes;
}
int xi_ = x0*stride_x - pad_left;
const __fp16* inptr_xi = inptr + Wi*yi_ + xi_;
float16x8_t s0 = vld1q_f16(inptr_xi + ofstab[0]);
Expand All @@ -275,7 +290,12 @@ static void _fx_maxpool_2d_f16(int nc, const char* inptr_, char* outptr_,
vst1q_f16(outptr + x0, s0);
}
} else {
for (; x0 <= x1 - vec_nlanes; x0 += vec_nlanes) {
for (; x0 < x1; x0 += vec_nlanes) {
if (x0 + vec_nlanes > x1) {
if (x0 <= inner_x0)
break;
x0 = x1 - vec_nlanes;
}
int xi_ = x0*stride_x - pad_left, k = 0;
const __fp16* inptr_xi = inptr + Wi*yi_ + xi_;
float16x8_t s0 = vld1q_f16(inptr_xi + ofstab[0]);
Expand Down

0 comments on commit 05581bf

Please sign in to comment.