Gigachat: CPU FA (needs 192 x 192 for MLA = 3)

This commit is contained in:
Iwan Kawrakow
2025-11-21 11:44:34 +02:00
parent 360c8c6fd4
commit 0369d2ba44
4 changed files with 52 additions and 0 deletions

View File

@@ -265,6 +265,7 @@ if (GGML_IQK_MUL_MAT)
iqk/iqk_flash_attn.cpp
iqk/fa/iqk_fa_576_512.cpp
iqk/fa/iqk_fa_192_128.cpp
iqk/fa/iqk_fa_192_192.cpp
iqk/fa/iqk_fa_256_256.cpp
iqk/fa/iqk_fa_128_128.cpp
iqk/fa/iqk_fa_96_96.cpp

View File

@@ -0,0 +1,45 @@
#include "iqk/iqk_config.h"
#if defined IQK_IMPLEMENT && defined GGML_IQK_FLASH_ATTENTION
#include "iqk/fa/iqk_fa_templates.h"
IQK_FA_CASE(iqk_fa_192_192) {
auto type_k = ggml_type(int_type_k);
auto type_v = ggml_type(int_type_v);
stride_q /= sizeof(float); // q stride as float
auto ck = (const char *)k;
auto cv = (const char *)v;
auto cm = (const char *)mask;
#ifdef __AVX512BF16__
if (type_k == GGML_TYPE_BF16) {
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
if (nk%64 == 0) {
iqk_flash_helper_T<192, 192, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
return true;
}
iqk_flash_helper_T<192, 192, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
return true;
}
#endif
if (nk%128 == 0) {
return iqk_flash_helper_T<192, 192, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
}
if (nk%64 == 0) {
return iqk_flash_helper_T<192, 192, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
}
return iqk_flash_helper_T<192, 192, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
}
#endif

View File

@@ -2235,6 +2235,7 @@ inline bool iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
IQK_FA_CASE(iqk_fa_576_512);
IQK_FA_CASE(iqk_fa_192_128);
IQK_FA_CASE(iqk_fa_192_192);
IQK_FA_CASE(iqk_fa_256_256);
IQK_FA_CASE(iqk_fa_128_128);
IQK_FA_CASE(iqk_fa_96_96);

View File

@@ -1349,6 +1349,11 @@ bool iqk_flash_attn_impl(int int_type_k, // type of k
q, k, v, mask, scale, softcap, qkv, sinksf, M, S);
}
if (Dk == 192 && Dv == 192) {
return iqk_fa_192_192(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, k, v, mask, scale, softcap, qkv, sinksf, M, S);
}
if (Dk == 256 && Dv == 256) {
return iqk_fa_256_256(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, k, v, mask, scale, softcap, qkv, sinksf, M, S);