Better CPU SWA

This commit is contained in:
Iwan Kawrakow
2025-09-04 11:08:42 +03:00
parent f5e68bf8b6
commit 910a27ab9b
3 changed files with 18 additions and 5 deletions

View File

@@ -18881,7 +18881,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
Dk, Dv, neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], Dk, Dv, neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1],
q->data, k->data, v->data, mask->data, sinks ? sinks->data : NULL, q->data, k->data, v->data, mask->data, sinks ? sinks->data : NULL,
scale, softcap, (float *)dst->data, scale, softcap, (float *)dst->data,
params->wdata, (barrier_t)ggml_barrier, (void *)params->shared, ith, nth)) return; params->wdata, (barrier_t)ggml_barrier, (void *)params->shared, ith, nth, dst->op_params[4])) return;
// if (max_bias <= 0.0f && q->type == GGML_TYPE_F32 && mask && mask->type == GGML_TYPE_F16) { // if (max_bias <= 0.0f && q->type == GGML_TYPE_F32 && mask && mask->type == GGML_TYPE_F16) {
// //if (ith == 0) printf("k: %ld x %ld x %ld, q: %ld x %ld x %ld, v: %ld x %ld x %ld mask: %ld x %ld x %ld\n", // //if (ith == 0) printf("k: %ld x %ld x %ld, q: %ld x %ld x %ld, v: %ld x %ld x %ld mask: %ld x %ld x %ld\n",

View File

@@ -71,10 +71,23 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
float softcap, // if > 0, a "soft-cap" operation is applied before softmax float softcap, // if > 0, a "soft-cap" operation is applied before softmax
float * qkv, // v*softmax(scale*(k*q)) float * qkv, // v*softmax(scale*(k*q))
[[maybe_unused]] void * work_buffer_in, [[maybe_unused]] barrier_t barrier, [[maybe_unused]] void * barrier_data, [[maybe_unused]] void * work_buffer_in, [[maybe_unused]] barrier_t barrier, [[maybe_unused]] void * barrier_data,
int ith, int nth) { int ith, int nth, int n_swa) {
if (type_q != 0 || type_mask != 1 || max_bias > 0) return false; if (type_q != 0 || type_mask != 1 || max_bias > 0) return false;
if (n_swa > 0) {
constexpr int kMinBatch = 256;
int ntokens = std::max(kMinBatch, neq1);
int nblock = (ntokens + n_swa + kMinBatch - 1)/kMinBatch;
int first = nek1 - nblock*kMinBatch;
if (first > 0) {
k = (const char *)k + int64_t(first)*stride_k;
v = (const char *)v + int64_t(first)*stride_v;
mask = (const uint16_t *)mask + first;
nek1 -= first;
}
}
int rk2 = neq2/nek2; int rk2 = neq2/nek2;
int rv2 = neq2/nev2; int rv2 = neq2/nev2;
int rk3 = neq3/nek3; int rk3 = neq3/nek3;
@@ -83,7 +96,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
int first_k = 0, last_k = nek1; int first_k = 0, last_k = nek1;
if (neq3 == 1 && rk2 > 1 && neq1 == 1 && nek1 > 256) { if (neq3 == 1 && rk2 > 1 && neq1 == 1 && nek1 > 256) {
// This is a quick hack for SWA models. // This is a quick hack for SWA models.
// Given that the mask is the same for all layers, ideally we should determinbe the // Given that the mask is the same for all layers, ideally we should determine the
// cache bounds once, and reuse for the whole graph. But even with this simple hack // cache bounds once, and reuse for the whole graph. But even with this simple hack
// we get non-negligible performance gains for SWA models and long context. // we get non-negligible performance gains for SWA models and long context.
auto umask = (const uint16_t *)mask; auto umask = (const uint16_t *)mask;
@@ -339,7 +352,7 @@ bool iqk_flash_attn_noalibi([[maybe_unused]] int type_q, [[maybe_unused]] int ty
[[maybe_unused]] float softcap, // if > 0, a "soft-cap" operation is applied before softmax [[maybe_unused]] float softcap, // if > 0, a "soft-cap" operation is applied before softmax
[[maybe_unused]] float * qkv, // v*softmax(scale*(k*q)) [[maybe_unused]] float * qkv, // v*softmax(scale*(k*q))
[[maybe_unused]] void * work_buffer, [[maybe_unused]] barrier_t barrier, [[maybe_unused]] void * barrier_data, [[maybe_unused]] void * work_buffer, [[maybe_unused]] barrier_t barrier, [[maybe_unused]] void * barrier_data,
[[maybe_unused]] int ith, [[maybe_unused]] int nth) { [[maybe_unused]] int ith, [[maybe_unused]] int nth, [[maybe_unused]] int n_swa) {
return false; return false;
} }

View File

@@ -63,7 +63,7 @@ IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias,
float softcap, // if > 0, a "soft-cap" operation is applied before softmax float softcap, // if > 0, a "soft-cap" operation is applied before softmax
float * qkv, // v*softmax(scale*(k*q)) float * qkv, // v*softmax(scale*(k*q))
void * work_buffer, barrier_t barrier, void * barrier_data, void * work_buffer, barrier_t barrier, void * barrier_data,
int ith, int nth); int ith, int nth, int n_swa);
#ifdef __cplusplus #ifdef __cplusplus
} }