Better CPU SWA (#757)

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-09-04 11:58:16 +02:00
committed by GitHub
parent 4a6a6f17ee
commit 144d456717
3 changed files with 18 additions and 5 deletions

View File

@@ -18881,7 +18881,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
Dk, Dv, neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1],
q->data, k->data, v->data, mask->data, sinks ? sinks->data : NULL,
scale, softcap, (float *)dst->data,
params->wdata, (barrier_t)ggml_barrier, (void *)params->shared, ith, nth)) return;
params->wdata, (barrier_t)ggml_barrier, (void *)params->shared, ith, nth, dst->op_params[4])) return;
// if (max_bias <= 0.0f && q->type == GGML_TYPE_F32 && mask && mask->type == GGML_TYPE_F16) {
// //if (ith == 0) printf("k: %ld x %ld x %ld, q: %ld x %ld x %ld, v: %ld x %ld x %ld mask: %ld x %ld x %ld\n",

View File

@@ -71,10 +71,23 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
float softcap, // if > 0, a "soft-cap" operation is applied before softmax
float * qkv, // v*softmax(scale*(k*q))
[[maybe_unused]] void * work_buffer_in, [[maybe_unused]] barrier_t barrier, [[maybe_unused]] void * barrier_data,
int ith, int nth) {
int ith, int nth, int n_swa) {
if (type_q != 0 || type_mask != 1 || max_bias > 0) return false;
if (n_swa > 0) {
constexpr int kMinBatch = 256;
int ntokens = std::max(kMinBatch, neq1);
int nblock = (ntokens + n_swa + kMinBatch - 1)/kMinBatch;
int first = nek1 - nblock*kMinBatch;
if (first > 0) {
k = (const char *)k + int64_t(first)*stride_k;
v = (const char *)v + int64_t(first)*stride_v;
mask = (const uint16_t *)mask + first;
nek1 -= first;
}
}
int rk2 = neq2/nek2;
int rv2 = neq2/nev2;
int rk3 = neq3/nek3;
@@ -83,7 +96,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
int first_k = 0, last_k = nek1;
if (neq3 == 1 && rk2 > 1 && neq1 == 1 && nek1 > 256) {
// This is a quick hack for SWA models.
// Given that the mask is the same for all layers, ideally we should determinbe the
// Given that the mask is the same for all layers, ideally we should determine the
// cache bounds once, and reuse for the whole graph. But even with this simple hack
// we get non-negligible performance gains for SWA models and long context.
auto umask = (const uint16_t *)mask;
@@ -339,7 +352,7 @@ bool iqk_flash_attn_noalibi([[maybe_unused]] int type_q, [[maybe_unused]] int ty
[[maybe_unused]] float softcap, // if > 0, a "soft-cap" operation is applied before softmax
[[maybe_unused]] float * qkv, // v*softmax(scale*(k*q))
[[maybe_unused]] void * work_buffer, [[maybe_unused]] barrier_t barrier, [[maybe_unused]] void * barrier_data,
[[maybe_unused]] int ith, [[maybe_unused]] int nth) {
[[maybe_unused]] int ith, [[maybe_unused]] int nth, [[maybe_unused]] int n_swa) {
return false;
}

View File

@@ -63,7 +63,7 @@ IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias,
float softcap, // if > 0, a "soft-cap" operation is applied before softmax
float * qkv, // v*softmax(scale*(k*q))
void * work_buffer, barrier_t barrier, void * barrier_data,
int ith, int nth);
int ith, int nth, int n_swa);
#ifdef __cplusplus
}