diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 832aed05..fae77921 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -13793,16 +13793,23 @@ template inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, const float * q, const char * mask, float scale, float softcap, float * qkv) { - if (nk1 >= 4096) { - if (nq1 >= 32) { - FlashAttn fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); - return; - } - else if (nq1 >= 8) { - FlashAttn fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); - return; +#if defined __AVX2__ && !HAVE_FANCY_SIMD + constexpr bool kUseLargeStepsQ = !std::is_same_v>; +#else + constexpr bool kUseLargeStepsQ = true; +#endif + if constexpr (kUseLargeStepsQ) { + if (nk1 >= 4096) { + if (nq1 >= 32) { + FlashAttn fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + return; + } + else if (nq1 >= 8) { + FlashAttn fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + return; + } } } if (nq1 >= 8) {