diff --git a/ggml/src/iqk/fa/iqk_fa_templates.h b/ggml/src/iqk/fa/iqk_fa_templates.h index 614a2936..cd983127 100644 --- a/ggml/src/iqk/fa/iqk_fa_templates.h +++ b/ggml/src/iqk/fa/iqk_fa_templates.h @@ -1415,6 +1415,8 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, HelperQ80::convert(q_step, stride_q, q, q8r); auto mr = mask; for (int k1 = 0; k1 < nk1/k_step; ++k1) { + auto Mc = (const uint16_t *)(mr + (q_step - 1)*stride_m); + if (Mc[0] != 0) break; HelperQ80R8::repack(k_step, kh.block, kh.stride, q8r8); KQHelper::mul_mask_kq(khr8, stride_m, q8r, mr, fms); fqkv.accumulate_qkv(vh, fms); @@ -1441,7 +1443,11 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, perf.accum_nolock(0, t1); #endif auto mr = mask; - for (int k1 = 0; k1 < nk1/k_step; ++k1) { + auto Mc = (const uint16_t *)(mr + (q_step - 1)*stride_m); + int ik = nk1 - k_step; + for (; ik >=0 && Mc[ik] != 0; ik -= k_step); + ik += k_step; + for (int k1 = 0; k1 < ik/k_step; ++k1) { #if FA_TIMING t1 = Perf::cur_time(); KQHelper::mul_mask_kq(kh, stride_m, q8, mr, fms);