WIP KQ binary mask: Metal soft_max

I need to redo this with better templates.
This commit is contained in:
Iwan Kawrakow
2024-08-28 13:04:23 +02:00
parent 900a39bec9
commit fe825ecbe4
2 changed files with 204 additions and 3 deletions

View File

@@ -67,6 +67,8 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4,
GGML_METAL_KERNEL_TYPE_SOFT_MAX_U32,
GGML_METAL_KERNEL_TYPE_SOFT_MAX_U32_4,
GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16,
GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16_4,
GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32,
@@ -578,6 +580,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_U32, soft_max_u32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_U32_4, soft_max_u32_4, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16, soft_cap_max_f16, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16_4, soft_cap_max_f16_4, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32, soft_cap_max_f32, ctx->support_simdgroup_reduction);
@@ -1633,19 +1637,22 @@ static enum ggml_status ggml_metal_graph_compute(
} break;
case GGML_OP_SOFT_MAX:
{
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_I32);
int nth = 32; // SIMD width
id<MTLComputePipelineState> pipeline = nil;
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
const bool use_u32 = (src1 && src1->type == GGML_TYPE_I32);
if (ne00%4 == 0) {
while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) {
nth *= 2;
}
if (use_f16) {
if (use_u32) {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_U32_4].pipeline;
} else if (use_f16) {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4].pipeline;
} else {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline;
@@ -1654,7 +1661,9 @@ static enum ggml_status ggml_metal_graph_compute(
while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
nth *= 2;
}
if (use_f16) {
if (use_u32) {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_U32].pipeline;
} else if (use_f16) {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16].pipeline;
} else {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32].pipeline;

View File

@@ -453,6 +453,198 @@ kernel void kernel_sum_rows(
dst_row[0] = row_sum;
}
kernel void kernel_soft_max_u32(
device const char * src0,
device const char * src1,
device char * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant float & scale,
constant float & max_bias,
constant float & m0,
constant float & m1,
constant uint32_t & n_head_log2,
threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint ntg[[threads_per_threadgroup]]) {
const int64_t i03 = (tgpig) / (ne02*ne01);
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
device const uint32_t * pmask = (device const uint32_t *) src1 + i01*ne00/32;
device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
// parallel max
float lmax = -INFINITY;
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
pdst[i00] = pmask[i00 >> 5] & (1u << (i00 & 31)) ? psrc0[i00]*scale : -INFINITY;
lmax = MAX(lmax, pdst[i00]);
}
// find the max value in the block
float max_val = simd_max(lmax);
if (ntg > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = -INFINITY;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
buf[sgitg] = max_val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
max_val = buf[tiisg];
max_val = simd_max(max_val);
}
// parallel sum
float lsum = 0.0f;
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
const float exp_psrc0 = exp(pdst[i00] - max_val);
lsum += exp_psrc0;
pdst[i00] = exp_psrc0;
}
// This barrier fixes a failing test
// ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
threadgroup_barrier(mem_flags::mem_none);
float sum = simd_sum(lsum);
if (ntg > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = 0.0f;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
buf[sgitg] = sum;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sum = buf[tiisg];
sum = simd_sum(sum);
}
const float inv_sum = 1.0f/sum;
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
pdst[i00] *= inv_sum;
}
}
kernel void kernel_soft_max_u32_4(
device const char * src0,
device const char * src1,
device char * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant float & scale,
constant float & max_bias,
constant float & m0,
constant float & m1,
constant uint32_t & n_head_log2,
threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint ntg[[threads_per_threadgroup]]) {
const int64_t i03 = (tgpig) / (ne02*ne01);
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
device const uint32_t * pmask = (device const uint32_t *) src1 + i01*ne00/32;
device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
// parallel max
float4 lmax4 = -INFINITY;
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
int idx = 4*i00;
uint8_t m4 = pmask[idx >> 5] >> (idx & 31);
float4 val = psrc4[i00]*scale;
val[0] = m4 & 1 ? val[0] : -INFINITY;
val[1] = m4 & 2 ? val[1] : -INFINITY;
val[2] = m4 & 4 ? val[2] : -INFINITY;
val[3] = m4 & 8 ? val[3] : -INFINITY;
lmax4 = fmax(lmax4, val);
pdst4[i00] = val;
}
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
float max_val = simd_max(lmax);
if (ntg > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = -INFINITY;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
buf[sgitg] = max_val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
max_val = buf[tiisg];
max_val = simd_max(max_val);
}
// parallel sum
float4 lsum4 = 0.0f;
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
const float4 exp_psrc4 = exp(pdst4[i00] - max_val);
lsum4 += exp_psrc4;
pdst4[i00] = exp_psrc4;
}
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
// This barrier fixes a failing test
// ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
threadgroup_barrier(mem_flags::mem_none);
float sum = simd_sum(lsum);
if (ntg > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = 0.0f;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
buf[sgitg] = sum;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sum = buf[tiisg];
sum = simd_sum(sum);
}
const float inv_sum = 1.0f/sum;
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
pdst4[i00] *= inv_sum;
}
}
template<typename T>
kernel void kernel_soft_max(
device const char * src0,