mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-10 00:10:13 +00:00
Fused y*unary(x) op: Metal
This commit is contained in:
@@ -56,13 +56,18 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_CLAMP,
|
||||
GGML_METAL_KERNEL_TYPE_TANH,
|
||||
GGML_METAL_KERNEL_TYPE_RELU,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_RELU,
|
||||
GGML_METAL_KERNEL_TYPE_SIGMOID,
|
||||
GGML_METAL_KERNEL_TYPE_GELU,
|
||||
GGML_METAL_KERNEL_TYPE_GELU_4,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_GELU,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_GELU_4,
|
||||
GGML_METAL_KERNEL_TYPE_GELU_QUICK,
|
||||
GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
|
||||
GGML_METAL_KERNEL_TYPE_SILU,
|
||||
GGML_METAL_KERNEL_TYPE_SILU_4,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_SILU,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_SILU_4,
|
||||
GGML_METAL_KERNEL_TYPE_SWIGLU,
|
||||
GGML_METAL_KERNEL_TYPE_SWIGLU_4,
|
||||
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
|
||||
@@ -584,13 +589,18 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_RELU, mul_relu, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_GELU, mul_gelu, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_GELU_4, mul_gelu_4, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_SILU, mul_silu, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_SILU_4, mul_silu_4, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU_4, swiglu_4, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, ctx->support_simdgroup_reduction);
|
||||
@@ -921,6 +931,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
|
||||
case GGML_OP_SQR:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
return true;
|
||||
case GGML_OP_FUSED_MUL_UNARY:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_SOFTCAP:
|
||||
case GGML_OP_SOFT_CAP_MAX:
|
||||
return true; //ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op);
|
||||
@@ -1648,6 +1660,26 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_FUSED_MUL_UNARY:
|
||||
{
|
||||
int64_t n = ggml_nelements(dst);
|
||||
enum ggml_unary_op op = (enum ggml_unary_op)dst->op_params[0];
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
if (n % 4 == 0 && op != GGML_UNARY_OP_RELU) {
|
||||
pipeline = op == GGML_UNARY_OP_GELU ? ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_GELU_4].pipeline
|
||||
: ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_SILU_4].pipeline;
|
||||
n /= 4;
|
||||
} else {
|
||||
pipeline = op == GGML_UNARY_OP_GELU ? ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_GELU].pipeline
|
||||
: op == GGML_UNARY_OP_SILU ? ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_SILU].pipeline
|
||||
: ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_RELU].pipeline;
|
||||
}
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_SQR:
|
||||
{
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
@@ -323,6 +323,14 @@ kernel void kernel_relu(
|
||||
dst[tpig] = max(0.0f, src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_mul_relu(
|
||||
device const float * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = max(0.0f, src0[tpig]) * src1[tpig];
|
||||
}
|
||||
|
||||
kernel void kernel_sigmoid(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
@@ -364,6 +372,30 @@ kernel void kernel_gelu_4(
|
||||
dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
||||
}
|
||||
|
||||
kernel void kernel_mul_gelu(
|
||||
device const float * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
device const float & x = src0[tpig];
|
||||
|
||||
dst[tpig] = 0.5f*x*src1[tpig]*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
||||
}
|
||||
|
||||
kernel void kernel_mul_gelu_4(
|
||||
device const float4 * src0,
|
||||
device const float4 * src1,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
device const float4 & x = src0[tpig];
|
||||
|
||||
// BEWARE !!!
|
||||
// Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
|
||||
// This was observed with Falcon 7B and 40B models
|
||||
//
|
||||
dst[tpig] = 0.5f*x*src1[tpig]*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
||||
}
|
||||
|
||||
kernel void kernel_gelu_quick(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
@@ -398,6 +430,24 @@ kernel void kernel_silu_4(
|
||||
dst[tpig] = x / (1.0f + exp(-x));
|
||||
}
|
||||
|
||||
kernel void kernel_mul_silu(
|
||||
device const float * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
device const float & x = src0[tpig];
|
||||
dst[tpig] = x * src1[tpig] / (1.0f + exp(-x));
|
||||
}
|
||||
|
||||
kernel void kernel_mul_silu_4(
|
||||
device const float4 * src0,
|
||||
device const float4 * src1,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
device const float4 & x = src0[tpig];
|
||||
dst[tpig] = x * src1[tpig] / (1.0f + exp(-x));
|
||||
}
|
||||
|
||||
kernel void kernel_swiglu(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
|
||||
Reference in New Issue
Block a user