From ec4bc75f90b357328609278b4a1e66fda27934df Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 23 Mar 2025 13:29:14 +0200 Subject: [PATCH] Revert the commented out section in iqk_mul_mat.cpp It does have some benefit at long contexts. --- ggml/src/iqk/iqk_mul_mat.cpp | 36 +++++++++++++++++------------------- 1 file changed, 17 insertions(+), 19 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 1b9f8aa2..4d29e2f0 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -17103,25 +17103,23 @@ template inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) { - // Not sure if this actually helps. - // So, let's reduce compilation time by commenting it out for now. - //if (nk1 >= 256) { //4096) { - // if (nq1 >= 64) { - // FlashAttn fa(scale, softcap); - // fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); - // return; - // } - // if (nq1 >= 32) { - // FlashAttn fa(scale, softcap); - // fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); - // return; - // } - // if (nq1 >= 16) { - // FlashAttn fa(scale, softcap); - // fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); - // return; - // } - //} + if (nk1 >= 256) { //4096) { + if (nq1 >= 64) { + FlashAttn fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + return; + } + if (nq1 >= 32) { + FlashAttn fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + return; + } + if (nq1 >= 16) { + FlashAttn fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + return; + } + } if (nq1 >= 8) { FlashAttn fa(scale, softcap); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);