diff --git a/ggml/src/iqk/fa/iqk_fa_templates.h b/ggml/src/iqk/fa/iqk_fa_templates.h index 614a2936..b11c2989 100644 --- a/ggml/src/iqk/fa/iqk_fa_templates.h +++ b/ggml/src/iqk/fa/iqk_fa_templates.h @@ -1355,7 +1355,11 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in KQHelper::convert(q_step, stride_q, q, q_f16); #endif auto mr = mask; - for (int k1 = 0; k1 < nk1/k_step; ++k1) { + auto Mc = (const uint16_t *)(mr + (q_step - 1)*stride_m); + int ik = nk1 - k_step; + for (; ik >=0 && Mc[ik] != 0; ik -= k_step); + ik += k_step; + for (int k1 = 0; k1 < ik/k_step; ++k1) { #ifdef __aarch64__ KQHelper::multiply_mask_kq(kh, Dk, stride_m, q_f16, mr, fms); #else @@ -1415,6 +1419,8 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, HelperQ80::convert(q_step, stride_q, q, q8r); auto mr = mask; for (int k1 = 0; k1 < nk1/k_step; ++k1) { + auto Mc = (const uint16_t *)(mr + (q_step - 1)*stride_m); + if (Mc[0] != 0) break; HelperQ80R8::repack(k_step, kh.block, kh.stride, q8r8); KQHelper::mul_mask_kq(khr8, stride_m, q8r, mr, fms); fqkv.accumulate_qkv(vh, fms); @@ -1441,7 +1447,11 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, perf.accum_nolock(0, t1); #endif auto mr = mask; - for (int k1 = 0; k1 < nk1/k_step; ++k1) { + auto Mc = (const uint16_t *)(mr + (q_step - 1)*stride_m); + int ik = nk1 - k_step; + for (; ik >=0 && Mc[ik] != 0; ik -= k_step); + ik += k_step; + for (int k1 = 0; k1 < ik/k_step; ++k1) { #if FA_TIMING t1 = Perf::cur_time(); KQHelper::mul_mask_kq(kh, stride_m, q8, mr, fms); @@ -1959,7 +1969,11 @@ struct FlashAttnBF16 { perf.accum_nolock(0, t1); #endif auto mr = mask; - for (int k1 = 0; k1 < nk1/k_step; ++k1) { + auto Mc = (const uint16_t *)(mr + (q_step - 1)*stride_m); + int ik = nk1 - k_step; + for (; ik >=0 && Mc[ik] != 0; ik -= k_step); + ik += k_step; + for (int k1 = 0; k1 < ik/k_step; ++k1) { #if FA_TIMING //t1 = Perf::cur_time(); FlashQKbf16::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms, perf);