mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 23:54:10 +00:00
WIP KQ binary mask: Metal soft_max
I need to redo this with better templates.
This commit is contained in:
@@ -67,6 +67,8 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
|
||||
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
|
||||
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4,
|
||||
GGML_METAL_KERNEL_TYPE_SOFT_MAX_U32,
|
||||
GGML_METAL_KERNEL_TYPE_SOFT_MAX_U32_4,
|
||||
GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16,
|
||||
GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16_4,
|
||||
GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32,
|
||||
@@ -578,6 +580,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_U32, soft_max_u32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_U32_4, soft_max_u32_4, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16, soft_cap_max_f16, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16_4, soft_cap_max_f16_4, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32, soft_cap_max_f32, ctx->support_simdgroup_reduction);
|
||||
@@ -1633,19 +1637,22 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
} break;
|
||||
case GGML_OP_SOFT_MAX:
|
||||
{
|
||||
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_I32);
|
||||
|
||||
int nth = 32; // SIMD width
|
||||
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
|
||||
const bool use_u32 = (src1 && src1->type == GGML_TYPE_I32);
|
||||
|
||||
if (ne00%4 == 0) {
|
||||
while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) {
|
||||
nth *= 2;
|
||||
}
|
||||
if (use_f16) {
|
||||
if (use_u32) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_U32_4].pipeline;
|
||||
} else if (use_f16) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4].pipeline;
|
||||
} else {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline;
|
||||
@@ -1654,7 +1661,9 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
|
||||
nth *= 2;
|
||||
}
|
||||
if (use_f16) {
|
||||
if (use_u32) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_U32].pipeline;
|
||||
} else if (use_f16) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16].pipeline;
|
||||
} else {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32].pipeline;
|
||||
|
||||
@@ -453,6 +453,198 @@ kernel void kernel_sum_rows(
|
||||
dst_row[0] = row_sum;
|
||||
}
|
||||
|
||||
kernel void kernel_soft_max_u32(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant float & scale,
|
||||
constant float & max_bias,
|
||||
constant float & m0,
|
||||
constant float & m1,
|
||||
constant uint32_t & n_head_log2,
|
||||
threadgroup float * buf [[threadgroup(0)]],
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
uint tpitg[[thread_position_in_threadgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint ntg[[threads_per_threadgroup]]) {
|
||||
const int64_t i03 = (tgpig) / (ne02*ne01);
|
||||
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
||||
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
||||
|
||||
device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||
device const uint32_t * pmask = (device const uint32_t *) src1 + i01*ne00/32;
|
||||
device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||
|
||||
// parallel max
|
||||
float lmax = -INFINITY;
|
||||
|
||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||
pdst[i00] = pmask[i00 >> 5] & (1u << (i00 & 31)) ? psrc0[i00]*scale : -INFINITY;
|
||||
lmax = MAX(lmax, pdst[i00]);
|
||||
}
|
||||
|
||||
// find the max value in the block
|
||||
float max_val = simd_max(lmax);
|
||||
if (ntg > N_SIMDWIDTH) {
|
||||
if (sgitg == 0) {
|
||||
buf[tiisg] = -INFINITY;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
buf[sgitg] = max_val;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
max_val = buf[tiisg];
|
||||
max_val = simd_max(max_val);
|
||||
}
|
||||
|
||||
// parallel sum
|
||||
float lsum = 0.0f;
|
||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||
const float exp_psrc0 = exp(pdst[i00] - max_val);
|
||||
lsum += exp_psrc0;
|
||||
pdst[i00] = exp_psrc0;
|
||||
}
|
||||
|
||||
// This barrier fixes a failing test
|
||||
// ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
float sum = simd_sum(lsum);
|
||||
|
||||
if (ntg > N_SIMDWIDTH) {
|
||||
if (sgitg == 0) {
|
||||
buf[tiisg] = 0.0f;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
buf[sgitg] = sum;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
sum = buf[tiisg];
|
||||
sum = simd_sum(sum);
|
||||
}
|
||||
|
||||
const float inv_sum = 1.0f/sum;
|
||||
|
||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||
pdst[i00] *= inv_sum;
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_soft_max_u32_4(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant float & scale,
|
||||
constant float & max_bias,
|
||||
constant float & m0,
|
||||
constant float & m1,
|
||||
constant uint32_t & n_head_log2,
|
||||
threadgroup float * buf [[threadgroup(0)]],
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
uint tpitg[[thread_position_in_threadgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint ntg[[threads_per_threadgroup]]) {
|
||||
const int64_t i03 = (tgpig) / (ne02*ne01);
|
||||
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
||||
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
||||
|
||||
device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
|
||||
device const uint32_t * pmask = (device const uint32_t *) src1 + i01*ne00/32;
|
||||
device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
|
||||
|
||||
// parallel max
|
||||
float4 lmax4 = -INFINITY;
|
||||
|
||||
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||
int idx = 4*i00;
|
||||
uint8_t m4 = pmask[idx >> 5] >> (idx & 31);
|
||||
float4 val = psrc4[i00]*scale;
|
||||
val[0] = m4 & 1 ? val[0] : -INFINITY;
|
||||
val[1] = m4 & 2 ? val[1] : -INFINITY;
|
||||
val[2] = m4 & 4 ? val[2] : -INFINITY;
|
||||
val[3] = m4 & 8 ? val[3] : -INFINITY;
|
||||
lmax4 = fmax(lmax4, val);
|
||||
pdst4[i00] = val;
|
||||
}
|
||||
|
||||
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
||||
|
||||
float max_val = simd_max(lmax);
|
||||
if (ntg > N_SIMDWIDTH) {
|
||||
if (sgitg == 0) {
|
||||
buf[tiisg] = -INFINITY;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
buf[sgitg] = max_val;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
max_val = buf[tiisg];
|
||||
max_val = simd_max(max_val);
|
||||
}
|
||||
|
||||
// parallel sum
|
||||
float4 lsum4 = 0.0f;
|
||||
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||
const float4 exp_psrc4 = exp(pdst4[i00] - max_val);
|
||||
lsum4 += exp_psrc4;
|
||||
pdst4[i00] = exp_psrc4;
|
||||
}
|
||||
|
||||
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
||||
|
||||
// This barrier fixes a failing test
|
||||
// ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
float sum = simd_sum(lsum);
|
||||
|
||||
if (ntg > N_SIMDWIDTH) {
|
||||
if (sgitg == 0) {
|
||||
buf[tiisg] = 0.0f;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
buf[sgitg] = sum;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
sum = buf[tiisg];
|
||||
sum = simd_sum(sum);
|
||||
}
|
||||
|
||||
const float inv_sum = 1.0f/sum;
|
||||
|
||||
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||
pdst4[i00] *= inv_sum;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_soft_max(
|
||||
device const char * src0,
|
||||
|
||||
Reference in New Issue
Block a user