From a8e932b734bf7ed492088b2fc322cd0eecbbd629 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 30 Sep 2024 10:48:28 +0300 Subject: [PATCH] Fused y*unary(x) op: Metal --- ggml/src/ggml-metal.m | 32 +++++++++++++++++++++++++ ggml/src/ggml-metal.metal | 50 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+) diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index dcdd0efe..4badc7a7 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -56,13 +56,18 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_CLAMP, GGML_METAL_KERNEL_TYPE_TANH, GGML_METAL_KERNEL_TYPE_RELU, + GGML_METAL_KERNEL_TYPE_MUL_RELU, GGML_METAL_KERNEL_TYPE_SIGMOID, GGML_METAL_KERNEL_TYPE_GELU, GGML_METAL_KERNEL_TYPE_GELU_4, + GGML_METAL_KERNEL_TYPE_MUL_GELU, + GGML_METAL_KERNEL_TYPE_MUL_GELU_4, GGML_METAL_KERNEL_TYPE_GELU_QUICK, GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, GGML_METAL_KERNEL_TYPE_SILU, GGML_METAL_KERNEL_TYPE_SILU_4, + GGML_METAL_KERNEL_TYPE_MUL_SILU, + GGML_METAL_KERNEL_TYPE_MUL_SILU_4, GGML_METAL_KERNEL_TYPE_SWIGLU, GGML_METAL_KERNEL_TYPE_SWIGLU_4, GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, @@ -584,13 +589,18 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_RELU, mul_relu, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_GELU, mul_gelu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_GELU_4, mul_gelu_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_SILU, mul_silu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_SILU_4, mul_silu_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU_4, swiglu_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, ctx->support_simdgroup_reduction); @@ -921,6 +931,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx case GGML_OP_SQR: case GGML_OP_SUM_ROWS: return true; + case GGML_OP_FUSED_MUL_UNARY: + return ggml_is_contiguous(op->src[0]); case GGML_OP_SOFTCAP: case GGML_OP_SOFT_CAP_MAX: return true; //ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op); @@ -1648,6 +1660,26 @@ static enum ggml_status ggml_metal_graph_compute( GGML_ABORT("fatal error"); } } break; + case GGML_OP_FUSED_MUL_UNARY: + { + int64_t n = ggml_nelements(dst); + enum ggml_unary_op op = (enum ggml_unary_op)dst->op_params[0]; + id pipeline = nil; + if (n % 4 == 0 && op != GGML_UNARY_OP_RELU) { + pipeline = op == GGML_UNARY_OP_GELU ? ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_GELU_4].pipeline + : ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_SILU_4].pipeline; + n /= 4; + } else { + pipeline = op == GGML_UNARY_OP_GELU ? ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_GELU].pipeline + : op == GGML_UNARY_OP_SILU ? ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_SILU].pipeline + : ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_RELU].pipeline; + } + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; case GGML_OP_SQR: { GGML_ASSERT(ggml_is_contiguous(src0)); diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 225fa5f1..4dbfa089 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -323,6 +323,14 @@ kernel void kernel_relu( dst[tpig] = max(0.0f, src0[tpig]); } +kernel void kernel_mul_relu( + device const float * src0, + device const float * src1, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = max(0.0f, src0[tpig]) * src1[tpig]; +} + kernel void kernel_sigmoid( device const float * src0, device float * dst, @@ -364,6 +372,30 @@ kernel void kernel_gelu_4( dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); } +kernel void kernel_mul_gelu( + device const float * src0, + device const float * src1, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + + dst[tpig] = 0.5f*x*src1[tpig]*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); +} + +kernel void kernel_mul_gelu_4( + device const float4 * src0, + device const float4 * src1, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + device const float4 & x = src0[tpig]; + + // BEWARE !!! + // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs! + // This was observed with Falcon 7B and 40B models + // + dst[tpig] = 0.5f*x*src1[tpig]*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); +} + kernel void kernel_gelu_quick( device const float * src0, device float * dst, @@ -398,6 +430,24 @@ kernel void kernel_silu_4( dst[tpig] = x / (1.0f + exp(-x)); } +kernel void kernel_mul_silu( + device const float * src0, + device const float * src1, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + dst[tpig] = x * src1[tpig] / (1.0f + exp(-x)); +} + +kernel void kernel_mul_silu_4( + device const float4 * src0, + device const float4 * src1, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + device const float4 & x = src0[tpig]; + dst[tpig] = x * src1[tpig] / (1.0f + exp(-x)); +} + kernel void kernel_swiglu( device const float * src0, device float * dst,