Avoid computing FA chunks where the mask is -infinity

This commit is contained in:
Iwan Kawrakow
2025-09-24 16:55:25 +03:00
parent 8e497e704e
commit 5a4dfb5aa1

View File

@@ -1415,6 +1415,8 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
HelperQ80::convert<Dk>(q_step, stride_q, q, q8r);
auto mr = mask;
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
auto Mc = (const uint16_t *)(mr + (q_step - 1)*stride_m);
if (Mc[0] != 0) break;
HelperQ80R8<Dk>::repack(k_step, kh.block, kh.stride, q8r8);
KQHelper::mul_mask_kq(khr8, stride_m, q8r, mr, fms);
fqkv.accumulate_qkv(vh, fms);
@@ -1441,7 +1443,11 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
perf.accum_nolock(0, t1);
#endif
auto mr = mask;
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
auto Mc = (const uint16_t *)(mr + (q_step - 1)*stride_m);
int ik = nk1 - k_step;
for (; ik >=0 && Mc[ik] != 0; ik -= k_step);
ik += k_step;
for (int k1 = 0; k1 < ik/k_step; ++k1) {
#if FA_TIMING
t1 = Perf::cur_time();
KQHelper::mul_mask_kq(kh, stride_m, q8, mr, fms);