WIP KQ binary mask: Metal

For now just soft_cap_max. On Gemma2-9b I'm observing a
~2% speedup for context of 16k tokens.
This commit is contained in:
Iwan Kawrakow
2024-08-28 11:40:26 +02:00
parent 62d6ef2892
commit 900a39bec9
3 changed files with 255 additions and 7 deletions

View File

@@ -71,6 +71,8 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16_4,
GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32,
GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32_4,
GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_U32,
GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_U32_4,
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
@@ -580,6 +582,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16_4, soft_cap_max_f16_4, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32, soft_cap_max_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32_4, soft_cap_max_f32_4, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_U32, soft_cap_max_u32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_U32_4, soft_cap_max_u32_4, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
@@ -1694,19 +1698,22 @@ static enum ggml_status ggml_metal_graph_compute(
} break;
case GGML_OP_SOFT_CAP_MAX:
{
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_I32);
int nth = 32; // SIMD width
id<MTLComputePipelineState> pipeline = nil;
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
const bool use_u32 = (src1 && src1->type == GGML_TYPE_I32);
if (ne00%4 == 0) {
while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) {
nth *= 2;
}
if (use_f16) {
if (use_u32) {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_U32_4].pipeline;
} else if (use_f16) {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16_4].pipeline;
} else {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32_4].pipeline;
@@ -1715,7 +1722,9 @@ static enum ggml_status ggml_metal_graph_compute(
while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
nth *= 2;
}
if (use_f16) {
if (use_u32) {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_U32].pipeline;
} else if (use_f16) {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16].pipeline;
} else {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32].pipeline;

View File

