From 0a17ff156fdf6ca5881b8f1e161568382b2b9ead Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Thu, 12 Sep 2024 14:50:50 +0300 Subject: [PATCH] Tidy up FlashMS --- ggml/src/iqk/iqk_mul_mat.cpp | 187 ++++++++++++++--------------------- 1 file changed, 74 insertions(+), 113 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 58f42db9..bfde26e4 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -7209,8 +7209,51 @@ struct FlashMS { } } + inline void update_M(int j, float smax) { + if (smax == -INFINITY) { + std::memset(cache + k_step*j, 0, k_step*sizeof(float)); + need_scaling[j] = M[j] == -INFINITY ? 2 : 0; + return; + } + need_scaling[j] = 0; + if (smax > M[j]) { + if (M[j] > -INFINITY) { + float m = expf(M[j] - smax); + vms[j] = F16::set1(m); + need_scaling[j] = 1; + S[j] *= m; + } else { + need_scaling[j] = 2; + S[j] = 0; + } + M[j] = smax; + } + } + #ifdef __aarch64__ - inline void update_M_S(int j, float32x4_t * vk) { + inline void update_S(int j, float32x4_t * vk) { + auto vm = vdupq_n_f32(M[j]); + auto vsum = vdupq_n_f32(0); + for (int l = 0; l < k_step/4; ++l) { + vk[l] = v_expf(vsubq_f32(vk[l], vm)); + vsum = vaddq_f32(vsum, vk[l]); + F16::store(cache + k_step*j + 4*l, vk[l]); + } + S[j] += vaddvq_f32(vsum); + } +#else + inline void update_S(int j, F16::Data * vk) { + auto vm = F16::set1(M[j]); + for (int l = 0; l < k_step/F16::block_size; ++l) { + vk[l] = v_expf(F16::sub(vk[l], vm)); + F16::store(cache + k_step*j + F16::block_size*l, vk[l]); + } + S[j] += F16::reduce_add(vk); + } +#endif + +#ifdef __aarch64__ + inline float load_and_scale(int j, float32x4_t * vk) { float32x4_t vmax = vdupq_n_f32(-INFINITY); // Something goes wrong when storing and manipulating K*Q as fp16. // It works for some models (e.g., Gemma-2), but not for others (e.g., LLaMA-3.1-8B). @@ -7248,37 +7291,9 @@ struct FlashMS { vmax = vmaxq_f32(vmax, vk[l]); } } - - float smax = vmaxvq_f32(vmax); - if (smax == -INFINITY) { - std::memset(cache + k_step*j, 0, k_step*sizeof(float)); - need_scaling[j] = M[j] == -INFINITY ? 2 : 0; - return; - } - need_scaling[j] = 0; - if (smax > M[j]) { - if (M[j] > -INFINITY) { - float m = expf(M[j] - smax); - vms[j] = F16::set1(m); - need_scaling[j] = 1; - S[j] *= m; - } else { - need_scaling[j] = 2; - S[j] = 0; - } - M[j] = smax; - } - auto vm = vdupq_n_f32(M[j]); - auto vsum = vdupq_n_f32(0); - for (int l = 0; l < k_step/4; ++l) { - vk[l] = v_expf(vsubq_f32(vk[l], vm)); - vsum = vaddq_f32(vsum, vk[l]); - F16::store(cache + k_step*j + 4*l, vk[l]); - } - S[j] += vaddvq_f32(vsum); + return vmaxvq_f32(vmax); } - inline void update_M_S(int j, float32x4_t * vk, const char * mask) { - { + inline float load_apply_mask_and_scale(int j, float32x4_t * vk, const char * mask) { auto vzero = vdupq_n_f32(0); auto vinf = vdupq_n_f32(-INFINITY); for (int l = 0; l < k_step/8; ++l) { @@ -7291,7 +7306,6 @@ struct FlashMS { vk[2*l+1] = vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(kq.val[1]), vm2), vbicq_u32(vinf, vm2))); } - } float32x4_t vmax = vdupq_n_f32(-INFINITY); auto vscale32 = vcvt_f32_f16(vget_low_f16(vscale)); if (softcap <= 0.0f) { @@ -7307,37 +7321,10 @@ struct FlashMS { vmax = vmaxq_f32(vmax, vk[l]); } } - - float smax = vmaxvq_f32(vmax); - if (smax == -INFINITY) { - std::memset(cache + k_step*j, 0, k_step*sizeof(float)); - need_scaling[j] = M[j] == -INFINITY ? 2 : 0; - return; - } - need_scaling[j] = 0; - if (smax > M[j]) { - if (M[j] > -INFINITY) { - float m = expf(M[j] - smax); - vms[j] = F16::set1(m); - need_scaling[j] = 1; - S[j] *= m; - } else { - need_scaling[j] = 2; - S[j] = 0; - } - M[j] = smax; - } - auto vm = vdupq_n_f32(M[j]); - auto vsum = vdupq_n_f32(0); - for (int l = 0; l < k_step/4; ++l) { - vk[l] = v_expf(vsubq_f32(vk[l], vm)); - vsum = vaddq_f32(vsum, vk[l]); - F16::store(cache + k_step*j + 4*l, vk[l]); - } - S[j] += vaddvq_f32(vsum); + return vmaxvq_f32(vmax); } #else - inline void update_M_S(int j, F16::Data * vk) { + inline float load_and_scale(int j, F16::Data * vk) { if (softcap <= 0.0f) { for (int l = 0; l < k_step/F16::block_size; ++l) vk[l] = F16::mul(vscale, F16::load(cache + k_step*j + F16::block_size*l)); } else { @@ -7347,34 +7334,9 @@ struct FlashMS { vk[l] = F16::mul(v_softcap, v_tanh(F16::mul(vscale, val))); } } - - float smax = F16::reduce_max(vk); - if (smax == -INFINITY) { - std::memset(cache + k_step*j, 0, k_step*sizeof(float)); - need_scaling[j] = M[j] == -INFINITY ? 2 : 0; - return; - } - need_scaling[j] = 0; - if (smax > M[j]) { - if (M[j] > -INFINITY) { - float m = expf(M[j] - smax); - vms[j] = F16::set1(m); - need_scaling[j] = 1; - S[j] *= m; - } else { - need_scaling[j] = 2; - S[j] = 0; - } - M[j] = smax; - } - auto vm = F16::set1(M[j]); - for (int l = 0; l < k_step/F16::block_size; ++l) { - vk[l] = v_expf(F16::sub(vk[l], vm)); - F16::store(cache + k_step*j + F16::block_size*l, vk[l]); - } - S[j] += F16::reduce_add(vk); + return F16::reduce_max(vk); } - inline void update_M_S(int j, F16::Data * vk, const char * mask) { + inline float load_apply_mask_and_scale(int j, F16::Data * vk, const char * mask) { #ifdef HAVE_FANCY_SIMD auto vzero = _mm256_set1_epi16(0); auto vinf = _mm512_set1_ps(-INFINITY); @@ -7408,32 +7370,31 @@ struct FlashMS { for (int l = 0; l < k_step/F16::block_size; ++l) vk[l] = F16::mul(v_softcap, v_tanh(F16::mul(vscale, vk[l]))); } #endif + return F16::reduce_max(vk); + } +#endif - float smax = F16::reduce_max(vk); - if (smax == -INFINITY) { - std::memset(cache + k_step*j, 0, k_step*sizeof(float)); - need_scaling[j] = M[j] == -INFINITY ? 2 : 0; - return; - } - need_scaling[j] = 0; - if (smax > M[j]) { - if (M[j] > -INFINITY) { - float m = expf(M[j] - smax); - vms[j] = F16::set1(m); - need_scaling[j] = 1; - S[j] *= m; - } else { - need_scaling[j] = 2; - S[j] = 0; - } - M[j] = smax; - } - auto vm = F16::set1(M[j]); - for (int l = 0; l < k_step/F16::block_size; ++l) { - vk[l] = v_expf(F16::sub(vk[l], vm)); - F16::store(cache + k_step*j + F16::block_size*l, vk[l]); - } - S[j] += F16::reduce_add(vk); +#ifdef __aarch64__ + inline void update_M_S(int j, float32x4_t * vk) { + float smax = load_and_scale(j, vk); + update_M(j, smax); + update_S(j, vk); + } + inline void update_M_S(int j, float32x4_t * vk, const char * mask) { + float smax = load_apply_mask_and_scale(j, vk, mask); + update_M(j, smax); + update_S(j, vk); + } +#else + inline void update_M_S(int j, F16::Data * vk) { + float smax = load_and_scale(j, vk); + update_M(j, smax); + update_S(j, vk); + } + inline void update_M_S(int j, F16::Data * vk, const char * mask) { + float smax = load_apply_mask_and_scale(j, vk, mask); + update_M(j, smax); + update_S(j, vk); } #endif