Revert the commented out section in iqk_mul_mat.cpp

It does have some benefit at long contexts.
This commit is contained in:
Iwan Kawrakow
2025-03-23 13:29:14 +02:00
parent d12f4a12aa
commit ec4bc75f90

View File

@@ -17103,25 +17103,23 @@ template <int Dk, int Dv, int k_step, typename KHelper, typename VHelper>
inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) {
// Not sure if this actually helps.
// So, let's reduce compilation time by commenting it out for now.
//if (nk1 >= 256) { //4096) {
// if (nq1 >= 64) {
// FlashAttn<Dk, Dv, 64, k_step> fa(scale, softcap);
// fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
// return;
// }
// if (nq1 >= 32) {
// FlashAttn<Dk, Dv, 32, k_step> fa(scale, softcap);
// fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
// return;
// }
// if (nq1 >= 16) {
// FlashAttn<Dk, Dv, 16, k_step> fa(scale, softcap);
// fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
// return;
// }
//}
if (nk1 >= 256) { //4096) {
if (nq1 >= 64) {
FlashAttn<Dk, Dv, 64, k_step> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
return;
}
if (nq1 >= 32) {
FlashAttn<Dk, Dv, 32, k_step> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
return;
}
if (nq1 >= 16) {
FlashAttn<Dk, Dv, 16, k_step> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
return;
}
}
if (nq1 >= 8) {
FlashAttn<Dk, Dv, 8, k_step> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);