From 7efe16f71522b1934055302cd7ca898c157576b9 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 18 Jan 2025 09:44:32 +0200 Subject: [PATCH] FA: don't store sum scaling factor in SIMD registers --- ggml/src/iqk/iqk_mul_mat.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index cf572868..cb24f721 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -12861,7 +12861,7 @@ struct FlashMS { if (smax > M[j]) { if (M[j] > -INFINITY) { float m = expf(M[j] - smax); - vms[j] = F16::set1(m); + vms[j] = m; need_scaling[j] = 1; S[j] *= m; } else { @@ -13043,7 +13043,7 @@ struct FlashMS { cache_t cache[q_step*k_step]; float S[q_step], M[q_step]; int need_scaling[q_step]; - F16::Data vms[q_step]; + float vms[q_step]; const F16::Data vscale; const float softcap; const ggml_half h_inf; @@ -13070,8 +13070,9 @@ struct FlashQKV { std::memset(R, 0, D*sizeof(qkv_cache_t)); } else if (fms.need_scaling[j] == 1) { + auto vms = F16::set1(fms.vms[j]); for (int i = 0; i < D/F16::block_size; ++i) { - F16::store(R + F16::block_size*i, F16::mul(fms.vms[j], F16::load(R + F16::block_size*i))); + F16::store(R + F16::block_size*i, F16::mul(vms, F16::load(R + F16::block_size*i))); } } } @@ -13110,8 +13111,9 @@ struct FlashQKV { std::memset(R, 0, D*sizeof(qkv_cache_t)); } else if (fms.need_scaling[j] == 1) { + auto vms = F16::set1(fms.vms[j]); for (int i = 0; i < D/F16::block_size; ++i) { - F16::store(R + F16::block_size*i, F16::mul(fms.vms[j], F16::load(R + F16::block_size*i))); + F16::store(R + F16::block_size*i, F16::mul(vms, F16::load(R + F16::block_size*i))); } } }