mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 19:31:48 +00:00
GGML_UNARY_OP_SWIGLU: Metal implementation
We get ~2% speedup for PP-512(Phi-3.5-mini).
This commit is contained in:
@@ -63,6 +63,8 @@ enum ggml_metal_kernel_type {
|
|||||||
GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
|
GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
|
||||||
GGML_METAL_KERNEL_TYPE_SILU,
|
GGML_METAL_KERNEL_TYPE_SILU,
|
||||||
GGML_METAL_KERNEL_TYPE_SILU_4,
|
GGML_METAL_KERNEL_TYPE_SILU_4,
|
||||||
|
GGML_METAL_KERNEL_TYPE_SWIGLU,
|
||||||
|
GGML_METAL_KERNEL_TYPE_SWIGLU_4,
|
||||||
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
|
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
|
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
|
||||||
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
|
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
|
||||||
@@ -583,6 +585,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
|
|||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU_4, swiglu_4, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, ctx->support_simdgroup_reduction);
|
||||||
@@ -884,6 +888,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
|
|||||||
case GGML_UNARY_OP_GELU:
|
case GGML_UNARY_OP_GELU:
|
||||||
case GGML_UNARY_OP_GELU_QUICK:
|
case GGML_UNARY_OP_GELU_QUICK:
|
||||||
case GGML_UNARY_OP_SILU:
|
case GGML_UNARY_OP_SILU:
|
||||||
|
case GGML_UNARY_OP_SWIGLU:
|
||||||
return ggml_is_contiguous(op->src[0]);
|
return ggml_is_contiguous(op->src[0]);
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
@@ -1595,6 +1600,30 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
|
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
|
} break;
|
||||||
|
case GGML_UNARY_OP_SWIGLU:
|
||||||
|
{
|
||||||
|
int64_t n = ggml_nelements(dst);
|
||||||
|
GGML_ASSERT(ne0 == src0->ne[0]/2);
|
||||||
|
|
||||||
|
id<MTLComputePipelineState> pipeline = nil;
|
||||||
|
|
||||||
|
if (ne0 % 4 == 0) {
|
||||||
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU_4].pipeline;
|
||||||
|
n /= 4;
|
||||||
|
} else {
|
||||||
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t stride = src0->nb[1]/sizeof(float);
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:2];
|
||||||
|
[encoder setBytes:&stride length:sizeof(stride) atIndex:3];
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
|
|||||||
@@ -398,6 +398,30 @@ kernel void kernel_silu_4(
|
|||||||
dst[tpig] = x / (1.0f + exp(-x));
|
dst[tpig] = x / (1.0f + exp(-x));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
kernel void kernel_swiglu(
|
||||||
|
device const float * src0,
|
||||||
|
device float * dst,
|
||||||
|
constant int64_t & ne0,
|
||||||
|
constant int64_t & stride,
|
||||||
|
uint tpig[[thread_position_in_grid]]) {
|
||||||
|
const int64_t row = tpig/ne0;
|
||||||
|
const int64_t idx = tpig%ne0;
|
||||||
|
const int64_t j = row*stride + idx;
|
||||||
|
dst[tpig] = src0[j] * src0[j + ne0] / (1.0f + exp(-src0[j]));
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel void kernel_swiglu_4(
|
||||||
|
device const float4 * src0,
|
||||||
|
device float4 * dst,
|
||||||
|
constant int64_t & ne0,
|
||||||
|
constant int64_t & stride,
|
||||||
|
uint tpig[[thread_position_in_grid]]) {
|
||||||
|
const int64_t row = tpig/(ne0/4);
|
||||||
|
const int64_t idx = tpig%(ne0/4);
|
||||||
|
const int64_t j = row*(stride/4) + idx;
|
||||||
|
dst[tpig] = src0[j] * src0[j + ne0/4] / (1.0f + exp(-src0[j]));
|
||||||
|
}
|
||||||
|
|
||||||
kernel void kernel_sqr(
|
kernel void kernel_sqr(
|
||||||
device const float * src0,
|
device const float * src0,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
|
|||||||
Reference in New Issue
Block a user