@@ -661,6 +661,101 @@ kernel void kernel_soft_max_4(
}
}
kernel void kernel_soft_cap_max_u32(
device const char * src0,
device const char * src1,
device char * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant float & scale,
constant float & max_bias,
constant float & m0,
constant float & m1,
constant float & s_before,
constant float & s_after,
constant uint32_t & n_head_log2,
threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint ntg[[threads_per_threadgroup]]) {
const int64_t i03 = (tgpig) / (ne02*ne01);
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
device const float * psrc0 = (device const float * ) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
device const uint32_t * pmask = (device const uint32_t *) src1 + i01*ne00/32;
device float * pdst = (device float * ) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
// parallel max
float lmax = -INFINITY;
const float tot_scale = scale * s_after;
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
float val = pmask[i00 >> 5] & (1u << (i00 & 31)) ? precise::tanh(s_before*psrc0[i00])*tot_scale : -INFINITY;
lmax = MAX(lmax, val);
pdst[i00] = val;
}
// find the max value in the block
float max_val = simd_max(lmax);
if (ntg > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = -INFINITY;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
buf[sgitg] = max_val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
max_val = buf[tiisg];
max_val = simd_max(max_val);
}
// parallel sum
float lsum = 0.0f;
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
const float exp_psrc0 = exp(pdst[i00] - max_val);
lsum += exp_psrc0;
pdst[i00] = exp_psrc0;
}
// This barrier fixes a failing test
// ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
threadgroup_barrier(mem_flags::mem_none);
float sum = simd_sum(lsum);
if (ntg > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = 0.0f;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
buf[sgitg] = sum;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sum = buf[tiisg];
sum = simd_sum(sum);
}
const float inv_sum = 1.0f/sum;
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
pdst[i00] *= inv_sum;
}
}
template<typename T>
kernel void kernel_soft_cap_max(
device const char * src0,
@@ -767,6 +862,116 @@ kernel void kernel_soft_cap_max(
}
}
kernel void kernel_soft_cap_max_u32_4(
device const char * src0,
device const char * src1,
device char * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant float & scale,
constant float & max_bias,
constant float & m0,
constant float & m1,
constant float & s_before,
constant float & s_after,
constant uint32_t & n_head_log2,
threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint ntg[[threads_per_threadgroup]]) {
const int64_t i03 = (tgpig) / (ne02*ne01);
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
device const uint32_t * pmask = (device const uint32_t *) src1 + i01*ne00/32;
device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
const float tot_scale = scale * s_after;
// parallel max
float4 lmax4 = -INFINITY;
float4 vinf = lmax4;
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
float4 val = precise::tanh(s_before*psrc4[i00])*tot_scale;
int idx = 4*i00;
uint8_t m = pmask[idx >> 5] >> (idx & 31);
bool4 m4 = { m & 1 ? true : false, m & 2 ? true : false, m & 4 ? true : false, m & 8 ? true : false };
//bool4 m4 = ((pmask[idx >> 5] >> (idx & 31)) & 0xf) * 0x01010101;
val = select(vinf, val, m4);
//uint32_t m = pmask[idx >> 5] >> (idx & 31);
//val[0] = m & 1 ? val[0] : -INFINITY;
//val[1] = m & 2 ? val[1] : -INFINITY;
//val[2] = m & 4 ? val[2] : -INFINITY;
//val[3] = m & 8 ? val[3] : -INFINITY;
lmax4 = fmax(lmax4, val);
pdst4[i00] = val;
}
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
float max_val = simd_max(lmax);
if (ntg > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = -INFINITY;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
buf[sgitg] = max_val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
max_val = buf[tiisg];
max_val = simd_max(max_val);
}
// parallel sum
float4 lsum4 = 0.0f;
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
const float4 exp_psrc4 = exp(pdst4[i00] - max_val);
lsum4 += exp_psrc4;
pdst4[i00] = exp_psrc4;
}
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
// This barrier fixes a failing test
// ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
threadgroup_barrier(mem_flags::mem_none);
float sum = simd_sum(lsum);
if (ntg > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = 0.0f;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
buf[sgitg] = sum;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sum = buf[tiisg];
sum = simd_sum(sum);
}
const float inv_sum = 1.0f/sum;
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
pdst4[i00] *= inv_sum;
}
}
template<typename T>
kernel void kernel_soft_cap_max_4(
device const char * src0,

View File

@@ -2043,6 +2043,7 @@ inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x)
inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; }
inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; }
#ifdef __AVX512F__
static inline float ggml_vec_add_f32_f16(const int n, const ggml_half * x, float * y, float slope) {
__m512 vslope = _mm512_set1_ps(slope);
__m512 vmax = _mm512_set1_ps(-INFINITY);
@@ -2058,7 +2059,6 @@ static inline float ggml_vec_add_f32_f16(const int n, const ggml_half * x, float
}
return max;
}
static inline float ggml_vec_add_f32_f32(const int n, const float * x, float * y, float slope) {
__m512 vslope = _mm512_set1_ps(slope);
__m512 vmax = _mm512_set1_ps(-INFINITY);
@@ -2074,7 +2074,6 @@ static inline float ggml_vec_add_f32_f32(const int n, const float * x, float * y
}
return max;
}
static inline float ggml_vec_add_f32_infmask(const int n, const uint32_t * x, float * y) {
GGML_ASSERT(n%16 == 0);
__m512 vmax = _mm512_set1_ps(-INFINITY);
@@ -2087,6 +2086,29 @@ static inline float ggml_vec_add_f32_infmask(const int n, const uint32_t * x, fl
}
return _mm512_reduce_max_ps(vmax);
}
#else
// TODO
static inline float ggml_vec_add_f32_f16(const int n, const ggml_half * x, float * y, float slope) {
GGML_UNUSED(n);
GGML_UNUSED(x);
GGML_UNUSED(y);
GGML_UNUSED(slope);
return 0.f;
}
static inline float ggml_vec_add_f32_f32(const int n, const float * x, float * y, float slope) {
GGML_UNUSED(n);
GGML_UNUSED(x);
GGML_UNUSED(y);
GGML_UNUSED(slope);
return 0.f;
}
static inline float ggml_vec_add_f32_infmask(const int n, const uint32_t * x, float * y) {
GGML_UNUSED(n);
GGML_UNUSED(x);
GGML_UNUSED(y);
return 0.f;
}
#endif
static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc) {
assert(nrc == 1);
@@ -2903,6 +2925,7 @@ static void ggml_vec_cpy_softcap_f32(const int n, const float * x, float * y, fl
}
}
#ifdef __AVX512__
static float ggml_vec_cpy_softcap_mask_f32(const int n, const float * x, float * y, const uint32_t * mask, float s_before, float s_after) {
const __mmask16 * m16 = (const __mmask16 *)mask;
__m512 vinf = _mm512_set1_ps(-INFINITY);
@@ -2916,6 +2939,17 @@ static float ggml_vec_cpy_softcap_mask_f32(const int n, const float * x, float *
}
return _mm512_reduce_max_ps(vmax);
}
#else
static float ggml_vec_cpy_softcap_mask_f32(const int n, const float * x, float * y, const uint32_t * mask, float s_before, float s_after) {
GGML_UNUSED(n);
GGML_UNUSED(x);
GGML_UNUSED(y);
GGML_UNUSED(mask);
GGML_UNUSED(s_before);
GGML_UNUSED(s_after);
return 0.f;
}
#endif
static void ggml_vec_softcap_f32(const int n, float * x, float s_before, float s_after) {
int i = 0;
@@ -13788,7 +13822,7 @@ static void ggml_compute_forward_softcap(
default:
{
GGML_ASSERT(false);
} break;
}
}
}
@@ -13920,7 +13954,7 @@ static void ggml_compute_forward_softcap_max(
default:
{
GGML_ASSERT(false);
} break;
}
}
}