diff --git a/ggml/src/iqk/fa/iqk_fa_templates.h b/ggml/src/iqk/fa/iqk_fa_templates.h index cd983127..b11c2989 100644 --- a/ggml/src/iqk/fa/iqk_fa_templates.h +++ b/ggml/src/iqk/fa/iqk_fa_templates.h @@ -1355,7 +1355,11 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in KQHelper::convert(q_step, stride_q, q, q_f16); #endif auto mr = mask; - for (int k1 = 0; k1 < nk1/k_step; ++k1) { + auto Mc = (const uint16_t *)(mr + (q_step - 1)*stride_m); + int ik = nk1 - k_step; + for (; ik >=0 && Mc[ik] != 0; ik -= k_step); + ik += k_step; + for (int k1 = 0; k1 < ik/k_step; ++k1) { #ifdef __aarch64__ KQHelper::multiply_mask_kq(kh, Dk, stride_m, q_f16, mr, fms); #else @@ -1965,7 +1969,11 @@ struct FlashAttnBF16 { perf.accum_nolock(0, t1); #endif auto mr = mask; - for (int k1 = 0; k1 < nk1/k_step; ++k1) { + auto Mc = (const uint16_t *)(mr + (q_step - 1)*stride_m); + int ik = nk1 - k_step; + for (; ik >=0 && Mc[ik] != 0; ik -= k_step); + ik += k_step; + for (int k1 = 0; k1 < ik/k_step; ++k1) { #if FA_TIMING //t1 = Perf::cur_time(); FlashQKbf16::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms, perf);