mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-12 23:10:01 +00:00
Adding add_4, mul_4, div_4 kernels to Metal
This gives ~2% speedup for Bitnet on Metal
This commit is contained in:
29
ggml-metal.m
29
ggml-metal.m
@@ -30,10 +30,13 @@ struct ggml_metal_kernel {
|
||||
|
||||
enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_ADD,
|
||||
GGML_METAL_KERNEL_TYPE_ADD_4,
|
||||
GGML_METAL_KERNEL_TYPE_ADD_ROW,
|
||||
GGML_METAL_KERNEL_TYPE_MUL,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_4,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_ROW,
|
||||
GGML_METAL_KERNEL_TYPE_DIV,
|
||||
GGML_METAL_KERNEL_TYPE_DIV_4,
|
||||
GGML_METAL_KERNEL_TYPE_DIV_ROW,
|
||||
GGML_METAL_KERNEL_TYPE_REPEAT_F32,
|
||||
GGML_METAL_KERNEL_TYPE_REPEAT_F16,
|
||||
@@ -496,10 +499,13 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||
// simd_sum and simd_max requires MTLGPUFamilyApple7
|
||||
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_4, add_4, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_4, mul_4, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_4, div_4, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
|
||||
@@ -1100,6 +1106,29 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
break;
|
||||
}
|
||||
else if (ggml_is_contiguous(dst->src[0]) && ggml_is_contiguous(dst->src[1]) && ggml_is_contiguous(dst) &&
|
||||
dst->src[0]->ne[0] == dst->src[1]->ne[0] && dst->src[0]->ne[0] == dst->ne[0] &&
|
||||
dst->src[0]->ne[1] == dst->src[1]->ne[1] && dst->src[0]->ne[1] == dst->ne[1] &&
|
||||
dst->src[0]->ne[2] == dst->src[1]->ne[2] && dst->src[0]->ne[2] == dst->ne[2] &&
|
||||
dst->src[0]->ne[3] == dst->src[1]->ne[3] && ggml_nelements(dst)%4 == 0) {
|
||||
|
||||
switch (dst->op) {
|
||||
case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_4].pipeline; break;
|
||||
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_4].pipeline; break;
|
||||
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_4].pipeline; break;
|
||||
default: GGML_ASSERT(false);
|
||||
}
|
||||
|
||||
int64_t n = ggml_nelements(dst)/4;
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
break;
|
||||
}
|
||||
else if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
|
||||
@@ -225,6 +225,13 @@ kernel void kernel_add_row(
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = src0[tpig] + src1[tpig % nb];
|
||||
}
|
||||
kernel void kernel_add_4(
|
||||
device const float4 * src0,
|
||||
device const float4 * src1,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = src0[tpig] + src1[tpig];
|
||||
}
|
||||
|
||||
kernel void kernel_mul_row(
|
||||
device const float4 * src0,
|
||||
@@ -235,6 +242,14 @@ kernel void kernel_mul_row(
|
||||
dst[tpig] = src0[tpig] * src1[tpig % nb];
|
||||
}
|
||||
|
||||
kernel void kernel_mul_4(
|
||||
device const float4 * src0,
|
||||
device const float4 * src1,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = src0[tpig] * src1[tpig];
|
||||
}
|
||||
|
||||
kernel void kernel_div_row(
|
||||
device const float4 * src0,
|
||||
device const float4 * src1,
|
||||
@@ -243,6 +258,13 @@ kernel void kernel_div_row(
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = src0[tpig] / src1[tpig % nb];
|
||||
}
|
||||
kernel void kernel_div_4(
|
||||
device const float4 * src0,
|
||||
device const float4 * src1,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = src0[tpig] / src1[tpig];
|
||||
}
|
||||
|
||||
kernel void kernel_scale(
|
||||
device const float * src0,
|
||||
|
||||
Reference in New Issue
Block a user