diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 8352a3c0..58f42db9 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -7375,20 +7375,9 @@ struct FlashMS { S[j] += F16::reduce_add(vk); } inline void update_M_S(int j, F16::Data * vk, const char * mask) { +#ifdef HAVE_FANCY_SIMD auto vzero = _mm256_set1_epi16(0); auto vinf = _mm512_set1_ps(-INFINITY); - //for (int l = 0; l < k_step/F16::block_size; ++l) { - // auto m16 = _mm256_cmpeq_epi16_mask(_mm256_loadu_si256((const __m256i *)mask + l), vzero); - // vk[l] = _mm512_mask_blend_ps(m16, vinf, F16::load(cache + k_step*j + F16::block_size*l)); - //} - //if (softcap <= 0.0f) { - // for (int l = 0; l < k_step/F16::block_size; ++l) vk[l] = F16::mul(vscale, vk[l]); - //} else { - // auto v_softcap = F16::set1(softcap); - // for (int l = 0; l < k_step/F16::block_size; ++l) { - // vk[l] = F16::mul(v_softcap, v_tanh(F16::mul(vscale, vk[l]))); - // } - //} if (softcap <= 0) { for (int l = 0; l < k_step/F16::block_size; ++l) { auto m16 = _mm256_cmpeq_epi16_mask(_mm256_loadu_si256((const __m256i *)mask + l), vzero); @@ -7401,6 +7390,24 @@ struct FlashMS { vk[l] = _mm512_mask_mul_ps(vinf, m16, v_softcap, v_tanh(F16::mul(vscale, F16::load(cache + k_step*j + F16::block_size*l)))); } } +#else + auto vzero = _mm_set1_epi16(0); + auto vinf = F16::set1(-INFINITY); + for (int l = 0; l < k_step/F16::block_size; ++l) { + auto m128 = _mm_loadu_si128((const __m128i *)mask + l); + m128 = _mm_cmpeq_epi16(m128, vzero); + auto m256 = _mm256_cvtepi16_epi32(m128); + auto mf = _mm256_castsi256_ps(_mm256_or_si256(m256, _mm256_slli_epi32(m256, 16))); + auto val = _mm256_loadu_ps(cache + k_step*j + F16::block_size*l); + vk[l] = _mm256_or_ps(_mm256_and_ps(mf, val), _mm256_andnot_ps(mf, vinf)); + } + if (softcap <= 0) { + for (int l = 0; l < k_step/F16::block_size; ++l) vk[l] = F16::mul(vscale, vk[l]); + } else { + auto v_softcap = F16::set1(softcap); + for (int l = 0; l < k_step/F16::block_size; ++l) vk[l] = F16::mul(v_softcap, v_tanh(F16::mul(vscale, vk[l]))); + } +#endif float smax = F16::reduce_max(vk); if (smax == -INFINITY) {