mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-04 19:10:03 +00:00
WIP KQ binary mask: Metal
For now just soft_cap_max. On Gemma2-9b I'm observing a ~2% speedup for context of 16k tokens.
This commit is contained in:
@@ -71,6 +71,8 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16_4,
|
||||
GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32,
|
||||
GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32_4,
|
||||
GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_U32,
|
||||
GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_U32_4,
|
||||
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
|
||||
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
|
||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
|
||||
@@ -580,6 +582,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16_4, soft_cap_max_f16_4, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32, soft_cap_max_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32_4, soft_cap_max_f32_4, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_U32, soft_cap_max_u32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_U32_4, soft_cap_max_u32_4, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
|
||||
@@ -1694,19 +1698,22 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
} break;
|
||||
case GGML_OP_SOFT_CAP_MAX:
|
||||
{
|
||||
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_I32);
|
||||
|
||||
int nth = 32; // SIMD width
|
||||
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
|
||||
const bool use_u32 = (src1 && src1->type == GGML_TYPE_I32);
|
||||
|
||||
if (ne00%4 == 0) {
|
||||
while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) {
|
||||
nth *= 2;
|
||||
}
|
||||
if (use_f16) {
|
||||
if (use_u32) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_U32_4].pipeline;
|
||||
} else if (use_f16) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16_4].pipeline;
|
||||
} else {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32_4].pipeline;
|
||||
@@ -1715,7 +1722,9 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
|
||||
nth *= 2;
|
||||
}
|
||||
if (use_f16) {
|
||||
if (use_u32) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_U32].pipeline;
|
||||
} else if (use_f16) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16].pipeline;
|
||||
} else {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32].pipeline;
|
||||
|
||||
@@ -661,6 +661,101 @@ kernel void kernel_soft_max_4(
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_soft_cap_max_u32(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant float & scale,
|
||||
constant float & max_bias,
|
||||
constant float & m0,
|
||||
constant float & m1,
|
||||
constant float & s_before,
|
||||
constant float & s_after,
|
||||
constant uint32_t & n_head_log2,
|
||||
threadgroup float * buf [[threadgroup(0)]],
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
uint tpitg[[thread_position_in_threadgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint ntg[[threads_per_threadgroup]]) {
|
||||
const int64_t i03 = (tgpig) / (ne02*ne01);
|
||||
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
||||
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
||||
|
||||
device const float * psrc0 = (device const float * ) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||
device const uint32_t * pmask = (device const uint32_t *) src1 + i01*ne00/32;
|
||||
device float * pdst = (device float * ) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||
|
||||
// parallel max
|
||||
float lmax = -INFINITY;
|
||||
|
||||
const float tot_scale = scale * s_after;
|
||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||
float val = pmask[i00 >> 5] & (1u << (i00 & 31)) ? precise::tanh(s_before*psrc0[i00])*tot_scale : -INFINITY;
|
||||
lmax = MAX(lmax, val);
|
||||
pdst[i00] = val;
|
||||
}
|
||||
|
||||
// find the max value in the block
|
||||
float max_val = simd_max(lmax);
|
||||
if (ntg > N_SIMDWIDTH) {
|
||||
if (sgitg == 0) {
|
||||
buf[tiisg] = -INFINITY;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
buf[sgitg] = max_val;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
max_val = buf[tiisg];
|
||||
max_val = simd_max(max_val);
|
||||
}
|
||||
|
||||
// parallel sum
|
||||
float lsum = 0.0f;
|
||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||
const float exp_psrc0 = exp(pdst[i00] - max_val);
|
||||
lsum += exp_psrc0;
|
||||
pdst[i00] = exp_psrc0;
|
||||
}
|
||||
|
||||
// This barrier fixes a failing test
|
||||
// ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
float sum = simd_sum(lsum);
|
||||
|
||||
if (ntg > N_SIMDWIDTH) {
|
||||
if (sgitg == 0) {
|
||||
buf[tiisg] = 0.0f;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
buf[sgitg] = sum;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
sum = buf[tiisg];
|
||||
sum = simd_sum(sum);
|
||||
}
|
||||
|
||||
const float inv_sum = 1.0f/sum;
|
||||
|
||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||
pdst[i00] *= inv_sum;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_soft_cap_max(
|
||||
device const char * src0,
|
||||
@@ -767,6 +862,116 @@ kernel void kernel_soft_cap_max(
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_soft_cap_max_u32_4(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant float & scale,
|
||||
constant float & max_bias,
|
||||
constant float & m0,
|
||||
constant float & m1,
|
||||
constant float & s_before,
|
||||
constant float & s_after,
|
||||
constant uint32_t & n_head_log2,
|
||||
threadgroup float * buf [[threadgroup(0)]],
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
uint tpitg[[thread_position_in_threadgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint ntg[[threads_per_threadgroup]]) {
|
||||
const int64_t i03 = (tgpig) / (ne02*ne01);
|
||||
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
||||
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
||||
|
||||
device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
|
||||
device const uint32_t * pmask = (device const uint32_t *) src1 + i01*ne00/32;
|
||||
device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
|
||||
|
||||
const float tot_scale = scale * s_after;
|
||||
|
||||
// parallel max
|
||||
float4 lmax4 = -INFINITY;
|
||||
float4 vinf = lmax4;
|
||||
|
||||
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||
float4 val = precise::tanh(s_before*psrc4[i00])*tot_scale;
|
||||
int idx = 4*i00;
|
||||
uint8_t m = pmask[idx >> 5] >> (idx & 31);
|
||||
bool4 m4 = { m & 1 ? true : false, m & 2 ? true : false, m & 4 ? true : false, m & 8 ? true : false };
|
||||
//bool4 m4 = ((pmask[idx >> 5] >> (idx & 31)) & 0xf) * 0x01010101;
|
||||
val = select(vinf, val, m4);
|
||||
//uint32_t m = pmask[idx >> 5] >> (idx & 31);
|
||||
//val[0] = m & 1 ? val[0] : -INFINITY;
|
||||
//val[1] = m & 2 ? val[1] : -INFINITY;
|
||||
//val[2] = m & 4 ? val[2] : -INFINITY;
|
||||
//val[3] = m & 8 ? val[3] : -INFINITY;
|
||||
lmax4 = fmax(lmax4, val);
|
||||
pdst4[i00] = val;
|
||||
}
|
||||
|
||||
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
||||
|
||||
float max_val = simd_max(lmax);
|
||||
if (ntg > N_SIMDWIDTH) {
|
||||
if (sgitg == 0) {
|
||||
buf[tiisg] = -INFINITY;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
buf[sgitg] = max_val;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
max_val = buf[tiisg];
|
||||
max_val = simd_max(max_val);
|
||||
}
|
||||
|
||||
// parallel sum
|
||||
float4 lsum4 = 0.0f;
|
||||
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||
const float4 exp_psrc4 = exp(pdst4[i00] - max_val);
|
||||
lsum4 += exp_psrc4;
|
||||
pdst4[i00] = exp_psrc4;
|
||||
}
|
||||
|
||||
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
||||
|
||||
// This barrier fixes a failing test
|
||||
// ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
float sum = simd_sum(lsum);
|
||||
|
||||
if (ntg > N_SIMDWIDTH) {
|
||||
if (sgitg == 0) {
|
||||
buf[tiisg] = 0.0f;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
buf[sgitg] = sum;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
sum = buf[tiisg];
|
||||
sum = simd_sum(sum);
|
||||
}
|
||||
|
||||
const float inv_sum = 1.0f/sum;
|
||||
|
||||
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||
pdst4[i00] *= inv_sum;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_soft_cap_max_4(
|
||||
device const char * src0,
|
||||
|
||||
@@ -2043,6 +2043,7 @@ inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x)
|
||||
inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; }
|
||||
inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; }
|
||||
|
||||
#ifdef __AVX512F__
|
||||
static inline float ggml_vec_add_f32_f16(const int n, const ggml_half * x, float * y, float slope) {
|
||||
__m512 vslope = _mm512_set1_ps(slope);
|
||||
__m512 vmax = _mm512_set1_ps(-INFINITY);
|
||||
@@ -2058,7 +2059,6 @@ static inline float ggml_vec_add_f32_f16(const int n, const ggml_half * x, float
|
||||
}
|
||||
return max;
|
||||
}
|
||||
|
||||
static inline float ggml_vec_add_f32_f32(const int n, const float * x, float * y, float slope) {
|
||||
__m512 vslope = _mm512_set1_ps(slope);
|
||||
__m512 vmax = _mm512_set1_ps(-INFINITY);
|
||||
@@ -2074,7 +2074,6 @@ static inline float ggml_vec_add_f32_f32(const int n, const float * x, float * y
|
||||
}
|
||||
return max;
|
||||
}
|
||||
|
||||
static inline float ggml_vec_add_f32_infmask(const int n, const uint32_t * x, float * y) {
|
||||
GGML_ASSERT(n%16 == 0);
|
||||
__m512 vmax = _mm512_set1_ps(-INFINITY);
|
||||
@@ -2087,6 +2086,29 @@ static inline float ggml_vec_add_f32_infmask(const int n, const uint32_t * x, fl
|
||||
}
|
||||
return _mm512_reduce_max_ps(vmax);
|
||||
}
|
||||
#else
|
||||
// TODO
|
||||
static inline float ggml_vec_add_f32_f16(const int n, const ggml_half * x, float * y, float slope) {
|
||||
GGML_UNUSED(n);
|
||||
GGML_UNUSED(x);
|
||||
GGML_UNUSED(y);
|
||||
GGML_UNUSED(slope);
|
||||
return 0.f;
|
||||
}
|
||||
static inline float ggml_vec_add_f32_f32(const int n, const float * x, float * y, float slope) {
|
||||
GGML_UNUSED(n);
|
||||
GGML_UNUSED(x);
|
||||
GGML_UNUSED(y);
|
||||
GGML_UNUSED(slope);
|
||||
return 0.f;
|
||||
}
|
||||
static inline float ggml_vec_add_f32_infmask(const int n, const uint32_t * x, float * y) {
|
||||
GGML_UNUSED(n);
|
||||
GGML_UNUSED(x);
|
||||
GGML_UNUSED(y);
|
||||
return 0.f;
|
||||
}
|
||||
#endif
|
||||
|
||||
static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc) {
|
||||
assert(nrc == 1);
|
||||
@@ -2903,6 +2925,7 @@ static void ggml_vec_cpy_softcap_f32(const int n, const float * x, float * y, fl
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef __AVX512__
|
||||
static float ggml_vec_cpy_softcap_mask_f32(const int n, const float * x, float * y, const uint32_t * mask, float s_before, float s_after) {
|
||||
const __mmask16 * m16 = (const __mmask16 *)mask;
|
||||
__m512 vinf = _mm512_set1_ps(-INFINITY);
|
||||
@@ -2916,6 +2939,17 @@ static float ggml_vec_cpy_softcap_mask_f32(const int n, const float * x, float *
|
||||
}
|
||||
return _mm512_reduce_max_ps(vmax);
|
||||
}
|
||||
#else
|
||||
static float ggml_vec_cpy_softcap_mask_f32(const int n, const float * x, float * y, const uint32_t * mask, float s_before, float s_after) {
|
||||
GGML_UNUSED(n);
|
||||
GGML_UNUSED(x);
|
||||
GGML_UNUSED(y);
|
||||
GGML_UNUSED(mask);
|
||||
GGML_UNUSED(s_before);
|
||||
GGML_UNUSED(s_after);
|
||||
return 0.f;
|
||||
}
|
||||
#endif
|
||||
|
||||
static void ggml_vec_softcap_f32(const int n, float * x, float s_before, float s_after) {
|
||||
int i = 0;
|
||||
@@ -13788,7 +13822,7 @@ static void ggml_compute_forward_softcap(
|
||||
default:
|
||||
{
|
||||
GGML_ASSERT(false);
|
||||
} break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13920,7 +13954,7 @@ static void ggml_compute_forward_softcap_max(
|
||||
default:
|
||||
{
|
||||
GGML_ASSERT(false);
|
||||
} break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user