Skip to content

Commit

Permalink
metal : add f16 support
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Jun 6, 2023
1 parent d5b111f commit 44f906e
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 11 deletions.
23 changes: 13 additions & 10 deletions ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,11 @@
GGML_METAL_DECL_KERNEL(relu);
GGML_METAL_DECL_KERNEL(soft_max);
GGML_METAL_DECL_KERNEL(diag_mask_inf);
GGML_METAL_DECL_KERNEL(get_rows_f16);
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
GGML_METAL_DECL_KERNEL(rms_norm);
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_DECL_KERNEL(rope);
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
Expand Down Expand Up @@ -130,10 +131,11 @@
GGML_METAL_ADD_KERNEL(relu);
GGML_METAL_ADD_KERNEL(soft_max);
GGML_METAL_ADD_KERNEL(diag_mask_inf);
GGML_METAL_ADD_KERNEL(get_rows_f16);
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
GGML_METAL_ADD_KERNEL(rms_norm);
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_ADD_KERNEL(rope);
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
Expand Down Expand Up @@ -498,6 +500,14 @@ void ggml_metal_graph_compute(

// use custom matrix x vector kernel
switch (src0t) {
case GGML_TYPE_F16:
{
GGML_ASSERT(ne02 == ne12);

nth0 = 64;
nth1 = 1;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
} break;
case GGML_TYPE_Q4_0:
{
GGML_ASSERT(ne02 == 1);
Expand All @@ -507,14 +517,6 @@ void ggml_metal_graph_compute(
nth1 = 4;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
} break;
case GGML_TYPE_F16:
{
GGML_ASSERT(ne02 == ne12);

nth0 = 32;
nth1 = 1;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
} break;
default: GGML_ASSERT(false && "not implemented");
};

Expand Down Expand Up @@ -551,6 +553,7 @@ void ggml_metal_graph_compute(
}

switch (src0->type) {
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
default: GGML_ASSERT(false && "not implemented");
}
Expand Down
16 changes: 16 additions & 0 deletions ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,22 @@ kernel void kernel_diag_mask_inf(
}
}

kernel void kernel_get_rows_f16(
device const void * src0,
device const int * src1,
device float * dst,
constant int64_t & ne00,
constant uint64_t & nb01,
constant uint64_t & nb1,
uint tpig[[thread_position_in_grid]]) {
const int i = tpig;
const int r = ((device int32_t *) src1)[i];

for (int j = 0; j < ne00; j++) {
dst[i*nb1 + j] = ((device half *) ((device char *) src0 + r*nb01))[j];
}
}

kernel void kernel_get_rows_q4_0(
device const void * src0,
device const int * src1,
Expand Down
3 changes: 2 additions & 1 deletion llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -961,7 +961,6 @@ static void llama_model_load_internal(
model.hparams = ml->file_loaders.at(0)->hparams;
llama_file_version file_version = ml->file_loaders.at(0)->file_version;
auto & hparams = model.hparams;
uint32_t n_ff = ((2*(4*hparams.n_embd)/3 + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult;

{
switch (hparams.n_layer) {
Expand All @@ -975,6 +974,8 @@ static void llama_model_load_internal(
hparams.n_ctx = n_ctx;
}

const uint32_t n_ff = ((2*(4*hparams.n_embd)/3 + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult;

{
fprintf(stderr, "%s: format = %s\n", __func__, llama_file_version_name(file_version));
fprintf(stderr, "%s: n_vocab = %u\n", __func__, hparams.n_vocab);
Expand Down

0 comments on commit 44f906e

Please sign in to comment.