From 2bf2fa8ba437e6559f204d91a51c77f79061d3c5 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sat, 31 Jan 2026 15:46:16 +0000 Subject: [PATCH] Better CPU FA thread strategy --- ggml/src/ggml.c | 17 ++++++++-- ggml/src/iqk/fa/iqk_fa_576_512.cpp | 17 ++++++---- ggml/src/iqk/fa/iqk_fa_templates.h | 10 +++--- ggml/src/iqk/iqk_flash_attn.cpp | 53 ++++++++++++++++++++++++++++++ 4 files changed, 84 insertions(+), 13 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 5633ac8a..20f9cd38 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -25259,6 +25259,12 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa if (k->type == GGML_TYPE_Q8_0) { qsize = ggml_nrows(k)*ggml_row_size(k->type, k->ne[0]); } + int nstep_k = k->ne[1]/32; + if (nstep_k >= 4*n_tasks && q->ne[1] == 1 && q->ne[3] == 1 && q->ne[2]/k->ne[2] > 1) { + size_t size_thread = (Dv + 16)*q->ne[2]/k->ne[2]*sizeof(float); + size_t size = size_thread*n_tasks; + cur = MAX(cur, size+qsize); + } else { if (q->ne[1] == 1 && q->ne[3] == 1 && q->ne[2]/k->ne[2] > 1 && n_tasks > 1 && k->ne[1]/32 > 1) { if (k->ne[2] > 1) { int gcd = simple_gcd(k->ne[2], n_tasks); @@ -25279,12 +25285,17 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa nk = 32*nm; } //int nk = 32 * (k->ne[2]*k->ne[1]/(32*n_tasks)); - int nstep_k = k->ne[2]*k->ne[1]/nk; + nstep_k = k->ne[2]*k->ne[1]/nk; size_t result_size = (Dv + 16)*q->ne[2]/k->ne[2]*sizeof(float); size_t size = nstep_k*result_size; cur = MAX(cur, size+qsize); } else { - int nstep_k = k->ne[1]/32; + nstep_k = k->ne[1]/32; + if (nstep_k >= n_tasks) { + size_t size_thread = (Dv + 16)*q->ne[2]/k->ne[2]*sizeof(float); + size_t size = size_thread*n_tasks; + cur = MAX(cur, size+qsize); + } else { int gcd_k = simple_gcd(nstep_k, n_tasks); if (gcd_k > 1) { int nth_k = n_tasks/gcd_k; @@ -25298,10 +25309,12 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa } cur = MAX(cur, size+qsize); } + } } } else { cur = MAX(cur, qsize); } + } #endif } break; case GGML_OP_FLASH_ATTN_BACK: diff --git a/ggml/src/iqk/fa/iqk_fa_576_512.cpp b/ggml/src/iqk/fa/iqk_fa_576_512.cpp index 9517eaa1..eec31cd3 100644 --- a/ggml/src/iqk/fa/iqk_fa_576_512.cpp +++ b/ggml/src/iqk/fa/iqk_fa_576_512.cpp @@ -38,14 +38,17 @@ inline void iqk_deepseek_helper(KHelper& kh, VHelper& vh, fa.compute(kh, vh, 4*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); if (update(4*n_step)) return; } - if (nq1 >= 2) { - int n_step = nq1/2; - FlashAttn<576, 512, 2, step_k> fa(scale, softcap, sinkf); - fa.compute(kh, vh, 2*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); - if (update(2*n_step)) return; + if (nq1 == 3) { + FlashAttn<576, 512, 3, step_k> fa(scale, softcap, sinkf); + fa.compute(kh, vh, 3, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); + } + else if (nq1 == 2) { + FlashAttn<576, 512, 2, step_k> fa(scale, softcap, sinkf); + fa.compute(kh, vh, 2, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); + } else { + FlashAttn<576, 512, 1, step_k> fa(scale, softcap, sinkf); + fa.compute(kh, vh, 1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); } - FlashAttn<576, 512, 1, step_k> fa(scale, softcap, sinkf); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); } template diff --git a/ggml/src/iqk/fa/iqk_fa_templates.h b/ggml/src/iqk/fa/iqk_fa_templates.h index f9534ffa..0bf7557c 100644 --- a/ggml/src/iqk/fa/iqk_fa_templates.h +++ b/ggml/src/iqk/fa/iqk_fa_templates.h @@ -1414,7 +1414,7 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, auto q8 = (typename KHelper::block_q8 *)qptr; // This optimization fails under certain conditions (see https://github.com/ikawrakow/ik_llama.cpp/issues/1205) // => disabling until I figure out what goes wrong - if constexpr (false && q_step > 1 && std::is_same_v) { + if constexpr (q_step >= 4 && std::is_same_v) { if (nq1 == q_step) { fms.init_qstep(); kh.reset_block(); @@ -1424,9 +1424,11 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, auto q8r = (typename HelperQ80R8::block_q8 *)qptr; HelperQ80::convert(q_step, stride_q, q, q8r); auto mr = mask; - for (int k1 = 0; k1 < nk1/k_step; ++k1) { - auto Mc = (const uint16_t *)(mr + (q_step - 1)*stride_m); - if (k1 > 0 && Mc[0] != 0) break; + auto Mc = (const uint16_t *)(mr + (q_step - 1)*stride_m); + int ik = nk1 - k_step; + for (; ik >=0 && Mc[ik] != 0; ik -= k_step); + ik += k_step; + for (int k1 = 0; k1 < ik/k_step; ++k1) { HelperQ80R8::repack(k_step, kh.block, kh.stride, q8r8); KQHelper::mul_mask_kq(khr8, stride_m, q8r, mr, fms); fqkv.accumulate_qkv(vh, fms); diff --git a/ggml/src/iqk/iqk_flash_attn.cpp b/ggml/src/iqk/iqk_flash_attn.cpp index 011981fd..a7cb30a0 100644 --- a/ggml/src/iqk/iqk_flash_attn.cpp +++ b/ggml/src/iqk/iqk_flash_attn.cpp @@ -145,6 +145,57 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float // I think it would also speed up things for GQA, but I'm leaving this for another day. if (neq3 == 1 && rk2 > 1 && neq1 == 1 && nth >= 1 && nek1/32 > 1 && nek2 == 1) { int nstep_k = nek1/32; + //if (ith >= nstep_k && ith >= rk2) return true; + if (nstep_k >= nth) { //4*nth) { + int nstep_k_per_thread = (nstep_k + nth - 1)/nth; + int ith_mid = nth; + int nstep_k_this_thread = nstep_k_per_thread; + if (nstep_k_per_thread*nth > nstep_k) { + ith_mid = nstep_k - nth*(nstep_k_per_thread - 1); + if (ith >= ith_mid) --nstep_k_this_thread; + } + //if (ith == 0) fprintf(stderr, "nstep_k = %d, nstep_k_per_thread = %d, ith_mid = %d\n", nstep_k, nstep_k_per_thread, ith_mid); + nstep_k_per_thread *= 32; + nstep_k_this_thread *= 32; + + auto kv_offset = ith <= ith_mid ? ith*nstep_k_per_thread + : ith_mid*nstep_k_per_thread + (ith - ith_mid)*nstep_k_this_thread; + auto kth = (const char *)k + kv_offset*stride_k; + auto vth = (const char *)v + kv_offset*stride_v; + auto qth = (const char *)q; + auto mth = (const char *)mask + kv_offset*sizeof(uint16_t); // we don't have ggml_half available here + + auto work = (char *)work_buffer; + auto size_thread = (Dv + 16)*rk2*sizeof(float); + auto result_buffer = work; + auto work_this_thread = (float *)(result_buffer + ith*size_thread); + //if (nstep_k_this_thread > 0) { + if (!iqk_flash_attn_impl(int_type_k, int_type_v, + Dk, Dv, rk2, nstep_k_this_thread, nbq2, stride_k, stride_v, 0, Dv, //Dk*sizeof(uint16_t), Dv, + (const float *)qth, (const void *)kth, (const void *)vth, (const void *)mth, nullptr, 0, + scale, softcap, + work_this_thread, work_this_thread + (Dv+0)*rk2, work_this_thread + (Dv+1)*rk2)) return false; + //} + + barrier(barrier_data); + + //int nhave = std::min(nstep_k, nth); + for (int j = ith; j < rk2; j += nth) { + auto Racc = qkv + j*nb1/sizeof(float); + float M = -INFINITY, S = 0; + for (int jth = 0; jth < nth; ++jth) { + //for (int jth = 0; jth < nhave; ++jth) { + auto R = (const float *)(result_buffer + jth*size_thread); + auto Mj = R + Dv*rk2; + auto Sj = Mj + rk2; + R += j*Dv; + accumulate_qkv(Dv, M, S, Mj[j], Sj[j], Racc, R); + } + float norm = S > 0 ? 1/S : 1; + for (int i = 0; i < Dv; ++i) Racc[i] *= norm; + } + return true; + } int gcd_k = simple_gcd(nstep_k, nth); if (gcd_k >= 1) { int nth_k = nth/gcd_k; @@ -312,6 +363,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float if (counter++ % (nth/ntg) == ith/ntg) { int iq1 = (ith%ntg)*neq1g; int this_neq1 = std::min(neq1g, neq1-iq1); + if (this_neq1 > 0) { if (!iqk_flash_attn_impl(int_type_k, int_type_v, Dk, Dv, this_neq1, nek1, stride_q, stride_k, stride_v, stride_m, ne1*nb1/sizeof(float), (const float *)((const char *)q + iq2*nbq2 + iq3*nbq3 + iq1*stride_q), @@ -320,6 +372,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float (const void *)((const char *)mask + iq1*stride_m), sinksf, 1, scale, softcap, (float *)((char *)qkv + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1), nullptr, nullptr)) return false; + } } } }