From 11621fe433254d9bd52f55ecbb89c61c85e705d8 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 24 Sep 2025 17:23:40 +0300 Subject: [PATCH] Avoid computing FA chunks where the mask is -infinity also for f16/bf16 --- ggml/src/iqk/fa/iqk_fa_templates.h | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/ggml/src/iqk/fa/iqk_fa_templates.h b/ggml/src/iqk/fa/iqk_fa_templates.h index cd983127..b11c2989 100644 --- a/ggml/src/iqk/fa/iqk_fa_templates.h +++ b/ggml/src/iqk/fa/iqk_fa_templates.h @@ -1355,7 +1355,11 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in KQHelper::convert(q_step, stride_q, q, q_f16); #endif auto mr = mask; - for (int k1 = 0; k1 < nk1/k_step; ++k1) { + auto Mc = (const uint16_t *)(mr + (q_step - 1)*stride_m); + int ik = nk1 - k_step; + for (; ik >=0 && Mc[ik] != 0; ik -= k_step); + ik += k_step; + for (int k1 = 0; k1 < ik/k_step; ++k1) { #ifdef __aarch64__ KQHelper::multiply_mask_kq(kh, Dk, stride_m, q_f16, mr, fms); #else @@ -1965,7 +1969,11 @@ struct FlashAttnBF16 { perf.accum_nolock(0, t1); #endif auto mr = mask; - for (int k1 = 0; k1 < nk1/k_step; ++k1) { + auto Mc = (const uint16_t *)(mr + (q_step - 1)*stride_m); + int ik = nk1 - k_step; + for (; ik >=0 && Mc[ik] != 0; ik -= k_step); + ik += k_step; + for (int k1 = 0; k1 < ik/k_step; ++k1) { #if FA_TIMING //t1 = Perf::cur_time(); FlashQKbf16::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms, perf);