mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-22 14:14:32 +00:00
Gigachat: CPU FA (needs 192 x 192 for MLA = 3)
This commit is contained in:
@@ -265,6 +265,7 @@ if (GGML_IQK_MUL_MAT)
|
||||
iqk/iqk_flash_attn.cpp
|
||||
iqk/fa/iqk_fa_576_512.cpp
|
||||
iqk/fa/iqk_fa_192_128.cpp
|
||||
iqk/fa/iqk_fa_192_192.cpp
|
||||
iqk/fa/iqk_fa_256_256.cpp
|
||||
iqk/fa/iqk_fa_128_128.cpp
|
||||
iqk/fa/iqk_fa_96_96.cpp
|
||||
|
||||
45
ggml/src/iqk/fa/iqk_fa_192_192.cpp
Normal file
45
ggml/src/iqk/fa/iqk_fa_192_192.cpp
Normal file
@@ -0,0 +1,45 @@
|
||||
#include "iqk/iqk_config.h"
|
||||
|
||||
#if defined IQK_IMPLEMENT && defined GGML_IQK_FLASH_ATTENTION
|
||||
|
||||
#include "iqk/fa/iqk_fa_templates.h"
|
||||
|
||||
IQK_FA_CASE(iqk_fa_192_192) {
|
||||
|
||||
auto type_k = ggml_type(int_type_k);
|
||||
auto type_v = ggml_type(int_type_v);
|
||||
|
||||
stride_q /= sizeof(float); // q stride as float
|
||||
auto ck = (const char *)k;
|
||||
auto cv = (const char *)v;
|
||||
auto cm = (const char *)mask;
|
||||
|
||||
#ifdef __AVX512BF16__
|
||||
if (type_k == GGML_TYPE_BF16) {
|
||||
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
|
||||
if (nk%64 == 0) {
|
||||
iqk_flash_helper_T<192, 192, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||
return true;
|
||||
}
|
||||
iqk_flash_helper_T<192, 192, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
if (nk%128 == 0) {
|
||||
return iqk_flash_helper_T<192, 192, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||
}
|
||||
if (nk%64 == 0) {
|
||||
return iqk_flash_helper_T<192, 192, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||
}
|
||||
|
||||
return iqk_flash_helper_T<192, 192, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -2235,6 +2235,7 @@ inline bool iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
|
||||
|
||||
IQK_FA_CASE(iqk_fa_576_512);
|
||||
IQK_FA_CASE(iqk_fa_192_128);
|
||||
IQK_FA_CASE(iqk_fa_192_192);
|
||||
IQK_FA_CASE(iqk_fa_256_256);
|
||||
IQK_FA_CASE(iqk_fa_128_128);
|
||||
IQK_FA_CASE(iqk_fa_96_96);
|
||||
|
||||
@@ -1349,6 +1349,11 @@ bool iqk_flash_attn_impl(int int_type_k, // type of k
|
||||
q, k, v, mask, scale, softcap, qkv, sinksf, M, S);
|
||||
}
|
||||
|
||||
if (Dk == 192 && Dv == 192) {
|
||||
return iqk_fa_192_192(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, k, v, mask, scale, softcap, qkv, sinksf, M, S);
|
||||
}
|
||||
|
||||
if (Dk == 256 && Dv == 256) {
|
||||
return iqk_fa_256_256(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, k, v, mask, scale, softcap, qkv, sinksf, M, S);
|
||||
|
||||
Reference in New Issue
Block a user