From 511c4592320270bd04babab9642e5bc7af96326b Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 28 Aug 2024 09:08:49 +0300 Subject: [PATCH] WIP: play with KQ mask - make it binary Here we get a small speedup: Gemma-2-2b and 32k context is ~4% faster on Zen4. But on Zen4 we can use _mm512_mask_mul_ps(-inifnity, mask, s_after, tanh(x*s_before)) to scale and apply mask in a single op that has the same latency and throughput as _mm512_mul_ps. Combined with reducing memory loads for the mask represented as fp32 (or fp16), this gives us some performance improvement for very large masks (contexts). It will be much more tricky on the other platforms that do not have masked instructions. --- ggml/src/ggml.c | 84 +++++++++++++++++++++++++++++++++++++------------ src/llama.cpp | 66 ++++++++++++++++++++------------------ 2 files changed, 99 insertions(+), 51 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index adff7669..755d34ce 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2075,6 +2075,19 @@ static inline float ggml_vec_add_f32_f32(const int n, const float * x, float * y return max; } +static inline float ggml_vec_add_f32_infmask(const int n, const uint32_t * x, float * y) { + GGML_ASSERT(n%16 == 0); + __m512 vmax = _mm512_set1_ps(-INFINITY); + __m512 vinf = _mm512_set1_ps(-INFINITY); + const __mmask16 * mm16 = (const __mmask16 *)x; + for (int j = 0; j < n/16; ++j) { + __m512 v = _mm512_mask_blend_ps(mm16[j], _mm512_loadu_ps(y + 16*j), vinf); + _mm512_storeu_ps(y + 16*j, v); + vmax = _mm512_max_ps(vmax, v); + } + return _mm512_reduce_max_ps(vmax); +} + static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); @@ -2646,6 +2659,13 @@ inline static __m512 ggml_v_softcap(__m512 x, __m512 s_before, __m512 s_after) { return _mm512_mul_ps(th, s_after); } +inline static __m512 ggml_v_softcap_mask(__m512 x, __m512 s_before, __m512 s_after, __m512 src, __mmask16 mask) { + const __m512 one = _mm512_set1_ps(1.0f); + const __m512 exp_two_x = ggml_v_expf(_mm512_mul_ps(x, s_before)); + const __m512 th = _mm512_div_ps(_mm512_sub_ps(exp_two_x, one), _mm512_add_ps(exp_two_x, one)); + return _mm512_mask_mul_ps(src, mask, th, s_after); +} + inline static __m512 ggml_v_gelu(__m512 x, __m512 c1, __m512 c2) { const __m512 one = _mm512_set1_ps(1.0f); __m512 arg = _mm512_fmadd_ps(x, _mm512_mul_ps(c1, x), one); @@ -2883,6 +2903,20 @@ static void ggml_vec_cpy_softcap_f32(const int n, const float * x, float * y, fl } } +static float ggml_vec_cpy_softcap_mask_f32(const int n, const float * x, float * y, const uint32_t * mask, float s_before, float s_after) { + const __mmask16 * m16 = (const __mmask16 *)mask; + __m512 vinf = _mm512_set1_ps(-INFINITY); + __m512 vmax = vinf; + __m512 vs_before = _mm512_set1_ps(2.f*s_before); + __m512 vs_after = _mm512_set1_ps(s_after); + for (int i = 0; i < n/16; ++i) { + __m512 v = ggml_v_softcap_mask(_mm512_loadu_ps(x + 16*i), vs_before, vs_after, vinf, m16[i]); + _mm512_storeu_ps(y + 16*i, v); + vmax = _mm512_max_ps(vmax, v); + } + return _mm512_reduce_max_ps(vmax); +} + static void ggml_vec_softcap_f32(const int n, float * x, float s_before, float s_after) { int i = 0; #if defined(__AVX512F__) && defined(__AVX512DQ__) @@ -6045,10 +6079,10 @@ static struct ggml_tensor * ggml_softcap_max_impl( GGML_ASSERT(ggml_is_padded_1d(a)); if (mask) { - GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32); + GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32 || mask->type == GGML_TYPE_I32); GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(ggml_is_matrix(mask)); - GGML_ASSERT(mask->ne[0] == a->ne[0]); + //GGML_ASSERT(mask->ne[0] == a->ne[0]); GGML_ASSERT(mask->ne[1] >= a->ne[1]); } @@ -13799,6 +13833,7 @@ static void ggml_compute_forward_softcap_max_f32( float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith; const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); + const bool use_i32 = (src1 && src1->type == GGML_TYPE_I32); for (int i1 = ir0; i1 < ir1; i1++) { // ALiBi @@ -13809,27 +13844,36 @@ static void ggml_compute_forward_softcap_max_f32( float * dp = (float *)((char *) dst->data + i1*dst->nb[1]); // broadcast the mask across rows - ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; - float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; - - ggml_vec_cpy_softcap_f32(nc, sp, wp, values[2], values[0]*values[3]); + const int mask_row = i1%ne01; float max = -INFINITY; - if (mp_f32) { - if (use_f16) { - max = ggml_vec_add_f32_f16(nc, mp_f16, wp, slope); - //for (int i = 0; i < nc; ++i) { - // wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]); - //} - } else { - max = ggml_vec_add_f32_f32(nc, mp_f32, wp, slope); - //for (int i = 0; i < nc; ++i) { - // wp[i] += slope*mp_f32[i]; - //} + if (use_i32) { + int n32 = (ne00 + 31)/32; + const uint32_t * mp_u32 = (const uint32_t *)src1->data + mask_row*n32; + max = ggml_vec_cpy_softcap_mask_f32(nc, sp, wp, mp_u32, values[2], values[0]*values[3]); + } else { + + ggml_vec_cpy_softcap_f32(nc, sp, wp, values[2], values[0]*values[3]); + + if (src1) { + if (use_f16) { + ggml_fp16_t * mp_f16 = (ggml_fp16_t *)((char *) src1->data) + mask_row*ne00; + max = ggml_vec_add_f32_f16(nc, mp_f16, wp, slope); + } else if (use_i32) { + int n32 = (ne00 + 31)/32; + const uint32_t * mp_u32 = (const uint32_t *)src1->data + mask_row*n32; + max = ggml_vec_add_f32_infmask(nc, mp_u32, wp); + } else { + float * mp_f32 = (float *)((char *) src1->data) + mask_row*ne00; + max = ggml_vec_add_f32_f32(nc, mp_f32, wp, slope); + //for (int i = 0; i < nc; ++i) { + // wp[i] += slope*mp_f32[i]; + //} + } + } + else { + ggml_vec_max_f32(nc, &max, wp); } - } - else { - ggml_vec_max_f32(nc, &max, wp); } #ifndef NDEBUG diff --git a/src/llama.cpp b/src/llama.cpp index 76aa3fb8..51b5bbaf 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -8687,10 +8687,10 @@ struct llm_build_context { } struct ggml_tensor * build_inp_KQ_mask(bool causal = true) { - auto type = hparams.use_alibi ? GGML_TYPE_F32 : GGML_TYPE_F16; - lctx.inp_KQ_mask = causal - ? ggml_new_tensor_2d(ctx0, type, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)) - : ggml_new_tensor_2d(ctx0, type, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + auto type = hparams.use_alibi ? GGML_TYPE_F32 : GGML_TYPE_I32; + auto nx = causal ? n_kv : n_tokens; + if (type == GGML_TYPE_I32) nx = (nx + 31)/32; + lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, type, nx, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); cb(lctx.inp_KQ_mask, "KQ_mask", -1); ggml_set_input(lctx.inp_KQ_mask); return flash_attn && type == GGML_TYPE_F32 ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask; @@ -8705,11 +8705,10 @@ struct llm_build_context { struct ggml_tensor * build_inp_KQ_mask_swa(bool causal = true) { GGML_ASSERT(hparams.n_swa > 0); - - auto type = hparams.use_alibi ? GGML_TYPE_F32 : GGML_TYPE_F16; - lctx.inp_KQ_mask_swa = causal - ? ggml_new_tensor_2d(ctx0, type, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)) - : ggml_new_tensor_2d(ctx0, type, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + auto type = hparams.use_alibi ? GGML_TYPE_F32 : GGML_TYPE_I32; + auto nx = causal ? n_kv : n_tokens; + if (type == GGML_TYPE_I32) nx = (nx + 31)/32; + lctx.inp_KQ_mask_swa = ggml_new_tensor_2d(ctx0, type, nx, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); cb(lctx.inp_KQ_mask_swa, "KQ_mask_swa", -1); ggml_set_input(lctx.inp_KQ_mask_swa); @@ -14288,40 +14287,45 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_swa->buffer)); } - auto float_type = lctx.inp_KQ_mask ? lctx.inp_KQ_mask->type : lctx.inp_KQ_mask_swa->type; - GGML_ASSERT(float_type == GGML_TYPE_F16 || float_type == GGML_TYPE_F32); + auto mask_type = lctx.inp_KQ_mask ? lctx.inp_KQ_mask->type : lctx.inp_KQ_mask_swa->type; + GGML_ASSERT(mask_type == GGML_TYPE_I32 || mask_type == GGML_TYPE_F32); - if (float_type == GGML_TYPE_F16) { + if (mask_type == GGML_TYPE_I32) { // in order this to be true, we are not using alibi GGML_ASSERT(!hparams.use_alibi); - auto h_zero = ggml_fp32_to_fp16(0.0f); - auto h_inf = ggml_fp32_to_fp16(-INFINITY); - ggml_fp16_t * h_data = lctx.inp_KQ_mask ? (ggml_fp16_t *)lctx.inp_KQ_mask->data : nullptr; - ggml_fp16_t * h_data_swa = lctx.inp_KQ_mask_swa ? (ggml_fp16_t *)lctx.inp_KQ_mask_swa->data : nullptr; + uint32_t * h_data = lctx.inp_KQ_mask ? (uint32_t *)lctx.inp_KQ_mask->data : nullptr; + uint32_t * h_data_swa = lctx.inp_KQ_mask_swa ? (uint32_t *)lctx.inp_KQ_mask_swa->data : nullptr; for (int j = 0; j < n_tokens; ++j) { const llama_pos pos = batch.pos[j]; const llama_seq_id seq_id = batch.seq_id[j][0]; + uint32_t u = 0, u_swa = 0; + uint32_t m = 1; + for (int i = 0; i < n_kv; ++i) { - auto f = lctx.kv_self.cells[i].pos <= pos && lctx.kv_self.cells[i].has_seq_id(seq_id) ? h_zero : h_inf; - if (h_data) h_data[j*n_kv + i] = f; - if (h_data_swa) { - if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) f = h_inf; - h_data_swa[j*n_kv + i] = f; + if (lctx.kv_self.cells[i].pos > pos || !lctx.kv_self.cells[i].has_seq_id(seq_id)) { + u |= m; u_swa |= m; + } + if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) u_swa |= m; + m <<= 1; + if (!m) { + if (h_data) *h_data++ = ~u; + if (h_data_swa) *h_data_swa++ = ~u_swa; + u = u_swa = 0; m = 1; } } + if (m > 1) { + if (h_data) *h_data++ = ~u; + if (h_data_swa) *h_data_swa++ = ~u_swa; + } + } - if (h_data) { - for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { - for (int j = 0; j < n_kv; ++j) h_data[i*n_kv + j] = h_inf; - } - } - - if (h_data_swa) { - for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { - for (int j = 0; j < n_kv; ++j) h_data_swa[i*n_kv + j] = h_inf; - } + auto n_pad = GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); + if (n_pad > n_tokens) { + auto n_kv_32 = (n_kv + 31)/32; + if (h_data) std::memset(h_data, 0, (n_pad - n_tokens)*n_kv_32*sizeof(uint32_t)); + if (h_data_swa) std::memset(h_data_swa, 0, (n_pad - n_tokens)*n_kv_32*sizeof(uint32_t)); } }