mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-22 07:29:23 +00:00
multi_add: Metal
This commit is contained in:
@@ -39,6 +39,8 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_ADD,
|
||||
GGML_METAL_KERNEL_TYPE_ADD_4,
|
||||
GGML_METAL_KERNEL_TYPE_ADD_ROW,
|
||||
GGML_METAL_KERNEL_TYPE_MULTI_ADD,
|
||||
GGML_METAL_KERNEL_TYPE_MULTI_ADD_4,
|
||||
GGML_METAL_KERNEL_TYPE_MUL,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_4,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_ROW,
|
||||
@@ -577,6 +579,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_4, add_4, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MULTI_ADD, multi_add, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MULTI_ADD_4, multi_add_4, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_4, mul_4, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
|
||||
@@ -932,6 +936,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
|
||||
case GGML_OP_PERMUTE:
|
||||
case GGML_OP_CONCAT:
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_MULTI_ADD:
|
||||
case GGML_OP_ACC:
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_DIV:
|
||||
@@ -1349,6 +1354,36 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_MULTI_ADD:
|
||||
{
|
||||
GGML_ASSERT(src0t == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dstt == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ne02 == 1 && ne03 == 1);
|
||||
GGML_ASSERT(nb0 == sizeof(float) && nb00 == sizeof(float));
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||
|
||||
int n_expert = dst->op_params[0];
|
||||
GGML_ASSERT(n_expert >= 2);
|
||||
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
int64_t n = ne0*ne1;
|
||||
if (ne0%4 == 0) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MULTI_ADD_4].pipeline;
|
||||
n /= 4;
|
||||
} else {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MULTI_ADD].pipeline;
|
||||
}
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:2];
|
||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:3];
|
||||
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:4];
|
||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
|
||||
[encoder setBytes:&n_expert length:sizeof(n_expert) atIndex:6];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_REPEAT:
|
||||
{
|
||||
id<MTLComputePipelineState> pipeline;
|
||||
|
||||
@@ -479,6 +479,44 @@ kernel void kernel_sqr(
|
||||
dst[tpig] = src0[tpig] * src0[tpig];
|
||||
}
|
||||
|
||||
kernel void kernel_multi_add_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant int64_t & nb1,
|
||||
constant int64_t & nb01,
|
||||
constant int & n_expert,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
|
||||
int64_t i0 = tpig % (ne0/4);
|
||||
int64_t i1 = tpig / (ne0/4);
|
||||
device float4 * dst_ptr = dst + i1*(nb1/16) + i0;
|
||||
device const float4 * src_ptr = src0 + i1*(nb01/16) + i0;
|
||||
float4 sum = src_ptr[0] + src_ptr[ne0/4];
|
||||
for (int i = 2; i < n_expert; ++i) sum += src_ptr[i*ne0/4];
|
||||
dst_ptr[0] = sum;
|
||||
}
|
||||
|
||||
kernel void kernel_multi_add(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant int64_t & nb1,
|
||||
constant int64_t & nb01,
|
||||
constant int & n_expert,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
|
||||
int64_t i0 = tpig % ne0;
|
||||
int64_t i1 = tpig / ne0;
|
||||
device float * dst_ptr = dst + i1*nb1/4 + i0;
|
||||
device const float * src_ptr = src0 + i1*nb01/4 + i0;
|
||||
float sum = src_ptr[0] + src_ptr[ne0];
|
||||
for (int i = 2; i < n_expert; ++i) sum += src_ptr[i*ne0];
|
||||
dst_ptr[0] = sum;
|
||||
}
|
||||
|
||||
kernel void kernel_sum_rows(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
|
||||
Reference in New Issue
Block a user