Skip to content

Commit

Permalink
metal : add Q4_1 implementation (ggerganov#1785)
Browse files Browse the repository at this point in the history
23.3 ms / token, so just ~1% slower than q4_0.
Achieves 290 GB/s memory throughput.

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
  • Loading branch information
ikawrakow and Kawrakow authored Jun 10, 2023
1 parent 4f0154b commit e9b66ee
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 1 deletion.
16 changes: 15 additions & 1 deletion ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,14 @@
GGML_METAL_DECL_KERNEL(diag_mask_inf);
GGML_METAL_DECL_KERNEL(get_rows_f16);
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
GGML_METAL_DECL_KERNEL(get_rows_q2_k);
GGML_METAL_DECL_KERNEL(get_rows_q4_k);
GGML_METAL_DECL_KERNEL(get_rows_q6_k);
GGML_METAL_DECL_KERNEL(rms_norm);
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q2_k_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q4_k_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q6_k_f32);
Expand Down Expand Up @@ -141,12 +143,14 @@
GGML_METAL_ADD_KERNEL(diag_mask_inf);
GGML_METAL_ADD_KERNEL(get_rows_f16);
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
GGML_METAL_ADD_KERNEL(get_rows_q2_k);
GGML_METAL_ADD_KERNEL(get_rows_q4_k);
GGML_METAL_ADD_KERNEL(get_rows_q6_k);
GGML_METAL_ADD_KERNEL(rms_norm);
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q2_k_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q4_k_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q6_k_f32);
Expand Down Expand Up @@ -545,6 +549,15 @@ void ggml_metal_graph_compute(
nth1 = 8;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
} break;
case GGML_TYPE_Q4_1:
{
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);

nth0 = 8;
nth1 = 8;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
} break;
case GGML_TYPE_Q2_K:
{
GGML_ASSERT(ne02 == 1);
Expand Down Expand Up @@ -596,7 +609,7 @@ void ggml_metal_graph_compute(
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];

if (src0t == GGML_TYPE_Q4_0) {
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1) {
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else if (src0t == GGML_TYPE_Q2_K) {
Expand All @@ -623,6 +636,7 @@ void ggml_metal_graph_compute(
switch (src0->type) {
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_k]; break;
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break;
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_k]; break;
Expand Down
123 changes: 123 additions & 0 deletions ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@ typedef struct {
uint8_t qs[QK4_0 / 2]; // nibbles / quants
} block_q4_0;

#define QK4_1 32
typedef struct {
half d; // delta
half m; // min
uint8_t qs[QK4_1 / 2]; // nibbles / quants
} block_q4_1;

static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, int k) {
const int qk = QK4_0;

Expand All @@ -31,6 +38,27 @@ static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, i
}
}

static void dequantize_row_q4_1(device const block_q4_1 * x, device float * y, int k) {
const int qk = QK4_1;

assert(k % qk == 0);

const int nb = k / qk;

for (int i = 0; i < nb; i++) {
const half d = x[i].d;
const half m = x[i].m;

for (int j = 0; j < qk/2; ++j) {
const int x0 = (x[i].qs[j] & 0x0F);
const int x1 = (x[i].qs[j] >> 4);

y[i*qk + j + 0 ] = x0*d + m;
y[i*qk + j + qk/2] = x1*d + m;
}
}
}

kernel void kernel_add(
device const float * src0,
device const float * src1,
Expand Down Expand Up @@ -212,6 +240,22 @@ kernel void kernel_get_rows_q4_0(
(device float *) ((device char *) dst + i*nb1), ne00);
}

kernel void kernel_get_rows_q4_1(
device const void * src0,
device const int * src1,
device float * dst,
constant int64_t & ne00,
constant uint64_t & nb01,
constant uint64_t & nb1,
uint tpig[[thread_position_in_grid]]) {
const int i = tpig;
const int r = ((device int32_t *) src1)[i];

dequantize_row_q4_1(
(device const block_q4_1 *) ((device char *) src0 + r*nb01),
(device float *) ((device char *) dst + i*nb1), ne00);
}

kernel void kernel_rms_norm(
device const void * src0,
device float * dst,
Expand Down Expand Up @@ -350,6 +394,85 @@ kernel void kernel_mul_mat_q4_0_f32(
//}
}

kernel void kernel_mul_mat_q4_1_f32(
device const void * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant int64_t & ne11,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
threadgroup float * sum [[threadgroup(0)]],
uint2 tgpig[[threadgroup_position_in_grid]],
uint2 tpig[[thread_position_in_grid]],
uint2 tpitg[[thread_position_in_threadgroup]],
uint2 tptg[[threads_per_threadgroup]]) {
const int nb = ne00/QK4_1;

const int64_t r0 = tgpig.x;
const int64_t r1 = tgpig.y;

device const block_q4_1 * x = (device const block_q4_1 *) src0 + r0*nb;
device const float * y = (device const float *) src1 + r1*ne10;

const uint nth = tptg.x*tptg.y;
const uint ith = tptg.y*tpitg.x + tpitg.y;

const int ix = tpitg.y/4; // 0 or 1
const int iy = tpitg.y - 4*ix; // 0...3

const int first = 4 * iy;

float sumf = 0;

for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) {

const float d = (float)x[i].d;
const float m = (float)x[i].m;

device const uint8_t * xl = x[i].qs + first;
device const float * yl = y + i * QK4_1 + first;

float2 acc = {0.0f, 0.0f};

for (int j = 0; j < 4; ++j) {

acc[0] += yl[j+ 0] * (d * (xl[j] & 0xF) + m);
acc[1] += yl[j+16] * (d * (xl[j] >> 4) + m);

}

sumf += acc[0] + acc[1];
}

sum[ith] = sumf;

//
// Accumulate the sum from all threads in the threadgroup
//
threadgroup_barrier(mem_flags::mem_threadgroup);
if (ith%4 == 0) {
for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (ith%16 == 0) {
for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (ith == 0) {
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
dst[r1*ne0 + r0] = sum[0];
}
}

kernel void kernel_mul_mat_f16_f32(
device const char * src0,
device const char * src1,
Expand Down

0 comments on commit e9b66ee

Please sign in to comment.