From 0369d2ba448941d216c5b35560745fde9d68f0eb Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 21 Nov 2025 11:44:34 +0200 Subject: [PATCH] Gigachat: CPU FA (needs 192 x 192 for MLA = 3) --- ggml/src/CMakeLists.txt | 1 + ggml/src/iqk/fa/iqk_fa_192_192.cpp | 45 ++++++++++++++++++++++++++++++ ggml/src/iqk/fa/iqk_fa_templates.h | 1 + ggml/src/iqk/iqk_mul_mat.cpp | 5 ++++ 4 files changed, 52 insertions(+) create mode 100644 ggml/src/iqk/fa/iqk_fa_192_192.cpp diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index b0bd3778..c9acf1fc 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -265,6 +265,7 @@ if (GGML_IQK_MUL_MAT) iqk/iqk_flash_attn.cpp iqk/fa/iqk_fa_576_512.cpp iqk/fa/iqk_fa_192_128.cpp + iqk/fa/iqk_fa_192_192.cpp iqk/fa/iqk_fa_256_256.cpp iqk/fa/iqk_fa_128_128.cpp iqk/fa/iqk_fa_96_96.cpp diff --git a/ggml/src/iqk/fa/iqk_fa_192_192.cpp b/ggml/src/iqk/fa/iqk_fa_192_192.cpp new file mode 100644 index 00000000..21fe033c --- /dev/null +++ b/ggml/src/iqk/fa/iqk_fa_192_192.cpp @@ -0,0 +1,45 @@ +#include "iqk/iqk_config.h" + +#if defined IQK_IMPLEMENT && defined GGML_IQK_FLASH_ATTENTION + +#include "iqk/fa/iqk_fa_templates.h" + +IQK_FA_CASE(iqk_fa_192_192) { + + auto type_k = ggml_type(int_type_k); + auto type_v = ggml_type(int_type_v); + + stride_q /= sizeof(float); // q stride as float + auto ck = (const char *)k; + auto cv = (const char *)v; + auto cm = (const char *)mask; + +#ifdef __AVX512BF16__ + if (type_k == GGML_TYPE_BF16) { + if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types + if (nk%64 == 0) { + iqk_flash_helper_T<192, 192, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); + return true; + } + iqk_flash_helper_T<192, 192, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); + return true; + } +#endif + + if (nk%128 == 0) { + return iqk_flash_helper_T<192, 192, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); + } + if (nk%64 == 0) { + return iqk_flash_helper_T<192, 192, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); + } + + return iqk_flash_helper_T<192, 192, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); + +} + +#endif diff --git a/ggml/src/iqk/fa/iqk_fa_templates.h b/ggml/src/iqk/fa/iqk_fa_templates.h index 3a0b7248..8e96844e 100644 --- a/ggml/src/iqk/fa/iqk_fa_templates.h +++ b/ggml/src/iqk/fa/iqk_fa_templates.h @@ -2235,6 +2235,7 @@ inline bool iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, IQK_FA_CASE(iqk_fa_576_512); IQK_FA_CASE(iqk_fa_192_128); +IQK_FA_CASE(iqk_fa_192_192); IQK_FA_CASE(iqk_fa_256_256); IQK_FA_CASE(iqk_fa_128_128); IQK_FA_CASE(iqk_fa_96_96); diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 7876d199..35573e49 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -1349,6 +1349,11 @@ bool iqk_flash_attn_impl(int int_type_k, // type of k q, k, v, mask, scale, softcap, qkv, sinksf, M, S); } + if (Dk == 192 && Dv == 192) { + return iqk_fa_192_192(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, + q, k, v, mask, scale, softcap, qkv, sinksf, M, S); + } + if (Dk == 256 && Dv == 256) { return iqk_fa_256_256(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, k, v, mask, scale, softcap, qkv, sinksf, M, S);