mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-05 03:20:00 +00:00
GGML_UNARY_OP_SWIGLU: minor improvement on Metal
This commit is contained in:
@@ -1609,19 +1609,22 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
uint32_t n_per_row = ne0;
|
||||
uint32_t stride = src0->nb[1]/sizeof(float);
|
||||
|
||||
if (ne0 % 4 == 0) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU_4].pipeline;
|
||||
n /= 4;
|
||||
n_per_row /= 4;
|
||||
stride /= 4;
|
||||
} else {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
|
||||
}
|
||||
|
||||
const int64_t stride = src0->nb[1]/sizeof(float);
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:2];
|
||||
[encoder setBytes:&n_per_row length:sizeof(n_per_row) atIndex:2];
|
||||
[encoder setBytes:&stride length:sizeof(stride) atIndex:3];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
|
||||
@@ -401,25 +401,25 @@ kernel void kernel_silu_4(
|
||||
kernel void kernel_swiglu(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & stride,
|
||||
constant uint & ne0,
|
||||
constant uint & stride,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
const int64_t row = tpig/ne0;
|
||||
const int64_t idx = tpig%ne0;
|
||||
const int64_t j = row*stride + idx;
|
||||
const uint row = tpig/ne0;
|
||||
const uint idx = tpig%ne0;
|
||||
const uint j = row*stride + idx;
|
||||
dst[tpig] = src0[j] * src0[j + ne0] / (1.0f + exp(-src0[j]));
|
||||
}
|
||||
|
||||
kernel void kernel_swiglu_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & stride,
|
||||
constant uint & ne0,
|
||||
constant uint & stride,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
const int64_t row = tpig/(ne0/4);
|
||||
const int64_t idx = tpig%(ne0/4);
|
||||
const int64_t j = row*(stride/4) + idx;
|
||||
dst[tpig] = src0[j] * src0[j + ne0/4] / (1.0f + exp(-src0[j]));
|
||||
const uint row = tpig/ne0;
|
||||
const uint idx = tpig%ne0;
|
||||
const uint j = row*stride + idx;
|
||||
dst[tpig] = src0[j] * src0[j + ne0] / (1.0f + exp(-src0[j]));
|
||||
}
|
||||
|
||||
kernel void kernel_sqr(
|
||||
|
||||
Reference in New Issue
Block a user