From 144d45671785e34d94785bdfea73ee43cf58fb90 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Thu, 4 Sep 2025 11:58:16 +0200 Subject: [PATCH] Better CPU SWA (#757) Co-authored-by: Iwan Kawrakow --- ggml/src/ggml.c | 2 +- ggml/src/iqk/iqk_flash_attn.cpp | 19 ++++++++++++++++--- ggml/src/iqk/iqk_mul_mat.h | 2 +- 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index c4227b35..94c1cc78 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -18881,7 +18881,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( Dk, Dv, neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], q->data, k->data, v->data, mask->data, sinks ? sinks->data : NULL, scale, softcap, (float *)dst->data, - params->wdata, (barrier_t)ggml_barrier, (void *)params->shared, ith, nth)) return; + params->wdata, (barrier_t)ggml_barrier, (void *)params->shared, ith, nth, dst->op_params[4])) return; // if (max_bias <= 0.0f && q->type == GGML_TYPE_F32 && mask && mask->type == GGML_TYPE_F16) { // //if (ith == 0) printf("k: %ld x %ld x %ld, q: %ld x %ld x %ld, v: %ld x %ld x %ld mask: %ld x %ld x %ld\n", diff --git a/ggml/src/iqk/iqk_flash_attn.cpp b/ggml/src/iqk/iqk_flash_attn.cpp index ccd81079..011981fd 100644 --- a/ggml/src/iqk/iqk_flash_attn.cpp +++ b/ggml/src/iqk/iqk_flash_attn.cpp @@ -71,10 +71,23 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float float softcap, // if > 0, a "soft-cap" operation is applied before softmax float * qkv, // v*softmax(scale*(k*q)) [[maybe_unused]] void * work_buffer_in, [[maybe_unused]] barrier_t barrier, [[maybe_unused]] void * barrier_data, - int ith, int nth) { + int ith, int nth, int n_swa) { if (type_q != 0 || type_mask != 1 || max_bias > 0) return false; + if (n_swa > 0) { + constexpr int kMinBatch = 256; + int ntokens = std::max(kMinBatch, neq1); + int nblock = (ntokens + n_swa + kMinBatch - 1)/kMinBatch; + int first = nek1 - nblock*kMinBatch; + if (first > 0) { + k = (const char *)k + int64_t(first)*stride_k; + v = (const char *)v + int64_t(first)*stride_v; + mask = (const uint16_t *)mask + first; + nek1 -= first; + } + } + int rk2 = neq2/nek2; int rv2 = neq2/nev2; int rk3 = neq3/nek3; @@ -83,7 +96,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float int first_k = 0, last_k = nek1; if (neq3 == 1 && rk2 > 1 && neq1 == 1 && nek1 > 256) { // This is a quick hack for SWA models. - // Given that the mask is the same for all layers, ideally we should determinbe the + // Given that the mask is the same for all layers, ideally we should determine the // cache bounds once, and reuse for the whole graph. But even with this simple hack // we get non-negligible performance gains for SWA models and long context. auto umask = (const uint16_t *)mask; @@ -339,7 +352,7 @@ bool iqk_flash_attn_noalibi([[maybe_unused]] int type_q, [[maybe_unused]] int ty [[maybe_unused]] float softcap, // if > 0, a "soft-cap" operation is applied before softmax [[maybe_unused]] float * qkv, // v*softmax(scale*(k*q)) [[maybe_unused]] void * work_buffer, [[maybe_unused]] barrier_t barrier, [[maybe_unused]] void * barrier_data, - [[maybe_unused]] int ith, [[maybe_unused]] int nth) { + [[maybe_unused]] int ith, [[maybe_unused]] int nth, [[maybe_unused]] int n_swa) { return false; } diff --git a/ggml/src/iqk/iqk_mul_mat.h b/ggml/src/iqk/iqk_mul_mat.h index b131095b..c599281b 100644 --- a/ggml/src/iqk/iqk_mul_mat.h +++ b/ggml/src/iqk/iqk_mul_mat.h @@ -63,7 +63,7 @@ IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias, float softcap, // if > 0, a "soft-cap" operation is applied before softmax float * qkv, // v*softmax(scale*(k*q)) void * work_buffer, barrier_t barrier, void * barrier_data, - int ith, int nth); + int ith, int nth, int n_swa); #ifdef __cplusplus }