diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index cf572868..cb24f721 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -12861,7 +12861,7 @@ struct FlashMS { if (smax > M[j]) { if (M[j] > -INFINITY) { float m = expf(M[j] - smax); - vms[j] = F16::set1(m); + vms[j] = m; need_scaling[j] = 1; S[j] *= m; } else { @@ -13043,7 +13043,7 @@ struct FlashMS { cache_t cache[q_step*k_step]; float S[q_step], M[q_step]; int need_scaling[q_step]; - F16::Data vms[q_step]; + float vms[q_step]; const F16::Data vscale; const float softcap; const ggml_half h_inf; @@ -13070,8 +13070,9 @@ struct FlashQKV { std::memset(R, 0, D*sizeof(qkv_cache_t)); } else if (fms.need_scaling[j] == 1) { + auto vms = F16::set1(fms.vms[j]); for (int i = 0; i < D/F16::block_size; ++i) { - F16::store(R + F16::block_size*i, F16::mul(fms.vms[j], F16::load(R + F16::block_size*i))); + F16::store(R + F16::block_size*i, F16::mul(vms, F16::load(R + F16::block_size*i))); } } } @@ -13110,8 +13111,9 @@ struct FlashQKV { std::memset(R, 0, D*sizeof(qkv_cache_t)); } else if (fms.need_scaling[j] == 1) { + auto vms = F16::set1(fms.vms[j]); for (int i = 0; i < D/F16::block_size; ++i) { - F16::store(R + F16::block_size*i, F16::mul(fms.vms[j], F16::load(R + F16::block_size*i))); + F16::store(R + F16::block_size*i, F16::mul(vms, F16::load(R + F16::block_size*i))); } } }