mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 19:31:48 +00:00
Adding add_4, mul_4, div_4 kernels to Metal
This gives ~2% speedup for Bitnet on Metal
This commit is contained in:
29
ggml-metal.m
29
ggml-metal.m
@@ -30,10 +30,13 @@ struct ggml_metal_kernel {
|
|||||||
|
|
||||||
enum ggml_metal_kernel_type {
|
enum ggml_metal_kernel_type {
|
||||||
GGML_METAL_KERNEL_TYPE_ADD,
|
GGML_METAL_KERNEL_TYPE_ADD,
|
||||||
|
GGML_METAL_KERNEL_TYPE_ADD_4,
|
||||||
GGML_METAL_KERNEL_TYPE_ADD_ROW,
|
GGML_METAL_KERNEL_TYPE_ADD_ROW,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL,
|
GGML_METAL_KERNEL_TYPE_MUL,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_4,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_ROW,
|
GGML_METAL_KERNEL_TYPE_MUL_ROW,
|
||||||
GGML_METAL_KERNEL_TYPE_DIV,
|
GGML_METAL_KERNEL_TYPE_DIV,
|
||||||
|
GGML_METAL_KERNEL_TYPE_DIV_4,
|
||||||
GGML_METAL_KERNEL_TYPE_DIV_ROW,
|
GGML_METAL_KERNEL_TYPE_DIV_ROW,
|
||||||
GGML_METAL_KERNEL_TYPE_REPEAT_F32,
|
GGML_METAL_KERNEL_TYPE_REPEAT_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_REPEAT_F16,
|
GGML_METAL_KERNEL_TYPE_REPEAT_F16,
|
||||||
@@ -496,10 +499,13 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||||||
// simd_sum and simd_max requires MTLGPUFamilyApple7
|
// simd_sum and simd_max requires MTLGPUFamilyApple7
|
||||||
|
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_4, add_4, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_4, mul_4, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_4, div_4, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
|
||||||
@@ -1100,6 +1106,29 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
else if (ggml_is_contiguous(dst->src[0]) && ggml_is_contiguous(dst->src[1]) && ggml_is_contiguous(dst) &&
|
||||||
|
dst->src[0]->ne[0] == dst->src[1]->ne[0] && dst->src[0]->ne[0] == dst->ne[0] &&
|
||||||
|
dst->src[0]->ne[1] == dst->src[1]->ne[1] && dst->src[0]->ne[1] == dst->ne[1] &&
|
||||||
|
dst->src[0]->ne[2] == dst->src[1]->ne[2] && dst->src[0]->ne[2] == dst->ne[2] &&
|
||||||
|
dst->src[0]->ne[3] == dst->src[1]->ne[3] && ggml_nelements(dst)%4 == 0) {
|
||||||
|
|
||||||
|
switch (dst->op) {
|
||||||
|
case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_4].pipeline; break;
|
||||||
|
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_4].pipeline; break;
|
||||||
|
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_4].pipeline; break;
|
||||||
|
default: GGML_ASSERT(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t n = ggml_nelements(dst)/4;
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||||
|
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
|
break;
|
||||||
|
}
|
||||||
else if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
|
else if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
|
||||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
|
|
||||||
|
|||||||
@@ -225,6 +225,13 @@ kernel void kernel_add_row(
|
|||||||
uint tpig[[thread_position_in_grid]]) {
|
uint tpig[[thread_position_in_grid]]) {
|
||||||
dst[tpig] = src0[tpig] + src1[tpig % nb];
|
dst[tpig] = src0[tpig] + src1[tpig % nb];
|
||||||
}
|
}
|
||||||
|
kernel void kernel_add_4(
|
||||||
|
device const float4 * src0,
|
||||||
|
device const float4 * src1,
|
||||||
|
device float4 * dst,
|
||||||
|
uint tpig[[thread_position_in_grid]]) {
|
||||||
|
dst[tpig] = src0[tpig] + src1[tpig];
|
||||||
|
}
|
||||||
|
|
||||||
kernel void kernel_mul_row(
|
kernel void kernel_mul_row(
|
||||||
device const float4 * src0,
|
device const float4 * src0,
|
||||||
@@ -235,6 +242,14 @@ kernel void kernel_mul_row(
|
|||||||
dst[tpig] = src0[tpig] * src1[tpig % nb];
|
dst[tpig] = src0[tpig] * src1[tpig % nb];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
kernel void kernel_mul_4(
|
||||||
|
device const float4 * src0,
|
||||||
|
device const float4 * src1,
|
||||||
|
device float4 * dst,
|
||||||
|
uint tpig[[thread_position_in_grid]]) {
|
||||||
|
dst[tpig] = src0[tpig] * src1[tpig];
|
||||||
|
}
|
||||||
|
|
||||||
kernel void kernel_div_row(
|
kernel void kernel_div_row(
|
||||||
device const float4 * src0,
|
device const float4 * src0,
|
||||||
device const float4 * src1,
|
device const float4 * src1,
|
||||||
@@ -243,6 +258,13 @@ kernel void kernel_div_row(
|
|||||||
uint tpig[[thread_position_in_grid]]) {
|
uint tpig[[thread_position_in_grid]]) {
|
||||||
dst[tpig] = src0[tpig] / src1[tpig % nb];
|
dst[tpig] = src0[tpig] / src1[tpig % nb];
|
||||||
}
|
}
|
||||||
|
kernel void kernel_div_4(
|
||||||
|
device const float4 * src0,
|
||||||
|
device const float4 * src1,
|
||||||
|
device float4 * dst,
|
||||||
|
uint tpig[[thread_position_in_grid]]) {
|
||||||
|
dst[tpig] = src0[tpig] / src1[tpig];
|
||||||
|
}
|
||||||
|
|
||||||
kernel void kernel_scale(
|
kernel void kernel_scale(
|
||||||
device const float * src0,
|
device const float * src0,
|
||||||
|
|||||||
Reference in New Issue
Block a user