Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gelu implementation to metal. #1770

Merged
merged 1 commit into from
Jun 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Add gelu implementation to metal.
  • Loading branch information
manyoso committed Jun 9, 2023
commit 45d34f4041e3ed4c9cc2cebe2db66d39dfa594bf
16 changes: 16 additions & 0 deletions ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
GGML_METAL_DECL_KERNEL(scale);
GGML_METAL_DECL_KERNEL(silu);
GGML_METAL_DECL_KERNEL(relu);
GGML_METAL_DECL_KERNEL(gelu);
GGML_METAL_DECL_KERNEL(soft_max);
GGML_METAL_DECL_KERNEL(diag_mask_inf);
GGML_METAL_DECL_KERNEL(get_rows_f16);
Expand Down Expand Up @@ -129,6 +130,7 @@
GGML_METAL_ADD_KERNEL(scale);
GGML_METAL_ADD_KERNEL(silu);
GGML_METAL_ADD_KERNEL(relu);
GGML_METAL_ADD_KERNEL(gelu);
GGML_METAL_ADD_KERNEL(soft_max);
GGML_METAL_ADD_KERNEL(diag_mask_inf);
GGML_METAL_ADD_KERNEL(get_rows_f16);
Expand Down Expand Up @@ -408,6 +410,20 @@ void ggml_metal_graph_compute(

[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_GELU:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}

[encoder setComputePipelineState:ctx->pipeline_gelu];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];

const int64_t n = ggml_nelements(dst);

[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_SOFT_MAX:
{
if (encoder == nil) {
Expand Down
11 changes: 11 additions & 0 deletions ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,17 @@ kernel void kernel_relu(
dst[tpig] = max(0.0f, src0[tpig]);
}

constant float GELU_COEF_A = 0.044715f;
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;

kernel void kernel_gelu(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
float x = src0[tpig];
dst[tpig] = 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
}

kernel void kernel_soft_max(
device const float * src0,
device float * dst,
Expand Down