CPU flash attention

This commit is contained in:
Kawrakow
2026-01-29 16:25:21 +00:00
parent 0f6cdd4aec
commit cb4b0ebb11
4 changed files with 133 additions and 6 deletions

View File

@@ -270,6 +270,7 @@ if (GGML_IQK_MUL_MAT)
iqk/fa/iqk_fa_128_128.cpp
iqk/fa/iqk_fa_96_96.cpp
iqk/fa/iqk_fa_64_64.cpp
iqk/fa/iqk_fa_1088_1024.cpp
iqk/iqk_gemm_floats.cpp
iqk/iqk_gemm_kquants.cpp
iqk/iqk_gemm_ktquants.cpp

View File

@@ -0,0 +1,121 @@
#include "iqk/iqk_config.h"
#if defined IQK_IMPLEMENT && defined GGML_IQK_FLASH_ATTENTION
#include "iqk/fa/iqk_fa_templates.h"
namespace {
template <int step_k, typename KHelper, typename VHelper>
inline void iqk_deepseek_helper(KHelper& kh, VHelper& vh,
int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
const float * q, const char * mask, float scale, float softcap, float * qkv,
const float * sinkf, float * M, float * S) {
auto update = [&nq1, &mask, &q, &qkv, &M, &S, stride_q, stride_m, stride_qkv] (int n) {
nq1 -= n;
if (nq1 == 0) return true;
q += n*stride_q;
mask += n*stride_m;
qkv += n*stride_qkv;
if (M && S) { M += n; S += n; }
return false;
};
if (nq1 >= 16) {
int n_step = nq1/16;
FlashAttn<1088, 1024, 16, step_k> fa(scale, softcap, sinkf);
fa.compute(kh, vh, 16*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
if (update(16*n_step)) return;
}
if (nq1 >= 8) {
int n_step = nq1/8;
FlashAttn<1088, 1024, 8, step_k> fa(scale, softcap, sinkf);
fa.compute(kh, vh, 8*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
if (update(8*n_step)) return;
}
if (nq1 >= 4) {
int n_step = nq1/4;
FlashAttn<1088, 1024, 4, step_k> fa(scale, softcap, sinkf);
fa.compute(kh, vh, 4*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
if (update(4*n_step)) return;
}
if (nq1 >= 2) {
int n_step = nq1/2;
FlashAttn<1088, 1024, 2, step_k> fa(scale, softcap, sinkf);
fa.compute(kh, vh, 2*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
if (update(2*n_step)) return;
}
FlashAttn<1088, 1024, 1, step_k> fa(scale, softcap, sinkf);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
}
template <int step_k>
inline bool iqk_deepseek_helper(ggml_type type_k,
int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
const float * q, const char * k, const char * v, const char * mask,
float scale, float softcap, float * qkv, const float * sinkf, float * M, float * S) {
if (type_k == GGML_TYPE_Q8_0) {
HelperQ80 kh((const char *)k, stride_k);
HelperQ80 vh((const char *)v, stride_v);
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
return true;
}
if (type_k == GGML_TYPE_Q8_0_R8) {
HelperQ80R8<1088> kh((const char *)k, stride_k);
HelperQ80 vh((const char *)v, stride_v);
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
return true;
}
if (type_k == GGML_TYPE_Q6_0) {
HelperQ60 kh((const char *)k, stride_k);
HelperQ60 vh((const char *)v, stride_v);
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
return true;
}
#if GGML_IQK_FA_ALL_QUANTS
if (type_k == GGML_TYPE_Q8_KV) {
HelperQ8KV<1088> kh((const char *)k, stride_k);
HelperQ8KV<512> vh((const char *)v, stride_v);
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
return true;
}
#endif
if (type_k == GGML_TYPE_F16) {
HelperF16 kh((const char *)k, stride_k);
HelperF16 vh((const char *)v, stride_v);
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
return true;
}
#ifdef __AVX512BF16__
if (type_k == GGML_TYPE_BF16) {
HelperBF16<1088, step_k> kh((const char *)k, stride_k);
HelperBF16<1024, step_k> vh((const char *)v, stride_v);
if (nq1 % 8 == 0) {
FlashAttnBF16<1088, 1024, 8, step_k> fa(scale, softcap, sinkf);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
} else {
FlashAttnBF16<1088, 1024, 1, step_k> fa(scale, softcap, sinkf);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
}
return true;
}
#endif
return false;
}
}
IQK_FA_CASE(iqk_fa_1088_1024) {
auto type_k = ggml_type(int_type_k);
auto type_v = ggml_type(int_type_v);
if (!(type_k == type_v || (type_k == GGML_TYPE_Q8_0_R8 && type_v == GGML_TYPE_Q8_0))) {
return false;
}
stride_q /= sizeof(float); // q stride as float
return iqk_deepseek_helper<32>(type_k, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, (const char *)k, (const char *)v, (const char *)mask, scale, softcap, qkv, sinkf, M, S);
}
#endif

View File

@@ -1231,7 +1231,7 @@ struct FlashQKV {
template <int D, int q_step, int k_step>
struct FlashQKfp32 {
static_assert(D%F16::block_size == 0 && D <= 576);
static_assert(D%F16::block_size == 0 && D <= 1088);
static_assert(k_step%F16::block_size == 0);
static_assert(q_step <= 4 || q_step%4 == 0);
@@ -1523,8 +1523,8 @@ char * get_q_storage(size_t size) {
// q_step-1 versions of these functions for us, which I though was too much with q_step = 8.
template <int Dk, int Dv, int q_step, int k_step>
struct FlashAttn {
static_assert(Dk%F16::block_size == 0 && Dk <= 576);
static_assert(Dv%F16::block_size == 0 && Dv <= 512);
static_assert(Dk%F16::block_size == 0 && Dk <= 1088);
static_assert(Dv%F16::block_size == 0 && Dv <= 1024);
static_assert(k_step%F16::block_size == 0);
static_assert(q_step <= 4 || q_step%4 == 0);
@@ -1635,7 +1635,7 @@ struct HelperBF16 final : public BaseHelper {
template <int D, int q_step, int k_step>
struct FlashQKbf16 {
//static_assert(D%32 == 0 && D <= 256);
static_assert(D%32 == 0 && D <= 576);
static_assert(D%32 == 0 && D <= 1088);
static_assert(k_step%32 == 0);
static_assert(q_step <= 4 || q_step%4 == 0);
@@ -1947,8 +1947,8 @@ template <int Dk, int Dv, int q_step, int k_step>
struct FlashAttnBF16 {
//static_assert(Dk%32 == 0 && Dk <= 256);
//static_assert(Dv%32 == 0 && Dv <= 256);
static_assert(Dk%32 == 0 && Dk <= 576);
static_assert(Dv%32 == 0 && Dv <= 512);
static_assert(Dk%32 == 0 && Dk <= 1088);
static_assert(Dv%32 == 0 && Dv <= 1024);
static_assert(k_step%32 == 0);
static_assert(q_step <= 4 || q_step%4 == 0);
@@ -2240,6 +2240,7 @@ IQK_FA_CASE(iqk_fa_256_256);
IQK_FA_CASE(iqk_fa_128_128);
IQK_FA_CASE(iqk_fa_96_96);
IQK_FA_CASE(iqk_fa_64_64);
IQK_FA_CASE(iqk_fa_1088_1024);
#endif

View File

@@ -1343,6 +1343,10 @@ bool iqk_flash_attn_impl(int int_type_k, // type of k
return iqk_fa_576_512(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, k, v, mask, scale, softcap, qkv, sinksf, M, S);
}
if (Dk == 1088 && Dv == 1024) {
return iqk_fa_1088_1024(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, k, v, mask, scale, softcap, qkv, sinksf, M, S);
}
if (Dk == 192 && Dv == 128) {
return iqk_fa_192_128(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,