From bc34573356bcff0acb4bc9329e1f7c4b95cbc9fd Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Fri, 26 Sep 2025 09:00:25 +0200 Subject: [PATCH] CPU: faster FA (#797) * Avoid computing FA chunks where the mask is -infinity * Avoid computing FA chunks where the mask is -infinity also for f16/bf16 --------- Co-authored-by: Iwan Kawrakow --- ggml/src/iqk/fa/iqk_fa_templates.h | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/ggml/src/iqk/fa/iqk_fa_templates.h b/ggml/src/iqk/fa/iqk_fa_templates.h index 614a2936..b11c2989 100644 --- a/ggml/src/iqk/fa/iqk_fa_templates.h +++ b/ggml/src/iqk/fa/iqk_fa_templates.h @@ -1355,7 +1355,11 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in KQHelper::convert(q_step, stride_q, q, q_f16); #endif auto mr = mask; - for (int k1 = 0; k1 < nk1/k_step; ++k1) { + auto Mc = (const uint16_t *)(mr + (q_step - 1)*stride_m); + int ik = nk1 - k_step; + for (; ik >=0 && Mc[ik] != 0; ik -= k_step); + ik += k_step; + for (int k1 = 0; k1 < ik/k_step; ++k1) { #ifdef __aarch64__ KQHelper::multiply_mask_kq(kh, Dk, stride_m, q_f16, mr, fms); #else @@ -1415,6 +1419,8 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, HelperQ80::convert(q_step, stride_q, q, q8r); auto mr = mask; for (int k1 = 0; k1 < nk1/k_step; ++k1) { + auto Mc = (const uint16_t *)(mr + (q_step - 1)*stride_m); + if (Mc[0] != 0) break; HelperQ80R8::repack(k_step, kh.block, kh.stride, q8r8); KQHelper::mul_mask_kq(khr8, stride_m, q8r, mr, fms); fqkv.accumulate_qkv(vh, fms); @@ -1441,7 +1447,11 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, perf.accum_nolock(0, t1); #endif auto mr = mask; - for (int k1 = 0; k1 < nk1/k_step; ++k1) { + auto Mc = (const uint16_t *)(mr + (q_step - 1)*stride_m); + int ik = nk1 - k_step; + for (; ik >=0 && Mc[ik] != 0; ik -= k_step); + ik += k_step; + for (int k1 = 0; k1 < ik/k_step; ++k1) { #if FA_TIMING t1 = Perf::cur_time(); KQHelper::mul_mask_kq(kh, stride_m, q8, mr, fms); @@ -1959,7 +1969,11 @@ struct FlashAttnBF16 { perf.accum_nolock(0, t1); #endif auto mr = mask; - for (int k1 = 0; k1 < nk1/k_step; ++k1) { + auto Mc = (const uint16_t *)(mr + (q_step - 1)*stride_m); + int ik = nk1 - k_step; + for (; ik >=0 && Mc[ik] != 0; ik -= k_step); + ik += k_step; + for (int k1 = 0; k1 < ik/k_step; ++k1) { #if FA_TIMING //t1 = Perf::cur_time(); FlashQKbf16::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms, perf);