From 589d80f6778c50decb715433877d7eec0b3d2b43 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Mon, 2 Feb 2026 12:39:41 +0200 Subject: [PATCH] Fix CPU FA work buffer size (#1216) --- ggml/src/iqk/iqk_flash_attn.cpp | 63 +++++++++++++++++++++++---------- 1 file changed, 45 insertions(+), 18 deletions(-) diff --git a/ggml/src/iqk/iqk_flash_attn.cpp b/ggml/src/iqk/iqk_flash_attn.cpp index d42cbe78..0c9677b1 100644 --- a/ggml/src/iqk/iqk_flash_attn.cpp +++ b/ggml/src/iqk/iqk_flash_attn.cpp @@ -52,29 +52,56 @@ size_t iqk_fa_work_buffer_size(const struct ggml_tensor * dst, int nth) { auto V = dst->src[2]; int rk2 = Q->ne[2]/K->ne[2]; size_t size = 0; - if (K->type == GGML_TYPE_Q8_0 && (Q->ne[1] >= 8 || (rk2 >= 8 && K->ne[2] > 1))) { + if (Q->ne[1] >= 8 && K->type == GGML_TYPE_Q8_0) { size = ggml_row_size(GGML_TYPE_Q8_0, K->ne[0]) * K->ne[1]*K->ne[2]*K->ne[3]; } - int nstep_k = K->ne[1]/32; - if (nstep_k >= 4*nth) { - auto size_thread = (V->ne[0] + 16)*rk2*sizeof(float); - size += size_thread*nth; - return size; - } - int gcd_k = simple_gcd(nstep_k, nth); - if (gcd_k >= 1) { - int nth_k = nth/gcd_k; - int nq_per_thread = (rk2 + nth_k - 1)/nth_k; - if (nq_per_thread > 1) { - auto size_thread = (V->ne[0] + 16)*nq_per_thread*sizeof(float); + if (Q->ne[1] == 1 && Q->ne[3] == 1 && Q->ne[2]/K->ne[2] > 1 && nth >= 1 && K->ne[1]/32 > 1) { + if (K->ne[2] > 1) { + int gcd = simple_gcd(K->ne[2], nth); + int nth_k = nth/gcd; + int nek2_k = K->ne[2]/gcd; + int nchunk = nek2_k*K->ne[1]/32; + int npt = (nchunk + nth_k - 1)/nth_k; + int nk; + if (npt*nth_k == nchunk) { + nk = 32 * (K->ne[1]*K->ne[2]/(32*nth)); + } else { + //int nm = std::max(1, npt/8); + int nm = 1; + while (true) { + if (nm*4 >= npt) break; + nm *= 2; + } + nk = 32*nm; + } + int nkk = (K->ne[1] + nk - 1)/nk; + int nstep_k = K->ne[2]*nkk; + size_t result_size = (V->ne[0] + 16)*Q->ne[2]/K->ne[2]*sizeof(float); + size += nstep_k*result_size; + return size; + } + int nstep_k = K->ne[1]/32; + if (nstep_k >= 4*nth) { + auto size_thread = (V->ne[0] + 16)*rk2*sizeof(float); size += size_thread*nth; return size; } - } - int rv2 = Q->ne[2] / V->ne[2]; - if (Q->ne[1] == 1 && Q->ne[3] == 1 && rk2 > 1 && rk2 == rv2 && K->ne[1]*K->ne[2] >= 32*nth) { - auto result_size = (V->ne[0] + 16)*rk2*sizeof(float); - size += result_size*nth; + int gcd_k = simple_gcd(nstep_k, nth); + if (gcd_k >= 1) { + int nth_k = nth/gcd_k; + int nq_per_thread = (rk2 + nth_k - 1)/nth_k; + if (nq_per_thread > 1) { + auto size_thread = (V->ne[0] + 16)*nq_per_thread*sizeof(float); + size += size_thread*nth; + return size; + } + } + int rv2 = Q->ne[2] / V->ne[2]; + if (Q->ne[1] == 1 && Q->ne[3] == 1 && rk2 > 1 && rk2 == rv2 && K->ne[1]*K->ne[2] >= 32*nth) { + auto result_size = (V->ne[0] + 16)*rk2*sizeof(float); + size += result_size*nth; + } + return size; } return size; }