Avoid computing FA chunks where the mask is -infinity also for f16/bf16

This commit is contained in:
Iwan Kawrakow
2025-09-24 17:23:40 +03:00
parent 5a4dfb5aa1
commit 11621fe433

View File

@@ -1355,7 +1355,11 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in
KQHelper::convert(q_step, stride_q, q, q_f16);
#endif
auto mr = mask;
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
auto Mc = (const uint16_t *)(mr + (q_step - 1)*stride_m);
int ik = nk1 - k_step;
for (; ik >=0 && Mc[ik] != 0; ik -= k_step);
ik += k_step;
for (int k1 = 0; k1 < ik/k_step; ++k1) {
#ifdef __aarch64__
KQHelper::multiply_mask_kq(kh, Dk, stride_m, q_f16, mr, fms);
#else
@@ -1965,7 +1969,11 @@ struct FlashAttnBF16 {
perf.accum_nolock(0, t1);
#endif
auto mr = mask;
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
auto Mc = (const uint16_t *)(mr + (q_step - 1)*stride_m);
int ik = nk1 - k_step;
for (; ik >=0 && Mc[ik] != 0; ik -= k_step);
ik += k_step;
for (int k1 = 0; k1 < ik/k_step; ++k1) {
#if FA_TIMING
//t1 = Perf::cur_time();
FlashQKbf16<Dk, q_step, k_step>::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms, perf);