Skip to content

Commit

Permalink
metal : faster q4_0 (ggerganov#1775)
Browse files Browse the repository at this point in the history
* metal : 8% faster q4_0

Avoid copying into local uchar4 anf float4.

* metal : 17% faster Q4_0

Use 64 threads in a thread group.

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
  • Loading branch information
ikawrakow and Kawrakow authored Jun 9, 2023
1 parent 72ff528 commit 245fc3c
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 16 deletions.
2 changes: 1 addition & 1 deletion ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ void ggml_metal_graph_compute(
GGML_ASSERT(ne12 == 1);

nth0 = 8;
nth1 = 4;
nth1 = 8;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
} break;
case GGML_TYPE_Q2_K:
Expand Down
34 changes: 19 additions & 15 deletions ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,8 @@ kernel void kernel_mul_mat_q4_0_f32(
uint2 tptg[[threads_per_threadgroup]]) {
const int nb = ne00/QK4_0;

const int8_t m8 = 8;

const int64_t r0 = tgpig.x;
const int64_t r1 = tgpig.y;

Expand All @@ -276,33 +278,34 @@ kernel void kernel_mul_mat_q4_0_f32(
const uint nth = tptg.x*tptg.y;
const uint ith = tptg.y*tpitg.x + tpitg.y;

sum[ith] = 0.0f;
const int ix = tpitg.y/4; // 0 or 1
const int iy = tpitg.y - 4*ix; // 0...3

for (int i = tpitg.x; i < nb; i += tptg.x) {
device const uchar4 * x0p = (device const uchar4 *) (x + i)->qs;
device const float4 * y0p = (device const float4 *) (y + i*QK4_0);
const int first = 4 * iy;

float sumf = 0;

const float d = (float)((x + i)->d);
for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) {

const uchar4 x0v = *(x0p + tpitg.y);
const float4 y0v = *(y0p + tpitg.y + 0);
const float4 y1v = *(y0p + tpitg.y + 4);
const float d = (float)x[i].d;

float acc = 0.0f;
device const uint8_t * xl = x[i].qs + first;
device const float * yl = y + i * QK4_0 + first;

float2 acc = {0.0f, 0.0f};

for (int j = 0; j < 4; ++j) {
const int x0 = x0v[j] & 0x0F;
const int x1 = x0v[j] >> 4;

const float y0 = y0v[j];
const float y1 = y1v[j];
acc[0] += yl[j+ 0] * ((int8_t)(xl[j] & 0xF) - m8);
acc[1] += yl[j+16] * ((int8_t)(xl[j] >> 4) - m8);

acc += (x0 - 8)*y0 + (x1 - 8)*y1;
}

sum[ith] += acc*d;
sumf += d * (acc[0] + acc[1]);
}

sum[ith] = sumf;

//
// Accumulate the sum from all threads in the threadgroup
// This version is slightly faster than the commented out one below,
Expand Down Expand Up @@ -357,6 +360,7 @@ kernel void kernel_mul_mat_f16_f32(
uint3 tpig[[thread_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 tptg[[threads_per_threadgroup]]) {

const int64_t r0 = tgpig.x;
const int64_t r1 = tgpig.y;
const int64_t im = tgpig.z;
Expand Down

0 comments on commit 245fc3c

Please sign in to comment.