Adding add_4, mul_4, div_4 kernels to Metal

This gives ~2% speedup for Bitnet on Metal
This commit is contained in:
Kawrakow
2024-06-24 10:22:10 +02:00
parent c9ddaf2fa3
commit f2a82090df
2 changed files with 51 additions and 0 deletions

View File

@@ -30,10 +30,13 @@ struct ggml_metal_kernel {
enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_ADD,
GGML_METAL_KERNEL_TYPE_ADD_4,
GGML_METAL_KERNEL_TYPE_ADD_ROW,
GGML_METAL_KERNEL_TYPE_MUL,
GGML_METAL_KERNEL_TYPE_MUL_4,
GGML_METAL_KERNEL_TYPE_MUL_ROW,
GGML_METAL_KERNEL_TYPE_DIV,
GGML_METAL_KERNEL_TYPE_DIV_4,
GGML_METAL_KERNEL_TYPE_DIV_ROW,
GGML_METAL_KERNEL_TYPE_REPEAT_F32,
GGML_METAL_KERNEL_TYPE_REPEAT_F16,
@@ -496,10 +499,13 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
// simd_sum and simd_max requires MTLGPUFamilyApple7
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_4, add_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_4, mul_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_4, div_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
@@ -1100,6 +1106,29 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
break;
}
else if (ggml_is_contiguous(dst->src[0]) && ggml_is_contiguous(dst->src[1]) && ggml_is_contiguous(dst) &&
dst->src[0]->ne[0] == dst->src[1]->ne[0] && dst->src[0]->ne[0] == dst->ne[0] &&
dst->src[0]->ne[1] == dst->src[1]->ne[1] && dst->src[0]->ne[1] == dst->ne[1] &&
dst->src[0]->ne[2] == dst->src[1]->ne[2] && dst->src[0]->ne[2] == dst->ne[2] &&
dst->src[0]->ne[3] == dst->src[1]->ne[3] && ggml_nelements(dst)%4 == 0) {
switch (dst->op) {
case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_4].pipeline; break;
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_4].pipeline; break;
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_4].pipeline; break;
default: GGML_ASSERT(false);
}
int64_t n = ggml_nelements(dst)/4;
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
break;
}
else if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
GGML_ASSERT(ggml_is_contiguous(src0));

View File

@@ -225,6 +225,13 @@ kernel void kernel_add_row(
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] + src1[tpig % nb];
}
kernel void kernel_add_4(
device const float4 * src0,
device const float4 * src1,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] + src1[tpig];
}
kernel void kernel_mul_row(
device const float4 * src0,
@@ -235,6 +242,14 @@ kernel void kernel_mul_row(
dst[tpig] = src0[tpig] * src1[tpig % nb];
}
kernel void kernel_mul_4(
device const float4 * src0,
device const float4 * src1,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] * src1[tpig];
}
kernel void kernel_div_row(
device const float4 * src0,
device const float4 * src1,
@@ -243,6 +258,13 @@ kernel void kernel_div_row(
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] / src1[tpig % nb];
}
kernel void kernel_div_4(
device const float4 * src0,
device const float4 * src1,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] / src1[tpig];
}
kernel void kernel_scale(
device const float * src0,