AVX2 Flash Attention: quantized K*Q for q4_0, q4_1, q8_0

This commit is contained in:
Iwan Kawrakow
2024-09-12 13:14:37 +03:00
parent 3539e4caa2
commit 7c4bc981dc

View File

@@ -7375,20 +7375,9 @@ struct FlashMS {
S[j] += F16::reduce_add<k_step>(vk);
}
inline void update_M_S(int j, F16::Data * vk, const char * mask) {
#ifdef HAVE_FANCY_SIMD
auto vzero = _mm256_set1_epi16(0);
auto vinf = _mm512_set1_ps(-INFINITY);
//for (int l = 0; l < k_step/F16::block_size; ++l) {
// auto m16 = _mm256_cmpeq_epi16_mask(_mm256_loadu_si256((const __m256i *)mask + l), vzero);
// vk[l] = _mm512_mask_blend_ps(m16, vinf, F16::load(cache + k_step*j + F16::block_size*l));
//}
//if (softcap <= 0.0f) {
// for (int l = 0; l < k_step/F16::block_size; ++l) vk[l] = F16::mul(vscale, vk[l]);
//} else {
// auto v_softcap = F16::set1(softcap);
// for (int l = 0; l < k_step/F16::block_size; ++l) {
// vk[l] = F16::mul(v_softcap, v_tanh(F16::mul(vscale, vk[l])));
// }
//}
if (softcap <= 0) {
for (int l = 0; l < k_step/F16::block_size; ++l) {
auto m16 = _mm256_cmpeq_epi16_mask(_mm256_loadu_si256((const __m256i *)mask + l), vzero);
@@ -7401,6 +7390,24 @@ struct FlashMS {
vk[l] = _mm512_mask_mul_ps(vinf, m16, v_softcap, v_tanh(F16::mul(vscale, F16::load(cache + k_step*j + F16::block_size*l))));
}
}
#else
auto vzero = _mm_set1_epi16(0);
auto vinf = F16::set1(-INFINITY);
for (int l = 0; l < k_step/F16::block_size; ++l) {
auto m128 = _mm_loadu_si128((const __m128i *)mask + l);
m128 = _mm_cmpeq_epi16(m128, vzero);
auto m256 = _mm256_cvtepi16_epi32(m128);
auto mf = _mm256_castsi256_ps(_mm256_or_si256(m256, _mm256_slli_epi32(m256, 16)));
auto val = _mm256_loadu_ps(cache + k_step*j + F16::block_size*l);
vk[l] = _mm256_or_ps(_mm256_and_ps(mf, val), _mm256_andnot_ps(mf, vinf));
}
if (softcap <= 0) {
for (int l = 0; l < k_step/F16::block_size; ++l) vk[l] = F16::mul(vscale, vk[l]);
} else {
auto v_softcap = F16::set1(softcap);
for (int l = 0; l < k_step/F16::block_size; ++l) vk[l] = F16::mul(v_softcap, v_tanh(F16::mul(vscale, vk[l])));
}
#endif
float smax = F16::reduce_max<k_step>(vk);
if (smax == -INFINITY) {