diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 20ab3a9d..774314df 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -1609,19 +1609,22 @@ static enum ggml_status ggml_metal_graph_compute( id pipeline = nil; + uint32_t n_per_row = ne0; + uint32_t stride = src0->nb[1]/sizeof(float); + if (ne0 % 4 == 0) { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU_4].pipeline; n /= 4; + n_per_row /= 4; + stride /= 4; } else { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline; } - const int64_t stride = src0->nb[1]/sizeof(float); - [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:2]; + [encoder setBytes:&n_per_row length:sizeof(n_per_row) atIndex:2]; [encoder setBytes:&stride length:sizeof(stride) atIndex:3]; [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 93f4d6c5..c1e11047 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -401,25 +401,25 @@ kernel void kernel_silu_4( kernel void kernel_swiglu( device const float * src0, device float * dst, - constant int64_t & ne0, - constant int64_t & stride, + constant uint & ne0, + constant uint & stride, uint tpig[[thread_position_in_grid]]) { - const int64_t row = tpig/ne0; - const int64_t idx = tpig%ne0; - const int64_t j = row*stride + idx; + const uint row = tpig/ne0; + const uint idx = tpig%ne0; + const uint j = row*stride + idx; dst[tpig] = src0[j] * src0[j + ne0] / (1.0f + exp(-src0[j])); } kernel void kernel_swiglu_4( device const float4 * src0, device float4 * dst, - constant int64_t & ne0, - constant int64_t & stride, + constant uint & ne0, + constant uint & stride, uint tpig[[thread_position_in_grid]]) { - const int64_t row = tpig/(ne0/4); - const int64_t idx = tpig%(ne0/4); - const int64_t j = row*(stride/4) + idx; - dst[tpig] = src0[j] * src0[j + ne0/4] / (1.0f + exp(-src0[j])); + const uint row = tpig/ne0; + const uint idx = tpig%ne0; + const uint j = row*stride + idx; + dst[tpig] = src0[j] * src0[j + ne0] / (1.0f + exp(-src0[j])); } kernel void kernel_sqr(