From 2a5552830b38e39bdb75ce7f650fec299335f4a0 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 2 Apr 2025 15:26:19 +0200 Subject: [PATCH] MoE improvements on Metal This version beats mainline, there are things I don't understand: * Mianline has effectively gone to GEMV for MUL_MAT_ID. We can do the same, but we are 30% slower. Why? * Using actual GEMM, we beat mainline with ubtach size of 128. But then performance degrades. Why? --- ggml/src/ggml-metal.m | 6756 +++++++++++++++++++++++++------------ ggml/src/ggml-metal.metal | 147 +- 2 files changed, 4756 insertions(+), 2147 deletions(-) diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 0498be1f..007a13f3 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -297,8 +297,9 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_COUNT }; +#define GGML_METAL_MAX_COMMAND_BUFFERS 8 + struct ggml_backend_metal_context { - int n_cb; id device; id queue; @@ -307,6 +308,26 @@ struct ggml_backend_metal_context { struct ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT]; + // capture state + bool capture_next_compute; + bool capture_started; + + id capture_scope; + + // command buffer state + int n_cb; // number of extra threads used to submit the command buffers + int n_nodes_0; // number of nodes submitted by the main thread + int n_nodes_1; // remaining number of nodes submitted by the n_cb threads + int n_nodes_per_cb; + + struct ggml_cgraph * gf; + + // the callback given to the thread pool + void (^encode_async)(size_t ith); + + // n_cb command buffers + 1 used by the main thread + id command_buffers[GGML_METAL_MAX_COMMAND_BUFFERS + 1]; + bool support_simdgroup_reduction; bool support_simdgroup_mm; @@ -373,7 +394,6 @@ static void * ggml_metal_host_malloc(size_t n) { const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n); if (result != 0) { GGML_METAL_LOG_ERROR("%s: error: posix_memalign failed\n", __func__); - return NULL; } #endif @@ -533,6 +553,14 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false"); ctx->should_capture_next_compute = false; + ctx->capture_started = false; + ctx->capture_scope = nil; + + ctx->gf = nil; + ctx->encode_async = nil; + for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) { + ctx->command_buffers[i] = nil; + } #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15) if (@available(macOS 10.12, iOS 16.0, *)) { @@ -846,6 +874,8 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) { [ctx->kernels[i].pipeline release]; } + Block_release(ctx->encode_async); + [ctx->queue release]; [ctx->device release]; @@ -1027,934 +1057,871 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx } } -static enum ggml_status ggml_metal_graph_compute( +static void ggml_metal_encode_node( struct ggml_backend_metal_context * ctx, - struct ggml_cgraph * gf) { + struct ggml_tensor * node, + id encoder) { - @autoreleasepool { - MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor; - edesc.dispatchType = MTLDispatchTypeSerial; - // create multiple command buffers and enqueue them - // then, we encode the graph into the command buffers in parallel + struct ggml_tensor * src0 = node->src[0]; + struct ggml_tensor * src1 = node->src[1]; + struct ggml_tensor * src2 = node->src[2]; + struct ggml_tensor * dst = node; - const int n_nodes = gf->n_nodes; - const int n_cb = ctx->n_cb; - const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb; - - const bool should_capture = ctx->should_capture_next_compute; - if (should_capture) { - ctx->should_capture_next_compute = false; - - MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new]; - descriptor.captureObject = ctx->queue; - - NSError * error = nil; - if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) { - GGML_METAL_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]); - GGML_ABORT("capture failed"); - } + if (ggml_is_empty(dst)) { + return; } - id command_buffer_builder[n_cb]; - for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) { - id command_buffer = [ctx->queue commandBufferWithUnretainedReferences]; - command_buffer_builder[cb_idx] = command_buffer; - - // always enqueue the first two command buffers - // enqueue all of the command buffers if we don't need to abort - if (cb_idx < 2 || ctx->abort_callback == NULL) { - [command_buffer enqueue]; - } + switch (dst->op) { + case GGML_OP_NONE: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_TRANSPOSE: + case GGML_OP_PERMUTE: return; // noop + default: break; } - const id *command_buffers = command_buffer_builder; + if (!ggml_metal_supports_op(ctx, dst)) { + GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst)); + GGML_ABORT("unsupported op"); + } - dispatch_apply(n_cb, ctx->d_queue, ^(size_t iter) { - const int cb_idx = iter; + const int64_t ne00 = src0 ? src0->ne[0] : 0; + const int64_t ne01 = src0 ? src0->ne[1] : 0; + const int64_t ne02 = src0 ? src0->ne[2] : 0; + const int64_t ne03 = src0 ? src0->ne[3] : 0; - size_t offs_src0 = 0; - size_t offs_src1 = 0; - size_t offs_src2 = 0; - size_t offs_dst = 0; + const uint64_t nb00 = src0 ? src0->nb[0] : 0; + const uint64_t nb01 = src0 ? src0->nb[1] : 0; + const uint64_t nb02 = src0 ? src0->nb[2] : 0; + const uint64_t nb03 = src0 ? src0->nb[3] : 0; - id command_buffer = command_buffers[cb_idx]; - id encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; + const int64_t ne10 = src1 ? src1->ne[0] : 0; + const int64_t ne11 = src1 ? src1->ne[1] : 0; + const int64_t ne12 = src1 ? src1->ne[2] : 0; + const int64_t ne13 = src1 ? src1->ne[3] : 0; - const int node_start = (cb_idx + 0) * n_nodes_per_cb; - const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes); + const uint64_t nb10 = src1 ? src1->nb[0] : 0; + const uint64_t nb11 = src1 ? src1->nb[1] : 0; + const uint64_t nb12 = src1 ? src1->nb[2] : 0; + const uint64_t nb13 = src1 ? src1->nb[3] : 0; - for (int i = node_start; i < node_end; ++i) { - if (i == -1) { - [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers]; - continue; - } + const int64_t ne20 = src2 ? src2->ne[0] : 0; + const int64_t ne21 = src2 ? src2->ne[1] : 0; + const int64_t ne22 = src2 ? src2->ne[2] : 0; GGML_UNUSED(ne22); + const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23); - //GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op)); + const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20); + const uint64_t nb21 = src2 ? src2->nb[1] : 0; + const uint64_t nb22 = src2 ? src2->nb[2] : 0; + const uint64_t nb23 = src2 ? src2->nb[3] : 0; - struct ggml_tensor * src0 = gf->nodes[i]->src[0]; - struct ggml_tensor * src1 = gf->nodes[i]->src[1]; - struct ggml_tensor * src2 = gf->nodes[i]->src[2]; - struct ggml_tensor * dst = gf->nodes[i]; + const int64_t ne0 = dst ? dst->ne[0] : 0; + const int64_t ne1 = dst ? dst->ne[1] : 0; + const int64_t ne2 = dst ? dst->ne[2] : 0; + const int64_t ne3 = dst ? dst->ne[3] : 0; - if (ggml_is_empty(dst)) { - continue; - } + const uint64_t nb0 = dst ? dst->nb[0] : 0; + const uint64_t nb1 = dst ? dst->nb[1] : 0; + const uint64_t nb2 = dst ? dst->nb[2] : 0; + const uint64_t nb3 = dst ? dst->nb[3] : 0; - switch (dst->op) { - case GGML_OP_NONE: - case GGML_OP_RESHAPE: - case GGML_OP_VIEW: - case GGML_OP_TRANSPOSE: - case GGML_OP_PERMUTE: - { - // noop -> next node - } continue; + const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT; + const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT; + const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT; + + size_t offs_src0 = 0; + size_t offs_src1 = 0; + size_t offs_src2 = 0; + size_t offs_dst = 0; + + id id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil; + id id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil; + id id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil; + id id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil; + + //GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op)); + //if (src0) { + // GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, + // ggml_is_contiguous(src0), src0->name); + //} + //if (src1) { + // GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, + // ggml_is_contiguous(src1), src1->name); + //} + //if (dst) { + // GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, + // dst->name); + //} + + switch (dst->op) { + case GGML_OP_CONCAT: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline; + + const int32_t dim = ((int32_t *) dst->op_params)[0]; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; + [encoder setBytes:&dim length:sizeof(dim) atIndex:27]; + + const int nth = MIN(1024, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_ADD: + case GGML_OP_MUL: + case GGML_OP_DIV: + { + GGML_ASSERT(src0t == GGML_TYPE_F32); + GGML_ASSERT(src1t == GGML_TYPE_F32); + + const size_t offs = 0; + + bool bcast_row = false; + + int64_t nb = ne00; // used by the "row" kernels + + id pipeline = nil; + + if (dst->op == GGML_OP_MUL && ggml_nelements(src1) == 1 && ggml_is_contiguous(src0)) { + float scale; + memcpy(&scale, src1->data, sizeof(float)); + //printf("Replacing op_mul with op_scale. scale = %g\n", (double)scale); + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline; + + int64_t n = ggml_nelements(dst); + + if (n % 4 == 0) { + n /= 4; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + break; + } + else if (ggml_is_contiguous(dst->src[0]) && ggml_is_contiguous(dst->src[1]) && ggml_is_contiguous(dst) && + dst->src[0]->ne[0] == dst->src[1]->ne[0] && dst->src[0]->ne[0] == dst->ne[0] && + dst->src[0]->ne[1] == dst->src[1]->ne[1] && dst->src[0]->ne[1] == dst->ne[1] && + dst->src[0]->ne[2] == dst->src[1]->ne[2] && dst->src[0]->ne[2] == dst->ne[2] && + dst->src[0]->ne[3] == dst->src[1]->ne[3] && ggml_nelements(dst)%4 == 0) { + + switch (dst->op) { + case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_4].pipeline; break; + case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_4].pipeline; break; + case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_4].pipeline; break; + default: GGML_ASSERT(false); + } + + int64_t n = ggml_nelements(dst)/4; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + break; + } + else if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { + GGML_ASSERT(ggml_is_contiguous(src0)); + + // src1 is a row + GGML_ASSERT(ne11 == 1); + + nb = ne00 / 4; + switch (dst->op) { + case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break; + case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break; + case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break; + default: GGML_ABORT("fatal error"); + } + + bcast_row = true; + } else { + switch (dst->op) { + case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break; + case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break; + case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break; + default: GGML_ABORT("fatal error"); + } + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; + [encoder setBytes:&offs length:sizeof(offs) atIndex:27]; + [encoder setBytes:&nb length:sizeof(nb) atIndex:28]; + + if (bcast_row) { + const int64_t n = ggml_nelements(dst)/4; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } else { + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } + } break; + case GGML_OP_MULTI_ADD: + { + GGML_ASSERT(src0t == GGML_TYPE_F32); + GGML_ASSERT(dstt == GGML_TYPE_F32); + GGML_ASSERT(ne02 == 1 && ne03 == 1); + GGML_ASSERT(nb0 == sizeof(float) && nb00 == sizeof(float)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + int n_expert = dst->op_params[0]; + GGML_ASSERT(n_expert >= 2); + + id pipeline = nil; + int64_t n = ne0*ne1; + if (ne0%4 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MULTI_ADD_4].pipeline; + n /= 4; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MULTI_ADD].pipeline; + } + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:2]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:3]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:4]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5]; + [encoder setBytes:&n_expert length:sizeof(n_expert) atIndex:6]; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_REPEAT: + { + id pipeline; + + switch (src0t) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break; + case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break; + case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break; + default: GGML_ABORT("fatal error"); + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; + + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_ACC: + { + GGML_ASSERT(src0t == GGML_TYPE_F32); + GGML_ASSERT(src1t == GGML_TYPE_F32); + GGML_ASSERT(dstt == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + + const size_t pnb1 = ((int32_t *) dst->op_params)[0]; + const size_t pnb2 = ((int32_t *) dst->op_params)[1]; + const size_t pnb3 = ((int32_t *) dst->op_params)[2]; + const size_t offs = ((int32_t *) dst->op_params)[3]; + + const bool inplace = (bool) ((int32_t *) dst->op_params)[4]; + + if (!inplace) { + // run a separete kernel to cpy src->dst + // not sure how to avoid this + // TODO: make a simpler cpy_bytes kernel + + const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; + + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } + + const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; + [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8]; + [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9]; + [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; + [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24]; + [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25]; + [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26]; + [encoder setBytes:&offs length:sizeof(offs) atIndex:27]; + + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); + + [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_SCALE: + { + GGML_ASSERT(ggml_is_contiguous(src0)); + + float scale; + memcpy(&scale, dst->op_params, sizeof(scale)); + + int64_t n = ggml_nelements(dst); + + id pipeline = nil; + + if (n % 4 == 0) { + n /= 4; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_SOFTCAP: + { + GGML_ASSERT(ggml_is_contiguous(src0)); + + float scales[2]; + memcpy(scales, dst->op_params, sizeof(scales)); + + int64_t n = ggml_nelements(dst); + + id pipeline = nil; + + if (n % 4 == 0) { + n /= 4; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFTCAP_4].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFTCAP].pipeline; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&scales[0] length:sizeof(float) atIndex:2]; + [encoder setBytes:&scales[1] length:sizeof(float) atIndex:3]; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_CLAMP: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline; + + float min; + float max; + memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float)); + memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float)); + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&min length:sizeof(min) atIndex:2]; + [encoder setBytes:&max length:sizeof(max) atIndex:3]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_UNARY: + switch (ggml_get_unary_op(node)) { + // we are not taking into account the strides, so for now require contiguous tensors + GGML_ASSERT(ggml_is_contiguous(src0)); + + case GGML_UNARY_OP_TANH: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_RELU: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RELU].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_SIGMOID: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIGMOID].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_GELU: + { + int64_t n = ggml_nelements(dst); + + id pipeline = nil; + + if (n % 4 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_4].pipeline; + n /= 4; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_GELU_QUICK: + { + int64_t n = ggml_nelements(dst); + + id pipeline = nil; + + if (n % 4 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK_4].pipeline; + n /= 4; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_SILU: + { + int64_t n = ggml_nelements(dst); + + id pipeline = nil; + + if (n % 4 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU_4].pipeline; + n /= 4; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_SWIGLU: + { + int64_t n = ggml_nelements(dst); + GGML_ASSERT(ne0 == src0->ne[0]/2); + + id pipeline = nil; + + uint32_t n_per_row = ne0; + uint32_t stride = src0->nb[1]/sizeof(float); + + if (ne0 % 4 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU_4].pipeline; + n /= 4; + n_per_row /= 4; + stride /= 4; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&n_per_row length:sizeof(n_per_row) atIndex:2]; + [encoder setBytes:&stride length:sizeof(stride) atIndex:3]; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; default: - { - } break; - } - - if (!ggml_metal_supports_op(ctx, dst)) { - GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst)); - GGML_ABORT("unsupported op"); - } - - if (should_capture) { - [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(dst) encoding:NSUTF8StringEncoding]]; - } - - const int64_t ne00 = src0 ? src0->ne[0] : 0; - const int64_t ne01 = src0 ? src0->ne[1] : 0; - const int64_t ne02 = src0 ? src0->ne[2] : 0; - const int64_t ne03 = src0 ? src0->ne[3] : 0; - - const uint64_t nb00 = src0 ? src0->nb[0] : 0; - const uint64_t nb01 = src0 ? src0->nb[1] : 0; - const uint64_t nb02 = src0 ? src0->nb[2] : 0; - const uint64_t nb03 = src0 ? src0->nb[3] : 0; - - const int64_t ne10 = src1 ? src1->ne[0] : 0; - const int64_t ne11 = src1 ? src1->ne[1] : 0; - const int64_t ne12 = src1 ? src1->ne[2] : 0; - const int64_t ne13 = src1 ? src1->ne[3] : 0; - - const uint64_t nb10 = src1 ? src1->nb[0] : 0; - const uint64_t nb11 = src1 ? src1->nb[1] : 0; - const uint64_t nb12 = src1 ? src1->nb[2] : 0; - const uint64_t nb13 = src1 ? src1->nb[3] : 0; - - const int64_t ne20 = src2 ? src2->ne[0] : 0; - const int64_t ne21 = src2 ? src2->ne[1] : 0; - const int64_t ne22 = src2 ? src2->ne[2] : 0; GGML_UNUSED(ne22); - const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23); - - const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20); - const uint64_t nb21 = src2 ? src2->nb[1] : 0; - const uint64_t nb22 = src2 ? src2->nb[2] : 0; - const uint64_t nb23 = src2 ? src2->nb[3] : 0; - - const int64_t ne0 = dst ? dst->ne[0] : 0; - const int64_t ne1 = dst ? dst->ne[1] : 0; - const int64_t ne2 = dst ? dst->ne[2] : 0; - const int64_t ne3 = dst ? dst->ne[3] : 0; - - const uint64_t nb0 = dst ? dst->nb[0] : 0; - const uint64_t nb1 = dst ? dst->nb[1] : 0; - const uint64_t nb2 = dst ? dst->nb[2] : 0; - const uint64_t nb3 = dst ? dst->nb[3] : 0; - - const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT; - const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT; - const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT; - - id id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil; - id id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil; - id id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil; - id id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil; - - //GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op)); - //if (src0) { - // GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, - // ggml_is_contiguous(src0), src0->name); - //} - //if (src1) { - // GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, - // ggml_is_contiguous(src1), src1->name); - //} - //if (dst) { - // GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, - // dst->name); - //} - - switch (dst->op) { - case GGML_OP_CONCAT: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline; - - const int32_t dim = ((int32_t *) dst->op_params)[0]; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; - [encoder setBytes:&dim length:sizeof(dim) atIndex:27]; - - const int nth = MIN(1024, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_ADD: - case GGML_OP_MUL: - case GGML_OP_DIV: - { - GGML_ASSERT(src0t == GGML_TYPE_F32); - GGML_ASSERT(src1t == GGML_TYPE_F32); - - const size_t offs = 0; - - bool bcast_row = false; - - int64_t nb = ne00; // used by the "row" kernels - - id pipeline = nil; - - if (dst->op == GGML_OP_MUL && ggml_nelements(src1) == 1 && ggml_is_contiguous(src0)) { - float scale; - memcpy(&scale, src1->data, sizeof(float)); - //printf("Replacing op_mul with op_scale. scale = %g\n", (double)scale); - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline; - - int64_t n = ggml_nelements(dst); - - if (n % 4 == 0) { - n /= 4; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline; - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - break; - } - else if (ggml_is_contiguous(dst->src[0]) && ggml_is_contiguous(dst->src[1]) && ggml_is_contiguous(dst) && - dst->src[0]->ne[0] == dst->src[1]->ne[0] && dst->src[0]->ne[0] == dst->ne[0] && - dst->src[0]->ne[1] == dst->src[1]->ne[1] && dst->src[0]->ne[1] == dst->ne[1] && - dst->src[0]->ne[2] == dst->src[1]->ne[2] && dst->src[0]->ne[2] == dst->ne[2] && - dst->src[0]->ne[3] == dst->src[1]->ne[3] && ggml_nelements(dst)%4 == 0) { - - switch (dst->op) { - case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_4].pipeline; break; - case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_4].pipeline; break; - case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_4].pipeline; break; - default: GGML_ASSERT(false); - } - - int64_t n = ggml_nelements(dst)/4; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - break; - } - else if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { - GGML_ASSERT(ggml_is_contiguous(src0)); - - // src1 is a row - GGML_ASSERT(ne11 == 1); - - nb = ne00 / 4; - switch (dst->op) { - case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break; - case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break; - case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break; - default: GGML_ABORT("fatal error"); - } - - bcast_row = true; - } else { - switch (dst->op) { - case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break; - case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break; - case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break; - default: GGML_ABORT("fatal error"); - } - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; - [encoder setBytes:&offs length:sizeof(offs) atIndex:27]; - [encoder setBytes:&nb length:sizeof(nb) atIndex:28]; - - if (bcast_row) { - const int64_t n = ggml_nelements(dst)/4; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } else { - const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } - } break; - case GGML_OP_MULTI_ADD: - { - GGML_ASSERT(src0t == GGML_TYPE_F32); - GGML_ASSERT(dstt == GGML_TYPE_F32); - GGML_ASSERT(ne02 == 1 && ne03 == 1); - GGML_ASSERT(nb0 == sizeof(float) && nb00 == sizeof(float)); - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - - int n_expert = dst->op_params[0]; - GGML_ASSERT(n_expert >= 2); - - id pipeline = nil; - int64_t n = ne0*ne1; - if (ne0%4 == 0) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MULTI_ADD_4].pipeline; - n /= 4; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MULTI_ADD].pipeline; - } - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:2]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:3]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:4]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5]; - [encoder setBytes:&n_expert length:sizeof(n_expert) atIndex:6]; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_REPEAT: - { - id pipeline; - - switch (src0t) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break; - case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break; - case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break; - default: GGML_ABORT("fatal error"); - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; - - const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_ACC: - { - GGML_ASSERT(src0t == GGML_TYPE_F32); - GGML_ASSERT(src1t == GGML_TYPE_F32); - GGML_ASSERT(dstt == GGML_TYPE_F32); - - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(src1)); - - const size_t pnb1 = ((int32_t *) dst->op_params)[0]; - const size_t pnb2 = ((int32_t *) dst->op_params)[1]; - const size_t pnb3 = ((int32_t *) dst->op_params)[2]; - const size_t offs = ((int32_t *) dst->op_params)[3]; - - const bool inplace = (bool) ((int32_t *) dst->op_params)[4]; - - if (!inplace) { - // run a separete kernel to cpy src->dst - // not sure how to avoid this - // TODO: make a simpler cpy_bytes kernel - - const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; - - const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } - - const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; - [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8]; - [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9]; - [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; - [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24]; - [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25]; - [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26]; - [encoder setBytes:&offs length:sizeof(offs) atIndex:27]; - - const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); - - [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_SCALE: - { - GGML_ASSERT(ggml_is_contiguous(src0)); - - float scale; - memcpy(&scale, dst->op_params, sizeof(scale)); - - int64_t n = ggml_nelements(dst); - - id pipeline = nil; - - if (n % 4 == 0) { - n /= 4; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline; - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_SOFTCAP: - { - GGML_ASSERT(ggml_is_contiguous(src0)); - - float scales[2]; - memcpy(scales, dst->op_params, sizeof(scales)); - - int64_t n = ggml_nelements(dst); - - id pipeline = nil; - - if (n % 4 == 0) { - n /= 4; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFTCAP_4].pipeline; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFTCAP].pipeline; - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&scales[0] length:sizeof(float) atIndex:2]; - [encoder setBytes:&scales[1] length:sizeof(float) atIndex:3]; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_CLAMP: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline; - - float min; - float max; - memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float)); - memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float)); - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&min length:sizeof(min) atIndex:2]; - [encoder setBytes:&max length:sizeof(max) atIndex:3]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_UNARY: - switch (ggml_get_unary_op(gf->nodes[i])) { - // we are not taking into account the strides, so for now require contiguous tensors - GGML_ASSERT(ggml_is_contiguous(src0)); - - case GGML_UNARY_OP_TANH: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_RELU: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RELU].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_SIGMOID: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIGMOID].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_GELU: - { - int64_t n = ggml_nelements(dst); - - id pipeline = nil; - - if (n % 4 == 0) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_4].pipeline; - n /= 4; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline; - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_GELU_QUICK: - { - int64_t n = ggml_nelements(dst); - - id pipeline = nil; - - if (n % 4 == 0) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK_4].pipeline; - n /= 4; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline; - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_SILU: - { - int64_t n = ggml_nelements(dst); - - id pipeline = nil; - - if (n % 4 == 0) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU_4].pipeline; - n /= 4; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline; - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_SWIGLU: - { - int64_t n = ggml_nelements(dst); - GGML_ASSERT(ne0 == src0->ne[0]/2); - - id pipeline = nil; - - uint32_t n_per_row = ne0; - uint32_t stride = src0->nb[1]/sizeof(float); - - if (ne0 % 4 == 0) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU_4].pipeline; - n /= 4; - n_per_row /= 4; - stride /= 4; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline; - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&n_per_row length:sizeof(n_per_row) atIndex:2]; - [encoder setBytes:&stride length:sizeof(stride) atIndex:3]; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - default: - { - GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op)); - GGML_ABORT("fatal error"); - } - } break; - case GGML_OP_FUSED_MUL_UNARY: - { - int64_t n = ggml_nelements(dst); - enum ggml_unary_op op = (enum ggml_unary_op)dst->op_params[0]; - id pipeline = nil; - if (n % 4 == 0 && op != GGML_UNARY_OP_RELU) { - pipeline = op == GGML_UNARY_OP_GELU ? ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_GELU_4].pipeline - : ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_SILU_4].pipeline; - n /= 4; - } else { - pipeline = op == GGML_UNARY_OP_GELU ? ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_GELU].pipeline - : op == GGML_UNARY_OP_SILU ? ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_SILU].pipeline - : ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_RELU].pipeline; - } - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_SQR: - { - GGML_ASSERT(ggml_is_contiguous(src0)); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQR].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_SUM_ROWS: - { - GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_SOFT_MAX: - { - GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); - - int nth = 32; // SIMD width - - id pipeline = nil; - - const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); - - if (ne00%4 == 0) { - while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) { - nth *= 2; - } - if (use_f16) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4].pipeline; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline; - } - } else { - while (nth < ne00 && nth*ne01*ne02*ne03 < 256) { - nth *= 2; - } - if (use_f16) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16].pipeline; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32].pipeline; - } - } - - float scale; - float max_bias; - - memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale)); - memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias)); - - const int64_t nrows_x = ggml_nrows(src0); - const int64_t nrows_y = src0->ne[1]; - - const uint32_t n_head = nrows_x/nrows_y; - const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); - - const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - if (id_src1) { - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - } - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&scale length:sizeof(scale) atIndex:6]; - [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7]; - [encoder setBytes:&m0 length:sizeof(m0) atIndex:8]; - [encoder setBytes:&m1 length:sizeof(m1) atIndex:9]; - [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10]; - [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_SOFT_CAP_MAX: - { - GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); - - int nth = 32; // SIMD width - - id pipeline = nil; - - const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); - - if (ne00%4 == 0) { - while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) { - nth *= 2; - } - if (use_f16) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16_4].pipeline; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32_4].pipeline; - } - } else { - while (nth < ne00 && nth*ne01*ne02*ne03 < 256) { - nth *= 2; - } - if (use_f16) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16].pipeline; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32].pipeline; - } - } - - float scale; - float max_bias; - float s_before; - float s_after; - - memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale)); - memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias)); - memcpy(&s_before, ((int32_t *) dst->op_params) + 2, sizeof(s_before)); - memcpy(&s_after, ((int32_t *) dst->op_params) + 3, sizeof(s_after)); - - const int64_t nrows_x = ggml_nrows(src0); - const int64_t nrows_y = src0->ne[1]; - - const uint32_t n_head = nrows_x/nrows_y; - const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); - - const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - if (id_src1) { - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - } - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&scale length:sizeof(scale) atIndex:6]; - [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7]; - [encoder setBytes:&m0 length:sizeof(m0) atIndex:8]; - [encoder setBytes:&m1 length:sizeof(m1) atIndex:9]; - [encoder setBytes:&s_before length:sizeof(s_before) atIndex:10]; - [encoder setBytes:&s_after length:sizeof(s_after ) atIndex:11]; - [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:12]; - [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_DIAG_MASK_INF: - { - const int n_past = ((int32_t *)(dst->op_params))[0]; - - id pipeline = nil; - - if (ne00%8 == 0) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8].pipeline; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline; - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&n_past length:sizeof(int) atIndex:4]; - - if (ne00%8 == 0) { - [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } - else { - [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } - } break; - case GGML_OP_MUL_MAT: - { - GGML_ASSERT(ne00 == ne10); - - GGML_ASSERT(ne12 % ne02 == 0); - GGML_ASSERT(ne13 % ne03 == 0); - - const uint r2 = ne12/ne02; - const uint r3 = ne13/ne03; - - // find the break-even point where the matrix-matrix kernel becomes more efficient compared - // to the matrix-vector kernel - int ne11_mm_min = 1; + { + GGML_METAL_LOG_WARN("%s: node %s, op = %8s not implemented\n", __func__, dst->name, ggml_op_name(dst->op)); + GGML_ABORT("fatal error"); + } + } break; + case GGML_OP_FUSED_MUL_UNARY: + { + int64_t n = ggml_nelements(dst); + enum ggml_unary_op op = (enum ggml_unary_op)dst->op_params[0]; + id pipeline = nil; + if (n % 4 == 0 && op != GGML_UNARY_OP_RELU) { + pipeline = op == GGML_UNARY_OP_GELU ? ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_GELU_4].pipeline + : ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_SILU_4].pipeline; + n /= 4; + } else { + pipeline = op == GGML_UNARY_OP_GELU ? ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_GELU].pipeline + : op == GGML_UNARY_OP_SILU ? ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_SILU].pipeline + : ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_RELU].pipeline; + } + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_SQR: + { + GGML_ASSERT(ggml_is_contiguous(src0)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQR].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_SUM_ROWS: + { + GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_SOFT_MAX: + { + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); + + int nth = 32; // SIMD width + + id pipeline = nil; + + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); + + if (ne00%4 == 0) { + while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) { + nth *= 2; + } + if (use_f16) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline; + } + } else { + while (nth < ne00 && nth*ne01*ne02*ne03 < 256) { + nth *= 2; + } + if (use_f16) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32].pipeline; + } + } + + float scale; + float max_bias; + + memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale)); + memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias)); + + const int64_t nrows_x = ggml_nrows(src0); + const int64_t nrows_y = src0->ne[1]; + + const uint32_t n_head = nrows_x/nrows_y; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + if (id_src1) { + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + } else { + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + } + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&scale length:sizeof(scale) atIndex:6]; + [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7]; + [encoder setBytes:&m0 length:sizeof(m0) atIndex:8]; + [encoder setBytes:&m1 length:sizeof(m1) atIndex:9]; + [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_SOFT_CAP_MAX: + { + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); + + int nth = 32; // SIMD width + + id pipeline = nil; + + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); + + if (ne00%4 == 0) { + while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) { + nth *= 2; + } + if (use_f16) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16_4].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32_4].pipeline; + } + } else { + while (nth < ne00 && nth*ne01*ne02*ne03 < 256) { + nth *= 2; + } + if (use_f16) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32].pipeline; + } + } + + float scale; + float max_bias; + float s_before; + float s_after; + + memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale)); + memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias)); + memcpy(&s_before, ((int32_t *) dst->op_params) + 2, sizeof(s_before)); + memcpy(&s_after, ((int32_t *) dst->op_params) + 3, sizeof(s_after)); + + const int64_t nrows_x = ggml_nrows(src0); + const int64_t nrows_y = src0->ne[1]; + + const uint32_t n_head = nrows_x/nrows_y; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + if (id_src1) { + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + } else { + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + } + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&scale length:sizeof(scale) atIndex:6]; + [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7]; + [encoder setBytes:&m0 length:sizeof(m0) atIndex:8]; + [encoder setBytes:&m1 length:sizeof(m1) atIndex:9]; + [encoder setBytes:&s_before length:sizeof(s_before) atIndex:10]; + [encoder setBytes:&s_after length:sizeof(s_after ) atIndex:11]; + [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:12]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_DIAG_MASK_INF: + { + const int n_past = ((int32_t *)(dst->op_params))[0]; + + id pipeline = nil; + + if (ne00%8 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&n_past length:sizeof(int) atIndex:4]; + + if (ne00%8 == 0) { + [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } + else { + [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } + } break; + case GGML_OP_MUL_MAT: + { + GGML_ASSERT(ne00 == ne10); + + GGML_ASSERT(ne12 % ne02 == 0); + GGML_ASSERT(ne13 % ne03 == 0); + + const uint r2 = ne12/ne02; + const uint r3 = ne13/ne03; + + // find the break-even point where the matrix-matrix kernel becomes more efficient compared + // to the matrix-vector kernel + int ne11_mm_min = 4; #if 0 - // the numbers below are measured on M2 Ultra for 7B and 13B models - // these numbers do not translate to other devices or model sizes - // TODO: need to find a better approach + // the numbers below are measured on M2 Ultra for 7B and 13B models + // these numbers do not translate to other devices or model sizes + // TODO: need to find a better approach if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) { switch (src0t) { case GGML_TYPE_F16: ne11_mm_min = 2; break; @@ -1976,11 +1943,11 @@ static enum ggml_status ggml_metal_graph_compute( // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && - !ggml_is_transposed(src0) && - !ggml_is_transposed(src1) && - src1t == GGML_TYPE_F32 && - ne00 % 32 == 0 && ne00 >= 64 && - (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) { + !ggml_is_transposed(src0) && + !ggml_is_transposed(src1) && + src1t == GGML_TYPE_F32 && + ne00 % 32 == 0 && ne00 >= 64 && + (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) { //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); // some Metal matrix data types require aligned pointers @@ -2305,9 +2272,9 @@ static enum ggml_status ggml_metal_graph_compute( [encoder setBytes:&r3 length:sizeof(r3) atIndex:18]; if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || - src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || - src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S|| - src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_Q6_0) { + src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || + src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S|| + src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_Q6_0) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_IQ2_KS || src0t == GGML_TYPE_IQ2_K || src0t == GGML_TYPE_IQ3_K) { @@ -2326,8 +2293,8 @@ static enum ggml_status ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS || src0t == GGML_TYPE_IQ4_K || - src0t == GGML_TYPE_IQ5_K || src0t == GGML_TYPE_IQ6_K || src0t == GGML_TYPE_IQ4_KS|| - src0t == GGML_TYPE_IQ4_KSS) { + src0t == GGML_TYPE_IQ5_K || src0t == GGML_TYPE_IQ6_K || src0t == GGML_TYPE_IQ4_KS|| + src0t == GGML_TYPE_IQ4_KSS) { const int mem_size = src0t == GGML_TYPE_IQ6_K ? 128*sizeof(float) : GGML_TYPE_IQ5_K ? 64*sizeof(float) : 32*sizeof(float); [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; @@ -2348,1182 +2315,3765 @@ static enum ggml_status ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } } - } break; - case GGML_OP_MUL_MAT_ID: - { - const int n_as = src0->ne[2]; - - // src2 = ids - const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t); - - GGML_ASSERT(src2t == GGML_TYPE_I32); - - GGML_ASSERT(!ggml_is_transposed(src0)); - GGML_ASSERT(!ggml_is_transposed(src1)); - - GGML_ASSERT(src1t == GGML_TYPE_F32); - - // find the break-even point where the matrix-matrix kernel becomes more efficient compared - // to the matrix-vector kernel - // ne20 = n_used_experts - // ne21 = n_rows - const int dst_rows = ne20*ne21; - const int dst_rows_min = n_as; - const int dst_rows_max = (ctx->device.maxThreadgroupMemoryLength - 32 - 8192)/4; - - // max size of the rowids array in the kernel shared buffer - GGML_ASSERT(dst_rows <= dst_rows_max); - - // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs - // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel - // !!! - // TODO: for now, always use mat-vec kernels until we figure out how to improve the - // indirect matrix multiplication - // !!! - if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && - ne00 % 32 == 0 && ne00 >= 64 && - dst_rows > dst_rows_min) { - - // some Metal matrix data types require aligned pointers - // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) - switch (src0->type) { - case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break; - case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break; - default: break; - } - - id pipeline = nil; - - switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32 ].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32 ].pipeline; break; - case GGML_TYPE_Q6_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_0_F32 ].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32 ].pipeline; break; - case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32 ].pipeline; break; - case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ].pipeline; break; - case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ].pipeline; break; - case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ].pipeline; break; - case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break; - case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break; - case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break; - case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break; - case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break; - case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break; - case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break; - case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32 ].pipeline; break; - case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_BN_F32 ].pipeline; break; - case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_BN_F32 ].pipeline; break; - case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break; - case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break; - case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KS_F32 ].pipeline; break; - case GGML_TYPE_IQ4_KSS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KSS_F32].pipeline; break; - case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_K_F32 ].pipeline; break; - case GGML_TYPE_IQ2_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_KS_F32 ].pipeline; break; - case GGML_TYPE_IQ3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_K_F32 ].pipeline; break; - case GGML_TYPE_IQ4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_K_F32 ].pipeline; break; - case GGML_TYPE_IQ5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ5_K_F32 ].pipeline; break; - case GGML_TYPE_IQ6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ6_K_F32 ].pipeline; break; - default: GGML_ABORT("MUL_MAT_ID not implemented"); - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; - [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; - [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:8]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19]; - - [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0]; - - [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; - } else { - int nth0 = 32; - int nth1 = 1; - int nrows = 1; - //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); - - id pipeline = nil; - - // use custom matrix x vector kernel - switch (src0t) { - case GGML_TYPE_F32: - { - GGML_ASSERT(src1t == GGML_TYPE_F32); - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline; - } break; - case GGML_TYPE_F16: - { - GGML_ASSERT(src1t == GGML_TYPE_F32); - nth0 = 32; - nth1 = 1; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline; - } break; - case GGML_TYPE_BF16: - { - GGML_ASSERT(src1t == GGML_TYPE_F32); - nth0 = 32; - nth1 = 1; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32].pipeline; - } break; - case GGML_TYPE_Q4_0: - { - nth0 = 8; - nth1 = 8; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline; - } break; - case GGML_TYPE_Q4_1: - { - nth0 = 8; - nth1 = 8; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline; - } break; - case GGML_TYPE_Q5_0: - { - nth0 = 8; - nth1 = 8; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline; - } break; - case GGML_TYPE_Q5_1: - { - nth0 = 8; - nth1 = 8; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline; - } break; - case GGML_TYPE_Q6_0: - { - nth0 = 8; - nth1 = 8; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_0_F32].pipeline; - } break; - case GGML_TYPE_Q8_0: - { - nth0 = 8; - nth1 = 8; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline; - } break; - case GGML_TYPE_Q2_K: - { - nth0 = 2; - nth1 = 32; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline; - } break; - case GGML_TYPE_Q3_K: - { - nth0 = 2; - nth1 = 32; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline; - } break; - case GGML_TYPE_Q4_K: - { - nth0 = 4; //1; - nth1 = 8; //32; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline; - } break; - case GGML_TYPE_Q5_K: - { - nth0 = 2; - nth1 = 32; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline; - } break; - case GGML_TYPE_Q6_K: - { - nth0 = 2; - nth1 = 32; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline; - } break; - case GGML_TYPE_IQ2_XXS: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline; - } break; - case GGML_TYPE_IQ2_XS: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline; - } break; - case GGML_TYPE_IQ3_XXS: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline; - } break; - case GGML_TYPE_IQ3_S: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline; - } break; - case GGML_TYPE_IQ2_S: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline; - } break; - case GGML_TYPE_IQ1_S: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline; - } break; - case GGML_TYPE_IQ1_M: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline; - } break; - case GGML_TYPE_IQ1_BN: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_BN_F32].pipeline; - } break; - case GGML_TYPE_IQ2_BN: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_BN_F32].pipeline; - } break; - case GGML_TYPE_IQ4_NL: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline; - } break; - case GGML_TYPE_IQ4_XS: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline; - } break; - case GGML_TYPE_IQ4_KS: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_KS_F32].pipeline; - } break; - case GGML_TYPE_IQ4_KSS: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_KSS_F32].pipeline; - } break; - case GGML_TYPE_IQ2_K: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_K_F32].pipeline; - } break; - case GGML_TYPE_IQ2_KS: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_KS_F32].pipeline; - } break; - case GGML_TYPE_IQ3_K: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_K_F32].pipeline; - } break; - case GGML_TYPE_IQ4_K: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_K_F32].pipeline; - } break; - case GGML_TYPE_IQ5_K: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ5_K_F32].pipeline; - } break; - case GGML_TYPE_IQ6_K: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ6_K_F32].pipeline; - } break; - default: - { - GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t); - GGML_ABORT("not implemented"); - } - }; - - if (ggml_is_quantized(src0t)) { - GGML_ASSERT(ne00 >= nth0*nth1); - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; - [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; - [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:8]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:9]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:10]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:11]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:12]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:13]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:15]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:16]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:17]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:20]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:21]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:22]; - - const int64_t _ne1 = 1; - const int tgz = dst_rows; - - if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || - src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || - src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_Q6_0 || - src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_IQ2_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_IQ2_KS || src0t == GGML_TYPE_IQ2_K || src0t == GGML_TYPE_IQ3_K) { - const int mem_size = src0t == GGML_TYPE_IQ2_KS ? 64*sizeof(float) : src0t == GGML_TYPE_IQ3_K ? 32*sizeof(float) : 16*sizeof(float); - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) { - const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128; - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) { - const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4; - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS || src0t == GGML_TYPE_IQ4_K || - src0t == GGML_TYPE_IQ5_K || src0t == GGML_TYPE_IQ6_K || src0t == GGML_TYPE_IQ4_KS|| - src0t == GGML_TYPE_IQ4_KSS) { - const int mem_size = src0t == GGML_TYPE_IQ6_K ? 128*sizeof(float) : GGML_TYPE_IQ5_K ? 64*sizeof(float) : 32*sizeof(float); - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q4_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q3_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q5_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q6_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } else { - const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1 - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - } - } break; - case GGML_OP_GET_ROWS: - { - id pipeline = nil; - - switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break; - case GGML_TYPE_Q6_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_0 ].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break; - case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break; - case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break; - case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break; - case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K ].pipeline; break; - case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ].pipeline; break; - case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break; - case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break; - case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break; - case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S ].pipeline; break; - case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S ].pipeline; break; - case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break; - case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M ].pipeline; break; - case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_BN ].pipeline; break; - case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_BN ].pipeline; break; - case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break; - case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break; - case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KS ].pipeline; break; - case GGML_TYPE_IQ4_KSS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KSS].pipeline; break; - case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_K ].pipeline; break; - case GGML_TYPE_IQ2_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_KS ].pipeline; break; - case GGML_TYPE_IQ3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_K ].pipeline; break; - case GGML_TYPE_IQ4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_K ].pipeline; break; - case GGML_TYPE_IQ5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ5_K ].pipeline; break; - case GGML_TYPE_IQ6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ6_K ].pipeline; break; - case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break; - default: GGML_ABORT("not implemented"); - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5]; - [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6]; - [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7]; - [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; - } break; - case GGML_OP_RMS_NORM: - { - GGML_ASSERT(ne00 % 4 == 0); - GGML_ASSERT(ggml_is_contiguous_1(src0)); - - float eps; - memcpy(&eps, dst->op_params, sizeof(float)); - - int nth = 32; // SIMD width - - while (nth < ne00/4 && nth < 1024) { - nth *= 2; - } - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; - [encoder setBytes:&eps length:sizeof( float) atIndex:4]; - [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; - - const int64_t nrows = ggml_nrows(src0); - - [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_FUSED_RMS_NORM: - { - GGML_ASSERT(ne00 % 4 == 0); - GGML_ASSERT(ggml_is_contiguous_1(src0)); - GGML_ASSERT(src1->ne[0] == src0->ne[0]); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT(ggml_nrows(src1) == 1); - - float eps; - memcpy(&eps, dst->op_params, sizeof(float)); - - int nth = 32; // SIMD width - - while (nth < ne00/4 && nth < 1024) { - nth *= 2; - } - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FUSED_RMS_NORM].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4]; - [encoder setBytes:&eps length:sizeof( float) atIndex:5]; - [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; - - const int64_t nrows = ggml_nrows(src0); - - [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_GROUP_NORM: - { - GGML_ASSERT(ne00 % 4 == 0); - GGML_ASSERT(ggml_is_contiguous(src0)); - - float eps; - memcpy(&eps, dst->op_params + 1, sizeof(float)); - - const int32_t n_groups = ((int32_t *) dst->op_params)[0]; - - int nth = 32; // SIMD width - - //while (nth < ne00/4 && nth < 1024) { - // nth *= 2; - //} - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8]; - [encoder setBytes:&eps length:sizeof( float) atIndex:9]; - [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; - - [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_NORM: - { - GGML_ASSERT(ggml_is_contiguous_1(src0)); - - float eps; - memcpy(&eps, dst->op_params, sizeof(float)); - - const int nth = MIN(256, ne00); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; - [encoder setBytes:&eps length:sizeof( float) atIndex:4]; - [encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0]; - - const int64_t nrows = ggml_nrows(src0); - - [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_ROPE: - { - GGML_ASSERT(ne10 == ne02); - - const int nth = MIN(1024, ne00); - - const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_dims = ((int32_t *) dst->op_params)[1]; - const int mode = ((int32_t *) dst->op_params)[2]; - // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal - const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; - - float freq_base; - float freq_scale; - float ext_factor; - float attn_factor; - float beta_fast; - float beta_slow; - - memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); - memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); - memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); - memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); - memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); - memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); - - const bool is_neox = mode & 2; - - id pipeline = nil; - - if (!is_neox) { - switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break; - default: GGML_ABORT("fatal error"); - }; - } else { - switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break; - default: GGML_ABORT("fatal error"); - }; - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - if (id_src2 != nil) { - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; - } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2]; - } - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:10]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:11]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:14]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:15]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:18]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19]; - [encoder setBytes:&n_past length:sizeof( int) atIndex:20]; - [encoder setBytes:&n_dims length:sizeof( int) atIndex:21]; - [encoder setBytes:&n_ctx_orig length:sizeof( int) atIndex:22]; - [encoder setBytes:&freq_base length:sizeof( float) atIndex:23]; - [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24]; - [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25]; - [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26]; - [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27]; - [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_IM2COL: - { - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); - - const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; - const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; - const int32_t p0 = ((const int32_t *)(dst->op_params))[2]; - const int32_t p1 = ((const int32_t *)(dst->op_params))[3]; - const int32_t d0 = ((const int32_t *)(dst->op_params))[4]; - const int32_t d1 = ((const int32_t *)(dst->op_params))[5]; - - const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1; - - const int32_t N = src1->ne[is_2D ? 3 : 2]; - const int32_t IC = src1->ne[is_2D ? 2 : 1]; - const int32_t IH = is_2D ? src1->ne[1] : 1; - const int32_t IW = src1->ne[0]; - - const int32_t KH = is_2D ? src0->ne[1] : 1; - const int32_t KW = src0->ne[0]; - - const int32_t OH = is_2D ? dst->ne[2] : 1; - const int32_t OW = dst->ne[1]; - - const int32_t CHW = IC * KH * KW; - - const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; - const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; - - id pipeline = nil; - - switch (dst->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break; - default: GGML_ABORT("fatal error"); - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2]; - [encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3]; - [encoder setBytes:&IW length:sizeof( int32_t) atIndex:4]; - [encoder setBytes:&IH length:sizeof( int32_t) atIndex:5]; - [encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6]; - [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7]; - [encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8]; - [encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9]; - [encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10]; - [encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11]; - [encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12]; - - [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)]; - } break; - case GGML_OP_UPSCALE: - { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - - const float sf0 = (float)ne0/src0->ne[0]; - const float sf1 = (float)ne1/src0->ne[1]; - const float sf2 = (float)ne2/src0->ne[2]; - const float sf3 = (float)ne3/src0->ne[3]; - - const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; - [encoder setBytes:&sf0 length:sizeof(sf0) atIndex:18]; - [encoder setBytes:&sf1 length:sizeof(sf1) atIndex:19]; - [encoder setBytes:&sf2 length:sizeof(sf2) atIndex:20]; - [encoder setBytes:&sf3 length:sizeof(sf3) atIndex:21]; - - const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_PAD: - { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; - - const int nth = MIN(1024, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_ARANGE: - { - GGML_ASSERT(dst->type == GGML_TYPE_F32); - - float start; - float step; - - memcpy(&start, ((int32_t *) dst->op_params) + 0, sizeof(float)); - memcpy(&step, ((int32_t *) dst->op_params) + 2, sizeof(float)); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:0]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1]; - [encoder setBytes:&start length:sizeof(start) atIndex:2]; - [encoder setBytes:&step length:sizeof(step) atIndex:3]; - - const int nth = MIN(1024, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_TIMESTEP_EMBEDDING: - { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - - const int dim = dst->op_params[0]; - const int max_period = dst->op_params[1]; - - const int half = dim / 2; - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:2]; - [encoder setBytes:&dim length:sizeof(dim) atIndex:3]; - [encoder setBytes:&max_period length:sizeof(max_period) atIndex:4]; - - const int nth = MIN(1024, half); - - [encoder dispatchThreadgroups:MTLSizeMake(ne00, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_ARGSORT: - { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_I32); - - const int nrows = ggml_nrows(src0); - - enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0]; - - // bitonic sort requires the number of elements to be power of 2 - int64_t ne00_padded = 1; - while (ne00_padded < ne00) { - ne00_padded *= 2; - } - - // Metal kernels require the buffer size to be multiple of 16 bytes - // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength - const int mem_size = GGML_PAD(ne00_padded*sizeof(int32_t), 16); - - id pipeline = nil; - - switch (order) { - case GGML_SORT_ORDER_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break; - case GGML_SORT_ORDER_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break; - default: GGML_ABORT("fatal error"); - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3]; - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - - [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)]; - } break; - case GGML_OP_LEAKY_RELU: - { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - - float slope; - memcpy(&slope, dst->op_params, sizeof(float)); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&slope length:sizeof(slope) atIndex:2]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_FLASH_ATTN_EXT: - { - GGML_ASSERT(ne00 % 4 == 0); - GGML_ASSERT(ne11 % 32 == 0); - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - - GGML_ASSERT(ggml_are_same_shape (src1, src2)); - - struct ggml_tensor * src3 = gf->nodes[i]->src[3]; - - size_t offs_src3 = 0; - - id id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; - - GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16); - GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) && - "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big"); - - const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30); - //const int64_t ne31 = src3 ? src3->ne[1] : 0; - const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32); - const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33); - - const uint64_t nb30 = src3 ? src3->nb[0] : 0; GGML_UNUSED(nb30); - const uint64_t nb31 = src3 ? src3->nb[1] : 0; - const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32); - const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33); - - const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t); - - float scale; - float max_bias; - float softcap; - - memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale)); - memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias)); - memcpy(&softcap, ((int32_t *) dst->op_params) + 2, sizeof(softcap)); - if (softcap != 0.0f) { - scale /= softcap; - } - - const uint32_t n_head = src0->ne[2]; - const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); - - const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - - id pipeline = nil; - - bool use_vec_kernel = false; - - if (ne01 >= 4 || (ne00%128 != 0)) { - switch (ne00) { - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; - //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; - default: - { - GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_METAL_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } - } else { - use_vec_kernel = true; - - switch (ne00) { - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break; - //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; - default: - { - GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_METAL_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; - if (id_src3) { - [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; - } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3]; - } - [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10]; - [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18]; - [encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19]; - [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22]; - [encoder setBytes:&scale length:sizeof( float) atIndex:23]; - [encoder setBytes:&max_bias length:sizeof( float) atIndex:24]; - [encoder setBytes:&m0 length:sizeof(m0) atIndex:25]; - [encoder setBytes:&m1 length:sizeof(m1) atIndex:26]; - [encoder setBytes:&softcap length:sizeof(softcap) atIndex:27]; - [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:28]; - - if (!use_vec_kernel) { - // half8x8 kernel - const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! - const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! - - GGML_ASSERT(nqptg <= 32); - GGML_ASSERT(nqptg % 8 == 0); - GGML_ASSERT(ncpsg % 32 == 0); - - int64_t nsgmax = 2; - - while (true) { - const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2); - if (smem > ctx->device.maxThreadgroupMemoryLength) { - break; - } - nsgmax *= 2; - } - nsgmax /= 2; - - // simdgroups per threadgroup (a.k.a. warps) - const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4; - - const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2); - - //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); - GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); - - [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0]; - - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; - } else { - // half1x4 kernel - const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !! - const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! - - GGML_ASSERT(nqptg <= 32); - GGML_ASSERT(nqptg % 1 == 0); - GGML_ASSERT(ncpsg % 32 == 0); - - // simdgroups per threadgroup (a.k.a. warps) - const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)); - - int64_t nsg = 1; - while (nsg <= nsgt) { - nsg *= 2; - } - nsg /= 2; - - const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2); - - //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); - GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); - [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0]; - - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; - } - } break; - case GGML_OP_DUP: - case GGML_OP_CPY: - case GGML_OP_CONT: - { - GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); - - int nth = MIN(1024, ne00/ggml_blck_size(src0->type)); - - id pipeline = nil; - - switch (src0t) { - case GGML_TYPE_F32: - { - GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0); - - switch (dstt) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break; - case GGML_TYPE_Q6_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q6_0].pipeline; break; - case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL].pipeline; break; - default: GGML_ABORT("not implemented"); - }; - } break; - case GGML_TYPE_F16: - { - switch (dstt) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break; - default: GGML_ABORT("not implemented"); - }; - } break; - default: GGML_ABORT("not implemented"); - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - default: - { - GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op)); - GGML_ABORT("fatal error"); + } break; + case GGML_OP_MUL_MAT_ID: + { + const int n_as = src0->ne[2]; + + // src2 = ids + const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t); + + GGML_ASSERT(src2t == GGML_TYPE_I32); + + GGML_ASSERT(!ggml_is_transposed(src0)); + GGML_ASSERT(!ggml_is_transposed(src1)); + + GGML_ASSERT(src1t == GGML_TYPE_F32); + + // find the break-even point where the matrix-matrix kernel becomes more efficient compared + // to the matrix-vector kernel + // ne20 = n_used_experts + // ne21 = n_rows + const int dst_rows = ne20*ne21; + const int dst_rows_min = n_as; + const int dst_rows_max = (ctx->device.maxThreadgroupMemoryLength/2 - 8192)/4; + //const int dst_rows_max = (ctx->device.maxThreadgroupMemoryLength - 32 - 8192)/4; + + // max size of the rowids array in the kernel shared buffer + //GGML_ASSERT(dst_rows <= dst_rows_max); + + // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs + // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel + // !!! + // TODO: for now, always use mat-vec kernels until we figure out how to improve the + // indirect matrix multiplication + // !!! + //printf("%s: ne00 = %d, dst_rows = %d, dst_rows_min = %d max = %d, ne20 = %d ne21 = %d\n", src0->name, (int)ne00, (int)dst_rows, (int)dst_rows_min, (int)dst_rows_max, (int)ne20, (int)ne21); + if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && + ne00 % 32 == 0 && ne00 >= 64 && + dst_rows > dst_rows_min && + dst_rows <= dst_rows_max) { + + //printf("Using mul_mm_id for %s\n", dst->name); + + // some Metal matrix data types require aligned pointers + // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) + switch (src0->type) { + case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break; + case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break; + default: break; } - } - if (should_capture) { - [encoder popDebugGroup]; + id pipeline = nil; + + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break; + case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32 ].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32 ].pipeline; break; + case GGML_TYPE_Q6_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_0_F32 ].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32 ].pipeline; break; + case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32 ].pipeline; break; + case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ].pipeline; break; + case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ].pipeline; break; + case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ].pipeline; break; + case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break; + case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break; + case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break; + case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break; + case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break; + case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break; + case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break; + case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32 ].pipeline; break; + case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_BN_F32 ].pipeline; break; + case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_BN_F32 ].pipeline; break; + case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break; + case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break; + case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KS_F32 ].pipeline; break; + case GGML_TYPE_IQ4_KSS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KSS_F32].pipeline; break; + case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_K_F32 ].pipeline; break; + case GGML_TYPE_IQ2_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_KS_F32 ].pipeline; break; + case GGML_TYPE_IQ3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_K_F32 ].pipeline; break; + case GGML_TYPE_IQ4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_K_F32 ].pipeline; break; + case GGML_TYPE_IQ5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ5_K_F32 ].pipeline; break; + case GGML_TYPE_IQ6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ6_K_F32 ].pipeline; break; + default: GGML_ABORT("MUL_MAT_ID not implemented"); + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; + [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; + [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5]; + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:8]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19]; + + [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + } else { + //printf("Using mul_mv_id for %s\n", dst->name); + int nth0 = 32; + int nth1 = 1; + int nrows = 1; + //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); + + id pipeline = nil; + + // use custom matrix x vector kernel + switch (src0t) { + case GGML_TYPE_F32: + { + GGML_ASSERT(src1t == GGML_TYPE_F32); + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline; + } break; + case GGML_TYPE_F16: + { + GGML_ASSERT(src1t == GGML_TYPE_F32); + nth0 = 32; + nth1 = 1; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline; + } break; + case GGML_TYPE_BF16: + { + GGML_ASSERT(src1t == GGML_TYPE_F32); + nth0 = 32; + nth1 = 1; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32].pipeline; + } break; + case GGML_TYPE_Q4_0: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline; + } break; + case GGML_TYPE_Q4_1: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline; + } break; + case GGML_TYPE_Q5_0: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline; + } break; + case GGML_TYPE_Q5_1: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline; + } break; + case GGML_TYPE_Q6_0: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_0_F32].pipeline; + } break; + case GGML_TYPE_Q8_0: + { + nth0 = 32; + nth1 = 2; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline; + } break; + case GGML_TYPE_Q2_K: + { + nth0 = 2; + nth1 = 32; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline; + } break; + case GGML_TYPE_Q3_K: + { + nth0 = 2; + nth1 = 32; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline; + } break; + case GGML_TYPE_Q4_K: + { + nth0 = 4; //1; + nth1 = 8; //32; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline; + } break; + case GGML_TYPE_Q5_K: + { + nth0 = 2; + nth1 = 32; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline; + } break; + case GGML_TYPE_Q6_K: + { + nth0 = 2; + nth1 = 32; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline; + } break; + case GGML_TYPE_IQ2_XXS: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline; + } break; + case GGML_TYPE_IQ2_XS: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline; + } break; + case GGML_TYPE_IQ3_XXS: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline; + } break; + case GGML_TYPE_IQ3_S: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline; + } break; + case GGML_TYPE_IQ2_S: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline; + } break; + case GGML_TYPE_IQ1_S: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline; + } break; + case GGML_TYPE_IQ1_M: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline; + } break; + case GGML_TYPE_IQ1_BN: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_BN_F32].pipeline; + } break; + case GGML_TYPE_IQ2_BN: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_BN_F32].pipeline; + } break; + case GGML_TYPE_IQ4_NL: + { + nth0 = 32; + nth1 = 2; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline; + } break; + case GGML_TYPE_IQ4_XS: + { + nth0 = 32; + nth1 = 2; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline; + } break; + case GGML_TYPE_IQ4_KS: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_KS_F32].pipeline; + } break; + case GGML_TYPE_IQ4_KSS: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_KSS_F32].pipeline; + } break; + case GGML_TYPE_IQ2_K: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_K_F32].pipeline; + } break; + case GGML_TYPE_IQ2_KS: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_KS_F32].pipeline; + } break; + case GGML_TYPE_IQ3_K: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_K_F32].pipeline; + } break; + case GGML_TYPE_IQ4_K: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_K_F32].pipeline; + } break; + case GGML_TYPE_IQ5_K: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ5_K_F32].pipeline; + } break; + case GGML_TYPE_IQ6_K: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ6_K_F32].pipeline; + } break; + default: + { + GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t); + GGML_ABORT("not implemented"); + } + }; + + if (ggml_is_quantized(src0t)) { + GGML_ASSERT(ne00 >= nth0*nth1); + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; + [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; + [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5]; + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:8]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:9]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:10]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:11]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:12]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:13]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:15]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:16]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:17]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:20]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:21]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:22]; + + const int64_t _ne1 = 1; + const int tgz = dst_rows; + + if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || + src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || + src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_Q6_0 || + src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_IQ2_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_IQ2_KS || src0t == GGML_TYPE_IQ2_K || src0t == GGML_TYPE_IQ3_K) { + const int mem_size = src0t == GGML_TYPE_IQ2_KS ? 64*sizeof(float) : src0t == GGML_TYPE_IQ3_K ? 32*sizeof(float) : 16*sizeof(float); + [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) { + const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128; + [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) { + const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4; + [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS || src0t == GGML_TYPE_IQ4_K || + src0t == GGML_TYPE_IQ5_K || src0t == GGML_TYPE_IQ6_K || src0t == GGML_TYPE_IQ4_KS|| + src0t == GGML_TYPE_IQ4_KSS) { + const int mem_size = src0t == GGML_TYPE_IQ6_K ? 128*sizeof(float) : GGML_TYPE_IQ5_K ? 64*sizeof(float) : 32*sizeof(float); + [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_Q4_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_Q3_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_Q5_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_Q6_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } else { + const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1 + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + } + } break; + case GGML_OP_GET_ROWS: + { + id pipeline = nil; + + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break; + case GGML_TYPE_Q6_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_0 ].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break; + case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break; + case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break; + case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break; + case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K ].pipeline; break; + case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ].pipeline; break; + case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break; + case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break; + case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break; + case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S ].pipeline; break; + case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S ].pipeline; break; + case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break; + case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M ].pipeline; break; + case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_BN ].pipeline; break; + case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_BN ].pipeline; break; + case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break; + case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break; + case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KS ].pipeline; break; + case GGML_TYPE_IQ4_KSS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KSS].pipeline; break; + case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_K ].pipeline; break; + case GGML_TYPE_IQ2_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_KS ].pipeline; break; + case GGML_TYPE_IQ3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_K ].pipeline; break; + case GGML_TYPE_IQ4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_K ].pipeline; break; + case GGML_TYPE_IQ5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ5_K ].pipeline; break; + case GGML_TYPE_IQ6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ6_K ].pipeline; break; + case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break; + default: GGML_ABORT("not implemented"); + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5]; + [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6]; + [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7]; + [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; + } break; + case GGML_OP_RMS_NORM: + { + GGML_ASSERT(ne00 % 4 == 0); + GGML_ASSERT(ggml_is_contiguous_1(src0)); + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + int nth = 32; // SIMD width + + while (nth < ne00/4 && nth < 1024) { + nth *= 2; + } + + nth = MIN(nth, ne00/4); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; + [encoder setBytes:&eps length:sizeof( float) atIndex:4]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; + + const int64_t nrows = ggml_nrows(src0); + + [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_FUSED_RMS_NORM: + { + GGML_ASSERT(ne00 % 4 == 0); + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(src1->ne[0] == src0->ne[0]); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_nrows(src1) == 1); + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + int nth = 32; // SIMD width + + while (nth < ne00/4 && nth < 1024) { + nth *= 2; + } + + nth = MIN(nth, ne00/4); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FUSED_RMS_NORM].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4]; + [encoder setBytes:&eps length:sizeof( float) atIndex:5]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; + + const int64_t nrows = ggml_nrows(src0); + + [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_GROUP_NORM: + { + GGML_ASSERT(ne00 % 4 == 0); + GGML_ASSERT(ggml_is_contiguous(src0)); + + float eps; + memcpy(&eps, dst->op_params + 1, sizeof(float)); + + const int32_t n_groups = ((int32_t *) dst->op_params)[0]; + + int nth = 32; // SIMD width + + //while (nth < ne00/4 && nth < 1024) { + // nth *= 2; + //} + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8]; + [encoder setBytes:&eps length:sizeof( float) atIndex:9]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_NORM: + { + GGML_ASSERT(ggml_is_contiguous_1(src0)); + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + const int nth = MIN(256, ne00); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; + [encoder setBytes:&eps length:sizeof( float) atIndex:4]; + [encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0]; + + const int64_t nrows = ggml_nrows(src0); + + [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_ROPE: + { + GGML_ASSERT(ne10 == ne02); + + const int nth = MIN(1024, ne00); + + const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal + const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; + + float freq_base; + float freq_scale; + float ext_factor; + float attn_factor; + float beta_fast; + float beta_slow; + + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + + const bool is_neox = mode & 2; + + id pipeline = nil; + + if (!is_neox) { + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break; + default: GGML_ABORT("fatal error"); + }; + } else { + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break; + default: GGML_ABORT("fatal error"); + }; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + if (id_src2 != nil) { + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + } else { + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2]; + } + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:10]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:11]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:14]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:15]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:18]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19]; + [encoder setBytes:&n_past length:sizeof( int) atIndex:20]; + [encoder setBytes:&n_dims length:sizeof( int) atIndex:21]; + [encoder setBytes:&n_ctx_orig length:sizeof( int) atIndex:22]; + [encoder setBytes:&freq_base length:sizeof( float) atIndex:23]; + [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24]; + [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25]; + [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26]; + [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27]; + [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_IM2COL: + { + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[2]; + const int32_t p1 = ((const int32_t *)(dst->op_params))[3]; + const int32_t d0 = ((const int32_t *)(dst->op_params))[4]; + const int32_t d1 = ((const int32_t *)(dst->op_params))[5]; + + const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1; + + const int32_t N = src1->ne[is_2D ? 3 : 2]; + const int32_t IC = src1->ne[is_2D ? 2 : 1]; + const int32_t IH = is_2D ? src1->ne[1] : 1; + const int32_t IW = src1->ne[0]; + + const int32_t KH = is_2D ? src0->ne[1] : 1; + const int32_t KW = src0->ne[0]; + + const int32_t OH = is_2D ? dst->ne[2] : 1; + const int32_t OW = dst->ne[1]; + + const int32_t CHW = IC * KH * KW; + + const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; + const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; + + id pipeline = nil; + + switch (dst->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break; + default: GGML_ABORT("fatal error"); + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2]; + [encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3]; + [encoder setBytes:&IW length:sizeof( int32_t) atIndex:4]; + [encoder setBytes:&IH length:sizeof( int32_t) atIndex:5]; + [encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6]; + [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7]; + [encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8]; + [encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9]; + [encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10]; + [encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11]; + [encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12]; + + [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)]; + } break; + case GGML_OP_UPSCALE: + { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + + const float sf0 = (float)ne0/src0->ne[0]; + const float sf1 = (float)ne1/src0->ne[1]; + const float sf2 = (float)ne2/src0->ne[2]; + const float sf3 = (float)ne3/src0->ne[3]; + + const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; + [encoder setBytes:&sf0 length:sizeof(sf0) atIndex:18]; + [encoder setBytes:&sf1 length:sizeof(sf1) atIndex:19]; + [encoder setBytes:&sf2 length:sizeof(sf2) atIndex:20]; + [encoder setBytes:&sf3 length:sizeof(sf3) atIndex:21]; + + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_PAD: + { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; + + const int nth = MIN(1024, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_ARANGE: + { + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + float start; + float step; + + memcpy(&start, ((int32_t *) dst->op_params) + 0, sizeof(float)); + memcpy(&step, ((int32_t *) dst->op_params) + 2, sizeof(float)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:0]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1]; + [encoder setBytes:&start length:sizeof(start) atIndex:2]; + [encoder setBytes:&step length:sizeof(step) atIndex:3]; + + const int nth = MIN(1024, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_TIMESTEP_EMBEDDING: + { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + + const int dim = dst->op_params[0]; + const int max_period = dst->op_params[1]; + + const int half = dim / 2; + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:2]; + [encoder setBytes:&dim length:sizeof(dim) atIndex:3]; + [encoder setBytes:&max_period length:sizeof(max_period) atIndex:4]; + + const int nth = MIN(1024, half); + + [encoder dispatchThreadgroups:MTLSizeMake(ne00, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_ARGSORT: + { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_I32); + + const int nrows = ggml_nrows(src0); + + enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0]; + + // bitonic sort requires the number of elements to be power of 2 + int64_t ne00_padded = 1; + while (ne00_padded < ne00) { + ne00_padded *= 2; + } + + // Metal kernels require the buffer size to be multiple of 16 bytes + // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength + const int mem_size = GGML_PAD(ne00_padded*sizeof(int32_t), 16); + + id pipeline = nil; + + switch (order) { + case GGML_SORT_ORDER_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break; + case GGML_SORT_ORDER_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break; + default: GGML_ABORT("fatal error"); + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3]; + [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)]; + } break; + case GGML_OP_LEAKY_RELU: + { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + + float slope; + memcpy(&slope, dst->op_params, sizeof(float)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&slope length:sizeof(slope) atIndex:2]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_FLASH_ATTN_EXT: + { + GGML_ASSERT(ne00 % 4 == 0); + GGML_ASSERT(ne11 % 32 == 0); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + + GGML_ASSERT(ggml_are_same_shape (src1, src2)); + + struct ggml_tensor * src3 = node->src[3]; + + size_t offs_src3 = 0; + + id id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; + + GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16); + GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) && + "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big"); + + const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30); + //const int64_t ne31 = src3 ? src3->ne[1] : 0; + const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32); + const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33); + + const uint64_t nb30 = src3 ? src3->nb[0] : 0; GGML_UNUSED(nb30); + const uint64_t nb31 = src3 ? src3->nb[1] : 0; + const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32); + const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33); + + const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t); + + float scale; + float max_bias; + float softcap; + + memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale)); + memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias)); + memcpy(&softcap, ((int32_t *) dst->op_params) + 2, sizeof(softcap)); + if (softcap != 0.0f) { + scale /= softcap; + } + + const uint32_t n_head = src0->ne[2]; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + id pipeline = nil; + + bool use_vec_kernel = false; + + if (ne01 >= 4 || (ne00%128 != 0)) { + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; + //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; + default: + { + GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_METAL_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } + } else { + use_vec_kernel = true; + + switch (ne00) { + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break; + //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; + default: + { + GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_METAL_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + if (id_src3) { + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; + } else { + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3]; + } + [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10]; + [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11]; + [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14]; + [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15]; + [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18]; + [encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19]; + [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22]; + [encoder setBytes:&scale length:sizeof( float) atIndex:23]; + [encoder setBytes:&max_bias length:sizeof( float) atIndex:24]; + [encoder setBytes:&m0 length:sizeof(m0) atIndex:25]; + [encoder setBytes:&m1 length:sizeof(m1) atIndex:26]; + [encoder setBytes:&softcap length:sizeof(softcap) atIndex:27]; + [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:28]; + + if (!use_vec_kernel) { + // half8x8 kernel + const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! + const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! + + GGML_ASSERT(nqptg <= 32); + GGML_ASSERT(nqptg % 8 == 0); + GGML_ASSERT(ncpsg % 32 == 0); + + int64_t nsgmax = 2; + + while (true) { + const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2); + if (smem > ctx->device.maxThreadgroupMemoryLength) { + break; + } + nsgmax *= 2; + } + nsgmax /= 2; + + // simdgroups per threadgroup (a.k.a. warps) + const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4; + + const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2); + + //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); + GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); + + [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } else { + // half1x4 kernel + const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !! + const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! + + GGML_ASSERT(nqptg <= 32); + GGML_ASSERT(nqptg % 1 == 0); + GGML_ASSERT(ncpsg % 32 == 0); + + // simdgroups per threadgroup (a.k.a. warps) + const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)); + + int64_t nsg = 1; + while (nsg <= nsgt) { + nsg *= 2; + } + nsg /= 2; + + const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2); + + //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); + GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } + } break; + case GGML_OP_DUP: + case GGML_OP_CPY: + case GGML_OP_CONT: + { + GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); + + int nth = MIN(1024, ne00/ggml_blck_size(src0->type)); + + id pipeline = nil; + + switch (src0t) { + case GGML_TYPE_F32: + { + GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0); + + switch (dstt) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break; + case GGML_TYPE_Q6_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q6_0].pipeline; break; + case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL].pipeline; break; + default: GGML_ABORT("not implemented"); + }; + } break; + case GGML_TYPE_F16: + { + switch (dstt) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break; + default: GGML_ABORT("not implemented"); + }; + } break; + default: GGML_ABORT("not implemented"); + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + default: + { + GGML_METAL_LOG_ERROR("%s: error: node %s, op = %8s not implemented\n", __func__, dst->name, ggml_op_name(dst->op)); + GGML_ABORT("fatal error"); + } + } + +} + +//static void ggml_metal_encode_node( +// struct ggml_backend_metal_context * ctx, +// struct ggml_tensor * node, +// id encoder) { +static enum ggml_status ggml_metal_graph_compute( + struct ggml_backend_metal_context * ctx, + struct ggml_cgraph * gf) { + + // number of nodes encoded by the main thread (empirically determined) + const int n_main = 128; + + // number of threads in addition to the main thread + const int n_cb = ctx->n_cb; + + @autoreleasepool { + ctx->gf = gf; + + ctx->n_nodes_0 = MIN(n_main, gf->n_nodes); + ctx->n_nodes_1 = gf->n_nodes - ctx->n_nodes_0; + + ctx->n_nodes_per_cb = (ctx->n_nodes_1 + ctx->n_cb - 1) / ctx->n_cb; + + const bool should_capture = ctx->capture_next_compute; + if (should_capture) { + ctx->capture_next_compute = false; + + if (!ctx->capture_started) { + // create capture scope + ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx->device]; //ctx_dev->mtl_device]; + + MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new]; + descriptor.captureObject = ctx->capture_scope; + descriptor.destination = MTLCaptureDestinationGPUTraceDocument; + descriptor.outputURL = [NSURL fileURLWithPath:[NSString stringWithFormat:@"/tmp/perf-metal.gputrace"]]; + + NSError * error = nil; + if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) { + printf("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]); + } else { + [ctx->capture_scope beginScope]; + ctx->capture_started = true; + } } } - [encoder endEncoding]; + // the main thread commits the first few commands immediately + // command_buffer[n_cb] + { + id command_buffer = [ctx->queue commandBufferWithUnretainedReferences]; + ctx->command_buffers[n_cb] = command_buffer; - if (cb_idx < 2 || ctx->abort_callback == NULL) { - [command_buffer commit]; + [command_buffer enqueue]; + ctx->encode_async(n_cb); } - }); - // Wait for completion and check status of each command buffer - // needed to detect if the device ran out-of-memory for example (#1881) + // prepare the rest of the command buffers asynchronously + // command_buffer[0.. n_cb) + for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) { + id command_buffer = [ctx->queue commandBufferWithUnretainedReferences]; + ctx->command_buffers[cb_idx] = command_buffer; - for (int i = 0; i < n_cb; ++i) { - id command_buffer = command_buffers[i]; - [command_buffer waitUntilCompleted]; + // always enqueue the first two command buffers + // enqueue all of the command buffers if we don't need to abort + if (cb_idx < 2 || ctx->abort_callback == NULL) { + [command_buffer enqueue]; + } + } - MTLCommandBufferStatus status = [command_buffer status]; - if (status != MTLCommandBufferStatusCompleted) { - GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status); - if (status == MTLCommandBufferStatusError) { - NSString * error_code = [command_buffer error].localizedDescription; - GGML_METAL_LOG_INFO("error: %s\n", [error_code UTF8String]); + dispatch_apply(n_cb, ctx->d_queue, ctx->encode_async); + + // wait for completion and check status of each command buffer + // needed to detect if the device ran out-of-memory for example (#1881) + { + id command_buffer = ctx->command_buffers[n_cb]; + [command_buffer waitUntilCompleted]; + + MTLCommandBufferStatus status = [command_buffer status]; + if (status != MTLCommandBufferStatusCompleted) { + printf("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status); + if (status == MTLCommandBufferStatusError) { + printf("error: %s\n", [[command_buffer error].localizedDescription UTF8String]); + } + + return GGML_STATUS_FAILED; + } + } + for (int i = 0; i < n_cb; ++i) { + id command_buffer = ctx->command_buffers[i]; + [command_buffer waitUntilCompleted]; + + MTLCommandBufferStatus status = [command_buffer status]; + if (status != MTLCommandBufferStatusCompleted) { + printf("%s: command buffer %d failed with status %lu\n", __func__, i, status); + if (status == MTLCommandBufferStatusError) { + printf("error: %s\n", [[command_buffer error].localizedDescription UTF8String]); + } + + return GGML_STATUS_FAILED; } - return GGML_STATUS_FAILED; + id next_buffer = (i + 1 < n_cb ? ctx->command_buffers[i + 1] : nil); + if (!next_buffer) { + continue; + } + + const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued); + if (next_queued) { + continue; + } + + if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) { + printf("%s: command buffer %d aborted", __func__, i); + return GGML_STATUS_ABORTED; + } + + [next_buffer commit]; } - id next_buffer = (i + 1 < n_cb ? command_buffers[i + 1] : nil); - if (!next_buffer) { - continue; + if (!should_capture && ctx->capture_started) { + [ctx->capture_scope endScope]; + [[MTLCaptureManager sharedCaptureManager] stopCapture]; } - - bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued); - if (next_queued) { - continue; - } - - if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) { - GGML_METAL_LOG_INFO("%s: command buffer %d aborted", __func__, i); - return GGML_STATUS_ABORTED; - } - - [next_buffer commit]; } - if (should_capture) { - [[MTLCaptureManager sharedCaptureManager] stopCapture]; - } - - } return GGML_STATUS_SUCCESS; } +//static enum ggml_status ggml_metal_graph_compute( +// struct ggml_backend_metal_context * ctx, +// struct ggml_cgraph * gf) { +// +// @autoreleasepool { +// MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor; +// edesc.dispatchType = MTLDispatchTypeSerial; +// +// // create multiple command buffers and enqueue them +// // then, we encode the graph into the command buffers in parallel +// +// const int n_nodes = gf->n_nodes; +// const int n_cb = ctx->n_cb; +// const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb; +// +// const bool should_capture = ctx->should_capture_next_compute; +// if (should_capture) { +// ctx->should_capture_next_compute = false; +// +// MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new]; +// descriptor.captureObject = ctx->queue; +// +// NSError * error = nil; +// if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) { +// GGML_METAL_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]); +// GGML_ABORT("capture failed"); +// } +// } +// +// id command_buffer_builder[n_cb]; +// for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) { +// id command_buffer = [ctx->queue commandBufferWithUnretainedReferences]; +// command_buffer_builder[cb_idx] = command_buffer; +// +// // always enqueue the first two command buffers +// // enqueue all of the command buffers if we don't need to abort +// if (cb_idx < 2 || ctx->abort_callback == NULL) { +// [command_buffer enqueue]; +// } +// } +// +// const id *command_buffers = command_buffer_builder; +// +// dispatch_apply(n_cb, ctx->d_queue, ^(size_t iter) { +// const int cb_idx = iter; +// +// size_t offs_src0 = 0; +// size_t offs_src1 = 0; +// size_t offs_src2 = 0; +// size_t offs_dst = 0; +// +// id command_buffer = command_buffers[cb_idx]; +// id encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; +// +// const int node_start = (cb_idx + 0) * n_nodes_per_cb; +// const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes); +// +// for (int i = node_start; i < node_end; ++i) { +// if (i == -1) { +// [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers]; +// continue; +// } +// +// //GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op)); +// +// struct ggml_tensor * src0 = gf->nodes[i]->src[0]; +// struct ggml_tensor * src1 = gf->nodes[i]->src[1]; +// struct ggml_tensor * src2 = gf->nodes[i]->src[2]; +// struct ggml_tensor * dst = gf->nodes[i]; +// +// if (ggml_is_empty(dst)) { +// continue; +// } +// +// switch (dst->op) { +// case GGML_OP_NONE: +// case GGML_OP_RESHAPE: +// case GGML_OP_VIEW: +// case GGML_OP_TRANSPOSE: +// case GGML_OP_PERMUTE: +// { +// // noop -> next node +// } continue; +// default: +// { +// } break; +// } +// +// if (!ggml_metal_supports_op(ctx, dst)) { +// GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst)); +// GGML_ABORT("unsupported op"); +// } +// +// if (should_capture) { +// [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(dst) encoding:NSUTF8StringEncoding]]; +// } +// +// const int64_t ne00 = src0 ? src0->ne[0] : 0; +// const int64_t ne01 = src0 ? src0->ne[1] : 0; +// const int64_t ne02 = src0 ? src0->ne[2] : 0; +// const int64_t ne03 = src0 ? src0->ne[3] : 0; +// +// const uint64_t nb00 = src0 ? src0->nb[0] : 0; +// const uint64_t nb01 = src0 ? src0->nb[1] : 0; +// const uint64_t nb02 = src0 ? src0->nb[2] : 0; +// const uint64_t nb03 = src0 ? src0->nb[3] : 0; +// +// const int64_t ne10 = src1 ? src1->ne[0] : 0; +// const int64_t ne11 = src1 ? src1->ne[1] : 0; +// const int64_t ne12 = src1 ? src1->ne[2] : 0; +// const int64_t ne13 = src1 ? src1->ne[3] : 0; +// +// const uint64_t nb10 = src1 ? src1->nb[0] : 0; +// const uint64_t nb11 = src1 ? src1->nb[1] : 0; +// const uint64_t nb12 = src1 ? src1->nb[2] : 0; +// const uint64_t nb13 = src1 ? src1->nb[3] : 0; +// +// const int64_t ne20 = src2 ? src2->ne[0] : 0; +// const int64_t ne21 = src2 ? src2->ne[1] : 0; +// const int64_t ne22 = src2 ? src2->ne[2] : 0; GGML_UNUSED(ne22); +// const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23); +// +// const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20); +// const uint64_t nb21 = src2 ? src2->nb[1] : 0; +// const uint64_t nb22 = src2 ? src2->nb[2] : 0; +// const uint64_t nb23 = src2 ? src2->nb[3] : 0; +// +// const int64_t ne0 = dst ? dst->ne[0] : 0; +// const int64_t ne1 = dst ? dst->ne[1] : 0; +// const int64_t ne2 = dst ? dst->ne[2] : 0; +// const int64_t ne3 = dst ? dst->ne[3] : 0; +// +// const uint64_t nb0 = dst ? dst->nb[0] : 0; +// const uint64_t nb1 = dst ? dst->nb[1] : 0; +// const uint64_t nb2 = dst ? dst->nb[2] : 0; +// const uint64_t nb3 = dst ? dst->nb[3] : 0; +// +// const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT; +// const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT; +// const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT; +// +// id id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil; +// id id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil; +// id id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil; +// id id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil; +// +// //GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op)); +// //if (src0) { +// // GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, +// // ggml_is_contiguous(src0), src0->name); +// //} +// //if (src1) { +// // GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, +// // ggml_is_contiguous(src1), src1->name); +// //} +// //if (dst) { +// // GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, +// // dst->name); +// //} +// +// switch (dst->op) { +// case GGML_OP_CONCAT: +// { +// id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline; +// +// const int32_t dim = ((int32_t *) dst->op_params)[0]; +// +// [encoder setComputePipelineState:pipeline]; +// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; +// [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; +// [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; +// [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; +// [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; +// [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; +// [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; +// [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; +// [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; +// [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; +// [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; +// [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; +// [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; +// [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; +// [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; +// [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; +// [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; +// [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; +// [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; +// [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; +// [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; +// [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; +// [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; +// [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; +// [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; +// [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; +// [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; +// [encoder setBytes:&dim length:sizeof(dim) atIndex:27]; +// +// const int nth = MIN(1024, ne0); +// +// [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; +// } break; +// case GGML_OP_ADD: +// case GGML_OP_MUL: +// case GGML_OP_DIV: +// { +// GGML_ASSERT(src0t == GGML_TYPE_F32); +// GGML_ASSERT(src1t == GGML_TYPE_F32); +// +// const size_t offs = 0; +// +// bool bcast_row = false; +// +// int64_t nb = ne00; // used by the "row" kernels +// +// id pipeline = nil; +// +// if (dst->op == GGML_OP_MUL && ggml_nelements(src1) == 1 && ggml_is_contiguous(src0)) { +// float scale; +// memcpy(&scale, src1->data, sizeof(float)); +// //printf("Replacing op_mul with op_scale. scale = %g\n", (double)scale); +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline; +// +// int64_t n = ggml_nelements(dst); +// +// if (n % 4 == 0) { +// n /= 4; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline; +// } else { +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline; +// } +// +// [encoder setComputePipelineState:pipeline]; +// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; +// [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; +// [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; +// +// [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; +// break; +// } +// else if (ggml_is_contiguous(dst->src[0]) && ggml_is_contiguous(dst->src[1]) && ggml_is_contiguous(dst) && +// dst->src[0]->ne[0] == dst->src[1]->ne[0] && dst->src[0]->ne[0] == dst->ne[0] && +// dst->src[0]->ne[1] == dst->src[1]->ne[1] && dst->src[0]->ne[1] == dst->ne[1] && +// dst->src[0]->ne[2] == dst->src[1]->ne[2] && dst->src[0]->ne[2] == dst->ne[2] && +// dst->src[0]->ne[3] == dst->src[1]->ne[3] && ggml_nelements(dst)%4 == 0) { +// +// switch (dst->op) { +// case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_4].pipeline; break; +// case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_4].pipeline; break; +// case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_4].pipeline; break; +// default: GGML_ASSERT(false); +// } +// +// int64_t n = ggml_nelements(dst)/4; +// +// [encoder setComputePipelineState:pipeline]; +// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; +// [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; +// [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; +// +// [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; +// break; +// } +// else if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { +// GGML_ASSERT(ggml_is_contiguous(src0)); +// +// // src1 is a row +// GGML_ASSERT(ne11 == 1); +// +// nb = ne00 / 4; +// switch (dst->op) { +// case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break; +// case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break; +// case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break; +// default: GGML_ABORT("fatal error"); +// } +// +// bcast_row = true; +// } else { +// switch (dst->op) { +// case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break; +// case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break; +// case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break; +// default: GGML_ABORT("fatal error"); +// } +// } +// +// [encoder setComputePipelineState:pipeline]; +// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; +// [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; +// [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; +// [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; +// [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; +// [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; +// [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; +// [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; +// [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; +// [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; +// [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; +// [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; +// [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; +// [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; +// [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; +// [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; +// [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; +// [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; +// [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; +// [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; +// [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; +// [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; +// [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; +// [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; +// [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; +// [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; +// [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; +// [encoder setBytes:&offs length:sizeof(offs) atIndex:27]; +// [encoder setBytes:&nb length:sizeof(nb) atIndex:28]; +// +// if (bcast_row) { +// const int64_t n = ggml_nelements(dst)/4; +// +// [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; +// } else { +// const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); +// +// [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; +// } +// } break; +// case GGML_OP_MULTI_ADD: +// { +// GGML_ASSERT(src0t == GGML_TYPE_F32); +// GGML_ASSERT(dstt == GGML_TYPE_F32); +// GGML_ASSERT(ne02 == 1 && ne03 == 1); +// GGML_ASSERT(nb0 == sizeof(float) && nb00 == sizeof(float)); +// GGML_ASSERT(ggml_are_same_shape(src0, dst)); +// +// int n_expert = dst->op_params[0]; +// GGML_ASSERT(n_expert >= 2); +// +// id pipeline = nil; +// int64_t n = ne0*ne1; +// if (ne0%4 == 0) { +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MULTI_ADD_4].pipeline; +// n /= 4; +// } else { +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MULTI_ADD].pipeline; +// } +// [encoder setComputePipelineState:pipeline]; +// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; +// [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; +// [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:2]; +// [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:3]; +// [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:4]; +// [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5]; +// [encoder setBytes:&n_expert length:sizeof(n_expert) atIndex:6]; +// +// [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; +// } break; +// case GGML_OP_REPEAT: +// { +// id pipeline; +// +// switch (src0t) { +// case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break; +// case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break; +// case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break; +// case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break; +// default: GGML_ABORT("fatal error"); +// } +// +// [encoder setComputePipelineState:pipeline]; +// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; +// [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; +// [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; +// [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; +// [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; +// [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; +// [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; +// [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; +// [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; +// [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; +// [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; +// [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; +// [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; +// [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; +// [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; +// [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; +// [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; +// [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; +// +// const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); +// +// [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; +// } break; +// case GGML_OP_ACC: +// { +// GGML_ASSERT(src0t == GGML_TYPE_F32); +// GGML_ASSERT(src1t == GGML_TYPE_F32); +// GGML_ASSERT(dstt == GGML_TYPE_F32); +// +// GGML_ASSERT(ggml_is_contiguous(src0)); +// GGML_ASSERT(ggml_is_contiguous(src1)); +// +// const size_t pnb1 = ((int32_t *) dst->op_params)[0]; +// const size_t pnb2 = ((int32_t *) dst->op_params)[1]; +// const size_t pnb3 = ((int32_t *) dst->op_params)[2]; +// const size_t offs = ((int32_t *) dst->op_params)[3]; +// +// const bool inplace = (bool) ((int32_t *) dst->op_params)[4]; +// +// if (!inplace) { +// // run a separete kernel to cpy src->dst +// // not sure how to avoid this +// // TODO: make a simpler cpy_bytes kernel +// +// const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; +// +// [encoder setComputePipelineState:pipeline]; +// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; +// [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; +// [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; +// [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; +// [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; +// [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; +// [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; +// [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; +// [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; +// [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; +// [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; +// [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; +// [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; +// [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; +// [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; +// [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; +// [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; +// [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; +// +// const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); +// +// [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; +// } +// +// const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; +// +// [encoder setComputePipelineState:pipeline]; +// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; +// [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; +// [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; +// [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; +// [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; +// [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; +// [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; +// [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; +// [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8]; +// [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9]; +// [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10]; +// [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; +// [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; +// [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; +// [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; +// [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; +// [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; +// [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; +// [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; +// [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; +// [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; +// [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; +// [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; +// [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; +// [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24]; +// [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25]; +// [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26]; +// [encoder setBytes:&offs length:sizeof(offs) atIndex:27]; +// +// const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); +// +// [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; +// } break; +// case GGML_OP_SCALE: +// { +// GGML_ASSERT(ggml_is_contiguous(src0)); +// +// float scale; +// memcpy(&scale, dst->op_params, sizeof(scale)); +// +// int64_t n = ggml_nelements(dst); +// +// id pipeline = nil; +// +// if (n % 4 == 0) { +// n /= 4; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline; +// } else { +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline; +// } +// +// [encoder setComputePipelineState:pipeline]; +// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; +// [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; +// [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; +// +// [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; +// } break; +// case GGML_OP_SOFTCAP: +// { +// GGML_ASSERT(ggml_is_contiguous(src0)); +// +// float scales[2]; +// memcpy(scales, dst->op_params, sizeof(scales)); +// +// int64_t n = ggml_nelements(dst); +// +// id pipeline = nil; +// +// if (n % 4 == 0) { +// n /= 4; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFTCAP_4].pipeline; +// } else { +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFTCAP].pipeline; +// } +// +// [encoder setComputePipelineState:pipeline]; +// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; +// [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; +// [encoder setBytes:&scales[0] length:sizeof(float) atIndex:2]; +// [encoder setBytes:&scales[1] length:sizeof(float) atIndex:3]; +// +// [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; +// } break; +// case GGML_OP_CLAMP: +// { +// id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline; +// +// float min; +// float max; +// memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float)); +// memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float)); +// +// [encoder setComputePipelineState:pipeline]; +// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; +// [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; +// [encoder setBytes:&min length:sizeof(min) atIndex:2]; +// [encoder setBytes:&max length:sizeof(max) atIndex:3]; +// +// const int64_t n = ggml_nelements(dst); +// +// [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; +// } break; +// case GGML_OP_UNARY: +// switch (ggml_get_unary_op(gf->nodes[i])) { +// // we are not taking into account the strides, so for now require contiguous tensors +// GGML_ASSERT(ggml_is_contiguous(src0)); +// +// case GGML_UNARY_OP_TANH: +// { +// id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline; +// +// [encoder setComputePipelineState:pipeline]; +// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; +// [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; +// +// const int64_t n = ggml_nelements(dst); +// +// [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; +// } break; +// case GGML_UNARY_OP_RELU: +// { +// id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RELU].pipeline; +// +// [encoder setComputePipelineState:pipeline]; +// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; +// [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; +// +// const int64_t n = ggml_nelements(dst); +// +// [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; +// } break; +// case GGML_UNARY_OP_SIGMOID: +// { +// id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIGMOID].pipeline; +// +// [encoder setComputePipelineState:pipeline]; +// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; +// [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; +// +// const int64_t n = ggml_nelements(dst); +// +// [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; +// } break; +// case GGML_UNARY_OP_GELU: +// { +// int64_t n = ggml_nelements(dst); +// +// id pipeline = nil; +// +// if (n % 4 == 0) { +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_4].pipeline; +// n /= 4; +// } else { +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline; +// } +// +// [encoder setComputePipelineState:pipeline]; +// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; +// [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; +// +// [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; +// } break; +// case GGML_UNARY_OP_GELU_QUICK: +// { +// int64_t n = ggml_nelements(dst); +// +// id pipeline = nil; +// +// if (n % 4 == 0) { +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK_4].pipeline; +// n /= 4; +// } else { +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline; +// } +// +// [encoder setComputePipelineState:pipeline]; +// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; +// [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; +// +// [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; +// } break; +// case GGML_UNARY_OP_SILU: +// { +// int64_t n = ggml_nelements(dst); +// +// id pipeline = nil; +// +// if (n % 4 == 0) { +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU_4].pipeline; +// n /= 4; +// } else { +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline; +// } +// +// [encoder setComputePipelineState:pipeline]; +// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; +// [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; +// +// [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; +// } break; +// case GGML_UNARY_OP_SWIGLU: +// { +// int64_t n = ggml_nelements(dst); +// GGML_ASSERT(ne0 == src0->ne[0]/2); +// +// id pipeline = nil; +// +// uint32_t n_per_row = ne0; +// uint32_t stride = src0->nb[1]/sizeof(float); +// +// if (ne0 % 4 == 0) { +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU_4].pipeline; +// n /= 4; +// n_per_row /= 4; +// stride /= 4; +// } else { +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline; +// } +// +// [encoder setComputePipelineState:pipeline]; +// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; +// [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; +// [encoder setBytes:&n_per_row length:sizeof(n_per_row) atIndex:2]; +// [encoder setBytes:&stride length:sizeof(stride) atIndex:3]; +// +// [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; +// } break; +// default: +// { +// GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op)); +// GGML_ABORT("fatal error"); +// } +// } break; +// case GGML_OP_FUSED_MUL_UNARY: +// { +// int64_t n = ggml_nelements(dst); +// enum ggml_unary_op op = (enum ggml_unary_op)dst->op_params[0]; +// id pipeline = nil; +// if (n % 4 == 0 && op != GGML_UNARY_OP_RELU) { +// pipeline = op == GGML_UNARY_OP_GELU ? ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_GELU_4].pipeline +// : ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_SILU_4].pipeline; +// n /= 4; +// } else { +// pipeline = op == GGML_UNARY_OP_GELU ? ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_GELU].pipeline +// : op == GGML_UNARY_OP_SILU ? ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_SILU].pipeline +// : ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_RELU].pipeline; +// } +// [encoder setComputePipelineState:pipeline]; +// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; +// [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; +// [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; +// [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; +// } break; +// case GGML_OP_SQR: +// { +// GGML_ASSERT(ggml_is_contiguous(src0)); +// +// id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQR].pipeline; +// +// [encoder setComputePipelineState:pipeline]; +// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; +// [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; +// +// const int64_t n = ggml_nelements(dst); +// +// [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; +// } break; +// case GGML_OP_SUM_ROWS: +// { +// GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); +// +// id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline; +// +// [encoder setComputePipelineState:pipeline]; +// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; +// [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; +// [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; +// [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; +// [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; +// [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; +// [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; +// [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; +// [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; +// [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; +// [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10]; +// [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11]; +// [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; +// [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13]; +// [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; +// [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; +// [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; +// [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17]; +// [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18]; +// [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19]; +// [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20]; +// [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21]; +// [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22]; +// [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23]; +// [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24]; +// [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25]; +// +// [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; +// } break; +// case GGML_OP_SOFT_MAX: +// { +// GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); +// +// int nth = 32; // SIMD width +// +// id pipeline = nil; +// +// const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); +// +// if (ne00%4 == 0) { +// while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) { +// nth *= 2; +// } +// if (use_f16) { +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4].pipeline; +// } else { +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline; +// } +// } else { +// while (nth < ne00 && nth*ne01*ne02*ne03 < 256) { +// nth *= 2; +// } +// if (use_f16) { +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16].pipeline; +// } else { +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32].pipeline; +// } +// } +// +// float scale; +// float max_bias; +// +// memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale)); +// memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias)); +// +// const int64_t nrows_x = ggml_nrows(src0); +// const int64_t nrows_y = src0->ne[1]; +// +// const uint32_t n_head = nrows_x/nrows_y; +// const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); +// +// const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); +// const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); +// +// [encoder setComputePipelineState:pipeline]; +// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; +// if (id_src1) { +// [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; +// } else { +// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; +// } +// [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; +// [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; +// [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; +// [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; +// [encoder setBytes:&scale length:sizeof(scale) atIndex:6]; +// [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7]; +// [encoder setBytes:&m0 length:sizeof(m0) atIndex:8]; +// [encoder setBytes:&m1 length:sizeof(m1) atIndex:9]; +// [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10]; +// [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; +// +// [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; +// } break; +// case GGML_OP_SOFT_CAP_MAX: +// { +// GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); +// +// int nth = 32; // SIMD width +// +// id pipeline = nil; +// +// const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); +// +// if (ne00%4 == 0) { +// while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) { +// nth *= 2; +// } +// if (use_f16) { +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16_4].pipeline; +// } else { +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32_4].pipeline; +// } +// } else { +// while (nth < ne00 && nth*ne01*ne02*ne03 < 256) { +// nth *= 2; +// } +// if (use_f16) { +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16].pipeline; +// } else { +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32].pipeline; +// } +// } +// +// float scale; +// float max_bias; +// float s_before; +// float s_after; +// +// memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale)); +// memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias)); +// memcpy(&s_before, ((int32_t *) dst->op_params) + 2, sizeof(s_before)); +// memcpy(&s_after, ((int32_t *) dst->op_params) + 3, sizeof(s_after)); +// +// const int64_t nrows_x = ggml_nrows(src0); +// const int64_t nrows_y = src0->ne[1]; +// +// const uint32_t n_head = nrows_x/nrows_y; +// const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); +// +// const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); +// const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); +// +// [encoder setComputePipelineState:pipeline]; +// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; +// if (id_src1) { +// [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; +// } else { +// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; +// } +// [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; +// [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; +// [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; +// [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; +// [encoder setBytes:&scale length:sizeof(scale) atIndex:6]; +// [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7]; +// [encoder setBytes:&m0 length:sizeof(m0) atIndex:8]; +// [encoder setBytes:&m1 length:sizeof(m1) atIndex:9]; +// [encoder setBytes:&s_before length:sizeof(s_before) atIndex:10]; +// [encoder setBytes:&s_after length:sizeof(s_after ) atIndex:11]; +// [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:12]; +// [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; +// +// [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; +// } break; +// case GGML_OP_DIAG_MASK_INF: +// { +// const int n_past = ((int32_t *)(dst->op_params))[0]; +// +// id pipeline = nil; +// +// if (ne00%8 == 0) { +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8].pipeline; +// } else { +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline; +// } +// +// [encoder setComputePipelineState:pipeline]; +// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; +// [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; +// [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; +// [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; +// [encoder setBytes:&n_past length:sizeof(int) atIndex:4]; +// +// if (ne00%8 == 0) { +// [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; +// } +// else { +// [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; +// } +// } break; +// case GGML_OP_MUL_MAT: +// { +// GGML_ASSERT(ne00 == ne10); +// +// GGML_ASSERT(ne12 % ne02 == 0); +// GGML_ASSERT(ne13 % ne03 == 0); +// +// const uint r2 = ne12/ne02; +// const uint r3 = ne13/ne03; +// +// // find the break-even point where the matrix-matrix kernel becomes more efficient compared +// // to the matrix-vector kernel +// int ne11_mm_min = 1; +// +//#if 0 +// // the numbers below are measured on M2 Ultra for 7B and 13B models +// // these numbers do not translate to other devices or model sizes +// // TODO: need to find a better approach +// if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) { +// switch (src0t) { +// case GGML_TYPE_F16: ne11_mm_min = 2; break; +// case GGML_TYPE_Q8_0: ne11_mm_min = 7; break; +// case GGML_TYPE_Q2_K: ne11_mm_min = 15; break; +// case GGML_TYPE_Q3_K: ne11_mm_min = 7; break; +// case GGML_TYPE_Q4_0: +// case GGML_TYPE_Q4_1: ne11_mm_min = 15; break; +// case GGML_TYPE_Q4_K: ne11_mm_min = 11; break; +// case GGML_TYPE_Q5_0: // not tested yet +// case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet +// case GGML_TYPE_Q5_K: ne11_mm_min = 7; break; +// case GGML_TYPE_Q6_K: ne11_mm_min = 7; break; +// default: ne11_mm_min = 1; break; +// } +// } +//#endif +// +// // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs +// // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel +// if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && +// !ggml_is_transposed(src0) && +// !ggml_is_transposed(src1) && +// src1t == GGML_TYPE_F32 && +// ne00 % 32 == 0 && ne00 >= 64 && +// (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) { +// //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); +// +// // some Metal matrix data types require aligned pointers +// // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) +// switch (src0->type) { +// case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break; +// case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break; +// default: break; +// } +// +// id pipeline = nil; +// +// switch (src0->type) { +// case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break; +// case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break; +// case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32 ].pipeline; break; +// case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break; +// case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break; +// case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break; +// case GGML_TYPE_Q6_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_0_F32 ].pipeline; break; +// case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break; +// case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break; +// case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break; +// case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break; +// case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break; +// case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32 ].pipeline; break; +// case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break; +// case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break; +// case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break; +// case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break; +// case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break; +// case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break; +// case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break; +// case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break; +// case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F32 ].pipeline; break; +// case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F32 ].pipeline; break; +// case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break; +// case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break; +// case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F32 ].pipeline; break; +// case GGML_TYPE_IQ4_KSS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KSS_F32].pipeline; break; +// case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_K_F32 ].pipeline; break; +// case GGML_TYPE_IQ2_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_KS_F32 ].pipeline; break; +// case GGML_TYPE_IQ3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_K_F32 ].pipeline; break; +// case GGML_TYPE_IQ4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F32 ].pipeline; break; +// case GGML_TYPE_IQ5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_K_F32 ].pipeline; break; +// case GGML_TYPE_IQ6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ6_K_F32 ].pipeline; break; +// default: GGML_ABORT("MUL MAT-MAT not implemented"); +// } +// +// [encoder setComputePipelineState:pipeline]; +// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; +// [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; +// [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; +// [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; +// [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; +// [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5]; +// [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6]; +// [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7]; +// [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8]; +// [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9]; +// [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10]; +// [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11]; +// [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12]; +// [encoder setBytes:&r2 length:sizeof(r2) atIndex:13]; +// [encoder setBytes:&r3 length:sizeof(r3) atIndex:14]; +// [encoder setThreadgroupMemoryLength:8192 atIndex:0]; +// [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; +// } else { +// int nth0 = 32; +// int nth1 = 1; +// int nrows = 1; +// //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); +// +// id pipeline = nil; +// +// // use custom matrix x vector kernel +// switch (src0t) { +// case GGML_TYPE_F32: +// { +// GGML_ASSERT(src1t == GGML_TYPE_F32); +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline; +// nrows = 4; +// } break; +// case GGML_TYPE_F16: +// { +// nth0 = 32; +// nth1 = 1; +// if (src1t == GGML_TYPE_F32) { +// if (ne11 * ne12 < 4) { +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline; +// } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline; +// nrows = ne11; +// } else { +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline; +// nrows = 4; +// } +// } else { +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline; +// nrows = 4; +// } +// } break; +// case GGML_TYPE_BF16: +// { +// if (src1t == GGML_TYPE_F32) { +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline; +// } +// else if (src1t == GGML_TYPE_F16) { +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F16].pipeline; +// } +// else { +// GGML_ABORT("not implemented"); +// } +// nrows = 4; +// } break; +// case GGML_TYPE_Q4_0: +// { +// nth0 = 8; +// nth1 = 8; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline; +// } break; +// case GGML_TYPE_Q4_1: +// { +// nth0 = 8; +// nth1 = 8; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline; +// } break; +// case GGML_TYPE_Q5_0: +// { +// nth0 = 8; +// nth1 = 8; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline; +// } break; +// case GGML_TYPE_Q5_1: +// { +// nth0 = 8; +// nth1 = 8; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline; +// } break; +// case GGML_TYPE_Q6_0: +// { +// nth0 = 8; +// nth1 = 8; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_0_F32].pipeline; +// } break; +// case GGML_TYPE_Q8_0: +// { +// nth0 = 8; +// nth1 = 8; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline; +// } break; +// case GGML_TYPE_Q2_K: +// { +// nth0 = 2; +// nth1 = 32; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline; +// } break; +// case GGML_TYPE_Q3_K: +// { +// nth0 = 2; +// nth1 = 32; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline; +// } break; +// case GGML_TYPE_Q4_K: +// { +// nth0 = 4; //1; +// nth1 = 8; //32; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline; +// } break; +// case GGML_TYPE_Q5_K: +// { +// nth0 = 2; +// nth1 = 32; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline; +// } break; +// case GGML_TYPE_Q6_K: +// { +// nth0 = 2; +// nth1 = 32; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline; +// } break; +// case GGML_TYPE_IQ2_XXS: +// { +// nth0 = 4; +// nth1 = 16; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline; +// } break; +// case GGML_TYPE_IQ2_XS: +// { +// nth0 = 4; +// nth1 = 16; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline; +// } break; +// case GGML_TYPE_IQ3_XXS: +// { +// nth0 = 4; +// nth1 = 16; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline; +// } break; +// case GGML_TYPE_IQ3_S: +// { +// nth0 = 4; +// nth1 = 16; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline; +// } break; +// case GGML_TYPE_IQ2_S: +// { +// nth0 = 4; +// nth1 = 16; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline; +// } break; +// case GGML_TYPE_IQ1_S: +// { +// nth0 = 4; +// nth1 = 16; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline; +// } break; +// case GGML_TYPE_IQ1_M: +// { +// nth0 = 4; +// nth1 = 16; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline; +// } break; +// case GGML_TYPE_IQ1_BN: +// { +// nth0 = 4; +// nth1 = 16; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_BN_F32].pipeline; +// } break; +// case GGML_TYPE_IQ2_BN: +// { +// nth0 = 4; +// nth1 = 16; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_BN_F32].pipeline; +// } break; +// case GGML_TYPE_IQ4_NL: +// { +// nth0 = 4; +// nth1 = 16; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline; +// } break; +// case GGML_TYPE_IQ4_XS: +// { +// nth0 = 4; +// nth1 = 16; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline; +// } break; +// case GGML_TYPE_IQ4_KS: +// { +// nth0 = 4; +// nth1 = 16; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_KS_F32].pipeline; +// } break; +// case GGML_TYPE_IQ4_KSS: +// { +// nth0 = 4; +// nth1 = 16; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_KSS_F32].pipeline; +// } break; +// case GGML_TYPE_IQ2_K: +// { +// nth0 = 4; +// nth1 = 16; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_K_F32].pipeline; +// } break; +// case GGML_TYPE_IQ2_KS: +// { +// nth0 = 4; +// nth1 = 16; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_KS_F32].pipeline; +// } break; +// case GGML_TYPE_IQ3_K: +// { +// nth0 = 4; +// nth1 = 16; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_K_F32].pipeline; +// } break; +// case GGML_TYPE_IQ4_K: +// { +// nth0 = 4; +// nth1 = 16; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_K_F32].pipeline; +// } break; +// case GGML_TYPE_IQ5_K: +// { +// nth0 = 4; +// nth1 = 16; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ5_K_F32].pipeline; +// } break; +// case GGML_TYPE_IQ6_K: +// { +// nth0 = 4; +// nth1 = 16; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ6_K_F32].pipeline; +// } break; +// default: +// { +// GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t); +// GGML_ABORT("not implemented"); +// } +// }; +// +// [encoder setComputePipelineState:pipeline]; +// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; +// [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; +// [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; +// [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; +// [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; +// [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; +// [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; +// [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; +// [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; +// [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9]; +// [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10]; +// [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11]; +// [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12]; +// [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13]; +// [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14]; +// [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15]; +// [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16]; +// [encoder setBytes:&r2 length:sizeof(r2) atIndex:17]; +// [encoder setBytes:&r3 length:sizeof(r3) atIndex:18]; +// +// if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || +// src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || +// src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S|| +// src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_Q6_0) { +// [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; +// } +// else if (src0t == GGML_TYPE_IQ2_KS || src0t == GGML_TYPE_IQ2_K || src0t == GGML_TYPE_IQ3_K) { +// const int mem_size = src0t == GGML_TYPE_IQ2_KS ? 64*sizeof(float) : src0t == GGML_TYPE_IQ3_K ? 32*sizeof(float) : 16*sizeof(float); +// [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; +// [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; +// } +// else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) { +// const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128; +// [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; +// [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; +// } +// else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) { +// const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4; +// [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; +// [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; +// } +// else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS || src0t == GGML_TYPE_IQ4_K || +// src0t == GGML_TYPE_IQ5_K || src0t == GGML_TYPE_IQ6_K || src0t == GGML_TYPE_IQ4_KS|| +// src0t == GGML_TYPE_IQ4_KSS) { +// const int mem_size = src0t == GGML_TYPE_IQ6_K ? 128*sizeof(float) : GGML_TYPE_IQ5_K ? 64*sizeof(float) : 32*sizeof(float); +// [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; +// [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; +// } +// else if (src0t == GGML_TYPE_Q4_K) { +// [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; +// } +// else if (src0t == GGML_TYPE_Q3_K) { +// [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; +// } +// else if (src0t == GGML_TYPE_Q5_K) { +// [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; +// } +// else if (src0t == GGML_TYPE_Q6_K) { +// [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; +// } else { +// const int64_t ny = (ne11 + nrows - 1)/nrows; +// [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; +// } +// } +// } break; +// case GGML_OP_MUL_MAT_ID: +// { +// const int n_as = src0->ne[2]; +// +// // src2 = ids +// const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t); +// +// GGML_ASSERT(src2t == GGML_TYPE_I32); +// +// GGML_ASSERT(!ggml_is_transposed(src0)); +// GGML_ASSERT(!ggml_is_transposed(src1)); +// +// GGML_ASSERT(src1t == GGML_TYPE_F32); +// +// // find the break-even point where the matrix-matrix kernel becomes more efficient compared +// // to the matrix-vector kernel +// // ne20 = n_used_experts +// // ne21 = n_rows +// const int dst_rows = ne20*ne21; +// const int dst_rows_min = n_as; +// const int dst_rows_max = (ctx->device.maxThreadgroupMemoryLength/2 - 8192)/4; +// +// // max size of the rowids array in the kernel shared buffer +// //GGML_ASSERT(dst_rows <= dst_rows_max); +// +// // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs +// // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel +// // !!! +// // TODO: for now, always use mat-vec kernels until we figure out how to improve the +// // indirect matrix multiplication +// // !!! +// //printf("%s: ne00 = %d, dst_rows = %d, dst_rows_min = %d max = %d, ne20 = %d ne21 = %d\n", src0->name, (int)ne00, (int)dst_rows, (int)dst_rows_min, (int)dst_rows_max, (int)ne20, (int)ne21); +// if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && +// ne00 % 32 == 0 && ne00 >= 64 && +// dst_rows > dst_rows_min && +// dst_rows <= dst_rows_max) { +// +// // some Metal matrix data types require aligned pointers +// // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) +// switch (src0->type) { +// case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break; +// case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break; +// default: break; +// } +// +// id pipeline = nil; +// +// switch (src0->type) { +// case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break; +// case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break; +// case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32 ].pipeline; break; +// case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break; +// case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break; +// case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break; +// case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32 ].pipeline; break; +// case GGML_TYPE_Q6_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_0_F32 ].pipeline; break; +// case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32 ].pipeline; break; +// case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32 ].pipeline; break; +// case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ].pipeline; break; +// case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ].pipeline; break; +// case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ].pipeline; break; +// case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break; +// case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break; +// case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break; +// case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break; +// case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break; +// case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break; +// case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break; +// case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32 ].pipeline; break; +// case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_BN_F32 ].pipeline; break; +// case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_BN_F32 ].pipeline; break; +// case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break; +// case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break; +// case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KS_F32 ].pipeline; break; +// case GGML_TYPE_IQ4_KSS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KSS_F32].pipeline; break; +// case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_K_F32 ].pipeline; break; +// case GGML_TYPE_IQ2_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_KS_F32 ].pipeline; break; +// case GGML_TYPE_IQ3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_K_F32 ].pipeline; break; +// case GGML_TYPE_IQ4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_K_F32 ].pipeline; break; +// case GGML_TYPE_IQ5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ5_K_F32 ].pipeline; break; +// case GGML_TYPE_IQ6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ6_K_F32 ].pipeline; break; +// default: GGML_ABORT("MUL_MAT_ID not implemented"); +// } +// +// [encoder setComputePipelineState:pipeline]; +// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; +// [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; +// [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; +// [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; +// [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; +// [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5]; +// [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6]; +// [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7]; +// [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:8]; +// [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9]; +// [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10]; +// [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11]; +// [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; +// [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13]; +// [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; +// [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; +// [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; +// [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17]; +// [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18]; +// [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19]; +// +// [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0]; +// +// [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; +// } else { +// //printf("Using gemv for mul_mat_id\n"); +// int nth0 = 32; +// int nth1 = 1; +// int nrows = 1; +// //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); +// +// id pipeline = nil; +// +// // use custom matrix x vector kernel +// switch (src0t) { +// case GGML_TYPE_F32: +// { +// GGML_ASSERT(src1t == GGML_TYPE_F32); +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline; +// } break; +// case GGML_TYPE_F16: +// { +// GGML_ASSERT(src1t == GGML_TYPE_F32); +// nth0 = 32; +// nth1 = 1; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline; +// } break; +// case GGML_TYPE_BF16: +// { +// GGML_ASSERT(src1t == GGML_TYPE_F32); +// nth0 = 32; +// nth1 = 1; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32].pipeline; +// } break; +// case GGML_TYPE_Q4_0: +// { +// nth0 = 8; +// nth1 = 8; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline; +// } break; +// case GGML_TYPE_Q4_1: +// { +// nth0 = 8; +// nth1 = 8; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline; +// } break; +// case GGML_TYPE_Q5_0: +// { +// nth0 = 8; +// nth1 = 8; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline; +// } break; +// case GGML_TYPE_Q5_1: +// { +// nth0 = 8; +// nth1 = 8; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline; +// } break; +// case GGML_TYPE_Q6_0: +// { +// nth0 = 8; +// nth1 = 8; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_0_F32].pipeline; +// } break; +// case GGML_TYPE_Q8_0: +// { +// nth0 = 32; +// nth1 = 2; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline; +// } break; +// case GGML_TYPE_Q2_K: +// { +// nth0 = 2; +// nth1 = 32; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline; +// } break; +// case GGML_TYPE_Q3_K: +// { +// nth0 = 2; +// nth1 = 32; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline; +// } break; +// case GGML_TYPE_Q4_K: +// { +// nth0 = 4; //1; +// nth1 = 8; //32; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline; +// } break; +// case GGML_TYPE_Q5_K: +// { +// nth0 = 2; +// nth1 = 32; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline; +// } break; +// case GGML_TYPE_Q6_K: +// { +// nth0 = 2; +// nth1 = 32; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline; +// } break; +// case GGML_TYPE_IQ2_XXS: +// { +// nth0 = 4; +// nth1 = 16; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline; +// } break; +// case GGML_TYPE_IQ2_XS: +// { +// nth0 = 4; +// nth1 = 16; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline; +// } break; +// case GGML_TYPE_IQ3_XXS: +// { +// nth0 = 4; +// nth1 = 16; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline; +// } break; +// case GGML_TYPE_IQ3_S: +// { +// nth0 = 4; +// nth1 = 16; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline; +// } break; +// case GGML_TYPE_IQ2_S: +// { +// nth0 = 4; +// nth1 = 16; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline; +// } break; +// case GGML_TYPE_IQ1_S: +// { +// nth0 = 4; +// nth1 = 16; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline; +// } break; +// case GGML_TYPE_IQ1_M: +// { +// nth0 = 4; +// nth1 = 16; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline; +// } break; +// case GGML_TYPE_IQ1_BN: +// { +// nth0 = 4; +// nth1 = 16; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_BN_F32].pipeline; +// } break; +// case GGML_TYPE_IQ2_BN: +// { +// nth0 = 4; +// nth1 = 16; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_BN_F32].pipeline; +// } break; +// case GGML_TYPE_IQ4_NL: +// { +// nth0 = 4; +// nth1 = 16; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline; +// } break; +// case GGML_TYPE_IQ4_XS: +// { +// nth0 = 4; +// nth1 = 16; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline; +// } break; +// case GGML_TYPE_IQ4_KS: +// { +// nth0 = 4; +// nth1 = 16; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_KS_F32].pipeline; +// } break; +// case GGML_TYPE_IQ4_KSS: +// { +// nth0 = 4; +// nth1 = 16; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_KSS_F32].pipeline; +// } break; +// case GGML_TYPE_IQ2_K: +// { +// nth0 = 4; +// nth1 = 16; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_K_F32].pipeline; +// } break; +// case GGML_TYPE_IQ2_KS: +// { +// nth0 = 4; +// nth1 = 16; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_KS_F32].pipeline; +// } break; +// case GGML_TYPE_IQ3_K: +// { +// nth0 = 4; +// nth1 = 16; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_K_F32].pipeline; +// } break; +// case GGML_TYPE_IQ4_K: +// { +// nth0 = 4; +// nth1 = 16; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_K_F32].pipeline; +// } break; +// case GGML_TYPE_IQ5_K: +// { +// nth0 = 4; +// nth1 = 16; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ5_K_F32].pipeline; +// } break; +// case GGML_TYPE_IQ6_K: +// { +// nth0 = 4; +// nth1 = 16; +// pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ6_K_F32].pipeline; +// } break; +// default: +// { +// GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t); +// GGML_ABORT("not implemented"); +// } +// }; +// +// if (ggml_is_quantized(src0t)) { +// GGML_ASSERT(ne00 >= nth0*nth1); +// } +// +// [encoder setComputePipelineState:pipeline]; +// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; +// [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; +// [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; +// [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; +// [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; +// [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5]; +// [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6]; +// [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7]; +// [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:8]; +// [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:9]; +// [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:10]; +// [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:11]; +// [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:12]; +// [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:13]; +// [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14]; +// [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:15]; +// [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:16]; +// [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:17]; +// [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18]; +// [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19]; +// [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:20]; +// [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:21]; +// [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:22]; +// +// const int64_t _ne1 = 1; +// const int tgz = dst_rows; +// +// if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || +// src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || +// src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_Q6_0 || +// src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_IQ2_K) { +// [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; +// } +// else if (src0t == GGML_TYPE_IQ2_KS || src0t == GGML_TYPE_IQ2_K || src0t == GGML_TYPE_IQ3_K) { +// const int mem_size = src0t == GGML_TYPE_IQ2_KS ? 64*sizeof(float) : src0t == GGML_TYPE_IQ3_K ? 32*sizeof(float) : 16*sizeof(float); +// [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; +// [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; +// } +// else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) { +// const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128; +// [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; +// [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; +// } +// else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) { +// const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4; +// [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; +// [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; +// } +// else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS || src0t == GGML_TYPE_IQ4_K || +// src0t == GGML_TYPE_IQ5_K || src0t == GGML_TYPE_IQ6_K || src0t == GGML_TYPE_IQ4_KS|| +// src0t == GGML_TYPE_IQ4_KSS) { +// const int mem_size = src0t == GGML_TYPE_IQ6_K ? 128*sizeof(float) : GGML_TYPE_IQ5_K ? 64*sizeof(float) : 32*sizeof(float); +// [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; +// [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; +// } +// else if (src0t == GGML_TYPE_Q4_K) { +// [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; +// } +// else if (src0t == GGML_TYPE_Q3_K) { +// [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; +// } +// else if (src0t == GGML_TYPE_Q5_K) { +// [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; +// } +// else if (src0t == GGML_TYPE_Q6_K) { +// [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; +// } else { +// const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1 +// [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; +// } +// } +// } break; +// case GGML_OP_GET_ROWS: +// { +// id pipeline = nil; +// +// switch (src0->type) { +// case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break; +// case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break; +// case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break; +// case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break; +// case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break; +// case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break; +// case GGML_TYPE_Q6_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_0 ].pipeline; break; +// case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break; +// case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break; +// case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break; +// case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break; +// case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K ].pipeline; break; +// case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ].pipeline; break; +// case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break; +// case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break; +// case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break; +// case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S ].pipeline; break; +// case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S ].pipeline; break; +// case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break; +// case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M ].pipeline; break; +// case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_BN ].pipeline; break; +// case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_BN ].pipeline; break; +// case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break; +// case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break; +// case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KS ].pipeline; break; +// case GGML_TYPE_IQ4_KSS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KSS].pipeline; break; +// case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_K ].pipeline; break; +// case GGML_TYPE_IQ2_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_KS ].pipeline; break; +// case GGML_TYPE_IQ3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_K ].pipeline; break; +// case GGML_TYPE_IQ4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_K ].pipeline; break; +// case GGML_TYPE_IQ5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ5_K ].pipeline; break; +// case GGML_TYPE_IQ6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ6_K ].pipeline; break; +// case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break; +// default: GGML_ABORT("not implemented"); +// } +// +// [encoder setComputePipelineState:pipeline]; +// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; +// [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; +// [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; +// [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; +// [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4]; +// [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5]; +// [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6]; +// [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7]; +// [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8]; +// [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9]; +// [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10]; +// +// [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; +// } break; +// case GGML_OP_RMS_NORM: +// { +// GGML_ASSERT(ne00 % 4 == 0); +// GGML_ASSERT(ggml_is_contiguous_1(src0)); +// +// float eps; +// memcpy(&eps, dst->op_params, sizeof(float)); +// +// int nth = 32; // SIMD width +// +// while (nth < ne00/4 && nth < 1024) { +// nth *= 2; +// } +// +// id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline; +// +// [encoder setComputePipelineState:pipeline]; +// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; +// [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; +// [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; +// [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; +// [encoder setBytes:&eps length:sizeof( float) atIndex:4]; +// [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; +// +// const int64_t nrows = ggml_nrows(src0); +// +// [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; +// } break; +// case GGML_OP_FUSED_RMS_NORM: +// { +// GGML_ASSERT(ne00 % 4 == 0); +// GGML_ASSERT(ggml_is_contiguous_1(src0)); +// GGML_ASSERT(src1->ne[0] == src0->ne[0]); +// GGML_ASSERT(src1->type == GGML_TYPE_F32); +// GGML_ASSERT(ggml_nrows(src1) == 1); +// +// float eps; +// memcpy(&eps, dst->op_params, sizeof(float)); +// +// int nth = 32; // SIMD width +// +// while (nth < ne00/4 && nth < 1024) { +// nth *= 2; +// } +// +// id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FUSED_RMS_NORM].pipeline; +// +// [encoder setComputePipelineState:pipeline]; +// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; +// [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; +// [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; +// [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; +// [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4]; +// [encoder setBytes:&eps length:sizeof( float) atIndex:5]; +// [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; +// +// const int64_t nrows = ggml_nrows(src0); +// +// [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; +// } break; +// case GGML_OP_GROUP_NORM: +// { +// GGML_ASSERT(ne00 % 4 == 0); +// GGML_ASSERT(ggml_is_contiguous(src0)); +// +// float eps; +// memcpy(&eps, dst->op_params + 1, sizeof(float)); +// +// const int32_t n_groups = ((int32_t *) dst->op_params)[0]; +// +// int nth = 32; // SIMD width +// +// //while (nth < ne00/4 && nth < 1024) { +// // nth *= 2; +// //} +// +// id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline; +// +// [encoder setComputePipelineState:pipeline]; +// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; +// [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; +// [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; +// [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; +// [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; +// [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5]; +// [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6]; +// [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7]; +// [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8]; +// [encoder setBytes:&eps length:sizeof( float) atIndex:9]; +// [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; +// +// [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; +// } break; +// case GGML_OP_NORM: +// { +// GGML_ASSERT(ggml_is_contiguous_1(src0)); +// +// float eps; +// memcpy(&eps, dst->op_params, sizeof(float)); +// +// const int nth = MIN(256, ne00); +// +// id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline; +// +// [encoder setComputePipelineState:pipeline]; +// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; +// [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; +// [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; +// [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; +// [encoder setBytes:&eps length:sizeof( float) atIndex:4]; +// [encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0]; +// +// const int64_t nrows = ggml_nrows(src0); +// +// [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; +// } break; +// case GGML_OP_ROPE: +// { +// GGML_ASSERT(ne10 == ne02); +// +// const int nth = MIN(1024, ne00); +// +// const int n_past = ((int32_t *) dst->op_params)[0]; +// const int n_dims = ((int32_t *) dst->op_params)[1]; +// const int mode = ((int32_t *) dst->op_params)[2]; +// // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal +// const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; +// +// float freq_base; +// float freq_scale; +// float ext_factor; +// float attn_factor; +// float beta_fast; +// float beta_slow; +// +// memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); +// memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); +// memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); +// memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); +// memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); +// memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); +// +// const bool is_neox = mode & 2; +// +// id pipeline = nil; +// +// if (!is_neox) { +// switch (src0->type) { +// case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break; +// case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break; +// default: GGML_ABORT("fatal error"); +// }; +// } else { +// switch (src0->type) { +// case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break; +// case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break; +// default: GGML_ABORT("fatal error"); +// }; +// } +// +// [encoder setComputePipelineState:pipeline]; +// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; +// [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; +// if (id_src2 != nil) { +// [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; +// } else { +// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2]; +// } +// [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; +// [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:4]; +// [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5]; +// [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6]; +// [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7]; +// [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:8]; +// [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:9]; +// [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:10]; +// [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:11]; +// [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:12]; +// [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:13]; +// [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:14]; +// [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:15]; +// [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:16]; +// [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:17]; +// [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:18]; +// [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19]; +// [encoder setBytes:&n_past length:sizeof( int) atIndex:20]; +// [encoder setBytes:&n_dims length:sizeof( int) atIndex:21]; +// [encoder setBytes:&n_ctx_orig length:sizeof( int) atIndex:22]; +// [encoder setBytes:&freq_base length:sizeof( float) atIndex:23]; +// [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24]; +// [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25]; +// [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26]; +// [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27]; +// [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28]; +// +// [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; +// } break; +// case GGML_OP_IM2COL: +// { +// GGML_ASSERT(src0->type == GGML_TYPE_F16); +// GGML_ASSERT(src1->type == GGML_TYPE_F32); +// GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); +// +// const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; +// const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; +// const int32_t p0 = ((const int32_t *)(dst->op_params))[2]; +// const int32_t p1 = ((const int32_t *)(dst->op_params))[3]; +// const int32_t d0 = ((const int32_t *)(dst->op_params))[4]; +// const int32_t d1 = ((const int32_t *)(dst->op_params))[5]; +// +// const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1; +// +// const int32_t N = src1->ne[is_2D ? 3 : 2]; +// const int32_t IC = src1->ne[is_2D ? 2 : 1]; +// const int32_t IH = is_2D ? src1->ne[1] : 1; +// const int32_t IW = src1->ne[0]; +// +// const int32_t KH = is_2D ? src0->ne[1] : 1; +// const int32_t KW = src0->ne[0]; +// +// const int32_t OH = is_2D ? dst->ne[2] : 1; +// const int32_t OW = dst->ne[1]; +// +// const int32_t CHW = IC * KH * KW; +// +// const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; +// const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; +// +// id pipeline = nil; +// +// switch (dst->type) { +// case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break; +// case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break; +// default: GGML_ABORT("fatal error"); +// }; +// +// [encoder setComputePipelineState:pipeline]; +// [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0]; +// [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; +// [encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2]; +// [encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3]; +// [encoder setBytes:&IW length:sizeof( int32_t) atIndex:4]; +// [encoder setBytes:&IH length:sizeof( int32_t) atIndex:5]; +// [encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6]; +// [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7]; +// [encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8]; +// [encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9]; +// [encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10]; +// [encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11]; +// [encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12]; +// +// [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)]; +// } break; +// case GGML_OP_UPSCALE: +// { +// GGML_ASSERT(src0->type == GGML_TYPE_F32); +// +// const float sf0 = (float)ne0/src0->ne[0]; +// const float sf1 = (float)ne1/src0->ne[1]; +// const float sf2 = (float)ne2/src0->ne[2]; +// const float sf3 = (float)ne3/src0->ne[3]; +// +// const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline; +// +// [encoder setComputePipelineState:pipeline]; +// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; +// [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; +// [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; +// [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; +// [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; +// [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; +// [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; +// [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; +// [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; +// [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; +// [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; +// [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; +// [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; +// [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; +// [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; +// [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; +// [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; +// [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; +// [encoder setBytes:&sf0 length:sizeof(sf0) atIndex:18]; +// [encoder setBytes:&sf1 length:sizeof(sf1) atIndex:19]; +// [encoder setBytes:&sf2 length:sizeof(sf2) atIndex:20]; +// [encoder setBytes:&sf3 length:sizeof(sf3) atIndex:21]; +// +// const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); +// +// [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; +// } break; +// case GGML_OP_PAD: +// { +// GGML_ASSERT(src0->type == GGML_TYPE_F32); +// +// id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline; +// +// [encoder setComputePipelineState:pipeline]; +// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; +// [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; +// [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; +// [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; +// [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; +// [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; +// [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; +// [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; +// [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; +// [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; +// [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; +// [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; +// [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; +// [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; +// [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; +// [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; +// [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; +// [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; +// +// const int nth = MIN(1024, ne0); +// +// [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; +// } break; +// case GGML_OP_ARANGE: +// { +// GGML_ASSERT(dst->type == GGML_TYPE_F32); +// +// float start; +// float step; +// +// memcpy(&start, ((int32_t *) dst->op_params) + 0, sizeof(float)); +// memcpy(&step, ((int32_t *) dst->op_params) + 2, sizeof(float)); +// +// id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline; +// +// [encoder setComputePipelineState:pipeline]; +// [encoder setBuffer:id_dst offset:offs_dst atIndex:0]; +// [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1]; +// [encoder setBytes:&start length:sizeof(start) atIndex:2]; +// [encoder setBytes:&step length:sizeof(step) atIndex:3]; +// +// const int nth = MIN(1024, ne0); +// +// [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; +// } break; +// case GGML_OP_TIMESTEP_EMBEDDING: +// { +// GGML_ASSERT(src0->type == GGML_TYPE_F32); +// +// const int dim = dst->op_params[0]; +// const int max_period = dst->op_params[1]; +// +// const int half = dim / 2; +// +// id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline; +// +// [encoder setComputePipelineState:pipeline]; +// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; +// [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; +// [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:2]; +// [encoder setBytes:&dim length:sizeof(dim) atIndex:3]; +// [encoder setBytes:&max_period length:sizeof(max_period) atIndex:4]; +// +// const int nth = MIN(1024, half); +// +// [encoder dispatchThreadgroups:MTLSizeMake(ne00, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; +// } break; +// case GGML_OP_ARGSORT: +// { +// GGML_ASSERT(src0->type == GGML_TYPE_F32); +// GGML_ASSERT( dst->type == GGML_TYPE_I32); +// +// const int nrows = ggml_nrows(src0); +// +// enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0]; +// +// // bitonic sort requires the number of elements to be power of 2 +// int64_t ne00_padded = 1; +// while (ne00_padded < ne00) { +// ne00_padded *= 2; +// } +// +// // Metal kernels require the buffer size to be multiple of 16 bytes +// // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength +// const int mem_size = GGML_PAD(ne00_padded*sizeof(int32_t), 16); +// +// id pipeline = nil; +// +// switch (order) { +// case GGML_SORT_ORDER_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break; +// case GGML_SORT_ORDER_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break; +// default: GGML_ABORT("fatal error"); +// }; +// +// [encoder setComputePipelineState:pipeline]; +// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; +// [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; +// [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; +// [encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3]; +// [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; +// +// [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)]; +// } break; +// case GGML_OP_LEAKY_RELU: +// { +// GGML_ASSERT(src0->type == GGML_TYPE_F32); +// +// float slope; +// memcpy(&slope, dst->op_params, sizeof(float)); +// +// id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline; +// +// [encoder setComputePipelineState:pipeline]; +// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; +// [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; +// [encoder setBytes:&slope length:sizeof(slope) atIndex:2]; +// +// const int64_t n = ggml_nelements(dst); +// +// [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; +// } break; +// case GGML_OP_FLASH_ATTN_EXT: +// { +// GGML_ASSERT(ne00 % 4 == 0); +// GGML_ASSERT(ne11 % 32 == 0); +// +// GGML_ASSERT(src0->type == GGML_TYPE_F32); +// +// GGML_ASSERT(ggml_are_same_shape (src1, src2)); +// +// struct ggml_tensor * src3 = gf->nodes[i]->src[3]; +// +// size_t offs_src3 = 0; +// +// id id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; +// +// GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16); +// GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) && +// "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big"); +// +// const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30); +// //const int64_t ne31 = src3 ? src3->ne[1] : 0; +// const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32); +// const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33); +// +// const uint64_t nb30 = src3 ? src3->nb[0] : 0; GGML_UNUSED(nb30); +// const uint64_t nb31 = src3 ? src3->nb[1] : 0; +// const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32); +// const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33); +// +// const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t); +// +// float scale; +// float max_bias; +// float softcap; +// +// memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale)); +// memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias)); +// memcpy(&softcap, ((int32_t *) dst->op_params) + 2, sizeof(softcap)); +// if (softcap != 0.0f) { +// scale /= softcap; +// } +// +// const uint32_t n_head = src0->ne[2]; +// const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); +// +// const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); +// const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); +// +// id pipeline = nil; +// +// bool use_vec_kernel = false; +// +// if (ne01 >= 4 || (ne00%128 != 0)) { +// switch (ne00) { +// case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; +// case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; +// case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; +// case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; +// case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; +// //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; +// default: +// { +// GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); +// GGML_METAL_LOG_ERROR("add template specialization for this size\n"); +// GGML_ABORT("add template specialization for this size"); +// } +// } +// } else { +// use_vec_kernel = true; +// +// switch (ne00) { +// case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break; +// //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; +// default: +// { +// GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); +// GGML_METAL_LOG_ERROR("add template specialization for this size\n"); +// GGML_ABORT("add template specialization for this size"); +// } +// } +// } +// +// [encoder setComputePipelineState:pipeline]; +// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; +// [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; +// [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; +// if (id_src3) { +// [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; +// } else { +// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3]; +// } +// [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; +// [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5]; +// [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6]; +// [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7]; +// [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8]; +// [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9]; +// [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10]; +// [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11]; +// [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12]; +// [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13]; +// [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14]; +// [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15]; +// [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16]; +// [encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17]; +// [encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18]; +// [encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19]; +// [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20]; +// [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21]; +// [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22]; +// [encoder setBytes:&scale length:sizeof( float) atIndex:23]; +// [encoder setBytes:&max_bias length:sizeof( float) atIndex:24]; +// [encoder setBytes:&m0 length:sizeof(m0) atIndex:25]; +// [encoder setBytes:&m1 length:sizeof(m1) atIndex:26]; +// [encoder setBytes:&softcap length:sizeof(softcap) atIndex:27]; +// [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:28]; +// +// if (!use_vec_kernel) { +// // half8x8 kernel +// const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! +// const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! +// +// GGML_ASSERT(nqptg <= 32); +// GGML_ASSERT(nqptg % 8 == 0); +// GGML_ASSERT(ncpsg % 32 == 0); +// +// int64_t nsgmax = 2; +// +// while (true) { +// const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2); +// if (smem > ctx->device.maxThreadgroupMemoryLength) { +// break; +// } +// nsgmax *= 2; +// } +// nsgmax /= 2; +// +// // simdgroups per threadgroup (a.k.a. warps) +// const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4; +// +// const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2); +// +// //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); +// GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); +// +// [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0]; +// +// [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; +// } else { +// // half1x4 kernel +// const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !! +// const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! +// +// GGML_ASSERT(nqptg <= 32); +// GGML_ASSERT(nqptg % 1 == 0); +// GGML_ASSERT(ncpsg % 32 == 0); +// +// // simdgroups per threadgroup (a.k.a. warps) +// const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)); +// +// int64_t nsg = 1; +// while (nsg <= nsgt) { +// nsg *= 2; +// } +// nsg /= 2; +// +// const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2); +// +// //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); +// GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); +// [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0]; +// +// [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; +// } +// } break; +// case GGML_OP_DUP: +// case GGML_OP_CPY: +// case GGML_OP_CONT: +// { +// GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); +// +// int nth = MIN(1024, ne00/ggml_blck_size(src0->type)); +// +// id pipeline = nil; +// +// switch (src0t) { +// case GGML_TYPE_F32: +// { +// GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0); +// +// switch (dstt) { +// case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break; +// case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break; +// case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break; +// case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break; +// case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break; +// case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break; +// case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break; +// case GGML_TYPE_Q6_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q6_0].pipeline; break; +// case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL].pipeline; break; +// default: GGML_ABORT("not implemented"); +// }; +// } break; +// case GGML_TYPE_F16: +// { +// switch (dstt) { +// case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break; +// case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break; +// default: GGML_ABORT("not implemented"); +// }; +// } break; +// default: GGML_ABORT("not implemented"); +// } +// +// [encoder setComputePipelineState:pipeline]; +// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; +// [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; +// [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; +// [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; +// [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; +// [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; +// [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; +// [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; +// [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; +// [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; +// [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; +// [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; +// [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; +// [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; +// [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; +// [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; +// [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; +// [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; +// +// [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; +// } break; +// default: +// { +// GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op)); +// GGML_ABORT("fatal error"); +// } +// } +// +// if (should_capture) { +// [encoder popDebugGroup]; +// } +// } +// +// [encoder endEncoding]; +// +// if (cb_idx < 2 || ctx->abort_callback == NULL) { +// [command_buffer commit]; +// } +// }); +// +// // Wait for completion and check status of each command buffer +// // needed to detect if the device ran out-of-memory for example (#1881) +// +// for (int i = 0; i < n_cb; ++i) { +// id command_buffer = command_buffers[i]; +// [command_buffer waitUntilCompleted]; +// +// MTLCommandBufferStatus status = [command_buffer status]; +// if (status != MTLCommandBufferStatusCompleted) { +// GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status); +// if (status == MTLCommandBufferStatusError) { +// NSString * error_code = [command_buffer error].localizedDescription; +// GGML_METAL_LOG_INFO("error: %s\n", [error_code UTF8String]); +// } +// +// return GGML_STATUS_FAILED; +// } +// +// id next_buffer = (i + 1 < n_cb ? command_buffers[i + 1] : nil); +// if (!next_buffer) { +// continue; +// } +// +// bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued); +// if (next_queued) { +// continue; +// } +// +// if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) { +// GGML_METAL_LOG_INFO("%s: command buffer %d aborted", __func__, i); +// return GGML_STATUS_ABORTED; +// } +// +// [next_buffer commit]; +// } +// +// if (should_capture) { +// [[MTLCaptureManager sharedCaptureManager] stopCapture]; +// } +// +// } +// return GGML_STATUS_SUCCESS; +//} + //////////////////////////////////////////////////////////////////////////////// // backend interface @@ -3905,8 +6455,60 @@ void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { GGML_ASSERT(ggml_backend_is_metal(backend)); struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context; + if (ctx->n_cb != n_cb) { + ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_COMMAND_BUFFERS); + + if (ctx->n_cb > 2) { + GGML_METAL_LOG_WARN("%s: n_cb = %d, using n_cb > 2 is not recommended and can degrade the performance in some cases\n", __func__, n_cb); + } + } + + if (ctx->encode_async) { + Block_release(ctx->encode_async); + } + + ctx->encode_async = Block_copy(^(size_t iter) { + const int cb_idx = iter; + const int n_cb_l = ctx->n_cb; + + const int n_nodes_0 = ctx->n_nodes_0; + const int n_nodes_1 = ctx->n_nodes_1; + + const int n_nodes_per_cb = ctx->n_nodes_per_cb; + + id command_buffer = ctx->command_buffers[cb_idx]; + id encoder = [command_buffer computeCommandEncoder]; + + int node_start = 0; + int node_end = n_nodes_0; + + if (cb_idx < n_cb_l) { + node_start = n_nodes_0 + ( (cb_idx + 0) * n_nodes_per_cb); + node_end = n_nodes_0 + (MIN((cb_idx == n_cb_l - 1) ? n_nodes_1 : (cb_idx + 1) * n_nodes_per_cb, n_nodes_1)); + } + + const bool should_capture = ctx->capture_next_compute; + + for (int idx = node_start; idx < node_end; ++idx) { + struct ggml_tensor * node = ctx->gf->nodes[idx]; + if (should_capture) { + [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(node) encoding:NSUTF8StringEncoding]]; + } + + ggml_metal_encode_node(ctx, node, encoder); + + if (should_capture) { + [encoder popDebugGroup]; + } + } + + [encoder endEncoding]; + + if (cb_idx < 2 || ctx->abort_callback == NULL) { + [command_buffer commit]; + } + }); - ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS); } void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data) { diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 89cd412a..4f7365ad 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -1652,13 +1652,18 @@ void kernel_mul_mv_q8_0_f32_impl( yl[i] = yb[i]; } + device const block_q8_0 * xr = x + ib; + for (int row = 0; row < nr; row++) { - device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il; + //device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il; + device const int8_t * qs = xr->qs + NB_Q8_0*il; float sumq = 0.f; for (int iq = 0; iq < NB_Q8_0; ++iq) { sumq += qs[iq] * yl[iq]; } - sumf[row] += sumq*x[ib+row*nb].d; + //sumf[row] += sumq*x[ib+row*nb].d; + sumf[row] += sumq*xr->d; + xr += nb; } yb += NB_Q8_0 * nw; @@ -5746,7 +5751,7 @@ void kernel_mul_mv_iq4_xs_f32_impl( uint tiisg, uint sgitg) { - threadgroup float * shared_values = (threadgroup float *)shared_values_i8; + //threadgroup float * shared_values = (threadgroup float *)shared_values_i8; const int nb = ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; @@ -5766,8 +5771,8 @@ void kernel_mul_mv_iq4_xs_f32_impl( const int ib = it/2; const int il = it%2; - shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16]; - threadgroup_barrier(mem_flags::mem_threadgroup); + //shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16]; + //threadgroup_barrier(mem_flags::mem_threadgroup); float4 yl[4]; float sumf[2]={0.f}, all_sum; @@ -5793,15 +5798,19 @@ void kernel_mul_mv_iq4_xs_f32_impl( aux32[0] = q4[0] & 0x0f0f0f0f; aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f; - qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; - qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + //qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; + //qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + qf1 = {kvalues_iq4nl_f[q8[0]], kvalues_iq4nl_f[q8[1]], kvalues_iq4nl_f[q8[2]], kvalues_iq4nl_f[q8[3]]}; + qf2 = {kvalues_iq4nl_f[q8[4]], kvalues_iq4nl_f[q8[5]], kvalues_iq4nl_f[q8[6]], kvalues_iq4nl_f[q8[7]]}; acc1 += yl[0] * qf1; acc2 += yl[1] * qf2; aux32[0] = q4[1] & 0x0f0f0f0f; aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f; - qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; - qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + //qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; + //qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + qf1 = {kvalues_iq4nl_f[q8[0]], kvalues_iq4nl_f[q8[1]], kvalues_iq4nl_f[q8[2]], kvalues_iq4nl_f[q8[3]]}; + qf2 = {kvalues_iq4nl_f[q8[4]], kvalues_iq4nl_f[q8[5]], kvalues_iq4nl_f[q8[6]], kvalues_iq4nl_f[q8[7]]}; acc1 += yl[2] * qf1; acc2 += yl[3] * qf2; @@ -7218,13 +7227,43 @@ void dequantize_q6_0(device const block_q6_0 *xb, short il, thread type4x4 & reg template void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) { device const int8_t * qs = ((device const int8_t *)xb->qs); - const half d = xb->d; - - for (int i = 0; i < 16; i++) { - reg[i/4][i%4] = (qs[i + 16*il] * d); + if constexpr (is_same_v) { + const half d = xb->d; + for (int i = 0; i < 16; i++) { + reg[i/4][i%4] = (half)qs[i + 16*il] * d; + } + } else { + const float d = xb->d; + for (int i = 0; i < 16; i++) { + reg[i/4][i%4] = qs[i + 16*il] * d; + } } } +//template +//void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) { +// device const int8_t * qs = ((device const int8_t *)xb->qs); +// const half d = xb->d; +// +// for (int i = 0; i < 16; i++) { +// reg[i/4][i%4] = (qs[i + 16*il] * d); +// } +//} + +//template +//void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) { +// device const int8_t * qs = ((device const int8_t *)xb->qs); +// const float d = xb->d; +// +// float4x4 reg_f; +// +// for (int i = 0; i < 16; i++) { +// reg_f[i/4][i%4] = d * qs[i + 16*il]; +// } +// +// reg = (type4x4)reg_f; +//} + template void dequantize_q2_K(device const block_q2_K * xb, short il, thread type4x4 & reg) { const float d = xb->d; @@ -8246,39 +8285,6 @@ kernel void kernel_mul_mm_id( uint ntg = ntg3.x * ntg3.y * ntg3.z; uint n = nei0*nei1; - //uint npt = (n + ntg - 1) / ntg; - //uint first = tiitg * npt; - //uint last = first + npt <= n ? first + npt : n; - - //uint nhave = 0; - //for (uint i = first; i < last; ++i) { - // uint ii0 = i % nei0; - // uint ii1 = i / nei0; - // int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0]; - // if (id == i02) ++nhave; - //} - //threadgroup uint * nums = (threadgroup uint *)shared_memory; - //nums[tiitg] = nhave; - //threadgroup_barrier(mem_flags::mem_threadgroup); - - //uint nprev = 0; - //for (uint i = 0; i < tiitg; ++i) nprev += nums[i]; - //int64_t _ne1 = nprev; - //for (uint i = tiitg; i < ntg; ++i) _ne1 += nums[i]; - - //threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192); - //for (uint i = first; i < last; ++i) { - // uint ii0 = i % nei0; - // uint ii1 = i / nei0; - // int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0]; - // if (id == i02) rowids[nprev++] = ushort2(ii0, ii1); - //} - - //threadgroup_barrier(mem_flags::mem_threadgroup); - - // - // The following is slightly faster than the commented out version above - // uint nhave = 0; for (uint i = tiitg; i < n; i += ntg) { uint ii0 = i % nei0; @@ -8290,10 +8296,31 @@ kernel void kernel_mul_mm_id( nums[tiitg] = nhave; threadgroup_barrier(mem_flags::mem_threadgroup); - uint nprev = 0; - for (uint i = 0; i < tiitg; ++i) nprev += nums[i]; - int64_t _ne1 = nprev; - for (uint i = tiitg; i < ntg; ++i) _ne1 += nums[i]; + uint stride = 1; + while (stride <= ntg/2) { + uint index = (tiitg+1)*stride*2 - 1; // index - stride = 2*tiitg*stride + stride - 1; + if (index < ntg) nums[index] += nums[index-stride]; + stride <<= 1; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + stride = ntg/2; + while (stride > 0) { + uint index = (tiitg+1)*stride*2 - 1; + if (index+stride < ntg) nums[index+stride] += nums[index]; + stride >>= 1; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + int64_t _ne1 = nums[ntg-1]; + if (!_ne1) return; + + uint nprev = tiitg > 0 ? nums[tiitg-1] : 0; + + //uint ncum = 0; + //for (uint i = 0; i < tiitg; ++i) ncum += nums[i]; + //uint nprev = ncum; + //for (uint i = tiitg; i < ntg; ++i) ncum += nums[i]; + //if (!ncum) return; + //int64_t _ne1 = ncum; threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192); for (uint i = tiitg; i < n; i += ntg) { @@ -8304,26 +8331,6 @@ kernel void kernel_mul_mm_id( } threadgroup_barrier(mem_flags::mem_threadgroup); - // This is the original version that is ridiculously slow. - //// row indices - //threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192); - - //// TODO: parallelize this loop - //int64_t _ne1 = 0; - //for (ushort ii1 = 0; ii1 < nei1; ii1++) { - // for (ushort ii0 = 0; ii0 < nei0; ii0++) { - // int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0]; - // if (id == i02) { - // //if (tiitg == 0) { - // rowids[_ne1] = ushort2(ii0, ii1); - // //} - // _ne1++; - // } - // } - //} - - //threadgroup_barrier(mem_flags::mem_threadgroup); - kernel_mul_mm_id_impl( src0, src1,