From a3d1111f65ac7cdc9e8f19760ee04298052f633d Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 28 Sep 2024 11:05:38 +0300 Subject: [PATCH] GGML_UNARY_OP_SWIGLU: Metal implementation We get ~2% speedup for PP-512(Phi-3.5-mini). --- ggml/src/ggml-metal.m | 29 +++++++++++++++++++++++++++++ ggml/src/ggml-metal.metal | 24 ++++++++++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 02794e3c..20ab3a9d 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -63,6 +63,8 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, GGML_METAL_KERNEL_TYPE_SILU, GGML_METAL_KERNEL_TYPE_SILU_4, + GGML_METAL_KERNEL_TYPE_SWIGLU, + GGML_METAL_KERNEL_TYPE_SWIGLU_4, GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, @@ -583,6 +585,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU_4, swiglu_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, ctx->support_simdgroup_reduction); @@ -884,6 +888,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_SWIGLU: return ggml_is_contiguous(op->src[0]); default: return false; @@ -1595,6 +1600,30 @@ static enum ggml_status ggml_metal_graph_compute( [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_SWIGLU: + { + int64_t n = ggml_nelements(dst); + GGML_ASSERT(ne0 == src0->ne[0]/2); + + id pipeline = nil; + + if (ne0 % 4 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU_4].pipeline; + n /= 4; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline; + } + + const int64_t stride = src0->nb[1]/sizeof(float); + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:2]; + [encoder setBytes:&stride length:sizeof(stride) atIndex:3]; + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; default: diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index e2e45029..93f4d6c5 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -398,6 +398,30 @@ kernel void kernel_silu_4( dst[tpig] = x / (1.0f + exp(-x)); } +kernel void kernel_swiglu( + device const float * src0, + device float * dst, + constant int64_t & ne0, + constant int64_t & stride, + uint tpig[[thread_position_in_grid]]) { + const int64_t row = tpig/ne0; + const int64_t idx = tpig%ne0; + const int64_t j = row*stride + idx; + dst[tpig] = src0[j] * src0[j + ne0] / (1.0f + exp(-src0[j])); +} + +kernel void kernel_swiglu_4( + device const float4 * src0, + device float4 * dst, + constant int64_t & ne0, + constant int64_t & stride, + uint tpig[[thread_position_in_grid]]) { + const int64_t row = tpig/(ne0/4); + const int64_t idx = tpig%(ne0/4); + const int64_t j = row*(stride/4) + idx; + dst[tpig] = src0[j] * src0[j + ne0/4] / (1.0f + exp(-src0[j])); +} + kernel void kernel_sqr( device const float * src0, device float * dst,