FA: don't use large Q steps on AVX2 for fp16 K-cache

This commit is contained in:
Iwan Kawrakow
2025-01-15 17:48:14 +02:00
parent 37162d2695
commit 3f4425205a

View File

@@ -13793,16 +13793,23 @@ template <int D, int k_step, typename KHelper, typename VHelper>
inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
const float * q, const char * mask, float scale, float softcap, float * qkv) {
if (nk1 >= 4096) {
if (nq1 >= 32) {
FlashAttn<D, 32, k_step> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
return;
}
else if (nq1 >= 8) {
FlashAttn<D, 8, k_step> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
return;
#if defined __AVX2__ && !HAVE_FANCY_SIMD
constexpr bool kUseLargeStepsQ = !std::is_same_v<KHelper, HelperF16<D, k_step>>;
#else
constexpr bool kUseLargeStepsQ = true;
#endif
if constexpr (kUseLargeStepsQ) {
if (nk1 >= 4096) {
if (nq1 >= 32) {
FlashAttn<D, 32, k_step> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
return;
}
else if (nq1 >= 8) {
FlashAttn<D, 8, k_step> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
return;
}
}
}
if (nq1 >= 8) {